Unverified Commit 1e5fb155 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Update variable names. (#6071)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-28-63.ap-northeast-1.compute.internal>
parent 5da7d391
...@@ -15,7 +15,7 @@ from .impl import * ...@@ -15,7 +15,7 @@ from .impl import *
from .dataloader import * from .dataloader import *
from .subgraph_sampler import * from .subgraph_sampler import *
from .sampled_subgraph import * from .sampled_subgraph import *
from .link_data_format import * from .data_format import *
from .negative_sampler import * from .negative_sampler import *
from .utils import unique_and_compact_node_pairs from .utils import unique_and_compact_node_pairs
......
"""Linked data format.""" """Data format enums for graphbolt."""
from enum import Enum from enum import Enum
__all__ = ["LinkDataFormat"] __all__ = ["LinkPredictionEdgeFormat"]
class LinkDataFormat(Enum): class LinkPredictionEdgeFormat(Enum):
""" """
An Enum class representing the two data formats used in link prediction: An Enum class representing the formats of positive and negative edges used
in link prediction:
Attributes: Attributes:
CONDITIONED: Represents the 'conditioned' format where data is CONDITIONED: Represents the 'conditioned' format where data is
......
...@@ -18,7 +18,7 @@ class UniformNegativeSampler(NegativeSampler): ...@@ -18,7 +18,7 @@ class UniformNegativeSampler(NegativeSampler):
self, self,
datapipe, datapipe,
negative_ratio, negative_ratio,
link_data_format, output_format,
graph, graph,
): ):
""" """
...@@ -30,7 +30,7 @@ class UniformNegativeSampler(NegativeSampler): ...@@ -30,7 +30,7 @@ class UniformNegativeSampler(NegativeSampler):
The datapipe. The datapipe.
negative_ratio : int negative_ratio : int
The proportion of negative samples to positive samples. The proportion of negative samples to positive samples.
link_data_format : LinkDataFormat output_format : LinkPredictionEdgeFormat
Determines the format of the output data: Determines the format of the output data:
- Conditioned format: Outputs data as quadruples - Conditioned format: Outputs data as quadruples
`[u, v, [negative heads], [negative tails]]`. Here, 'u' and 'v' `[u, v, [negative heads], [negative tails]]`. Here, 'u' and 'v'
...@@ -50,14 +50,14 @@ class UniformNegativeSampler(NegativeSampler): ...@@ -50,14 +50,14 @@ class UniformNegativeSampler(NegativeSampler):
>>> indptr = torch.LongTensor([0, 2, 4, 5]) >>> indptr = torch.LongTensor([0, 2, 4, 5])
>>> indices = torch.LongTensor([1, 2, 0, 2, 0]) >>> indices = torch.LongTensor([1, 2, 0, 2, 0])
>>> graph = gb.from_csc(indptr, indices) >>> graph = gb.from_csc(indptr, indices)
>>> link_data_format = gb.LinkDataFormat.INDEPENDENT >>> output_format = gb.LinkPredictionEdgeFormat.INDEPENDENT
>>> node_pairs = (torch.tensor([0, 1]), torch.tensor([1, 2])) >>> node_pairs = (torch.tensor([0, 1]), torch.tensor([1, 2]))
>>> item_set = gb.ItemSet(node_pairs) >>> item_set = gb.ItemSet(node_pairs)
>>> minibatch_sampler = gb.MinibatchSampler( >>> minibatch_sampler = gb.MinibatchSampler(
...item_set, batch_size=1, ...item_set, batch_size=1,
...) ...)
>>> neg_sampler = gb.UniformNegativeSampler( >>> neg_sampler = gb.UniformNegativeSampler(
...minibatch_sampler, 2, link_data_format, graph) ...minibatch_sampler, 2, output_format, graph)
>>> for data in neg_sampler: >>> for data in neg_sampler:
... print(data) ... print(data)
... ...
...@@ -68,21 +68,21 @@ class UniformNegativeSampler(NegativeSampler): ...@@ -68,21 +68,21 @@ class UniformNegativeSampler(NegativeSampler):
>>> indptr = torch.LongTensor([0, 2, 4, 5]) >>> indptr = torch.LongTensor([0, 2, 4, 5])
>>> indices = torch.LongTensor([1, 2, 0, 2, 0]) >>> indices = torch.LongTensor([1, 2, 0, 2, 0])
>>> graph = gb.from_csc(indptr, indices) >>> graph = gb.from_csc(indptr, indices)
>>> link_data_format = gb.LinkDataFormat.CONDITIONED >>> output_format = gb.LinkPredictionEdgeFormat.CONDITIONED
>>> node_pairs = (torch.tensor([0, 1]), torch.tensor([1, 2])) >>> node_pairs = (torch.tensor([0, 1]), torch.tensor([1, 2]))
>>> item_set = gb.ItemSet(node_pairs) >>> item_set = gb.ItemSet(node_pairs)
>>> minibatch_sampler = gb.MinibatchSampler( >>> minibatch_sampler = gb.MinibatchSampler(
...item_set, batch_size=1, ...item_set, batch_size=1,
...) ...)
>>> neg_sampler = gb.UniformNegativeSampler( >>> neg_sampler = gb.UniformNegativeSampler(
...minibatch_sampler, 2, link_data_format, graph) ...minibatch_sampler, 2, output_format, graph)
>>> for data in neg_sampler: >>> for data in neg_sampler:
... print(data) ... print(data)
... ...
(tensor([0]), tensor([1]), tensor([[0, 0]]), tensor([[2, 1]])) (tensor([0]), tensor([1]), tensor([[0, 0]]), tensor([[2, 1]]))
(tensor([1]), tensor([2]), tensor([[1, 1]]), tensor([[1, 2]])) (tensor([1]), tensor([2]), tensor([[1, 1]]), tensor([[1, 2]]))
""" """
super().__init__(datapipe, negative_ratio, link_data_format) super().__init__(datapipe, negative_ratio, output_format)
self.graph = graph self.graph = graph
def _sample_with_etype(self, node_pairs, etype=None): def _sample_with_etype(self, node_pairs, etype=None):
......
...@@ -5,7 +5,7 @@ from _collections_abc import Mapping ...@@ -5,7 +5,7 @@ from _collections_abc import Mapping
import torch import torch
from torchdata.datapipes.iter import Mapper from torchdata.datapipes.iter import Mapper
from .link_data_format import LinkDataFormat from .data_format import LinkPredictionEdgeFormat
class NegativeSampler(Mapper): class NegativeSampler(Mapper):
...@@ -18,7 +18,7 @@ class NegativeSampler(Mapper): ...@@ -18,7 +18,7 @@ class NegativeSampler(Mapper):
self, self,
datapipe, datapipe,
negative_ratio, negative_ratio,
link_data_format, output_format,
): ):
""" """
Initlization for a negative sampler. Initlization for a negative sampler.
...@@ -29,8 +29,8 @@ class NegativeSampler(Mapper): ...@@ -29,8 +29,8 @@ class NegativeSampler(Mapper):
The datapipe. The datapipe.
negative_ratio : int negative_ratio : int
The proportion of negative samples to positive samples. The proportion of negative samples to positive samples.
link_data_format : LinkDataFormat output_format : LinkPredictionEdgeFormat
Determines the format of the output data: Determines the edge format of the output data:
- Conditioned format: Outputs data as quadruples - Conditioned format: Outputs data as quadruples
`[u, v, [negative heads], [negative tails]]`. Here, 'u' and 'v' `[u, v, [negative heads], [negative tails]]`. Here, 'u' and 'v'
are the source and destination nodes of positive edges, while are the source and destination nodes of positive edges, while
...@@ -44,7 +44,7 @@ class NegativeSampler(Mapper): ...@@ -44,7 +44,7 @@ class NegativeSampler(Mapper):
super().__init__(datapipe, self._sample) super().__init__(datapipe, self._sample)
assert negative_ratio > 0, "Negative_ratio should be positive Integer." assert negative_ratio > 0, "Negative_ratio should be positive Integer."
self.negative_ratio = negative_ratio self.negative_ratio = negative_ratio
self.link_data_format = link_data_format self.output_format = output_format
def _sample(self, node_pairs): def _sample(self, node_pairs):
""" """
...@@ -113,7 +113,7 @@ class NegativeSampler(Mapper): ...@@ -113,7 +113,7 @@ class NegativeSampler(Mapper):
Tuple[Tensor] Tuple[Tensor]
A mixed collection of positive and negative node pairs. A mixed collection of positive and negative node pairs.
""" """
if self.link_data_format == LinkDataFormat.INDEPENDENT: if self.output_format == LinkPredictionEdgeFormat.INDEPENDENT:
pos_src, pos_dst = pos_pairs pos_src, pos_dst = pos_pairs
neg_src, neg_dst = neg_pairs neg_src, neg_dst = neg_pairs
pos_label = torch.ones_like(pos_src) pos_label = torch.ones_like(pos_src)
...@@ -122,11 +122,11 @@ class NegativeSampler(Mapper): ...@@ -122,11 +122,11 @@ class NegativeSampler(Mapper):
dst = torch.cat([pos_dst, neg_dst]) dst = torch.cat([pos_dst, neg_dst])
label = torch.cat([pos_label, neg_label]) label = torch.cat([pos_label, neg_label])
return (src, dst, label) return (src, dst, label)
elif self.link_data_format == LinkDataFormat.CONDITIONED: elif self.output_format == LinkPredictionEdgeFormat.CONDITIONED:
pos_src, pos_dst = pos_pairs pos_src, pos_dst = pos_pairs
neg_src, neg_dst = neg_pairs neg_src, neg_dst = neg_pairs
neg_src = neg_src.view(-1, self.negative_ratio) neg_src = neg_src.view(-1, self.negative_ratio)
neg_dst = neg_dst.view(-1, self.negative_ratio) neg_dst = neg_dst.view(-1, self.negative_ratio)
return (pos_src, pos_dst, neg_src, neg_dst) return (pos_src, pos_dst, neg_src, neg_dst)
else: else:
raise ValueError("Unsupported link data format.") raise ValueError("Unsupported output format.")
...@@ -21,7 +21,7 @@ def test_NegativeSampler_Independent_Format(negative_ratio): ...@@ -21,7 +21,7 @@ def test_NegativeSampler_Independent_Format(negative_ratio):
negative_sampler = gb.UniformNegativeSampler( negative_sampler = gb.UniformNegativeSampler(
minibatch_sampler, minibatch_sampler,
negative_ratio, negative_ratio,
gb.LinkDataFormat.INDEPENDENT, gb.LinkPredictionEdgeFormat.INDEPENDENT,
graph, graph,
) )
# Perform Negative sampling. # Perform Negative sampling.
...@@ -52,7 +52,7 @@ def test_NegativeSampler_Conditioned_Format(negative_ratio): ...@@ -52,7 +52,7 @@ def test_NegativeSampler_Conditioned_Format(negative_ratio):
negative_sampler = gb.UniformNegativeSampler( negative_sampler = gb.UniformNegativeSampler(
minibatch_sampler, minibatch_sampler,
negative_ratio, negative_ratio,
gb.LinkDataFormat.CONDITIONED, gb.LinkPredictionEdgeFormat.CONDITIONED,
graph, graph,
) )
# Perform Negative sampling. # Perform Negative sampling.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment