Unverified Commit a272efed authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[GraphBolt] Implement labor dependent minibatching - python side. (#7208)

parent 93990a90
...@@ -45,6 +45,9 @@ namespace ops { ...@@ -45,6 +45,9 @@ namespace ops {
* @param probs_or_mask An optional tensor with (unnormalized) probabilities * @param probs_or_mask An optional tensor with (unnormalized) probabilities
* corresponding to each neighboring edge of a node. It must be * corresponding to each neighboring edge of a node. It must be
* a 1D tensor, with the number of elements equaling the total number of edges. * a 1D tensor, with the number of elements equaling the total number of edges.
* @param random_seed The random seed for the sampler for layer=True.
* @param seed2_contribution The contribution of the second random seed, [0, 1)
* for layer=True.
* *
* @return An intrusive pointer to a FusedSampledSubgraph object containing * @return An intrusive pointer to a FusedSampledSubgraph object containing
* the sampled graph's information. * the sampled graph's information.
...@@ -54,7 +57,9 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors( ...@@ -54,7 +57,9 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
torch::optional<torch::Tensor> nodes, const std::vector<int64_t>& fanouts, torch::optional<torch::Tensor> nodes, const std::vector<int64_t>& fanouts,
bool replace, bool layer, bool return_eids, bool replace, bool layer, bool return_eids,
torch::optional<torch::Tensor> type_per_edge = torch::nullopt, torch::optional<torch::Tensor> type_per_edge = torch::nullopt,
torch::optional<torch::Tensor> probs_or_mask = torch::nullopt); torch::optional<torch::Tensor> probs_or_mask = torch::nullopt,
torch::optional<torch::Tensor> random_seed = torch::nullopt,
float seed2_contribution = .0f);
/** /**
* @brief Return the subgraph induced on the inbound edges of the given nodes. * @brief Return the subgraph induced on the inbound edges of the given nodes.
......
...@@ -314,6 +314,9 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder { ...@@ -314,6 +314,9 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
* probabilities corresponding to each neighboring edge of a node. It must be * probabilities corresponding to each neighboring edge of a node. It must be
* a 1D floating-point or boolean tensor, with the number of elements * a 1D floating-point or boolean tensor, with the number of elements
* equalling the total number of edges. * equalling the total number of edges.
* @param random_seed The random seed for the sampler for layer=True.
* @param seed2_contribution The contribution of the second random seed,
* [0, 1) for layer=True.
* *
* @return An intrusive pointer to a FusedSampledSubgraph object containing * @return An intrusive pointer to a FusedSampledSubgraph object containing
* the sampled graph's information. * the sampled graph's information.
...@@ -321,7 +324,9 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder { ...@@ -321,7 +324,9 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
c10::intrusive_ptr<FusedSampledSubgraph> SampleNeighbors( c10::intrusive_ptr<FusedSampledSubgraph> SampleNeighbors(
torch::optional<torch::Tensor> nodes, const std::vector<int64_t>& fanouts, torch::optional<torch::Tensor> nodes, const std::vector<int64_t>& fanouts,
bool replace, bool layer, bool return_eids, bool replace, bool layer, bool return_eids,
torch::optional<std::string> probs_name) const; torch::optional<std::string> probs_name,
torch::optional<torch::Tensor> random_seed,
double seed2_contribution) const;
/** /**
* @brief Sample neighboring edges of the given nodes with a temporal * @brief Sample neighboring edges of the given nodes with a temporal
......
...@@ -125,7 +125,9 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors( ...@@ -125,7 +125,9 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
torch::optional<torch::Tensor> nodes, const std::vector<int64_t>& fanouts, torch::optional<torch::Tensor> nodes, const std::vector<int64_t>& fanouts,
bool replace, bool layer, bool return_eids, bool replace, bool layer, bool return_eids,
torch::optional<torch::Tensor> type_per_edge, torch::optional<torch::Tensor> type_per_edge,
torch::optional<torch::Tensor> probs_or_mask) { torch::optional<torch::Tensor> probs_or_mask,
torch::optional<torch::Tensor> random_seed_tensor,
float seed2_contribution) {
TORCH_CHECK(!replace, "Sampling with replacement is not supported yet!"); TORCH_CHECK(!replace, "Sampling with replacement is not supported yet!");
// Assume that indptr, indices, nodes, type_per_edge and probs_or_mask // Assume that indptr, indices, nodes, type_per_edge and probs_or_mask
// are all resident on the GPU. If not, it is better to first extract them // are all resident on the GPU. If not, it is better to first extract them
...@@ -202,8 +204,14 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors( ...@@ -202,8 +204,14 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
auto coo_rows = ExpandIndptrImpl( auto coo_rows = ExpandIndptrImpl(
sub_indptr, indices.scalar_type(), torch::nullopt, num_edges); sub_indptr, indices.scalar_type(), torch::nullopt, num_edges);
num_edges = coo_rows.size(0); num_edges = coo_rows.size(0);
const continuous_seed random_seed(RandomEngine::ThreadLocal()->RandInt( const continuous_seed random_seed = [&] {
static_cast<int64_t>(0), std::numeric_limits<int64_t>::max())); if (random_seed_tensor.has_value()) {
return continuous_seed(random_seed_tensor.value(), seed2_contribution);
} else {
return continuous_seed{RandomEngine::ThreadLocal()->RandInt(
static_cast<int64_t>(0), std::numeric_limits<int64_t>::max())};
}
}();
auto output_indptr = torch::empty_like(sub_indptr); auto output_indptr = torch::empty_like(sub_indptr);
torch::Tensor picked_eids; torch::Tensor picked_eids;
torch::Tensor output_indices; torch::Tensor output_indices;
......
...@@ -618,7 +618,9 @@ FusedCSCSamplingGraph::SampleNeighborsImpl( ...@@ -618,7 +618,9 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors( c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
torch::optional<torch::Tensor> nodes, const std::vector<int64_t>& fanouts, torch::optional<torch::Tensor> nodes, const std::vector<int64_t>& fanouts,
bool replace, bool layer, bool return_eids, bool replace, bool layer, bool return_eids,
torch::optional<std::string> probs_name) const { torch::optional<std::string> probs_name,
torch::optional<torch::Tensor> random_seed,
double seed2_contribution) const {
auto probs_or_mask = this->EdgeAttribute(probs_name); auto probs_or_mask = this->EdgeAttribute(probs_name);
// If nodes does not have a value, then we expect all arguments to be resident // If nodes does not have a value, then we expect all arguments to be resident
...@@ -642,7 +644,7 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors( ...@@ -642,7 +644,7 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
c10::DeviceType::CUDA, "SampleNeighbors", { c10::DeviceType::CUDA, "SampleNeighbors", {
return ops::SampleNeighbors( return ops::SampleNeighbors(
indptr_, indices_, nodes, fanouts, replace, layer, return_eids, indptr_, indices_, nodes, fanouts, replace, layer, return_eids,
type_per_edge_, probs_or_mask); type_per_edge_, probs_or_mask, random_seed, seed2_contribution);
}); });
} }
TORCH_CHECK(nodes.has_value(), "Nodes can not be None on the CPU."); TORCH_CHECK(nodes.has_value(), "Nodes can not be None on the CPU.");
...@@ -658,9 +660,20 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors( ...@@ -658,9 +660,20 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
} }
if (layer) { if (layer) {
const int64_t random_seed = RandomEngine::ThreadLocal()->RandInt( SamplerArgs<SamplerType::LABOR> args = [&] {
static_cast<int64_t>(0), std::numeric_limits<int64_t>::max()); if (random_seed.has_value()) {
SamplerArgs<SamplerType::LABOR> args{indices_, random_seed, NumNodes()}; return SamplerArgs<SamplerType::LABOR>{
indices_,
{random_seed.value(), static_cast<float>(seed2_contribution)},
NumNodes()};
} else {
return SamplerArgs<SamplerType::LABOR>{
indices_,
RandomEngine::ThreadLocal()->RandInt(
static_cast<int64_t>(0), std::numeric_limits<int64_t>::max()),
NumNodes()};
}
}();
return SampleNeighborsImpl( return SampleNeighborsImpl(
nodes.value(), return_eids, nodes.value(), return_eids,
GetNumPickFn(fanouts, replace, type_per_edge_, probs_or_mask), GetNumPickFn(fanouts, replace, type_per_edge_, probs_or_mask),
......
...@@ -735,9 +735,11 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -735,9 +735,11 @@ class FusedCSCSamplingGraph(SamplingGraph):
nodes, nodes,
fanouts.tolist(), fanouts.tolist(),
replace, replace,
False, False, # is_labor
return_eids, return_eids,
probs_name, probs_name,
None, # random_seed, labor parameter
0, # seed2_contribution, labor_parameter
) )
def sample_layer_neighbors( def sample_layer_neighbors(
...@@ -746,6 +748,8 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -746,6 +748,8 @@ class FusedCSCSamplingGraph(SamplingGraph):
fanouts: torch.Tensor, fanouts: torch.Tensor,
replace: bool = False, replace: bool = False,
probs_name: Optional[str] = None, probs_name: Optional[str] = None,
random_seed: torch.Tensor = None,
seed2_contribution: float = 0.0,
) -> SampledSubgraphImpl: ) -> SampledSubgraphImpl:
"""Sample neighboring edges of the given nodes and return the induced """Sample neighboring edges of the given nodes and return the induced
subgraph via layer-neighbor sampling from the NeurIPS 2023 paper subgraph via layer-neighbor sampling from the NeurIPS 2023 paper
...@@ -833,6 +837,8 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -833,6 +837,8 @@ class FusedCSCSamplingGraph(SamplingGraph):
True, True,
has_original_eids, has_original_eids,
probs_name, probs_name,
random_seed,
seed2_contribution,
) )
return self._convert_to_sampled_subgraph(C_sampled_subgraph) return self._convert_to_sampled_subgraph(C_sampled_subgraph)
......
...@@ -146,12 +146,17 @@ class SamplePerLayerFromFetchedSubgraph(MiniBatchTransformer): ...@@ -146,12 +146,17 @@ class SamplePerLayerFromFetchedSubgraph(MiniBatchTransformer):
def _sample_per_layer_from_fetched_subgraph(self, minibatch): def _sample_per_layer_from_fetched_subgraph(self, minibatch):
subgraph = minibatch.sampled_subgraphs[0] subgraph = minibatch.sampled_subgraphs[0]
kwargs = {
key[1:]: getattr(minibatch, key)
for key in ["_random_seed", "_seed2_contribution"]
if hasattr(minibatch, key)
}
sampled_subgraph = getattr(subgraph, self.sampler_name)( sampled_subgraph = getattr(subgraph, self.sampler_name)(
minibatch._subgraph_seed_nodes, minibatch._subgraph_seed_nodes,
self.fanout, self.fanout,
self.replace, self.replace,
self.prob_name, self.prob_name,
**kwargs,
) )
delattr(minibatch, "_subgraph_seed_nodes") delattr(minibatch, "_subgraph_seed_nodes")
sampled_subgraph.original_column_node_ids = minibatch._seed_nodes sampled_subgraph.original_column_node_ids = minibatch._seed_nodes
...@@ -172,8 +177,17 @@ class SamplePerLayer(MiniBatchTransformer): ...@@ -172,8 +177,17 @@ class SamplePerLayer(MiniBatchTransformer):
self.prob_name = prob_name self.prob_name = prob_name
def _sample_per_layer(self, minibatch): def _sample_per_layer(self, minibatch):
kwargs = {
key[1:]: getattr(minibatch, key)
for key in ["_random_seed", "_seed2_contribution"]
if hasattr(minibatch, key)
}
subgraph = self.sampler( subgraph = self.sampler(
minibatch._seed_nodes, self.fanout, self.replace, self.prob_name minibatch._seed_nodes,
self.fanout,
self.replace,
self.prob_name,
**kwargs,
) )
minibatch.sampled_subgraphs.insert(0, subgraph) minibatch.sampled_subgraphs.insert(0, subgraph)
return minibatch return minibatch
...@@ -244,11 +258,57 @@ class NeighborSamplerImpl(SubgraphSampler): ...@@ -244,11 +258,57 @@ class NeighborSamplerImpl(SubgraphSampler):
prob_name, prob_name,
deduplicate, deduplicate,
sampler, sampler,
layer_dependency=None,
batch_dependency=None,
): ):
if sampler.__name__ == "sample_layer_neighbors":
self._init_seed(batch_dependency)
super().__init__( super().__init__(
datapipe, graph, fanouts, replace, prob_name, deduplicate, sampler datapipe,
graph,
fanouts,
replace,
prob_name,
deduplicate,
sampler,
layer_dependency,
) )
def _init_seed(self, batch_dependency):
self.rng = torch.random.manual_seed(
torch.randint(0, int(1e18), size=tuple())
)
self.cnt = [-1, int(batch_dependency)]
self.random_seed = torch.empty(
2 if self.cnt[1] > 1 else 1, dtype=torch.int64
)
self.random_seed.random_(generator=self.rng)
def _set_seed(self, minibatch):
self.cnt[0] += 1
if self.cnt[1] > 0 and self.cnt[0] % self.cnt[1] == 0:
self.random_seed[0] = self.random_seed[-1]
self.random_seed[-1:].random_(generator=self.rng)
minibatch._random_seed = self.random_seed.clone()
minibatch._seed2_contribution = (
0.0
if self.cnt[1] <= 1
else (self.cnt[0] % self.cnt[1]) / self.cnt[1]
)
minibatch._iter = self.cnt[0]
return minibatch
@staticmethod
def _increment_seed(minibatch):
minibatch._random_seed = 1 + minibatch._random_seed
return minibatch
@staticmethod
def _delattr_dependency(minibatch):
delattr(minibatch, "_random_seed")
delattr(minibatch, "_seed2_contribution")
return minibatch
@staticmethod @staticmethod
def _prepare(node_type_to_id, minibatch): def _prepare(node_type_to_id, minibatch):
seeds = minibatch._seed_nodes seeds = minibatch._seed_nodes
...@@ -277,11 +337,22 @@ class NeighborSamplerImpl(SubgraphSampler): ...@@ -277,11 +337,22 @@ class NeighborSamplerImpl(SubgraphSampler):
# pylint: disable=arguments-differ # pylint: disable=arguments-differ
def sampling_stages( def sampling_stages(
self, datapipe, graph, fanouts, replace, prob_name, deduplicate, sampler self,
datapipe,
graph,
fanouts,
replace,
prob_name,
deduplicate,
sampler,
layer_dependency,
): ):
datapipe = datapipe.transform( datapipe = datapipe.transform(
partial(self._prepare, graph.node_type_to_id) partial(self._prepare, graph.node_type_to_id)
) )
is_labor = sampler.__name__ == "sample_layer_neighbors"
if is_labor:
datapipe = datapipe.transform(self._set_seed)
for fanout in reversed(fanouts): for fanout in reversed(fanouts):
# Convert fanout to tensor. # Convert fanout to tensor.
if not isinstance(fanout, torch.Tensor): if not isinstance(fanout, torch.Tensor):
...@@ -290,7 +361,10 @@ class NeighborSamplerImpl(SubgraphSampler): ...@@ -290,7 +361,10 @@ class NeighborSamplerImpl(SubgraphSampler):
sampler, fanout, replace, prob_name sampler, fanout, replace, prob_name
) )
datapipe = datapipe.compact_per_layer(deduplicate) datapipe = datapipe.compact_per_layer(deduplicate)
if is_labor and not layer_dependency:
datapipe = datapipe.transform(self._increment_seed)
if is_labor:
datapipe = datapipe.transform(self._delattr_dependency)
return datapipe.transform(self._set_input_nodes) return datapipe.transform(self._set_input_nodes)
...@@ -504,6 +578,8 @@ class LayerNeighborSampler(NeighborSamplerImpl): ...@@ -504,6 +578,8 @@ class LayerNeighborSampler(NeighborSamplerImpl):
replace=False, replace=False,
prob_name=None, prob_name=None,
deduplicate=True, deduplicate=True,
layer_dependency=False,
batch_dependency=1,
): ):
super().__init__( super().__init__(
datapipe, datapipe,
...@@ -513,4 +589,6 @@ class LayerNeighborSampler(NeighborSamplerImpl): ...@@ -513,4 +589,6 @@ class LayerNeighborSampler(NeighborSamplerImpl):
prob_name, prob_name,
deduplicate, deduplicate,
graph.sample_layer_neighbors, graph.sample_layer_neighbors,
layer_dependency,
batch_dependency,
) )
...@@ -75,3 +75,59 @@ def test_NeighborSampler_GraphFetch(hetero, prob_name, sorted): ...@@ -75,3 +75,59 @@ def test_NeighborSampler_GraphFetch(hetero, prob_name, sorted):
assert len(expected_results) == len(new_results) assert len(expected_results) == len(new_results)
for a, b in zip(expected_results, new_results): for a, b in zip(expected_results, new_results):
assert repr(a) == repr(b) assert repr(a) == repr(b)
@pytest.mark.parametrize("layer_dependency", [False, True])
@pytest.mark.parametrize("overlap_graph_fetch", [False, True])
def test_labor_dependent_minibatching(layer_dependency, overlap_graph_fetch):
num_edges = 200
csc_indptr = torch.cat(
(
torch.zeros(1, dtype=torch.int64),
torch.ones(num_edges + 1, dtype=torch.int64) * num_edges,
)
)
indices = torch.arange(1, num_edges + 1)
graph = gb.fused_csc_sampling_graph(
csc_indptr.int(),
indices.int(),
).to(F.ctx())
torch.random.set_rng_state(torch.manual_seed(123).get_state())
batch_dependency = 100
itemset = gb.ItemSet(
torch.zeros(batch_dependency + 1).int(), names="seed_nodes"
)
datapipe = gb.ItemSampler(itemset, batch_size=1).copy_to(F.ctx())
fanouts = [5, 5]
datapipe = datapipe.sample_layer_neighbor(
graph,
fanouts,
layer_dependency=layer_dependency,
batch_dependency=batch_dependency,
)
dataloader = gb.DataLoader(
datapipe, overlap_graph_fetch=overlap_graph_fetch
)
res = list(dataloader)
assert len(res) == batch_dependency + 1
if layer_dependency:
assert torch.equal(
res[0].input_nodes,
res[0].sampled_subgraphs[1].original_row_node_ids,
)
else:
assert res[0].input_nodes.size(0) > res[0].sampled_subgraphs[
1
].original_row_node_ids.size(0)
delta = 0
for i in range(batch_dependency):
res_current = (
res[i].sampled_subgraphs[-1].original_row_node_ids.tolist()
)
res_next = (
res[i + 1].sampled_subgraphs[-1].original_row_node_ids.tolist()
)
intersect_len = len(set(res_current).intersection(set(res_next)))
assert intersect_len >= fanouts[-1]
delta += 1 + fanouts[-1] - intersect_len
assert delta >= fanouts[-1]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment