"vscode:/vscode.git/clone" did not exist on "65676b4ba1a9fd4417293cb16f690d06a4b2fb4b"
Unverified Commit ce223b36 authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[GraphBolt][CUDA] Add `type_per_edge` reading support and ensure same types are consecutive (#6865)


Co-authored-by: default avatarRhett Ying <85214957+Rhett-Ying@users.noreply.github.com>
parent 1e34f664
......@@ -114,6 +114,15 @@ struct IteratorFuncAddOffset {
}
};
template <typename indptr_t, typename in_degree_iterator_t>
struct SegmentEndFunc {
indptr_t* indptr;
in_degree_iterator_t in_degree;
__host__ __device__ auto operator()(int64_t i) {
return indptr[i] + in_degree[i];
}
};
c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
torch::Tensor indptr, torch::Tensor indices, torch::Tensor nodes,
const std::vector<int64_t>& fanouts, bool replace, bool layer,
......@@ -155,6 +164,7 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
static_cast<int64_t>(0), std::numeric_limits<int64_t>::max());
torch::Tensor picked_eids;
torch::Tensor output_indices;
torch::optional<torch::Tensor> output_type_per_edge;
AT_DISPATCH_INTEGRAL_TYPES(
indptr.scalar_type(), "SampleNeighborsIndptr", ([&] {
......@@ -261,6 +271,30 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
static_cast<indptr_t>(num_sampled_edges),
nodes.options().dtype(indptr.scalar_type()));
// Need to sort the sampled edges only when fanouts.size() == 1
// since multiple fanout sampling case is automatically going to
// be sorted.
if (type_per_edge && fanouts.size() == 1) {
// Ensuring sort result still ends up in sorted_edge_id_segments
std::swap(edge_id_segments, sorted_edge_id_segments);
auto sampled_segment_end_it = thrust::make_transform_iterator(
iota, SegmentEndFunc<indptr_t, decltype(sampled_degree)>{
sub_indptr.data_ptr<indptr_t>(), sampled_degree});
size_t tmp_storage_size = 0;
CUDA_CALL(cub::DeviceSegmentedSort::SortKeys(
nullptr, tmp_storage_size, edge_id_segments.get(),
sorted_edge_id_segments.get(), picked_eids.size(0),
num_rows, sub_indptr.data_ptr<indptr_t>(),
sampled_segment_end_it, stream));
auto tmp_storage =
allocator.AllocateStorage<char>(tmp_storage_size);
CUDA_CALL(cub::DeviceSegmentedSort::SortKeys(
tmp_storage.get(), tmp_storage_size, edge_id_segments.get(),
sorted_edge_id_segments.get(), picked_eids.size(0),
num_rows, sub_indptr.data_ptr<indptr_t>(),
sampled_segment_end_it, stream));
}
auto input_buffer_it = thrust::make_transform_iterator(
iota, IteratorFunc<indptr_t, edge_id_t>{
sub_indptr.data_ptr<indptr_t>(),
......@@ -305,6 +339,27 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
indices.data_ptr<indices_t>(),
output_indices.data_ptr<indices_t>());
}));
if (type_per_edge) {
// output_type_per_edge = type_per_edge.gather(0, picked_eids);
// The commented out torch equivalent above does not work when
// type_per_edge is on pinned memory. That is why, we have to
// reimplement it, similar to the indices gather operation above.
auto types = type_per_edge.value();
output_type_per_edge = torch::empty(
picked_eids.size(0),
picked_eids.options().dtype(types.scalar_type()));
AT_DISPATCH_INTEGRAL_TYPES(
types.scalar_type(), "SampleNeighborsOutputTypePerEdge", ([&] {
const auto exec_policy =
thrust::cuda::par_nosync(allocator).on(stream);
thrust::gather(
exec_policy, picked_eids.data_ptr<indptr_t>(),
picked_eids.data_ptr<indptr_t>() + picked_eids.size(0),
types.data_ptr<scalar_t>(),
output_type_per_edge.value().data_ptr<scalar_t>());
}));
}
}));
torch::optional<torch::Tensor> subgraph_reverse_edge_ids = torch::nullopt;
......@@ -312,7 +367,7 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
return c10::make_intrusive<sampling::FusedSampledSubgraph>(
output_indptr, output_indices, nodes, torch::nullopt,
subgraph_reverse_edge_ids, torch::nullopt);
subgraph_reverse_edge_ids, output_type_per_edge);
}
} // namespace ops
......
......@@ -148,10 +148,6 @@ def get_hetero_graph():
)
@unittest.skipIf(
F._default_context_str != "cpu",
reason="Heterogenous sampling not yet supported on GPU.",
)
@pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_Node_Hetero(labor):
graph = get_hetero_graph().to(F.ctx())
......@@ -168,10 +164,6 @@ def test_SubgraphSampler_Node_Hetero(labor):
assert len(minibatch.sampled_subgraphs) == num_layer
@unittest.skipIf(
F._default_context_str != "cpu",
reason="Heterogenous sampling not yet supported on GPU.",
)
@pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_Link_Hetero(labor):
graph = get_hetero_graph().to(F.ctx())
......@@ -197,10 +189,6 @@ def test_SubgraphSampler_Link_Hetero(labor):
assert len(list(datapipe)) == 5
@unittest.skipIf(
F._default_context_str != "cpu",
reason="Heterogenous sampling not yet supported on GPU.",
)
@pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_Link_Hetero_With_Negative(labor):
graph = get_hetero_graph().to(F.ctx())
......@@ -346,10 +334,6 @@ def test_SubgraphSampler_without_dedpulication_Homo(labor):
)
@unittest.skipIf(
F._default_context_str != "cpu",
reason="Heterogenous sampling not yet supported on GPU.",
)
@pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_without_dedpulication_Hetero(labor):
graph = get_hetero_graph().to(F.ctx())
......@@ -409,20 +393,20 @@ def test_SubgraphSampler_without_dedpulication_Hetero(labor):
for ntype in ["n1", "n2"]:
assert torch.equal(
sampled_subgraph.original_row_node_ids[ntype],
original_row_node_ids[step][ntype],
original_row_node_ids[step][ntype].to(F.ctx()),
)
assert torch.equal(
sampled_subgraph.original_column_node_ids[ntype],
original_column_node_ids[step][ntype],
original_column_node_ids[step][ntype].to(F.ctx()),
)
for etype in ["n1:e1:n2", "n2:e2:n1"]:
assert torch.equal(
sampled_subgraph.sampled_csc[etype].indices,
csc_formats[step][etype].indices,
csc_formats[step][etype].indices.to(F.ctx()),
)
assert torch.equal(
sampled_subgraph.sampled_csc[etype].indptr,
csc_formats[step][etype].indptr,
csc_formats[step][etype].indptr.to(F.ctx()),
)
......@@ -486,10 +470,6 @@ def test_SubgraphSampler_unique_csc_format_Homo(labor):
)
@unittest.skipIf(
F._default_context_str != "cpu",
reason="Heterogenous sampling not yet supported on GPU.",
)
@pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_unique_csc_format_Hetero(labor):
graph = get_hetero_graph().to(F.ctx())
......@@ -554,18 +534,18 @@ def test_SubgraphSampler_unique_csc_format_Hetero(labor):
for ntype in ["n1", "n2"]:
assert torch.equal(
sampled_subgraph.original_row_node_ids[ntype],
original_row_node_ids[step][ntype],
original_row_node_ids[step][ntype].to(F.ctx()),
)
assert torch.equal(
sampled_subgraph.original_column_node_ids[ntype],
original_column_node_ids[step][ntype],
original_column_node_ids[step][ntype].to(F.ctx()),
)
for etype in ["n1:e1:n2", "n2:e2:n1"]:
assert torch.equal(
sampled_subgraph.sampled_csc[etype].indices,
csc_formats[step][etype].indices,
csc_formats[step][etype].indices.to(F.ctx()),
)
assert torch.equal(
sampled_subgraph.sampled_csc[etype].indptr,
csc_formats[step][etype].indptr,
csc_formats[step][etype].indptr.to(F.ctx()),
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment