"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "e6110f68569c7b620306e678c3a3d9eee1a293e2"
Unverified Commit 35b50b61 authored by Mingbang Wang's avatar Mingbang Wang Committed by GitHub
Browse files

[Misc] using `index_select` to avoid IndexError (#6777)


Co-authored-by: default avatarRhett Ying <85214957+Rhett-Ying@users.noreply.github.com>
parent 6e76bf3f
......@@ -93,7 +93,9 @@ def evaluate(model, dataset, device):
# Forward.
y = model(data.blocks, x)
logit = (
model.predictor(y[compacted_pairs[0]] * y[compacted_pairs[1]])
model.predictor(
y[compacted_pairs[0].long()] * y[compacted_pairs[1].long()]
)
.squeeze()
.detach()
)
......@@ -132,7 +134,7 @@ def train(model, dataset, device):
# Forward.
y = model(data.blocks, x)
logits = model.predictor(
y[compacted_pairs[0]] * y[compacted_pairs[1]]
y[compacted_pairs[0].long()] * y[compacted_pairs[1].long()]
).squeeze()
# Compute loss.
......
......@@ -417,9 +417,11 @@ class FusedCSCSamplingGraph(SamplingGraph):
and ORIGINAL_EDGE_ID in self.edge_attributes
)
if has_original_eids:
original_edge_ids = self.edge_attributes[ORIGINAL_EDGE_ID][
original_edge_ids
]
original_edge_ids = torch.index_select(
self.edge_attributes[ORIGINAL_EDGE_ID],
dim=0,
index=original_edge_ids,
)
if type_per_edge is None:
# The sampled graph is already a homogeneous graph.
sampled_csc = CSCFormatBase(indptr=indptr, indices=indices)
......@@ -1138,14 +1140,20 @@ def from_dglgraph(
)
# Assign edge type according to the order of CSC matrix.
type_per_edge = None if is_homogeneous else homo_g.edata[ETYPE][edge_ids]
type_per_edge = (
None
if is_homogeneous
else torch.index_select(homo_g.edata[ETYPE], dim=0, index=edge_ids)
)
node_attributes = {}
edge_attributes = {}
if include_original_edge_id:
# Assign edge attributes according to the original eids mapping.
edge_attributes[ORIGINAL_EDGE_ID] = homo_g.edata[EID][edge_ids]
edge_attributes[ORIGINAL_EDGE_ID] = torch.index_select(
homo_g.edata[EID], dim=0, index=edge_ids
)
return FusedCSCSamplingGraph(
torch.ops.graphbolt.fused_csc_sampling_graph(
......
......@@ -148,7 +148,7 @@ class ItemShufflerAndBatcher:
if isinstance(buffer, torch.Tensor):
# For item set that's initialized with integer or single tensor,
# `buffer` is a tensor.
return buffer[indices]
return torch.index_select(buffer, dim=0, index=indices)
elif isinstance(buffer, list) and isinstance(buffer[0], DGLGraph):
# For item set that's initialized with a list of
# DGLGraphs, `buffer` is a list of DGLGraphs.
......
......@@ -223,7 +223,9 @@ def _to_reverse_ids(node_pair, original_row_node_ids, original_column_node_ids):
indptr = node_pair.indptr
indices = node_pair.indices
if original_row_node_ids is not None:
indices = original_row_node_ids[indices]
indices = torch.index_select(
original_row_node_ids, dim=0, index=indices
)
if original_column_node_ids is not None:
indptr = original_column_node_ids.repeat_interleave(
indptr[1:] - indptr[:-1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment