Unverified Commit a848aa3e authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Graphbolt] Reorder the parameters of UniformNegativeSampler. (#6322)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-28-63.ap-northeast-1.compute.internal>
parent b8886900
......@@ -124,7 +124,7 @@ def create_dataloader(args, graph, features, itemset, is_train=True):
# negative edges information.
############################################################################
if is_train:
datapipe = datapipe.sample_uniform_negative(args.neg_ratio, graph)
datapipe = datapipe.sample_uniform_negative(graph, args.neg_ratio)
############################################################################
# [Input]:
......
......@@ -20,8 +20,8 @@ class UniformNegativeSampler(NegativeSampler):
def __init__(
self,
datapipe,
negative_ratio,
graph,
negative_ratio,
):
"""
Initlization for a uniform negative sampler.
......@@ -30,10 +30,10 @@ class UniformNegativeSampler(NegativeSampler):
----------
datapipe : DataPipe
The datapipe.
negative_ratio : int
The proportion of negative samples to positive samples.
graph : CSCSamplingGraph
The graph on which to perform negative sampling.
negative_ratio : int
The proportion of negative samples to positive samples.
Examples
--------
......@@ -47,7 +47,7 @@ class UniformNegativeSampler(NegativeSampler):
...item_set, batch_size=1,
...)
>>> neg_sampler = gb.UniformNegativeSampler(
...item_sampler, 2, graph)
...item_sampler, graph, 2)
>>> for minibatch in neg_sampler:
... print(minibatch.negative_srcs)
... print(minibatch.negative_dsts)
......
......@@ -53,15 +53,15 @@ def test_UniformNegativeSampler_invoke():
# Invoke UniformNegativeSampler via class constructor.
negative_sampler = gb.UniformNegativeSampler(
item_sampler,
negative_ratio,
graph,
negative_ratio,
)
_verify(negative_sampler)
# Invoke UniformNegativeSampler via functional form.
negative_sampler = item_sampler.sample_uniform_negative(
negative_ratio,
graph,
negative_ratio,
)
_verify(negative_sampler)
......@@ -79,8 +79,8 @@ def test_Uniform_NegativeSampler(negative_ratio):
# Construct NegativeSampler.
negative_sampler = gb.UniformNegativeSampler(
item_sampler,
negative_ratio,
graph,
negative_ratio,
)
# Perform Negative sampling.
for data in negative_sampler:
......@@ -135,5 +135,5 @@ def test_NegativeSampler_Hetero_Data():
)
item_sampler = gb.ItemSampler(itemset, batch_size=2)
negative_dp = gb.UniformNegativeSampler(item_sampler, 1, graph)
negative_dp = gb.UniformNegativeSampler(item_sampler, graph, 1)
assert len(list(negative_dp)) == 5
......@@ -101,7 +101,7 @@ def test_SubgraphSampler_Link_With_Negative(labor):
item_sampler = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
negative_dp = gb.UniformNegativeSampler(item_sampler, 1, graph)
negative_dp = gb.UniformNegativeSampler(item_sampler, graph, 1)
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
neighbor_dp = Sampler(negative_dp, graph, fanouts)
assert len(list(neighbor_dp)) == 5
......@@ -172,7 +172,7 @@ def test_SubgraphSampler_Link_Hetero_With_Negative(labor):
item_sampler = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
negative_dp = gb.UniformNegativeSampler(item_sampler, 1, graph)
negative_dp = gb.UniformNegativeSampler(item_sampler, graph, 1)
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
neighbor_dp = Sampler(negative_dp, graph, fanouts)
assert len(list(neighbor_dp)) == 5
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment