Unverified Commit a848aa3e authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Graphbolt] Reorder the parameters of UniformNegativeSampler. (#6322)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-28-63.ap-northeast-1.compute.internal>
parent b8886900
...@@ -124,7 +124,7 @@ def create_dataloader(args, graph, features, itemset, is_train=True): ...@@ -124,7 +124,7 @@ def create_dataloader(args, graph, features, itemset, is_train=True):
# negative edges information. # negative edges information.
############################################################################ ############################################################################
if is_train: if is_train:
datapipe = datapipe.sample_uniform_negative(args.neg_ratio, graph) datapipe = datapipe.sample_uniform_negative(graph, args.neg_ratio)
############################################################################ ############################################################################
# [Input]: # [Input]:
......
...@@ -20,8 +20,8 @@ class UniformNegativeSampler(NegativeSampler): ...@@ -20,8 +20,8 @@ class UniformNegativeSampler(NegativeSampler):
def __init__( def __init__(
self, self,
datapipe, datapipe,
negative_ratio,
graph, graph,
negative_ratio,
): ):
""" """
Initlization for a uniform negative sampler. Initlization for a uniform negative sampler.
...@@ -30,10 +30,10 @@ class UniformNegativeSampler(NegativeSampler): ...@@ -30,10 +30,10 @@ class UniformNegativeSampler(NegativeSampler):
---------- ----------
datapipe : DataPipe datapipe : DataPipe
The datapipe. The datapipe.
negative_ratio : int
The proportion of negative samples to positive samples.
graph : CSCSamplingGraph graph : CSCSamplingGraph
The graph on which to perform negative sampling. The graph on which to perform negative sampling.
negative_ratio : int
The proportion of negative samples to positive samples.
Examples Examples
-------- --------
...@@ -47,7 +47,7 @@ class UniformNegativeSampler(NegativeSampler): ...@@ -47,7 +47,7 @@ class UniformNegativeSampler(NegativeSampler):
...item_set, batch_size=1, ...item_set, batch_size=1,
...) ...)
>>> neg_sampler = gb.UniformNegativeSampler( >>> neg_sampler = gb.UniformNegativeSampler(
...item_sampler, 2, graph) ...item_sampler, graph, 2)
>>> for minibatch in neg_sampler: >>> for minibatch in neg_sampler:
... print(minibatch.negative_srcs) ... print(minibatch.negative_srcs)
... print(minibatch.negative_dsts) ... print(minibatch.negative_dsts)
......
...@@ -53,15 +53,15 @@ def test_UniformNegativeSampler_invoke(): ...@@ -53,15 +53,15 @@ def test_UniformNegativeSampler_invoke():
# Invoke UniformNegativeSampler via class constructor. # Invoke UniformNegativeSampler via class constructor.
negative_sampler = gb.UniformNegativeSampler( negative_sampler = gb.UniformNegativeSampler(
item_sampler, item_sampler,
negative_ratio,
graph, graph,
negative_ratio,
) )
_verify(negative_sampler) _verify(negative_sampler)
# Invoke UniformNegativeSampler via functional form. # Invoke UniformNegativeSampler via functional form.
negative_sampler = item_sampler.sample_uniform_negative( negative_sampler = item_sampler.sample_uniform_negative(
negative_ratio,
graph, graph,
negative_ratio,
) )
_verify(negative_sampler) _verify(negative_sampler)
...@@ -79,8 +79,8 @@ def test_Uniform_NegativeSampler(negative_ratio): ...@@ -79,8 +79,8 @@ def test_Uniform_NegativeSampler(negative_ratio):
# Construct NegativeSampler. # Construct NegativeSampler.
negative_sampler = gb.UniformNegativeSampler( negative_sampler = gb.UniformNegativeSampler(
item_sampler, item_sampler,
negative_ratio,
graph, graph,
negative_ratio,
) )
# Perform Negative sampling. # Perform Negative sampling.
for data in negative_sampler: for data in negative_sampler:
...@@ -135,5 +135,5 @@ def test_NegativeSampler_Hetero_Data(): ...@@ -135,5 +135,5 @@ def test_NegativeSampler_Hetero_Data():
) )
item_sampler = gb.ItemSampler(itemset, batch_size=2) item_sampler = gb.ItemSampler(itemset, batch_size=2)
negative_dp = gb.UniformNegativeSampler(item_sampler, 1, graph) negative_dp = gb.UniformNegativeSampler(item_sampler, graph, 1)
assert len(list(negative_dp)) == 5 assert len(list(negative_dp)) == 5
...@@ -101,7 +101,7 @@ def test_SubgraphSampler_Link_With_Negative(labor): ...@@ -101,7 +101,7 @@ def test_SubgraphSampler_Link_With_Negative(labor):
item_sampler = gb.ItemSampler(itemset, batch_size=2) item_sampler = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2 num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)] fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
negative_dp = gb.UniformNegativeSampler(item_sampler, 1, graph) negative_dp = gb.UniformNegativeSampler(item_sampler, graph, 1)
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
neighbor_dp = Sampler(negative_dp, graph, fanouts) neighbor_dp = Sampler(negative_dp, graph, fanouts)
assert len(list(neighbor_dp)) == 5 assert len(list(neighbor_dp)) == 5
...@@ -172,7 +172,7 @@ def test_SubgraphSampler_Link_Hetero_With_Negative(labor): ...@@ -172,7 +172,7 @@ def test_SubgraphSampler_Link_Hetero_With_Negative(labor):
item_sampler = gb.ItemSampler(itemset, batch_size=2) item_sampler = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2 num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)] fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
negative_dp = gb.UniformNegativeSampler(item_sampler, 1, graph) negative_dp = gb.UniformNegativeSampler(item_sampler, graph, 1)
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
neighbor_dp = Sampler(negative_dp, graph, fanouts) neighbor_dp = Sampler(negative_dp, graph, fanouts)
assert len(list(neighbor_dp)) == 5 assert len(list(neighbor_dp)) == 5
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment