Unverified Commit 6963d796 authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[Distributed] update Embedding API (#1847)



* update.

* update Embedding.

* add comments.

* fix lint
Co-authored-by: default avatarChao Ma <mctt90@gmail.com>
parent ec2e24be
...@@ -5,7 +5,7 @@ import sys ...@@ -5,7 +5,7 @@ import sys
from .dist_graph import DistGraphServer, DistGraph, DistTensor, node_split, edge_split from .dist_graph import DistGraphServer, DistGraph, DistTensor, node_split, edge_split
from .partition import partition_graph, load_partition, load_partition_book from .partition import partition_graph, load_partition, load_partition_book
from .graph_partition_book import GraphPartitionBook, RangePartitionBook, PartitionPolicy from .graph_partition_book import GraphPartitionBook, RangePartitionBook, PartitionPolicy
from .sparse_emb import SparseAdagrad, SparseNodeEmbedding from .sparse_emb import SparseAdagrad, DistEmbedding
from .rpc import * from .rpc import *
from .rpc_server import start_server from .rpc_server import start_server
......
...@@ -44,14 +44,14 @@ class DistTensor: ...@@ -44,14 +44,14 @@ class DistTensor:
The dtype of the tensor The dtype of the tensor
name : string name : string
The name of the tensor. The name of the tensor.
part_policy : PartitionPolicy
The partition policy of the tensor
init_func : callable init_func : callable
The function to initialize data in the tensor. The function to initialize data in the tensor.
part_policy : PartitionPolicy
The partition policy of the tensor
persistent : bool persistent : bool
Whether the created tensor is persistent. Whether the created tensor is persistent.
''' '''
def __init__(self, g, shape, dtype, name=None, part_policy=None, init_func=None, def __init__(self, g, shape, dtype, name=None, init_func=None, part_policy=None,
persistent=False): persistent=False):
self.kvstore = g._client self.kvstore = g._client
self._shape = shape self._shape = shape
......
"""Define sparse embedding and optimizer.""" """Define sparse embedding and optimizer."""
from .. import backend as F from .. import backend as F
from .. import utils
from .dist_tensor import DistTensor from .dist_tensor import DistTensor
from .graph_partition_book import PartitionPolicy, NODE_PART_POLICY from .graph_partition_book import PartitionPolicy, NODE_PART_POLICY
class SparseNodeEmbedding: class DistEmbedding:
''' Sparse embeddings in the distributed KVStore. '''Embeddings in the distributed training.
The sparse embeddings are only used as node embeddings. By default, the embeddings are created for nodes in the graph.
Parameters Parameters
---------- ----------
g : DistGraph g : DistGraph
The distributed graph object. The distributed graph object.
num_embeddings : int
The number of embeddings
embedding_dim : int
The dimension size of embeddings.
name : str name : str
The name of the embeddings The name of the embeddings
shape : tuple of int init_func : callable
The shape of the embedding. The first dimension should be the number of nodes.
initializer : callable
The function to create the initial data. The function to create the initial data.
part_policy : PartitionPolicy
The partition policy.
Examples Examples
-------- --------
>>> emb_init = lambda shape, dtype: F.zeros(shape, dtype, F.cpu()) >>> emb_init = lambda shape, dtype: F.zeros(shape, dtype, F.cpu())
>>> shape = (g.number_of_nodes(), 1) >>> emb = dgl.distributed.DistEmbedding(g, g.number_of_nodes(), 10)
>>> emb = dgl.distributed.SparseNodeEmbedding(g, 'emb1', shape, emb_init)
>>> optimizer = dgl.distributed.SparseAdagrad([emb], lr=0.001) >>> optimizer = dgl.distributed.SparseAdagrad([emb], lr=0.001)
>>> for blocks in dataloader: >>> for blocks in dataloader:
>>> feats = emb(nids) >>> feats = emb(nids)
...@@ -32,15 +36,17 @@ class SparseNodeEmbedding: ...@@ -32,15 +36,17 @@ class SparseNodeEmbedding:
>>> loss.backward() >>> loss.backward()
>>> optimizer.step() >>> optimizer.step()
''' '''
def __init__(self, g, name, shape, initializer): def __init__(self, g, num_embeddings, embedding_dim, name=None,
assert shape[0] == g.number_of_nodes() init_func=None, part_policy=None):
if part_policy is None:
part_policy = PartitionPolicy(NODE_PART_POLICY, g.get_partition_book()) part_policy = PartitionPolicy(NODE_PART_POLICY, g.get_partition_book())
g.ndata[name] = DistTensor(g, shape, F.float32, name, part_policy, initializer)
self._tensor = g.ndata[name] self._tensor = DistTensor(g, (num_embeddings, embedding_dim), F.float32, name,
init_func, part_policy)
self._trace = [] self._trace = []
def __call__(self, idx): def __call__(self, idx):
idx = utils.toindex(idx).tousertensor()
emb = F.attach_grad(self._tensor[idx]) emb = F.attach_grad(self._tensor[idx])
self._trace.append((idx, emb)) self._trace.append((idx, emb))
return emb return emb
...@@ -95,7 +101,7 @@ class SparseAdagrad: ...@@ -95,7 +101,7 @@ class SparseAdagrad:
Parameters Parameters
---------- ----------
params : list of SparseNodeEmbeddings params : list of DistEmbeddings
The list of sparse embeddings. The list of sparse embeddings.
lr : float lr : float
The learning rate. The learning rate.
......
...@@ -13,7 +13,7 @@ from dgl.graph_index import create_graph_index ...@@ -13,7 +13,7 @@ from dgl.graph_index import create_graph_index
from dgl.data.utils import load_graphs, save_graphs from dgl.data.utils import load_graphs, save_graphs
from dgl.distributed import DistGraphServer, DistGraph from dgl.distributed import DistGraphServer, DistGraph
from dgl.distributed import partition_graph, load_partition, load_partition_book, node_split, edge_split from dgl.distributed import partition_graph, load_partition, load_partition_book, node_split, edge_split
from dgl.distributed import SparseAdagrad, SparseNodeEmbedding from dgl.distributed import SparseAdagrad, DistEmbedding
from numpy.testing import assert_almost_equal from numpy.testing import assert_almost_equal
import backend as F import backend as F
import math import math
...@@ -120,8 +120,7 @@ def check_dist_graph(g, num_nodes, num_edges): ...@@ -120,8 +120,7 @@ def check_dist_graph(g, num_nodes, num_edges):
# Test sparse emb # Test sparse emb
try: try:
new_shape = (g.number_of_nodes(), 1) emb = DistEmbedding(g, g.number_of_nodes(), 1, 'emb1', emb_init)
emb = SparseNodeEmbedding(g, 'emb1', new_shape, emb_init)
lr = 0.001 lr = 0.001
optimizer = SparseAdagrad([emb], lr=lr) optimizer = SparseAdagrad([emb], lr=lr)
with F.record_grad(): with F.record_grad():
...@@ -142,7 +141,7 @@ def check_dist_graph(g, num_nodes, num_edges): ...@@ -142,7 +141,7 @@ def check_dist_graph(g, num_nodes, num_edges):
assert np.all(F.asnumpy(grad_sum[nids]) == np.ones((len(nids), 1))) assert np.all(F.asnumpy(grad_sum[nids]) == np.ones((len(nids), 1)))
assert np.all(F.asnumpy(grad_sum[rest]) == np.zeros((len(rest), 1))) assert np.all(F.asnumpy(grad_sum[rest]) == np.zeros((len(rest), 1)))
emb = SparseNodeEmbedding(g, 'emb2', new_shape, emb_init) emb = DistEmbedding(g, g.number_of_nodes(), 1, 'emb2', emb_init)
optimizer = SparseAdagrad([emb], lr=lr) optimizer = SparseAdagrad([emb], lr=lr)
with F.record_grad(): with F.record_grad():
feats1 = emb(nids) feats1 = emb(nids)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment