Unverified Commit 35b50b61 authored by Mingbang Wang's avatar Mingbang Wang Committed by GitHub
Browse files

[Misc] using `index_select` to avoid IndexError (#6777)


Co-authored-by: default avatarRhett Ying <85214957+Rhett-Ying@users.noreply.github.com>
parent 6e76bf3f
...@@ -93,7 +93,9 @@ def evaluate(model, dataset, device): ...@@ -93,7 +93,9 @@ def evaluate(model, dataset, device):
# Forward. # Forward.
y = model(data.blocks, x) y = model(data.blocks, x)
logit = ( logit = (
model.predictor(y[compacted_pairs[0]] * y[compacted_pairs[1]]) model.predictor(
y[compacted_pairs[0].long()] * y[compacted_pairs[1].long()]
)
.squeeze() .squeeze()
.detach() .detach()
) )
...@@ -132,7 +134,7 @@ def train(model, dataset, device): ...@@ -132,7 +134,7 @@ def train(model, dataset, device):
# Forward. # Forward.
y = model(data.blocks, x) y = model(data.blocks, x)
logits = model.predictor( logits = model.predictor(
y[compacted_pairs[0]] * y[compacted_pairs[1]] y[compacted_pairs[0].long()] * y[compacted_pairs[1].long()]
).squeeze() ).squeeze()
# Compute loss. # Compute loss.
......
...@@ -417,9 +417,11 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -417,9 +417,11 @@ class FusedCSCSamplingGraph(SamplingGraph):
and ORIGINAL_EDGE_ID in self.edge_attributes and ORIGINAL_EDGE_ID in self.edge_attributes
) )
if has_original_eids: if has_original_eids:
original_edge_ids = self.edge_attributes[ORIGINAL_EDGE_ID][ original_edge_ids = torch.index_select(
original_edge_ids self.edge_attributes[ORIGINAL_EDGE_ID],
] dim=0,
index=original_edge_ids,
)
if type_per_edge is None: if type_per_edge is None:
# The sampled graph is already a homogeneous graph. # The sampled graph is already a homogeneous graph.
sampled_csc = CSCFormatBase(indptr=indptr, indices=indices) sampled_csc = CSCFormatBase(indptr=indptr, indices=indices)
...@@ -1138,14 +1140,20 @@ def from_dglgraph( ...@@ -1138,14 +1140,20 @@ def from_dglgraph(
) )
# Assign edge type according to the order of CSC matrix. # Assign edge type according to the order of CSC matrix.
type_per_edge = None if is_homogeneous else homo_g.edata[ETYPE][edge_ids] type_per_edge = (
None
if is_homogeneous
else torch.index_select(homo_g.edata[ETYPE], dim=0, index=edge_ids)
)
node_attributes = {} node_attributes = {}
edge_attributes = {} edge_attributes = {}
if include_original_edge_id: if include_original_edge_id:
# Assign edge attributes according to the original eids mapping. # Assign edge attributes according to the original eids mapping.
edge_attributes[ORIGINAL_EDGE_ID] = homo_g.edata[EID][edge_ids] edge_attributes[ORIGINAL_EDGE_ID] = torch.index_select(
homo_g.edata[EID], dim=0, index=edge_ids
)
return FusedCSCSamplingGraph( return FusedCSCSamplingGraph(
torch.ops.graphbolt.fused_csc_sampling_graph( torch.ops.graphbolt.fused_csc_sampling_graph(
......
...@@ -148,7 +148,7 @@ class ItemShufflerAndBatcher: ...@@ -148,7 +148,7 @@ class ItemShufflerAndBatcher:
if isinstance(buffer, torch.Tensor): if isinstance(buffer, torch.Tensor):
# For item set that's initialized with integer or single tensor, # For item set that's initialized with integer or single tensor,
# `buffer` is a tensor. # `buffer` is a tensor.
return buffer[indices] return torch.index_select(buffer, dim=0, index=indices)
elif isinstance(buffer, list) and isinstance(buffer[0], DGLGraph): elif isinstance(buffer, list) and isinstance(buffer[0], DGLGraph):
# For item set that's initialized with a list of # For item set that's initialized with a list of
# DGLGraphs, `buffer` is a list of DGLGraphs. # DGLGraphs, `buffer` is a list of DGLGraphs.
......
...@@ -223,7 +223,9 @@ def _to_reverse_ids(node_pair, original_row_node_ids, original_column_node_ids): ...@@ -223,7 +223,9 @@ def _to_reverse_ids(node_pair, original_row_node_ids, original_column_node_ids):
indptr = node_pair.indptr indptr = node_pair.indptr
indices = node_pair.indices indices = node_pair.indices
if original_row_node_ids is not None: if original_row_node_ids is not None:
indices = original_row_node_ids[indices] indices = torch.index_select(
original_row_node_ids, dim=0, index=indices
)
if original_column_node_ids is not None: if original_column_node_ids is not None:
indptr = original_column_node_ids.repeat_interleave( indptr = original_column_node_ids.repeat_interleave(
indptr[1:] - indptr[:-1] indptr[1:] - indptr[:-1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment