"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "7bc2fff1a552ea16de1bdfccdf5d865613f6a63f"
Unverified Commit 45da2b23 authored by Ramon Zhou's avatar Ramon Zhou Committed by GitHub
Browse files

[Graphbolt] Add num_pick_fn into neighbor sampling (#6101)

parent 6b047e4d
......@@ -222,9 +222,10 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
const std::string& shared_memory_name);
private:
template <typename PickFn>
template <typename NumPickFn, typename PickFn>
c10::intrusive_ptr<SampledSubgraph> SampleNeighborsImpl(
const torch::Tensor& nodes, bool return_eids, PickFn pick_fn) const;
const torch::Tensor& nodes, bool return_eids, NumPickFn num_pick_fn,
PickFn pick_fn) const;
/**
* @brief Build a CSCSamplingGraph from shared memory tensors.
......@@ -286,6 +287,41 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
SharedMemoryPtr tensor_meta_shm_, tensor_data_shm_;
};
/**
* @brief Calculate the number of the neighbors to be picked for the given node.
*
* @param fanout The number of edges to be sampled for each node. It should be
* >= 0 or -1.
* - When the value is -1, all neighbors (with non-zero probability, if
* weighted) will be chosen for sampling. It is equivalent to selecting all
* neighbors with non-zero probability when the fanout is >= the number of
* neighbors (and replacement is set to false).
* - When the value is a non-negative integer, it serves as a minimum
* threshold for selecting neighbors.
* @param replace Boolean indicating whether the sample is performed with or
* without replacement. If True, a value can be selected multiple times.
* Otherwise, each value can be selected only once.
* @param probs_or_mask Optional tensor containing the (unnormalized)
* probabilities associated with each neighboring edge of a node in the original
* graph. It must be a 1D floating-point tensor with the number of elements
* equal to the number of edges in the graph.
* @param offset The starting edge ID for the connected neighbors of the given
* node.
* @param num_neighbors The number of neighbors of this node.
*
* @return The pick number of the given node.
*/
int64_t NumPick(
int64_t fanout, bool replace,
const torch::optional<torch::Tensor>& probs_or_mask, int64_t offset,
int64_t num_neighbors);
int64_t NumPickByEtype(
const std::vector<int64_t>& fanouts, bool replace,
const torch::Tensor& type_per_edge,
const torch::optional<torch::Tensor>& probs_or_mask, int64_t offset,
int64_t num_neighbors);
/**
* @brief Picks a specified number of neighbors for a node, starting from the
* given offset and having the specified number of neighbors.
......
......@@ -131,6 +131,45 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::InSubgraph(
: torch::nullopt);
}
/**
* @brief Get a lambda function which counts the number of the neighbors to be
* sampled.
*
* @param fanouts The number of edges to be sampled for each node with or
* without considering edge types.
* @param replace Boolean indicating whether the sample is performed with or
* without replacement. If True, a value can be selected multiple times.
* Otherwise, each value can be selected only once.
* @param type_per_edge A tensor representing the type of each edge, if
* present.
* @param probs_or_mask Optional tensor containing the (unnormalized)
* probabilities associated with each neighboring edge of a node in the original
* graph. It must be a 1D floating-point tensor with the number of elements
* equal to the number of edges in the graph.
*
* @return A lambda function (int64_t offset, int64_t num_neighbors) ->
* torch::Tensor, which takes offset (the starting edge ID of the given node)
* and num_neighbors (number of neighbors) as params and returns the pick number
* of the given node.
*/
auto GetNumPickFn(
const std::vector<int64_t>& fanouts, bool replace,
const torch::optional<torch::Tensor>& type_per_edge,
const torch::optional<torch::Tensor>& probs_or_mask) {
// If fanouts.size() > 1, returns the total number of all edge types of the
// given node.
return [&fanouts, replace, &probs_or_mask, &type_per_edge](
int64_t offset, int64_t num_neighbors) {
if (fanouts.size() > 1) {
return NumPickByEtype(
fanouts, replace, type_per_edge.value(), probs_or_mask, offset,
num_neighbors);
} else {
return NumPick(fanouts[0], replace, probs_or_mask, offset, num_neighbors);
}
};
}
/**
* @brief Get a lambda function which contains the sampling process.
*
......@@ -149,8 +188,9 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::InSubgraph(
* @param args Contains sampling algorithm specific arguments.
*
* @return A lambda function: (int64_t offset, int64_t num_neighbors) ->
* torch::Tensor, which takes offset and num_neighbors as params and returns a
* tensor of picked neighbors.
* torch::Tensor, which takes offset (the starting edge ID of the given node)
* and num_neighbors (number of neighbors) as params and returns a tensor of
* picked neighbors.
*/
template <SamplerType S>
auto GetPickFn(
......@@ -174,9 +214,10 @@ auto GetPickFn(
};
}
template <typename PickFn>
template <typename NumPickFn, typename PickFn>
c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighborsImpl(
const torch::Tensor& nodes, bool return_eids, PickFn pick_fn) const {
const torch::Tensor& nodes, bool return_eids, NumPickFn num_pick_fn,
PickFn pick_fn) const {
const int64_t num_nodes = nodes.size(0);
const int64_t num_threads = torch::get_num_threads();
std::vector<torch::Tensor> picked_neighbors_per_thread(num_threads);
......@@ -198,8 +239,9 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighborsImpl(
std::vector<torch::Tensor> picked_neighbors_cur_thread(
local_grain_size);
const auto nodes_data_ptr = nodes.data_ptr<int64_t>();
for (scalar_t i = begin; i < end; ++i) {
const auto nid = nodes[i].item<int64_t>();
const auto nid = nodes_data_ptr[i];
TORCH_CHECK(
nid >= 0 && nid < NumNodes(),
"The seed nodes' IDs should fall within the range of the "
......@@ -221,6 +263,11 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighborsImpl(
// This number should be the same as the result of num_pick_fn.
num_picked_neighbors_per_node[i + 1] =
picked_neighbors_cur_thread[i - begin].size(0);
TORCH_CHECK(
*num_picked_neighbors_per_node[i + 1].data_ptr<int64_t>() ==
num_pick_fn(offset, num_neighbors),
"Return value of num_pick_fn doesn't match the actual "
"picked number.");
}
picked_neighbors_per_thread[thread_id] =
torch::cat(picked_neighbors_cur_thread);
......@@ -266,6 +313,7 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighbors(
SamplerArgs<SamplerType::LABOR> args{indices_, random_seed, NumNodes()};
return SampleNeighborsImpl(
nodes, return_eids,
GetNumPickFn(fanouts, replace, type_per_edge_, probs_or_mask),
GetPickFn(
fanouts, replace, indptr_.options(), type_per_edge_, probs_or_mask,
args));
......@@ -273,6 +321,7 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighbors(
SamplerArgs<SamplerType::NEIGHBOR> args;
return SampleNeighborsImpl(
nodes, return_eids,
GetNumPickFn(fanouts, replace, type_per_edge_, probs_or_mask),
GetPickFn(
fanouts, replace, indptr_.options(), type_per_edge_, probs_or_mask,
args));
......@@ -321,6 +370,50 @@ c10::intrusive_ptr<CSCSamplingGraph> CSCSamplingGraph::LoadFromSharedMemory(
return BuildGraphFromSharedMemoryTensors(std::move(shared_memory_tensors));
}
int64_t NumPick(
int64_t fanout, bool replace,
const torch::optional<torch::Tensor>& probs_or_mask, int64_t offset,
int64_t num_neighbors) {
int64_t num_valid_neighbors =
probs_or_mask.has_value()
? *torch::count_nonzero(
probs_or_mask.value().slice(0, offset, offset + num_neighbors))
.data_ptr<int64_t>()
: num_neighbors;
if (num_valid_neighbors == 0 || fanout == -1) return num_valid_neighbors;
return replace ? fanout : std::min(fanout, num_valid_neighbors);
}
int64_t NumPickByEtype(
const std::vector<int64_t>& fanouts, bool replace,
const torch::Tensor& type_per_edge,
const torch::optional<torch::Tensor>& probs_or_mask, int64_t offset,
int64_t num_neighbors) {
int64_t etype_begin = offset;
const int64_t end = offset + num_neighbors;
int64_t total_count = 0;
AT_DISPATCH_INTEGRAL_TYPES(
type_per_edge.scalar_type(), "NumPickFnByEtype", ([&] {
const scalar_t* type_per_edge_data = type_per_edge.data_ptr<scalar_t>();
while (etype_begin < end) {
scalar_t etype = type_per_edge_data[etype_begin];
TORCH_CHECK(
etype >= 0 && etype < (int64_t)fanouts.size(),
"Etype values exceed the number of fanouts.");
auto etype_end_it = std::upper_bound(
type_per_edge_data + etype_begin, type_per_edge_data + end,
etype);
int64_t etype_end = etype_end_it - type_per_edge_data;
// Do sampling for one etype.
total_count += NumPick(
fanouts[etype], replace, probs_or_mask, etype_begin,
etype_end - etype_begin);
etype_begin = etype_end;
}
}));
return total_count;
}
/**
* @brief Perform uniform sampling of elements and return the sampled indices.
*
......
......@@ -861,3 +861,192 @@ def test_from_dglgraph_heterogeneous():
("n2", "r21", "n1"): 2,
("n2", "r23", "n3"): 3,
}
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
@pytest.mark.parametrize("replace", [False, True])
@pytest.mark.parametrize("labor", [False, True])
@pytest.mark.parametrize(
"fanouts, probs_name",
[
([2], "mask"),
([3], "mask"),
([4], "mask"),
([-1], "mask"),
([7], "mask"),
([3], "all"),
([-1], "all"),
([7], "all"),
([3], "zero"),
([-1], "zero"),
([3], "none"),
([-1], "none"),
],
)
def test_sample_neighbors_homo_pick_number(fanouts, replace, labor, probs_name):
"""Original graph in COO:
1 1 1 1 1 1
0 0 0 0 0 0
0 0 0 0 0 0
0 0 0 0 0 0
0 0 0 0 0 0
0 0 0 0 0 0
"""
# Initialize data.
num_nodes = 6
num_edges = 6
indptr = torch.LongTensor([0, 6, 6, 6, 6, 6, 6])
indices = torch.LongTensor([0, 1, 2, 3, 4, 5])
assert indptr[-1] == num_edges
assert indptr[-1] == len(indices)
edge_attributes = {
"mask": torch.BoolTensor([1, 0, 0, 1, 0, 1]),
"all": torch.BoolTensor([1, 1, 1, 1, 1, 1]),
"zero": torch.BoolTensor([0, 0, 0, 0, 0, 0]),
}
# Construct CSCSamplingGraph.
graph = gb.from_csc(indptr, indices, edge_attributes=edge_attributes)
# Generate subgraph via sample neighbors.
nodes = torch.LongTensor([0, 1])
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
# Make sure no exception will be thrown.
subgraph = sampler(
nodes,
fanouts=torch.LongTensor(fanouts),
replace=replace,
probs_name=probs_name if probs_name != "none" else None,
)
sampled_num = subgraph.node_pairs[0].size(0)
# Verify in subgraph.
if probs_name == "mask":
if fanouts[0] == -1:
assert sampled_num == 3
else:
if replace:
assert sampled_num == fanouts[0]
else:
assert sampled_num == min(fanouts[0], 3)
elif probs_name == "zero":
assert sampled_num == 0
else:
if fanouts[0] == -1:
assert sampled_num == 6
else:
if replace:
assert sampled_num == fanouts[0]
else:
assert sampled_num == min(fanouts[0], 6)
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
@pytest.mark.parametrize("replace", [False, True])
@pytest.mark.parametrize("labor", [False, True])
@pytest.mark.parametrize(
"fanouts, probs_name",
[
([-1, -1, -1], "mask"),
([1, 1, 1], "mask"),
([2, 2, 2], "mask"),
([3, 3, 3], "mask"),
([4, 4, 4], "mask"),
([-1, 1, 3], "none"),
([2, -1, 4], "none"),
],
)
def test_sample_neighbors_hetero_pick_number(
fanouts, replace, labor, probs_name
):
# Initialize data.
num_nodes = 10
num_edges = 9
ntypes = {"N0": 0, "N1": 1, "N2": 2, "N3": 3}
etypes = {
("N0", "R0", "N1"): 0,
("N0", "R1", "N2"): 1,
("N0", "R2", "N3"): 2,
}
metadata = gb.GraphMetadata(ntypes, etypes)
indptr = torch.LongTensor([0, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9])
indices = torch.LongTensor([1, 2, 3, 4, 5, 6, 7, 8, 9])
node_type_offset = torch.LongTensor([0, 1, 4, 7, 10])
type_per_edge = torch.LongTensor([0, 0, 0, 1, 1, 1, 2, 2, 2])
assert indptr[-1] == num_edges
assert indptr[-1] == len(indices)
assert node_type_offset[-1] == num_nodes
assert all(type_per_edge < len(etypes))
edge_attributes = {
"mask": torch.BoolTensor([1, 1, 0, 1, 1, 1, 0, 0, 0]),
"all": torch.BoolTensor([1, 1, 1, 1, 1, 1, 1, 1, 1]),
"zero": torch.BoolTensor([0, 0, 0, 0, 0, 0, 0, 0, 0]),
}
# Construct CSCSamplingGraph.
graph = gb.from_csc(
indptr,
indices,
edge_attributes=edge_attributes,
node_type_offset=node_type_offset,
type_per_edge=type_per_edge,
metadata=metadata,
)
# Generate subgraph via sample neighbors.
nodes = torch.LongTensor([0, 1])
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
# Make sure no exception will be thrown.
subgraph = sampler(
nodes,
fanouts=torch.LongTensor(fanouts),
replace=replace,
probs_name=probs_name if probs_name != "none" else None,
)
if probs_name == "none":
for etype, pairs in subgraph.node_pairs.items():
fanout = fanouts[etypes[etype]]
if fanout == -1:
assert pairs[0].size(0) == 3
else:
if replace:
assert pairs[0].size(0) == fanout
else:
assert pairs[0].size(0) == min(fanout, 3)
else:
fanout = fanouts[0] # Here fanout is the same for all etypes.
for etype, pairs in subgraph.node_pairs.items():
if etypes[etype] == 0:
# Etype 0: 2 valid neighbors.
if fanout == -1:
assert pairs[0].size(0) == 2
else:
if replace:
assert pairs[0].size(0) == fanout
else:
assert pairs[0].size(0) == min(fanout, 2)
elif etypes[etype] == 1:
# Etype 1: 3 valid neighbors.
if fanout == -1:
assert pairs[0].size(0) == 3
else:
if replace:
assert pairs[0].size(0) == fanout
else:
assert pairs[0].size(0) == min(fanout, 3)
else:
# Etype 2: 0 valid neighbors.
assert pairs[0].size(0) == 0
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment