Unverified Commit ce223b36 authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[GraphBolt][CUDA] Add `type_per_edge` reading support and ensure same types are consecutive (#6865)


Co-authored-by: default avatarRhett Ying <85214957+Rhett-Ying@users.noreply.github.com>
parent 1e34f664
...@@ -114,6 +114,15 @@ struct IteratorFuncAddOffset { ...@@ -114,6 +114,15 @@ struct IteratorFuncAddOffset {
} }
}; };
template <typename indptr_t, typename in_degree_iterator_t>
struct SegmentEndFunc {
indptr_t* indptr;
in_degree_iterator_t in_degree;
__host__ __device__ auto operator()(int64_t i) {
return indptr[i] + in_degree[i];
}
};
c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors( c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
torch::Tensor indptr, torch::Tensor indices, torch::Tensor nodes, torch::Tensor indptr, torch::Tensor indices, torch::Tensor nodes,
const std::vector<int64_t>& fanouts, bool replace, bool layer, const std::vector<int64_t>& fanouts, bool replace, bool layer,
...@@ -155,6 +164,7 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors( ...@@ -155,6 +164,7 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
static_cast<int64_t>(0), std::numeric_limits<int64_t>::max()); static_cast<int64_t>(0), std::numeric_limits<int64_t>::max());
torch::Tensor picked_eids; torch::Tensor picked_eids;
torch::Tensor output_indices; torch::Tensor output_indices;
torch::optional<torch::Tensor> output_type_per_edge;
AT_DISPATCH_INTEGRAL_TYPES( AT_DISPATCH_INTEGRAL_TYPES(
indptr.scalar_type(), "SampleNeighborsIndptr", ([&] { indptr.scalar_type(), "SampleNeighborsIndptr", ([&] {
...@@ -261,6 +271,30 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors( ...@@ -261,6 +271,30 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
static_cast<indptr_t>(num_sampled_edges), static_cast<indptr_t>(num_sampled_edges),
nodes.options().dtype(indptr.scalar_type())); nodes.options().dtype(indptr.scalar_type()));
// Need to sort the sampled edges only when fanouts.size() == 1
// since multiple fanout sampling case is automatically going to
// be sorted.
if (type_per_edge && fanouts.size() == 1) {
// Ensuring sort result still ends up in sorted_edge_id_segments
std::swap(edge_id_segments, sorted_edge_id_segments);
auto sampled_segment_end_it = thrust::make_transform_iterator(
iota, SegmentEndFunc<indptr_t, decltype(sampled_degree)>{
sub_indptr.data_ptr<indptr_t>(), sampled_degree});
size_t tmp_storage_size = 0;
CUDA_CALL(cub::DeviceSegmentedSort::SortKeys(
nullptr, tmp_storage_size, edge_id_segments.get(),
sorted_edge_id_segments.get(), picked_eids.size(0),
num_rows, sub_indptr.data_ptr<indptr_t>(),
sampled_segment_end_it, stream));
auto tmp_storage =
allocator.AllocateStorage<char>(tmp_storage_size);
CUDA_CALL(cub::DeviceSegmentedSort::SortKeys(
tmp_storage.get(), tmp_storage_size, edge_id_segments.get(),
sorted_edge_id_segments.get(), picked_eids.size(0),
num_rows, sub_indptr.data_ptr<indptr_t>(),
sampled_segment_end_it, stream));
}
auto input_buffer_it = thrust::make_transform_iterator( auto input_buffer_it = thrust::make_transform_iterator(
iota, IteratorFunc<indptr_t, edge_id_t>{ iota, IteratorFunc<indptr_t, edge_id_t>{
sub_indptr.data_ptr<indptr_t>(), sub_indptr.data_ptr<indptr_t>(),
...@@ -305,6 +339,27 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors( ...@@ -305,6 +339,27 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
indices.data_ptr<indices_t>(), indices.data_ptr<indices_t>(),
output_indices.data_ptr<indices_t>()); output_indices.data_ptr<indices_t>());
})); }));
if (type_per_edge) {
// output_type_per_edge = type_per_edge.gather(0, picked_eids);
// The commented out torch equivalent above does not work when
// type_per_edge is on pinned memory. That is why, we have to
// reimplement it, similar to the indices gather operation above.
auto types = type_per_edge.value();
output_type_per_edge = torch::empty(
picked_eids.size(0),
picked_eids.options().dtype(types.scalar_type()));
AT_DISPATCH_INTEGRAL_TYPES(
types.scalar_type(), "SampleNeighborsOutputTypePerEdge", ([&] {
const auto exec_policy =
thrust::cuda::par_nosync(allocator).on(stream);
thrust::gather(
exec_policy, picked_eids.data_ptr<indptr_t>(),
picked_eids.data_ptr<indptr_t>() + picked_eids.size(0),
types.data_ptr<scalar_t>(),
output_type_per_edge.value().data_ptr<scalar_t>());
}));
}
})); }));
torch::optional<torch::Tensor> subgraph_reverse_edge_ids = torch::nullopt; torch::optional<torch::Tensor> subgraph_reverse_edge_ids = torch::nullopt;
...@@ -312,7 +367,7 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors( ...@@ -312,7 +367,7 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
return c10::make_intrusive<sampling::FusedSampledSubgraph>( return c10::make_intrusive<sampling::FusedSampledSubgraph>(
output_indptr, output_indices, nodes, torch::nullopt, output_indptr, output_indices, nodes, torch::nullopt,
subgraph_reverse_edge_ids, torch::nullopt); subgraph_reverse_edge_ids, output_type_per_edge);
} }
} // namespace ops } // namespace ops
......
...@@ -148,10 +148,6 @@ def get_hetero_graph(): ...@@ -148,10 +148,6 @@ def get_hetero_graph():
) )
@unittest.skipIf(
F._default_context_str != "cpu",
reason="Heterogenous sampling not yet supported on GPU.",
)
@pytest.mark.parametrize("labor", [False, True]) @pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_Node_Hetero(labor): def test_SubgraphSampler_Node_Hetero(labor):
graph = get_hetero_graph().to(F.ctx()) graph = get_hetero_graph().to(F.ctx())
...@@ -168,10 +164,6 @@ def test_SubgraphSampler_Node_Hetero(labor): ...@@ -168,10 +164,6 @@ def test_SubgraphSampler_Node_Hetero(labor):
assert len(minibatch.sampled_subgraphs) == num_layer assert len(minibatch.sampled_subgraphs) == num_layer
@unittest.skipIf(
F._default_context_str != "cpu",
reason="Heterogenous sampling not yet supported on GPU.",
)
@pytest.mark.parametrize("labor", [False, True]) @pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_Link_Hetero(labor): def test_SubgraphSampler_Link_Hetero(labor):
graph = get_hetero_graph().to(F.ctx()) graph = get_hetero_graph().to(F.ctx())
...@@ -197,10 +189,6 @@ def test_SubgraphSampler_Link_Hetero(labor): ...@@ -197,10 +189,6 @@ def test_SubgraphSampler_Link_Hetero(labor):
assert len(list(datapipe)) == 5 assert len(list(datapipe)) == 5
@unittest.skipIf(
F._default_context_str != "cpu",
reason="Heterogenous sampling not yet supported on GPU.",
)
@pytest.mark.parametrize("labor", [False, True]) @pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_Link_Hetero_With_Negative(labor): def test_SubgraphSampler_Link_Hetero_With_Negative(labor):
graph = get_hetero_graph().to(F.ctx()) graph = get_hetero_graph().to(F.ctx())
...@@ -346,10 +334,6 @@ def test_SubgraphSampler_without_dedpulication_Homo(labor): ...@@ -346,10 +334,6 @@ def test_SubgraphSampler_without_dedpulication_Homo(labor):
) )
@unittest.skipIf(
F._default_context_str != "cpu",
reason="Heterogenous sampling not yet supported on GPU.",
)
@pytest.mark.parametrize("labor", [False, True]) @pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_without_dedpulication_Hetero(labor): def test_SubgraphSampler_without_dedpulication_Hetero(labor):
graph = get_hetero_graph().to(F.ctx()) graph = get_hetero_graph().to(F.ctx())
...@@ -409,20 +393,20 @@ def test_SubgraphSampler_without_dedpulication_Hetero(labor): ...@@ -409,20 +393,20 @@ def test_SubgraphSampler_without_dedpulication_Hetero(labor):
for ntype in ["n1", "n2"]: for ntype in ["n1", "n2"]:
assert torch.equal( assert torch.equal(
sampled_subgraph.original_row_node_ids[ntype], sampled_subgraph.original_row_node_ids[ntype],
original_row_node_ids[step][ntype], original_row_node_ids[step][ntype].to(F.ctx()),
) )
assert torch.equal( assert torch.equal(
sampled_subgraph.original_column_node_ids[ntype], sampled_subgraph.original_column_node_ids[ntype],
original_column_node_ids[step][ntype], original_column_node_ids[step][ntype].to(F.ctx()),
) )
for etype in ["n1:e1:n2", "n2:e2:n1"]: for etype in ["n1:e1:n2", "n2:e2:n1"]:
assert torch.equal( assert torch.equal(
sampled_subgraph.sampled_csc[etype].indices, sampled_subgraph.sampled_csc[etype].indices,
csc_formats[step][etype].indices, csc_formats[step][etype].indices.to(F.ctx()),
) )
assert torch.equal( assert torch.equal(
sampled_subgraph.sampled_csc[etype].indptr, sampled_subgraph.sampled_csc[etype].indptr,
csc_formats[step][etype].indptr, csc_formats[step][etype].indptr.to(F.ctx()),
) )
...@@ -486,10 +470,6 @@ def test_SubgraphSampler_unique_csc_format_Homo(labor): ...@@ -486,10 +470,6 @@ def test_SubgraphSampler_unique_csc_format_Homo(labor):
) )
@unittest.skipIf(
F._default_context_str != "cpu",
reason="Heterogenous sampling not yet supported on GPU.",
)
@pytest.mark.parametrize("labor", [False, True]) @pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_unique_csc_format_Hetero(labor): def test_SubgraphSampler_unique_csc_format_Hetero(labor):
graph = get_hetero_graph().to(F.ctx()) graph = get_hetero_graph().to(F.ctx())
...@@ -554,18 +534,18 @@ def test_SubgraphSampler_unique_csc_format_Hetero(labor): ...@@ -554,18 +534,18 @@ def test_SubgraphSampler_unique_csc_format_Hetero(labor):
for ntype in ["n1", "n2"]: for ntype in ["n1", "n2"]:
assert torch.equal( assert torch.equal(
sampled_subgraph.original_row_node_ids[ntype], sampled_subgraph.original_row_node_ids[ntype],
original_row_node_ids[step][ntype], original_row_node_ids[step][ntype].to(F.ctx()),
) )
assert torch.equal( assert torch.equal(
sampled_subgraph.original_column_node_ids[ntype], sampled_subgraph.original_column_node_ids[ntype],
original_column_node_ids[step][ntype], original_column_node_ids[step][ntype].to(F.ctx()),
) )
for etype in ["n1:e1:n2", "n2:e2:n1"]: for etype in ["n1:e1:n2", "n2:e2:n1"]:
assert torch.equal( assert torch.equal(
sampled_subgraph.sampled_csc[etype].indices, sampled_subgraph.sampled_csc[etype].indices,
csc_formats[step][etype].indices, csc_formats[step][etype].indices.to(F.ctx()),
) )
assert torch.equal( assert torch.equal(
sampled_subgraph.sampled_csc[etype].indptr, sampled_subgraph.sampled_csc[etype].indptr,
csc_formats[step][etype].indptr, csc_formats[step][etype].indptr.to(F.ctx()),
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment