Unverified Commit a208e886 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Black auto fix. (#4680)



* [Misc] Black auto fix.

* fix pylint disable
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent 29434e65
from __future__ import absolute_import from __future__ import absolute_import
import builtins
import numbers
from distutils.version import LooseVersion from distutils.version import LooseVersion
import scipy # Weird bug in new pytorch when import scipy after import torch
import numpy as np import numpy as np
import scipy # Weird bug in new pytorch when import scipy after import torch
import torch as th import torch as th
import builtins
import numbers
from torch.utils import dlpack from torch.utils import dlpack
from ... import ndarray as nd from ... import ndarray as nd
...@@ -16,24 +16,33 @@ from ...function.base import TargetCode ...@@ -16,24 +16,33 @@ from ...function.base import TargetCode
if LooseVersion(th.__version__) < LooseVersion("1.9.0"): if LooseVersion(th.__version__) < LooseVersion("1.9.0"):
raise RuntimeError("DGL requires PyTorch >= 1.9.0") raise RuntimeError("DGL requires PyTorch >= 1.9.0")
def data_type_dict(): def data_type_dict():
return {'float16' : th.float16, return {
'float32' : th.float32, "float16": th.float16,
'float64' : th.float64, "float32": th.float32,
'uint8' : th.uint8, "float64": th.float64,
'int8' : th.int8, "uint8": th.uint8,
'int16' : th.int16, "int8": th.int8,
'int32' : th.int32, "int16": th.int16,
'int64' : th.int64, "int32": th.int32,
'bool' : th.bool} "int64": th.int64,
"bool": th.bool,
}
def cpu(): def cpu():
return th.device('cpu') return th.device("cpu")
def tensor(data, dtype=None): def tensor(data, dtype=None):
if isinstance(data, numbers.Number): if isinstance(data, numbers.Number):
data = [data] data = [data]
if isinstance(data, list) and len(data) > 0 and isinstance(data[0], th.Tensor): if (
isinstance(data, list)
and len(data) > 0
and isinstance(data[0], th.Tensor)
):
# prevent GPU->CPU->GPU copies # prevent GPU->CPU->GPU copies
if data[0].ndim == 0: if data[0].ndim == 0:
# zero dimenion scalar tensors # zero dimenion scalar tensors
...@@ -43,9 +52,11 @@ def tensor(data, dtype=None): ...@@ -43,9 +52,11 @@ def tensor(data, dtype=None):
else: else:
return th.as_tensor(data, dtype=dtype) return th.as_tensor(data, dtype=dtype)
def as_scalar(data): def as_scalar(data):
return data.item() return data.item()
def get_preferred_sparse_format(): def get_preferred_sparse_format():
"""Get the preferred sparse matrix format supported by the backend. """Get the preferred sparse matrix format supported by the backend.
...@@ -54,187 +65,241 @@ def get_preferred_sparse_format(): ...@@ -54,187 +65,241 @@ def get_preferred_sparse_format():
""" """
return "coo" return "coo"
def sparse_matrix(data, index, shape, force_format=False): def sparse_matrix(data, index, shape, force_format=False):
fmt = index[0] fmt = index[0]
if fmt != 'coo': if fmt != "coo":
raise TypeError('Pytorch backend only supports COO format. But got %s.' % fmt) raise TypeError(
"Pytorch backend only supports COO format. But got %s." % fmt
)
spmat = th.sparse_coo_tensor(index[1], data, shape) spmat = th.sparse_coo_tensor(index[1], data, shape)
return spmat, None return spmat, None
def sparse_matrix_indices(spmat): def sparse_matrix_indices(spmat):
return ('coo', spmat._indices()) return ("coo", spmat._indices())
def is_tensor(obj): def is_tensor(obj):
return isinstance(obj, th.Tensor) return isinstance(obj, th.Tensor)
def shape(input): def shape(input):
return input.shape return input.shape
def dtype(input): def dtype(input):
return input.dtype return input.dtype
def ndim(input): def ndim(input):
return input.dim() return input.dim()
def context(input): def context(input):
return input.device return input.device
def device_type(ctx): def device_type(ctx):
return th.device(ctx).type return th.device(ctx).type
def device_id(ctx): def device_id(ctx):
ctx = th.device(ctx) ctx = th.device(ctx)
if ctx.index is None: if ctx.index is None:
return 0 if ctx.type == 'cpu' else th.cuda.current_device() return 0 if ctx.type == "cpu" else th.cuda.current_device()
else: else:
return ctx.index return ctx.index
def to_backend_ctx(dglctx): def to_backend_ctx(dglctx):
dev_type = dglctx.device_type dev_type = dglctx.device_type
if dev_type == 1: if dev_type == 1:
return th.device('cpu') return th.device("cpu")
elif dev_type == 2: elif dev_type == 2:
return th.device('cuda', dglctx.device_id) return th.device("cuda", dglctx.device_id)
else: else:
raise ValueError('Unsupported DGL device context:', dglctx) raise ValueError("Unsupported DGL device context:", dglctx)
def astype(input, ty): def astype(input, ty):
return input.type(ty) return input.type(ty)
def asnumpy(input): def asnumpy(input):
if isinstance(input, th.sparse.FloatTensor): if isinstance(input, th.sparse.FloatTensor):
return input.to_dense().cpu().detach().numpy() return input.to_dense().cpu().detach().numpy()
else: else:
return input.cpu().detach().numpy() return input.cpu().detach().numpy()
def copy_to(input, ctx, **kwargs): def copy_to(input, ctx, **kwargs):
ctx = th.device(ctx) ctx = th.device(ctx)
if ctx.type == 'cpu': if ctx.type == "cpu":
return input.cpu() return input.cpu()
elif ctx.type == 'cuda': elif ctx.type == "cuda":
if ctx.index is not None: if ctx.index is not None:
th.cuda.set_device(ctx.index) th.cuda.set_device(ctx.index)
return input.cuda(**kwargs) return input.cuda(**kwargs)
else: else:
raise RuntimeError('Invalid context', ctx) raise RuntimeError("Invalid context", ctx)
def is_pinned(input): def is_pinned(input):
return input.is_pinned() return input.is_pinned()
def sum(input, dim, keepdims=False): def sum(input, dim, keepdims=False):
return th.sum(input, dim=dim, keepdim=keepdims) return th.sum(input, dim=dim, keepdim=keepdims)
def floor_div(in1, in2): def floor_div(in1, in2):
return in1 // in2 return in1 // in2
def reduce_sum(input): def reduce_sum(input):
return input.sum() return input.sum()
def cumsum(input, dim): def cumsum(input, dim):
return th.cumsum(input, dim=dim) return th.cumsum(input, dim=dim)
def mean(input, dim): def mean(input, dim):
return th.mean(input, dim=dim) return th.mean(input, dim=dim)
def reduce_mean(input): def reduce_mean(input):
return input.mean() return input.mean()
def max(input, dim): def max(input, dim):
# NOTE: the second argmax array is not returned # NOTE: the second argmax array is not returned
return th.max(input, dim=dim)[0] return th.max(input, dim=dim)[0]
def reduce_max(input): def reduce_max(input):
return input.max() return input.max()
def min(input, dim): def min(input, dim):
# NOTE: the second argmin array is not returned # NOTE: the second argmin array is not returned
return th.min(input, dim=dim)[0] return th.min(input, dim=dim)[0]
def reduce_min(input): def reduce_min(input):
return input.min() return input.min()
def argsort(input, dim, descending): def argsort(input, dim, descending):
return th.argsort(input, dim=dim, descending=descending) return th.argsort(input, dim=dim, descending=descending)
def topk(input, k, dim, descending=True): def topk(input, k, dim, descending=True):
return th.topk(input, k, dim, largest=descending)[0] return th.topk(input, k, dim, largest=descending)[0]
def argtopk(input, k, dim, descending=True): def argtopk(input, k, dim, descending=True):
return th.topk(input, k, dim, largest=descending)[1] return th.topk(input, k, dim, largest=descending)[1]
def exp(input): def exp(input):
return th.exp(input) return th.exp(input)
def inverse(input): def inverse(input):
return th.inverse(input) return th.inverse(input)
def sqrt(input): def sqrt(input):
return th.sqrt(input) return th.sqrt(input)
def softmax(input, dim=-1): def softmax(input, dim=-1):
return th.softmax(input, dim=dim) return th.softmax(input, dim=dim)
def cat(seq, dim): def cat(seq, dim):
return th.cat(seq, dim=dim) return th.cat(seq, dim=dim)
def stack(seq, dim): def stack(seq, dim):
return th.stack(seq, dim=dim) return th.stack(seq, dim=dim)
def split(input, sizes_or_sections, dim): def split(input, sizes_or_sections, dim):
return th.split(input, sizes_or_sections, dim) return th.split(input, sizes_or_sections, dim)
def repeat(input, repeats, dim): def repeat(input, repeats, dim):
return th.repeat_interleave(input, repeats, dim) # PyTorch 1.1 return th.repeat_interleave(input, repeats, dim) # PyTorch 1.1
def gather_row(data, row_index): def gather_row(data, row_index):
return th.index_select(data, 0, row_index.long()) return th.index_select(data, 0, row_index.long())
def slice_axis(data, axis, begin, end): def slice_axis(data, axis, begin, end):
return th.narrow(data, axis, begin, end - begin) return th.narrow(data, axis, begin, end - begin)
def take(data, indices, dim): def take(data, indices, dim):
new_shape = data.shape[:dim] + indices.shape + data.shape[dim+1:] new_shape = data.shape[:dim] + indices.shape + data.shape[dim + 1 :]
return th.index_select(data, dim, indices.view(-1)).view(new_shape) return th.index_select(data, dim, indices.view(-1)).view(new_shape)
def narrow_row(x, start, stop): def narrow_row(x, start, stop):
return x[start:stop] return x[start:stop]
def index_add_inplace(data, row_idx, value): def index_add_inplace(data, row_idx, value):
data.index_add_(0, row_idx, value) data.index_add_(0, row_idx, value)
def scatter_row(data, row_index, value): def scatter_row(data, row_index, value):
return data.index_copy(0, row_index.long(), value) return data.index_copy(0, row_index.long(), value)
def scatter_row_inplace(data, row_index, value): def scatter_row_inplace(data, row_index, value):
data[row_index.long()] = value data[row_index.long()] = value
def squeeze(input, dim): def squeeze(input, dim):
return th.squeeze(input, dim) return th.squeeze(input, dim)
def unsqueeze(input, dim): def unsqueeze(input, dim):
return th.unsqueeze(input, dim) return th.unsqueeze(input, dim)
def reshape(input, shape): def reshape(input, shape):
return th.reshape(input ,shape) return th.reshape(input, shape)
def swapaxes(input, axis1, axis2): def swapaxes(input, axis1, axis2):
return th.transpose(input, axis1, axis2) return th.transpose(input, axis1, axis2)
def zeros(shape, dtype, ctx): def zeros(shape, dtype, ctx):
return th.zeros(shape, dtype=dtype, device=ctx) return th.zeros(shape, dtype=dtype, device=ctx)
def zeros_like(input): def zeros_like(input):
return th.zeros_like(input) return th.zeros_like(input)
def ones(shape, dtype, ctx): def ones(shape, dtype, ctx):
return th.ones(shape, dtype=dtype, device=ctx) return th.ones(shape, dtype=dtype, device=ctx)
def uniform(shape, dtype, ctx, low, high): def uniform(shape, dtype, ctx, low, high):
return th.empty(shape, dtype=dtype, device=ctx).uniform_(low, high) return th.empty(shape, dtype=dtype, device=ctx).uniform_(low, high)
def randint(shape, dtype, ctx, low, high): def randint(shape, dtype, ctx, low, high):
return th.randint(low, high, shape, dtype=dtype, device=ctx) return th.randint(low, high, shape, dtype=dtype, device=ctx)
def pad_packed_tensor(input, lengths, value, l_min=None): def pad_packed_tensor(input, lengths, value, l_min=None):
old_shape = input.shape old_shape = input.shape
device = input.device device = input.device
...@@ -252,11 +317,12 @@ def pad_packed_tensor(input, lengths, value, l_min=None): ...@@ -252,11 +317,12 @@ def pad_packed_tensor(input, lengths, value, l_min=None):
x.fill_(value) x.fill_(value)
index = th.ones(len(input), dtype=th.int64, device=device) index = th.ones(len(input), dtype=th.int64, device=device)
cum_lengths = th.cumsum(lengths, 0) cum_lengths = th.cumsum(lengths, 0)
index[cum_lengths[:-1]] += (max_len - lengths[:-1]) index[cum_lengths[:-1]] += max_len - lengths[:-1]
index = th.cumsum(index, 0) - 1 index = th.cumsum(index, 0) - 1
x[index] = input x[index] = input
return x.view(batch_size, max_len, *old_shape[1:]) return x.view(batch_size, max_len, *old_shape[1:])
def pack_padded_tensor(input, lengths): def pack_padded_tensor(input, lengths):
max_len = input.shape[1] max_len = input.shape[1]
device = input.device device = input.device
...@@ -268,222 +334,377 @@ def pack_padded_tensor(input, lengths): ...@@ -268,222 +334,377 @@ def pack_padded_tensor(input, lengths):
out_len = lengths.sum().item() out_len = lengths.sum().item()
index = th.ones(out_len, dtype=th.int64, device=device) index = th.ones(out_len, dtype=th.int64, device=device)
cum_lengths = th.cumsum(lengths, 0) cum_lengths = th.cumsum(lengths, 0)
index[cum_lengths[:-1]] += (max_len - lengths[:-1]) index[cum_lengths[:-1]] += max_len - lengths[:-1]
index = th.cumsum(index, 0) - 1 index = th.cumsum(index, 0) - 1
return input[index] return input[index]
def boolean_mask(input, mask): def boolean_mask(input, mask):
if 'bool' not in str(mask.dtype): if "bool" not in str(mask.dtype):
mask = th.tensor(mask, dtype=th.bool) mask = th.tensor(mask, dtype=th.bool)
return input[mask] return input[mask]
def equal(x, y): def equal(x, y):
return x == y return x == y
def allclose(x, y, rtol=1e-4, atol=1e-4): def allclose(x, y, rtol=1e-4, atol=1e-4):
return th.allclose(x, y, rtol=rtol, atol=atol) return th.allclose(x, y, rtol=rtol, atol=atol)
def logical_not(input): def logical_not(input):
return ~input return ~input
def logical_and(input1, input2): def logical_and(input1, input2):
return input1 & input2 return input1 & input2
def clone(input): def clone(input):
return input.clone() return input.clone()
def clamp(data, min_val, max_val): def clamp(data, min_val, max_val):
return th.clamp(data, min_val, max_val) return th.clamp(data, min_val, max_val)
def replace_inf_with_zero(x): def replace_inf_with_zero(x):
return th.masked_fill(x, th.isinf(x), 0) return th.masked_fill(x, th.isinf(x), 0)
def count_nonzero(input): def count_nonzero(input):
# TODO: fallback to numpy for backward compatibility # TODO: fallback to numpy for backward compatibility
return np.count_nonzero(input) return np.count_nonzero(input)
def unique(input, return_inverse=False, return_counts=False): def unique(input, return_inverse=False, return_counts=False):
if input.dtype == th.bool: if input.dtype == th.bool:
input = input.type(th.int8) input = input.type(th.int8)
return th.unique(input, return_inverse=return_inverse, return_counts=return_counts) return th.unique(
input, return_inverse=return_inverse, return_counts=return_counts
)
def full_1d(length, fill_value, dtype, ctx): def full_1d(length, fill_value, dtype, ctx):
return th.full((length,), fill_value, dtype=dtype, device=ctx) return th.full((length,), fill_value, dtype=dtype, device=ctx)
def nonzero_1d(input): def nonzero_1d(input):
x = th.nonzero(input, as_tuple=False).squeeze() x = th.nonzero(input, as_tuple=False).squeeze()
return x if x.dim() == 1 else x.view(-1) return x if x.dim() == 1 else x.view(-1)
def sort_1d(input): def sort_1d(input):
return th.sort(input) return th.sort(input)
def arange(start, stop, dtype=th.int64, ctx=None): def arange(start, stop, dtype=th.int64, ctx=None):
return th.arange(start, stop, dtype=dtype, device=ctx) return th.arange(start, stop, dtype=dtype, device=ctx)
def rand_shuffle(arr): def rand_shuffle(arr):
idx = th.randperm(len(arr)) idx = th.randperm(len(arr))
return arr[idx] return arr[idx]
def zerocopy_to_dlpack(input): def zerocopy_to_dlpack(input):
return dlpack.to_dlpack(input.contiguous()) return dlpack.to_dlpack(input.contiguous())
def zerocopy_from_dlpack(dlpack_tensor): def zerocopy_from_dlpack(dlpack_tensor):
return dlpack.from_dlpack(dlpack_tensor) return dlpack.from_dlpack(dlpack_tensor)
def zerocopy_to_numpy(input): def zerocopy_to_numpy(input):
# NOTE: not zerocopy # NOTE: not zerocopy
return asnumpy(input) return asnumpy(input)
def zerocopy_from_numpy(np_array): def zerocopy_from_numpy(np_array):
return th.as_tensor(np_array) return th.as_tensor(np_array)
if LooseVersion(th.__version__) >= LooseVersion("1.10.0"): if LooseVersion(th.__version__) >= LooseVersion("1.10.0"):
def zerocopy_to_dgl_ndarray(data): def zerocopy_to_dgl_ndarray(data):
if data.dtype == th.bool: if data.dtype == th.bool:
data = data.byte() data = data.byte()
return nd.from_dlpack(dlpack.to_dlpack(data.contiguous())) return nd.from_dlpack(dlpack.to_dlpack(data.contiguous()))
else: else:
def zerocopy_to_dgl_ndarray(data): def zerocopy_to_dgl_ndarray(data):
return nd.from_dlpack(dlpack.to_dlpack(data.contiguous())) return nd.from_dlpack(dlpack.to_dlpack(data.contiguous()))
def zerocopy_to_dgl_ndarray_for_write(input): def zerocopy_to_dgl_ndarray_for_write(input):
assert input.is_contiguous(), "Cannot convert non-contiguous tensors " \ assert input.is_contiguous(), (
"Cannot convert non-contiguous tensors "
"to dgl ndarray for write. Call .to_contiguous() first." "to dgl ndarray for write. Call .to_contiguous() first."
assert input.numel() == input.storage().size(), "Cannot convert view " \ )
"tensors to dgl ndarray for write." assert input.numel() == input.storage().size(), (
"Cannot convert view " "tensors to dgl ndarray for write."
)
return zerocopy_to_dgl_ndarray(input) return zerocopy_to_dgl_ndarray(input)
def zerocopy_from_dgl_ndarray(data): def zerocopy_from_dgl_ndarray(data):
if data.shape == (0,): if data.shape == (0,):
# NOTE: PyTorch v1.5 does not accept DLPack object representing empty CUDA tensor. # NOTE: PyTorch v1.5 does not accept DLPack object representing empty CUDA tensor.
# Related issue: https://github.com/pytorch/pytorch/issues/41182 # Related issue: https://github.com/pytorch/pytorch/issues/41182
# The issue will be fixed in v1.6 and later. # The issue will be fixed in v1.6 and later.
return th.tensor([], dtype=getattr(th, data.dtype), return th.tensor(
device=to_backend_ctx(data.ctx)) [], dtype=getattr(th, data.dtype), device=to_backend_ctx(data.ctx)
)
elif len(data.shape) == 0 or builtins.min(data.shape) == 0: elif len(data.shape) == 0 or builtins.min(data.shape) == 0:
# Workaround the same issue as above, but preserve the shape of the # Workaround the same issue as above, but preserve the shape of the
# empty tensor. This is needed by the sparse optimizer when one of # empty tensor. This is needed by the sparse optimizer when one of
# processors may receive no gradients to update, but we want to keep # processors may receive no gradients to update, but we want to keep
# the dimension of the embedding. # the dimension of the embedding.
return th.empty(data.shape, dtype=getattr(th, data.dtype), return th.empty(
device=to_backend_ctx(data.ctx)) data.shape,
dtype=getattr(th, data.dtype),
device=to_backend_ctx(data.ctx),
)
else: else:
return dlpack.from_dlpack(data.to_dlpack()) return dlpack.from_dlpack(data.to_dlpack())
class BinaryReduce(th.autograd.Function): class BinaryReduce(th.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, reducer, binary_op, graph, lhs, rhs, lhs_data, rhs_data, out_data, def forward(
out_size, lhs_map, rhs_map, out_map): ctx,
reducer,
binary_op,
graph,
lhs,
rhs,
lhs_data,
rhs_data,
out_data,
out_size,
lhs_map,
rhs_map,
out_map,
):
lhs_data_nd = zerocopy_to_dgl_ndarray(lhs_data) lhs_data_nd = zerocopy_to_dgl_ndarray(lhs_data)
rhs_data_nd = zerocopy_to_dgl_ndarray(rhs_data) rhs_data_nd = zerocopy_to_dgl_ndarray(rhs_data)
feat_shape = K.infer_binary_feature_shape(binary_op, lhs_data_nd, rhs_data_nd) feat_shape = K.infer_binary_feature_shape(
binary_op, lhs_data_nd, rhs_data_nd
)
out_shape = feat_shape out_shape = feat_shape
if binary_op == 'dot': if binary_op == "dot":
out_shape = feat_shape[:-1] out_shape = feat_shape[:-1]
out_data_nd = zerocopy_to_dgl_ndarray(out_data) out_data_nd = zerocopy_to_dgl_ndarray(out_data)
K.binary_op_reduce( K.binary_op_reduce(
reducer if reducer != 'mean' else 'sum', reducer if reducer != "mean" else "sum",
binary_op, graph, lhs, rhs, lhs_data_nd, rhs_data_nd, binary_op,
out_data_nd, lhs_map[0], rhs_map[0], out_map[0]) graph,
lhs,
rhs,
lhs_data_nd,
rhs_data_nd,
out_data_nd,
lhs_map[0],
rhs_map[0],
out_map[0],
)
# normalize if mean reducer # normalize if mean reducer
# NOTE(zihao): this is a temporary hack and we should have better solution in the future. # NOTE(zihao): this is a temporary hack and we should have better solution in the future.
if reducer == 'mean': if reducer == "mean":
degs = lhs_data.new_empty((out_data.shape[0],)) degs = lhs_data.new_empty((out_data.shape[0],))
degs_nd = zerocopy_to_dgl_ndarray(degs) degs_nd = zerocopy_to_dgl_ndarray(degs)
if lhs != TargetCode.DST: # src or edge if lhs != TargetCode.DST: # src or edge
target = lhs target = lhs
n = lhs_data.shape[0] n = lhs_data.shape[0]
in_map = lhs_map[0] in_map = lhs_map[0]
else: # rhs != TargetCode.DST else: # rhs != TargetCode.DST
target = rhs target = rhs
n = rhs_data.shape[0] n = rhs_data.shape[0]
in_map = rhs_map[0] in_map = rhs_map[0]
in_ones = lhs_data.new_ones((n,)) in_ones = lhs_data.new_ones((n,))
in_ones_nd = zerocopy_to_dgl_ndarray(in_ones) in_ones_nd = zerocopy_to_dgl_ndarray(in_ones)
K.copy_reduce( K.copy_reduce(
'sum', graph, target, in_ones_nd, degs_nd, in_map, out_map[0]) "sum", graph, target, in_ones_nd, degs_nd, in_map, out_map[0]
)
# reshape # reshape
degs = degs.reshape((out_data.shape[0],) + (1,) * (out_data.dim() - 1)).clamp(min=1) degs = degs.reshape(
(out_data.shape[0],) + (1,) * (out_data.dim() - 1)
).clamp(min=1)
out_data = out_data / degs out_data = out_data / degs
else: else:
degs = None degs = None
# save_for_backward can only save variables # save_for_backward can only save variables
ctx.backward_cache = (reducer, binary_op, graph, lhs, rhs, lhs_map, ctx.backward_cache = (
rhs_map, out_map, feat_shape, degs) reducer,
binary_op,
graph,
lhs,
rhs,
lhs_map,
rhs_map,
out_map,
feat_shape,
degs,
)
ctx.save_for_backward(lhs_data, rhs_data, out_data) ctx.save_for_backward(lhs_data, rhs_data, out_data)
return out_data return out_data
@staticmethod @staticmethod
def backward(ctx, grad_out): def backward(ctx, grad_out):
reducer, binary_op, graph, lhs, rhs, lhs_map, rhs_map, out_map, \ (
feat_shape, degs = ctx.backward_cache reducer,
binary_op,
graph,
lhs,
rhs,
lhs_map,
rhs_map,
out_map,
feat_shape,
degs,
) = ctx.backward_cache
lhs_data, rhs_data, out_data = ctx.saved_tensors lhs_data, rhs_data, out_data = ctx.saved_tensors
lhs_data_nd = zerocopy_to_dgl_ndarray(lhs_data) lhs_data_nd = zerocopy_to_dgl_ndarray(lhs_data)
rhs_data_nd = zerocopy_to_dgl_ndarray(rhs_data) rhs_data_nd = zerocopy_to_dgl_ndarray(rhs_data)
out_data_nd = zerocopy_to_dgl_ndarray(out_data) out_data_nd = zerocopy_to_dgl_ndarray(out_data)
grad_lhs = None grad_lhs = None
grad_rhs = None grad_rhs = None
if reducer == 'mean': if reducer == "mean":
grad_out = grad_out / degs grad_out = grad_out / degs
grad_out_nd = zerocopy_to_dgl_ndarray(grad_out) grad_out_nd = zerocopy_to_dgl_ndarray(grad_out)
if ctx.needs_input_grad[5]: if ctx.needs_input_grad[5]:
grad_lhs = grad_out.new_empty((lhs_data_nd.shape[0],) + feat_shape) grad_lhs = grad_out.new_empty((lhs_data_nd.shape[0],) + feat_shape)
K.backward_lhs_binary_op_reduce( K.backward_lhs_binary_op_reduce(
reducer if reducer != 'mean' else 'sum', reducer if reducer != "mean" else "sum",
binary_op, graph, lhs, rhs, lhs_data_nd, rhs_data_nd, binary_op,
out_data_nd, grad_out_nd, zerocopy_to_dgl_ndarray(grad_lhs), graph,
lhs_map[1], rhs_map[1], out_map[1]) lhs,
rhs,
lhs_data_nd,
rhs_data_nd,
out_data_nd,
grad_out_nd,
zerocopy_to_dgl_ndarray(grad_lhs),
lhs_map[1],
rhs_map[1],
out_map[1],
)
grad_lhs = _reduce_grad(grad_lhs, lhs_data_nd.shape) grad_lhs = _reduce_grad(grad_lhs, lhs_data_nd.shape)
if ctx.needs_input_grad[6]: if ctx.needs_input_grad[6]:
grad_rhs = grad_out.new_empty((rhs_data_nd.shape[0],) + feat_shape) grad_rhs = grad_out.new_empty((rhs_data_nd.shape[0],) + feat_shape)
K.backward_rhs_binary_op_reduce( K.backward_rhs_binary_op_reduce(
reducer if reducer != 'mean' else 'sum', reducer if reducer != "mean" else "sum",
binary_op, graph, lhs, rhs, lhs_data_nd, rhs_data_nd, binary_op,
out_data_nd, grad_out_nd, zerocopy_to_dgl_ndarray(grad_rhs), graph,
lhs_map[1], rhs_map[1], out_map[1]) lhs,
rhs,
lhs_data_nd,
rhs_data_nd,
out_data_nd,
grad_out_nd,
zerocopy_to_dgl_ndarray(grad_rhs),
lhs_map[1],
rhs_map[1],
out_map[1],
)
grad_rhs = _reduce_grad(grad_rhs, rhs_data_nd.shape) grad_rhs = _reduce_grad(grad_rhs, rhs_data_nd.shape)
return None, None, None, None, None, grad_lhs, grad_rhs, None, None, None, \ return (
None, None None,
None,
None,
def binary_reduce(reducer, binary_op, graph, lhs, rhs, lhs_data, rhs_data, None,
out_size, lhs_map=(None, None), rhs_map=(None, None), out_map=(None, None)): None,
grad_lhs,
grad_rhs,
None,
None,
None,
None,
None,
)
def binary_reduce(
reducer,
binary_op,
graph,
lhs,
rhs,
lhs_data,
rhs_data,
out_size,
lhs_map=(None, None),
rhs_map=(None, None),
out_map=(None, None),
):
lhs_data_nd = zerocopy_to_dgl_ndarray(lhs_data) lhs_data_nd = zerocopy_to_dgl_ndarray(lhs_data)
rhs_data_nd = zerocopy_to_dgl_ndarray(rhs_data) rhs_data_nd = zerocopy_to_dgl_ndarray(rhs_data)
feat_shape = K.infer_binary_feature_shape(binary_op, lhs_data_nd, rhs_data_nd) feat_shape = K.infer_binary_feature_shape(
binary_op, lhs_data_nd, rhs_data_nd
)
out_shape = feat_shape out_shape = feat_shape
if binary_op == 'dot': if binary_op == "dot":
out_shape = feat_shape[:-1] out_shape = feat_shape[:-1]
out_data = lhs_data.new_empty((out_size,) + out_shape) out_data = lhs_data.new_empty((out_size,) + out_shape)
return BinaryReduce.apply( return BinaryReduce.apply(
reducer, binary_op, graph, lhs, rhs, lhs_data, rhs_data, out_data, reducer,
out_size, lhs_map, rhs_map, out_map) binary_op,
graph,
lhs,
rhs,
lhs_data,
rhs_data,
out_data,
out_size,
lhs_map,
rhs_map,
out_map,
)
class CopyReduce(th.autograd.Function): class CopyReduce(th.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, reducer, graph, target, in_data, out_data, out_size, in_map, def forward(
out_map): ctx,
reducer,
graph,
target,
in_data,
out_data,
out_size,
in_map,
out_map,
):
in_data_nd = zerocopy_to_dgl_ndarray(in_data) in_data_nd = zerocopy_to_dgl_ndarray(in_data)
out_data_nd = zerocopy_to_dgl_ndarray(out_data) out_data_nd = zerocopy_to_dgl_ndarray(out_data)
K.copy_reduce( K.copy_reduce(
reducer if reducer != 'mean' else 'sum', reducer if reducer != "mean" else "sum",
graph, target, in_data_nd, out_data_nd, in_map[0], out_map[0]) graph,
target,
in_data_nd,
out_data_nd,
in_map[0],
out_map[0],
)
# normalize if mean reducer # normalize if mean reducer
# NOTE(zihao): this is a temporary hack and we should have better solution in the future. # NOTE(zihao): this is a temporary hack and we should have better solution in the future.
if reducer == 'mean': if reducer == "mean":
in_ones = in_data.new_ones((in_data.shape[0],)) in_ones = in_data.new_ones((in_data.shape[0],))
degs = in_data.new_empty((out_data.shape[0],)) degs = in_data.new_empty((out_data.shape[0],))
in_ones_nd = zerocopy_to_dgl_ndarray(in_ones) in_ones_nd = zerocopy_to_dgl_ndarray(in_ones)
degs_nd = zerocopy_to_dgl_ndarray(degs) degs_nd = zerocopy_to_dgl_ndarray(degs)
K.copy_reduce( K.copy_reduce(
'sum', graph, target, in_ones_nd, degs_nd, in_map[0], out_map[0]) "sum", graph, target, in_ones_nd, degs_nd, in_map[0], out_map[0]
)
# reshape # reshape
degs = degs.reshape((out_data.shape[0],) + (1,) * (out_data.dim() - 1)).clamp(min=1) degs = degs.reshape(
(out_data.shape[0],) + (1,) * (out_data.dim() - 1)
).clamp(min=1)
out_data = out_data / degs out_data = out_data / degs
else: else:
degs = None degs = None
...@@ -499,22 +720,38 @@ class CopyReduce(th.autograd.Function): ...@@ -499,22 +720,38 @@ class CopyReduce(th.autograd.Function):
in_data_nd = zerocopy_to_dgl_ndarray(in_data) in_data_nd = zerocopy_to_dgl_ndarray(in_data)
out_data_nd = zerocopy_to_dgl_ndarray(out_data) out_data_nd = zerocopy_to_dgl_ndarray(out_data)
grad_in = None grad_in = None
if reducer == 'mean': if reducer == "mean":
grad_out = grad_out / degs grad_out = grad_out / degs
grad_out_nd = zerocopy_to_dgl_ndarray(grad_out) grad_out_nd = zerocopy_to_dgl_ndarray(grad_out)
if ctx.needs_input_grad[3]: if ctx.needs_input_grad[3]:
grad_in = grad_out.new_empty(in_data_nd.shape) grad_in = grad_out.new_empty(in_data_nd.shape)
K.backward_copy_reduce( K.backward_copy_reduce(
reducer if reducer != 'mean' else 'sum', reducer if reducer != "mean" else "sum",
graph, target, in_data_nd, out_data_nd, grad_out_nd, graph,
zerocopy_to_dgl_ndarray(grad_in), in_map[1], out_map[1]) target,
in_data_nd,
out_data_nd,
grad_out_nd,
zerocopy_to_dgl_ndarray(grad_in),
in_map[1],
out_map[1],
)
return None, None, None, grad_in, None, None, None, None return None, None, None, grad_in, None, None, None, None
def copy_reduce(reducer, graph, target, in_data, out_size, in_map=(None, None), def copy_reduce(
out_map=(None, None)): reducer,
graph,
target,
in_data,
out_size,
in_map=(None, None),
out_map=(None, None),
):
out_data = in_data.new_empty((out_size,) + in_data.shape[1:]) out_data = in_data.new_empty((out_size,) + in_data.shape[1:])
return CopyReduce.apply(reducer, graph, target, in_data, out_data, out_size, in_map, out_map) return CopyReduce.apply(
reducer, graph, target, in_data, out_data, out_size, in_map, out_map
)
def _reduce_grad(grad, shape): def _reduce_grad(grad, shape):
...@@ -543,15 +780,19 @@ def _reduce_grad(grad, shape): ...@@ -543,15 +780,19 @@ def _reduce_grad(grad, shape):
num_to_squeeze = len(grad_shape) - len(in_shape) num_to_squeeze = len(grad_shape) - len(in_shape)
# pad inshape # pad inshape
in_shape = (1,) * num_to_squeeze + in_shape in_shape = (1,) * num_to_squeeze + in_shape
reduce_idx = th.nonzero(th.tensor(grad_shape) - th.tensor(in_shape), as_tuple=False) reduce_idx = th.nonzero(
th.tensor(grad_shape) - th.tensor(in_shape), as_tuple=False
)
reduce_idx += 1 # skip batch dim reduce_idx += 1 # skip batch dim
grad = grad.sum(dim=tuple(reduce_idx), keepdim=True) grad = grad.sum(dim=tuple(reduce_idx), keepdim=True)
return grad.view(shape) return grad.view(shape)
def sync(): def sync():
# Pytorch performs computation synchronously, so no need for synchronization. # Pytorch performs computation synchronously, so no need for synchronization.
pass pass
def attach_grad(x): def attach_grad(x):
if x.grad is not None: if x.grad is not None:
x.grad.zero_() x.grad.zero_()
...@@ -559,21 +800,30 @@ def attach_grad(x): ...@@ -559,21 +800,30 @@ def attach_grad(x):
else: else:
return x.requires_grad_() return x.requires_grad_()
def backward(x, head_gradient=None): def backward(x, head_gradient=None):
if head_gradient is not None and head_gradient.shape[0] == 1 and len(head_gradient.shape) == 1: if (
head_gradient is not None
and head_gradient.shape[0] == 1
and len(head_gradient.shape) == 1
):
# Fix for torch 1.3.1 # Fix for torch 1.3.1
head_gradient = th.tensor(head_gradient.item()).to(head_gradient.device) head_gradient = th.tensor(head_gradient.item()).to(head_gradient.device)
x.backward(head_gradient) x.backward(head_gradient)
def grad(x): def grad(x):
return x.grad return x.grad
def is_no_grad(x): def is_no_grad(x):
return x.grad is None or (x.grad == 0).all() return x.grad is None or (x.grad == 0).all()
def is_recording(): def is_recording():
return th.is_grad_enabled() return th.is_grad_enabled()
class record_grad(object): class record_grad(object):
def __init__(self): def __init__(self):
pass pass
...@@ -584,4 +834,5 @@ class record_grad(object): ...@@ -584,4 +834,5 @@ class record_grad(object):
def __exit__(self, exc_type, exc_value, exc_traceback): def __exit__(self, exc_type, exc_value, exc_traceback):
pass pass
no_grad = th.no_grad no_grad = th.no_grad
import argparse import argparse
import os
import json import json
import os
def set_default_backend(default_dir, backend_name): def set_default_backend(default_dir, backend_name):
os.makedirs(default_dir, exist_ok=True) os.makedirs(default_dir, exist_ok=True)
config_path = os.path.join(default_dir, 'config.json') config_path = os.path.join(default_dir, "config.json")
with open(config_path, "w") as config_file: with open(config_path, "w") as config_file:
json.dump({'backend': backend_name.lower()}, config_file) json.dump({"backend": backend_name.lower()}, config_file)
print('Setting the default backend to "{}". You can change it in the ' print(
'~/.dgl/config.json file or export the DGLBACKEND environment variable. ' 'Setting the default backend to "{}". You can change it in the '
'Valid options are: pytorch, mxnet, tensorflow (all lowercase)'.format( "~/.dgl/config.json file or export the DGLBACKEND environment variable. "
backend_name)) "Valid options are: pytorch, mxnet, tensorflow (all lowercase)".format(
backend_name
)
)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("default_dir", type=str, default=os.path.join(os.path.expanduser('~'), '.dgl')) parser.add_argument(
parser.add_argument("backend", nargs=1, type=str, choices=[ "default_dir",
'pytorch', 'tensorflow', 'mxnet'], help="Set default backend") type=str,
default=os.path.join(os.path.expanduser("~"), ".dgl"),
)
parser.add_argument(
"backend",
nargs=1,
type=str,
choices=["pytorch", "tensorflow", "mxnet"],
help="Set default backend",
)
args = parser.parse_args() args = parser.parse_args()
set_default_backend(args.default_dir, args.backend[0]) set_default_backend(args.default_dir, args.backend[0])
import os import os
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
from .tensor import * os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"
from .sparse import * from .sparse import *
from .tensor import *
import tensorflow as tf
import numpy as np import numpy as np
from .tensor import tensor, copy_to, context, asnumpy, zerocopy_from_numpy import tensorflow as tf
from ...base import is_all, ALL
from ...sparse import _gspmm, _gsddmm, _segment_reduce, _bwd_segment_cmp, _scatter_add
from ...sparse import _csrmm, _csrsum, _csrmask
from ...heterograph_index import create_unitgraph_from_csr
__all__ = ['gspmm', 'gsddmm', 'edge_softmax', 'segment_reduce', 'scatter_add', from ...base import ALL, is_all
'csrmm', 'csrsum', 'csrmask'] from ...heterograph_index import create_unitgraph_from_csr
from ...sparse import (
_bwd_segment_cmp,
_csrmask,
_csrmm,
_csrsum,
_gsddmm,
_gspmm,
_scatter_add,
_segment_reduce,
)
from .tensor import asnumpy, context, copy_to, tensor, zerocopy_from_numpy
__all__ = [
"gspmm",
"gsddmm",
"edge_softmax",
"segment_reduce",
"scatter_add",
"csrmm",
"csrsum",
"csrmask",
]
def _scatter_nd(index, src, n_rows): def _scatter_nd(index, src, n_rows):
...@@ -21,7 +38,10 @@ def _scatter_nd(index, src, n_rows): ...@@ -21,7 +38,10 @@ def _scatter_nd(index, src, n_rows):
di = shp[i] di = shp[i]
offset_i = tf.range(di, dtype=index.dtype) offset_i = tf.range(di, dtype=index.dtype)
offsets.append( offsets.append(
tf.reshape((stride * offset_i), (1,) * i + (di,) + (1,) * (ndim - 1 - i))) tf.reshape(
(stride * offset_i), (1,) * i + (di,) + (1,) * (ndim - 1 - i)
)
)
stride *= di stride *= di
if ndim > 1: if ndim > 1:
new_idx = index * stride + copy_to(sum(offsets), ctx) new_idx = index * stride + copy_to(sum(offsets), ctx)
...@@ -29,7 +49,9 @@ def _scatter_nd(index, src, n_rows): ...@@ -29,7 +49,9 @@ def _scatter_nd(index, src, n_rows):
new_idx = index new_idx = index
src = tf.reshape(src, (-1,)) src = tf.reshape(src, (-1,))
new_idx = tf.reshape(new_idx, (-1, 1)) new_idx = tf.reshape(new_idx, (-1, 1))
rst = tf.reshape(tf.scatter_nd(new_idx, src, (stride * n_rows,)), (n_rows, *shp[1:])) rst = tf.reshape(
tf.scatter_nd(new_idx, src, (stride * n_rows,)), (n_rows, *shp[1:])
)
return rst return rst
...@@ -43,7 +65,10 @@ def _gather_nd(index, src): ...@@ -43,7 +65,10 @@ def _gather_nd(index, src):
di = shp[i] di = shp[i]
offset_i = tf.range(di, dtype=index.dtype) offset_i = tf.range(di, dtype=index.dtype)
offsets.append( offsets.append(
tf.reshape((stride * offset_i), (1,) * i + (di,) + (1,) * (ndim - 1 - i))) tf.reshape(
(stride * offset_i), (1,) * i + (di,) + (1,) * (ndim - 1 - i)
)
)
stride *= di stride *= di
if ndim > 1: if ndim > 1:
new_idx = index * stride + copy_to(sum(offsets), ctx) new_idx = index * stride + copy_to(sum(offsets), ctx)
...@@ -78,10 +103,13 @@ def _reduce_grad(grad, shape): ...@@ -78,10 +103,13 @@ def _reduce_grad(grad, shape):
num_to_squeeze = len(grad_shape) - len(in_shape) num_to_squeeze = len(grad_shape) - len(in_shape)
# pad inshape # pad inshape
in_shape = (1,) * num_to_squeeze + in_shape in_shape = (1,) * num_to_squeeze + in_shape
reduce_idx = np.asarray(np.nonzero(np.asarray(grad_shape) - np.asarray(in_shape))) reduce_idx = np.asarray(
np.nonzero(np.asarray(grad_shape) - np.asarray(in_shape))
)
reduce_idx += 1 # skip batch dim reduce_idx += 1 # skip batch dim
reduce_idx_tensor = tf.constant(tuple( reduce_idx_tensor = tf.constant(
reduce_idx.flatten().tolist()), dtype=tf.int32) tuple(reduce_idx.flatten().tolist()), dtype=tf.int32
)
grad = tf.reduce_sum(grad, axis=reduce_idx_tensor, keepdims=True) grad = tf.reduce_sum(grad, axis=reduce_idx_tensor, keepdims=True)
return tf.reshape(grad, shape) return tf.reshape(grad, shape)
...@@ -96,11 +124,11 @@ def _need_reduce_last_dim(ufeat, efeat): ...@@ -96,11 +124,11 @@ def _need_reduce_last_dim(ufeat, efeat):
def _muldiv(op, x): def _muldiv(op, x):
return 1. / x if op == 'div' else x return 1.0 / x if op == "div" else x
def _addsub(op, x): def _addsub(op, x):
return -x if op == 'sub' else x return -x if op == "sub" else x
def _expand(x, shape): def _expand(x, shape):
...@@ -112,49 +140,55 @@ def gspmm_real(gidx, op, reduce_op, X, Y): ...@@ -112,49 +140,55 @@ def gspmm_real(gidx, op, reduce_op, X, Y):
def grad(dZ): def grad(dZ):
dZ = tensor(dZ) dZ = tensor(dZ)
if op != 'copy_rhs': if op != "copy_rhs":
g_rev = gidx.reverse() g_rev = gidx.reverse()
if reduce_op == 'sum': if reduce_op == "sum":
if op in ['mul', 'div']: if op in ["mul", "div"]:
dX = _gspmm(g_rev, 'mul', 'sum', dZ, _muldiv(op, Y))[0] dX = _gspmm(g_rev, "mul", "sum", dZ, _muldiv(op, Y))[0]
elif op in ['add', 'sub']: elif op in ["add", "sub"]:
dX = _gspmm(g_rev, 'copy_lhs', 'sum', dZ, Y)[0] dX = _gspmm(g_rev, "copy_lhs", "sum", dZ, Y)[0]
elif op == 'copy_lhs': elif op == "copy_lhs":
dX = _gspmm(g_rev, 'copy_lhs', 'sum', dZ, None)[0] dX = _gspmm(g_rev, "copy_lhs", "sum", dZ, None)[0]
else: else:
if op in ['mul', 'div']: if op in ["mul", "div"]:
dX = _scatter_nd( dX = _scatter_nd(
argX, argX,
_muldiv(op, _gather_nd(argY, _expand(Y, dZ.shape[1:]))) * dZ, _muldiv(op, _gather_nd(argY, _expand(Y, dZ.shape[1:])))
X.shape[0]) * dZ,
elif op in ['add', 'sub', 'copy_lhs']: X.shape[0],
)
elif op in ["add", "sub", "copy_lhs"]:
dX = _scatter_nd(argX, dZ, X.shape[0]) dX = _scatter_nd(argX, dZ, X.shape[0])
dX = _reduce_grad(dX, X.shape) dX = _reduce_grad(dX, X.shape)
else: else:
dX = tf.zeros_like(X) dX = tf.zeros_like(X)
if op != 'copy_lhs': if op != "copy_lhs":
if reduce_op == 'sum': if reduce_op == "sum":
if op == 'mul' and _need_reduce_last_dim(X, Y): if op == "mul" and _need_reduce_last_dim(X, Y):
dY = _gsddmm(gidx, 'dot', X, dZ) dY = _gsddmm(gidx, "dot", X, dZ)
elif op in ['mul', 'div']: elif op in ["mul", "div"]:
dY = _gsddmm(gidx, 'mul', X, dZ) dY = _gsddmm(gidx, "mul", X, dZ)
if op == 'div': dY = -dY / (Y ** 2) if op == "div":
elif op in ['add', 'sub', 'copy_rhs']: dY = -dY / (Y**2)
dY = _gsddmm(gidx, 'copy_rhs', X, _addsub(op, dZ)) elif op in ["add", "sub", "copy_rhs"]:
dY = _gsddmm(gidx, "copy_rhs", X, _addsub(op, dZ))
else: else:
out_shp = (Y.shape[0],) + dZ.shape[1:] out_shp = (Y.shape[0],) + dZ.shape[1:]
if op in ['mul', 'div']: if op in ["mul", "div"]:
dY = _scatter_nd( dY = _scatter_nd(
argY, argY,
_gather_nd(argX, _expand(X, dZ.shape[1:])) * dZ, _gather_nd(argX, _expand(X, dZ.shape[1:])) * dZ,
Y.shape[0]) Y.shape[0],
if op == 'div': dY = -dY / (Y ** 2) )
elif op in ['add', 'sub', 'copy_rhs']: if op == "div":
dY = -dY / (Y**2)
elif op in ["add", "sub", "copy_rhs"]:
dY = _scatter_nd(argY, _addsub(op, dZ), Y.shape[0]) dY = _scatter_nd(argY, _addsub(op, dZ), Y.shape[0])
dY = _reduce_grad(dY, Y.shape) dY = _reduce_grad(dY, Y.shape)
else: else:
dY = tf.zeros_like(Y) dY = tf.zeros_like(Y)
return dX, dY return dX, dY
return out, grad return out, grad
...@@ -162,6 +196,7 @@ def gspmm(gidx, op, reduce_op, X, Y): ...@@ -162,6 +196,7 @@ def gspmm(gidx, op, reduce_op, X, Y):
@tf.custom_gradient @tf.custom_gradient
def _lambda(X, Y): def _lambda(X, Y):
return gspmm_real(gidx, op, reduce_op, X, Y) return gspmm_real(gidx, op, reduce_op, X, Y)
if X is None: if X is None:
X = tf.zeros(()) X = tf.zeros(())
if Y is None: if Y is None:
...@@ -173,58 +208,68 @@ def gsddmm_real(gidx, op, X, Y, lhs_target, rhs_target): ...@@ -173,58 +208,68 @@ def gsddmm_real(gidx, op, X, Y, lhs_target, rhs_target):
out = _gsddmm(gidx, op, X, Y, lhs_target, rhs_target) out = _gsddmm(gidx, op, X, Y, lhs_target, rhs_target)
def grad(dZ): def grad(dZ):
if op != 'copy_rhs': if op != "copy_rhs":
if lhs_target in ['u', 'v']: if lhs_target in ["u", "v"]:
_gidx = gidx if lhs_target == 'v' else gidx.reverse() _gidx = gidx if lhs_target == "v" else gidx.reverse()
if op in ['add', 'sub', 'copy_lhs']: if op in ["add", "sub", "copy_lhs"]:
dX = _gspmm(_gidx, 'copy_rhs', 'sum', None, dZ)[0] dX = _gspmm(_gidx, "copy_rhs", "sum", None, dZ)[0]
else: # mul, div, dot else: # mul, div, dot
if rhs_target == lhs_target: if rhs_target == lhs_target:
dX = _gspmm(_gidx, 'copy_rhs', 'sum', None, dZ)[0] * _muldiv(op, Y) dX = _gspmm(_gidx, "copy_rhs", "sum", None, dZ)[
elif rhs_target == 'e': 0
dX = _gspmm(_gidx, 'copy_rhs', 'sum', None, dZ * _muldiv(op, Y))[0] ] * _muldiv(op, Y)
elif rhs_target == "e":
dX = _gspmm(
_gidx, "copy_rhs", "sum", None, dZ * _muldiv(op, Y)
)[0]
else: # rhs_target = !lhs_target else: # rhs_target = !lhs_target
dX = _gspmm(_gidx, 'mul', 'sum', _muldiv(op, Y), dZ)[0] dX = _gspmm(_gidx, "mul", "sum", _muldiv(op, Y), dZ)[0]
else: # lhs_target == 'e' else: # lhs_target == 'e'
if op in ['add', 'sub', 'copy_lhs']: if op in ["add", "sub", "copy_lhs"]:
dX = dZ dX = dZ
else: # mul, div, dot else: # mul, div, dot
dX = _gsddmm(gidx, 'mul', dZ, _muldiv(op, Y), 'e', rhs_target) dX = _gsddmm(
gidx, "mul", dZ, _muldiv(op, Y), "e", rhs_target
)
dX = _reduce_grad(dX, X.shape) dX = _reduce_grad(dX, X.shape)
else: else:
dX = tf.zeros_like(X) dX = tf.zeros_like(X)
if op != 'copy_lhs': if op != "copy_lhs":
if rhs_target in ['u', 'v']: if rhs_target in ["u", "v"]:
_gidx = gidx if rhs_target == 'v' else gidx.reverse() _gidx = gidx if rhs_target == "v" else gidx.reverse()
if op in ['add', 'sub', 'copy_rhs']: if op in ["add", "sub", "copy_rhs"]:
dY = _gspmm(_gidx, 'copy_rhs', 'sum', None, _addsub(op, dZ))[0] dY = _gspmm(
_gidx, "copy_rhs", "sum", None, _addsub(op, dZ)
)[0]
else: # mul, div, dot else: # mul, div, dot
if lhs_target == rhs_target: if lhs_target == rhs_target:
dY = _gspmm(_gidx, 'copy_rhs', 'sum', None, dZ)[0] * X dY = _gspmm(_gidx, "copy_rhs", "sum", None, dZ)[0] * X
elif lhs_target == 'e': elif lhs_target == "e":
dY = _gspmm(_gidx, 'copy_rhs', 'sum', None, dZ * X)[0] dY = _gspmm(_gidx, "copy_rhs", "sum", None, dZ * X)[0]
else: # rhs_target = !lhs_target else: # rhs_target = !lhs_target
dY = _gspmm(_gidx, 'mul', 'sum', X, dZ)[0] dY = _gspmm(_gidx, "mul", "sum", X, dZ)[0]
if op == 'div': if op == "div":
dY = -dY / (Y ** 2) dY = -dY / (Y**2)
else: else:
if op in ['add', 'sub', 'copy_rhs']: if op in ["add", "sub", "copy_rhs"]:
dY = _addsub(op, dZ) dY = _addsub(op, dZ)
else: # mul, div, dot else: # mul, div, dot
dY = _gsddmm(gidx, 'mul', dZ, X, 'e', lhs_target) dY = _gsddmm(gidx, "mul", dZ, X, "e", lhs_target)
if op == 'div': if op == "div":
dY = -dY / (Y ** 2) dY = -dY / (Y**2)
dY = _reduce_grad(dY, Y.shape) dY = _reduce_grad(dY, Y.shape)
else: else:
dY = tf.zeros_like(Y) dY = tf.zeros_like(Y)
return dX, dY return dX, dY
return out, grad return out, grad
def gsddmm(gidx, op, X, Y, lhs_target='u', rhs_target='v'): def gsddmm(gidx, op, X, Y, lhs_target="u", rhs_target="v"):
@tf.custom_gradient @tf.custom_gradient
def _lambda(X, Y): def _lambda(X, Y):
return gsddmm_real(gidx, op, X, Y, lhs_target, rhs_target) return gsddmm_real(gidx, op, X, Y, lhs_target, rhs_target)
if X is None: if X is None:
X = tf.zeros(()) X = tf.zeros(())
if Y is None: if Y is None:
...@@ -232,29 +277,30 @@ def gsddmm(gidx, op, X, Y, lhs_target='u', rhs_target='v'): ...@@ -232,29 +277,30 @@ def gsddmm(gidx, op, X, Y, lhs_target='u', rhs_target='v'):
return _lambda(X, Y) return _lambda(X, Y)
def edge_softmax_real(gidx, score, eids=ALL, norm_by='dst'): def edge_softmax_real(gidx, score, eids=ALL, norm_by="dst"):
if not is_all(eids): if not is_all(eids):
gidx = gidx.edge_subgraph([eids], True).graph gidx = gidx.edge_subgraph([eids], True).graph
if norm_by == 'src': if norm_by == "src":
gidx = gidx.reverse() gidx = gidx.reverse()
score_max = _gspmm(gidx, 'copy_rhs', 'max', None, score)[0] score_max = _gspmm(gidx, "copy_rhs", "max", None, score)[0]
score = tf.math.exp(_gsddmm(gidx, 'sub', score, score_max, 'e', 'v')) score = tf.math.exp(_gsddmm(gidx, "sub", score, score_max, "e", "v"))
score_sum = _gspmm(gidx, 'copy_rhs', 'sum', None, score)[0] score_sum = _gspmm(gidx, "copy_rhs", "sum", None, score)[0]
out = _gsddmm(gidx, 'div', score, score_sum, 'e', 'v') out = _gsddmm(gidx, "div", score, score_sum, "e", "v")
def edge_softmax_backward(grad_out): def edge_softmax_backward(grad_out):
sds = out * grad_out sds = out * grad_out
accum = gspmm(gidx, 'copy_rhs', 'sum', None, sds) accum = gspmm(gidx, "copy_rhs", "sum", None, sds)
grad_score = sds - gsddmm(gidx, 'mul', out, accum, 'e', 'v') grad_score = sds - gsddmm(gidx, "mul", out, accum, "e", "v")
return grad_score return grad_score
return out, edge_softmax_backward return out, edge_softmax_backward
def edge_softmax(gidx, logits, eids=ALL, norm_by='dst'): def edge_softmax(gidx, logits, eids=ALL, norm_by="dst"):
@tf.custom_gradient @tf.custom_gradient
def _lambda(logits): def _lambda(logits):
return edge_softmax_real(gidx, logits, eids, norm_by) return edge_softmax_real(gidx, logits, eids, norm_by)
return _lambda(logits) return _lambda(logits)
...@@ -263,7 +309,7 @@ def segment_reduce_real(op, x, offsets): ...@@ -263,7 +309,7 @@ def segment_reduce_real(op, x, offsets):
def segment_reduce_backward(dy): def segment_reduce_backward(dy):
m = x.shape[0] m = x.shape[0]
if op == 'sum': if op == "sum":
offsets_np = asnumpy(offsets[1:]) offsets_np = asnumpy(offsets[1:])
indices_np = np.zeros((m + 1,), dtype=offsets_np.dtype) indices_np = np.zeros((m + 1,), dtype=offsets_np.dtype)
np.add.at(indices_np, offsets_np, np.ones_like(offsets_np)) np.add.at(indices_np, offsets_np, np.ones_like(offsets_np))
...@@ -281,6 +327,7 @@ def segment_reduce(op, x, offsets): ...@@ -281,6 +327,7 @@ def segment_reduce(op, x, offsets):
@tf.custom_gradient @tf.custom_gradient
def _lambda(x): def _lambda(x):
return segment_reduce_real(op, x, offsets) return segment_reduce_real(op, x, offsets)
return _lambda(x) return _lambda(x)
...@@ -289,7 +336,7 @@ def scatter_add_real(x, idx, m): ...@@ -289,7 +336,7 @@ def scatter_add_real(x, idx, m):
def scatter_add_backward(dy): def scatter_add_backward(dy):
return tf.gather(dy, idx) return tf.gather(dy, idx)
return y, scatter_add_backward return y, scatter_add_backward
...@@ -297,53 +344,102 @@ def scatter_add(x, idx, m): ...@@ -297,53 +344,102 @@ def scatter_add(x, idx, m):
@tf.custom_gradient @tf.custom_gradient
def _lambda(x): def _lambda(x):
return scatter_add_real(x, idx, m) return scatter_add_real(x, idx, m)
return _lambda(x) return _lambda(x)
def csrmm_real(gidxA, A_weights, gidxB, B_weights, num_vtypes): def csrmm_real(gidxA, A_weights, gidxB, B_weights, num_vtypes):
gidxC, C_weights = _csrmm(gidxA, A_weights, gidxB, B_weights, num_vtypes) gidxC, C_weights = _csrmm(gidxA, A_weights, gidxB, B_weights, num_vtypes)
nrows, ncols, C_indptr, C_indices, C_eids = gidxC.adjacency_matrix_tensors(0, False, 'csr') nrows, ncols, C_indptr, C_indices, C_eids = gidxC.adjacency_matrix_tensors(
0, False, "csr"
)
def grad(dnrows, dncols, dC_indptr, dC_indices, dC_eids, dC_weights): def grad(dnrows, dncols, dC_indptr, dC_indices, dC_eids, dC_weights):
# Only the last argument is meaningful. # Only the last argument is meaningful.
dgidxA, dA_weights = _csrmm( dgidxA, dA_weights = _csrmm(
gidxC, dC_weights, gidxB.reverse(), B_weights, gidxA.number_of_ntypes()) gidxC,
dC_weights,
gidxB.reverse(),
B_weights,
gidxA.number_of_ntypes(),
)
dgidxB, dB_weights = _csrmm( dgidxB, dB_weights = _csrmm(
gidxA.reverse(), A_weights, gidxC, dC_weights, gidxB.number_of_ntypes()) gidxA.reverse(),
A_weights,
gidxC,
dC_weights,
gidxB.number_of_ntypes(),
)
dA_weights = _csrmask(dgidxA, dA_weights, gidxA) dA_weights = _csrmask(dgidxA, dA_weights, gidxA)
dB_weights = _csrmask(dgidxB, dB_weights, gidxB) dB_weights = _csrmask(dgidxB, dB_weights, gidxB)
return dA_weights, dB_weights return dA_weights, dB_weights
return (tf.constant(nrows), tf.constant(ncols), C_indptr, C_indices, C_eids, C_weights), grad
return (
tf.constant(nrows),
tf.constant(ncols),
C_indptr,
C_indices,
C_eids,
C_weights,
), grad
def csrmm(gidxA, A_weights, gidxB, B_weights, num_vtypes): def csrmm(gidxA, A_weights, gidxB, B_weights, num_vtypes):
@tf.custom_gradient @tf.custom_gradient
def _lambda(A_weights, B_weights): def _lambda(A_weights, B_weights):
return csrmm_real(gidxA, A_weights, gidxB, B_weights, num_vtypes) return csrmm_real(gidxA, A_weights, gidxB, B_weights, num_vtypes)
nrows, ncols, C_indptr, C_indices, C_eids, C_weights = _lambda(A_weights, B_weights)
nrows, ncols, C_indptr, C_indices, C_eids, C_weights = _lambda(
A_weights, B_weights
)
gidxC = create_unitgraph_from_csr( gidxC = create_unitgraph_from_csr(
num_vtypes, nrows.numpy(), ncols.numpy(), C_indptr, C_indices, C_eids, num_vtypes,
["coo", "csr", "csc"]) nrows.numpy(),
ncols.numpy(),
C_indptr,
C_indices,
C_eids,
["coo", "csr", "csc"],
)
return gidxC, C_weights return gidxC, C_weights
def csrsum_real(gidxs, weights): def csrsum_real(gidxs, weights):
gidxC, C_weights = _csrsum(gidxs, weights) gidxC, C_weights = _csrsum(gidxs, weights)
nrows, ncols, C_indptr, C_indices, C_eids = gidxC.adjacency_matrix_tensors(0, False, 'csr') nrows, ncols, C_indptr, C_indices, C_eids = gidxC.adjacency_matrix_tensors(
0, False, "csr"
)
def grad(dnrows, dncols, dC_indptr, dC_indices, dC_eids, dC_weights): def grad(dnrows, dncols, dC_indptr, dC_indices, dC_eids, dC_weights):
# Only the last argument is meaningful. # Only the last argument is meaningful.
return tuple(_csrmask(gidxC, dC_weights, gidx) for gidx in gidxs) return tuple(_csrmask(gidxC, dC_weights, gidx) for gidx in gidxs)
return (tf.constant(nrows), tf.constant(ncols), C_indptr, C_indices, C_eids, C_weights), grad
return (
tf.constant(nrows),
tf.constant(ncols),
C_indptr,
C_indices,
C_eids,
C_weights,
), grad
def csrsum(gidxs, weights): def csrsum(gidxs, weights):
@tf.custom_gradient @tf.custom_gradient
def _lambda(*weights): def _lambda(*weights):
return csrsum_real(gidxs, weights) return csrsum_real(gidxs, weights)
nrows, ncols, C_indptr, C_indices, C_eids, C_weights = _lambda(*weights) nrows, ncols, C_indptr, C_indices, C_eids, C_weights = _lambda(*weights)
num_vtypes = gidxs[0].number_of_ntypes() num_vtypes = gidxs[0].number_of_ntypes()
gidxC = create_unitgraph_from_csr( gidxC = create_unitgraph_from_csr(
num_vtypes, nrows.numpy(), ncols.numpy(), C_indptr, C_indices, C_eids, num_vtypes,
["coo", "csr", "csc"]) nrows.numpy(),
ncols.numpy(),
C_indptr,
C_indices,
C_eids,
["coo", "csr", "csc"],
)
return gidxC, C_weights return gidxC, C_weights
...@@ -352,10 +448,13 @@ def csrmask_real(gidxA, A_weights, gidxB): ...@@ -352,10 +448,13 @@ def csrmask_real(gidxA, A_weights, gidxB):
def grad(dB_weights): def grad(dB_weights):
return _csrmask(gidxB, dB_weights, gidxA) return _csrmask(gidxB, dB_weights, gidxA)
return B_weights, grad return B_weights, grad
def csrmask(gidxA, A_weights, gidxB): def csrmask(gidxA, A_weights, gidxB):
@tf.custom_gradient @tf.custom_gradient
def _lambda(A_weights): def _lambda(A_weights):
return csrmask_real(gidxA, A_weights, gidxB) return csrmask_real(gidxA, A_weights, gidxB)
return _lambda(A_weights) return _lambda(A_weights)
"""Sparse optimizer is not supported for tensorflow""" """Sparse optimizer is not supported for tensorflow"""
\ No newline at end of file
"""Tensorflow backend implementation""" """Tensorflow backend implementation"""
from __future__ import absolute_import from __future__ import absolute_import
from distutils.version import LooseVersion
import tensorflow as tf
import builtins import builtins
import numbers import numbers
from distutils.version import LooseVersion
import numpy as np import numpy as np
import tensorflow as tf
from ... import ndarray as nd from ... import ndarray as nd
from ..._deprecate import kernel as K from ..._deprecate import kernel as K
from ...function.base import TargetCode from ...function.base import TargetCode
if LooseVersion(tf.__version__) < LooseVersion("2.3.0"): if LooseVersion(tf.__version__) < LooseVersion("2.3.0"):
raise RuntimeError("DGL requires TensorFlow>=2.3.0 for the official DLPack support.") raise RuntimeError(
"DGL requires TensorFlow>=2.3.0 for the official DLPack support."
)
def zerocopy_to_dlpack(data): def zerocopy_to_dlpack(data):
return tf.experimental.dlpack.to_dlpack(data) return tf.experimental.dlpack.to_dlpack(data)
def zerocopy_from_dlpack(dlpack_tensor): def zerocopy_from_dlpack(dlpack_tensor):
# TODO(Jinjing): Tensorflow requires memory to be 64-bytes aligned. We check the # TODO(Jinjing): Tensorflow requires memory to be 64-bytes aligned. We check the
# alignment and make a copy if needed. The functionality is better in TF's main repo. # alignment and make a copy if needed. The functionality is better in TF's main repo.
...@@ -26,15 +30,17 @@ def zerocopy_from_dlpack(dlpack_tensor): ...@@ -26,15 +30,17 @@ def zerocopy_from_dlpack(dlpack_tensor):
def data_type_dict(): def data_type_dict():
return {'float16': tf.float16, return {
'float32': tf.float32, "float16": tf.float16,
'float64': tf.float64, "float32": tf.float32,
'uint8': tf.uint8, "float64": tf.float64,
'int8': tf.int8, "uint8": tf.uint8,
'int16': tf.int16, "int8": tf.int8,
'int32': tf.int32, "int16": tf.int16,
'int64': tf.int64, "int32": tf.int32,
'bool' : tf.bool} "int64": tf.int64,
"bool": tf.bool,
}
def cpu(): def cpu():
...@@ -73,18 +79,22 @@ def get_preferred_sparse_format(): ...@@ -73,18 +79,22 @@ def get_preferred_sparse_format():
def sparse_matrix(data, index, shape, force_format=False): def sparse_matrix(data, index, shape, force_format=False):
fmt = index[0] fmt = index[0]
if fmt != 'coo': if fmt != "coo":
raise TypeError( raise TypeError(
'Tensorflow backend only supports COO format. But got %s.' % fmt) "Tensorflow backend only supports COO format. But got %s." % fmt
)
# tf.SparseTensor only supports int64 indexing, # tf.SparseTensor only supports int64 indexing,
# therefore manually casting to int64 when input in int32 # therefore manually casting to int64 when input in int32
spmat = tf.SparseTensor(indices=tf.cast(tf.transpose( spmat = tf.SparseTensor(
index[1], (1, 0)), tf.int64), values=data, dense_shape=shape) indices=tf.cast(tf.transpose(index[1], (1, 0)), tf.int64),
values=data,
dense_shape=shape,
)
return spmat, None return spmat, None
def sparse_matrix_indices(spmat): def sparse_matrix_indices(spmat):
return ('coo', spmat.indices) return ("coo", spmat.indices)
def is_tensor(obj): def is_tensor(obj):
...@@ -107,6 +117,7 @@ def context(input): ...@@ -107,6 +117,7 @@ def context(input):
spec = tf.DeviceSpec.from_string(input.device) spec = tf.DeviceSpec.from_string(input.device)
return "/{}:{}".format(spec.device_type.lower(), spec.device_index) return "/{}:{}".format(spec.device_type.lower(), spec.device_index)
def device_type(ctx): def device_type(ctx):
return tf.DeviceSpec.from_string(ctx).device_type.lower() return tf.DeviceSpec.from_string(ctx).device_type.lower()
...@@ -122,7 +133,7 @@ def to_backend_ctx(dglctx): ...@@ -122,7 +133,7 @@ def to_backend_ctx(dglctx):
elif dev_type == 2: elif dev_type == 2:
return "/gpu:%d" % (dglctx.device_id) return "/gpu:%d" % (dglctx.device_id)
else: else:
raise ValueError('Unsupported DGL device context:', dglctx) raise ValueError("Unsupported DGL device context:", dglctx)
def astype(input, ty): def astype(input, ty):
...@@ -143,17 +154,21 @@ def copy_to(input, ctx, **kwargs): ...@@ -143,17 +154,21 @@ def copy_to(input, ctx, **kwargs):
new_tensor = tf.identity(input) new_tensor = tf.identity(input)
return new_tensor return new_tensor
def is_pinned(input): def is_pinned(input):
return False # not sure how to do this return False # not sure how to do this
def sum(input, dim, keepdims=False): def sum(input, dim, keepdims=False):
if input.dtype == tf.bool: if input.dtype == tf.bool:
input = tf.cast(input, tf.int32) input = tf.cast(input, tf.int32)
return tf.reduce_sum(input, axis=dim, keepdims=keepdims) return tf.reduce_sum(input, axis=dim, keepdims=keepdims)
def floor_div(in1, in2): def floor_div(in1, in2):
return astype(in1 / in2, dtype(in1)) return astype(in1 / in2, dtype(in1))
def reduce_sum(input): def reduce_sum(input):
if input.dtype == tf.bool: if input.dtype == tf.bool:
input = tf.cast(input, tf.int32) input = tf.cast(input, tf.int32)
...@@ -192,9 +207,13 @@ def reduce_min(input): ...@@ -192,9 +207,13 @@ def reduce_min(input):
def argsort(input, dim, descending): def argsort(input, dim, descending):
if descending: if descending:
return tf.cast(tf.argsort(input, axis=dim, direction="DESCENDING"), dtype=tf.int64) return tf.cast(
tf.argsort(input, axis=dim, direction="DESCENDING"), dtype=tf.int64
)
else: else:
return tf.cast(tf.argsort(input, axis=dim, direction="ASCENDING"), dtype=tf.int64) return tf.cast(
tf.argsort(input, axis=dim, direction="ASCENDING"), dtype=tf.int64
)
def topk(input, k, dim, descending=True): def topk(input, k, dim, descending=True):
...@@ -248,7 +267,10 @@ def stack(seq, dim): ...@@ -248,7 +267,10 @@ def stack(seq, dim):
def split(input, sizes_or_sections, dim): def split(input, sizes_or_sections, dim):
return [copy_to(_, input.device) for _ in tf.split(input, sizes_or_sections, axis=dim)] return [
copy_to(_, input.device)
for _ in tf.split(input, sizes_or_sections, axis=dim)
]
def repeat(input, repeats, dim): def repeat(input, repeats, dim):
...@@ -283,7 +305,9 @@ def scatter_row(data, row_index, value): ...@@ -283,7 +305,9 @@ def scatter_row(data, row_index, value):
# notorious legacy issue that int32 type data is always on CPU, which will # notorious legacy issue that int32 type data is always on CPU, which will
# crash the program since DGL requires feature data to be on the same device # crash the program since DGL requires feature data to be on the same device
# as graph structure. # as graph structure.
return copy_to(tf.tensor_scatter_nd_update(data, row_index, value), data.device) return copy_to(
tf.tensor_scatter_nd_update(data, row_index, value), data.device
)
def index_add_inplace(data, row_idx, value): def index_add_inplace(data, row_idx, value):
...@@ -357,10 +381,11 @@ def pad_packed_tensor(input, lengths, value, l_min=None): ...@@ -357,10 +381,11 @@ def pad_packed_tensor(input, lengths, value, l_min=None):
cum_row = 0 cum_row = 0
pad_nparray = np.zeros((ndim, 2), dtype=np.int32) pad_nparray = np.zeros((ndim, 2), dtype=np.int32)
for l in lengths: for l in lengths:
t = input[cum_row:cum_row+l] t = input[cum_row : cum_row + l]
pad_nparray[0, 1] = max_len - l pad_nparray[0, 1] = max_len - l
t = tf.pad(t, tf.constant(pad_nparray), t = tf.pad(
mode='CONSTANT', constant_values=value) t, tf.constant(pad_nparray), mode="CONSTANT", constant_values=value
)
tensor_list.append(t) tensor_list.append(t)
cum_row += l cum_row += l
return tf.stack(tensor_list, axis=0) return tf.stack(tensor_list, axis=0)
...@@ -384,26 +409,35 @@ def equal(x, y): ...@@ -384,26 +409,35 @@ def equal(x, y):
def allclose(x, y, rtol=1e-4, atol=1e-4): def allclose(x, y, rtol=1e-4, atol=1e-4):
return np.allclose(tf.convert_to_tensor(x).numpy(), return np.allclose(
tf.convert_to_tensor(y).numpy(), rtol=rtol, atol=atol) tf.convert_to_tensor(x).numpy(),
tf.convert_to_tensor(y).numpy(),
rtol=rtol,
atol=atol,
)
def logical_not(input): def logical_not(input):
return ~input return ~input
def logical_and(input1, input2): def logical_and(input1, input2):
return tf.math.logical_and(input1, input2) return tf.math.logical_and(input1, input2)
def clone(input): def clone(input):
# TF tensor is always immutable so returning the input is safe. # TF tensor is always immutable so returning the input is safe.
return input return input
def clamp(data, min_val, max_val): def clamp(data, min_val, max_val):
return tf.clip_by_value(data, min_val, max_val) return tf.clip_by_value(data, min_val, max_val)
def replace_inf_with_zero(x): def replace_inf_with_zero(x):
return tf.where(tf.abs(x) == np.inf, 0, x) return tf.where(tf.abs(x) == np.inf, 0, x)
def count_nonzero(input): def count_nonzero(input):
return int(tf.math.count_nonzero(input)) return int(tf.math.count_nonzero(input))
...@@ -429,7 +463,7 @@ def full_1d(length, fill_value, dtype, ctx): ...@@ -429,7 +463,7 @@ def full_1d(length, fill_value, dtype, ctx):
def nonzero_1d(input): def nonzero_1d(input):
nonzero_bool = tf.cast(input, tf.bool) nonzero_bool = tf.cast(input, tf.bool)
return tf.reshape(tf.where(nonzero_bool), (-1, )) return tf.reshape(tf.where(nonzero_bool), (-1,))
def sort_1d(input): def sort_1d(input):
...@@ -461,7 +495,7 @@ def zerocopy_from_numpy(np_array): ...@@ -461,7 +495,7 @@ def zerocopy_from_numpy(np_array):
def zerocopy_to_dgl_ndarray(data): def zerocopy_to_dgl_ndarray(data):
if device_type(data.device) == 'gpu' and data.dtype in (tf.int32, tf.int64): if device_type(data.device) == "gpu" and data.dtype in (tf.int32, tf.int64):
# NOTE: TF doesn't keep signed tensors on GPU due to legacy issues with # NOTE: TF doesn't keep signed tensors on GPU due to legacy issues with
# shape inference. Convert it to unsigned and cast it back afterwards. # shape inference. Convert it to unsigned and cast it back afterwards.
if data.dtype == tf.int32: if data.dtype == tf.int32:
...@@ -481,35 +515,78 @@ def zerocopy_from_dgl_ndarray(input): ...@@ -481,35 +515,78 @@ def zerocopy_from_dgl_ndarray(input):
return zerocopy_from_dlpack(input.to_dlpack()) return zerocopy_from_dlpack(input.to_dlpack())
def binary_reduce(reducer, binary_op, graph, lhs, rhs, lhs_data, rhs_data, def binary_reduce(
out_size, lhs_map=(None, None), rhs_map=(None, None), out_map=(None, None)): reducer,
binary_op,
graph,
lhs,
rhs,
lhs_data,
rhs_data,
out_size,
lhs_map=(None, None),
rhs_map=(None, None),
out_map=(None, None),
):
@tf.custom_gradient @tf.custom_gradient
def _lambda(lhs_data, rhs_data): def _lambda(lhs_data, rhs_data):
return binary_reduce_real(reducer, binary_op, graph, lhs, rhs, lhs_data, rhs_data, return binary_reduce_real(
out_size, lhs_map, rhs_map, out_map) reducer,
binary_op,
graph,
lhs,
rhs,
lhs_data,
rhs_data,
out_size,
lhs_map,
rhs_map,
out_map,
)
return _lambda(lhs_data, rhs_data) return _lambda(lhs_data, rhs_data)
def binary_reduce_real(reducer, binary_op, graph, lhs, rhs, lhs_data, rhs_data, def binary_reduce_real(
out_size, lhs_map, rhs_map, out_map): reducer,
binary_op,
graph,
lhs,
rhs,
lhs_data,
rhs_data,
out_size,
lhs_map,
rhs_map,
out_map,
):
with tf.device(lhs_data.device): with tf.device(lhs_data.device):
lhs_data_nd = zerocopy_to_dgl_ndarray(lhs_data) lhs_data_nd = zerocopy_to_dgl_ndarray(lhs_data)
rhs_data_nd = zerocopy_to_dgl_ndarray(rhs_data) rhs_data_nd = zerocopy_to_dgl_ndarray(rhs_data)
feat_shape = K.infer_binary_feature_shape( feat_shape = K.infer_binary_feature_shape(
binary_op, lhs_data_nd, rhs_data_nd) binary_op, lhs_data_nd, rhs_data_nd
)
out_shape = feat_shape out_shape = feat_shape
if binary_op == 'dot': if binary_op == "dot":
out_shape = feat_shape[:-1] out_shape = feat_shape[:-1]
out_data = tf.zeros((out_size,) + out_shape, dtype=lhs_data.dtype) out_data = tf.zeros((out_size,) + out_shape, dtype=lhs_data.dtype)
out_data_nd = zerocopy_to_dgl_ndarray(out_data) out_data_nd = zerocopy_to_dgl_ndarray(out_data)
K.binary_op_reduce( K.binary_op_reduce(
reducer if reducer != 'mean' else 'sum', reducer if reducer != "mean" else "sum",
binary_op, graph, lhs, rhs, lhs_data_nd, rhs_data_nd, binary_op,
out_data_nd, lhs_map[0], rhs_map[0], out_map[0]) graph,
lhs,
rhs,
lhs_data_nd,
rhs_data_nd,
out_data_nd,
lhs_map[0],
rhs_map[0],
out_map[0],
)
# normalize if mean reducer # normalize if mean reducer
# NOTE(zihao): this is a temporary hack and we should have better solution in the future. # NOTE(zihao): this is a temporary hack and we should have better solution in the future.
if reducer == 'mean': if reducer == "mean":
degs = tf.zeros((out_data.shape[0],), dtype=lhs_data.dtype) degs = tf.zeros((out_data.shape[0],), dtype=lhs_data.dtype)
degs_nd = zerocopy_to_dgl_ndarray(degs) degs_nd = zerocopy_to_dgl_ndarray(degs)
if lhs != TargetCode.DST: # src or edge if lhs != TargetCode.DST: # src or edge
...@@ -523,12 +600,15 @@ def binary_reduce_real(reducer, binary_op, graph, lhs, rhs, lhs_data, rhs_data, ...@@ -523,12 +600,15 @@ def binary_reduce_real(reducer, binary_op, graph, lhs, rhs, lhs_data, rhs_data,
in_ones = tf.ones((n,), dtype=lhs_data.dtype) in_ones = tf.ones((n,), dtype=lhs_data.dtype)
in_ones_nd = zerocopy_to_dgl_ndarray(in_ones) in_ones_nd = zerocopy_to_dgl_ndarray(in_ones)
K.copy_reduce( K.copy_reduce(
'sum', graph, target, in_ones_nd, degs_nd, in_map, out_map[0]) "sum", graph, target, in_ones_nd, degs_nd, in_map, out_map[0]
)
# reshape # reshape
degs = tf.reshape(degs, degs = tf.reshape(
(out_data.shape[0],) + (1,) * (out_data.ndim - 1)) degs, (out_data.shape[0],) + (1,) * (out_data.ndim - 1)
degs = tf.clip_by_value(degs, clip_value_min=1, )
clip_value_max=np.inf) # ??? degs = tf.clip_by_value(
degs, clip_value_min=1, clip_value_max=np.inf
) # ???
out_data = out_data / degs out_data = out_data / degs
else: else:
degs = None degs = None
...@@ -537,80 +617,129 @@ def binary_reduce_real(reducer, binary_op, graph, lhs, rhs, lhs_data, rhs_data, ...@@ -537,80 +617,129 @@ def binary_reduce_real(reducer, binary_op, graph, lhs, rhs, lhs_data, rhs_data,
with tf.device(grad_out.device): with tf.device(grad_out.device):
grad_lhs = None grad_lhs = None
grad_rhs = None grad_rhs = None
if reducer == 'mean': if reducer == "mean":
grad_out = grad_out / degs grad_out = grad_out / degs
grad_out_nd = zerocopy_to_dgl_ndarray(grad_out) grad_out_nd = zerocopy_to_dgl_ndarray(grad_out)
# comptue gradient for lhs # comptue gradient for lhs
grad_lhs = tf.zeros((lhs_data_nd.shape[0],) + feat_shape) grad_lhs = tf.zeros((lhs_data_nd.shape[0],) + feat_shape)
K.backward_lhs_binary_op_reduce( K.backward_lhs_binary_op_reduce(
reducer if reducer != 'mean' else 'sum', reducer if reducer != "mean" else "sum",
binary_op, graph, lhs, rhs, lhs_data_nd, rhs_data_nd, binary_op,
out_data_nd, grad_out_nd, zerocopy_to_dgl_ndarray(grad_lhs), graph,
lhs_map[1], rhs_map[1], out_map[1]) lhs,
rhs,
lhs_data_nd,
rhs_data_nd,
out_data_nd,
grad_out_nd,
zerocopy_to_dgl_ndarray(grad_lhs),
lhs_map[1],
rhs_map[1],
out_map[1],
)
grad_lhs = _reduce_grad(grad_lhs, lhs_data_nd.shape) grad_lhs = _reduce_grad(grad_lhs, lhs_data_nd.shape)
# compute gradient for rhs # compute gradient for rhs
grad_rhs = tf.zeros((rhs_data_nd.shape[0],) + feat_shape) grad_rhs = tf.zeros((rhs_data_nd.shape[0],) + feat_shape)
K.backward_rhs_binary_op_reduce( K.backward_rhs_binary_op_reduce(
reducer if reducer != 'mean' else 'sum', reducer if reducer != "mean" else "sum",
binary_op, graph, lhs, rhs, lhs_data_nd, rhs_data_nd, binary_op,
out_data_nd, grad_out_nd, zerocopy_to_dgl_ndarray(grad_rhs), graph,
lhs_map[1], rhs_map[1], out_map[1]) lhs,
rhs,
lhs_data_nd,
rhs_data_nd,
out_data_nd,
grad_out_nd,
zerocopy_to_dgl_ndarray(grad_rhs),
lhs_map[1],
rhs_map[1],
out_map[1],
)
grad_rhs = _reduce_grad(grad_rhs, rhs_data_nd.shape) grad_rhs = _reduce_grad(grad_rhs, rhs_data_nd.shape)
return grad_lhs, grad_rhs return grad_lhs, grad_rhs
return out_data, grad return out_data, grad
def copy_reduce(reducer, graph, target, in_data, out_size, in_map=(None, None), def copy_reduce(
out_map=(None, None)): reducer,
graph,
target,
in_data,
out_size,
in_map=(None, None),
out_map=(None, None),
):
@tf.custom_gradient @tf.custom_gradient
def _lambda(in_data): def _lambda(in_data):
return copy_reduce_real(reducer, graph, target, in_data, out_size, in_map, return copy_reduce_real(
out_map) reducer, graph, target, in_data, out_size, in_map, out_map
)
return _lambda(in_data) return _lambda(in_data)
def copy_reduce_real(reducer, graph, target, in_data, out_size, in_map, def copy_reduce_real(
out_map): reducer, graph, target, in_data, out_size, in_map, out_map
):
with tf.device(in_data.device): with tf.device(in_data.device):
out_data = tf.zeros( out_data = tf.zeros(
(out_size,) + tuple(in_data.shape[1:]), dtype=in_data.dtype) (out_size,) + tuple(in_data.shape[1:]), dtype=in_data.dtype
)
in_data_nd = zerocopy_to_dgl_ndarray(in_data) in_data_nd = zerocopy_to_dgl_ndarray(in_data)
out_data_nd = zerocopy_to_dgl_ndarray(out_data) out_data_nd = zerocopy_to_dgl_ndarray(out_data)
K.copy_reduce( K.copy_reduce(
reducer if reducer != 'mean' else 'sum', reducer if reducer != "mean" else "sum",
graph, target, in_data_nd, out_data_nd, in_map[0], out_map[0]) graph,
target,
in_data_nd,
out_data_nd,
in_map[0],
out_map[0],
)
# normalize if mean reducer # normalize if mean reducer
# NOTE(zihao): this is a temporary hack and we should have better solution in the future. # NOTE(zihao): this is a temporary hack and we should have better solution in the future.
if reducer == 'mean': if reducer == "mean":
in_ones = tf.ones(in_data.shape[0], dtype=in_data.dtype) in_ones = tf.ones(in_data.shape[0], dtype=in_data.dtype)
degs = tf.zeros(out_data.shape[0], dtype=in_data.dtype) degs = tf.zeros(out_data.shape[0], dtype=in_data.dtype)
in_ones_nd = zerocopy_to_dgl_ndarray(in_ones) in_ones_nd = zerocopy_to_dgl_ndarray(in_ones)
degs_nd = zerocopy_to_dgl_ndarray(degs) degs_nd = zerocopy_to_dgl_ndarray(degs)
K.copy_reduce( K.copy_reduce(
'sum', graph, target, in_ones_nd, degs_nd, in_map[0], out_map[0]) "sum", graph, target, in_ones_nd, degs_nd, in_map[0], out_map[0]
)
# reshape # reshape
degs = tf.reshape(degs, degs = tf.reshape(
(out_data.shape[0],) + (1,) * (out_data.ndim - 1)) degs, (out_data.shape[0],) + (1,) * (out_data.ndim - 1)
degs = tf.clip_by_value(degs, clip_value_min=1, )
clip_value_max=np.inf) # TODO: ??? degs = tf.clip_by_value(
degs, clip_value_min=1, clip_value_max=np.inf
) # TODO: ???
out_data = out_data / degs out_data = out_data / degs
else: else:
degs = None degs = None
def grad(grad_out): def grad(grad_out):
with tf.device(grad_out.device): with tf.device(grad_out.device):
if reducer == 'mean': if reducer == "mean":
grad_out = grad_out / degs grad_out = grad_out / degs
grad_out_nd = zerocopy_to_dgl_ndarray(grad_out) grad_out_nd = zerocopy_to_dgl_ndarray(grad_out)
grad_in = tf.zeros(in_data_nd.shape) grad_in = tf.zeros(in_data_nd.shape)
K.backward_copy_reduce( K.backward_copy_reduce(
reducer if reducer != 'mean' else 'sum', reducer if reducer != "mean" else "sum",
graph, target, in_data_nd, out_data_nd, grad_out_nd, graph,
zerocopy_to_dgl_ndarray(grad_in), in_map[1], out_map[1]) target,
in_data_nd,
out_data_nd,
grad_out_nd,
zerocopy_to_dgl_ndarray(grad_in),
in_map[1],
out_map[1],
)
return grad_in return grad_in
return out_data, grad return out_data, grad
...@@ -640,10 +769,11 @@ def _reduce_grad(grad, shape): ...@@ -640,10 +769,11 @@ def _reduce_grad(grad, shape):
num_to_squeeze = len(grad_shape) - len(in_shape) num_to_squeeze = len(grad_shape) - len(in_shape)
# pad inshape # pad inshape
in_shape = (1,) * num_to_squeeze + in_shape in_shape = (1,) * num_to_squeeze + in_shape
reduce_idx = np.asarray(np.nonzero(np.asarray(grad_shape) - np.asarray(in_shape))) reduce_idx = np.asarray(
np.nonzero(np.asarray(grad_shape) - np.asarray(in_shape))
)
reduce_idx += 1 # skip batch dim reduce_idx += 1 # skip batch dim
reduce_idx_tensor = tf.constant(tuple( reduce_idx_tensor = tf.constant(tuple(reduce_idx.flatten().tolist()))
reduce_idx.flatten().tolist()))
grad = tf.reduce_sum(grad, axis=reduce_idx_tensor, keepdims=True) grad = tf.reduce_sum(grad, axis=reduce_idx_tensor, keepdims=True)
return tf.reshape(grad, shape) return tf.reshape(grad, shape)
...@@ -741,6 +871,7 @@ def is_no_grad(x): ...@@ -741,6 +871,7 @@ def is_no_grad(x):
def is_recording(): def is_recording():
raise NotImplementedError("Tensorflow doesn't support is_recording") raise NotImplementedError("Tensorflow doesn't support is_recording")
no_grad = None no_grad = None
initialize_context() initialize_context()
...@@ -11,39 +11,48 @@ ALL = "__ALL__" ...@@ -11,39 +11,48 @@ ALL = "__ALL__"
# An alias for [:] # An alias for [:]
SLICE_FULL = slice(None, None, None) SLICE_FULL = slice(None, None, None)
# Reserved column names for storing parent node/edge types and IDs in flattened heterographs # Reserved column names for storing parent node/edge types and IDs in flattened heterographs
NTYPE = '_TYPE' NTYPE = "_TYPE"
NID = '_ID' NID = "_ID"
ETYPE = '_TYPE' ETYPE = "_TYPE"
EID = '_ID' EID = "_ID"
_INTERNAL_COLUMNS = {NTYPE, NID, ETYPE, EID} _INTERNAL_COLUMNS = {NTYPE, NID, ETYPE, EID}
def is_internal_column(name): def is_internal_column(name):
"""Return true if the column name is reversed by DGL.""" """Return true if the column name is reversed by DGL."""
return name in _INTERNAL_COLUMNS return name in _INTERNAL_COLUMNS
def is_all(arg): def is_all(arg):
"""Return true if the argument is a special symbol for all nodes or edges.""" """Return true if the argument is a special symbol for all nodes or edges."""
return isinstance(arg, str) and arg == ALL return isinstance(arg, str) and arg == ALL
# pylint: disable=invalid-name # pylint: disable=invalid-name
_default_formatwarning = warnings.formatwarning _default_formatwarning = warnings.formatwarning
class DGLWarning(UserWarning): class DGLWarning(UserWarning):
"""DGL Warning class.""" """DGL Warning class."""
# pylint: disable=unused-argument # pylint: disable=unused-argument
def dgl_warning_format(message, category, filename, lineno, line=None): def dgl_warning_format(message, category, filename, lineno, line=None):
"""Format DGL warnings.""" """Format DGL warnings."""
if isinstance(category, DGLWarning): if isinstance(category, DGLWarning):
return "DGL Warning: {}\n".format(message) return "DGL Warning: {}\n".format(message)
else: else:
return _default_formatwarning(message, category, filename, lineno, line=None) return _default_formatwarning(
message, category, filename, lineno, line=None
)
def dgl_warning(message, category=DGLWarning, stacklevel=2): def dgl_warning(message, category=DGLWarning, stacklevel=2):
"""DGL warning wrapper that defaults to ``DGLWarning`` instead of ``UserWarning`` category.""" """DGL warning wrapper that defaults to ``DGLWarning`` instead of ``UserWarning`` category."""
return warnings.warn(message, category=category, stacklevel=stacklevel) return warnings.warn(message, category=category, stacklevel=stacklevel)
warnings.formatwarning = dgl_warning_format warnings.formatwarning = dgl_warning_format
_init_internal_api() _init_internal_api()
...@@ -2,9 +2,10 @@ ...@@ -2,9 +2,10 @@
reference: tvm/python/tvm/collections.py reference: tvm/python/tvm/collections.py
""" """
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from . import _api_internal
from ._ffi.object import ObjectBase, register_object from ._ffi.object import ObjectBase, register_object
from ._ffi.object_generic import convert_to_object from ._ffi.object_generic import convert_to_object
from . import _api_internal
@register_object @register_object
...@@ -29,8 +30,11 @@ class List(ObjectBase): ...@@ -29,8 +30,11 @@ class List(ObjectBase):
return [self[idx] for idx in range(start, stop, step)] return [self[idx] for idx in range(start, stop, step)]
if i < -len(self) or i >= len(self): if i < -len(self) or i >= len(self):
raise IndexError("List index out of range. List size: {}, got index {}" raise IndexError(
.format(len(self), i)) "List index out of range. List size: {}, got index {}".format(
len(self), i
)
)
if i < 0: if i < 0:
i += len(self) i += len(self)
ret = _api_internal._ListGetItem(self, i) ret = _api_internal._ListGetItem(self, i)
...@@ -60,7 +64,7 @@ class Map(ObjectBase): ...@@ -60,7 +64,7 @@ class Map(ObjectBase):
def items(self): def items(self):
"""Get the items from the map""" """Get the items from the map"""
akvs = _api_internal._MapItems(self) akvs = _api_internal._MapItems(self)
return [(akvs[i], akvs[i+1]) for i in range(0, len(akvs), 2)] return [(akvs[i], akvs[i + 1]) for i in range(0, len(akvs), 2)]
def __len__(self): def __len__(self):
return _api_internal._MapSize(self) return _api_internal._MapSize(self)
...@@ -76,12 +80,13 @@ class StrMap(Map): ...@@ -76,12 +80,13 @@ class StrMap(Map):
def items(self): def items(self):
"""Get the items from the map""" """Get the items from the map"""
akvs = _api_internal._MapItems(self) akvs = _api_internal._MapItems(self)
return [(akvs[i], akvs[i+1]) for i in range(0, len(akvs), 2)] return [(akvs[i], akvs[i + 1]) for i in range(0, len(akvs), 2)]
@register_object @register_object
class Value(ObjectBase): class Value(ObjectBase):
"""Object wrapper for various values.""" """Object wrapper for various values."""
@property @property
def data(self): def data(self):
"""Return the value data.""" """Return the value data."""
......
...@@ -2,17 +2,19 @@ ...@@ -2,17 +2,19 @@
# pylint: disable=not-callable # pylint: disable=not-callable
import numpy as np import numpy as np
from .base import DGLError, is_all, NID, EID, ALL, dgl_warning
from . import backend as F from . import backend as F
from . import function as fn from . import function as fn
from .frame import Frame
from .udf import NodeBatch, EdgeBatch
from . import ops from . import ops
from .base import ALL, EID, NID, DGLError, dgl_warning, is_all
from .frame import Frame
from .udf import EdgeBatch, NodeBatch
def is_builtin(func): def is_builtin(func):
"""Return true if the function is a DGL builtin function.""" """Return true if the function is a DGL builtin function."""
return isinstance(func, fn.BuiltinFunction) return isinstance(func, fn.BuiltinFunction)
def invoke_node_udf(graph, nid, ntype, func, *, ndata=None, orig_nid=None): def invoke_node_udf(graph, nid, ntype, func, *, ndata=None, orig_nid=None):
"""Invoke user-defined node function on the given nodes. """Invoke user-defined node function on the given nodes.
...@@ -43,9 +45,12 @@ def invoke_node_udf(graph, nid, ntype, func, *, ndata=None, orig_nid=None): ...@@ -43,9 +45,12 @@ def invoke_node_udf(graph, nid, ntype, func, *, ndata=None, orig_nid=None):
nid = graph.nodes(ntype=ntype) nid = graph.nodes(ntype=ntype)
else: else:
ndata = graph._node_frames[ntid].subframe(nid) ndata = graph._node_frames[ntid].subframe(nid)
nbatch = NodeBatch(graph, nid if orig_nid is None else orig_nid, ntype, ndata) nbatch = NodeBatch(
graph, nid if orig_nid is None else orig_nid, ntype, ndata
)
return func(nbatch) return func(nbatch)
def invoke_edge_udf(graph, eid, etype, func, *, orig_eid=None): def invoke_edge_udf(graph, eid, etype, func, *, orig_eid=None):
"""Invoke user-defined edge function on the given edges. """Invoke user-defined edge function on the given edges.
...@@ -70,20 +75,29 @@ def invoke_edge_udf(graph, eid, etype, func, *, orig_eid=None): ...@@ -70,20 +75,29 @@ def invoke_edge_udf(graph, eid, etype, func, *, orig_eid=None):
etid = graph.get_etype_id(etype) etid = graph.get_etype_id(etype)
stid, dtid = graph._graph.metagraph.find_edge(etid) stid, dtid = graph._graph.metagraph.find_edge(etid)
if is_all(eid): if is_all(eid):
u, v, eid = graph.edges(form='all') u, v, eid = graph.edges(form="all")
edata = graph._edge_frames[etid] edata = graph._edge_frames[etid]
else: else:
u, v = graph.find_edges(eid) u, v = graph.find_edges(eid)
edata = graph._edge_frames[etid].subframe(eid) edata = graph._edge_frames[etid].subframe(eid)
if len(u) == 0: if len(u) == 0:
dgl_warning('The input graph for the user-defined edge function ' \ dgl_warning(
'does not contain valid edges') "The input graph for the user-defined edge function "
"does not contain valid edges"
)
srcdata = graph._node_frames[stid].subframe(u) srcdata = graph._node_frames[stid].subframe(u)
dstdata = graph._node_frames[dtid].subframe(v) dstdata = graph._node_frames[dtid].subframe(v)
ebatch = EdgeBatch(graph, eid if orig_eid is None else orig_eid, ebatch = EdgeBatch(
etype, srcdata, edata, dstdata) graph,
eid if orig_eid is None else orig_eid,
etype,
srcdata,
edata,
dstdata,
)
return func(ebatch) return func(ebatch)
def invoke_udf_reduce(graph, func, msgdata, *, orig_nid=None): def invoke_udf_reduce(graph, func, msgdata, *, orig_nid=None):
"""Invoke user-defined reduce function on all the nodes in the graph. """Invoke user-defined reduce function on all the nodes in the graph.
...@@ -119,7 +133,9 @@ def invoke_udf_reduce(graph, func, msgdata, *, orig_nid=None): ...@@ -119,7 +133,9 @@ def invoke_udf_reduce(graph, func, msgdata, *, orig_nid=None):
unique_degs, bucketor = _bucketing(degs) unique_degs, bucketor = _bucketing(degs)
bkt_rsts = [] bkt_rsts = []
bkt_nodes = [] bkt_nodes = []
for deg, node_bkt, orig_nid_bkt in zip(unique_degs, bucketor(nodes), bucketor(orig_nid)): for deg, node_bkt, orig_nid_bkt in zip(
unique_degs, bucketor(nodes), bucketor(orig_nid)
):
if deg == 0: if deg == 0:
# skip reduce function for zero-degree nodes # skip reduce function for zero-degree nodes
continue continue
...@@ -127,7 +143,7 @@ def invoke_udf_reduce(graph, func, msgdata, *, orig_nid=None): ...@@ -127,7 +143,7 @@ def invoke_udf_reduce(graph, func, msgdata, *, orig_nid=None):
ndata_bkt = dstdata.subframe(node_bkt) ndata_bkt = dstdata.subframe(node_bkt)
# order the incoming edges per node by edge ID # order the incoming edges per node by edge ID
eid_bkt = F.zerocopy_to_numpy(graph.in_edges(node_bkt, form='eid')) eid_bkt = F.zerocopy_to_numpy(graph.in_edges(node_bkt, form="eid"))
assert len(eid_bkt) == deg * len(node_bkt) assert len(eid_bkt) == deg * len(node_bkt)
eid_bkt = np.sort(eid_bkt.reshape((len(node_bkt), deg)), 1) eid_bkt = np.sort(eid_bkt.reshape((len(node_bkt), deg)), 1)
eid_bkt = F.zerocopy_from_numpy(eid_bkt.flatten()) eid_bkt = F.zerocopy_from_numpy(eid_bkt.flatten())
...@@ -148,7 +164,9 @@ def invoke_udf_reduce(graph, func, msgdata, *, orig_nid=None): ...@@ -148,7 +164,9 @@ def invoke_udf_reduce(graph, func, msgdata, *, orig_nid=None):
retf._default_initializer = dstdata._default_initializer retf._default_initializer = dstdata._default_initializer
# merge bucket results and write to the result frame # merge bucket results and write to the result frame
if len(bkt_rsts) != 0: # if all the nodes have zero degree, no need to merge results. if (
len(bkt_rsts) != 0
): # if all the nodes have zero degree, no need to merge results.
merged_rst = {} merged_rst = {}
for k in bkt_rsts[0].keys(): for k in bkt_rsts[0].keys():
merged_rst[k] = F.cat([rst[k] for rst in bkt_rsts], dim=0) merged_rst[k] = F.cat([rst[k] for rst in bkt_rsts], dim=0)
...@@ -157,6 +175,7 @@ def invoke_udf_reduce(graph, func, msgdata, *, orig_nid=None): ...@@ -157,6 +175,7 @@ def invoke_udf_reduce(graph, func, msgdata, *, orig_nid=None):
return retf return retf
def _bucketing(val): def _bucketing(val):
"""Internal function to create groups on the values. """Internal function to create groups on the values.
...@@ -179,11 +198,14 @@ def _bucketing(val): ...@@ -179,11 +198,14 @@ def _bucketing(val):
for v in unique_val: for v in unique_val:
eqidx = F.nonzero_1d(F.equal(sorted_val, v)) eqidx = F.nonzero_1d(F.equal(sorted_val, v))
bkt_idx.append(F.gather_row(idx, eqidx)) bkt_idx.append(F.gather_row(idx, eqidx))
def bucketor(data): def bucketor(data):
bkts = [F.gather_row(data, idx) for idx in bkt_idx] bkts = [F.gather_row(data, idx) for idx in bkt_idx]
return bkts return bkts
return unique_val, bucketor return unique_val, bucketor
def data_dict_to_list(graph, data_dict, func, target): def data_dict_to_list(graph, data_dict, func, target):
"""Get node or edge feature data of the given name for all the types. """Get node or edge feature data of the given name for all the types.
...@@ -206,23 +228,23 @@ def data_dict_to_list(graph, data_dict, func, target): ...@@ -206,23 +228,23 @@ def data_dict_to_list(graph, data_dict, func, target):
data of type ``types[i]``. data of type ``types[i]``.
""" """
if isinstance(func, fn.BinaryMessageFunction): if isinstance(func, fn.BinaryMessageFunction):
if target in ['u', 'v']: if target in ["u", "v"]:
output_list = [None] * graph._graph.number_of_ntypes() output_list = [None] * graph._graph.number_of_ntypes()
for srctype, _, dsttype in graph.canonical_etypes: for srctype, _, dsttype in graph.canonical_etypes:
if target == 'u': if target == "u":
src_id = graph.get_ntype_id(srctype) src_id = graph.get_ntype_id(srctype)
output_list[src_id] = data_dict[srctype] output_list[src_id] = data_dict[srctype]
else: else:
dst_id = graph.get_ntype_id(dsttype) dst_id = graph.get_ntype_id(dsttype)
output_list[dst_id] = data_dict[dsttype] output_list[dst_id] = data_dict[dsttype]
else: # target == 'e' else: # target == 'e'
output_list = [None] * graph._graph.number_of_etypes() output_list = [None] * graph._graph.number_of_etypes()
for rel in graph.canonical_etypes: for rel in graph.canonical_etypes:
etid = graph.get_etype_id(rel) etid = graph.get_etype_id(rel)
output_list[etid] = data_dict[rel] output_list[etid] = data_dict[rel]
return output_list return output_list
else: else:
if target == 'u': if target == "u":
lhs_list = [None] * graph._graph.number_of_ntypes() lhs_list = [None] * graph._graph.number_of_ntypes()
if not isinstance(data_dict, dict): if not isinstance(data_dict, dict):
src_id, _ = graph._graph.metagraph.find_edge(0) src_id, _ = graph._graph.metagraph.find_edge(0)
...@@ -232,13 +254,14 @@ def data_dict_to_list(graph, data_dict, func, target): ...@@ -232,13 +254,14 @@ def data_dict_to_list(graph, data_dict, func, target):
src_id = graph.get_ntype_id(srctype) src_id = graph.get_ntype_id(srctype)
lhs_list[src_id] = data_dict[srctype] lhs_list[src_id] = data_dict[srctype]
return lhs_list return lhs_list
else: # target == 'e': else: # target == 'e':
rhs_list = [None] * graph._graph.number_of_etypes() rhs_list = [None] * graph._graph.number_of_etypes()
for rel in graph.canonical_etypes: for rel in graph.canonical_etypes:
etid = graph.get_etype_id(rel) etid = graph.get_etype_id(rel)
rhs_list[etid] = data_dict[rel] rhs_list[etid] = data_dict[rel]
return rhs_list return rhs_list
def invoke_gsddmm(graph, func): def invoke_gsddmm(graph, func):
"""Invoke g-SDDMM computation on the graph. """Invoke g-SDDMM computation on the graph.
...@@ -270,13 +293,16 @@ def invoke_gsddmm(graph, func): ...@@ -270,13 +293,16 @@ def invoke_gsddmm(graph, func):
if graph._graph.number_of_etypes() > 1: if graph._graph.number_of_etypes() > 1:
# Convert to list as dict is unordered. # Convert to list as dict is unordered.
if func.name == "copy_u": if func.name == "copy_u":
x = data_dict_to_list(graph, x, func, 'u') x = data_dict_to_list(graph, x, func, "u")
else: # "copy_e" else: # "copy_e"
x = data_dict_to_list(graph, x, func, 'e') x = data_dict_to_list(graph, x, func, "e")
z = op(graph, x) z = op(graph, x)
return {func.out_field : z} return {func.out_field: z}
def invoke_gspmm(graph, mfunc, rfunc, *, srcdata=None, dstdata=None, edata=None): def invoke_gspmm(
graph, mfunc, rfunc, *, srcdata=None, dstdata=None, edata=None
):
"""Invoke g-SPMM computation on the graph. """Invoke g-SPMM computation on the graph.
Parameters Parameters
...@@ -301,9 +327,11 @@ def invoke_gspmm(graph, mfunc, rfunc, *, srcdata=None, dstdata=None, edata=None) ...@@ -301,9 +327,11 @@ def invoke_gspmm(graph, mfunc, rfunc, *, srcdata=None, dstdata=None, edata=None)
""" """
# sanity check # sanity check
if mfunc.out_field != rfunc.msg_field: if mfunc.out_field != rfunc.msg_field:
raise DGLError('Invalid message ({}) and reduce ({}) function pairs.' raise DGLError(
' The output field of the message function must be equal to the' "Invalid message ({}) and reduce ({}) function pairs."
' message field of the reduce function.'.format(mfunc, rfunc)) " The output field of the message function must be equal to the"
" message field of the reduce function.".format(mfunc, rfunc)
)
if edata is None: if edata is None:
edata = graph.edata edata = graph.edata
if srcdata is None: if srcdata is None:
...@@ -315,7 +343,7 @@ def invoke_gspmm(graph, mfunc, rfunc, *, srcdata=None, dstdata=None, edata=None) ...@@ -315,7 +343,7 @@ def invoke_gspmm(graph, mfunc, rfunc, *, srcdata=None, dstdata=None, edata=None)
if isinstance(mfunc, fn.BinaryMessageFunction): if isinstance(mfunc, fn.BinaryMessageFunction):
x = alldata[mfunc.lhs][mfunc.lhs_field] x = alldata[mfunc.lhs][mfunc.lhs_field]
y = alldata[mfunc.rhs][mfunc.rhs_field] y = alldata[mfunc.rhs][mfunc.rhs_field]
op = getattr(ops, '{}_{}'.format(mfunc.name, rfunc.name)) op = getattr(ops, "{}_{}".format(mfunc.name, rfunc.name))
if graph._graph.number_of_etypes() > 1: if graph._graph.number_of_etypes() > 1:
lhs_target, _, rhs_target = mfunc.name.split("_", 2) lhs_target, _, rhs_target = mfunc.name.split("_", 2)
x = data_dict_to_list(graph, x, mfunc, lhs_target) x = data_dict_to_list(graph, x, mfunc, lhs_target)
...@@ -323,14 +351,15 @@ def invoke_gspmm(graph, mfunc, rfunc, *, srcdata=None, dstdata=None, edata=None) ...@@ -323,14 +351,15 @@ def invoke_gspmm(graph, mfunc, rfunc, *, srcdata=None, dstdata=None, edata=None)
z = op(graph, x, y) z = op(graph, x, y)
else: else:
x = alldata[mfunc.target][mfunc.in_field] x = alldata[mfunc.target][mfunc.in_field]
op = getattr(ops, '{}_{}'.format(mfunc.name, rfunc.name)) op = getattr(ops, "{}_{}".format(mfunc.name, rfunc.name))
if graph._graph.number_of_etypes() > 1 and not isinstance(x, tuple): if graph._graph.number_of_etypes() > 1 and not isinstance(x, tuple):
if mfunc.name == "copy_u": if mfunc.name == "copy_u":
x = data_dict_to_list(graph, x, mfunc, 'u') x = data_dict_to_list(graph, x, mfunc, "u")
else: # "copy_e" else: # "copy_e"
x = data_dict_to_list(graph, x, mfunc, 'e') x = data_dict_to_list(graph, x, mfunc, "e")
z = op(graph, x) z = op(graph, x)
return {rfunc.out_field : z} return {rfunc.out_field: z}
def message_passing(g, mfunc, rfunc, afunc): def message_passing(g, mfunc, rfunc, afunc):
"""Invoke message passing computation on the whole graph. """Invoke message passing computation on the whole graph.
...@@ -351,8 +380,12 @@ def message_passing(g, mfunc, rfunc, afunc): ...@@ -351,8 +380,12 @@ def message_passing(g, mfunc, rfunc, afunc):
dict[str, Tensor] dict[str, Tensor]
Results from the message passing computation. Results from the message passing computation.
""" """
if (is_builtin(mfunc) and is_builtin(rfunc) and if (
getattr(ops, '{}_{}'.format(mfunc.name, rfunc.name), None) is not None): is_builtin(mfunc)
and is_builtin(rfunc)
and getattr(ops, "{}_{}".format(mfunc.name, rfunc.name), None)
is not None
):
# invoke fused message passing # invoke fused message passing
ndata = invoke_gspmm(g, mfunc, rfunc) ndata = invoke_gspmm(g, mfunc, rfunc)
else: else:
...@@ -362,7 +395,9 @@ def message_passing(g, mfunc, rfunc, afunc): ...@@ -362,7 +395,9 @@ def message_passing(g, mfunc, rfunc, afunc):
msgdata = invoke_gsddmm(g, mfunc) msgdata = invoke_gsddmm(g, mfunc)
else: else:
orig_eid = g.edata.get(EID, None) orig_eid = g.edata.get(EID, None)
msgdata = invoke_edge_udf(g, ALL, g.canonical_etypes[0], mfunc, orig_eid=orig_eid) msgdata = invoke_edge_udf(
g, ALL, g.canonical_etypes[0], mfunc, orig_eid=orig_eid
)
# reduce phase # reduce phase
if is_builtin(rfunc): if is_builtin(rfunc):
msg = rfunc.msg_field msg = rfunc.msg_field
...@@ -372,9 +407,11 @@ def message_passing(g, mfunc, rfunc, afunc): ...@@ -372,9 +407,11 @@ def message_passing(g, mfunc, rfunc, afunc):
ndata = invoke_udf_reduce(g, rfunc, msgdata, orig_nid=orig_nid) ndata = invoke_udf_reduce(g, rfunc, msgdata, orig_nid=orig_nid)
# apply phase # apply phase
if afunc is not None: if afunc is not None:
for k, v in g.dstdata.items(): # include original node features for k, v in g.dstdata.items(): # include original node features
if k not in ndata: if k not in ndata:
ndata[k] = v ndata[k] = v
orig_nid = g.dstdata.get(NID, None) orig_nid = g.dstdata.get(NID, None)
ndata = invoke_node_udf(g, ALL, g.dsttypes[0], afunc, ndata=ndata, orig_nid=orig_nid) ndata = invoke_node_udf(
g, ALL, g.dsttypes[0], afunc, ndata=ndata, orig_nid=orig_nid
)
return ndata return ndata
...@@ -3,27 +3,25 @@ ...@@ -3,27 +3,25 @@
from .. import backend as F from .. import backend as F
from .._ffi.function import _init_api from .._ffi.function import _init_api
_COMM_MODES_MAP = { _COMM_MODES_MAP = {"remainder": 0}
'remainder': 0
}
class UniqueId(object): class UniqueId(object):
""" Class for allowing python code to create and communicate NCCL Unique """Class for allowing python code to create and communicate NCCL Unique
IDs, needed for creating communicators. IDs, needed for creating communicators.
""" """
def __init__(self, id_str=None): def __init__(self, id_str=None):
""" Create an object reference the current NCCL unique id. """Create an object reference the current NCCL unique id."""
"""
if id_str: if id_str:
if isinstance(id_str, bytes): if isinstance(id_str, bytes):
id_str = id_str.decode('utf-8') id_str = id_str.decode("utf-8")
self._handle = _CAPI_DGLNCCLUniqueIdFromString(id_str) self._handle = _CAPI_DGLNCCLUniqueIdFromString(id_str)
else: else:
self._handle = _CAPI_DGLNCCLGetUniqueId() self._handle = _CAPI_DGLNCCLGetUniqueId()
def get(self): def get(self):
""" Get the C-handle for this object. """Get the C-handle for this object."""
"""
return self._handle return self._handle
def __str__(self): def __str__(self):
...@@ -37,187 +35,196 @@ class UniqueId(object): ...@@ -37,187 +35,196 @@ class UniqueId(object):
class Communicator(object): class Communicator(object):
""" High-level wrapper for NCCL communication. """High-level wrapper for NCCL communication."""
"""
def __init__(self, size, rank, unique_id): def __init__(self, size, rank, unique_id):
""" Create a new NCCL communicator. """Create a new NCCL communicator.
Parameters Parameters
---------- ----------
size : int size : int
The number of processes in the communicator. The number of processes in the communicator.
rank : int rank : int
The rank of the current process in the communicator. The rank of the current process in the communicator.
unique_id : NCCLUniqueId unique_id : NCCLUniqueId
The unique id of the root process (rank=0). The unique id of the root process (rank=0).
Examples Examples
-------- --------
>>> from dgl.cuda.nccl import Communicator, UniqueId >>> from dgl.cuda.nccl import Communicator, UniqueId
The root process will generate a unique NCCL id and communicate it The root process will generate a unique NCCL id and communicate it
to the other processes. to the other processes.
>>> uid = UniqueId() >>> uid = UniqueId()
>>> store.set('nccl_root_id', str(uid)) >>> store.set('nccl_root_id', str(uid))
And all other processes create unique ids from the root processes. And all other processes create unique ids from the root processes.
>>> uid = UniqueId(store.get('nccl_root_id')) >>> uid = UniqueId(store.get('nccl_root_id'))
Then, all processes should create the communicator. Then, all processes should create the communicator.
>>> comm = Communicator(world_size, rank, uid) >>> comm = Communicator(world_size, rank, uid)
""" """
assert rank < size, "The rank of a process must be less than the " \ assert rank < size, (
"The rank of a process must be less than the "
"size of the communicator." "size of the communicator."
)
self._handle = _CAPI_DGLNCCLCreateComm(size, rank, unique_id.get()) self._handle = _CAPI_DGLNCCLCreateComm(size, rank, unique_id.get())
self._rank = rank self._rank = rank
self._size = size self._size = size
def sparse_all_to_all_push(self, idx, value, partition): def sparse_all_to_all_push(self, idx, value, partition):
""" Perform an all-to-all-v operation, where by all processors send out """Perform an all-to-all-v operation, where by all processors send out
a set of indices and corresponding values. Indices and values, a set of indices and corresponding values. Indices and values,
corresponding to the current process, will copied into the output corresponding to the current process, will copied into the output
arrays. arrays.
Parameters Parameters
---------- ----------
idx : tensor idx : tensor
The 1D set of indices to send to other processors. The 1D set of indices to send to other processors.
value : tensor value : tensor
The multi-dimension set of values to send to other processors. The multi-dimension set of values to send to other processors.
The first dimension must match that of `idx`. The first dimension must match that of `idx`.
partition : NDArrayPartition partition : NDArrayPartition
The object containing information for assigning indices to The object containing information for assigning indices to
processors. processors.
Returns Returns
------- -------
tensor tensor
The 1D tensor of the recieved indices. The 1D tensor of the recieved indices.
tensor tensor
The set of recieved values. The set of recieved values.
Examples Examples
-------- --------
To perform a sparse_all_to_all_push(), a partition object must be To perform a sparse_all_to_all_push(), a partition object must be
provided. A partition of a homgeonous graph, where the vertices are provided. A partition of a homgeonous graph, where the vertices are
striped across processes can be generated via: striped across processes can be generated via:
>>> from dgl.partition import NDArrayPartition >>> from dgl.partition import NDArrayPartition
>>> part = NDArrayPartition(g.num_nodes(), comm.size(), mode='remainder' ) >>> part = NDArrayPartition(g.num_nodes(), comm.size(), mode='remainder' )
With this partition, each processor can send values to be associatd With this partition, each processor can send values to be associatd
with vertices in the graph. So if we have an array `global_idxs` of all of with vertices in the graph. So if we have an array `global_idxs` of all of
the neighbors updated during mini-batch processing, and an array the neighbors updated during mini-batch processing, and an array
`global_values` containing the new values associated with the neighbors, `global_values` containing the new values associated with the neighbors,
we communicate them to the own processes via: we communicate them to the own processes via:
>>> my_idxs, my_values = comm.sparse_all_to_all_push(global_idxs, global_values, part) >>> my_idxs, my_values = comm.sparse_all_to_all_push(global_idxs, global_values, part)
This communication pattern is common when communicating gradient This communication pattern is common when communicating gradient
updates for node embeddings. updates for node embeddings.
Indices the current process owns, do not need to treated specially, Indices the current process owns, do not need to treated specially,
as internally they will be copied to the output array. If we have a as internally they will be copied to the output array. If we have a
set of indices in process 0 '[0, 3, 8, 9, 10]` and for process 1 set of indices in process 0 '[0, 3, 8, 9, 10]` and for process 1
'[0, 2, 4, 5, 8, 8, 9]'. Using a remainder partition will result '[0, 2, 4, 5, 8, 8, 9]'. Using a remainder partition will result
indices for processe 0 of '[0, 8, 10, 0, 2, 4, 8, 8]', and for indices for processe 0 of '[0, 8, 10, 0, 2, 4, 8, 8]', and for
process 1 of '[3, 9, 5, 9]'. process 1 of '[3, 9, 5, 9]'.
""" """
out_idx, out_value = _CAPI_DGLNCCLSparseAllToAllPush( out_idx, out_value = _CAPI_DGLNCCLSparseAllToAllPush(
self.get(), F.zerocopy_to_dgl_ndarray(idx), self.get(),
F.zerocopy_to_dgl_ndarray(idx),
F.zerocopy_to_dgl_ndarray(value), F.zerocopy_to_dgl_ndarray(value),
partition.get()) partition.get(),
return (F.zerocopy_from_dgl_ndarray(out_idx), )
F.zerocopy_from_dgl_ndarray(out_value)) return (
F.zerocopy_from_dgl_ndarray(out_idx),
F.zerocopy_from_dgl_ndarray(out_value),
)
def sparse_all_to_all_pull(self, req_idx, value, partition): def sparse_all_to_all_pull(self, req_idx, value, partition):
""" Perform an all-to-all-v operation, where by all processors request """Perform an all-to-all-v operation, where by all processors request
the values corresponding to their set of indices. the values corresponding to their set of indices.
Parameters Parameters
---------- ----------
req_idx : IdArray req_idx : IdArray
The set of indices this processor is requesting. The set of indices this processor is requesting.
value : NDArray value : NDArray
The multi-dimension set of values that can be requested from The multi-dimension set of values that can be requested from
this processor. this processor.
partition : NDArrayPartition partition : NDArrayPartition
The object containing information for assigning indices to The object containing information for assigning indices to
processors. processors.
Returns Returns
------- -------
tensor tensor
The set of recieved values, corresponding to `req_idx`. The set of recieved values, corresponding to `req_idx`.
Examples Examples
-------- --------
To perform a sparse_all_to_all_pull(), a partition object must be To perform a sparse_all_to_all_pull(), a partition object must be
provided. A partition of a homgeonous graph, where the vertices are provided. A partition of a homgeonous graph, where the vertices are
striped across processes can be generated via: striped across processes can be generated via:
>>> from dgl.partition import NDArrayPartition >>> from dgl.partition import NDArrayPartition
>>> part = NDArrayPartition(g.num_nodes(), comm.size(), mode='remainder' ) >>> part = NDArrayPartition(g.num_nodes(), comm.size(), mode='remainder' )
With this partition, each processor can request values/features With this partition, each processor can request values/features
associated with vertices in the graph. So in the case where we have associated with vertices in the graph. So in the case where we have
a set of neighbors 'nbr_idxs' we need features for, and each process a set of neighbors 'nbr_idxs' we need features for, and each process
has a tensor 'node_feat' storing the features of nodes it owns in has a tensor 'node_feat' storing the features of nodes it owns in
the partition, the features can be requested via: the partition, the features can be requested via:
>>> nbr_values = comm.sparse_all_to_all_pull(nbr_idxs, node_feat, part) >>> nbr_values = comm.sparse_all_to_all_pull(nbr_idxs, node_feat, part)
Then two the arrays 'nbr_idxs' and 'nbr_values' forms the sparse Then two the arrays 'nbr_idxs' and 'nbr_values' forms the sparse
set of features, where 'nbr_idxs[i]' is the global node id, and set of features, where 'nbr_idxs[i]' is the global node id, and
'nbr_values[i]' is the feature vector for that node. This 'nbr_values[i]' is the feature vector for that node. This
communication pattern is useful for node features or node communication pattern is useful for node features or node
embeddings. embeddings.
""" """
out_value = _CAPI_DGLNCCLSparseAllToAllPull( out_value = _CAPI_DGLNCCLSparseAllToAllPull(
self.get(), F.zerocopy_to_dgl_ndarray(req_idx), self.get(),
F.zerocopy_to_dgl_ndarray(req_idx),
F.zerocopy_to_dgl_ndarray(value), F.zerocopy_to_dgl_ndarray(value),
partition.get()) partition.get(),
)
return F.zerocopy_from_dgl_ndarray(out_value) return F.zerocopy_from_dgl_ndarray(out_value)
def get(self): def get(self):
""" Get the C-Handle for this object. """Get the C-Handle for this object."""
"""
return self._handle return self._handle
def rank(self): def rank(self):
""" Get the rank of this process in this communicator. """Get the rank of this process in this communicator.
Returns Returns
------- -------
int int
The rank of this process. The rank of this process.
""" """
return self._rank return self._rank
def size(self): def size(self):
""" Get the size of this communicator. """Get the size of this communicator.
Returns Returns
------- -------
int int
The number of processes in this communicator. The number of processes in this communicator.
""" """
return self._size return self._size
def is_supported(): def is_supported():
""" Check if DGL was built with NCCL support. """Check if DGL was built with NCCL support.
Returns Returns
------- -------
bool bool
True if NCCL support was built in. True if NCCL support was built in.
""" """
return _CAPI_DGLNCCLHasSupport() return _CAPI_DGLNCCLHasSupport()
_init_api("dgl.cuda.nccl") _init_api("dgl.cuda.nccl")
...@@ -5,55 +5,74 @@ for downloading, processing, saving and loading data from external resources. ...@@ -5,55 +5,74 @@ for downloading, processing, saving and loading data from external resources.
from __future__ import absolute_import from __future__ import absolute_import
from . import citation_graph as citegrh from . import citation_graph as citegrh
from .citation_graph import CoraBinary, CitationGraphDataset from .adapter import *
from .minigc import *
from .tree import SST, SSTDataset
from .utils import *
from .sbm import SBMMixture, SBMMixtureDataset
from .reddit import RedditDataset
from .ppi import PPIDataset, LegacyPPIDataset
from .tu import TUDataset, LegacyTUDataset
from .gnn_benchmark import AmazonCoBuy, CoraFull, Coauthor, AmazonCoBuyComputerDataset, \
AmazonCoBuyPhotoDataset, CoauthorPhysicsDataset, CoauthorCSDataset, CoraFullDataset
from .karate import KarateClub, KarateClubDataset
from .gindt import GINDataset
from .bitcoinotc import BitcoinOTC, BitcoinOTCDataset from .bitcoinotc import BitcoinOTC, BitcoinOTCDataset
from .citation_graph import (
CitationGraphDataset,
CiteseerGraphDataset,
CoraBinary,
CoraGraphDataset,
PubmedGraphDataset,
)
from .csv_dataset import CSVDataset
from .dgl_dataset import DGLBuiltinDataset, DGLDataset
from .fakenews import FakeNewsDataset
from .flickr import FlickrDataset
from .fraud import FraudAmazonDataset, FraudDataset, FraudYelpDataset
from .gdelt import GDELT, GDELTDataset from .gdelt import GDELT, GDELTDataset
from .gindt import GINDataset
from .gnn_benchmark import (
AmazonCoBuy,
AmazonCoBuyComputerDataset,
AmazonCoBuyPhotoDataset,
Coauthor,
CoauthorCSDataset,
CoauthorPhysicsDataset,
CoraFull,
CoraFullDataset,
)
from .icews18 import ICEWS18, ICEWS18Dataset from .icews18 import ICEWS18, ICEWS18Dataset
from .karate import KarateClub, KarateClubDataset
from .knowledge_graph import FB15k237Dataset, FB15kDataset, WN18Dataset
from .minigc import *
from .ppi import LegacyPPIDataset, PPIDataset
from .qm7b import QM7b, QM7bDataset from .qm7b import QM7b, QM7bDataset
from .qm9 import QM9, QM9Dataset from .qm9 import QM9, QM9Dataset
from .qm9_edge import QM9Edge, QM9EdgeDataset from .qm9_edge import QM9Edge, QM9EdgeDataset
from .dgl_dataset import DGLDataset, DGLBuiltinDataset from .rdf import AIFBDataset, AMDataset, BGSDataset, MUTAGDataset
from .citation_graph import CoraGraphDataset, CiteseerGraphDataset, PubmedGraphDataset from .reddit import RedditDataset
from .knowledge_graph import FB15k237Dataset, FB15kDataset, WN18Dataset from .sbm import SBMMixture, SBMMixtureDataset
from .rdf import AIFBDataset, MUTAGDataset, BGSDataset, AMDataset from .synthetic import (
from .fraud import FraudDataset, FraudYelpDataset, FraudAmazonDataset BA2MotifDataset,
from .fakenews import FakeNewsDataset BACommunityDataset,
from .csv_dataset import CSVDataset BAShapeDataset,
from .adapter import * TreeCycleDataset,
from .synthetic import BAShapeDataset, BACommunityDataset, TreeCycleDataset, TreeGridDataset, BA2MotifDataset TreeGridDataset,
)
from .tree import SST, SSTDataset
from .tu import LegacyTUDataset, TUDataset
from .utils import *
from .wikics import WikiCSDataset from .wikics import WikiCSDataset
from .flickr import FlickrDataset
from .yelp import YelpDataset from .yelp import YelpDataset
def register_data_args(parser): def register_data_args(parser):
parser.add_argument( parser.add_argument(
"--dataset", "--dataset",
type=str, type=str,
required=False, required=False,
help= help="The input dataset. Can be cora, citeseer, pubmed, syn(synthetic dataset) or reddit",
"The input dataset. Can be cora, citeseer, pubmed, syn(synthetic dataset) or reddit"
) )
def load_data(args): def load_data(args):
if args.dataset == 'cora': if args.dataset == "cora":
return citegrh.load_cora() return citegrh.load_cora()
elif args.dataset == 'citeseer': elif args.dataset == "citeseer":
return citegrh.load_citeseer() return citegrh.load_citeseer()
elif args.dataset == 'pubmed': elif args.dataset == "pubmed":
return citegrh.load_pubmed() return citegrh.load_pubmed()
elif args.dataset is not None and args.dataset.startswith('reddit'): elif args.dataset is not None and args.dataset.startswith("reddit"):
return RedditDataset(self_loop=('self-loop' in args.dataset)) return RedditDataset(self_loop=("self-loop" in args.dataset))
else: else:
raise ValueError('Unknown dataset: {}'.format(args.dataset)) raise ValueError("Unknown dataset: {}".format(args.dataset))
"""Dataset adapters for re-purposing a dataset for a different kind of training task.""" """Dataset adapters for re-purposing a dataset for a different kind of training task."""
import os
import json import json
import os
import numpy as np import numpy as np
from .. import backend as F from .. import backend as F
from ..base import DGLError
from ..convert import graph as create_dgl_graph from ..convert import graph as create_dgl_graph
from ..sampling.negative import _calc_redundancy from ..sampling.negative import _calc_redundancy
from .dgl_dataset import DGLDataset
from . import utils from . import utils
from ..base import DGLError from .dgl_dataset import DGLDataset
from .. import backend as F
__all__ = ['AsNodePredDataset', 'AsLinkPredDataset', 'AsGraphPredDataset'] __all__ = ["AsNodePredDataset", "AsLinkPredDataset", "AsGraphPredDataset"]
class AsNodePredDataset(DGLDataset): class AsNodePredDataset(DGLDataset):
...@@ -77,83 +77,118 @@ class AsNodePredDataset(DGLDataset): ...@@ -77,83 +77,118 @@ class AsNodePredDataset(DGLDataset):
True True
""" """
def __init__(self, def __init__(self, dataset, split_ratio=None, target_ntype=None, **kwargs):
dataset,
split_ratio=None,
target_ntype=None,
**kwargs):
self.dataset = dataset self.dataset = dataset
self.split_ratio = split_ratio self.split_ratio = split_ratio
self.target_ntype = target_ntype self.target_ntype = target_ntype
super().__init__(self.dataset.name + '-as-nodepred', super().__init__(
hash_key=(split_ratio, target_ntype, dataset.name, 'nodepred'), **kwargs) self.dataset.name + "-as-nodepred",
hash_key=(split_ratio, target_ntype, dataset.name, "nodepred"),
**kwargs
)
def process(self): def process(self):
is_ogb = hasattr(self.dataset, 'get_idx_split') is_ogb = hasattr(self.dataset, "get_idx_split")
if is_ogb: if is_ogb:
g, label = self.dataset[0] g, label = self.dataset[0]
self.g = g.clone() self.g = g.clone()
self.g.ndata['label'] = F.reshape(label, (g.num_nodes(),)) self.g.ndata["label"] = F.reshape(label, (g.num_nodes(),))
else: else:
self.g = self.dataset[0].clone() self.g = self.dataset[0].clone()
if 'label' not in self.g.nodes[self.target_ntype].data: if "label" not in self.g.nodes[self.target_ntype].data:
raise ValueError("Missing node labels. Make sure labels are stored " raise ValueError(
"under name 'label'.") "Missing node labels. Make sure labels are stored "
"under name 'label'."
)
if self.split_ratio is None: if self.split_ratio is None:
if is_ogb: if is_ogb:
split = self.dataset.get_idx_split() split = self.dataset.get_idx_split()
train_idx, val_idx, test_idx = split['train'], split['valid'], split['test'] train_idx, val_idx, test_idx = (
split["train"],
split["valid"],
split["test"],
)
n = self.g.num_nodes() n = self.g.num_nodes()
train_mask = utils.generate_mask_tensor(utils.idx2mask(train_idx, n)) train_mask = utils.generate_mask_tensor(
val_mask = utils.generate_mask_tensor(utils.idx2mask(val_idx, n)) utils.idx2mask(train_idx, n)
test_mask = utils.generate_mask_tensor(utils.idx2mask(test_idx, n)) )
self.g.ndata['train_mask'] = train_mask val_mask = utils.generate_mask_tensor(
self.g.ndata['val_mask'] = val_mask utils.idx2mask(val_idx, n)
self.g.ndata['test_mask'] = test_mask )
test_mask = utils.generate_mask_tensor(
utils.idx2mask(test_idx, n)
)
self.g.ndata["train_mask"] = train_mask
self.g.ndata["val_mask"] = val_mask
self.g.ndata["test_mask"] = test_mask
else: else:
assert "train_mask" in self.g.nodes[self.target_ntype].data, \ assert (
"train_mask is not provided, please specify split_ratio to generate the masks" "train_mask" in self.g.nodes[self.target_ntype].data
assert "val_mask" in self.g.nodes[self.target_ntype].data, \ ), "train_mask is not provided, please specify split_ratio to generate the masks"
"val_mask is not provided, please specify split_ratio to generate the masks" assert (
assert "test_mask" in self.g.nodes[self.target_ntype].data, \ "val_mask" in self.g.nodes[self.target_ntype].data
"test_mask is not provided, please specify split_ratio to generate the masks" ), "val_mask is not provided, please specify split_ratio to generate the masks"
assert (
"test_mask" in self.g.nodes[self.target_ntype].data
), "test_mask is not provided, please specify split_ratio to generate the masks"
else: else:
if self.verbose: if self.verbose:
print('Generating train/val/test masks...') print("Generating train/val/test masks...")
utils.add_nodepred_split(self, self.split_ratio, self.target_ntype) utils.add_nodepred_split(self, self.split_ratio, self.target_ntype)
self._set_split_index() self._set_split_index()
self.num_classes = getattr(self.dataset, 'num_classes', None) self.num_classes = getattr(self.dataset, "num_classes", None)
if self.num_classes is None: if self.num_classes is None:
self.num_classes = len(F.unique(self.g.nodes[self.target_ntype].data['label'])) self.num_classes = len(
F.unique(self.g.nodes[self.target_ntype].data["label"])
)
def has_cache(self): def has_cache(self):
return os.path.isfile(os.path.join(self.save_path, 'graph_{}.bin'.format(self.hash))) return os.path.isfile(
os.path.join(self.save_path, "graph_{}.bin".format(self.hash))
)
def load(self): def load(self):
with open(os.path.join(self.save_path, 'info_{}.json'.format(self.hash)), 'r') as f: with open(
os.path.join(self.save_path, "info_{}.json".format(self.hash)), "r"
) as f:
info = json.load(f) info = json.load(f)
if (info['split_ratio'] != self.split_ratio if (
or info['target_ntype'] != self.target_ntype): info["split_ratio"] != self.split_ratio
raise ValueError('Provided split ratio is different from the cached file. ' or info["target_ntype"] != self.target_ntype
'Re-process the dataset.') ):
self.split_ratio = info['split_ratio'] raise ValueError(
self.target_ntype = info['target_ntype'] "Provided split ratio is different from the cached file. "
self.num_classes = info['num_classes'] "Re-process the dataset."
gs, _ = utils.load_graphs(os.path.join(self.save_path, 'graph_{}.bin'.format(self.hash))) )
self.split_ratio = info["split_ratio"]
self.target_ntype = info["target_ntype"]
self.num_classes = info["num_classes"]
gs, _ = utils.load_graphs(
os.path.join(self.save_path, "graph_{}.bin".format(self.hash))
)
self.g = gs[0] self.g = gs[0]
self._set_split_index() self._set_split_index()
def save(self): def save(self):
utils.save_graphs(os.path.join(self.save_path, 'graph_{}.bin'.format(self.hash)), [self.g]) utils.save_graphs(
with open(os.path.join(self.save_path, 'info_{}.json'.format(self.hash)), 'w') as f: os.path.join(self.save_path, "graph_{}.bin".format(self.hash)),
json.dump({ [self.g],
'split_ratio': self.split_ratio, )
'target_ntype': self.target_ntype, with open(
'num_classes': self.num_classes}, f) os.path.join(self.save_path, "info_{}.json".format(self.hash)), "w"
) as f:
json.dump(
{
"split_ratio": self.split_ratio,
"target_ntype": self.target_ntype,
"num_classes": self.num_classes,
},
f,
)
def __getitem__(self, idx): def __getitem__(self, idx):
return self.g return self.g
...@@ -164,19 +199,18 @@ class AsNodePredDataset(DGLDataset): ...@@ -164,19 +199,18 @@ class AsNodePredDataset(DGLDataset):
def _set_split_index(self): def _set_split_index(self):
"""Add train_idx/val_idx/test_idx as dataset attributes according to corresponding mask.""" """Add train_idx/val_idx/test_idx as dataset attributes according to corresponding mask."""
ndata = self.g.nodes[self.target_ntype].data ndata = self.g.nodes[self.target_ntype].data
self.train_idx = F.nonzero_1d(ndata['train_mask']) self.train_idx = F.nonzero_1d(ndata["train_mask"])
self.val_idx = F.nonzero_1d(ndata['val_mask']) self.val_idx = F.nonzero_1d(ndata["val_mask"])
self.test_idx = F.nonzero_1d(ndata['test_mask']) self.test_idx = F.nonzero_1d(ndata["test_mask"])
def negative_sample(g, num_samples): def negative_sample(g, num_samples):
"""Random sample negative edges from graph, excluding self-loops, """Random sample negative edges from graph, excluding self-loops,
the result samples might be less than num_samples the result samples might be less than num_samples
""" """
num_nodes = g.num_nodes() num_nodes = g.num_nodes()
redundancy = _calc_redundancy( redundancy = _calc_redundancy(num_samples, g.num_edges(), num_nodes**2)
num_samples, g.num_edges(), num_nodes ** 2) sample_size = int(num_samples * (1 + redundancy))
sample_size = int(num_samples*(1+redundancy))
edges = np.random.randint(0, num_nodes, size=(2, sample_size)) edges = np.random.randint(0, num_nodes, size=(2, sample_size))
edges = np.unique(edges, axis=1) edges = np.unique(edges, axis=1)
# remove self loop # remove self loop
...@@ -236,49 +270,71 @@ class AsLinkPredDataset(DGLDataset): ...@@ -236,49 +270,71 @@ class AsLinkPredDataset(DGLDataset):
True True
""" """
def __init__(self, def __init__(self, dataset, split_ratio=None, neg_ratio=3, **kwargs):
dataset,
split_ratio=None,
neg_ratio=3,
**kwargs):
self.g = dataset[0] self.g = dataset[0]
self.num_nodes = self.g.num_nodes() self.num_nodes = self.g.num_nodes()
self.dataset = dataset self.dataset = dataset
self.split_ratio = split_ratio self.split_ratio = split_ratio
self.neg_ratio = neg_ratio self.neg_ratio = neg_ratio
super().__init__(dataset.name + '-as-linkpred', super().__init__(
hash_key=(neg_ratio, split_ratio, dataset.name, 'linkpred'), **kwargs) dataset.name + "-as-linkpred",
hash_key=(neg_ratio, split_ratio, dataset.name, "linkpred"),
**kwargs
)
def process(self): def process(self):
if self.split_ratio is None: if self.split_ratio is None:
# Handle logics for OGB link prediction dataset # Handle logics for OGB link prediction dataset
assert hasattr(self.dataset, "get_edge_split"), \ assert hasattr(
"dataset doesn't have get_edge_split method, please specify split_ratio and neg_ratio to generate the split" self.dataset, "get_edge_split"
), "dataset doesn't have get_edge_split method, please specify split_ratio and neg_ratio to generate the split"
# This is likely to be an ogb dataset # This is likely to be an ogb dataset
self.edge_split = self.dataset.get_edge_split() self.edge_split = self.dataset.get_edge_split()
self._train_graph = self.g self._train_graph = self.g
if 'source_node' in self.edge_split["test"]: if "source_node" in self.edge_split["test"]:
# Probably ogbl-citation2 # Probably ogbl-citation2
pos_e = (self.edge_split["valid"]["source_node"], self.edge_split["valid"]["target_node"]) pos_e = (
neg_e_size = self.edge_split["valid"]['target_node_neg'].shape[-1] self.edge_split["valid"]["source_node"],
neg_e_src = np.repeat(self.edge_split['valid']['source_node'], neg_e_size) self.edge_split["valid"]["target_node"],
neg_e_dst = np.reshape(self.edge_split["valid"]["target_node_neg"], -1) )
neg_e_size = self.edge_split["valid"]["target_node_neg"].shape[
-1
]
neg_e_src = np.repeat(
self.edge_split["valid"]["source_node"], neg_e_size
)
neg_e_dst = np.reshape(
self.edge_split["valid"]["target_node_neg"], -1
)
self._val_edges = pos_e, (neg_e_src, neg_e_dst) self._val_edges = pos_e, (neg_e_src, neg_e_dst)
pos_e = (self.edge_split["test"]["source_node"], self.edge_split["test"]["target_node"]) pos_e = (
neg_e_size = self.edge_split["test"]['target_node_neg'].shape[-1] self.edge_split["test"]["source_node"],
neg_e_src = np.repeat(self.edge_split['test']['source_node'], neg_e_size) self.edge_split["test"]["target_node"],
neg_e_dst = np.reshape(self.edge_split["test"]["target_node_neg"], -1) )
neg_e_size = self.edge_split["test"]["target_node_neg"].shape[
-1
]
neg_e_src = np.repeat(
self.edge_split["test"]["source_node"], neg_e_size
)
neg_e_dst = np.reshape(
self.edge_split["test"]["target_node_neg"], -1
)
self._test_edges = pos_e, (neg_e_src, neg_e_dst) self._test_edges = pos_e, (neg_e_src, neg_e_dst)
elif 'edge' in self.edge_split["test"]: elif "edge" in self.edge_split["test"]:
# Probably ogbl-collab # Probably ogbl-collab
pos_e_tensor, neg_e_tensor = self.edge_split["valid"][ pos_e_tensor, neg_e_tensor = (
"edge"], self.edge_split["valid"]["edge_neg"] self.edge_split["valid"]["edge"],
self.edge_split["valid"]["edge_neg"],
)
pos_e = (pos_e_tensor[:, 0], pos_e_tensor[:, 1]) pos_e = (pos_e_tensor[:, 0], pos_e_tensor[:, 1])
neg_e = (neg_e_tensor[:, 0], neg_e_tensor[:, 1]) neg_e = (neg_e_tensor[:, 0], neg_e_tensor[:, 1])
self._val_edges = pos_e, neg_e self._val_edges = pos_e, neg_e
pos_e_tensor, neg_e_tensor = self.edge_split["test"][ pos_e_tensor, neg_e_tensor = (
"edge"], self.edge_split["test"]["edge_neg"] self.edge_split["test"]["edge"],
self.edge_split["test"]["edge_neg"],
)
pos_e = (pos_e_tensor[:, 0], pos_e_tensor[:, 1]) pos_e = (pos_e_tensor[:, 0], pos_e_tensor[:, 1])
neg_e = (neg_e_tensor[:, 0], neg_e_tensor[:, 1]) neg_e = (neg_e_tensor[:, 0], neg_e_tensor[:, 1])
self._test_edges = pos_e, neg_e self._test_edges = pos_e, neg_e
...@@ -292,40 +348,65 @@ class AsLinkPredDataset(DGLDataset): ...@@ -292,40 +348,65 @@ class AsLinkPredDataset(DGLDataset):
n = graph.num_edges() n = graph.num_edges()
src, dst = graph.edges() src, dst = graph.edges()
src, dst = F.asnumpy(src), F.asnumpy(dst) src, dst = F.asnumpy(src), F.asnumpy(dst)
n_train, n_val, n_test = int( n_train, n_val, n_test = (
n * ratio[0]), int(n * ratio[1]), int(n * ratio[2]) int(n * ratio[0]),
int(n * ratio[1]),
int(n * ratio[2]),
)
idx = np.random.permutation(n) idx = np.random.permutation(n)
train_pos_idx = idx[:n_train] train_pos_idx = idx[:n_train]
val_pos_idx = idx[n_train:n_train+n_val] val_pos_idx = idx[n_train : n_train + n_val]
test_pos_idx = idx[n_train+n_val:] test_pos_idx = idx[n_train + n_val :]
neg_src, neg_dst = negative_sample( neg_src, neg_dst = negative_sample(
graph, self.neg_ratio*(n_val+n_test)) graph, self.neg_ratio * (n_val + n_test)
neg_n_val, neg_n_test = self.neg_ratio * n_val, self.neg_ratio * n_test )
neg_n_val, neg_n_test = (
self.neg_ratio * n_val,
self.neg_ratio * n_test,
)
neg_val_src, neg_val_dst = neg_src[:neg_n_val], neg_dst[:neg_n_val] neg_val_src, neg_val_dst = neg_src[:neg_n_val], neg_dst[:neg_n_val]
neg_test_src, neg_test_dst = neg_src[neg_n_val:], neg_dst[neg_n_val:] neg_test_src, neg_test_dst = (
self._val_edges = (F.tensor(src[val_pos_idx]), F.tensor(dst[val_pos_idx]) neg_src[neg_n_val:],
), (F.tensor(neg_val_src), F.tensor(neg_val_dst)) neg_dst[neg_n_val:],
self._test_edges = (F.tensor(src[test_pos_idx]), )
F.tensor(dst[test_pos_idx])), (F.tensor(neg_test_src), F.tensor(neg_test_dst)) self._val_edges = (
F.tensor(src[val_pos_idx]),
F.tensor(dst[val_pos_idx]),
), (F.tensor(neg_val_src), F.tensor(neg_val_dst))
self._test_edges = (
F.tensor(src[test_pos_idx]),
F.tensor(dst[test_pos_idx]),
), (F.tensor(neg_test_src), F.tensor(neg_test_dst))
self._train_graph = create_dgl_graph( self._train_graph = create_dgl_graph(
(src[train_pos_idx], dst[train_pos_idx]), num_nodes=self.num_nodes) (src[train_pos_idx], dst[train_pos_idx]),
num_nodes=self.num_nodes,
)
self._train_graph.ndata["feat"] = graph.ndata["feat"] self._train_graph.ndata["feat"] = graph.ndata["feat"]
def has_cache(self): def has_cache(self):
return os.path.isfile(os.path.join(self.save_path, 'graph_{}.bin'.format(self.hash))) return os.path.isfile(
os.path.join(self.save_path, "graph_{}.bin".format(self.hash))
)
def load(self): def load(self):
gs, tensor_dict = utils.load_graphs( gs, tensor_dict = utils.load_graphs(
os.path.join(self.save_path, 'graph_{}.bin'.format(self.hash))) os.path.join(self.save_path, "graph_{}.bin".format(self.hash))
)
self.g = gs[0] self.g = gs[0]
self._train_graph = self.g self._train_graph = self.g
self._val_edges = (tensor_dict["val_pos_src"], tensor_dict["val_pos_dst"]), ( self._val_edges = (
tensor_dict["val_neg_src"], tensor_dict["val_neg_dst"]) tensor_dict["val_pos_src"],
self._test_edges = (tensor_dict["test_pos_src"], tensor_dict["test_pos_dst"]), ( tensor_dict["val_pos_dst"],
tensor_dict["test_neg_src"], tensor_dict["test_neg_dst"]) ), (tensor_dict["val_neg_src"], tensor_dict["val_neg_dst"])
self._test_edges = (
with open(os.path.join(self.save_path, 'info_{}.json'.format(self.hash)), 'r') as f: tensor_dict["test_pos_src"],
tensor_dict["test_pos_dst"],
), (tensor_dict["test_neg_src"], tensor_dict["test_neg_dst"])
with open(
os.path.join(self.save_path, "info_{}.json".format(self.hash)), "r"
) as f:
info = json.load(f) info = json.load(f)
self.split_ratio = info["split_ratio"] self.split_ratio = info["split_ratio"]
self.neg_ratio = info["neg_ratio"] self.neg_ratio = info["neg_ratio"]
...@@ -341,12 +422,18 @@ class AsLinkPredDataset(DGLDataset): ...@@ -341,12 +422,18 @@ class AsLinkPredDataset(DGLDataset):
"test_neg_src": self._test_edges[1][0], "test_neg_src": self._test_edges[1][0],
"test_neg_dst": self._test_edges[1][1], "test_neg_dst": self._test_edges[1][1],
} }
utils.save_graphs(os.path.join(self.save_path, 'graph_{}.bin'.format(self.hash)), [ utils.save_graphs(
self._train_graph], tensor_dict) os.path.join(self.save_path, "graph_{}.bin".format(self.hash)),
with open(os.path.join(self.save_path, 'info_{}.json'.format(self.hash)), 'w') as f: [self._train_graph],
json.dump({ tensor_dict,
'split_ratio': self.split_ratio, )
'neg_ratio': self.neg_ratio}, f) with open(
os.path.join(self.save_path, "info_{}.json".format(self.hash)), "w"
) as f:
json.dump(
{"split_ratio": self.split_ratio, "neg_ratio": self.neg_ratio},
f,
)
@property @property
def feat_size(self): def feat_size(self):
...@@ -370,6 +457,7 @@ class AsLinkPredDataset(DGLDataset): ...@@ -370,6 +457,7 @@ class AsLinkPredDataset(DGLDataset):
def __len__(self): def __len__(self):
return 1 return 1
class AsGraphPredDataset(DGLDataset): class AsGraphPredDataset(DGLDataset):
"""Repurpose a dataset for standard graph property prediction task. """Repurpose a dataset for standard graph property prediction task.
...@@ -425,23 +513,24 @@ class AsGraphPredDataset(DGLDataset): ...@@ -425,23 +513,24 @@ class AsGraphPredDataset(DGLDataset):
ndata_schemes={'feat': Scheme(shape=(9,), dtype=torch.int64)} ndata_schemes={'feat': Scheme(shape=(9,), dtype=torch.int64)}
edata_schemes={'feat': Scheme(shape=(3,), dtype=torch.int64)}), tensor([0])) edata_schemes={'feat': Scheme(shape=(3,), dtype=torch.int64)}), tensor([0]))
""" """
def __init__(self,
dataset, def __init__(self, dataset, split_ratio=None, **kwargs):
split_ratio=None,
**kwargs):
self.dataset = dataset self.dataset = dataset
self.split_ratio = split_ratio self.split_ratio = split_ratio
super().__init__(dataset.name + '-as-graphpred', super().__init__(
hash_key=(split_ratio, dataset.name, 'graphpred'), **kwargs) dataset.name + "-as-graphpred",
hash_key=(split_ratio, dataset.name, "graphpred"),
**kwargs
)
def process(self): def process(self):
is_ogb = hasattr(self.dataset, 'get_idx_split') is_ogb = hasattr(self.dataset, "get_idx_split")
if self.split_ratio is None: if self.split_ratio is None:
if is_ogb: if is_ogb:
split = self.dataset.get_idx_split() split = self.dataset.get_idx_split()
self.train_idx = split['train'] self.train_idx = split["train"]
self.val_idx = split['valid'] self.val_idx = split["valid"]
self.test_idx = split['test'] self.test_idx = split["test"]
else: else:
# Handle FakeNewsDataset # Handle FakeNewsDataset
try: try:
...@@ -449,11 +538,13 @@ class AsGraphPredDataset(DGLDataset): ...@@ -449,11 +538,13 @@ class AsGraphPredDataset(DGLDataset):
self.val_idx = F.nonzero_1d(self.dataset.val_mask) self.val_idx = F.nonzero_1d(self.dataset.val_mask)
self.test_idx = F.nonzero_1d(self.dataset.test_mask) self.test_idx = F.nonzero_1d(self.dataset.test_mask)
except: except:
raise DGLError('The input dataset does not have default train/val/test\ raise DGLError(
split. Please specify split_ratio to generate the split.') "The input dataset does not have default train/val/test\
split. Please specify split_ratio to generate the split."
)
else: else:
if self.verbose: if self.verbose:
print('Generating train/val/test split...') print("Generating train/val/test split...")
train_ratio, val_ratio, _ = self.split_ratio train_ratio, val_ratio, _ = self.split_ratio
num_graphs = len(self.dataset) num_graphs = len(self.dataset)
num_train = int(num_graphs * train_ratio) num_train = int(num_graphs * train_ratio)
...@@ -461,10 +552,10 @@ class AsGraphPredDataset(DGLDataset): ...@@ -461,10 +552,10 @@ class AsGraphPredDataset(DGLDataset):
idx = np.random.permutation(num_graphs) idx = np.random.permutation(num_graphs)
self.train_idx = F.tensor(idx[:num_train]) self.train_idx = F.tensor(idx[:num_train])
self.val_idx = F.tensor(idx[num_train: num_train + num_val]) self.val_idx = F.tensor(idx[num_train : num_train + num_val])
self.test_idx = F.tensor(idx[num_train + num_val:]) self.test_idx = F.tensor(idx[num_train + num_val :])
if hasattr(self.dataset, 'num_classes'): if hasattr(self.dataset, "num_classes"):
# GINDataset, MiniGCDataset, FakeNewsDataset, TUDataset, # GINDataset, MiniGCDataset, FakeNewsDataset, TUDataset,
# LegacyTUDataset, BA2MotifDataset # LegacyTUDataset, BA2MotifDataset
self.num_classes = self.dataset.num_classes self.num_classes = self.dataset.num_classes
...@@ -472,42 +563,58 @@ class AsGraphPredDataset(DGLDataset): ...@@ -472,42 +563,58 @@ class AsGraphPredDataset(DGLDataset):
# None for multi-label classification and regression # None for multi-label classification and regression
self.num_classes = None self.num_classes = None
if hasattr(self.dataset, 'num_tasks'): if hasattr(self.dataset, "num_tasks"):
# OGB datasets # OGB datasets
self.num_tasks = self.dataset.num_tasks self.num_tasks = self.dataset.num_tasks
else: else:
self.num_tasks = 1 self.num_tasks = 1
def has_cache(self): def has_cache(self):
return os.path.isfile(os.path.join(self.save_path, 'info_{}.json'.format(self.hash))) return os.path.isfile(
os.path.join(self.save_path, "info_{}.json".format(self.hash))
)
def load(self): def load(self):
with open(os.path.join(self.save_path, 'info_{}.json'.format(self.hash)), 'r') as f: with open(
os.path.join(self.save_path, "info_{}.json".format(self.hash)), "r"
) as f:
info = json.load(f) info = json.load(f)
if info['split_ratio'] != self.split_ratio: if info["split_ratio"] != self.split_ratio:
raise ValueError('Provided split ratio is different from the cached file. ' raise ValueError(
'Re-process the dataset.') "Provided split ratio is different from the cached file. "
self.split_ratio = info['split_ratio'] "Re-process the dataset."
self.num_tasks = info['num_tasks'] )
self.num_classes = info['num_classes'] self.split_ratio = info["split_ratio"]
self.num_tasks = info["num_tasks"]
split = np.load(os.path.join(self.save_path, 'split_{}.npz'.format(self.hash))) self.num_classes = info["num_classes"]
self.train_idx = F.zerocopy_from_numpy(split['train_idx'])
self.val_idx = F.zerocopy_from_numpy(split['val_idx']) split = np.load(
self.test_idx = F.zerocopy_from_numpy(split['test_idx']) os.path.join(self.save_path, "split_{}.npz".format(self.hash))
)
self.train_idx = F.zerocopy_from_numpy(split["train_idx"])
self.val_idx = F.zerocopy_from_numpy(split["val_idx"])
self.test_idx = F.zerocopy_from_numpy(split["test_idx"])
def save(self): def save(self):
if not os.path.exists(self.save_path): if not os.path.exists(self.save_path):
os.makedirs(self.save_path) os.makedirs(self.save_path)
with open(os.path.join(self.save_path, 'info_{}.json'.format(self.hash)), 'w') as f: with open(
json.dump({ os.path.join(self.save_path, "info_{}.json".format(self.hash)), "w"
'split_ratio': self.split_ratio, ) as f:
'num_tasks': self.num_tasks, json.dump(
'num_classes': self.num_classes}, f) {
np.savez(os.path.join(self.save_path, 'split_{}.npz'.format(self.hash)), "split_ratio": self.split_ratio,
train_idx=F.zerocopy_to_numpy(self.train_idx), "num_tasks": self.num_tasks,
val_idx=F.zerocopy_to_numpy(self.val_idx), "num_classes": self.num_classes,
test_idx=F.zerocopy_to_numpy(self.test_idx)) },
f,
)
np.savez(
os.path.join(self.save_path, "split_{}.npz".format(self.hash)),
train_idx=F.zerocopy_to_numpy(self.train_idx),
val_idx=F.zerocopy_to_numpy(self.val_idx),
test_idx=F.zerocopy_to_numpy(self.test_idx),
)
def __getitem__(self, idx): def __getitem__(self, idx):
return self.dataset[idx] return self.dataset[idx]
...@@ -518,9 +625,9 @@ class AsGraphPredDataset(DGLDataset): ...@@ -518,9 +625,9 @@ class AsGraphPredDataset(DGLDataset):
@property @property
def node_feat_size(self): def node_feat_size(self):
g = self[0][0] g = self[0][0]
return g.ndata['feat'].shape[-1] if 'feat' in g.ndata else None return g.ndata["feat"].shape[-1] if "feat" in g.ndata else None
@property @property
def edge_feat_size(self): def edge_feat_size(self):
g = self[0][0] g = self[0][0]
return g.edata['feat'].shape[-1] if 'feat' in g.edata else None return g.edata["feat"].shape[-1] if "feat" in g.edata else None
""" BitcoinOTC dataset for fraud detection """ """ BitcoinOTC dataset for fraud detection """
import numpy as np
import os
import datetime import datetime
import gzip import gzip
import os
import shutil import shutil
from .dgl_dataset import DGLBuiltinDataset import numpy as np
from .utils import download, makedirs, save_graphs, load_graphs, check_sha1
from ..convert import graph as dgl_graph
from .. import backend as F from .. import backend as F
from ..convert import graph as dgl_graph
from .dgl_dataset import DGLBuiltinDataset
from .utils import check_sha1, download, load_graphs, makedirs, save_graphs
class BitcoinOTCDataset(DGLBuiltinDataset): class BitcoinOTCDataset(DGLBuiltinDataset):
...@@ -68,35 +69,44 @@ class BitcoinOTCDataset(DGLBuiltinDataset): ...@@ -68,35 +69,44 @@ class BitcoinOTCDataset(DGLBuiltinDataset):
>>> >>>
""" """
_url = 'https://snap.stanford.edu/data/soc-sign-bitcoinotc.csv.gz' _url = "https://snap.stanford.edu/data/soc-sign-bitcoinotc.csv.gz"
_sha1_str = 'c14281f9e252de0bd0b5f1c6e2bae03123938641' _sha1_str = "c14281f9e252de0bd0b5f1c6e2bae03123938641"
def __init__(self, raw_dir=None, force_reload=False, verbose=False, transform=None): def __init__(
super(BitcoinOTCDataset, self).__init__(name='bitcoinotc', self, raw_dir=None, force_reload=False, verbose=False, transform=None
url=self._url, ):
raw_dir=raw_dir, super(BitcoinOTCDataset, self).__init__(
force_reload=force_reload, name="bitcoinotc",
verbose=verbose, url=self._url,
transform=transform) raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose,
transform=transform,
)
def download(self): def download(self):
gz_file_path = os.path.join(self.raw_dir, self.name + '.csv.gz') gz_file_path = os.path.join(self.raw_dir, self.name + ".csv.gz")
download(self.url, path=gz_file_path) download(self.url, path=gz_file_path)
if not check_sha1(gz_file_path, self._sha1_str): if not check_sha1(gz_file_path, self._sha1_str):
raise UserWarning('File {} is downloaded but the content hash does not match.' raise UserWarning(
'The repo may be outdated or download may be incomplete. ' "File {} is downloaded but the content hash does not match."
'Otherwise you can create an issue for it.'.format(self.name + '.csv.gz')) "The repo may be outdated or download may be incomplete. "
"Otherwise you can create an issue for it.".format(
self.name + ".csv.gz"
)
)
self._extract_gz(gz_file_path, self.raw_path) self._extract_gz(gz_file_path, self.raw_path)
def process(self): def process(self):
filename = os.path.join(self.save_path, self.name + '.csv') filename = os.path.join(self.save_path, self.name + ".csv")
data = np.loadtxt(filename, delimiter=',').astype(np.int64) data = np.loadtxt(filename, delimiter=",").astype(np.int64)
data[:, 0:2] = data[:, 0:2] - data[:, 0:2].min() data[:, 0:2] = data[:, 0:2] - data[:, 0:2].min()
delta = datetime.timedelta(days=14).total_seconds() delta = datetime.timedelta(days=14).total_seconds()
# The source code is not released, but the paper indicates there're # The source code is not released, but the paper indicates there're
# totally 137 samples. The cutoff below has exactly 137 samples. # totally 137 samples. The cutoff below has exactly 137 samples.
time_index = np.around( time_index = np.around((data[:, 3] - data[:, 3].min()) / delta).astype(
(data[:, 3] - data[:, 3].min()) / delta).astype(np.int64) np.int64
)
self._graphs = [] self._graphs = []
for i in range(time_index.max()): for i in range(time_index.max()):
...@@ -104,19 +114,21 @@ class BitcoinOTCDataset(DGLBuiltinDataset): ...@@ -104,19 +114,21 @@ class BitcoinOTCDataset(DGLBuiltinDataset):
edges = data[row_mask][:, 0:2] edges = data[row_mask][:, 0:2]
rate = data[row_mask][:, 2] rate = data[row_mask][:, 2]
g = dgl_graph((edges[:, 0], edges[:, 1])) g = dgl_graph((edges[:, 0], edges[:, 1]))
g.edata['h'] = F.tensor(rate.reshape(-1, 1), dtype=F.data_type_dict['int64']) g.edata["h"] = F.tensor(
rate.reshape(-1, 1), dtype=F.data_type_dict["int64"]
)
self._graphs.append(g) self._graphs.append(g)
def has_cache(self): def has_cache(self):
graph_path = os.path.join(self.save_path, 'dgl_graph.bin') graph_path = os.path.join(self.save_path, "dgl_graph.bin")
return os.path.exists(graph_path) return os.path.exists(graph_path)
def save(self): def save(self):
graph_path = os.path.join(self.save_path, 'dgl_graph.bin') graph_path = os.path.join(self.save_path, "dgl_graph.bin")
save_graphs(graph_path, self.graphs) save_graphs(graph_path, self.graphs)
def load(self): def load(self):
graph_path = os.path.join(self.save_path, 'dgl_graph.bin') graph_path = os.path.join(self.save_path, "dgl_graph.bin")
self._graphs = load_graphs(graph_path)[0] self._graphs = load_graphs(graph_path)[0]
@property @property
...@@ -124,7 +136,7 @@ class BitcoinOTCDataset(DGLBuiltinDataset): ...@@ -124,7 +136,7 @@ class BitcoinOTCDataset(DGLBuiltinDataset):
return self._graphs return self._graphs
def __len__(self): def __len__(self):
r""" Number of graphs in the dataset. r"""Number of graphs in the dataset.
Return Return
------- -------
...@@ -133,7 +145,7 @@ class BitcoinOTCDataset(DGLBuiltinDataset): ...@@ -133,7 +145,7 @@ class BitcoinOTCDataset(DGLBuiltinDataset):
return len(self.graphs) return len(self.graphs)
def __getitem__(self, item): def __getitem__(self, item):
r""" Get graph by index r"""Get graph by index
Parameters Parameters
---------- ----------
...@@ -155,7 +167,7 @@ class BitcoinOTCDataset(DGLBuiltinDataset): ...@@ -155,7 +167,7 @@ class BitcoinOTCDataset(DGLBuiltinDataset):
@property @property
def is_temporal(self): def is_temporal(self):
r""" Are the graphs temporal graphs r"""Are the graphs temporal graphs
Returns Returns
------- -------
...@@ -166,12 +178,12 @@ class BitcoinOTCDataset(DGLBuiltinDataset): ...@@ -166,12 +178,12 @@ class BitcoinOTCDataset(DGLBuiltinDataset):
def _extract_gz(self, file, target_dir, overwrite=False): def _extract_gz(self, file, target_dir, overwrite=False):
if os.path.exists(target_dir) and not overwrite: if os.path.exists(target_dir) and not overwrite:
return return
print('Extracting file to {}'.format(target_dir)) print("Extracting file to {}".format(target_dir))
fname = os.path.basename(file) fname = os.path.basename(file)
makedirs(target_dir) makedirs(target_dir)
out_file_path = os.path.join(target_dir, fname[:-3]) out_file_path = os.path.join(target_dir, fname[:-3])
with gzip.open(file, 'rb') as f_in: with gzip.open(file, "rb") as f_in:
with open(out_file_path, 'wb') as f_out: with open(out_file_path, "wb") as f_out:
shutil.copyfileobj(f_in, f_out) shutil.copyfileobj(f_in, f_out)
......
import os import os
import numpy as np import numpy as np
from .dgl_dataset import DGLDataset
from .utils import save_graphs, load_graphs, Subset
from .. import backend as F from .. import backend as F
from ..base import DGLError from ..base import DGLError
from .dgl_dataset import DGLDataset
from .utils import Subset, load_graphs, save_graphs
class CSVDataset(DGLDataset): class CSVDataset(DGLDataset):
...@@ -65,11 +67,24 @@ class CSVDataset(DGLDataset): ...@@ -65,11 +67,24 @@ class CSVDataset(DGLDataset):
Please refer to :ref:`guide-data-pipeline-loadcsv`. Please refer to :ref:`guide-data-pipeline-loadcsv`.
""" """
META_YAML_NAME = 'meta.yaml'
def __init__(self, data_path, force_reload=False, verbose=True, ndata_parser=None, META_YAML_NAME = "meta.yaml"
edata_parser=None, gdata_parser=None, transform=None):
from .csv_dataset_base import load_yaml_with_sanity_check, DefaultDataParser def __init__(
self,
data_path,
force_reload=False,
verbose=True,
ndata_parser=None,
edata_parser=None,
gdata_parser=None,
transform=None,
):
from .csv_dataset_base import (
DefaultDataParser,
load_yaml_with_sanity_check,
)
self.graphs = None self.graphs = None
self.data = None self.data = None
self.ndata_parser = {} if ndata_parser is None else ndata_parser self.ndata_parser = {} if ndata_parser is None else ndata_parser
...@@ -79,17 +94,29 @@ class CSVDataset(DGLDataset): ...@@ -79,17 +94,29 @@ class CSVDataset(DGLDataset):
meta_yaml_path = os.path.join(data_path, CSVDataset.META_YAML_NAME) meta_yaml_path = os.path.join(data_path, CSVDataset.META_YAML_NAME)
if not os.path.exists(meta_yaml_path): if not os.path.exists(meta_yaml_path):
raise DGLError( raise DGLError(
"'{}' cannot be found under {}.".format(CSVDataset.META_YAML_NAME, data_path)) "'{}' cannot be found under {}.".format(
CSVDataset.META_YAML_NAME, data_path
)
)
self.meta_yaml = load_yaml_with_sanity_check(meta_yaml_path) self.meta_yaml = load_yaml_with_sanity_check(meta_yaml_path)
ds_name = self.meta_yaml.dataset_name ds_name = self.meta_yaml.dataset_name
super().__init__(ds_name, raw_dir=os.path.dirname( super().__init__(
meta_yaml_path), force_reload=force_reload, verbose=verbose, transform=transform) ds_name,
raw_dir=os.path.dirname(meta_yaml_path),
force_reload=force_reload,
verbose=verbose,
transform=transform,
)
def process(self): def process(self):
"""Parse node/edge data from CSV files and construct DGL.Graphs """Parse node/edge data from CSV files and construct DGL.Graphs"""
""" from .csv_dataset_base import (
from .csv_dataset_base import NodeData, EdgeData, GraphData, DGLGraphConstructor DGLGraphConstructor,
EdgeData,
GraphData,
NodeData,
)
meta_yaml = self.meta_yaml meta_yaml = self.meta_yaml
base_dir = self.raw_dir base_dir = self.raw_dir
node_data = [] node_data = []
...@@ -97,36 +124,58 @@ class CSVDataset(DGLDataset): ...@@ -97,36 +124,58 @@ class CSVDataset(DGLDataset):
if meta_node is None: if meta_node is None:
continue continue
ntype = meta_node.ntype ntype = meta_node.ntype
data_parser = self.ndata_parser if callable( data_parser = (
self.ndata_parser) else self.ndata_parser.get(ntype, self.default_data_parser) self.ndata_parser
if callable(self.ndata_parser)
else self.ndata_parser.get(ntype, self.default_data_parser)
)
ndata = NodeData.load_from_csv( ndata = NodeData.load_from_csv(
meta_node, base_dir=base_dir, separator=meta_yaml.separator, data_parser=data_parser) meta_node,
base_dir=base_dir,
separator=meta_yaml.separator,
data_parser=data_parser,
)
node_data.append(ndata) node_data.append(ndata)
edge_data = [] edge_data = []
for meta_edge in meta_yaml.edge_data: for meta_edge in meta_yaml.edge_data:
if meta_edge is None: if meta_edge is None:
continue continue
etype = tuple(meta_edge.etype) etype = tuple(meta_edge.etype)
data_parser = self.edata_parser if callable( data_parser = (
self.edata_parser) else self.edata_parser.get(etype, self.default_data_parser) self.edata_parser
if callable(self.edata_parser)
else self.edata_parser.get(etype, self.default_data_parser)
)
edata = EdgeData.load_from_csv( edata = EdgeData.load_from_csv(
meta_edge, base_dir=base_dir, separator=meta_yaml.separator, data_parser=data_parser) meta_edge,
base_dir=base_dir,
separator=meta_yaml.separator,
data_parser=data_parser,
)
edge_data.append(edata) edge_data.append(edata)
graph_data = None graph_data = None
if meta_yaml.graph_data is not None: if meta_yaml.graph_data is not None:
meta_graph = meta_yaml.graph_data meta_graph = meta_yaml.graph_data
data_parser = self.default_data_parser if self.gdata_parser is None else self.gdata_parser data_parser = (
self.default_data_parser
if self.gdata_parser is None
else self.gdata_parser
)
graph_data = GraphData.load_from_csv( graph_data = GraphData.load_from_csv(
meta_graph, base_dir=base_dir, separator=meta_yaml.separator, data_parser=data_parser) meta_graph,
base_dir=base_dir,
separator=meta_yaml.separator,
data_parser=data_parser,
)
# construct graphs # construct graphs
self.graphs, self.data = DGLGraphConstructor.construct_graphs( self.graphs, self.data = DGLGraphConstructor.construct_graphs(
node_data, edge_data, graph_data) node_data, edge_data, graph_data
)
if len(self.data) == 1: if len(self.data) == 1:
self.labels = list(self.data.values())[0] self.labels = list(self.data.values())[0]
def has_cache(self): def has_cache(self):
graph_path = os.path.join(self.save_path, graph_path = os.path.join(self.save_path, self.name + ".bin")
self.name + '.bin')
if os.path.exists(graph_path): if os.path.exists(graph_path):
return True return True
...@@ -135,14 +184,11 @@ class CSVDataset(DGLDataset): ...@@ -135,14 +184,11 @@ class CSVDataset(DGLDataset):
def save(self): def save(self):
if self.graphs is None: if self.graphs is None:
raise DGLError("No graphs available in dataset") raise DGLError("No graphs available in dataset")
graph_path = os.path.join(self.save_path, graph_path = os.path.join(self.save_path, self.name + ".bin")
self.name + '.bin') save_graphs(graph_path, self.graphs, labels=self.data)
save_graphs(graph_path, self.graphs,
labels=self.data)
def load(self): def load(self):
graph_path = os.path.join(self.save_path, graph_path = os.path.join(self.save_path, self.name + ".bin")
self.name + '.bin')
self.graphs, self.data = load_graphs(graph_path) self.graphs, self.data = load_graphs(graph_path)
if len(self.data) == 1: if len(self.data) == 1:
self.labels = list(self.data.values())[0] self.labels = list(self.data.values())[0]
......
import ast
import os import os
from typing import Callable, List, Optional
import numpy as np import numpy as np
from typing import List, Optional, Callable
from .. import backend as F
from ..convert import heterograph as dgl_heterograph
from ..base import dgl_warning, DGLError
import ast
import pydantic as dt
import pandas as pd import pandas as pd
import pydantic as dt
import yaml import yaml
from .. import backend as F
from ..base import DGLError, dgl_warning
from ..convert import heterograph as dgl_heterograph
class MetaNode(dt.BaseModel): class MetaNode(dt.BaseModel):
""" Class of node_data in YAML. Internal use only. """ """Class of node_data in YAML. Internal use only."""
file_name: str file_name: str
ntype: Optional[str] = '_V' ntype: Optional[str] = "_V"
graph_id_field: Optional[str] = 'graph_id' graph_id_field: Optional[str] = "graph_id"
node_id_field: Optional[str] = 'node_id' node_id_field: Optional[str] = "node_id"
class MetaEdge(dt.BaseModel): class MetaEdge(dt.BaseModel):
""" Class of edge_data in YAML. Internal use only. """ """Class of edge_data in YAML. Internal use only."""
file_name: str file_name: str
etype: Optional[List[str]] = ['_V', '_E', '_V'] etype: Optional[List[str]] = ["_V", "_E", "_V"]
graph_id_field: Optional[str] = 'graph_id' graph_id_field: Optional[str] = "graph_id"
src_id_field: Optional[str] = 'src_id' src_id_field: Optional[str] = "src_id"
dst_id_field: Optional[str] = 'dst_id' dst_id_field: Optional[str] = "dst_id"
class MetaGraph(dt.BaseModel): class MetaGraph(dt.BaseModel):
""" Class of graph_data in YAML. Internal use only. """ """Class of graph_data in YAML. Internal use only."""
file_name: str file_name: str
graph_id_field: Optional[str] = 'graph_id' graph_id_field: Optional[str] = "graph_id"
class MetaYaml(dt.BaseModel): class MetaYaml(dt.BaseModel):
""" Class of YAML. Internal use only. """ """Class of YAML. Internal use only."""
version: Optional[str] = '1.0.0'
version: Optional[str] = "1.0.0"
dataset_name: str dataset_name: str
separator: Optional[str] = ',' separator: Optional[str] = ","
node_data: List[MetaNode] node_data: List[MetaNode]
edge_data: List[MetaEdge] edge_data: List[MetaEdge]
graph_data: Optional[MetaGraph] = None graph_data: Optional[MetaGraph] = None
def load_yaml_with_sanity_check(yaml_file): def load_yaml_with_sanity_check(yaml_file):
""" Load yaml and do sanity check. Internal use only. """ """Load yaml and do sanity check. Internal use only."""
with open(yaml_file) as f: with open(yaml_file) as f:
yaml_data = yaml.load(f, Loader=yaml.loader.SafeLoader) yaml_data = yaml.load(f, Loader=yaml.loader.SafeLoader)
try: try:
meta_yaml = MetaYaml(**yaml_data) meta_yaml = MetaYaml(**yaml_data)
except dt.ValidationError as e: except dt.ValidationError as e:
print( print("Details of pydantic.ValidationError:\n{}".format(e.json()))
"Details of pydantic.ValidationError:\n{}".format(e.json())) raise DGLError(
"Validation Error for YAML fields. Details are shown above."
)
if meta_yaml.version != "1.0.0":
raise DGLError( raise DGLError(
"Validation Error for YAML fields. Details are shown above.") "Invalid CSVDataset version {}. Supported versions: '1.0.0'".format(
if meta_yaml.version != '1.0.0': meta_yaml.version
raise DGLError("Invalid CSVDataset version {}. Supported versions: '1.0.0'".format( )
meta_yaml.version)) )
ntypes = [meta.ntype for meta in meta_yaml.node_data] ntypes = [meta.ntype for meta in meta_yaml.node_data]
if len(ntypes) > len(set(ntypes)): if len(ntypes) > len(set(ntypes)):
raise DGLError( raise DGLError(
"Each node CSV file must have a unique node type name, but found duplicate node type: {}.".format(ntypes)) "Each node CSV file must have a unique node type name, but found duplicate node type: {}.".format(
ntypes
)
)
etypes = [tuple(meta.etype) for meta in meta_yaml.edge_data] etypes = [tuple(meta.etype) for meta in meta_yaml.edge_data]
if len(etypes) > len(set(etypes)): if len(etypes) > len(set(etypes)):
raise DGLError( raise DGLError(
"Each edge CSV file must have a unique edge type name, but found duplicate edge type: {}.".format(etypes)) "Each edge CSV file must have a unique edge type name, but found duplicate edge type: {}.".format(
etypes
)
)
return meta_yaml return meta_yaml
...@@ -74,7 +89,10 @@ def _validate_data_length(data_dict): ...@@ -74,7 +89,10 @@ def _validate_data_length(data_dict):
res = lst.count(lst[0]) == len(lst) res = lst.count(lst[0]) == len(lst)
if not res: if not res:
raise DGLError( raise DGLError(
"All data are required to have same length while some of them does not. Length of data={}".format(str(len_dict))) "All data are required to have same length while some of them does not. Length of data={}".format(
str(len_dict)
)
)
def _tensor(data, dtype=None): def _tensor(data, dtype=None):
...@@ -86,8 +104,10 @@ def _tensor(data, dtype=None): ...@@ -86,8 +104,10 @@ def _tensor(data, dtype=None):
ret = F.tensor(ret, dtype=F.float32) ret = F.tensor(ret, dtype=F.float32)
return ret return ret
class BaseData: class BaseData:
""" Class of base data which is inherited by Node/Edge/GraphData. Internal use only. """ """Class of base data which is inherited by Node/Edge/GraphData. Internal use only."""
@staticmethod @staticmethod
def read_csv(file_name, base_dir, separator): def read_csv(file_name, base_dir, separator):
csv_path = file_name csv_path = file_name
...@@ -106,31 +126,40 @@ class BaseData: ...@@ -106,31 +126,40 @@ class BaseData:
class NodeData(BaseData): class NodeData(BaseData):
""" Class of node data which is used for DGLGraph construction. Internal use only. """ """Class of node data which is used for DGLGraph construction. Internal use only."""
def __init__(self, node_id, data, type=None, graph_id=None): def __init__(self, node_id, data, type=None, graph_id=None):
self.id = np.array(node_id) self.id = np.array(node_id)
self.data = data self.data = data
self.type = type if type is not None else '_V' self.type = type if type is not None else "_V"
self.graph_id = np.array( self.graph_id = (
graph_id) if graph_id is not None else np.full(len(node_id), 0) np.array(graph_id)
if graph_id is not None
else np.full(len(node_id), 0)
)
_validate_data_length( _validate_data_length(
{**{'id': self.id, 'graph_id': self.graph_id}, **self.data}) {**{"id": self.id, "graph_id": self.graph_id}, **self.data}
)
@staticmethod @staticmethod
def load_from_csv(meta: MetaNode, data_parser: Callable, base_dir=None, separator=','): def load_from_csv(
meta: MetaNode, data_parser: Callable, base_dir=None, separator=","
):
df = BaseData.read_csv(meta.file_name, base_dir, separator) df = BaseData.read_csv(meta.file_name, base_dir, separator)
node_ids = BaseData.pop_from_dataframe(df, meta.node_id_field) node_ids = BaseData.pop_from_dataframe(df, meta.node_id_field)
graph_ids = BaseData.pop_from_dataframe(df, meta.graph_id_field) graph_ids = BaseData.pop_from_dataframe(df, meta.graph_id_field)
if node_ids is None: if node_ids is None:
raise DGLError("Missing node id field [{}] in file [{}].".format( raise DGLError(
meta.node_id_field, meta.file_name)) "Missing node id field [{}] in file [{}].".format(
meta.node_id_field, meta.file_name
)
)
ntype = meta.ntype ntype = meta.ntype
ndata = data_parser(df) ndata = data_parser(df)
return NodeData(node_ids, ndata, type=ntype, graph_id=graph_ids) return NodeData(node_ids, ndata, type=ntype, graph_id=graph_ids)
@staticmethod @staticmethod
def to_dict(node_data: List['NodeData']) -> dict: def to_dict(node_data: List["NodeData"]) -> dict:
# node_ids could be numeric or non-numeric values, but duplication is not allowed. # node_ids could be numeric or non-numeric values, but duplication is not allowed.
node_dict = {} node_dict = {}
for n_data in node_data: for n_data in node_data:
...@@ -139,112 +168,159 @@ class NodeData(BaseData): ...@@ -139,112 +168,159 @@ class NodeData(BaseData):
idx = n_data.graph_id == graph_id idx = n_data.graph_id == graph_id
ids = n_data.id[idx] ids = n_data.id[idx]
u_ids, u_indices, u_counts = np.unique( u_ids, u_indices, u_counts = np.unique(
ids, return_index=True, return_counts=True) ids, return_index=True, return_counts=True
)
if len(ids) > len(u_ids): if len(ids) > len(u_ids):
raise DGLError("Node IDs are required to be unique but the following ids are duplicate: {}".format( raise DGLError(
u_ids[u_counts > 1])) "Node IDs are required to be unique but the following ids are duplicate: {}".format(
u_ids[u_counts > 1]
)
)
if graph_id not in node_dict: if graph_id not in node_dict:
node_dict[graph_id] = {} node_dict[graph_id] = {}
node_dict[graph_id][n_data.type] = {'mapping': {index: i for i, node_dict[graph_id][n_data.type] = {
index in enumerate(ids[u_indices])}, "mapping": {
'data': {k: _tensor(v[idx][u_indices]) index: i for i, index in enumerate(ids[u_indices])
for k, v in n_data.data.items()}, },
'dtype': ids.dtype} "data": {
k: _tensor(v[idx][u_indices])
for k, v in n_data.data.items()
},
"dtype": ids.dtype,
}
return node_dict return node_dict
class EdgeData(BaseData): class EdgeData(BaseData):
""" Class of edge data which is used for DGLGraph construction. Internal use only. """ """Class of edge data which is used for DGLGraph construction. Internal use only."""
def __init__(self, src_id, dst_id, data, type=None, graph_id=None): def __init__(self, src_id, dst_id, data, type=None, graph_id=None):
self.src = np.array(src_id) self.src = np.array(src_id)
self.dst = np.array(dst_id) self.dst = np.array(dst_id)
self.data = data self.data = data
self.type = type if type is not None else ('_V', '_E', '_V') self.type = type if type is not None else ("_V", "_E", "_V")
self.graph_id = np.array( self.graph_id = (
graph_id) if graph_id is not None else np.full(len(src_id), 0) np.array(graph_id)
if graph_id is not None
else np.full(len(src_id), 0)
)
_validate_data_length( _validate_data_length(
{**{'src': self.src, 'dst': self.dst, 'graph_id': self.graph_id}, **self.data}) {
**{"src": self.src, "dst": self.dst, "graph_id": self.graph_id},
**self.data,
}
)
@staticmethod @staticmethod
def load_from_csv(meta: MetaEdge, data_parser: Callable, base_dir=None, separator=','): def load_from_csv(
meta: MetaEdge, data_parser: Callable, base_dir=None, separator=","
):
df = BaseData.read_csv(meta.file_name, base_dir, separator) df = BaseData.read_csv(meta.file_name, base_dir, separator)
src_ids = BaseData.pop_from_dataframe(df, meta.src_id_field) src_ids = BaseData.pop_from_dataframe(df, meta.src_id_field)
if src_ids is None: if src_ids is None:
raise DGLError("Missing src id field [{}] in file [{}].".format( raise DGLError(
meta.src_id_field, meta.file_name)) "Missing src id field [{}] in file [{}].".format(
meta.src_id_field, meta.file_name
)
)
dst_ids = BaseData.pop_from_dataframe(df, meta.dst_id_field) dst_ids = BaseData.pop_from_dataframe(df, meta.dst_id_field)
if dst_ids is None: if dst_ids is None:
raise DGLError("Missing dst id field [{}] in file [{}].".format( raise DGLError(
meta.dst_id_field, meta.file_name)) "Missing dst id field [{}] in file [{}].".format(
meta.dst_id_field, meta.file_name
)
)
graph_ids = BaseData.pop_from_dataframe(df, meta.graph_id_field) graph_ids = BaseData.pop_from_dataframe(df, meta.graph_id_field)
etype = tuple(meta.etype) etype = tuple(meta.etype)
edata = data_parser(df) edata = data_parser(df)
return EdgeData(src_ids, dst_ids, edata, type=etype, graph_id=graph_ids) return EdgeData(src_ids, dst_ids, edata, type=etype, graph_id=graph_ids)
@staticmethod @staticmethod
def to_dict(edge_data: List['EdgeData'], node_dict: dict) -> dict: def to_dict(edge_data: List["EdgeData"], node_dict: dict) -> dict:
edge_dict = {} edge_dict = {}
for e_data in edge_data: for e_data in edge_data:
(src_type, e_type, dst_type) = e_data.type (src_type, e_type, dst_type) = e_data.type
graph_ids = np.unique(e_data.graph_id) graph_ids = np.unique(e_data.graph_id)
for graph_id in graph_ids: for graph_id in graph_ids:
if graph_id in edge_dict and e_data.type in edge_dict[graph_id]: if graph_id in edge_dict and e_data.type in edge_dict[graph_id]:
raise DGLError(f"Duplicate edge type[{e_data.type}] for same graph[{graph_id}], please place the same edge_type for same graph into single EdgeData.") raise DGLError(
f"Duplicate edge type[{e_data.type}] for same graph[{graph_id}], please place the same edge_type for same graph into single EdgeData."
)
idx = e_data.graph_id == graph_id idx = e_data.graph_id == graph_id
src_mapping = node_dict[graph_id][src_type]['mapping'] src_mapping = node_dict[graph_id][src_type]["mapping"]
dst_mapping = node_dict[graph_id][dst_type]['mapping'] dst_mapping = node_dict[graph_id][dst_type]["mapping"]
orig_src_ids = e_data.src[idx].astype(node_dict[graph_id][src_type]['dtype']) orig_src_ids = e_data.src[idx].astype(
orig_dst_ids = e_data.dst[idx].astype(node_dict[graph_id][dst_type]['dtype']) node_dict[graph_id][src_type]["dtype"]
)
orig_dst_ids = e_data.dst[idx].astype(
node_dict[graph_id][dst_type]["dtype"]
)
src_ids = [src_mapping[index] for index in orig_src_ids] src_ids = [src_mapping[index] for index in orig_src_ids]
dst_ids = [dst_mapping[index] for index in orig_dst_ids] dst_ids = [dst_mapping[index] for index in orig_dst_ids]
if graph_id not in edge_dict: if graph_id not in edge_dict:
edge_dict[graph_id] = {} edge_dict[graph_id] = {}
edge_dict[graph_id][e_data.type] = {'edges': (_tensor(src_ids), _tensor(dst_ids)), edge_dict[graph_id][e_data.type] = {
'data': {k: _tensor(v[idx]) "edges": (_tensor(src_ids), _tensor(dst_ids)),
for k, v in e_data.data.items()}} "data": {
k: _tensor(v[idx]) for k, v in e_data.data.items()
},
}
return edge_dict return edge_dict
class GraphData(BaseData): class GraphData(BaseData):
""" Class of graph data which is used for DGLGraph construction. Internal use only. """ """Class of graph data which is used for DGLGraph construction. Internal use only."""
def __init__(self, graph_id, data): def __init__(self, graph_id, data):
self.graph_id = np.array(graph_id) self.graph_id = np.array(graph_id)
self.data = data self.data = data
_validate_data_length({**{'graph_id': self.graph_id}, **self.data}) _validate_data_length({**{"graph_id": self.graph_id}, **self.data})
@staticmethod @staticmethod
def load_from_csv(meta: MetaGraph, data_parser: Callable, base_dir=None, separator=','): def load_from_csv(
meta: MetaGraph, data_parser: Callable, base_dir=None, separator=","
):
df = BaseData.read_csv(meta.file_name, base_dir, separator) df = BaseData.read_csv(meta.file_name, base_dir, separator)
graph_ids = BaseData.pop_from_dataframe(df, meta.graph_id_field) graph_ids = BaseData.pop_from_dataframe(df, meta.graph_id_field)
if graph_ids is None: if graph_ids is None:
raise DGLError("Missing graph id field [{}] in file [{}].".format( raise DGLError(
meta.graph_id_field, meta.file_name)) "Missing graph id field [{}] in file [{}].".format(
meta.graph_id_field, meta.file_name
)
)
gdata = data_parser(df) gdata = data_parser(df)
return GraphData(graph_ids, gdata) return GraphData(graph_ids, gdata)
@staticmethod @staticmethod
def to_dict(graph_data: 'GraphData', graphs_dict: dict) -> dict: def to_dict(graph_data: "GraphData", graphs_dict: dict) -> dict:
missing_ids = np.setdiff1d( missing_ids = np.setdiff1d(
np.array(list(graphs_dict.keys())), graph_data.graph_id) np.array(list(graphs_dict.keys())), graph_data.graph_id
)
if len(missing_ids) > 0: if len(missing_ids) > 0:
raise DGLError( raise DGLError(
"Found following graph ids in node/edge CSVs but not in graph CSV: {}.".format(missing_ids)) "Found following graph ids in node/edge CSVs but not in graph CSV: {}.".format(
missing_ids
)
)
graph_ids = graph_data.graph_id graph_ids = graph_data.graph_id
graphs = [] graphs = []
for graph_id in graph_ids: for graph_id in graph_ids:
if graph_id not in graphs_dict: if graph_id not in graphs_dict:
graphs_dict[graph_id] = dgl_heterograph( graphs_dict[graph_id] = dgl_heterograph(
{('_V', '_E', '_V'): ([], [])}) {("_V", "_E", "_V"): ([], [])}
)
for graph_id in graph_ids: for graph_id in graph_ids:
graphs.append(graphs_dict[graph_id]) graphs.append(graphs_dict[graph_id])
data = {k: F.reshape(_tensor(v), (len(graphs), -1)) for k, v in graph_data.data.items()} data = {
k: F.reshape(_tensor(v), (len(graphs), -1))
for k, v in graph_data.data.items()
}
return graphs, data return graphs, data
class DGLGraphConstructor: class DGLGraphConstructor:
""" Class for constructing DGLGraph from Node/Edge/Graph data. Internal use only. """ """Class for constructing DGLGraph from Node/Edge/Graph data. Internal use only."""
@staticmethod @staticmethod
def construct_graphs(node_data, edge_data, graph_data=None): def construct_graphs(node_data, edge_data, graph_data=None):
if not isinstance(node_data, list): if not isinstance(node_data, list):
...@@ -253,12 +329,10 @@ class DGLGraphConstructor: ...@@ -253,12 +329,10 @@ class DGLGraphConstructor:
edge_data = [edge_data] edge_data = [edge_data]
node_dict = NodeData.to_dict(node_data) node_dict = NodeData.to_dict(node_data)
edge_dict = EdgeData.to_dict(edge_data, node_dict) edge_dict = EdgeData.to_dict(edge_data, node_dict)
graph_dict = DGLGraphConstructor._construct_graphs( graph_dict = DGLGraphConstructor._construct_graphs(node_dict, edge_dict)
node_dict, edge_dict)
if graph_data is None: if graph_data is None:
graph_data = GraphData(np.full(1, 0), {}) graph_data = GraphData(np.full(1, 0), {})
graphs, data = GraphData.to_dict( graphs, data = GraphData.to_dict(graph_data, graph_dict)
graph_data, graph_dict)
return graphs, data return graphs, data
@staticmethod @staticmethod
...@@ -266,40 +340,47 @@ class DGLGraphConstructor: ...@@ -266,40 +340,47 @@ class DGLGraphConstructor:
graph_dict = {} graph_dict = {}
for graph_id in node_dict: for graph_id in node_dict:
if graph_id not in edge_dict: if graph_id not in edge_dict:
edge_dict[graph_id][('_V', '_E', '_V')] = {'edges': ([], [])} edge_dict[graph_id][("_V", "_E", "_V")] = {"edges": ([], [])}
graph = dgl_heterograph({etype: edata['edges'] graph = dgl_heterograph(
for etype, edata in edge_dict[graph_id].items()}, {
num_nodes_dict={ntype: len(ndata['mapping']) etype: edata["edges"]
for ntype, ndata in node_dict[graph_id].items()}) for etype, edata in edge_dict[graph_id].items()
},
num_nodes_dict={
ntype: len(ndata["mapping"])
for ntype, ndata in node_dict[graph_id].items()
},
)
def assign_data(type, src_data, dst_data): def assign_data(type, src_data, dst_data):
for key, value in src_data.items(): for key, value in src_data.items():
dst_data[type].data[key] = value dst_data[type].data[key] = value
for type, data in node_dict[graph_id].items(): for type, data in node_dict[graph_id].items():
assign_data(type, data['data'], graph.nodes) assign_data(type, data["data"], graph.nodes)
for (type), data in edge_dict[graph_id].items(): for (type), data in edge_dict[graph_id].items():
assign_data(type, data['data'], graph.edges) assign_data(type, data["data"], graph.edges)
graph_dict[graph_id] = graph graph_dict[graph_id] = graph
return graph_dict return graph_dict
class DefaultDataParser: class DefaultDataParser:
""" Default data parser for CSVDataset. It """Default data parser for CSVDataset. It
1. ignores any columns which does not have a header. 1. ignores any columns which does not have a header.
2. tries to convert to list of numeric values(generated by 2. tries to convert to list of numeric values(generated by
np.array().tolist()) if cell data is a str separated by ','. np.array().tolist()) if cell data is a str separated by ','.
3. read data and infer data type directly, otherwise. 3. read data and infer data type directly, otherwise.
""" """
def __call__(self, df: pd.DataFrame): def __call__(self, df: pd.DataFrame):
data = {} data = {}
for header in df: for header in df:
if 'Unnamed' in header: if "Unnamed" in header:
dgl_warning("Unamed column is found. Ignored...") dgl_warning("Unamed column is found. Ignored...")
continue continue
dt = df[header].to_numpy().squeeze() dt = df[header].to_numpy().squeeze()
if len(dt) > 0 and isinstance(dt[0], str): if len(dt) > 0 and isinstance(dt[0], str):
#probably consists of list of numeric values # probably consists of list of numeric values
dt = np.array([ast.literal_eval(row) for row in dt]) dt = np.array([ast.literal_eval(row) for row in dt])
data[header] = dt data[header] = dt
return data return data
...@@ -3,11 +3,15 @@ ...@@ -3,11 +3,15 @@
from __future__ import absolute_import from __future__ import absolute_import
import os, sys, hashlib
import traceback
import abc import abc
from .utils import download, extract_archive, get_download_dir, makedirs import hashlib
import os
import sys
import traceback
from ..utils import retry_method_with_fix from ..utils import retry_method_with_fix
from .utils import download, extract_archive, get_download_dir, makedirs
class DGLDataset(object): class DGLDataset(object):
r"""The basic DGL dataset for creating graph datasets. r"""The basic DGL dataset for creating graph datasets.
...@@ -75,8 +79,18 @@ class DGLDataset(object): ...@@ -75,8 +79,18 @@ class DGLDataset(object):
hash : str hash : str
Hash value for the dataset and the setting. Hash value for the dataset and the setting.
""" """
def __init__(self, name, url=None, raw_dir=None, save_dir=None,
hash_key=(), force_reload=False, verbose=False, transform=None): def __init__(
self,
name,
url=None,
raw_dir=None,
save_dir=None,
hash_key=(),
force_reload=False,
verbose=False,
transform=None,
):
self._name = name self._name = name
self._url = url self._url = url
self._force_reload = force_reload self._force_reload = force_reload
...@@ -131,8 +145,7 @@ class DGLDataset(object): ...@@ -131,8 +145,7 @@ class DGLDataset(object):
@abc.abstractmethod @abc.abstractmethod
def process(self): def process(self):
r"""Overwrite to realize your own logic of processing the input data. r"""Overwrite to realize your own logic of processing the input data."""
"""
pass pass
def has_cache(self): def has_cache(self):
...@@ -177,21 +190,21 @@ class DGLDataset(object): ...@@ -177,21 +190,21 @@ class DGLDataset(object):
try: try:
self.load() self.load()
if self.verbose: if self.verbose:
print('Done loading data from cached files.') print("Done loading data from cached files.")
except KeyboardInterrupt: except KeyboardInterrupt:
raise raise
except: except:
load_flag = False load_flag = False
if self.verbose: if self.verbose:
print(traceback.format_exc()) print(traceback.format_exc())
print('Loading from cache failed, re-processing.') print("Loading from cache failed, re-processing.")
if not load_flag: if not load_flag:
self._download() self._download()
self.process() self.process()
self.save() self.save()
if self.verbose: if self.verbose:
print('Done saving data into cached files.') print("Done saving data into cached files.")
def _get_hash(self): def _get_hash(self):
"""Compute the hash of the input tuple """Compute the hash of the input tuple
...@@ -205,62 +218,54 @@ class DGLDataset(object): ...@@ -205,62 +218,54 @@ class DGLDataset(object):
'a770b222' 'a770b222'
""" """
hash_func = hashlib.sha1() hash_func = hashlib.sha1()
hash_func.update(str(self._hash_key).encode('utf-8')) hash_func.update(str(self._hash_key).encode("utf-8"))
return hash_func.hexdigest()[:8] return hash_func.hexdigest()[:8]
@property @property
def url(self): def url(self):
r"""Get url to download the raw dataset. r"""Get url to download the raw dataset."""
"""
return self._url return self._url
@property @property
def name(self): def name(self):
r"""Name of the dataset. r"""Name of the dataset."""
"""
return self._name return self._name
@property @property
def raw_dir(self): def raw_dir(self):
r"""Raw file directory contains the input data folder. r"""Raw file directory contains the input data folder."""
"""
return self._raw_dir return self._raw_dir
@property @property
def raw_path(self): def raw_path(self):
r"""Directory contains the input data files. r"""Directory contains the input data files.
By default raw_path = os.path.join(self.raw_dir, self.name) By default raw_path = os.path.join(self.raw_dir, self.name)
""" """
return os.path.join(self.raw_dir, self.name) return os.path.join(self.raw_dir, self.name)
@property @property
def save_dir(self): def save_dir(self):
r"""Directory to save the processed dataset. r"""Directory to save the processed dataset."""
"""
return self._save_dir return self._save_dir
@property @property
def save_path(self): def save_path(self):
r"""Path to save the processed dataset. r"""Path to save the processed dataset."""
"""
return os.path.join(self._save_dir, self.name) return os.path.join(self._save_dir, self.name)
@property @property
def verbose(self): def verbose(self):
r"""Whether to print information. r"""Whether to print information."""
"""
return self._verbose return self._verbose
@property @property
def hash(self): def hash(self):
r"""Hash value for the dataset and the setting. r"""Hash value for the dataset and the setting."""
"""
return self._hash return self._hash
@abc.abstractmethod @abc.abstractmethod
def __getitem__(self, idx): def __getitem__(self, idx):
r"""Gets the data object at index. r"""Gets the data object at index."""
"""
pass pass
@abc.abstractmethod @abc.abstractmethod
...@@ -269,8 +274,11 @@ class DGLDataset(object): ...@@ -269,8 +274,11 @@ class DGLDataset(object):
pass pass
def __repr__(self): def __repr__(self):
return f'Dataset("{self.name}", num_graphs={len(self)},' + \ return (
f' save_path={self.save_path})' f'Dataset("{self.name}", num_graphs={len(self)},'
+ f" save_path={self.save_path})"
)
class DGLBuiltinDataset(DGLDataset): class DGLBuiltinDataset(DGLDataset):
r"""The Basic DGL Builtin Dataset. r"""The Basic DGL Builtin Dataset.
...@@ -299,21 +307,31 @@ class DGLBuiltinDataset(DGLDataset): ...@@ -299,21 +307,31 @@ class DGLBuiltinDataset(DGLDataset):
a transformed version. The :class:`~dgl.DGLGraph` object will be a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access. transformed before every access.
""" """
def __init__(self, name, url, raw_dir=None, hash_key=(),
force_reload=False, verbose=False, transform=None): def __init__(
super(DGLBuiltinDataset, self).__init__(name, self,
url=url, name,
raw_dir=raw_dir, url,
save_dir=None, raw_dir=None,
hash_key=hash_key, hash_key=(),
force_reload=force_reload, force_reload=False,
verbose=verbose, verbose=False,
transform=transform) transform=None,
):
super(DGLBuiltinDataset, self).__init__(
name,
url=url,
raw_dir=raw_dir,
save_dir=None,
hash_key=hash_key,
force_reload=force_reload,
verbose=verbose,
transform=transform,
)
def download(self): def download(self):
r""" Automatically download data and extract it. r"""Automatically download data and extract it."""
"""
if self.url is not None: if self.url is not None:
zip_file_path = os.path.join(self.raw_dir, self.name + '.zip') zip_file_path = os.path.join(self.raw_dir, self.name + ".zip")
download(self.url, path=zip_file_path) download(self.url, path=zip_file_path)
extract_archive(zip_file_path, self.raw_path) extract_archive(zip_file_path, self.raw_path)
import os import os
import numpy as np import numpy as np
import scipy.sparse as sp import scipy.sparse as sp
from .dgl_dataset import DGLBuiltinDataset
from .utils import save_graphs, load_graphs, _get_dgl_url
from .utils import save_info, load_info
from ..convert import graph
from .. import backend as F from .. import backend as F
from ..convert import graph
from .dgl_dataset import DGLBuiltinDataset
from .utils import _get_dgl_url, load_graphs, load_info, save_graphs, save_info
class FakeNewsDataset(DGLBuiltinDataset): class FakeNewsDataset(DGLBuiltinDataset):
...@@ -113,30 +113,41 @@ class FakeNewsDataset(DGLBuiltinDataset): ...@@ -113,30 +113,41 @@ class FakeNewsDataset(DGLBuiltinDataset):
>>> labels = dataset.labels >>> labels = dataset.labels
""" """
file_urls = { file_urls = {
'gossipcop': 'dataset/FakeNewsGOS.zip', "gossipcop": "dataset/FakeNewsGOS.zip",
'politifact': 'dataset/FakeNewsPOL.zip' "politifact": "dataset/FakeNewsPOL.zip",
} }
def __init__(self, name, feature_name, raw_dir=None, transform=None): def __init__(self, name, feature_name, raw_dir=None, transform=None):
assert name in ['gossipcop', 'politifact'], \ assert name in [
"Only supports 'gossipcop' or 'politifact'." "gossipcop",
"politifact",
], "Only supports 'gossipcop' or 'politifact'."
url = _get_dgl_url(self.file_urls[name]) url = _get_dgl_url(self.file_urls[name])
assert feature_name in ['bert', 'content', 'profile', 'spacy'], \ assert feature_name in [
"Only supports 'bert', 'content', 'profile', or 'spacy'" "bert",
"content",
"profile",
"spacy",
], "Only supports 'bert', 'content', 'profile', or 'spacy'"
self.feature_name = feature_name self.feature_name = feature_name
super(FakeNewsDataset, self).__init__(name=name, super(FakeNewsDataset, self).__init__(
url=url, name=name, url=url, raw_dir=raw_dir, transform=transform
raw_dir=raw_dir, )
transform=transform)
def process(self): def process(self):
"""process raw data to graph, labels and masks""" """process raw data to graph, labels and masks"""
self.labels = F.tensor(np.load(os.path.join(self.raw_path, 'graph_labels.npy'))) self.labels = F.tensor(
np.load(os.path.join(self.raw_path, "graph_labels.npy"))
)
num_graphs = self.labels.shape[0] num_graphs = self.labels.shape[0]
node_graph_id = np.load(os.path.join(self.raw_path, 'node_graph_id.npy')) node_graph_id = np.load(
edges = np.genfromtxt(os.path.join(self.raw_path, 'A.txt'), delimiter=',', dtype=int) os.path.join(self.raw_path, "node_graph_id.npy")
)
edges = np.genfromtxt(
os.path.join(self.raw_path, "A.txt"), delimiter=",", dtype=int
)
src = edges[:, 0] src = edges[:, 0]
dst = edges[:, 1] dst = edges[:, 1]
g = graph((src, dst)) g = graph((src, dst))
...@@ -148,9 +159,9 @@ class FakeNewsDataset(DGLBuiltinDataset): ...@@ -148,9 +159,9 @@ class FakeNewsDataset(DGLBuiltinDataset):
self.graphs = [g.subgraph(node_idx) for node_idx in node_idx_list] self.graphs = [g.subgraph(node_idx) for node_idx in node_idx_list]
train_idx = np.load(os.path.join(self.raw_path, 'train_idx.npy')) train_idx = np.load(os.path.join(self.raw_path, "train_idx.npy"))
val_idx = np.load(os.path.join(self.raw_path, 'val_idx.npy')) val_idx = np.load(os.path.join(self.raw_path, "val_idx.npy"))
test_idx = np.load(os.path.join(self.raw_path, 'test_idx.npy')) test_idx = np.load(os.path.join(self.raw_path, "test_idx.npy"))
train_mask = np.zeros(num_graphs, dtype=np.bool) train_mask = np.zeros(num_graphs, dtype=np.bool)
val_mask = np.zeros(num_graphs, dtype=np.bool) val_mask = np.zeros(num_graphs, dtype=np.bool)
test_mask = np.zeros(num_graphs, dtype=np.bool) test_mask = np.zeros(num_graphs, dtype=np.bool)
...@@ -161,40 +172,47 @@ class FakeNewsDataset(DGLBuiltinDataset): ...@@ -161,40 +172,47 @@ class FakeNewsDataset(DGLBuiltinDataset):
self.val_mask = F.tensor(val_mask) self.val_mask = F.tensor(val_mask)
self.test_mask = F.tensor(test_mask) self.test_mask = F.tensor(test_mask)
feature_file = 'new_' + self.feature_name + '_feature.npz' feature_file = "new_" + self.feature_name + "_feature.npz"
self.feature = F.tensor(sp.load_npz(os.path.join(self.raw_path, feature_file)).todense()) self.feature = F.tensor(
sp.load_npz(os.path.join(self.raw_path, feature_file)).todense()
)
def save(self): def save(self):
"""save the graph list and the labels""" """save the graph list and the labels"""
graph_path = os.path.join(self.save_path, self.name + '_dgl_graph.bin') graph_path = os.path.join(self.save_path, self.name + "_dgl_graph.bin")
info_path = os.path.join(self.save_path, self.name + '_dgl_graph.pkl') info_path = os.path.join(self.save_path, self.name + "_dgl_graph.pkl")
save_graphs(str(graph_path), self.graphs) save_graphs(str(graph_path), self.graphs)
save_info(info_path, {'label': self.labels, save_info(
'feature': self.feature, info_path,
'train_mask': self.train_mask, {
'val_mask': self.val_mask, "label": self.labels,
'test_mask': self.test_mask}) "feature": self.feature,
"train_mask": self.train_mask,
"val_mask": self.val_mask,
"test_mask": self.test_mask,
},
)
def has_cache(self): def has_cache(self):
""" check whether there are processed data in `self.save_path` """ """check whether there are processed data in `self.save_path`"""
graph_path = os.path.join(self.save_path, self.name + '_dgl_graph.bin') graph_path = os.path.join(self.save_path, self.name + "_dgl_graph.bin")
info_path = os.path.join(self.save_path, self.name + '_dgl_graph.pkl') info_path = os.path.join(self.save_path, self.name + "_dgl_graph.pkl")
return os.path.exists(graph_path) and os.path.exists(info_path) return os.path.exists(graph_path) and os.path.exists(info_path)
def load(self): def load(self):
"""load processed data from directory `self.save_path`""" """load processed data from directory `self.save_path`"""
graph_path = os.path.join(self.save_path, self.name + '_dgl_graph.bin') graph_path = os.path.join(self.save_path, self.name + "_dgl_graph.bin")
info_path = os.path.join(self.save_path, self.name + '_dgl_graph.pkl') info_path = os.path.join(self.save_path, self.name + "_dgl_graph.pkl")
graphs, _ = load_graphs(str(graph_path)) graphs, _ = load_graphs(str(graph_path))
info = load_info(str(info_path)) info = load_info(str(info_path))
self.graphs = graphs self.graphs = graphs
self.labels = info['label'] self.labels = info["label"]
self.feature = info['feature'] self.feature = info["feature"]
self.train_mask = info['train_mask'] self.train_mask = info["train_mask"]
self.val_mask = info['val_mask'] self.val_mask = info["val_mask"]
self.test_mask = info['test_mask'] self.test_mask = info["test_mask"]
@property @property
def num_classes(self): def num_classes(self):
...@@ -207,7 +225,7 @@ class FakeNewsDataset(DGLBuiltinDataset): ...@@ -207,7 +225,7 @@ class FakeNewsDataset(DGLBuiltinDataset):
return self.labels.shape[0] return self.labels.shape[0]
def __getitem__(self, i): def __getitem__(self, i):
r""" Get graph and label by index r"""Get graph and label by index
Parameters Parameters
---------- ----------
......
"""Flickr Dataset""" """Flickr Dataset"""
import os
import json import json
import os
import numpy as np import numpy as np
import scipy.sparse as sp import scipy.sparse as sp
from .. import backend as F from .. import backend as F
from ..convert import from_scipy from ..convert import from_scipy
from ..transforms import reorder_graph from ..transforms import reorder_graph
from .dgl_dataset import DGLBuiltinDataset from .dgl_dataset import DGLBuiltinDataset
from .utils import generate_mask_tensor, load_graphs, save_graphs, _get_dgl_url from .utils import _get_dgl_url, generate_mask_tensor, load_graphs, save_graphs
class FlickrDataset(DGLBuiltinDataset): class FlickrDataset(DGLBuiltinDataset):
...@@ -65,66 +67,78 @@ class FlickrDataset(DGLBuiltinDataset): ...@@ -65,66 +67,78 @@ class FlickrDataset(DGLBuiltinDataset):
>>> test_mask = g.ndata['test_mask'] >>> test_mask = g.ndata['test_mask']
""" """
def __init__(self, raw_dir=None, force_reload=False, verbose=False, transform=None, def __init__(
reorder=False): self,
_url = _get_dgl_url('dataset/flickr.zip') raw_dir=None,
force_reload=False,
verbose=False,
transform=None,
reorder=False,
):
_url = _get_dgl_url("dataset/flickr.zip")
self._reorder = reorder self._reorder = reorder
super(FlickrDataset, self).__init__(name='flickr', super(FlickrDataset, self).__init__(
raw_dir=raw_dir, name="flickr",
url=_url, raw_dir=raw_dir,
force_reload=force_reload, url=_url,
verbose=verbose, force_reload=force_reload,
transform=transform) verbose=verbose,
transform=transform,
)
def process(self): def process(self):
"""process raw data to graph, labels and masks""" """process raw data to graph, labels and masks"""
coo_adj = sp.load_npz(os.path.join(self.raw_path, "adj_full.npz")) coo_adj = sp.load_npz(os.path.join(self.raw_path, "adj_full.npz"))
g = from_scipy(coo_adj) g = from_scipy(coo_adj)
features = np.load(os.path.join(self.raw_path, 'feats.npy')) features = np.load(os.path.join(self.raw_path, "feats.npy"))
features = F.tensor(features, dtype=F.float32) features = F.tensor(features, dtype=F.float32)
y = [-1] * features.shape[0] y = [-1] * features.shape[0]
with open(os.path.join(self.raw_path, 'class_map.json')) as f: with open(os.path.join(self.raw_path, "class_map.json")) as f:
class_map = json.load(f) class_map = json.load(f)
for key, item in class_map.items(): for key, item in class_map.items():
y[int(key)] = item y[int(key)] = item
labels = F.tensor(np.array(y), dtype=F.int64) labels = F.tensor(np.array(y), dtype=F.int64)
with open(os.path.join(self.raw_path, 'role.json')) as f: with open(os.path.join(self.raw_path, "role.json")) as f:
role = json.load(f) role = json.load(f)
train_mask = np.zeros(features.shape[0], dtype=bool) train_mask = np.zeros(features.shape[0], dtype=bool)
train_mask[role['tr']] = True train_mask[role["tr"]] = True
val_mask = np.zeros(features.shape[0], dtype=bool) val_mask = np.zeros(features.shape[0], dtype=bool)
val_mask[role['va']] = True val_mask[role["va"]] = True
test_mask = np.zeros(features.shape[0], dtype=bool) test_mask = np.zeros(features.shape[0], dtype=bool)
test_mask[role['te']] = True test_mask[role["te"]] = True
g.ndata['feat'] = features g.ndata["feat"] = features
g.ndata['label'] = labels g.ndata["label"] = labels
g.ndata['train_mask'] = generate_mask_tensor(train_mask) g.ndata["train_mask"] = generate_mask_tensor(train_mask)
g.ndata['val_mask'] = generate_mask_tensor(val_mask) g.ndata["val_mask"] = generate_mask_tensor(val_mask)
g.ndata['test_mask'] = generate_mask_tensor(test_mask) g.ndata["test_mask"] = generate_mask_tensor(test_mask)
if self._reorder: if self._reorder:
self._graph = reorder_graph( self._graph = reorder_graph(
g, node_permute_algo='rcmk', edge_permute_algo='dst', store_ids=False) g,
node_permute_algo="rcmk",
edge_permute_algo="dst",
store_ids=False,
)
else: else:
self._graph = g self._graph = g
def has_cache(self): def has_cache(self):
graph_path = os.path.join(self.save_path, 'dgl_graph.bin') graph_path = os.path.join(self.save_path, "dgl_graph.bin")
return os.path.exists(graph_path) return os.path.exists(graph_path)
def save(self): def save(self):
graph_path = os.path.join(self.save_path, 'dgl_graph.bin') graph_path = os.path.join(self.save_path, "dgl_graph.bin")
save_graphs(graph_path, self._graph) save_graphs(graph_path, self._graph)
def load(self): def load(self):
graph_path = os.path.join(self.save_path, 'dgl_graph.bin') graph_path = os.path.join(self.save_path, "dgl_graph.bin")
g, _ = load_graphs(graph_path) g, _ = load_graphs(graph_path)
self._graph = g[0] self._graph = g[0]
...@@ -137,7 +151,7 @@ class FlickrDataset(DGLBuiltinDataset): ...@@ -137,7 +151,7 @@ class FlickrDataset(DGLBuiltinDataset):
return 1 return 1
def __getitem__(self, idx): def __getitem__(self, idx):
r""" Get graph object r"""Get graph object
Parameters Parameters
---------- ----------
...@@ -161,4 +175,4 @@ class FlickrDataset(DGLBuiltinDataset): ...@@ -161,4 +175,4 @@ class FlickrDataset(DGLBuiltinDataset):
if self._transform is None: if self._transform is None:
return self._graph return self._graph
else: else:
return self._transform(self._graph) return self._transform(self._graph)
\ No newline at end of file
"""Fraud Dataset """Fraud Dataset
""" """
import os import os
from scipy import io
import numpy as np import numpy as np
from scipy import io
from .utils import save_graphs, load_graphs, _get_dgl_url from .. import backend as F
from ..convert import heterograph from ..convert import heterograph
from .dgl_dataset import DGLBuiltinDataset from .dgl_dataset import DGLBuiltinDataset
from .. import backend as F from .utils import _get_dgl_url, load_graphs, save_graphs
class FraudDataset(DGLBuiltinDataset): class FraudDataset(DGLBuiltinDataset):
...@@ -77,61 +78,74 @@ class FraudDataset(DGLBuiltinDataset): ...@@ -77,61 +78,74 @@ class FraudDataset(DGLBuiltinDataset):
>>> label = graph.ndata['label'] >>> label = graph.ndata['label']
""" """
file_urls = { file_urls = {
'yelp': 'dataset/FraudYelp.zip', "yelp": "dataset/FraudYelp.zip",
'amazon': 'dataset/FraudAmazon.zip' "amazon": "dataset/FraudAmazon.zip",
} }
relations = { relations = {
'yelp': ['net_rsr', 'net_rtr', 'net_rur'], "yelp": ["net_rsr", "net_rtr", "net_rur"],
'amazon': ['net_upu', 'net_usu', 'net_uvu'] "amazon": ["net_upu", "net_usu", "net_uvu"],
}
file_names = {
'yelp': 'YelpChi.mat',
'amazon': 'Amazon.mat'
} }
node_name = { file_names = {"yelp": "YelpChi.mat", "amazon": "Amazon.mat"}
'yelp': 'review', node_name = {"yelp": "review", "amazon": "user"}
'amazon': 'user'
} def __init__(
self,
def __init__(self, name, raw_dir=None, random_seed=717, train_size=0.7, name,
val_size=0.1, force_reload=False, verbose=True, transform=None): raw_dir=None,
assert name in ['yelp', 'amazon'], "only supports 'yelp', or 'amazon'" random_seed=717,
train_size=0.7,
val_size=0.1,
force_reload=False,
verbose=True,
transform=None,
):
assert name in ["yelp", "amazon"], "only supports 'yelp', or 'amazon'"
url = _get_dgl_url(self.file_urls[name]) url = _get_dgl_url(self.file_urls[name])
self.seed = random_seed self.seed = random_seed
self.train_size = train_size self.train_size = train_size
self.val_size = val_size self.val_size = val_size
super(FraudDataset, self).__init__(name=name, super(FraudDataset, self).__init__(
url=url, name=name,
raw_dir=raw_dir, url=url,
hash_key=(random_seed, train_size, val_size), raw_dir=raw_dir,
force_reload=force_reload, hash_key=(random_seed, train_size, val_size),
verbose=verbose, force_reload=force_reload,
transform=transform) verbose=verbose,
transform=transform,
)
def process(self): def process(self):
"""process raw data to graph, labels, splitting masks""" """process raw data to graph, labels, splitting masks"""
file_path = os.path.join(self.raw_path, self.file_names[self.name]) file_path = os.path.join(self.raw_path, self.file_names[self.name])
data = io.loadmat(file_path) data = io.loadmat(file_path)
node_features = data['features'].todense() node_features = data["features"].todense()
# remove additional dimension of length 1 in raw .mat file # remove additional dimension of length 1 in raw .mat file
node_labels = data['label'].squeeze() node_labels = data["label"].squeeze()
graph_data = {} graph_data = {}
for relation in self.relations[self.name]: for relation in self.relations[self.name]:
adj = data[relation].tocoo() adj = data[relation].tocoo()
row, col = adj.row, adj.col row, col = adj.row, adj.col
graph_data[(self.node_name[self.name], relation, self.node_name[self.name])] = (row, col) graph_data[
(self.node_name[self.name], relation, self.node_name[self.name])
] = (row, col)
g = heterograph(graph_data) g = heterograph(graph_data)
g.ndata['feature'] = F.tensor(node_features, dtype=F.data_type_dict['float32']) g.ndata["feature"] = F.tensor(
g.ndata['label'] = F.tensor(node_labels, dtype=F.data_type_dict['int64']) node_features, dtype=F.data_type_dict["float32"]
)
g.ndata["label"] = F.tensor(
node_labels, dtype=F.data_type_dict["int64"]
)
self.graph = g self.graph = g
self._random_split(g.ndata['feature'], self.seed, self.train_size, self.val_size) self._random_split(
g.ndata["feature"], self.seed, self.train_size, self.val_size
)
def __getitem__(self, idx): def __getitem__(self, idx):
r""" Get graph object r"""Get graph object
Parameters Parameters
---------- ----------
...@@ -171,51 +185,61 @@ class FraudDataset(DGLBuiltinDataset): ...@@ -171,51 +185,61 @@ class FraudDataset(DGLBuiltinDataset):
def save(self): def save(self):
"""save processed data to directory `self.save_path`""" """save processed data to directory `self.save_path`"""
graph_path = os.path.join(self.save_path, self.name + '_dgl_graph_{}.bin'.format(self.hash)) graph_path = os.path.join(
self.save_path, self.name + "_dgl_graph_{}.bin".format(self.hash)
)
save_graphs(str(graph_path), self.graph) save_graphs(str(graph_path), self.graph)
def load(self): def load(self):
"""load processed data from directory `self.save_path`""" """load processed data from directory `self.save_path`"""
graph_path = os.path.join(self.save_path, self.name + '_dgl_graph_{}.bin'.format(self.hash)) graph_path = os.path.join(
self.save_path, self.name + "_dgl_graph_{}.bin".format(self.hash)
)
graph_list, _ = load_graphs(str(graph_path)) graph_list, _ = load_graphs(str(graph_path))
g = graph_list[0] g = graph_list[0]
self.graph = g self.graph = g
def has_cache(self): def has_cache(self):
"""check whether there are processed data in `self.save_path`""" """check whether there are processed data in `self.save_path`"""
graph_path = os.path.join(self.save_path, self.name + '_dgl_graph_{}.bin'.format(self.hash)) graph_path = os.path.join(
self.save_path, self.name + "_dgl_graph_{}.bin".format(self.hash)
)
return os.path.exists(graph_path) return os.path.exists(graph_path)
def _random_split(self, x, seed=717, train_size=0.7, val_size=0.1): def _random_split(self, x, seed=717, train_size=0.7, val_size=0.1):
"""split the dataset into training set, validation set and testing set""" """split the dataset into training set, validation set and testing set"""
assert 0 <= train_size + val_size <= 1, \ assert 0 <= train_size + val_size <= 1, (
"The sum of valid training set size and validation set size " \ "The sum of valid training set size and validation set size "
"must between 0 and 1 (inclusive)." "must between 0 and 1 (inclusive)."
)
N = x.shape[0] N = x.shape[0]
index = np.arange(N) index = np.arange(N)
if self.name == 'amazon': if self.name == "amazon":
# 0-3304 are unlabeled nodes # 0-3304 are unlabeled nodes
index = np.arange(3305, N) index = np.arange(3305, N)
index = np.random.RandomState(seed).permutation(index) index = np.random.RandomState(seed).permutation(index)
train_idx = index[:int(train_size * len(index))] train_idx = index[: int(train_size * len(index))]
val_idx = index[len(index) - int(val_size * len(index)):] val_idx = index[len(index) - int(val_size * len(index)) :]
test_idx = index[int(train_size * len(index)):len(index) - int(val_size * len(index))] test_idx = index[
int(train_size * len(index)) : len(index)
- int(val_size * len(index))
]
train_mask = np.zeros(N, dtype=np.bool) train_mask = np.zeros(N, dtype=np.bool)
val_mask = np.zeros(N, dtype=np.bool) val_mask = np.zeros(N, dtype=np.bool)
test_mask = np.zeros(N, dtype=np.bool) test_mask = np.zeros(N, dtype=np.bool)
train_mask[train_idx] = True train_mask[train_idx] = True
val_mask[val_idx] = True val_mask[val_idx] = True
test_mask[test_idx] = True test_mask[test_idx] = True
self.graph.ndata['train_mask'] = F.tensor(train_mask) self.graph.ndata["train_mask"] = F.tensor(train_mask)
self.graph.ndata['val_mask'] = F.tensor(val_mask) self.graph.ndata["val_mask"] = F.tensor(val_mask)
self.graph.ndata['test_mask'] = F.tensor(test_mask) self.graph.ndata["test_mask"] = F.tensor(test_mask)
class FraudYelpDataset(FraudDataset): class FraudYelpDataset(FraudDataset):
r""" Fraud Yelp Dataset r"""Fraud Yelp Dataset
The Yelp dataset includes hotel and restaurant reviews filtered (spam) and recommended The Yelp dataset includes hotel and restaurant reviews filtered (spam) and recommended
(legitimate) by Yelp. A spam review detection task can be conducted, which is a binary (legitimate) by Yelp. A spam review detection task can be conducted, which is a binary
...@@ -278,20 +302,30 @@ class FraudYelpDataset(FraudDataset): ...@@ -278,20 +302,30 @@ class FraudYelpDataset(FraudDataset):
>>> label = graph.ndata['label'] >>> label = graph.ndata['label']
""" """
def __init__(self, raw_dir=None, random_seed=717, train_size=0.7, def __init__(
val_size=0.1, force_reload=False, verbose=True, transform=None): self,
super(FraudYelpDataset, self).__init__(name='yelp', raw_dir=None,
raw_dir=raw_dir, random_seed=717,
random_seed=random_seed, train_size=0.7,
train_size=train_size, val_size=0.1,
val_size=val_size, force_reload=False,
force_reload=force_reload, verbose=True,
verbose=verbose, transform=None,
transform=transform) ):
super(FraudYelpDataset, self).__init__(
name="yelp",
raw_dir=raw_dir,
random_seed=random_seed,
train_size=train_size,
val_size=val_size,
force_reload=force_reload,
verbose=verbose,
transform=transform,
)
class FraudAmazonDataset(FraudDataset): class FraudAmazonDataset(FraudDataset):
r""" Fraud Amazon Dataset r"""Fraud Amazon Dataset
The Amazon dataset includes product reviews under the Musical Instruments category. The Amazon dataset includes product reviews under the Musical Instruments category.
Users with more than 80% helpful votes are labelled as benign entities and users with Users with more than 80% helpful votes are labelled as benign entities and users with
...@@ -359,13 +393,23 @@ class FraudAmazonDataset(FraudDataset): ...@@ -359,13 +393,23 @@ class FraudAmazonDataset(FraudDataset):
>>> label = graph.ndata['label'] >>> label = graph.ndata['label']
""" """
def __init__(self, raw_dir=None, random_seed=717, train_size=0.7, def __init__(
val_size=0.1, force_reload=False, verbose=True, transform=None): self,
super(FraudAmazonDataset, self).__init__(name='amazon', raw_dir=None,
raw_dir=raw_dir, random_seed=717,
random_seed=random_seed, train_size=0.7,
train_size=train_size, val_size=0.1,
val_size=val_size, force_reload=False,
force_reload=force_reload, verbose=True,
verbose=verbose, transform=None,
transform=transform) ):
super(FraudAmazonDataset, self).__init__(
name="amazon",
raw_dir=raw_dir,
random_seed=random_seed,
train_size=train_size,
val_size=val_size,
force_reload=force_reload,
verbose=verbose,
transform=transform,
)
""" GDELT dataset for temporal graph """ """ GDELT dataset for temporal graph """
import numpy as np
import os import os
from .dgl_dataset import DGLBuiltinDataset import numpy as np
from .utils import loadtxt, save_info, load_info, _get_dgl_url
from ..convert import graph as dgl_graph
from .. import backend as F from .. import backend as F
from ..convert import graph as dgl_graph
from .dgl_dataset import DGLBuiltinDataset
from .utils import _get_dgl_url, load_info, loadtxt, save_info
class GDELTDataset(DGLBuiltinDataset): class GDELTDataset(DGLBuiltinDataset):
...@@ -69,23 +70,32 @@ class GDELTDataset(DGLBuiltinDataset): ...@@ -69,23 +70,32 @@ class GDELTDataset(DGLBuiltinDataset):
.... ....
>>> >>>
""" """
def __init__(self, mode='train', raw_dir=None,
force_reload=False, verbose=False, transform=None): def __init__(
self,
mode="train",
raw_dir=None,
force_reload=False,
verbose=False,
transform=None,
):
mode = mode.lower() mode = mode.lower()
assert mode in ['train', 'valid', 'test'], "Mode not valid." assert mode in ["train", "valid", "test"], "Mode not valid."
self.mode = mode self.mode = mode
self.num_nodes = 23033 self.num_nodes = 23033
_url = _get_dgl_url('dataset/gdelt.zip') _url = _get_dgl_url("dataset/gdelt.zip")
super(GDELTDataset, self).__init__(name='GDELT', super(GDELTDataset, self).__init__(
url=_url, name="GDELT",
raw_dir=raw_dir, url=_url,
force_reload=force_reload, raw_dir=raw_dir,
verbose=verbose, force_reload=force_reload,
transform=transform) verbose=verbose,
transform=transform,
)
def process(self): def process(self):
file_path = os.path.join(self.raw_path, self.mode + '.txt') file_path = os.path.join(self.raw_path, self.mode + ".txt")
self.data = loadtxt(file_path, delimiter='\t').astype(np.int64) self.data = loadtxt(file_path, delimiter="\t").astype(np.int64)
# The source code is not released, but the paper indicates there're # The source code is not released, but the paper indicates there're
# totally 137 samples. The cutoff below has exactly 137 samples. # totally 137 samples. The cutoff below has exactly 137 samples.
...@@ -94,25 +104,34 @@ class GDELTDataset(DGLBuiltinDataset): ...@@ -94,25 +104,34 @@ class GDELTDataset(DGLBuiltinDataset):
self._end_time = self.time_index.max() self._end_time = self.time_index.max()
def has_cache(self): def has_cache(self):
info_path = os.path.join(self.save_path, self.mode + '_info.pkl') info_path = os.path.join(self.save_path, self.mode + "_info.pkl")
return os.path.exists(info_path) return os.path.exists(info_path)
def save(self): def save(self):
info_path = os.path.join(self.save_path, self.mode + '_info.pkl') info_path = os.path.join(self.save_path, self.mode + "_info.pkl")
save_info(info_path, {'data': self.data, save_info(
'time_index': self.time_index, info_path,
'start_time': self.start_time, {
'end_time': self.end_time}) "data": self.data,
"time_index": self.time_index,
"start_time": self.start_time,
"end_time": self.end_time,
},
)
def load(self): def load(self):
info_path = os.path.join(self.save_path, self.mode + '_info.pkl') info_path = os.path.join(self.save_path, self.mode + "_info.pkl")
info = load_info(info_path) info = load_info(info_path)
self.data, self.time_index, self._start_time, self._end_time = \ self.data, self.time_index, self._start_time, self._end_time = (
info['data'], info['time_index'], info['start_time'], info['end_time'] info["data"],
info["time_index"],
info["start_time"],
info["end_time"],
)
@property @property
def start_time(self): def start_time(self):
r""" Start time of events in the temporal graph r"""Start time of events in the temporal graph
Returns Returns
------- -------
...@@ -122,7 +141,7 @@ class GDELTDataset(DGLBuiltinDataset): ...@@ -122,7 +141,7 @@ class GDELTDataset(DGLBuiltinDataset):
@property @property
def end_time(self): def end_time(self):
r""" End time of events in the temporal graph r"""End time of events in the temporal graph
Returns Returns
------- -------
...@@ -131,7 +150,7 @@ class GDELTDataset(DGLBuiltinDataset): ...@@ -131,7 +150,7 @@ class GDELTDataset(DGLBuiltinDataset):
return self._end_time return self._end_time
def __getitem__(self, t): def __getitem__(self, t):
r""" Get graph by with events before time `t + self.start_time` r"""Get graph by with events before time `t + self.start_time`
Parameters Parameters
---------- ----------
...@@ -153,7 +172,9 @@ class GDELTDataset(DGLBuiltinDataset): ...@@ -153,7 +172,9 @@ class GDELTDataset(DGLBuiltinDataset):
edges = self.data[row_mask][:, [0, 2]] edges = self.data[row_mask][:, [0, 2]]
rate = self.data[row_mask][:, 1] rate = self.data[row_mask][:, 1]
g = dgl_graph((edges[:, 0], edges[:, 1])) g = dgl_graph((edges[:, 0], edges[:, 1]))
g.edata['rel_type'] = F.tensor(rate.reshape(-1, 1), dtype=F.data_type_dict['int64']) g.edata["rel_type"] = F.tensor(
rate.reshape(-1, 1), dtype=F.data_type_dict["int64"]
)
if self._transform is not None: if self._transform is not None:
g = self._transform(g) g = self._transform(g)
return g return g
...@@ -169,7 +190,7 @@ class GDELTDataset(DGLBuiltinDataset): ...@@ -169,7 +190,7 @@ class GDELTDataset(DGLBuiltinDataset):
@property @property
def is_temporal(self): def is_temporal(self):
r""" Does the dataset contain temporal graphs r"""Does the dataset contain temporal graphs
Returns Returns
------- -------
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment