Unverified Commit a2e1c796 authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[GraphBolt][CUDA] Pipelined sampling accuracy fix (#7088)

parent 4ee0a8bd
...@@ -48,22 +48,33 @@ class FetchInsubgraphData(Mapper): ...@@ -48,22 +48,33 @@ class FetchInsubgraphData(Mapper):
with torch.cuda.stream(self.stream): with torch.cuda.stream(self.stream):
index = minibatch._seed_nodes index = minibatch._seed_nodes
if isinstance(index, dict): if isinstance(index, dict):
for idx in index.values():
idx.record_stream(torch.cuda.current_stream())
index = self.graph._convert_to_homogeneous_nodes(index) index = self.graph._convert_to_homogeneous_nodes(index)
else:
index.record_stream(torch.cuda.current_stream())
def record_stream(tensor):
if stream is not None and tensor.is_cuda:
tensor.record_stream(stream)
return tensor
if self.graph.node_type_offset is None:
# sorting not needed.
minibatch._subgraph_seed_nodes = None
else:
index, original_positions = index.sort() index, original_positions = index.sort()
if (original_positions.diff() == 1).all().item(): # is_sorted if (original_positions.diff() == 1).all().item():
# already sorted.
minibatch._subgraph_seed_nodes = None minibatch._subgraph_seed_nodes = None
else: else:
minibatch._subgraph_seed_nodes = original_positions minibatch._subgraph_seed_nodes = record_stream(
index.record_stream(torch.cuda.current_stream()) original_positions.sort()[1]
)
index_select_csc_with_indptr = partial( index_select_csc_with_indptr = partial(
torch.ops.graphbolt.index_select_csc, self.graph.csc_indptr torch.ops.graphbolt.index_select_csc, self.graph.csc_indptr
) )
def record_stream(tensor):
if stream is not None and tensor.is_cuda:
tensor.record_stream(stream)
indptr, indices = index_select_csc_with_indptr( indptr, indices = index_select_csc_with_indptr(
self.graph.indices, index, None self.graph.indices, index, None
) )
......
...@@ -41,8 +41,12 @@ def get_hetero_graph(): ...@@ -41,8 +41,12 @@ def get_hetero_graph():
@unittest.skipIf(F._default_context_str != "gpu", reason="Enabled only on GPU.") @unittest.skipIf(F._default_context_str != "gpu", reason="Enabled only on GPU.")
@pytest.mark.parametrize("hetero", [False, True]) @pytest.mark.parametrize("hetero", [False, True])
@pytest.mark.parametrize("prob_name", [None, "weight", "mask"]) @pytest.mark.parametrize("prob_name", [None, "weight", "mask"])
def test_NeighborSampler_GraphFetch(hetero, prob_name): @pytest.mark.parametrize("sorted", [False, True])
def test_NeighborSampler_GraphFetch(hetero, prob_name, sorted):
if sorted:
items = torch.arange(3) items = torch.arange(3)
else:
items = torch.tensor([2, 0, 1])
names = "seed_nodes" names = "seed_nodes"
itemset = gb.ItemSet(items, names=names) itemset = gb.ItemSet(items, names=names)
graph = get_hetero_graph().to(F.ctx()) graph = get_hetero_graph().to(F.ctx())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment