Unverified Commit 1e5fb155 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Update variable names. (#6071)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-28-63.ap-northeast-1.compute.internal>
parent 5da7d391
......@@ -15,7 +15,7 @@ from .impl import *
from .dataloader import *
from .subgraph_sampler import *
from .sampled_subgraph import *
from .link_data_format import *
from .data_format import *
from .negative_sampler import *
from .utils import unique_and_compact_node_pairs
......
"""Linked data format."""
"""Data format enums for graphbolt."""
from enum import Enum
__all__ = ["LinkDataFormat"]
__all__ = ["LinkPredictionEdgeFormat"]
class LinkDataFormat(Enum):
class LinkPredictionEdgeFormat(Enum):
"""
An Enum class representing the two data formats used in link prediction:
An Enum class representing the formats of positive and negative edges used
in link prediction:
Attributes:
CONDITIONED: Represents the 'conditioned' format where data is
......@@ -15,7 +16,7 @@ class LinkDataFormat(Enum):
indicating the source and destination nodes of positive and negative edges.
INDEPENDENT: Represents the 'independent' format where data is structured
as triples `[u, v, label]` indicating the source and destination nodes of
as triples `[u, v, label]` indicating the source and destination nodes of
an edge, with a label (0 or 1) denoting it as negative or positive.
"""
......
......@@ -18,7 +18,7 @@ class UniformNegativeSampler(NegativeSampler):
self,
datapipe,
negative_ratio,
link_data_format,
output_format,
graph,
):
"""
......@@ -30,11 +30,11 @@ class UniformNegativeSampler(NegativeSampler):
The datapipe.
negative_ratio : int
The proportion of negative samples to positive samples.
link_data_format : LinkDataFormat
output_format : LinkPredictionEdgeFormat
Determines the format of the output data:
- Conditioned format: Outputs data as quadruples
`[u, v, [negative heads], [negative tails]]`. Here, 'u' and 'v'
are the source and destination nodes of positive edges, while
are the source and destination nodes of positive edges, while
'negative heads' and 'negative tails' refer to the source and
destination nodes of negative edges.
- Independent format: Outputs data as triples `[u, v, label]`.
......@@ -50,14 +50,14 @@ class UniformNegativeSampler(NegativeSampler):
>>> indptr = torch.LongTensor([0, 2, 4, 5])
>>> indices = torch.LongTensor([1, 2, 0, 2, 0])
>>> graph = gb.from_csc(indptr, indices)
>>> link_data_format = gb.LinkDataFormat.INDEPENDENT
>>> output_format = gb.LinkPredictionEdgeFormat.INDEPENDENT
>>> node_pairs = (torch.tensor([0, 1]), torch.tensor([1, 2]))
>>> item_set = gb.ItemSet(node_pairs)
>>> minibatch_sampler = gb.MinibatchSampler(
...item_set, batch_size=1,
...)
>>> neg_sampler = gb.UniformNegativeSampler(
...minibatch_sampler, 2, link_data_format, graph)
...minibatch_sampler, 2, output_format, graph)
>>> for data in neg_sampler:
... print(data)
...
......@@ -68,21 +68,21 @@ class UniformNegativeSampler(NegativeSampler):
>>> indptr = torch.LongTensor([0, 2, 4, 5])
>>> indices = torch.LongTensor([1, 2, 0, 2, 0])
>>> graph = gb.from_csc(indptr, indices)
>>> link_data_format = gb.LinkDataFormat.CONDITIONED
>>> output_format = gb.LinkPredictionEdgeFormat.CONDITIONED
>>> node_pairs = (torch.tensor([0, 1]), torch.tensor([1, 2]))
>>> item_set = gb.ItemSet(node_pairs)
>>> minibatch_sampler = gb.MinibatchSampler(
...item_set, batch_size=1,
...)
>>> neg_sampler = gb.UniformNegativeSampler(
...minibatch_sampler, 2, link_data_format, graph)
...minibatch_sampler, 2, output_format, graph)
>>> for data in neg_sampler:
... print(data)
...
(tensor([0]), tensor([1]), tensor([[0, 0]]), tensor([[2, 1]]))
(tensor([1]), tensor([2]), tensor([[1, 1]]), tensor([[1, 2]]))
"""
super().__init__(datapipe, negative_ratio, link_data_format)
super().__init__(datapipe, negative_ratio, output_format)
self.graph = graph
def _sample_with_etype(self, node_pairs, etype=None):
......
......@@ -5,7 +5,7 @@ from _collections_abc import Mapping
import torch
from torchdata.datapipes.iter import Mapper
from .link_data_format import LinkDataFormat
from .data_format import LinkPredictionEdgeFormat
class NegativeSampler(Mapper):
......@@ -18,7 +18,7 @@ class NegativeSampler(Mapper):
self,
datapipe,
negative_ratio,
link_data_format,
output_format,
):
"""
Initlization for a negative sampler.
......@@ -29,8 +29,8 @@ class NegativeSampler(Mapper):
The datapipe.
negative_ratio : int
The proportion of negative samples to positive samples.
link_data_format : LinkDataFormat
Determines the format of the output data:
output_format : LinkPredictionEdgeFormat
Determines the edge format of the output data:
- Conditioned format: Outputs data as quadruples
`[u, v, [negative heads], [negative tails]]`. Here, 'u' and 'v'
are the source and destination nodes of positive edges, while
......@@ -44,7 +44,7 @@ class NegativeSampler(Mapper):
super().__init__(datapipe, self._sample)
assert negative_ratio > 0, "Negative_ratio should be positive Integer."
self.negative_ratio = negative_ratio
self.link_data_format = link_data_format
self.output_format = output_format
def _sample(self, node_pairs):
"""
......@@ -113,7 +113,7 @@ class NegativeSampler(Mapper):
Tuple[Tensor]
A mixed collection of positive and negative node pairs.
"""
if self.link_data_format == LinkDataFormat.INDEPENDENT:
if self.output_format == LinkPredictionEdgeFormat.INDEPENDENT:
pos_src, pos_dst = pos_pairs
neg_src, neg_dst = neg_pairs
pos_label = torch.ones_like(pos_src)
......@@ -122,11 +122,11 @@ class NegativeSampler(Mapper):
dst = torch.cat([pos_dst, neg_dst])
label = torch.cat([pos_label, neg_label])
return (src, dst, label)
elif self.link_data_format == LinkDataFormat.CONDITIONED:
elif self.output_format == LinkPredictionEdgeFormat.CONDITIONED:
pos_src, pos_dst = pos_pairs
neg_src, neg_dst = neg_pairs
neg_src = neg_src.view(-1, self.negative_ratio)
neg_dst = neg_dst.view(-1, self.negative_ratio)
return (pos_src, pos_dst, neg_src, neg_dst)
else:
raise ValueError("Unsupported link data format.")
raise ValueError("Unsupported output format.")
......@@ -21,7 +21,7 @@ def test_NegativeSampler_Independent_Format(negative_ratio):
negative_sampler = gb.UniformNegativeSampler(
minibatch_sampler,
negative_ratio,
gb.LinkDataFormat.INDEPENDENT,
gb.LinkPredictionEdgeFormat.INDEPENDENT,
graph,
)
# Perform Negative sampling.
......@@ -52,7 +52,7 @@ def test_NegativeSampler_Conditioned_Format(negative_ratio):
negative_sampler = gb.UniformNegativeSampler(
minibatch_sampler,
negative_ratio,
gb.LinkDataFormat.CONDITIONED,
gb.LinkPredictionEdgeFormat.CONDITIONED,
graph,
)
# Perform Negative sampling.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment