"...models/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "7fa267e80efb02a1f2b43699a38e8498c74303df"
Unverified Commit 1789972c authored by Hengrui Zhang's avatar Hengrui Zhang Committed by GitHub
Browse files

[Bugfix] add default value for BlockSampler (#2771)



* [Bugfix] add default value for BlockSampler

* [Doc] modify user_guide description about MultiLayerDropoutSample
Co-authored-by: default avatarQuan (Andy) Gan <coin2028@hotmail.com>
parent 17f86356
...@@ -361,10 +361,9 @@ nodes with a probability, one can simply define the sampler as follows: ...@@ -361,10 +361,9 @@ nodes with a probability, one can simply define the sampler as follows:
.. code:: python .. code:: python
class MultiLayerDropoutSampler(dgl.dataloading.BlockSampler): class MultiLayerDropoutSampler(dgl.dataloading.BlockSampler):
def __init__(self, p, n_layers): def __init__(self, p, num_layers):
super().__init__() super().__init__(num_layers)
self.n_layers = n_layers
self.p = p self.p = p
def sample_frontier(self, block_id, g, seed_nodes, *args, **kwargs): def sample_frontier(self, block_id, g, seed_nodes, *args, **kwargs):
...@@ -380,7 +379,7 @@ nodes with a probability, one can simply define the sampler as follows: ...@@ -380,7 +379,7 @@ nodes with a probability, one can simply define the sampler as follows:
return frontier return frontier
def __len__(self): def __len__(self):
return self.n_layers return self.num_layers
After implementing your sampler, you can create a data loader that takes After implementing your sampler, you can create a data loader that takes
in your sampler and it will keep generating lists of MFGs while in your sampler and it will keep generating lists of MFGs while
...@@ -422,10 +421,9 @@ all edge types, so that it can work on heterogeneous graphs as well. ...@@ -422,10 +421,9 @@ all edge types, so that it can work on heterogeneous graphs as well.
.. code:: python .. code:: python
class MultiLayerDropoutSampler(dgl.dataloading.BlockSampler): class MultiLayerDropoutSampler(dgl.dataloading.BlockSampler):
def __init__(self, p, n_layers): def __init__(self, p, num_layers):
super().__init__() super().__init__(num_layers)
self.n_layers = n_layers
self.p = p self.p = p
def sample_frontier(self, block_id, g, seed_nodes, *args, **kwargs): def sample_frontier(self, block_id, g, seed_nodes, *args, **kwargs):
...@@ -445,7 +443,4 @@ all edge types, so that it can work on heterogeneous graphs as well. ...@@ -445,7 +443,4 @@ all edge types, so that it can work on heterogeneous graphs as well.
return frontier return frontier
def __len__(self): def __len__(self):
return self.n_layers return self.num_layers
...@@ -308,10 +308,9 @@ DGL确保块的输出节点将始终出现在输入节点中。如下代码所 ...@@ -308,10 +308,9 @@ DGL确保块的输出节点将始终出现在输入节点中。如下代码所
.. code:: python .. code:: python
class MultiLayerDropoutSampler(dgl.dataloading.BlockSampler): class MultiLayerDropoutSampler(dgl.dataloading.BlockSampler):
def __init__(self, p, n_layers): def __init__(self, p, num_layers):
super().__init__() super().__init__(num_layers)
self.n_layers = n_layers
self.p = p self.p = p
def sample_frontier(self, block_id, g, seed_nodes, *args, **kwargs): def sample_frontier(self, block_id, g, seed_nodes, *args, **kwargs):
...@@ -326,7 +325,7 @@ DGL确保块的输出节点将始终出现在输入节点中。如下代码所 ...@@ -326,7 +325,7 @@ DGL确保块的输出节点将始终出现在输入节点中。如下代码所
return frontier return frontier
def __len__(self): def __len__(self):
return self.n_layers return self.num_layers
在实现自定义采样器后,用户可以创建一个数据加载器。这个数据加载器使用用户自定义的采样器, 在实现自定义采样器后,用户可以创建一个数据加载器。这个数据加载器使用用户自定义的采样器,
并且遍历种子节点生成一系列的块。 并且遍历种子节点生成一系列的块。
...@@ -365,10 +364,9 @@ DGL确保块的输出节点将始终出现在输入节点中。如下代码所 ...@@ -365,10 +364,9 @@ DGL确保块的输出节点将始终出现在输入节点中。如下代码所
.. code:: python .. code:: python
class MultiLayerDropoutSampler(dgl.dataloading.BlockSampler): class MultiLayerDropoutSampler(dgl.dataloading.BlockSampler):
def __init__(self, p, n_layers): def __init__(self, p, num_layers):
super().__init__() super().__init__(num_layers)
self.n_layers = n_layers
self.p = p self.p = p
def sample_frontier(self, block_id, g, seed_nodes, *args, **kwargs): def sample_frontier(self, block_id, g, seed_nodes, *args, **kwargs):
...@@ -387,4 +385,5 @@ DGL确保块的输出节点将始终出现在输入节点中。如下代码所 ...@@ -387,4 +385,5 @@ DGL确保块的输出节点将始终出现在输入节点中。如下代码所
return frontier return frontier
def __len__(self): def __len__(self):
return self.n_layers return self.num_layers
\ No newline at end of file
\ No newline at end of file
...@@ -158,7 +158,7 @@ class BlockSampler(object): ...@@ -158,7 +158,7 @@ class BlockSampler(object):
:ref:`User Guide Section 6 <guide-minibatch>` and :ref:`User Guide Section 6 <guide-minibatch>` and
:doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`. :doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`.
""" """
def __init__(self, num_layers, return_eids): def __init__(self, num_layers, return_eids=False):
self.num_layers = num_layers self.num_layers = num_layers
self.return_eids = return_eids self.return_eids = return_eids
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment