Unverified Commit cadcc1c2 authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[Feature] add sparse embedding. (#1497)

* add sparse embedding.

* fix

* add test.

* man fixes.

* many fixes

* fix sparse emb.

* fix.

* fix lint.

* fix lint.

* fix kvstore.

* expose DistTensor.

* test sparse embeddings.

* add attach_grad to the backends.

* remove part_id

* fix.

* move backward computation.

* move more computation to backend.

* fix a bug when applying learning rate.

* fix a few things.

* fix a few things.

* add docstring

* fix.

* apply no_grad.

* fix tests.

* fix for other frameworks.

* add examples in docstring.
parent 17701174
...@@ -531,6 +531,21 @@ def exp(input): ...@@ -531,6 +531,21 @@ def exp(input):
""" """
pass pass
def sqrt(input):
"""Returns a new tensor with the square root of the elements of the input tensor `input`.
Parameters
----------
input : Tensor
The input tensor.
Returns
-------
Tensor
The output tensor.
"""
pass
def softmax(input, dim=-1): def softmax(input, dim=-1):
"""Apply the softmax function on given dimension. """Apply the softmax function on given dimension.
...@@ -718,6 +733,31 @@ def scatter_row(data, row_index, value): ...@@ -718,6 +733,31 @@ def scatter_row(data, row_index, value):
""" """
pass pass
def index_add_inplace(data, row_idx, value):
"""Add the values into the data tensor using the row index inplace.
If two row indices are the same, the corresponding values are sum up before
adding to the data tensor.
Examples
--------
>>> import torch as th
>>> arr = th.zeros((10))
>>> F. index_add_inplace(arr, th.tensor([0, 1, 1]), th.tensor([1.0, 1.0, 1.0]))
>>> arr
tensor([1., 2., 0., 0., 0., 0., 0., 0., 0., 0.])
Parameters
----------
data : Tensor
The data tensor to be updated.
row_index : Tensor
A 1-D integer tensor containing which rows to be updated.
value : Tensor
The new value.
"""
pass
def scatter_row_inplace(data, row_index, value): def scatter_row_inplace(data, row_index, value):
"""Write the value into the data tensor using the row index inplace. """Write the value into the data tensor using the row index inplace.
...@@ -1325,3 +1365,46 @@ def sync(): ...@@ -1325,3 +1365,46 @@ def sync():
that all computation is complete after this function call. that all computation is complete after this function call.
""" """
pass pass
def attach_grad(tensor):
""" Attach gradients to the input tensor
"""
pass
def backward(x, head_gradient=None):
"""Invoke backward computation with an optional head gradient.
"""
pass
def grad(x):
"""Fetches the gradient from the tensor after backward computation.
"""
pass
def is_no_grad(x):
""" Test if the input tensor has gradient
"""
pass
class record_grad(object):
"""Context manager that records the gradients"""
def __init__(self):
pass
def __enter__(self):
pass
def __exit__(self, exc_type, exc_value, exc_traceback):
pass
class no_grad(object):
"""Context manager that explicitly disables gradient computation"""
def __init__(self):
pass
def __enter__(self):
pass
def __exit__(self, exc_type, exc_value, exc_traceback):
pass
...@@ -166,6 +166,9 @@ def argsort(input, dim, descending): ...@@ -166,6 +166,9 @@ def argsort(input, dim, descending):
def exp(input): def exp(input):
return nd.exp(input) return nd.exp(input)
def sqrt(input):
return nd.sqrt(input)
def softmax(input, dim=-1): def softmax(input, dim=-1):
return nd.softmax(input, axis=dim) return nd.softmax(input, axis=dim)
...@@ -224,6 +227,9 @@ def take(data, indices, dim): ...@@ -224,6 +227,9 @@ def take(data, indices, dim):
def narrow_row(data, start, stop): def narrow_row(data, start, stop):
return data[start:stop] return data[start:stop]
def index_add_inplace(data, row_idx, value):
raise NotImplementedError("MXNet doesn't support inplace index_add")
def scatter_row(data, row_index, value): def scatter_row(data, row_index, value):
return mx.nd.contrib.index_copy(data, row_index, value) return mx.nd.contrib.index_copy(data, row_index, value)
...@@ -576,3 +582,28 @@ def sync(): ...@@ -576,3 +582,28 @@ def sync():
that all computation is complete after this function call. that all computation is complete after this function call.
""" """
mx.nd.waitall() mx.nd.waitall()
def attach_grad(tensor):
tensor.attach_grad()
return tensor
def backward(x, head_gradient=None):
x.backward(head_gradient)
def grad(x):
return x.grad
def is_no_grad(x):
return (x != 0).sum() == 0
record_grad = mx.autograd.record
class no_grad(object):
def __init__(self):
pass
def __enter__(self):
pass
def __exit__(self, exc_type, exc_value, exc_traceback):
pass
...@@ -145,6 +145,9 @@ def argtopk(input, k, dim, descending=True): ...@@ -145,6 +145,9 @@ def argtopk(input, k, dim, descending=True):
def exp(input): def exp(input):
return th.exp(input) return th.exp(input)
def sqrt(input):
return th.sqrt(input)
def softmax(input, dim=-1): def softmax(input, dim=-1):
return th.softmax(input, dim=dim) return th.softmax(input, dim=dim)
...@@ -176,6 +179,9 @@ def take(data, indices, dim): ...@@ -176,6 +179,9 @@ def take(data, indices, dim):
def narrow_row(x, start, stop): def narrow_row(x, start, stop):
return x[start:stop] return x[start:stop]
def index_add_inplace(data, row_idx, value):
data.index_add_(0, row_idx, value)
def scatter_row(data, row_index, value): def scatter_row(data, row_index, value):
return data.index_copy(0, row_index.long(), value) return data.index_copy(0, row_index.long(), value)
...@@ -486,3 +492,34 @@ def _reduce_grad(grad, shape): ...@@ -486,3 +492,34 @@ def _reduce_grad(grad, shape):
def sync(): def sync():
# Pytorch performs computation synchronously, so no need for synchronization. # Pytorch performs computation synchronously, so no need for synchronization.
pass pass
def attach_grad(x):
if x.grad is not None:
x.grad.zero_()
return x
else:
return x.requires_grad_()
def backward(x, head_gradient=None):
if head_gradient is not None and head_gradient.shape[0] == 1 and len(head_gradient.shape) == 1:
# Fix for torch 1.3.1
head_gradient = th.tensor(head_gradient.item()).to(head_gradient.device)
x.backward(head_gradient)
def grad(x):
return x.grad
def is_no_grad(x):
return x.grad is None or (x.grad == 0).all()
class record_grad(object):
def __init__(self):
pass
def __enter__(self):
pass
def __exit__(self, exc_type, exc_value, exc_traceback):
pass
no_grad = th.no_grad
...@@ -213,6 +213,8 @@ def argtopk(input, k, dim, descending=True): ...@@ -213,6 +213,8 @@ def argtopk(input, k, dim, descending=True):
def exp(input): def exp(input):
return tf.exp(input) return tf.exp(input)
def sqrt(input):
return tf.sqrt(input)
def softmax(input, dim=-1): def softmax(input, dim=-1):
return tf.math.softmax(input, axis=dim) return tf.math.softmax(input, axis=dim)
...@@ -260,6 +262,9 @@ def scatter_row(data, row_index, value): ...@@ -260,6 +262,9 @@ def scatter_row(data, row_index, value):
row_index = tf.expand_dims(row_index, 1) row_index = tf.expand_dims(row_index, 1)
return tf.tensor_scatter_nd_update(data, row_index, value) return tf.tensor_scatter_nd_update(data, row_index, value)
def index_add_inplace(data, row_idx, value):
raise NotImplementedError("Tensorflow doesn't support inplace index_add")
def scatter_row_inplace(data, row_index, value): def scatter_row_inplace(data, row_index, value):
raise NotImplementedError("Tensorflow doesn't support inplace update") raise NotImplementedError("Tensorflow doesn't support inplace update")
...@@ -593,4 +598,90 @@ def sync(): ...@@ -593,4 +598,90 @@ def sync():
context = context().context() context = context().context()
context.async_wait() context.async_wait()
initialize_context()
\ No newline at end of file class GradContext:
def __init__(self):
self.tensor_for_grad = []
self.grad_list = []
self.tape = None
def set_tape(self, tape):
self.tape = tape
def add_tensor(self, x):
idx_pop = []
for idx, ele in enumerate(self.tensor_for_grad):
if ele._id == x._id:
idx_pop.append(idx)
if len(idx_pop) > 0:
self.tensor_for_grad.pop(idx_pop[0])
if self.tape is not None:
self.tape.watch(x)
self.tensor_for_grad.append(x)
def backward(self, x, head_gradient=None):
if head_gradient is not None:
x = x * head_gradient
self.grad_list = self.tape.gradient(x, self.tensor_for_grad)
def is_no_grad(self, x):
idx_pop = []
for idx, ele in enumerate(self.tensor_for_grad):
if ele._id == x._id:
idx_pop.append(idx)
if len(idx_pop) == 0:
return True
else:
return self.grad_list[idx_pop[0]] is None
def grad(self, x):
idx_pop = []
for idx, ele in enumerate(self.tensor_for_grad):
if ele._id == x._id:
idx_pop.append(idx)
assert len(idx_pop) == 1
t = self.grad_list[idx_pop[0]]
return tf.convert_to_tensor(t)
cgrad = GradContext()
def get_cgrad():
return cgrad
class record_grad:
def __init__(self):
self.tape = tf.GradientTape()
def __enter__(self):
cgrad.set_tape(self.tape)
self.tape.__enter__()
for x in cgrad.tensor_for_grad:
self.tape.watch(x)
def __exit__(self, exc_type, exc_value, exc_traceback):
# pass
self.tape.__exit__(exc_type, exc_value, exc_traceback)
cgrad.tape = None
def attach_grad(x):
cgrad.add_tensor(x)
return x
def backward(x, head_gradient=None):
cgrad.backward(x, head_gradient)
def grad(x):
return cgrad.grad(x)
def is_no_grad(x):
return cgrad.is_no_grad(x)
no_grad = None
initialize_context()
"""DGL distributed.""" """DGL distributed."""
from .dist_graph import DistGraphServer, DistGraph, node_split, edge_split from .dist_graph import DistGraphServer, DistGraph, DistTensor, node_split, edge_split
from .partition import partition_graph, load_partition, load_partition_book from .partition import partition_graph, load_partition, load_partition_book
from .graph_partition_book import GraphPartitionBook, RangePartitionBook, PartitionPolicy from .graph_partition_book import GraphPartitionBook, RangePartitionBook, PartitionPolicy
from .sparse_emb import SparseAdagrad, SparseNodeEmbedding
from .rpc import * from .rpc import *
from .rpc_server import start_server from .rpc_server import start_server
......
...@@ -115,32 +115,40 @@ class DistTensor: ...@@ -115,32 +115,40 @@ class DistTensor:
Parameters Parameters
---------- ----------
kv : DistGraph g : DistGraph
The distributed graph object. The distributed graph object.
name : string name : string
The name of the tensor. The name of the tensor.
part_policy : PartitionPolicy
The partition policy of the tensor
''' '''
def __init__(self, g, name): def __init__(self, g, name, part_policy):
self.kvstore = g._client self.kvstore = g._client
self.name = name self._name = name
dtype, shape, _ = g._client.get_data_meta(name) dtype, shape, _ = g._client.get_data_meta(name)
self._shape = shape self._shape = shape
self._dtype = dtype self._dtype = dtype
self._part_policy = part_policy
def __getitem__(self, idx): def __getitem__(self, idx):
idx = utils.toindex(idx) idx = utils.toindex(idx)
idx = idx.tousertensor() idx = idx.tousertensor()
return self.kvstore.pull(name=self.name, id_tensor=idx) return self.kvstore.pull(name=self._name, id_tensor=idx)
def __setitem__(self, idx, val): def __setitem__(self, idx, val):
idx = utils.toindex(idx) idx = utils.toindex(idx)
idx = idx.tousertensor() idx = idx.tousertensor()
# TODO(zhengda) how do we want to support broadcast (e.g., G.ndata['h'][idx] = 1). # TODO(zhengda) how do we want to support broadcast (e.g., G.ndata['h'][idx] = 1).
self.kvstore.push(name=self.name, id_tensor=idx, data_tensor=val) self.kvstore.push(name=self._name, id_tensor=idx, data_tensor=val)
def __len__(self): def __len__(self):
return self._shape[0] return self._shape[0]
@property
def part_policy(self):
''' Return the partition policy '''
return self._part_policy
@property @property
def shape(self): def shape(self):
''' Return the shape of the distributed tensor. ''' ''' Return the shape of the distributed tensor. '''
...@@ -151,6 +159,11 @@ class DistTensor: ...@@ -151,6 +159,11 @@ class DistTensor:
''' Return the data type of the distributed tensor. ''' ''' Return the data type of the distributed tensor. '''
return self._dtype return self._dtype
@property
def name(self):
''' Return the name of the distributed tensor '''
return self._name
class NodeDataView(MutableMapping): class NodeDataView(MutableMapping):
"""The data view class when dist_graph.ndata[...].data is called. """The data view class when dist_graph.ndata[...].data is called.
...@@ -162,13 +175,15 @@ class NodeDataView(MutableMapping): ...@@ -162,13 +175,15 @@ class NodeDataView(MutableMapping):
# When this is created, the server may already load node data. We need to # When this is created, the server may already load node data. We need to
# initialize the node data in advance. # initialize the node data in advance.
names = g._get_all_ndata_names() names = g._get_all_ndata_names()
self._data = {name: DistTensor(g, _get_ndata_name(name)) for name in names} policy = PartitionPolicy("node", g.get_partition_book())
self._data = {name: DistTensor(g, _get_ndata_name(name), policy) for name in names}
def _get_names(self): def _get_names(self):
return list(self._data.keys()) return list(self._data.keys())
def _add(self, name): def _add(self, name):
self._data[name] = DistTensor(self._graph, _get_ndata_name(name)) policy = PartitionPolicy("node", self._graph.get_partition_book())
self._data[name] = DistTensor(self._graph, _get_ndata_name(name), policy)
def __getitem__(self, key): def __getitem__(self, key):
return self._data[key] return self._data[key]
...@@ -207,13 +222,15 @@ class EdgeDataView(MutableMapping): ...@@ -207,13 +222,15 @@ class EdgeDataView(MutableMapping):
# When this is created, the server may already load edge data. We need to # When this is created, the server may already load edge data. We need to
# initialize the edge data in advance. # initialize the edge data in advance.
names = g._get_all_edata_names() names = g._get_all_edata_names()
self._data = {name: DistTensor(g, _get_edata_name(name)) for name in names} policy = PartitionPolicy("edge", g.get_partition_book())
self._data = {name: DistTensor(g, _get_edata_name(name), policy) for name in names}
def _get_names(self): def _get_names(self):
return list(self._data.keys()) return list(self._data.keys())
def _add(self, name): def _add(self, name):
self._data[name] = DistTensor(self._graph, _get_edata_name(name)) policy = PartitionPolicy("edge", self._graph.get_partition_book())
self._data[name] = DistTensor(self._graph, _get_edata_name(name), policy)
def __getitem__(self, key): def __getitem__(self, key):
return self._data[key] return self._data[key]
...@@ -287,8 +304,9 @@ class DistGraphServer(KVServer): ...@@ -287,8 +304,9 @@ class DistGraphServer(KVServer):
# Init kvstore. # Init kvstore.
if not disable_shared_mem: if not disable_shared_mem:
self.gpb.shared_memory(graph_name) self.gpb.shared_memory(graph_name)
self.add_part_policy(PartitionPolicy('node', server_id, self.gpb)) assert self.gpb.partid == server_id
self.add_part_policy(PartitionPolicy('edge', server_id, self.gpb)) self.add_part_policy(PartitionPolicy('node', self.gpb))
self.add_part_policy(PartitionPolicy('edge', self.gpb))
if not self.is_backup_server(): if not self.is_backup_server():
for name in node_feats: for name in node_feats:
...@@ -351,8 +369,6 @@ class DistGraph: ...@@ -351,8 +369,6 @@ class DistGraph:
self._client.map_shared_data(self._gpb) self._client.map_shared_data(self._gpb)
self._ndata = NodeDataView(self) self._ndata = NodeDataView(self)
self._edata = EdgeDataView(self) self._edata = EdgeDataView(self)
self._default_init_ndata = _default_init_data
self._default_init_edata = _default_init_data
self._num_nodes = 0 self._num_nodes = 0
self._num_edges = 0 self._num_edges = 0
...@@ -361,10 +377,18 @@ class DistGraph: ...@@ -361,10 +377,18 @@ class DistGraph:
self._num_edges += int(part_md['num_edges']) self._num_edges += int(part_md['num_edges'])
def init_ndata(self, ndata_name, shape, dtype): def init_ndata(self, name, shape, dtype, init_func=None):
'''Initialize node data '''Initialize node data
This initializes the node data in the distributed graph storage. This initializes the node data in the distributed graph storage.
Users can provide a init function to initialize data. The signature of
the init function is
```
def init_func(shape, dtype)
```
The inputs are the shape and data type and the output is a tensor with
the initialized values.
Parameters Parameters
---------- ----------
...@@ -374,16 +398,27 @@ class DistGraph: ...@@ -374,16 +398,27 @@ class DistGraph:
The shape of the node data. The shape of the node data.
dtype : dtype dtype : dtype
The data type of the node data. The data type of the node data.
init_func : callable
The function to initialize the data
''' '''
assert shape[0] == self.number_of_nodes() assert shape[0] == self.number_of_nodes()
self._client.init_data(_get_ndata_name(ndata_name), shape, dtype, 'node', self._gpb, if init_func is None:
self._default_init_ndata) init_func = _default_init_data
self._ndata._add(ndata_name) self._client.init_data(_get_ndata_name(name), shape, dtype, 'node', self._gpb, init_func)
self._ndata._add(name)
def init_edata(self, edata_name, shape, dtype): def init_edata(self, name, shape, dtype, init_func=None):
'''Initialize edge data '''Initialize edge data
This initializes the edge data in the distributed graph storage. This initializes the edge data in the distributed graph storage.
Users can provide a init function to initialize data. The signature of
the init function is
```
def init_func(shape, dtype)
```
The inputs are the shape and data type and the output is a tensor with
the initialized values.
Parameters Parameters
---------- ----------
...@@ -393,41 +428,14 @@ class DistGraph: ...@@ -393,41 +428,14 @@ class DistGraph:
The shape of the edge data. The shape of the edge data.
dtype : dtype dtype : dtype
The data type of the edge data. The data type of the edge data.
init_func : callable
The function to initialize the data
''' '''
assert shape[0] == self.number_of_edges() assert shape[0] == self.number_of_edges()
self._client.init_data(_get_edata_name(edata_name), shape, dtype, 'edge', self._gpb, if init_func is None:
self._default_init_edata) init_func = _default_init_data
self._edata._add(edata_name) self._client.init_data(_get_edata_name(name), shape, dtype, 'edge', self._gpb, init_func)
self._edata._add(name)
def init_node_emb(self, name, shape, dtype, initializer):
''' Initialize node embeddings.
This initializes the node embeddings in the distributed graph storage.
Parameters
----------
name : string
The name of the node embeddings.
shape : tuple
The shape of the node embeddings.
dtype : string
The data type of the node embeddings.
initializer : callable
The initializer.
'''
# TODO(zhengda)
raise NotImplementedError("init_node_emb isn't supported yet")
def get_node_embeddings(self):
''' Return node embeddings
Returns
-------
a dict of SparseEmbedding
All node embeddings in the graph store.
'''
# TODO(zhengda)
raise NotImplementedError("get_node_embeddings isn't supported yet")
@property @property
def local_partition(self): def local_partition(self):
......
...@@ -632,17 +632,14 @@ class PartitionPolicy(object): ...@@ -632,17 +632,14 @@ class PartitionPolicy(object):
---------- ----------
policy_str : str policy_str : str
partition-policy string, e.g., 'edge' or 'node'. partition-policy string, e.g., 'edge' or 'node'.
part_id : int
partition ID
partition_book : GraphPartitionBook or RangePartitionBook partition_book : GraphPartitionBook or RangePartitionBook
Main class storing the partition information Main class storing the partition information
""" """
def __init__(self, policy_str, part_id, partition_book): def __init__(self, policy_str, partition_book):
# TODO(chao): support more policies for HeteroGraph # TODO(chao): support more policies for HeteroGraph
assert policy_str in ('edge', 'node'), 'policy_str must be \'edge\' or \'node\'.' assert policy_str in ('edge', 'node'), 'policy_str must be \'edge\' or \'node\'.'
assert part_id >= 0, 'part_id %d cannot be a negative number.' % part_id
self._policy_str = policy_str self._policy_str = policy_str
self._part_id = part_id self._part_id = partition_book.partid
self._partition_book = partition_book self._partition_book = partition_book
@property @property
......
...@@ -7,6 +7,7 @@ from . import rpc ...@@ -7,6 +7,7 @@ from . import rpc
from .graph_partition_book import PartitionPolicy from .graph_partition_book import PartitionPolicy
from .. import backend as F from .. import backend as F
from .. import utils
from .._ffi.ndarray import empty_shared_mem from .._ffi.ndarray import empty_shared_mem
############################ Register KVStore Requsts and Responses ############################### ############################ Register KVStore Requsts and Responses ###############################
...@@ -900,7 +901,7 @@ class KVClient(object): ...@@ -900,7 +901,7 @@ class KVClient(object):
raise RuntimeError("Data %s has already exists!" % name) raise RuntimeError("Data %s has already exists!" % name)
if self._full_data_shape.__contains__(name): if self._full_data_shape.__contains__(name):
raise RuntimeError("Data shape %s has already exists!" % name) raise RuntimeError("Data shape %s has already exists!" % name)
self._part_policy[name] = PartitionPolicy(policy_str, self._part_id, partition_book) self._part_policy[name] = PartitionPolicy(policy_str, partition_book)
shared_data = empty_shared_mem(name+'-kvdata-', False, \ shared_data = empty_shared_mem(name+'-kvdata-', False, \
local_shape, F.reverse_data_type_dict[dtype]) local_shape, F.reverse_data_type_dict[dtype])
dlpack = shared_data.to_dlpack() dlpack = shared_data.to_dlpack()
...@@ -930,7 +931,7 @@ class KVClient(object): ...@@ -930,7 +931,7 @@ class KVClient(object):
shared_data = empty_shared_mem(name+'-kvdata-', False, shape, dtype) shared_data = empty_shared_mem(name+'-kvdata-', False, shape, dtype)
dlpack = shared_data.to_dlpack() dlpack = shared_data.to_dlpack()
self._data_store[name] = F.zerocopy_from_dlpack(dlpack) self._data_store[name] = F.zerocopy_from_dlpack(dlpack)
self._part_policy[name] = PartitionPolicy(policy_str, self._part_id, partition_book) self._part_policy[name] = PartitionPolicy(policy_str, partition_book)
self._pull_handlers[name] = default_pull_handler self._pull_handlers[name] = default_pull_handler
self._push_handlers[name] = default_push_handler self._push_handlers[name] = default_push_handler
# Get full data shape across servers # Get full data shape across servers
...@@ -992,6 +993,8 @@ class KVClient(object): ...@@ -992,6 +993,8 @@ class KVClient(object):
a tensor with the same row size of data ID a tensor with the same row size of data ID
""" """
assert len(name) > 0, 'name cannot be empty.' assert len(name) > 0, 'name cannot be empty.'
id_tensor = utils.toindex(id_tensor)
id_tensor = id_tensor.tousertensor()
assert F.ndim(id_tensor) == 1, 'ID must be a vector.' assert F.ndim(id_tensor) == 1, 'ID must be a vector.'
assert F.shape(id_tensor)[0] == F.shape(data_tensor)[0], \ assert F.shape(id_tensor)[0] == F.shape(data_tensor)[0], \
'The data must has the same row size with ID.' 'The data must has the same row size with ID.'
...@@ -1040,6 +1043,8 @@ class KVClient(object): ...@@ -1040,6 +1043,8 @@ class KVClient(object):
a data tensor with the same row size of id_tensor. a data tensor with the same row size of id_tensor.
""" """
assert len(name) > 0, 'name cannot be empty.' assert len(name) > 0, 'name cannot be empty.'
id_tensor = utils.toindex(id_tensor)
id_tensor = id_tensor.tousertensor()
assert F.ndim(id_tensor) == 1, 'ID must be a vector.' assert F.ndim(id_tensor) == 1, 'ID must be a vector.'
if self._pull_handlers[name] is default_pull_handler: # Use fast-pull if self._pull_handlers[name] is default_pull_handler: # Use fast-pull
part_id = self._part_policy[name].to_partid(id_tensor) part_id = self._part_policy[name].to_partid(id_tensor)
......
"""Define sparse embedding and optimizer."""
from .. import backend as F
class SparseNodeEmbedding:
''' Sparse embeddings in the distributed KVStore.
The sparse embeddings are only used as node embeddings.
Parameters
----------
g : DistGraph
The distributed graph object.
name : str
The name of the embeddings
shape : tuple of int
The shape of the embedding. The first dimension should be the number of nodes.
initializer : callable
The function to create the initial data.
Examples
--------
>>> emb_init = lambda shape, dtype: F.zeros(shape, dtype, F.cpu())
>>> shape = (g.number_of_nodes(), 1)
>>> emb = dgl.distributed.SparseNodeEmbedding(g, 'emb1', shape, emb_init)
>>> optimizer = dgl.distributed.SparseAdagrad([emb], lr=0.001)
>>> for blocks in dataloader:
>>> feats = emb(nids)
>>> loss = F.sum(feats + 1, 0)
>>> loss.backward()
>>> optimizer.step()
'''
def __init__(self, g, name, shape, initializer):
assert shape[0] == g.number_of_nodes()
g.init_ndata(name, shape, F.float32, initializer)
self._tensor = g.ndata[name]
self._trace = []
def __call__(self, idx):
emb = F.attach_grad(self._tensor[idx])
self._trace.append((idx, emb))
return emb
class SparseAdagradUDF:
''' The UDF to update the embeddings with sparse Adagrad.
Parameters
----------
lr : float
The learning rate.
'''
def __init__(self, lr):
self._lr = lr
def __call__(self, data_store, name, indices, data):
''' Update the embeddings with sparse Adagrad.
This function runs on the KVStore server. It updates the gradients by scaling them
according to the state sum.
Parameters
----------
data_store : dict of data
all data in the kvstore.
name : str
data name
indices : tensor
the indices in the local tensor.
data : tensor (mx.ndarray or torch.tensor)
a tensor with the same row size of id
'''
grad_indices = indices
grad_values = data
embs = data_store[name]
state_sum = data_store[name + "_sum"]
with F.no_grad():
grad_sum = F.mean(grad_values * grad_values, 1)
F.index_add_inplace(state_sum, grad_indices, grad_sum)
std = state_sum[grad_indices] # _sparse_mask
std_values = F.unsqueeze((F.sqrt(std) + 1e-10), 1)
F.index_add_inplace(embs, grad_indices, grad_values / std_values * (-self._lr))
def _init_state(shape, dtype):
return F.zeros(shape, dtype, F.cpu())
class SparseAdagrad:
''' The Adagrad optimizer for sparse embeddings.
This optimizer collects gradients for the sparse embeddings and update
the embeddings in the distributed KVStore.
Parameters
----------
params : list of SparseNodeEmbeddings
The list of sparse embeddings.
lr : float
The learning rate.
'''
def __init__(self, params, lr):
self._params = params
self._lr = lr
# We need to register a state sum for each embedding in the kvstore.
for emb in params:
name = emb._tensor.name
kvstore = emb._tensor.kvstore
policy = emb._tensor.part_policy
kvstore.init_data(name + "_sum",
(emb._tensor.shape[0],), emb._tensor.dtype,
policy.policy_str, policy.partition_book, _init_state)
kvstore.register_push_handler(name, SparseAdagradUDF(self._lr))
def step(self):
''' The step function.
The step function is invoked at the end of every batch to push the gradients
of the sparse embeddings to the distributed kvstore and update the embeddings
in the kvstore.
'''
with F.no_grad():
for emb in self._params:
name = emb._tensor.name
kvstore = emb._tensor.kvstore
trace = emb._trace
if len(trace) == 1:
kvstore.push(name, trace[0][0], F.grad(trace[0][1]))
else:
# TODO(zhengda) we need to merge the gradients of the same embeddings first.
idxs = [t[0] for t in trace]
grads = [F.grad(t[1]) for t in trace]
idxs = F.cat(idxs, 0)
# Here let's adjust the gradients with the learning rate first.
# We'll need to scale them with the state sum on the kvstore server
# after we push them.
grads = F.cat(grads, 0)
kvstore.push(name, idxs, grads)
# Clean up the old traces.
emb._trace = []
...@@ -31,28 +31,6 @@ def randn(shape): ...@@ -31,28 +31,6 @@ def randn(shape):
"""Generate a tensor with elements from standard normal distribution.""" """Generate a tensor with elements from standard normal distribution."""
pass pass
def attach_grad(x):
"""Flag the tensor *in-place* to have its gradient computed in backward
pass.
If the flag is already set, reset the gradient buffer as well.
"""
pass
def backward(x, head_gradient=None):
"""Invoke backward computation with an optional head gradient.
Returns nothing."""
pass
def grad(x):
"""Fetches the gradient from the tensor after backward computation."""
pass
def is_no_grad(x):
"""Check whether a tensor has its gradient computed."""
pass
def full(shape, fill_value, dtype, ctx): def full(shape, fill_value, dtype, ctx):
pass pass
...@@ -132,26 +110,3 @@ def dot(a, b): ...@@ -132,26 +110,3 @@ def dot(a, b):
# ---------------- # ----------------
# These are not related to tensors. Some of them are temporary workarounds that # These are not related to tensors. Some of them are temporary workarounds that
# should be included in DGL in the future. # should be included in DGL in the future.
class record_grad(object):
"""Context manager that records the gradients"""
def __init__(self):
pass
def __enter__(self):
pass
def __exit__(self, exc_type, exc_value, exc_traceback):
pass
class no_grad(object):
"""Context manager that explicitly disables gradient computation"""
def __init__(self):
pass
def __enter__(self):
pass
def __exit__(self, exc_type, exc_value, exc_traceback):
pass
...@@ -3,7 +3,6 @@ from __future__ import absolute_import ...@@ -3,7 +3,6 @@ from __future__ import absolute_import
import numpy as np import numpy as np
import mxnet as mx import mxnet as mx
import mxnet.ndarray as nd import mxnet.ndarray as nd
import mxnet.autograd as autograd
def cuda(): def cuda():
return mx.gpu() return mx.gpu()
...@@ -25,19 +24,6 @@ def allclose(a, b, rtol=1e-4, atol=1e-4): ...@@ -25,19 +24,6 @@ def allclose(a, b, rtol=1e-4, atol=1e-4):
def randn(shape): def randn(shape):
return nd.random.randn(*shape) return nd.random.randn(*shape)
def attach_grad(x):
x.attach_grad()
return x
def backward(x, head_gradient=None):
x.backward(head_gradient)
def grad(x):
return x.grad
def is_no_grad(x):
return (x != 0).sum() == 0
def full(shape, fill_value, dtype, ctx): def full(shape, fill_value, dtype, ctx):
return nd.full(shape, fill_value, dtype=dtype, ctx=ctx) return nd.full(shape, fill_value, dtype=dtype, ctx=ctx)
...@@ -88,16 +74,3 @@ def matmul(a, b): ...@@ -88,16 +74,3 @@ def matmul(a, b):
def dot(a, b): def dot(a, b):
return nd.sum(mul(a, b), axis=-1) return nd.sum(mul(a, b), axis=-1)
record_grad = autograd.record
class no_grad(object):
def __init__(self):
pass
def __enter__(self):
pass
def __exit__(self, exc_type, exc_value, exc_traceback):
pass
...@@ -18,25 +18,6 @@ def allclose(a, b, rtol=1e-4, atol=1e-4): ...@@ -18,25 +18,6 @@ def allclose(a, b, rtol=1e-4, atol=1e-4):
def randn(shape): def randn(shape):
return th.randn(*shape) return th.randn(*shape)
def attach_grad(x):
if x.grad is not None:
x.grad.zero_()
return x
else:
return x.requires_grad_()
def backward(x, head_gradient=None):
if head_gradient is not None and head_gradient.shape[0] == 1 and len(head_gradient.shape) == 1:
# Fix for torch 1.3.1
head_gradient = th.tensor(head_gradient.item()).to(head_gradient.device)
x.backward(head_gradient)
def grad(x):
return x.grad
def is_no_grad(x):
return x.grad is None or (x.grad == 0).all()
def full(shape, fill_value, dtype, ctx): def full(shape, fill_value, dtype, ctx):
return th.full(shape, fill_value, dtype=dtype, device=ctx) return th.full(shape, fill_value, dtype=dtype, device=ctx)
...@@ -87,15 +68,3 @@ def matmul(a, b): ...@@ -87,15 +68,3 @@ def matmul(a, b):
def dot(a, b): def dot(a, b):
return sum(mul(a, b), dim=-1) return sum(mul(a, b), dim=-1)
class record_grad(object):
def __init__(self):
pass
def __enter__(self):
pass
def __exit__(self, exc_type, exc_value, exc_traceback):
pass
no_grad = th.no_grad
...@@ -26,90 +26,6 @@ def randn(shape): ...@@ -26,90 +26,6 @@ def randn(shape):
return tf.random.normal(shape) return tf.random.normal(shape)
class GradContext:
def __init__(self):
self.tensor_for_grad = []
self.grad_list = []
self.tape = None
def set_tape(self, tape):
self.tape = tape
def add_tensor(self, x):
idx_pop = []
for idx, ele in enumerate(self.tensor_for_grad):
if ele._id == x._id:
idx_pop.append(idx)
if len(idx_pop) > 0:
self.tensor_for_grad.pop(idx_pop[0])
if self.tape is not None:
self.tape.watch(x)
self.tensor_for_grad.append(x)
def backward(self, x, head_gradient=None):
if head_gradient is not None:
x = x * head_gradient
self.grad_list = self.tape.gradient(x, self.tensor_for_grad)
def is_no_grad(self, x):
idx_pop = []
for idx, ele in enumerate(self.tensor_for_grad):
if ele._id == x._id:
idx_pop.append(idx)
if len(idx_pop) == 0:
return True
else:
return self.grad_list[idx_pop[0]] is None
def grad(self, x):
idx_pop = []
for idx, ele in enumerate(self.tensor_for_grad):
if ele._id == x._id:
idx_pop.append(idx)
assert len(idx_pop) == 1
t = self.grad_list[idx_pop[0]]
return tf.convert_to_tensor(t)
cgrad = GradContext()
def get_cgrad():
return cgrad
class record_grad:
def __init__(self):
self.tape = tf.GradientTape()
def __enter__(self):
cgrad.set_tape(self.tape)
self.tape.__enter__()
for x in cgrad.tensor_for_grad:
self.tape.watch(x)
def __exit__(self, exc_type, exc_value, exc_traceback):
# pass
self.tape.__exit__(exc_type, exc_value, exc_traceback)
cgrad.tape = None
def attach_grad(x):
cgrad.add_tensor(x)
return x
def backward(x, head_gradient=None):
cgrad.backward(x, head_gradient)
def grad(x):
return cgrad.grad(x)
def is_no_grad(x):
return cgrad.is_no_grad(x)
def full(shape, fill_value, dtype, ctx): def full(shape, fill_value, dtype, ctx):
with tf.device(ctx): with tf.device(ctx):
t = tf.constant(fill_value, shape=shape, dtype=dtype) t = tf.constant(fill_value, shape=shape, dtype=dtype)
...@@ -180,6 +96,3 @@ def matmul(a, b): ...@@ -180,6 +96,3 @@ def matmul(a, b):
def dot(a, b): def dot(a, b):
return sum(mul(a, b), dim=-1) return sum(mul(a, b), dim=-1)
no_grad = None
...@@ -13,7 +13,10 @@ from dgl.graph_index import create_graph_index ...@@ -13,7 +13,10 @@ from dgl.graph_index import create_graph_index
from dgl.data.utils import load_graphs, save_graphs from dgl.data.utils import load_graphs, save_graphs
from dgl.distributed import DistGraphServer, DistGraph from dgl.distributed import DistGraphServer, DistGraph
from dgl.distributed import partition_graph, load_partition, load_partition_book, node_split, edge_split from dgl.distributed import partition_graph, load_partition, load_partition_book, node_split, edge_split
from dgl.distributed import SparseAdagrad, SparseNodeEmbedding
from numpy.testing import assert_almost_equal
import backend as F import backend as F
import math
import unittest import unittest
import pickle import pickle
...@@ -58,6 +61,9 @@ def run_server(graph_name, server_id, num_clients, shared_mem): ...@@ -58,6 +61,9 @@ def run_server(graph_name, server_id, num_clients, shared_mem):
print('start server', server_id) print('start server', server_id)
g.start() g.start()
def emb_init(shape, dtype):
return F.zeros(shape, dtype, F.cpu())
def run_client(graph_name, part_id, num_nodes, num_edges): def run_client(graph_name, part_id, num_nodes, num_edges):
time.sleep(5) time.sleep(5)
gpb = load_partition_book('/tmp/dist_graph/{}.json'.format(graph_name), gpb = load_partition_book('/tmp/dist_graph/{}.json'.format(graph_name),
...@@ -92,6 +98,47 @@ def run_client(graph_name, part_id, num_nodes, num_edges): ...@@ -92,6 +98,47 @@ def run_client(graph_name, part_id, num_nodes, num_edges):
feats = g.edata['test1'][eids] feats = g.edata['test1'][eids]
assert np.all(F.asnumpy(feats) == 0) assert np.all(F.asnumpy(feats) == 0)
# Test sparse emb
try:
new_shape = (g.number_of_nodes(), 1)
emb = SparseNodeEmbedding(g, 'emb1', new_shape, emb_init)
lr = 0.001
optimizer = SparseAdagrad([emb], lr=lr)
with F.record_grad():
feats = emb(nids)
assert np.all(F.asnumpy(feats) == np.zeros((len(nids), 1)))
loss = F.sum(feats + 1, 0)
loss.backward()
optimizer.step()
feats = emb(nids)
assert_almost_equal(F.asnumpy(feats), np.ones((len(nids), 1)) * -lr)
rest = np.setdiff1d(np.arange(g.number_of_nodes()), F.asnumpy(nids))
feats1 = emb(rest)
assert np.all(F.asnumpy(feats1) == np.zeros((len(rest), 1)))
policy = dgl.distributed.PartitionPolicy('node', g.get_partition_book())
grad_sum = dgl.distributed.DistTensor(g, 'node:emb1_sum', policy)
assert np.all(F.asnumpy(grad_sum[nids]) == np.ones((len(nids), 1)))
assert np.all(F.asnumpy(grad_sum[rest]) == np.zeros((len(rest), 1)))
emb = SparseNodeEmbedding(g, 'emb2', new_shape, emb_init)
optimizer = SparseAdagrad([emb], lr=lr)
with F.record_grad():
feats1 = emb(nids)
feats2 = emb(nids)
feats = F.cat([feats1, feats2], 0)
assert np.all(F.asnumpy(feats) == np.zeros((len(nids) * 2, 1)))
loss = F.sum(feats + 1, 0)
loss.backward()
optimizer.step()
feats = emb(nids)
assert_almost_equal(F.asnumpy(feats), np.ones((len(nids), 1)) * math.sqrt(2) * -lr)
rest = np.setdiff1d(np.arange(g.number_of_nodes()), F.asnumpy(nids))
feats1 = emb(rest)
assert np.all(F.asnumpy(feats1) == np.zeros((len(rest), 1)))
except NotImplementedError as e:
pass
# Test write data # Test write data
new_feats = F.ones((len(nids), 2), F.int32, F.cpu()) new_feats = F.ones((len(nids), 2), F.int32, F.cpu())
g.ndata['test1'][nids] = new_feats g.ndata['test1'][nids] = new_feats
......
...@@ -70,11 +70,9 @@ gpb = dgl.distributed.GraphPartitionBook(part_id=0, ...@@ -70,11 +70,9 @@ gpb = dgl.distributed.GraphPartitionBook(part_id=0,
part_graph=g) part_graph=g)
node_policy = dgl.distributed.PartitionPolicy(policy_str='node', node_policy = dgl.distributed.PartitionPolicy(policy_str='node',
part_id=0,
partition_book=gpb) partition_book=gpb)
edge_policy = dgl.distributed.PartitionPolicy(policy_str='edge', edge_policy = dgl.distributed.PartitionPolicy(policy_str='edge',
part_id=0,
partition_book=gpb) partition_book=gpb)
data_0 = F.tensor([[1.,1.],[1.,1.],[1.,1.],[1.,1.],[1.,1.],[1.,1.]], F.float32) data_0 = F.tensor([[1.,1.],[1.,1.],[1.,1.],[1.,1.],[1.,1.],[1.,1.]], F.float32)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment