Unverified Commit 4bd4d6e3 authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[Lint] Pylint (#330)

* fix lint for graph_index.py

* pylint for base.py

* pylint for batched_graph.py

* pylint for frame.py; simplify and fix bugs in frame when index is slice type

* pylint for graph.py

* pylint for immutable_graph_index.py

* pylint for init.py

* pylint for rest files in root package

* pylint for _ffi package

* pylint for function package

* pylint for runtime package

* pylint for runtime.ir package

* add pylint to ci

* fix mx tests

* fix lint errors

* fix ci

* fix as requested

* fix lint
parent 1e50cd2e
"""Package for DGL's internal IR."""
from .executor import * from .executor import *
from .program import get_current_prog, prog from .program import get_current_prog, prog
"""Module for executors."""
# pylint: disable=invalid-name
from __future__ import absolute_import from __future__ import absolute_import
from abc import abstractmethod from abc import abstractmethod
import functools import functools
import operator import operator
from ...base import DGLError
from ... import backend as F from ... import backend as F
from ...frame import FrameRef, Frame from ...frame import FrameRef, Frame
from ... import utils from ... import utils
...@@ -14,7 +15,29 @@ from . import var ...@@ -14,7 +15,29 @@ from . import var
from .var import VarType from .var import VarType
from .registry import IR_REGISTRY from .registry import IR_REGISTRY
__all__ = [
'OpCode', 'Executor',
'NodeUDFExecutor', 'NODE_UDF',
'EdgeUDFExecutor', 'EDGE_UDF',
'SPMVExecutor', 'SPMV',
'SPMVWithDataExecutor', 'SPMV_WITH_DATA',
'ReadExecutor', 'READ',
'ReadColExecutor', 'READ_COL',
'ReadRowExecutor', 'READ_ROW',
'MergeRowExecutor', 'MERGE_ROW',
'UpdateDictExecutor', 'UPDATE_DICT',
'NewDictExecutor', 'NEW_DICT',
'Write_Executor', 'WRITE_',
'WriteCol_Executor', 'WRITE_COL_',
'WriteRow_Executor', 'WRITE_ROW_',
'WriteDict_Executor', 'WRITE_DICT_',
'AppendRow_Executor', 'APPEND_ROW_',
'WriteRowInplace_Executor', 'WRITE_ROW_INPLACE_',
'ClearFrame_Executor', 'CLEAR_FRAME_',
]
class OpCode(object): class OpCode(object):
"""Opcode for all the executor types."""
# immutable op # immutable op
NODE_UDF = 0 NODE_UDF = 0
EDGE_UDF = 1 EDGE_UDF = 1
...@@ -37,23 +60,49 @@ class OpCode(object): ...@@ -37,23 +60,49 @@ class OpCode(object):
CLEAR_FRAME_ = 27 CLEAR_FRAME_ = 27
class Executor(object): class Executor(object):
"""Base executor class.
An executor is similar to a basic operator in dataflow-based framework.
The executor can be evaluated by the ``run`` function.
"""
@abstractmethod @abstractmethod
def opcode(self): def opcode(self):
"""Return the opcode of this executor."""
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def arg_vars(self): def arg_vars(self):
"""Return the argument variable list of this executor."""
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def ret_var(self): def ret_var(self):
"""Return the result variable of this executor."""
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def run(self): def run(self):
"""Evaluate this executor.
The function takes no argument and returns none, which means all the
argument and result variables must be pre-bound.
"""
raise NotImplementedError raise NotImplementedError
class NodeUDFExecutor(Executor): class NodeUDFExecutor(Executor):
"""Executor for Node UDF call.
Parameters
----------
fn : var.Var
The UDF.
fdnode : var.Var
The node feature dict.
fdmail : var.Var
The mailbox data dict.
ret : var.Var
The return new node feature dict.
"""
def __init__(self, fn, fdnode, fdmail, ret): def __init__(self, fn, fdnode, fdmail, ret):
self.fn = fn self.fn = fn
self.fdnode = fdnode self.fdnode = fdnode
...@@ -88,13 +137,48 @@ IR_REGISTRY[OpCode.NODE_UDF] = { ...@@ -88,13 +137,48 @@ IR_REGISTRY[OpCode.NODE_UDF] = {
'ret_type' : VarType.FEAT_DICT, 'ret_type' : VarType.FEAT_DICT,
'executor_cls' : NodeUDFExecutor, 'executor_cls' : NodeUDFExecutor,
} }
def NODE_UDF(fn, fdnode, fdmail=None, ret=None): def NODE_UDF(fn, fdnode, fdmail=None, ret=None):
"""Apply the node UDF and get the new node feature symbolically.
Parameters
----------
fn : var.Var
The UDF.
fdnode : var.Var
The node feature dict.
fdmail : var.Var
The mailbox data dict.
ret : var.Var, optional
The return variable for new node feature dict. If not give,
a new variable will be created.
Returns
-------
var.Var
Variable for the result.
"""
reg = IR_REGISTRY[OpCode.NODE_UDF] reg = IR_REGISTRY[OpCode.NODE_UDF]
ret = var.new(reg['ret_type']) if ret is None else ret ret = var.new(reg['ret_type']) if ret is None else ret
get_current_prog().issue(reg['executor_cls'](fn, fdnode, fdmail, ret)) get_current_prog().issue(reg['executor_cls'](fn, fdnode, fdmail, ret))
return ret return ret
class EdgeUDFExecutor(Executor): class EdgeUDFExecutor(Executor):
"""Executor for edge UDF call.
Parameters
----------
fn : var.Var
The UDF.
fdsrc : var.Var
The src node feature dict.
fdedge : var.Var
The edge feature dict.
fddst : var.Var
The dst node feature dict.
ret : var.Var
The return new edge feature dict.
"""
def __init__(self, fn, fdsrc, fdedge, fddst, ret): def __init__(self, fn, fdsrc, fdedge, fddst, ret):
self.fn = fn self.fn = fn
self.fdsrc = fdsrc self.fdsrc = fdsrc
...@@ -126,12 +210,46 @@ IR_REGISTRY[OpCode.EDGE_UDF] = { ...@@ -126,12 +210,46 @@ IR_REGISTRY[OpCode.EDGE_UDF] = {
'executor_cls' : EdgeUDFExecutor, 'executor_cls' : EdgeUDFExecutor,
} }
def EDGE_UDF(fn, fdsrc, fdedge, fddst, ret=None): def EDGE_UDF(fn, fdsrc, fdedge, fddst, ret=None):
"""Apply the edge UDF and get the new edge feature symbolically.
Parameters
----------
fn : var.Var
The UDF.
fdsrc : var.Var
The src node feature dict.
fdedge : var.Var
The edge feature dict.
fddst : var.Var
The dst node feature dict.
ret : var.Var, optional
The return variable for new node feature dict. If not give,
a new variable will be created.
Returns
-------
var.Var
Variable for the result.
"""
reg = IR_REGISTRY[OpCode.EDGE_UDF] reg = IR_REGISTRY[OpCode.EDGE_UDF]
ret = var.new(reg['ret_type']) if ret is None else ret ret = var.new(reg['ret_type']) if ret is None else ret
get_current_prog().issue(reg['executor_cls'](fn, fdsrc, fdedge, fddst, ret)) get_current_prog().issue(reg['executor_cls'](fn, fdsrc, fdedge, fddst, ret))
return ret return ret
class ReadExecutor(Executor): class ReadExecutor(Executor):
"""Executor for read data from feature dict.
Parameters
----------
fd : var.Var
The feature dict.
row : var.Var
The row index.
col : var.Var
The column name.
ret : var.Var
The return feature tensor.
"""
def __init__(self, fd, row, col, ret): def __init__(self, fd, row, col, ret):
self.fd = fd self.fd = fd
self.row = row self.row = row
...@@ -159,13 +277,43 @@ IR_REGISTRY[OpCode.READ] = { ...@@ -159,13 +277,43 @@ IR_REGISTRY[OpCode.READ] = {
'ret_type' : VarType.FEAT, 'ret_type' : VarType.FEAT,
'executor_cls' : ReadExecutor, 'executor_cls' : ReadExecutor,
} }
def READ(fd, row, col, ret=None): def READ(fd, row, col, ret=None):
"""Read the feature data from the dictionary specified by the row and column symbolically.
Parameters
----------
fd : var.Var
The feature dict.
row : var.Var
The row index.
col : var.Var
The column name.
ret : var.Var, optional
The return feature tensor. If not give, a new variable will be created.
Returns
-------
var.Var
Variable for the result.
"""
reg = IR_REGISTRY[OpCode.READ] reg = IR_REGISTRY[OpCode.READ]
ret = var.new(reg['ret_type']) if ret is None else ret ret = var.new(reg['ret_type']) if ret is None else ret
get_current_prog().issue(reg['executor_cls'](fd, row, col, ret)) get_current_prog().issue(reg['executor_cls'](fd, row, col, ret))
return ret return ret
class ReadColExecutor(Executor): class ReadColExecutor(Executor):
"""Executor for read column data from feature dict.
Parameters
----------
fd : var.Var
The feature dict.
col : var.Var
The column name.
ret : var.Var
The return feature tensor.
"""
def __init__(self, fd, col, ret): def __init__(self, fd, col, ret):
self.fd = fd self.fd = fd
self.col = col self.col = col
...@@ -191,13 +339,41 @@ IR_REGISTRY[OpCode.READ_COL] = { ...@@ -191,13 +339,41 @@ IR_REGISTRY[OpCode.READ_COL] = {
'ret_type' : VarType.FEAT, 'ret_type' : VarType.FEAT,
'executor_cls' : ReadColExecutor, 'executor_cls' : ReadColExecutor,
} }
def READ_COL(fd, col, ret=None): def READ_COL(fd, col, ret=None):
"""Read the column data from the dictionary.
Parameters
----------
fd : var.Var
The feature dict.
col : var.Var
The column name.
ret : var.Var, optional
The return feature tensor. If not give, a new variable will be created.
Returns
-------
var.Var
Variable for the result.
"""
reg = IR_REGISTRY[OpCode.READ_COL] reg = IR_REGISTRY[OpCode.READ_COL]
ret = var.new(reg['ret_type']) if ret is None else ret ret = var.new(reg['ret_type']) if ret is None else ret
get_current_prog().issue(reg['executor_cls'](fd, col, ret)) get_current_prog().issue(reg['executor_cls'](fd, col, ret))
return ret return ret
class ReadRowExecutor(Executor): class ReadRowExecutor(Executor):
"""Executor for read row data from feature dict.
Parameters
----------
fd : var.Var
The feature dict.
row : var.Var
The row index.
ret : var.Var
The return feature tensor.
"""
def __init__(self, fd, row, ret): def __init__(self, fd, row, ret):
self.fd = fd self.fd = fd
self.row = row self.row = row
...@@ -223,13 +399,42 @@ IR_REGISTRY[OpCode.READ_ROW] = { ...@@ -223,13 +399,42 @@ IR_REGISTRY[OpCode.READ_ROW] = {
'ret_type' : VarType.FEAT_DICT, 'ret_type' : VarType.FEAT_DICT,
'executor_cls' : ReadRowExecutor, 'executor_cls' : ReadRowExecutor,
} }
def READ_ROW(fd, row, ret=None): def READ_ROW(fd, row, ret=None):
"""Read the row data from the dictionary.
Parameters
----------
fd : var.Var
The feature dict.
row : var.Var
The row index.
ret : var.Var, optional
The return feature tensor. If not give, a new variable will be created.
Returns
-------
var.Var
Variable for the result.
"""
reg = IR_REGISTRY[OpCode.READ_ROW] reg = IR_REGISTRY[OpCode.READ_ROW]
ret = var.new(reg['ret_type']) if ret is None else ret ret = var.new(reg['ret_type']) if ret is None else ret
get_current_prog().issue(reg['executor_cls'](fd, row, ret)) get_current_prog().issue(reg['executor_cls'](fd, row, ret))
return ret return ret
class SPMVExecutor(Executor): class SPMVExecutor(Executor):
"""Executor for sparse-matrix-dense-matrix multiply.
Parameters
----------
spA : var.Var
Variable for sparse matrix lambda. The lambda returns the sparse matrix
given a context object.
B : var.Var
Variable for the dense feature tensor.
ret : var.Var
Variable for the result.
"""
def __init__(self, spA, B, ret): def __init__(self, spA, B, ret):
self.spA = spA self.spA = spA
self.B = B self.B = B
...@@ -258,8 +463,7 @@ class SPMVExecutor(Executor): ...@@ -258,8 +463,7 @@ class SPMVExecutor(Executor):
# Flatten the dim 1:~ # Flatten the dim 1:~
B_shape = F.shape(B) B_shape = F.shape(B)
feat_shape = B_shape[1:] feat_shape = B_shape[1:]
tmp_B_shape = (B_shape[0], tmp_B_shape = (B_shape[0], functools.reduce(operator.mul, feat_shape, 1))
functools.reduce(operator.mul, feat_shape, 1))
B = F.reshape(B, tmp_B_shape) B = F.reshape(B, tmp_B_shape)
C = F.spmm(spA, B) C = F.spmm(spA, B)
C_shape = (F.shape(C)[0],) + feat_shape C_shape = (F.shape(C)[0],) + feat_shape
...@@ -274,13 +478,45 @@ IR_REGISTRY[OpCode.SPMV] = { ...@@ -274,13 +478,45 @@ IR_REGISTRY[OpCode.SPMV] = {
'ret_type' : VarType.FEAT, 'ret_type' : VarType.FEAT,
'executor_cls' : SPMVExecutor, 'executor_cls' : SPMVExecutor,
} }
def SPMV(spA, B, ret=None): def SPMV(spA, B, ret=None):
"""Perform sparse-matrix-dense-matrix multiply symbolically.
Parameters
----------
spA : var.Var
Variable for sparse matrix lambda. The lambda returns the sparse matrix
given a context object.
B : var.Var
Variable for the dense feature tensor.
ret : var.Var, optional
Variable for the result. If not give, a new variable will be created.
Returns
-------
var.Var
Variable for the result.
"""
reg = IR_REGISTRY[OpCode.SPMV] reg = IR_REGISTRY[OpCode.SPMV]
ret = var.new(reg['ret_type']) if ret is None else ret ret = var.new(reg['ret_type']) if ret is None else ret
get_current_prog().issue(reg['executor_cls'](spA, B, ret)) get_current_prog().issue(reg['executor_cls'](spA, B, ret))
return ret return ret
class SPMVWithDataExecutor(Executor): class SPMVWithDataExecutor(Executor):
"""Executor for sparse-matrix-dense-matrix multiply with provided sparse data.
Parameters
----------
spA : var.Var
Variable for sparse matrix lambda. The lambda returns the sparse matrix
given a context object.
A_data : var.Var
Variable for the sparse matrix data.
B : var.Var
Variable for the dense feature tensor.
ret : var.Var
Variable for the result.
"""
def __init__(self, spA, A_data, B, ret): def __init__(self, spA, A_data, B, ret):
self.spA = spA self.spA = spA
self.A_data = A_data self.A_data = A_data
...@@ -320,8 +556,7 @@ class SPMVWithDataExecutor(Executor): ...@@ -320,8 +556,7 @@ class SPMVWithDataExecutor(Executor):
# Flatten the dim 1:~ # Flatten the dim 1:~
B_shape = F.shape(B) B_shape = F.shape(B)
feat_shape = B_shape[1:] feat_shape = B_shape[1:]
tmp_B_shape = (B_shape[0], tmp_B_shape = (B_shape[0], functools.reduce(operator.mul, feat_shape, 1))
functools.reduce(operator.mul, feat_shape, 1))
B = F.reshape(B, tmp_B_shape) B = F.reshape(B, tmp_B_shape)
C = F.spmm(spA, B) C = F.spmm(spA, B)
C_shape = (F.shape(C)[0],) + feat_shape C_shape = (F.shape(C)[0],) + feat_shape
...@@ -336,13 +571,44 @@ IR_REGISTRY[OpCode.SPMV_WITH_DATA] = { ...@@ -336,13 +571,44 @@ IR_REGISTRY[OpCode.SPMV_WITH_DATA] = {
'ret_type' : VarType.FEAT, 'ret_type' : VarType.FEAT,
'executor_cls' : SPMVWithDataExecutor, 'executor_cls' : SPMVWithDataExecutor,
} }
def SPMV_WITH_DATA(spA, A_data, B, ret=None): def SPMV_WITH_DATA(spA, A_data, B, ret=None):
"""Perform sparse-matrix-dense-matrix multiply with sparse data symbolically.
Parameters
----------
spA : var.Var
Variable for sparse matrix lambda. The lambda returns the sparse matrix
given a context object.
A_data : var.Var
Variable for the sparse matrix data.
B : var.Var
Variable for the dense feature tensor.
ret : var.Var, optional
Variable for the result. If not give, a new variable will be created.
Returns
-------
var.Var
Variable for the result.
"""
reg = IR_REGISTRY[OpCode.SPMV_WITH_DATA] reg = IR_REGISTRY[OpCode.SPMV_WITH_DATA]
ret = var.new(reg['ret_type']) if ret is None else ret ret = var.new(reg['ret_type']) if ret is None else ret
get_current_prog().issue(reg['executor_cls'](spA, A_data, B, ret)) get_current_prog().issue(reg['executor_cls'](spA, A_data, B, ret))
return ret return ret
class MergeRowExecutor(Executor): class MergeRowExecutor(Executor):
"""Executor for merge row data according to the given order.
Parameters
----------
order : var.Var
The order index.
fd_list : list of var.Var
The list of row data variables. Each represents a feature dict.
ret : var.Var
Variable for the result.
"""
def __init__(self, order, fd_list, ret): def __init__(self, order, fd_list, ret):
self.order = order self.order = order
self.fd_list = fd_list self.fd_list = fd_list
...@@ -373,13 +639,43 @@ IR_REGISTRY[OpCode.MERGE_ROW] = { ...@@ -373,13 +639,43 @@ IR_REGISTRY[OpCode.MERGE_ROW] = {
'ret_type' : VarType.FEAT_DICT, 'ret_type' : VarType.FEAT_DICT,
'executor_cls' : MergeRowExecutor, 'executor_cls' : MergeRowExecutor,
} }
def MERGE_ROW(idx_list, fd_list, ret=None): def MERGE_ROW(idx_list, fd_list, ret=None):
"""Merge row data according to the given order symbolically.
Parameters
----------
order : var.Var
The order index.
fd_list : list of var.Var
The list of row data variables. Each represents a feature dict.
ret : var.Var, optional
Variable for the result. If not give, a new variable will be created.
Returns
-------
var.Var
Variable for the result.
"""
reg = IR_REGISTRY[OpCode.MERGE_ROW] reg = IR_REGISTRY[OpCode.MERGE_ROW]
ret = var.new(reg['ret_type']) if ret is None else ret ret = var.new(reg['ret_type']) if ret is None else ret
get_current_prog().issue(reg['executor_cls'](idx_list, fd_list, ret)) get_current_prog().issue(reg['executor_cls'](idx_list, fd_list, ret))
return ret return ret
class UpdateDictExecutor(Executor): class UpdateDictExecutor(Executor):
"""Executor for update feature dictionary with another one.
Similar to python dict's update but return a new dictionary.
Parameters
----------
fd1 : var.Var
Variable for the feature dict to be updated.
fd2 : var.Var
Variable for the provided feature dict.
ret : var.Var
Variable for the result.
"""
def __init__(self, fd1, fd2, ret): def __init__(self, fd1, fd2, ret):
self.fd1 = fd1 self.fd1 = fd1
self.fd2 = fd2 self.fd2 = fd2
...@@ -398,7 +694,7 @@ class UpdateDictExecutor(Executor): ...@@ -398,7 +694,7 @@ class UpdateDictExecutor(Executor):
fd1_data = self.fd1.data fd1_data = self.fd1.data
fd2_data = self.fd2.data fd2_data = self.fd2.data
if (isinstance(fd1_data, utils.LazyDict) if (isinstance(fd1_data, utils.LazyDict)
or isinstance(fd2_data, utils.LazyDict)): or isinstance(fd2_data, utils.LazyDict)):
# NOTE: fd2 has higher priority # NOTE: fd2 has higher priority
ret_data = utils.HybridDict(fd2_data, fd1_data) ret_data = utils.HybridDict(fd2_data, fd1_data)
else: else:
...@@ -412,13 +708,45 @@ IR_REGISTRY[OpCode.UPDATE_DICT] = { ...@@ -412,13 +708,45 @@ IR_REGISTRY[OpCode.UPDATE_DICT] = {
'ret_type' : VarType.FEAT_DICT, 'ret_type' : VarType.FEAT_DICT,
'executor_cls' : UpdateDictExecutor, 'executor_cls' : UpdateDictExecutor,
} }
def UPDATE_DICT(fd1, fd2, ret=None): def UPDATE_DICT(fd1, fd2, ret=None):
"""Executor for update feature dictionary with another one.
Similar to python dict's update but return a new dictionary.
Parameters
----------
fd1 : var.Var
Variable for the feature dict to be updated.
fd2 : var.Var
Variable for the provided feature dict.
ret : var.Var, optional
Variable for the result. If not give, a new variable will be created.
Returns
-------
var.Var
Variable for the result.
"""
reg = IR_REGISTRY[OpCode.UPDATE_DICT] reg = IR_REGISTRY[OpCode.UPDATE_DICT]
ret = var.new(reg['ret_type']) if ret is None else ret ret = var.new(reg['ret_type']) if ret is None else ret
get_current_prog().issue(reg['executor_cls'](fd1, fd2, ret)) get_current_prog().issue(reg['executor_cls'](fd1, fd2, ret))
return ret return ret
class NewDictExecutor(Executor): class NewDictExecutor(Executor):
"""Executor for creating new feature dictionary.
Parameters
----------
fd_init : var.Var
The feat dict to borrow initializer.
idx : var.Var
The index to look for number or rows.
fd_scheme : var.Var
The feat dict to look for column scheme.
ret : var.Var
Variable for the result.
"""
def __init__(self, fd_init, idx, fd_scheme, ret): def __init__(self, fd_init, idx, fd_scheme, ret):
self.fd_init = fd_init # the feat dict to borrow initializer self.fd_init = fd_init # the feat dict to borrow initializer
self.idx = idx # the index to look for number or rows self.idx = idx # the index to look for number or rows
...@@ -455,13 +783,45 @@ IR_REGISTRY[OpCode.NEW_DICT] = { ...@@ -455,13 +783,45 @@ IR_REGISTRY[OpCode.NEW_DICT] = {
'ret_type' : VarType.FEAT_DICT, 'ret_type' : VarType.FEAT_DICT,
'executor_cls' : NewDictExecutor, 'executor_cls' : NewDictExecutor,
} }
def NEW_DICT(fd_init, idx, fd_scheme, ret=None): def NEW_DICT(fd_init, idx, fd_scheme, ret=None):
"""Create a new dictionary symbolically.
Parameters
----------
fd_init : var.Var
The feat dict to borrow initializer.
idx : var.Var
The index to look for number or rows.
fd_scheme : var.Var
The feat dict to look for column scheme.
ret : var.Var
Variable for the result. If not give, a new variable will be created.
Returns
-------
var.Var
Variable for the result.
"""
reg = IR_REGISTRY[OpCode.NEW_DICT] reg = IR_REGISTRY[OpCode.NEW_DICT]
ret = var.new(reg['ret_type']) if ret is None else ret ret = var.new(reg['ret_type']) if ret is None else ret
get_current_prog().issue(reg['executor_cls'](fd_init, idx, fd_scheme, ret)) get_current_prog().issue(reg['executor_cls'](fd_init, idx, fd_scheme, ret))
return ret return ret
class Write_Executor(Executor): class Write_Executor(Executor):
"""Executor for writing the given data to the feature dict.
Parameters
----------
fd : var.Var
The feature dict.
row : var.Var
The row index.
col : var.Var
The column name.
val : var.Var
The given feature data.
"""
def __init__(self, fd, row, col, val): def __init__(self, fd, row, col, val):
self.fd = fd self.fd = fd
self.row = row self.row = row
...@@ -490,11 +850,36 @@ IR_REGISTRY[OpCode.WRITE_] = { ...@@ -490,11 +850,36 @@ IR_REGISTRY[OpCode.WRITE_] = {
'ret_type' : None, 'ret_type' : None,
'executor_cls' : Write_Executor, 'executor_cls' : Write_Executor,
} }
def WRITE_(fd, row, col, val): def WRITE_(fd, row, col, val):
"""Write the given data to the feature dict symbolically.
Parameters
----------
fd : var.Var
The feature dict.
row : var.Var
The row index.
col : var.Var
The column name.
val : var.Var
The given feature data.
"""
reg = IR_REGISTRY[OpCode.WRITE_] reg = IR_REGISTRY[OpCode.WRITE_]
get_current_prog().issue(reg['executor_cls'](fd, row, col, val)) get_current_prog().issue(reg['executor_cls'](fd, row, col, val))
class WriteCol_Executor(Executor): class WriteCol_Executor(Executor):
"""Executor for writing the given column data to the feature dict.
Parameters
----------
fd : var.Var
The feature dict.
col : var.Var
The column name.
val : var.Var
The given feature data.
"""
def __init__(self, fd, col, val): def __init__(self, fd, col, val):
self.fd = fd self.fd = fd
self.col = col self.col = col
...@@ -521,11 +906,34 @@ IR_REGISTRY[OpCode.WRITE_COL_] = { ...@@ -521,11 +906,34 @@ IR_REGISTRY[OpCode.WRITE_COL_] = {
'ret_type' : None, 'ret_type' : None,
'executor_cls' : WriteCol_Executor, 'executor_cls' : WriteCol_Executor,
} }
def WRITE_COL_(fd, col, val): def WRITE_COL_(fd, col, val):
"""Writing the given column data to the feature dict symbolically.
Parameters
----------
fd : var.Var
The feature dict.
col : var.Var
The column name.
val : var.Var
The given feature data.
"""
reg = IR_REGISTRY[OpCode.WRITE_COL_] reg = IR_REGISTRY[OpCode.WRITE_COL_]
get_current_prog().issue(reg['executor_cls'](fd, col, val)) get_current_prog().issue(reg['executor_cls'](fd, col, val))
class WriteRow_Executor(Executor): class WriteRow_Executor(Executor):
"""Executor for writing the given row data to the feature dict.
Parameters
----------
fd : var.Var
The feature dict.
row : var.Var
The row index.
val : var.Var
The given feature data.
"""
def __init__(self, fd, row, val): def __init__(self, fd, row, val):
self.fd = fd self.fd = fd
self.row = row self.row = row
...@@ -552,11 +960,34 @@ IR_REGISTRY[OpCode.WRITE_ROW_] = { ...@@ -552,11 +960,34 @@ IR_REGISTRY[OpCode.WRITE_ROW_] = {
'ret_type' : None, 'ret_type' : None,
'executor_cls' : WriteRow_Executor, 'executor_cls' : WriteRow_Executor,
} }
def WRITE_ROW_(fd, row, val): def WRITE_ROW_(fd, row, val):
"""Write the given row data to the feature dict symbolically.
Parameters
----------
fd : var.Var
The feature dict.
row : var.Var
The row index.
val : var.Var
The given feature data.
"""
reg = IR_REGISTRY[OpCode.WRITE_ROW_] reg = IR_REGISTRY[OpCode.WRITE_ROW_]
get_current_prog().issue(reg['executor_cls'](fd, row, val)) get_current_prog().issue(reg['executor_cls'](fd, row, val))
class WriteRowInplace_Executor(Executor): class WriteRowInplace_Executor(Executor):
"""Executor for writing the given row data to the feature dict in-place.
Parameters
----------
fd : var.Var
The feature dict.
row : var.Var
The row index.
val : var.Var
The given feature data.
"""
def __init__(self, fd, row, val): def __init__(self, fd, row, val):
self.fd = fd self.fd = fd
self.row = row self.row = row
...@@ -575,7 +1006,7 @@ class WriteRowInplace_Executor(Executor): ...@@ -575,7 +1006,7 @@ class WriteRowInplace_Executor(Executor):
fd_data = self.fd.data # feature dict fd_data = self.fd.data # feature dict
row_data = self.row.data # idx row_data = self.row.data # idx
val_data = self.val.data val_data = self.val.data
fd_data.set_item_inplace(row_data, val_data, inplace=True) fd_data.update_data(row_data, val_data, inplace=True)
IR_REGISTRY[OpCode.WRITE_ROW_INPLACE_] = { IR_REGISTRY[OpCode.WRITE_ROW_INPLACE_] = {
'name' : 'WRITE_ROW_INPLACE_', 'name' : 'WRITE_ROW_INPLACE_',
...@@ -585,10 +1016,30 @@ IR_REGISTRY[OpCode.WRITE_ROW_INPLACE_] = { ...@@ -585,10 +1016,30 @@ IR_REGISTRY[OpCode.WRITE_ROW_INPLACE_] = {
} }
def WRITE_ROW_INPLACE_(fd, row, val): def WRITE_ROW_INPLACE_(fd, row, val):
"""Write the given row data to the feature dict in-place symbolically.
Parameters
----------
fd : var.Var
The feature dict.
row : var.Var
The row index.
val : var.Var
The given feature data.
"""
reg = IR_REGISTRY[OpCode.WRITE_ROW_INPLACE_] reg = IR_REGISTRY[OpCode.WRITE_ROW_INPLACE_]
get_current_prog().issue(reg['executor_cls'](fd, row, val)) get_current_prog().issue(reg['executor_cls'](fd, row, val))
class WriteDict_Executor(Executor): class WriteDict_Executor(Executor):
"""Executor for writing the given feature dict data into the another one.
Parameters
----------
fd1 : var.Var
The feature dict to be mutated.
fd2 : var.Var
The feature dict data.
"""
def __init__(self, fd1, fd2): def __init__(self, fd1, fd2):
self.fd1 = fd1 self.fd1 = fd1
self.fd2 = fd2 self.fd2 = fd2
...@@ -614,11 +1065,30 @@ IR_REGISTRY[OpCode.WRITE_DICT_] = { ...@@ -614,11 +1065,30 @@ IR_REGISTRY[OpCode.WRITE_DICT_] = {
'ret_type' : None, 'ret_type' : None,
'executor_cls' : WriteDict_Executor, 'executor_cls' : WriteDict_Executor,
} }
def WRITE_DICT_(fd1, fd2): def WRITE_DICT_(fd1, fd2):
"""Writing the given feature dict data into the another one symbolically.
Parameters
----------
fd1 : var.Var
The feature dict to be mutated.
fd2 : var.Var
The feature dict data.
"""
reg = IR_REGISTRY[OpCode.WRITE_DICT_] reg = IR_REGISTRY[OpCode.WRITE_DICT_]
get_current_prog().issue(reg['executor_cls'](fd1, fd2)) get_current_prog().issue(reg['executor_cls'](fd1, fd2))
class AppendRow_Executor(Executor): class AppendRow_Executor(Executor):
"""Executor for appending one feature dict to another.
Parameters
----------
fd1 : var.Var
The feature dict in the front.
fd2 : var.Var
The feature dict in the back.
"""
def __init__(self, fd1, fd2): def __init__(self, fd1, fd2):
self.fd1 = fd1 self.fd1 = fd1
self.fd2 = fd2 self.fd2 = fd2
...@@ -644,10 +1114,26 @@ IR_REGISTRY[OpCode.APPEND_ROW_] = { ...@@ -644,10 +1114,26 @@ IR_REGISTRY[OpCode.APPEND_ROW_] = {
'executor_cls' : AppendRow_Executor, 'executor_cls' : AppendRow_Executor,
} }
def APPEND_ROW_(fd1, fd2): def APPEND_ROW_(fd1, fd2):
"""Append one feature dict to another symbolically.
Parameters
----------
fd1 : var.Var
The feature dict in the front.
fd2 : var.Var
The feature dict in the back.
"""
reg = IR_REGISTRY[OpCode.APPEND_ROW_] reg = IR_REGISTRY[OpCode.APPEND_ROW_]
get_current_prog().issue(reg['executor_cls'](fd1, fd2)) get_current_prog().issue(reg['executor_cls'](fd1, fd2))
class ClearFrame_Executor(Executor): class ClearFrame_Executor(Executor):
"""Executor for clear the feature dict.
Parameters
----------
fd : var.Var
The feature dict to be cleared.
"""
def __init__(self, fd): def __init__(self, fd):
self.fd = fd self.fd = fd
...@@ -672,6 +1158,14 @@ IR_REGISTRY[OpCode.CLEAR_FRAME_] = { ...@@ -672,6 +1158,14 @@ IR_REGISTRY[OpCode.CLEAR_FRAME_] = {
'ret_type': None, 'ret_type': None,
'executor_cls': ClearFrame_Executor, 'executor_cls': ClearFrame_Executor,
} }
def CLEAR_FRAME_(fd): def CLEAR_FRAME_(fd):
"""Clear the feature dict symbolically.
Parameters
----------
fd : var.Var
The feature dict to be cleared.
"""
reg = IR_REGISTRY[OpCode.CLEAR_FRAME_] reg = IR_REGISTRY[OpCode.CLEAR_FRAME_]
get_current_prog().issue(reg['executor_cls'](fd)) get_current_prog().issue(reg['executor_cls'](fd))
"""Module for program."""
from __future__ import absolute_import from __future__ import absolute_import
from contextlib import contextmanager from contextlib import contextmanager
...@@ -5,15 +6,26 @@ from contextlib import contextmanager ...@@ -5,15 +6,26 @@ from contextlib import contextmanager
from .registry import IR_REGISTRY from .registry import IR_REGISTRY
class Prog(object): class Prog(object):
"""The program.""" """The program.
A program is simply a list of executors.
"""
def __init__(self): def __init__(self):
self.execs = [] self.execs = []
self.varcount = 0 self.varcount = 0
def issue(self, exe): def issue(self, exe):
"""Issue an executor to this program.
Parameters
----------
exe : Executor
The executor.
"""
self.execs.append(exe) self.execs.append(exe)
def pprint_exe(self, exe): def pprint_exe(self, exe):
"""Internal function to pretty-print the executor."""
argstr = ', '.join([str(av) for av in exe.arg_vars()]) argstr = ', '.join([str(av) for av in exe.arg_vars()])
if exe.ret_var() is None: if exe.ret_var() is None:
# stmt # stmt
...@@ -28,21 +40,26 @@ class Prog(object): ...@@ -28,21 +40,26 @@ class Prog(object):
argstr)) argstr))
def pprint(self): def pprint(self):
"""Pretty-print the program."""
for exe in self.execs: for exe in self.execs:
self.pprint_exe(exe) self.pprint_exe(exe)
_current_prog = None # current program
CURRENT_PROG = None
def get_current_prog(): def get_current_prog():
global _current_prog """Get the current program."""
return _current_prog global CURRENT_PROG
return CURRENT_PROG
def set_current_prog(prog): def set_current_prog(program):
global _current_prog """Set the current program."""
_current_prog = prog global CURRENT_PROG
CURRENT_PROG = program
@contextmanager @contextmanager
def prog(): def prog():
"""A context manager to create a new program."""
set_current_prog(Prog()) set_current_prog(Prog())
yield get_current_prog() yield get_current_prog()
set_current_prog(None) set_current_prog(None)
"""Module for variables."""
# pylint: disable=invalid-name
from __future__ import absolute_import from __future__ import absolute_import
from .program import get_current_prog from .program import get_current_prog
class VarType(object): class VarType(object):
"""Variable types."""
# Types for symbolic objects (i.e, they might not be # Types for symbolic objects (i.e, they might not be
# concretized before evaluation. # concretized before evaluation.
FEAT = 0 FEAT = 0
...@@ -23,47 +26,65 @@ VAR_TYPE_NAME_MAP = [ ...@@ -23,47 +26,65 @@ VAR_TYPE_NAME_MAP = [
] ]
class Var(object): class Var(object):
"""Variable """Class for variables in IR.
Variables represent data in the IR. A variable can contain concrete values.
Otherwise, it can act as a "symbol", whose values are not materialized at the
moment, but later.
Parameters
----------
name : str name : str
The variable name.
type : int type : int
The type code.
data : any, default=None (not concretized) data : any, default=None (not concretized)
The data.
""" """
__slots__ = ['name', 'type', 'data'] __slots__ = ['name', 'typecode', 'data']
def __init__(self, name, type, data): def __init__(self, name, typecode, data):
self.name = name self.name = name
self.type = type self.typecode = typecode
self.data = data self.data = data
def __str__(self): def __str__(self):
if self.type == VarType.STR: if self.typecode == VarType.STR:
return '"%s"' % self.data return '"%s"' % self.data
else: else:
return self.name return self.name
def typestr(self): def typestr(self):
return VAR_TYPE_NAME_MAP[self.type] """Return the type string of this variable."""
return VAR_TYPE_NAME_MAP[self.typecode]
def new(type, data=None, name=None): def new(typecode, data=None, name=None):
"""Create a new variable."""
if name is None: if name is None:
cur_prog = get_current_prog() cur_prog = get_current_prog()
name = '_z%d' % cur_prog.varcount name = '_z%d' % cur_prog.varcount
cur_prog.varcount += 1 cur_prog.varcount += 1
return Var(name, type, data) return Var(name, typecode, data)
def FEAT(data=None, name=None): def FEAT(data=None, name=None):
"""Create a variable for feature tensor."""
return new(VarType.FEAT, data, name) return new(VarType.FEAT, data, name)
def FEAT_DICT(data=None, name=None): def FEAT_DICT(data=None, name=None):
"""Create a variable for feature dict."""
return new(VarType.FEAT_DICT, data, name) return new(VarType.FEAT_DICT, data, name)
def SPMAT(data=None, name=None): def SPMAT(data=None, name=None):
"""Create a variable for sparse matrix lambda."""
return new(VarType.SPMAT, data, name) return new(VarType.SPMAT, data, name)
def IDX(data=None, name=None): def IDX(data=None, name=None):
"""Create a variable for index."""
return new(VarType.IDX, data, name) return new(VarType.IDX, data, name)
def STR(data=None, name=None): def STR(data=None, name=None):
"""Create a variable for string value."""
return new(VarType.STR, data, name) return new(VarType.STR, data, name)
def FUNC(data=None, name=None): def FUNC(data=None, name=None):
"""Create a variable for function."""
return new(VarType.FUNC, data, name) return new(VarType.FUNC, data, name)
"""DGL mini-runtime.""" """DGL mini-runtime."""
class Runtime(object): class Runtime(object):
"""The mini runtime class."""
@staticmethod @staticmethod
def run(prog): def run(prog):
"""Run the given program."""
for exe in prog.execs: for exe in prog.execs:
#prog.pprint_exe(exe) #prog.pprint_exe(exe)
exe.run() exe.run()
...@@ -3,27 +3,27 @@ from __future__ import absolute_import ...@@ -3,27 +3,27 @@ from __future__ import absolute_import
from .. import utils from .. import utils
from .._ffi.function import _init_api from .._ffi.function import _init_api
from ..base import ALL, DGLError, is_all from ..base import DGLError
from .. import backend as F from .. import backend as F
from ..frame import frame_like, FrameRef from ..frame import frame_like, FrameRef
from ..function.base import BuiltinFunction, BundledFunction from ..function.base import BuiltinFunction, BundledFunction
from ..udf import EdgeBatch, NodeBatch from ..udf import EdgeBatch, NodeBatch
from . import ir from . import ir
from .ir import var as var from .ir import var
from . import degree_bucketing as db from . import degree_bucketing as db
from . import spmv from . import spmv
__all__ = [ __all__ = [
"schedule_send", "schedule_send",
"schedule_recv", "schedule_recv",
"schedule_update_all", "schedule_update_all",
"schedule_snr", "schedule_snr",
"schedule_apply_nodes", "schedule_apply_nodes",
"schedule_apply_edges", "schedule_apply_edges",
"schedule_push", "schedule_push",
"schedule_pull" "schedule_pull"
] ]
def schedule_send(graph, u, v, eid, message_func): def schedule_send(graph, u, v, eid, message_func):
"""get send schedule """get send schedule
...@@ -132,7 +132,6 @@ def schedule_snr(graph, ...@@ -132,7 +132,6 @@ def schedule_snr(graph,
inplace: bool inplace: bool
If True, the update will be done in place If True, the update will be done in place
""" """
call_type = 'send_and_recv'
u, v, eid = edge_tuples u, v, eid = edge_tuples
recv_nodes, _ = F.sort_1d(F.unique(v.tousertensor())) recv_nodes, _ = F.sort_1d(F.unique(v.tousertensor()))
recv_nodes = utils.toindex(recv_nodes) recv_nodes = utils.toindex(recv_nodes)
...@@ -143,13 +142,12 @@ def schedule_snr(graph, ...@@ -143,13 +142,12 @@ def schedule_snr(graph,
var_eid = var.IDX(eid) var_eid = var.IDX(eid)
var_recv_nodes = var.IDX(recv_nodes, name='recv_nodes') var_recv_nodes = var.IDX(recv_nodes, name='recv_nodes')
# generate send and reduce schedule # generate send and reduce schedule
uv_getter = lambda : (var_u, var_v) uv_getter = lambda: (var_u, var_v)
adj_creator = lambda : spmv.build_adj_matrix_uv(graph, (u, v), recv_nodes) adj_creator = lambda: spmv.build_adj_matrix_uv(graph, (u, v), recv_nodes)
inc_creator = lambda : spmv.build_inc_matrix_dst(v, recv_nodes) inc_creator = lambda: spmv.build_inc_matrix_dst(v, recv_nodes)
reduced_feat = _gen_send_reduce( reduced_feat = _gen_send_reduce(graph, message_func, reduce_func,
graph, message_func, reduce_func, var_eid, var_recv_nodes,
var_eid, var_recv_nodes, uv_getter, adj_creator, inc_creator)
uv_getter, adj_creator, inc_creator)
# generate apply schedule # generate apply schedule
final_feat = _apply_with_accum(graph, var_recv_nodes, var_nf, reduced_feat, apply_func) final_feat = _apply_with_accum(graph, var_recv_nodes, var_nf, reduced_feat, apply_func)
if inplace: if inplace:
...@@ -180,7 +178,6 @@ def schedule_update_all(graph, ...@@ -180,7 +178,6 @@ def schedule_update_all(graph,
nodes = utils.toindex(slice(0, graph.number_of_nodes())) nodes = utils.toindex(slice(0, graph.number_of_nodes()))
schedule_apply_nodes(graph, nodes, apply_func, inplace=False) schedule_apply_nodes(graph, nodes, apply_func, inplace=False)
else: else:
call_type = 'update_all'
eid = utils.toindex(slice(0, graph.number_of_edges())) # shortcut for ALL eid = utils.toindex(slice(0, graph.number_of_edges())) # shortcut for ALL
recv_nodes = utils.toindex(slice(0, graph.number_of_nodes())) # shortcut for ALL recv_nodes = utils.toindex(slice(0, graph.number_of_nodes())) # shortcut for ALL
# create vars # create vars
...@@ -191,12 +188,11 @@ def schedule_update_all(graph, ...@@ -191,12 +188,11 @@ def schedule_update_all(graph,
def uv_getter(): def uv_getter():
src, dst, _ = graph._graph.edges() src, dst, _ = graph._graph.edges()
return var.IDX(src), var.IDX(dst) return var.IDX(src), var.IDX(dst)
adj_creator = lambda : spmv.build_adj_matrix_graph(graph) adj_creator = lambda: spmv.build_adj_matrix_graph(graph)
inc_creator = lambda : spmv.build_inc_matrix_graph(graph) inc_creator = lambda: spmv.build_inc_matrix_graph(graph)
reduced_feat = _gen_send_reduce( reduced_feat = _gen_send_reduce(graph, message_func, reduce_func,
graph, message_func, reduce_func, var_eid, var_recv_nodes,
var_eid, var_recv_nodes, uv_getter, adj_creator, inc_creator)
uv_getter, adj_creator, inc_creator)
# generate optional apply # generate optional apply
final_feat = _apply_with_accum(graph, var_recv_nodes, var_nf, reduced_feat, apply_func) final_feat = _apply_with_accum(graph, var_recv_nodes, var_nf, reduced_feat, apply_func)
ir.WRITE_DICT_(var_nf, final_feat) ir.WRITE_DICT_(var_nf, final_feat)
...@@ -226,8 +222,8 @@ def schedule_apply_nodes(graph, ...@@ -226,8 +222,8 @@ def schedule_apply_nodes(graph,
var_v = var.IDX(v) var_v = var.IDX(v)
v_nf = ir.READ_ROW(var_nf, var_v) v_nf = ir.READ_ROW(var_nf, var_v)
def _afunc_wrapper(node_data): def _afunc_wrapper(node_data):
nb = NodeBatch(graph, v, node_data) nbatch = NodeBatch(graph, v, node_data)
return apply_func(nb) return apply_func(nbatch)
afunc = var.FUNC(_afunc_wrapper) afunc = var.FUNC(_afunc_wrapper)
applied_feat = ir.NODE_UDF(afunc, v_nf) applied_feat = ir.NODE_UDF(afunc, v_nf)
if inplace: if inplace:
...@@ -271,9 +267,8 @@ def schedule_apply_edges(graph, ...@@ -271,9 +267,8 @@ def schedule_apply_edges(graph,
fddst = ir.READ_ROW(var_nf, var_v) fddst = ir.READ_ROW(var_nf, var_v)
fdedge = ir.READ_ROW(var_ef, var_eid) fdedge = ir.READ_ROW(var_ef, var_eid)
def _efunc_wrapper(src_data, edge_data, dst_data): def _efunc_wrapper(src_data, edge_data, dst_data):
eb = EdgeBatch(graph, (u, v, eid), ebatch = EdgeBatch(graph, (u, v, eid), src_data, edge_data, dst_data)
src_data, edge_data, dst_data) return apply_func(ebatch)
return apply_func(eb)
_efunc = var.FUNC(_efunc_wrapper) _efunc = var.FUNC(_efunc_wrapper)
new_fdedge = ir.EDGE_UDF(_efunc, fdsrc, fdedge, fddst) new_fdedge = ir.EDGE_UDF(_efunc, fdsrc, fdedge, fddst)
if inplace: if inplace:
...@@ -343,7 +338,6 @@ def schedule_pull(graph, ...@@ -343,7 +338,6 @@ def schedule_pull(graph,
if apply_func is not None: if apply_func is not None:
schedule_apply_nodes(graph, pull_nodes, apply_func, inplace) schedule_apply_nodes(graph, pull_nodes, apply_func, inplace)
else: else:
call_type = 'send_and_recv'
pull_nodes, _ = F.sort_1d(F.unique(pull_nodes.tousertensor())) pull_nodes, _ = F.sort_1d(F.unique(pull_nodes.tousertensor()))
pull_nodes = utils.toindex(pull_nodes) pull_nodes = utils.toindex(pull_nodes)
# create vars # create vars
...@@ -353,13 +347,12 @@ def schedule_pull(graph, ...@@ -353,13 +347,12 @@ def schedule_pull(graph,
var_v = var.IDX(v) var_v = var.IDX(v)
var_eid = var.IDX(eid) var_eid = var.IDX(eid)
# generate send and reduce schedule # generate send and reduce schedule
uv_getter = lambda : (var_u, var_v) uv_getter = lambda: (var_u, var_v)
adj_creator = lambda : spmv.build_adj_matrix_uv(graph, (u, v), pull_nodes) adj_creator = lambda: spmv.build_adj_matrix_uv(graph, (u, v), pull_nodes)
inc_creator = lambda : spmv.build_inc_matrix_dst(v, pull_nodes) inc_creator = lambda: spmv.build_inc_matrix_dst(v, pull_nodes)
reduced_feat = _gen_send_reduce( reduced_feat = _gen_send_reduce(graph, message_func, reduce_func,
graph, message_func, reduce_func, var_eid, var_pull_nodes,
var_eid, var_pull_nodes, uv_getter, adj_creator, inc_creator)
uv_getter, adj_creator, inc_creator)
# generate optional apply # generate optional apply
final_feat = _apply_with_accum(graph, var_pull_nodes, var_nf, reduced_feat, apply_func) final_feat = _apply_with_accum(graph, var_pull_nodes, var_nf, reduced_feat, apply_func)
if inplace: if inplace:
...@@ -423,8 +416,8 @@ def _apply_with_accum(graph, var_nodes, var_nf, var_accum, apply_func): ...@@ -423,8 +416,8 @@ def _apply_with_accum(graph, var_nodes, var_nf, var_accum, apply_func):
v_nf = ir.READ_ROW(var_nf, var_nodes) v_nf = ir.READ_ROW(var_nf, var_nodes)
v_nf = ir.UPDATE_DICT(v_nf, var_accum) v_nf = ir.UPDATE_DICT(v_nf, var_accum)
def _afunc_wrapper(node_data): def _afunc_wrapper(node_data):
nb = NodeBatch(graph, var_nodes.data, node_data) nbatch = NodeBatch(graph, var_nodes.data, node_data)
return apply_func(nb) return apply_func(nbatch)
afunc = var.FUNC(_afunc_wrapper) afunc = var.FUNC(_afunc_wrapper)
applied_feat = ir.NODE_UDF(afunc, v_nf) applied_feat = ir.NODE_UDF(afunc, v_nf)
final_feat = ir.UPDATE_DICT(var_accum, applied_feat) final_feat = ir.UPDATE_DICT(var_accum, applied_feat)
...@@ -439,7 +432,6 @@ def _gen_reduce(graph, reduce_func, edge_tuples, recv_nodes): ...@@ -439,7 +432,6 @@ def _gen_reduce(graph, reduce_func, edge_tuples, recv_nodes):
edge_tuples : tuple of utils.Index edge_tuples : tuple of utils.Index
recv_nodes : utils.Index recv_nodes : utils.Index
""" """
call_type = "recv"
_, dst, eid = edge_tuples _, dst, eid = edge_tuples
rfunc = _standardize_func_usage(reduce_func, 'reduce') rfunc = _standardize_func_usage(reduce_func, 'reduce')
rfunc_is_list = utils.is_iterable(rfunc) rfunc_is_list = utils.is_iterable(rfunc)
...@@ -451,9 +443,9 @@ def _gen_reduce(graph, reduce_func, edge_tuples, recv_nodes): ...@@ -451,9 +443,9 @@ def _gen_reduce(graph, reduce_func, edge_tuples, recv_nodes):
tmpframe = FrameRef(frame_like(graph._node_frame._frame, len(recv_nodes))) tmpframe = FrameRef(frame_like(graph._node_frame._frame, len(recv_nodes)))
# vars # vars
msg = var.FEAT_DICT(graph._msg_frame, 'msg') var_msg = var.FEAT_DICT(graph._msg_frame, 'msg')
nf = var.FEAT_DICT(graph._node_frame, 'nf') var_nf = var.FEAT_DICT(graph._node_frame, 'nf')
out = var.FEAT_DICT(data=tmpframe) var_out = var.FEAT_DICT(data=tmpframe)
if rfunc_is_list: if rfunc_is_list:
# UDF message + builtin reducer # UDF message + builtin reducer
...@@ -461,19 +453,19 @@ def _gen_reduce(graph, reduce_func, edge_tuples, recv_nodes): ...@@ -461,19 +453,19 @@ def _gen_reduce(graph, reduce_func, edge_tuples, recv_nodes):
spmv_rfunc, rfunc = spmv.analyze_e2v_spmv(graph, rfunc) spmv_rfunc, rfunc = spmv.analyze_e2v_spmv(graph, rfunc)
inc = spmv.build_inc_matrix_eid(graph._msg_frame.num_rows, eid, dst, inc = spmv.build_inc_matrix_eid(graph._msg_frame.num_rows, eid, dst,
recv_nodes) recv_nodes)
spmv.gen_e2v_spmv_schedule(inc, spmv_rfunc, msg, out) spmv.gen_e2v_spmv_schedule(inc, spmv_rfunc, var_msg, var_out)
if len(rfunc) == 0: if len(rfunc) == 0:
# All mfunc and rfunc has been processed. # All mfunc and rfunc has been processed.
return out return var_out
# convert the remaining rfunc to UDFs # convert the remaining rfunc to UDFs
rfunc = BundledFunction(rfunc) rfunc = BundledFunction(rfunc)
# gen degree bucketing schedule for UDF recv # gen degree bucketing schedule for UDF recv
db.gen_degree_bucketing_schedule(graph, rfunc, eid, dst, db.gen_degree_bucketing_schedule(graph, rfunc, eid, dst,
recv_nodes, nf, msg, out) recv_nodes, var_nf, var_msg, var_out)
return out return var_out
def _gen_send_reduce( def _gen_send_reduce(
graph, graph,
...@@ -573,19 +565,19 @@ def _gen_send_reduce( ...@@ -573,19 +565,19 @@ def _gen_send_reduce(
# gen degree bucketing schedule for UDF recv # gen degree bucketing schedule for UDF recv
mid = utils.toindex(slice(0, len(var_v.data))) # message id is from 0~|dst| mid = utils.toindex(slice(0, len(var_v.data))) # message id is from 0~|dst|
db.gen_degree_bucketing_schedule(graph, rfunc, db.gen_degree_bucketing_schedule(
mid, var_v.data, reduce_nodes, graph, rfunc, mid, var_v.data, reduce_nodes, var_nf, var_mf, var_out)
var_nf, var_mf, var_out)
return var_out return var_out
def _gen_send(graph, nf, ef, u, v, eid, mfunc): def _gen_send(graph, nfr, efr, u, v, eid, mfunc):
fdsrc = ir.READ_ROW(nf, u) """Internal function to generate send schedule."""
fddst = ir.READ_ROW(nf, v) fdsrc = ir.READ_ROW(nfr, u)
fdedge = ir.READ_ROW(ef, eid) fddst = ir.READ_ROW(nfr, v)
fdedge = ir.READ_ROW(efr, eid)
def _mfunc_wrapper(src_data, edge_data, dst_data): def _mfunc_wrapper(src_data, edge_data, dst_data):
eb = EdgeBatch(graph, (u.data, v.data, eid.data), ebatch = EdgeBatch(graph, (u.data, v.data, eid.data),
src_data, edge_data, dst_data) src_data, edge_data, dst_data)
return mfunc(eb) return mfunc(ebatch)
_mfunc_wrapper = var.FUNC(_mfunc_wrapper) _mfunc_wrapper = var.FUNC(_mfunc_wrapper)
msg = ir.EDGE_UDF(_mfunc_wrapper, fdsrc, fdedge, fddst) msg = ir.EDGE_UDF(_mfunc_wrapper, fdsrc, fdedge, fddst)
return msg return msg
......
...@@ -6,7 +6,7 @@ from .. import backend as F ...@@ -6,7 +6,7 @@ from .. import backend as F
from .. import utils from .. import utils
from . import ir from . import ir
from .ir import var as var from .ir import var
def analyze_v2v_spmv(graph, mfunc, rfunc): def analyze_v2v_spmv(graph, mfunc, rfunc):
"""Analyze if SPMV from node space to node space can be applied. """Analyze if SPMV from node space to node space can be applied.
...@@ -54,7 +54,7 @@ def analyze_v2v_spmv(graph, mfunc, rfunc): ...@@ -54,7 +54,7 @@ def analyze_v2v_spmv(graph, mfunc, rfunc):
return spmv_pairs, mfunc_left, rfunc_left return spmv_pairs, mfunc_left, rfunc_left
def analyze_e2v_spmv(graph, rfunc): def analyze_e2v_spmv(graph, rfunc): # pylint: disable=unused-argument
"""Analyze if SPMV from edge space to node space can be applied. """Analyze if SPMV from edge space to node space can be applied.
Parameters Parameters
...@@ -80,16 +80,16 @@ def analyze_e2v_spmv(graph, rfunc): ...@@ -80,16 +80,16 @@ def analyze_e2v_spmv(graph, rfunc):
rfunc_left.append(rfn) rfunc_left.append(rfn)
return spmv_rfunc, rfunc_left return spmv_rfunc, rfunc_left
def gen_v2v_spmv_schedule(adj, spmv_pairs, nf, ef, eid, out): def gen_v2v_spmv_schedule(adj, spmv_pairs, nft, eft, eid, out):
"""Generate v2v spmv schedule. """Generate v2v spmv schedule.
Parameters Parameters
---------- ----------
adj : tuple (sparse matrix, utils.Index) adj : tuple (sparse matrix, utils.Index)
spmv_pairs : list of pair spmv_pairs : list of pair
nf : var.Var nft : var.Var
input node features input node features
ef : var.Var eft : var.Var
input edge features input edge features
eid : var.Var eid : var.Var
eid index eid index
...@@ -103,16 +103,16 @@ def gen_v2v_spmv_schedule(adj, spmv_pairs, nf, ef, eid, out): ...@@ -103,16 +103,16 @@ def gen_v2v_spmv_schedule(adj, spmv_pairs, nf, ef, eid, out):
eid = var.IDX(new_eid) eid = var.IDX(new_eid)
for mfn, rfn in spmv_pairs: for mfn, rfn in spmv_pairs:
if mfn.use_edge_feature: if mfn.use_edge_feature:
ftedge = ir.READ(ef, eid, var.STR(mfn.edge_field)) ftedge = ir.READ(eft, eid, var.STR(mfn.edge_field))
ftsrc = ir.READ_COL(nf, var.STR(mfn.src_field)) ftsrc = ir.READ_COL(nft, var.STR(mfn.src_field))
ftdst = ir.SPMV_WITH_DATA(adj_var, ftedge, ftsrc) ftdst = ir.SPMV_WITH_DATA(adj_var, ftedge, ftsrc)
else: else:
ftsrc = ir.READ_COL(nf, var.STR(mfn.src_field)) ftsrc = ir.READ_COL(nft, var.STR(mfn.src_field))
ftdst = ir.SPMV(adj_var, ftsrc) ftdst = ir.SPMV(adj_var, ftsrc)
# save for merge # save for merge
ir.WRITE_COL_(out, var.STR(rfn.out_field), ftdst) ir.WRITE_COL_(out, var.STR(rfn.out_field), ftdst)
def gen_e2v_spmv_schedule(inc, spmv_rfunc, mf, out): def gen_e2v_spmv_schedule(inc, spmv_rfunc, mfr, out):
"""Generate e2v SPMV schedule. """Generate e2v SPMV schedule.
Parameters Parameters
...@@ -127,7 +127,7 @@ def gen_e2v_spmv_schedule(inc, spmv_rfunc, mf, out): ...@@ -127,7 +127,7 @@ def gen_e2v_spmv_schedule(inc, spmv_rfunc, mf, out):
incmat, _ = inc incmat, _ = inc
inc_var = var.SPMAT(incmat) inc_var = var.SPMAT(incmat)
for rfn in spmv_rfunc: for rfn in spmv_rfunc:
ftmsg = ir.READ_COL(mf, var.STR(rfn.msg_field)) ftmsg = ir.READ_COL(mfr, var.STR(rfn.msg_field))
ftdst = ir.SPMV(inc_var, ftmsg) ftdst = ir.SPMV(inc_var, ftmsg)
ir.WRITE_COL_(out, var.STR(rfn.out_field), ftdst) ir.WRITE_COL_(out, var.STR(rfn.out_field), ftdst)
...@@ -147,9 +147,9 @@ def build_adj_matrix_graph(graph): ...@@ -147,9 +147,9 @@ def build_adj_matrix_graph(graph):
A index for data shuffling due to sparse format change. Return None A index for data shuffling due to sparse format change. Return None
if shuffle is not required. if shuffle is not required.
""" """
gi = graph._graph gidx = graph._graph
_, shuffle_idx = gi.adjacency_matrix(False, F.cpu()) _, shuffle_idx = gidx.adjacency_matrix(False, F.cpu())
return lambda ctx : gi.adjacency_matrix(False, ctx)[0], shuffle_idx return lambda ctx: gidx.adjacency_matrix(False, ctx)[0], shuffle_idx
def _build_adj_matrix_index_uv(graph, edges, reduce_nodes): def _build_adj_matrix_index_uv(graph, edges, reduce_nodes):
"""Build adj matrix index and shape using the given (u, v) edges. """Build adj matrix index and shape using the given (u, v) edges.
...@@ -180,7 +180,7 @@ def _build_adj_matrix_index_uv(graph, edges, reduce_nodes): ...@@ -180,7 +180,7 @@ def _build_adj_matrix_index_uv(graph, edges, reduce_nodes):
The dense shape. The dense shape.
""" """
# TODO(minjie): add node frontier for this # TODO(minjie): add node frontier for this
new2old, old2new = utils.build_relabel_map(reduce_nodes, sorted=True) _, old2new = utils.build_relabel_map(reduce_nodes, is_sorted=True)
u, v = edges u, v = edges
u = u.tousertensor() u = u.tousertensor()
v = v.tousertensor() v = v.tousertensor()
...@@ -218,13 +218,13 @@ def build_adj_matrix_uv(graph, edges, reduce_nodes): ...@@ -218,13 +218,13 @@ def build_adj_matrix_uv(graph, edges, reduce_nodes):
if shuffle is not required. if shuffle is not required.
""" """
sp_idx, shape = _build_adj_matrix_index_uv(graph, edges, reduce_nodes) sp_idx, shape = _build_adj_matrix_index_uv(graph, edges, reduce_nodes)
u, v = edges u, _ = edges
nnz = len(u) nnz = len(u)
# FIXME(minjie): data type # FIXME(minjie): data type
dat = F.ones((nnz,), dtype=F.float32, ctx=F.cpu()) dat = F.ones((nnz,), dtype=F.float32, ctx=F.cpu())
mat, shuffle_idx = F.sparse_matrix(dat, sp_idx, shape) mat, shuffle_idx = F.sparse_matrix(dat, sp_idx, shape)
shuffle_idx = utils.toindex(shuffle_idx) if shuffle_idx is not None else None shuffle_idx = utils.toindex(shuffle_idx) if shuffle_idx is not None else None
return utils.CtxCachedObject(lambda ctx : F.copy_to(mat, ctx)), shuffle_idx return utils.CtxCachedObject(lambda ctx: F.copy_to(mat, ctx)), shuffle_idx
def build_inc_matrix_graph(graph): def build_inc_matrix_graph(graph):
"""Build incidence matrix. """Build incidence matrix.
...@@ -242,16 +242,16 @@ def build_inc_matrix_graph(graph): ...@@ -242,16 +242,16 @@ def build_inc_matrix_graph(graph):
A index for data shuffling due to sparse format change. Return None A index for data shuffling due to sparse format change. Return None
if shuffle is not required. if shuffle is not required.
""" """
gi = graph._graph gidx = graph._graph
# inc mat will not use data tensor so conversion index is not needed # inc mat will not use data tensor so conversion index is not needed
return lambda ctx : gi.incidence_matrix('in', ctx)[0], None return lambda ctx: gidx.incidence_matrix('in', ctx)[0], None
def build_inc_matrix_eid(m, eid, dst, reduce_nodes): def build_inc_matrix_eid(m, eid, dst, reduce_nodes):
"""Build incidence matrix using edge id and edge dst nodes. """Build incidence matrix using edge id and edge dst nodes.
The incidence matrix is of shape (n, m), where n=len(reduce_nodes). The incidence matrix is of shape (n, m), where n=len(reduce_nodes).
The nnz is equal to len(eid). The nnz is equal to len(eid).
Invariant: len(eid) == len(dst) Invariant: len(eid) == len(dst)
The dst nodes will be sorted in the *unique-ascending* order of The dst nodes will be sorted in the *unique-ascending* order of
...@@ -296,7 +296,7 @@ def build_inc_matrix_eid(m, eid, dst, reduce_nodes): ...@@ -296,7 +296,7 @@ def build_inc_matrix_eid(m, eid, dst, reduce_nodes):
A index for data shuffling due to sparse format change. Return None A index for data shuffling due to sparse format change. Return None
if shuffle is not required. if shuffle is not required.
""" """
new2old, old2new = utils.build_relabel_map(reduce_nodes, sorted=True) _, old2new = utils.build_relabel_map(reduce_nodes, is_sorted=True)
dst = dst.tousertensor() dst = dst.tousertensor()
eid = eid.tousertensor() eid = eid.tousertensor()
# relabel edges dsts # relabel edges dsts
...@@ -311,7 +311,7 @@ def build_inc_matrix_eid(m, eid, dst, reduce_nodes): ...@@ -311,7 +311,7 @@ def build_inc_matrix_eid(m, eid, dst, reduce_nodes):
dat = F.ones((nnz,), dtype=F.float32, ctx=F.cpu()) dat = F.ones((nnz,), dtype=F.float32, ctx=F.cpu())
mat, _ = F.sparse_matrix(dat, ('coo', idx), (n, m)) mat, _ = F.sparse_matrix(dat, ('coo', idx), (n, m))
# inc mat will not use data tensor so conversion index is not needed # inc mat will not use data tensor so conversion index is not needed
return utils.CtxCachedObject(lambda ctx : F.copy_to(mat, ctx)), None return utils.CtxCachedObject(lambda ctx: F.copy_to(mat, ctx)), None
def build_inc_matrix_dst(dst, reduce_nodes): def build_inc_matrix_dst(dst, reduce_nodes):
"""Build incidence matrix using only edge destinations. """Build incidence matrix using only edge destinations.
...@@ -332,7 +332,7 @@ def build_inc_matrix_dst(dst, reduce_nodes): ...@@ -332,7 +332,7 @@ def build_inc_matrix_dst(dst, reduce_nodes):
[0, 0, 0, 0, 0], [0, 0, 0, 0, 0],
[0, 0, 1, 0, 0], [0, 0, 1, 0, 0],
[0, 0, 0, 1, 1]], shape=(5, 5)) [0, 0, 0, 1, 1]], shape=(5, 5))
Parameters Parameters
---------- ----------
dst : utils.Index dst : utils.Index
......
"""Class for subgraph data structure.""" """Class for subgraph data structure."""
from __future__ import absolute_import from __future__ import absolute_import
import networkx as nx
from . import backend as F
from .frame import Frame, FrameRef from .frame import Frame, FrameRef
from .graph import DGLGraph from .graph import DGLGraph
from . import utils from . import utils
...@@ -47,22 +44,24 @@ class DGLSubGraph(DGLGraph): ...@@ -47,22 +44,24 @@ class DGLSubGraph(DGLGraph):
def __init__(self, parent, parent_nid, parent_eid, graph_idx, shared=False): def __init__(self, parent, parent_nid, parent_eid, graph_idx, shared=False):
super(DGLSubGraph, self).__init__(graph_data=graph_idx, super(DGLSubGraph, self).__init__(graph_data=graph_idx,
readonly=graph_idx.is_readonly()) readonly=graph_idx.is_readonly())
if shared:
raise DGLError('Shared mode is not yet supported.')
self._parent = parent self._parent = parent
self._parent_nid = parent_nid self._parent_nid = parent_nid
self._parent_eid = parent_eid self._parent_eid = parent_eid
# override APIs # override APIs
def add_nodes(self, num, reprs=None): def add_nodes(self, num, data=None):
"""Add nodes. Disabled because BatchedDGLGraph is read-only.""" """Add nodes. Disabled because BatchedDGLGraph is read-only."""
raise RuntimeError('Readonly graph. Mutation is not allowed.') raise DGLError('Readonly graph. Mutation is not allowed.')
def add_edge(self, u, v, reprs=None): def add_edge(self, u, v, data=None):
"""Add one edge. Disabled because BatchedDGLGraph is read-only.""" """Add one edge. Disabled because BatchedDGLGraph is read-only."""
raise RuntimeError('Readonly graph. Mutation is not allowed.') raise DGLError('Readonly graph. Mutation is not allowed.')
def add_edges(self, u, v, reprs=None): def add_edges(self, u, v, data=None):
"""Add many edges. Disabled because BatchedDGLGraph is read-only.""" """Add many edges. Disabled because BatchedDGLGraph is read-only."""
raise RuntimeError('Readonly graph. Mutation is not allowed.') raise DGLError('Readonly graph. Mutation is not allowed.')
@property @property
def parent_nid(self): def parent_nid(self):
...@@ -110,10 +109,10 @@ class DGLSubGraph(DGLGraph): ...@@ -110,10 +109,10 @@ class DGLSubGraph(DGLGraph):
If true, use inplace write (no gradient but faster) If true, use inplace write (no gradient but faster)
""" """
self._parent._node_frame.update_rows( self._parent._node_frame.update_rows(
self._parent_nid, self._node_frame, inplace=inplace) self._parent_nid, self._node_frame, inplace=inplace)
if self._parent._edge_frame.num_rows != 0: if self._parent._edge_frame.num_rows != 0:
self._parent._edge_frame.update_rows( self._parent._edge_frame.update_rows(
self._get_parent_eid(), self._edge_frame, inplace=inplace) self._get_parent_eid(), self._edge_frame, inplace=inplace)
def copy_from_parent(self): def copy_from_parent(self):
"""Copy node/edge features from the parent graph. """Copy node/edge features from the parent graph.
......
...@@ -9,7 +9,7 @@ __all__ = ['bfs_nodes_generator', 'bfs_edges_generator', ...@@ -9,7 +9,7 @@ __all__ = ['bfs_nodes_generator', 'bfs_edges_generator',
'topological_nodes_generator', 'topological_nodes_generator',
'dfs_edges_generator', 'dfs_labeled_edges_generator',] 'dfs_edges_generator', 'dfs_labeled_edges_generator',]
def bfs_nodes_generator(graph, source, reversed=False): def bfs_nodes_generator(graph, source, reverse=False):
"""Node frontiers generator using breadth-first search. """Node frontiers generator using breadth-first search.
Parameters Parameters
...@@ -18,7 +18,7 @@ def bfs_nodes_generator(graph, source, reversed=False): ...@@ -18,7 +18,7 @@ def bfs_nodes_generator(graph, source, reversed=False):
The graph object. The graph object.
source : list, tensor of nodes source : list, tensor of nodes
Source nodes. Source nodes.
reversed : bool, default False reverse : bool, default False
If True, traverse following the in-edge direction. If True, traverse following the in-edge direction.
Returns Returns
...@@ -41,14 +41,14 @@ def bfs_nodes_generator(graph, source, reversed=False): ...@@ -41,14 +41,14 @@ def bfs_nodes_generator(graph, source, reversed=False):
""" """
ghandle = graph._graph._handle ghandle = graph._graph._handle
source = utils.toindex(source) source = utils.toindex(source)
ret = _CAPI_DGLBFSNodes(ghandle, source.todgltensor(), reversed) ret = _CAPI_DGLBFSNodes(ghandle, source.todgltensor(), reverse)
all_nodes = utils.toindex(ret(0)).tousertensor() all_nodes = utils.toindex(ret(0)).tousertensor()
# TODO(minjie): how to support directly creating python list # TODO(minjie): how to support directly creating python list
sections = utils.toindex(ret(1)).tonumpy().tolist() sections = utils.toindex(ret(1)).tonumpy().tolist()
node_frontiers = F.split(all_nodes, sections, dim=0) node_frontiers = F.split(all_nodes, sections, dim=0)
return node_frontiers return node_frontiers
def bfs_edges_generator(graph, source, reversed=False): def bfs_edges_generator(graph, source, reverse=False):
"""Edges frontiers generator using breadth-first search. """Edges frontiers generator using breadth-first search.
Parameters Parameters
...@@ -57,7 +57,7 @@ def bfs_edges_generator(graph, source, reversed=False): ...@@ -57,7 +57,7 @@ def bfs_edges_generator(graph, source, reversed=False):
The graph object. The graph object.
source : list, tensor of nodes source : list, tensor of nodes
Source nodes. Source nodes.
reversed : bool, default False reverse : bool, default False
If True, traverse following the in-edge direction. If True, traverse following the in-edge direction.
Returns Returns
...@@ -81,21 +81,21 @@ def bfs_edges_generator(graph, source, reversed=False): ...@@ -81,21 +81,21 @@ def bfs_edges_generator(graph, source, reversed=False):
""" """
ghandle = graph._graph._handle ghandle = graph._graph._handle
source = utils.toindex(source) source = utils.toindex(source)
ret = _CAPI_DGLBFSEdges(ghandle, source.todgltensor(), reversed) ret = _CAPI_DGLBFSEdges(ghandle, source.todgltensor(), reverse)
all_edges = utils.toindex(ret(0)).tousertensor() all_edges = utils.toindex(ret(0)).tousertensor()
# TODO(minjie): how to support directly creating python list # TODO(minjie): how to support directly creating python list
sections = utils.toindex(ret(1)).tonumpy().tolist() sections = utils.toindex(ret(1)).tonumpy().tolist()
edge_frontiers = F.split(all_edges, sections, dim=0) edge_frontiers = F.split(all_edges, sections, dim=0)
return edge_frontiers return edge_frontiers
def topological_nodes_generator(graph, reversed=False): def topological_nodes_generator(graph, reverse=False):
"""Node frontiers generator using topological traversal. """Node frontiers generator using topological traversal.
Parameters Parameters
---------- ----------
graph : DGLGraph graph : DGLGraph
The graph object. The graph object.
reversed : bool, optional reverse : bool, optional
If True, traverse following the in-edge direction. If True, traverse following the in-edge direction.
Returns Returns
...@@ -117,13 +117,13 @@ def topological_nodes_generator(graph, reversed=False): ...@@ -117,13 +117,13 @@ def topological_nodes_generator(graph, reversed=False):
[tensor([0]), tensor([1]), tensor([2]), tensor([3, 4]), tensor([5])] [tensor([0]), tensor([1]), tensor([2]), tensor([3, 4]), tensor([5])]
""" """
ghandle = graph._graph._handle ghandle = graph._graph._handle
ret = _CAPI_DGLTopologicalNodes(ghandle, reversed) ret = _CAPI_DGLTopologicalNodes(ghandle, reverse)
all_nodes = utils.toindex(ret(0)).tousertensor() all_nodes = utils.toindex(ret(0)).tousertensor()
# TODO(minjie): how to support directly creating python list # TODO(minjie): how to support directly creating python list
sections = utils.toindex(ret(1)).tonumpy().tolist() sections = utils.toindex(ret(1)).tonumpy().tolist()
return F.split(all_nodes, sections, dim=0) return F.split(all_nodes, sections, dim=0)
def dfs_edges_generator(graph, source, reversed=False): def dfs_edges_generator(graph, source, reverse=False):
"""Edge frontiers generator using depth-first-search (DFS). """Edge frontiers generator using depth-first-search (DFS).
Multiple source nodes can be specified to start the DFS traversal. One Multiple source nodes can be specified to start the DFS traversal. One
...@@ -137,7 +137,7 @@ def dfs_edges_generator(graph, source, reversed=False): ...@@ -137,7 +137,7 @@ def dfs_edges_generator(graph, source, reversed=False):
The graph object. The graph object.
source : list, tensor of nodes source : list, tensor of nodes
Source nodes. Source nodes.
reversed : bool, optional reverse : bool, optional
If True, traverse following the in-edge direction. If True, traverse following the in-edge direction.
Returns Returns
...@@ -162,7 +162,7 @@ def dfs_edges_generator(graph, source, reversed=False): ...@@ -162,7 +162,7 @@ def dfs_edges_generator(graph, source, reversed=False):
""" """
ghandle = graph._graph._handle ghandle = graph._graph._handle
source = utils.toindex(source) source = utils.toindex(source)
ret = _CAPI_DGLDFSEdges(ghandle, source.todgltensor(), reversed) ret = _CAPI_DGLDFSEdges(ghandle, source.todgltensor(), reverse)
all_edges = utils.toindex(ret(0)).tousertensor() all_edges = utils.toindex(ret(0)).tousertensor()
# TODO(minjie): how to support directly creating python list # TODO(minjie): how to support directly creating python list
sections = utils.toindex(ret(1)).tonumpy().tolist() sections = utils.toindex(ret(1)).tonumpy().tolist()
...@@ -171,7 +171,7 @@ def dfs_edges_generator(graph, source, reversed=False): ...@@ -171,7 +171,7 @@ def dfs_edges_generator(graph, source, reversed=False):
def dfs_labeled_edges_generator( def dfs_labeled_edges_generator(
graph, graph,
source, source,
reversed=False, reverse=False,
has_reverse_edge=False, has_reverse_edge=False,
has_nontree_edge=False, has_nontree_edge=False,
return_labels=True): return_labels=True):
...@@ -199,7 +199,7 @@ def dfs_labeled_edges_generator( ...@@ -199,7 +199,7 @@ def dfs_labeled_edges_generator(
The graph object. The graph object.
source : list, tensor of nodes source : list, tensor of nodes
Source nodes. Source nodes.
reversed : bool, optional reverse : bool, optional
If true, traverse following the in-edge direction. If true, traverse following the in-edge direction.
has_reverse_edge : bool, optional has_reverse_edge : bool, optional
True to include reverse edges. True to include reverse edges.
...@@ -234,12 +234,12 @@ def dfs_labeled_edges_generator( ...@@ -234,12 +234,12 @@ def dfs_labeled_edges_generator(
ghandle = graph._graph._handle ghandle = graph._graph._handle
source = utils.toindex(source) source = utils.toindex(source)
ret = _CAPI_DGLDFSLabeledEdges( ret = _CAPI_DGLDFSLabeledEdges(
ghandle, ghandle,
source.todgltensor(), source.todgltensor(),
reversed, reverse,
has_reverse_edge, has_reverse_edge,
has_nontree_edge, has_nontree_edge,
return_labels) return_labels)
all_edges = utils.toindex(ret(0)).tousertensor() all_edges = utils.toindex(ret(0)).tousertensor()
# TODO(minjie): how to support directly creating python list # TODO(minjie): how to support directly creating python list
if return_labels: if return_labels:
......
"""User-defined function related data structures.""" """User-defined function related data structures."""
from __future__ import absolute_import from __future__ import absolute_import
from .base import ALL, is_all from .base import is_all
from . import backend as F from . import backend as F
from . import utils from . import utils
......
"""Utility module.""" """Utility module."""
from __future__ import absolute_import, division from __future__ import absolute_import, division
from collections import Mapping, Iterable from collections.abc import Mapping, Iterable
from functools import wraps from functools import wraps
import numpy as np import numpy as np
...@@ -43,7 +43,7 @@ class Index(object): ...@@ -43,7 +43,7 @@ class Index(object):
def _dispatch(self, data): def _dispatch(self, data):
"""Store data based on its type.""" """Store data based on its type."""
if F.is_tensor(data): if F.is_tensor(data):
if not (F.dtype(data) == F.int64): if F.dtype(data) != F.int64:
raise DGLError('Index data must be an int64 vector, but got: %s' % str(data)) raise DGLError('Index data must be an int64 vector, but got: %s' % str(data))
if len(F.shape(data)) > 1: if len(F.shape(data)) > 1:
raise DGLError('Index data must be 1D int64 vector, but got: %s' % str(data)) raise DGLError('Index data must be 1D int64 vector, but got: %s' % str(data))
...@@ -63,19 +63,17 @@ class Index(object): ...@@ -63,19 +63,17 @@ class Index(object):
self._slice_data = slice(data.start, data.stop) self._slice_data = slice(data.start, data.stop)
else: else:
try: try:
self._pydata = np.array([int(data)]).astype(np.int64) data = np.array(data).astype(np.int64)
except: except Exception: # pylint: disable=broad-except
try: raise DGLError('Error index data: %s' % str(data))
data = np.array(data).astype(np.int64) if data.ndim == 0: # scalar array
if data.ndim != 1: data = np.expand_dims(data, 0)
raise DGLError('Index data must be 1D int64 vector,' elif data.ndim != 1:
' but got: %s' % str(data)) raise DGLError('Index data must be 1D int64 vector,'
self._pydata = data ' but got: %s' % str(data))
except: self._pydata = data
raise DGLError('Error index data: %s' % str(data))
self._user_tensor_data[F.cpu()] = F.zerocopy_from_numpy(self._pydata) self._user_tensor_data[F.cpu()] = F.zerocopy_from_numpy(self._pydata)
def tonumpy(self): def tonumpy(self):
"""Convert to a numpy ndarray.""" """Convert to a numpy ndarray."""
if self._pydata is None: if self._pydata is None:
...@@ -96,8 +94,8 @@ class Index(object): ...@@ -96,8 +94,8 @@ class Index(object):
if len(self._user_tensor_data) == 0: if len(self._user_tensor_data) == 0:
if self._dgl_tensor_data is not None: if self._dgl_tensor_data is not None:
# zero copy from dgl tensor # zero copy from dgl tensor
dl = self._dgl_tensor_data.to_dlpack() dlpack = self._dgl_tensor_data.to_dlpack()
self._user_tensor_data[F.cpu()] = F.zerocopy_from_dlpack(dl) self._user_tensor_data[F.cpu()] = F.zerocopy_from_dlpack(dlpack)
else: else:
# zero copy from numpy array # zero copy from numpy array
self._user_tensor_data[F.cpu()] = F.zerocopy_from_numpy(self.tonumpy()) self._user_tensor_data[F.cpu()] = F.zerocopy_from_numpy(self.tonumpy())
...@@ -112,10 +110,17 @@ class Index(object): ...@@ -112,10 +110,17 @@ class Index(object):
if self._dgl_tensor_data is None: if self._dgl_tensor_data is None:
# zero copy from user tensor # zero copy from user tensor
tsor = self.tousertensor() tsor = self.tousertensor()
dl = F.zerocopy_to_dlpack(tsor) dlpack = F.zerocopy_to_dlpack(tsor)
self._dgl_tensor_data = nd.from_dlpack(dl) self._dgl_tensor_data = nd.from_dlpack(dlpack)
return self._dgl_tensor_data return self._dgl_tensor_data
def slice_data(self):
"""Return the internal slice data.
If this index is not initialized from slice, the return will be None.
"""
return self._slice_data
def is_slice(self, start, stop): def is_slice(self, start, stop):
"""Check if Index wraps a slice data with given start and stop""" """Check if Index wraps a slice data with given start and stop"""
return self._slice_data == slice(start, stop) return self._slice_data == slice(start, stop)
...@@ -136,20 +141,26 @@ class Index(object): ...@@ -136,20 +141,26 @@ class Index(object):
Returns Returns
------- -------
utils.Index utils.Index
The values at the given position.
""" """
if index._slice_data is None: if self._slice_data is not None and self._slice_data.start == 0:
# short-cut for identical mapping
# NOTE: we don't check for out-of-bound error
return index
elif index._slice_data is None:
# the provided index is not a slice
tensor = self.tousertensor() tensor = self.tousertensor()
index = index.tousertensor() index = index.tousertensor()
return Index(F.gather_row(tensor, index)) return Index(F.gather_row(tensor, index))
elif self._slice_data is None: elif self._slice_data is None:
# the current index is not a slice but the provided is a slice
tensor = self.tousertensor() tensor = self.tousertensor()
index = index._slice_data index = index._slice_data
return Index(F.narrow_row(tensor, index.start, index.stop)) return Index(F.narrow_row(tensor, index.start, index.stop))
else: else:
# both self and index wrap a slice object, then return another # both self and index wrap a slice object, then return another
# Index wrapping a slice # Index wrapping a slice
start = self._slicedata.start start = self._slice_data.start
index = index._slice_data index = index._slice_data
return Index(slice(start + index.start, start + index.stop)) return Index(slice(start + index.start, start + index.stop))
...@@ -168,7 +179,7 @@ class Index(object): ...@@ -168,7 +179,7 @@ class Index(object):
Returns Returns
------- -------
utils.Index utils.Index
The new values.
""" """
tensor = self.tousertensor() tensor = self.tousertensor()
index = index.tousertensor() index = index.tousertensor()
...@@ -207,8 +218,24 @@ class Index(object): ...@@ -207,8 +218,24 @@ class Index(object):
tensor = self.tousertensor() tensor = self.tousertensor()
return F.sum(tensor, 0) > 0 return F.sum(tensor, 0) > 0
def toindex(x): def toindex(data):
return x if isinstance(x, Index) else Index(x) """Convert the given data to Index object.
Parameters
----------
data : index data
Data to create the index.
Returns
-------
Index
The index object.
See Also
--------
Index
"""
return data if isinstance(data, Index) else Index(data)
def zero_index(size): def zero_index(size):
"""Create a index with provided size initialized to zero """Create a index with provided size initialized to zero
...@@ -244,21 +271,22 @@ class LazyDict(Mapping): ...@@ -244,21 +271,22 @@ class LazyDict(Mapping):
class HybridDict(Mapping): class HybridDict(Mapping):
"""A readonly dictonary that merges several dict-like (python dict, LazyDict). """A readonly dictonary that merges several dict-like (python dict, LazyDict).
If there are duplicate keys, early keys have priority over latter ones
If there are duplicate keys, early keys have priority over latter ones.
""" """
def __init__(self, *dict_like_list): def __init__(self, *dict_like_list):
self._dict_like_list = dict_like_list self._dict_like_list = dict_like_list
self._keys = set() self._keys = set()
for d in dict_like_list: for obj in dict_like_list:
self._keys.update(d.keys()) self._keys.update(obj.keys())
def keys(self): def keys(self):
return self._keys return self._keys
def __getitem__(self, key): def __getitem__(self, key):
for d in self._dict_like_list: for obj in self._dict_like_list:
if key in d: if key in obj:
return d[key] return obj[key]
raise KeyError(key) raise KeyError(key)
def __contains__(self, key): def __contains__(self, key):
...@@ -290,7 +318,7 @@ class ReadOnlyDict(Mapping): ...@@ -290,7 +318,7 @@ class ReadOnlyDict(Mapping):
def __len__(self): def __len__(self):
return len(self._dict_like) return len(self._dict_like)
def build_relabel_map(x, sorted=False): def build_relabel_map(x, is_sorted=False):
"""Relabel the input ids to continuous ids that starts from zero. """Relabel the input ids to continuous ids that starts from zero.
Ids are assigned new ids according to their ascending order. Ids are assigned new ids according to their ascending order.
...@@ -310,7 +338,7 @@ def build_relabel_map(x, sorted=False): ...@@ -310,7 +338,7 @@ def build_relabel_map(x, sorted=False):
---------- ----------
x : Index x : Index
The input ids. The input ids.
sorted : bool, default=False is_sorted : bool, default=False
Whether the input has already been unique and sorted. Whether the input has already been unique and sorted.
Returns Returns
...@@ -323,7 +351,7 @@ def build_relabel_map(x, sorted=False): ...@@ -323,7 +351,7 @@ def build_relabel_map(x, sorted=False):
new id tensor: new_id = old_to_new[old_id] new id tensor: new_id = old_to_new[old_id]
""" """
x = x.tousertensor() x = x.tousertensor()
if not sorted: if not is_sorted:
unique_x, _ = F.sort_1d(F.unique(x)) unique_x, _ = F.sort_1d(F.unique(x))
else: else:
unique_x = x unique_x = x
...@@ -397,6 +425,7 @@ def cached_member(cache, prefix): ...@@ -397,6 +425,7 @@ def cached_member(cache, prefix):
return _creator return _creator
def is_dict_like(obj): def is_dict_like(obj):
"""Return true if the object can be treated as a dictionary."""
return isinstance(obj, Mapping) return isinstance(obj, Mapping)
def reorder(dict_like, index): def reorder(dict_like, index):
......
"""Views of DGLGraph.""" """Views of DGLGraph."""
from __future__ import absolute_import from __future__ import absolute_import
from collections import MutableMapping, namedtuple from collections import namedtuple
from collections.abc import MutableMapping
from .base import ALL, is_all, DGLError from .base import ALL, is_all, DGLError
from . import backend as F from . import backend as F
from . import utils
NodeSpace = namedtuple('NodeSpace', ['data']) NodeSpace = namedtuple('NodeSpace', ['data'])
...@@ -41,6 +41,12 @@ class NodeView(object): ...@@ -41,6 +41,12 @@ class NodeView(object):
return F.arange(0, len(self)) return F.arange(0, len(self))
class NodeDataView(MutableMapping): class NodeDataView(MutableMapping):
"""The data view class when G.nodes[...].data is called.
See Also
--------
dgl.DGLGraph.nodes
"""
__slots__ = ['_graph', '_nodes'] __slots__ = ['_graph', '_nodes']
def __init__(self, graph, nodes): def __init__(self, graph, nodes):
...@@ -103,6 +109,12 @@ class EdgeView(object): ...@@ -103,6 +109,12 @@ class EdgeView(object):
return self._graph.all_edges(*args, **kwargs) return self._graph.all_edges(*args, **kwargs)
class EdgeDataView(MutableMapping): class EdgeDataView(MutableMapping):
"""The data view class when G.edges[...].data is called.
See Also
--------
dgl.DGLGraph.edges
"""
__slots__ = ['_graph', '_edges'] __slots__ = ['_graph', '_edges']
def __init__(self, graph, edges): def __init__(self, graph, edges):
......
...@@ -145,10 +145,11 @@ def test_create_from_elist(): ...@@ -145,10 +145,11 @@ def test_create_from_elist():
for i, (u, v) in enumerate(elist): for i, (u, v) in enumerate(elist):
assert g.edge_id(u, v)[0] == i assert g.edge_id(u, v)[0] == i
# immutable graph # immutable graph
g = create_graph_index(elist, readonly=True) # TODO: disabled due to torch support
for i, (u, v) in enumerate(elist): #g = create_graph_index(elist, readonly=True)
print(u, v, g.edge_id(u, v)[0]) #for i, (u, v) in enumerate(elist):
assert g.edge_id(u, v)[0] == i # print(u, v, g.edge_id(u, v)[0])
# assert g.edge_id(u, v)[0] == i
if __name__ == '__main__': if __name__ == '__main__':
test_edge_id() test_edge_id()
......
[MASTER]
# A comma-separated list of package or module names from where C extensions may
# be loaded. Extensions are loading into the active Python interpreter and may
# run arbitrary code.
extension-pkg-whitelist=
# Add files or directories to the blacklist. They should be base names, not
# paths.
ignore=CVS,_cy2,_cy3,backend,data,nn,contrib
# Add files or directories matching the regex patterns to the blacklist. The
# regex matches against base names, not paths.
ignore-patterns=
# Python code to execute, usually for sys.path manipulation such as
# pygtk.require().
#init-hook=
# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the
# number of processors available to use.
jobs=4
# Control the amount of potential inferred values when inferring a single
# object. This can help the performance when dealing with large functions or
# complex, nested conditions.
limit-inference-results=100
# List of plugins (as comma separated values of python modules names) to load,
# usually to register additional checkers.
load-plugins=
# Pickle collected data for later comparisons.
persistent=yes
# Specify a configuration file.
#rcfile=
# When enabled, pylint would attempt to guess common misconfiguration and emit
# user-friendly hints instead of false-positive error messages.
suggestion-mode=yes
# Allow loading of arbitrary C extensions. Extensions are imported into the
# active Python interpreter and may run arbitrary code.
unsafe-load-any-extension=no
[MESSAGES CONTROL]
# Only show warnings with the listed confidence levels. Leave empty to show
# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED.
confidence=
# Disable the message, report, category or checker with the given id(s). You
# can either give multiple identifiers separated by comma (,) or put this
# option multiple times (only on the command line, not in the configuration
# file where it should appear only once). You can also use "--disable=all" to
# disable everything first and then reenable specific checks. For example, if
# you want to run only the similarities checker, you can use "--disable=all
# --enable=similarities". If you want to run only the classes checker, but have
# no Warning level messages displayed, use "--disable=all --enable=classes
# --disable=W".
disable=design,
similarities,
no-self-use,
attribute-defined-outside-init,
locally-disabled,
star-args,
pointless-except,
bad-option-value,
global-statement,
fixme,
suppressed-message,
useless-suppression,
locally-enabled,
import-error,
unsubscriptable-object,
unbalanced-tuple-unpacking,
protected-access,
useless-object-inheritance,
no-else-return,
len-as-condition,
cyclic-import, # disabled due to the inevitable dgl.graph -> dgl.subgraph loop
undefined-variable, # disabled due to C extension (should enable)
# Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option
# multiple time (only on the command line, not in the configuration file where
# it should appear only once). See also the "--disable" option for examples.
enable=c-extension-no-member
[REPORTS]
# Python expression which should return a note less than 10 (10 is the highest
# note). You have access to the variables errors warning, statement which
# respectively contain the number of errors / warnings messages and the total
# number of statements analyzed. This is used by the global evaluation report
# (RP0004).
evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
# Template used to display messages. This is a python new-style format string
# used to format the message information. See doc for all details.
#msg-template=
# Set the output format. Available formats are text, parseable, colorized, json
# and msvs (visual studio). You can also give a reporter class, e.g.
# mypackage.mymodule.MyReporterClass.
output-format=text
# Tells whether to display a full report or only the messages.
reports=no
# Activate the evaluation score.
score=yes
[REFACTORING]
# Maximum number of nested blocks for function / method body
max-nested-blocks=5
# Complete name of functions that never returns. When checking for
# inconsistent-return-statements if a never returning function is called then
# it will be considered as an explicit return statement and no message will be
# printed.
never-returning-functions=sys.exit
[MISCELLANEOUS]
# List of note tags to take in consideration, separated by a comma.
notes=FIXME,
XXX,
TODO
[BASIC]
# Naming style matching correct argument names.
argument-naming-style=snake_case
# Regular expression matching correct argument names. Overrides argument-
# naming-style.
#argument-rgx=
# Naming style matching correct attribute names.
attr-naming-style=snake_case
# Regular expression matching correct attribute names. Overrides attr-naming-
# style.
#attr-rgx=
# Bad variable names which should always be refused, separated by a comma.
bad-names=foo,
bar,
baz,
toto,
tutu,
tata
# Naming style matching correct class attribute names.
class-attribute-naming-style=any
# Regular expression matching correct class attribute names. Overrides class-
# attribute-naming-style.
#class-attribute-rgx=
# Naming style matching correct class names.
class-naming-style=PascalCase
# Regular expression matching correct class names. Overrides class-naming-
# style.
#class-rgx=
# Naming style matching correct constant names.
const-naming-style=UPPER_CASE
# Regular expression matching correct constant names. Overrides const-naming-
# style.
#const-rgx=
# Minimum line length for functions/classes that require docstrings, shorter
# ones are exempt.
docstring-min-length=-1
# Naming style matching correct function names.
function-naming-style=snake_case
# Regular expression matching correct function names. Overrides function-
# naming-style.
#function-rgx=
# Good variable names which should always be accepted, separated by a comma.
good-names=i,j,k,u,v,e,n,m,w,x,y,g,fn,ex,Run,_
# Include a hint for the correct naming format with invalid-name.
include-naming-hint=no
# Naming style matching correct inline iteration names.
inlinevar-naming-style=any
# Regular expression matching correct inline iteration names. Overrides
# inlinevar-naming-style.
#inlinevar-rgx=
# Naming style matching correct method names.
method-naming-style=snake_case
# Regular expression matching correct method names. Overrides method-naming-
# style.
#method-rgx=
# Naming style matching correct module names.
module-naming-style=snake_case
# Regular expression matching correct module names. Overrides module-naming-
# style.
#module-rgx=
# Colon-delimited sets of names that determine each other's naming style when
# the name regexes allow several styles.
name-group=
# Regular expression which should only match function or class names that do
# not require a docstring.
no-docstring-rgx=^_
# List of decorators that produce properties, such as abc.abstractproperty. Add
# to this list to register other decorators that produce valid properties.
# These decorators are taken in consideration only for invalid-name.
property-classes=abc.abstractproperty
# Naming style matching correct variable names.
variable-naming-style=snake_case
# Regular expression matching correct variable names. Overrides variable-
# naming-style.
#variable-rgx=
[VARIABLES]
# List of additional names supposed to be defined in builtins. Remember that
# you should avoid defining new builtins when possible.
additional-builtins=
# Tells whether unused global variables should be treated as a violation.
allow-global-unused-variables=yes
# List of strings which can identify a callback function by name. A callback
# name must start or end with one of those strings.
callbacks=cb_,
_cb
# A regular expression matching the name of dummy variables (i.e. expected to
# not be used).
dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_
# Argument names that match this expression will be ignored. Default to name
# with leading underscore.
ignored-argument-names=_.*|^ignored_|^unused_
# Tells whether we should check for unused import in __init__ files.
init-import=no
# List of qualified module names which can have objects that can redefine
# builtins.
redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io
[SPELLING]
# Limits count of emitted suggestions for spelling mistakes.
max-spelling-suggestions=4
# Spelling dictionary name. Available dictionaries: none. To make it working
# install python-enchant package..
spelling-dict=
# List of comma separated words that should not be checked.
spelling-ignore-words=
# A path to a file that contains private dictionary; one word per line.
spelling-private-dict-file=
# Tells whether to store unknown words to indicated private dictionary in
# --spelling-private-dict-file option instead of raising a message.
spelling-store-unknown-words=no
[LOGGING]
# Format style used to check logging format string. `old` means using %
# formatting, while `new` is for `{}` formatting.
logging-format-style=old
# Logging modules to check that the string format arguments are in logging
# function parameter format.
logging-modules=logging
[FORMAT]
# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
expected-line-ending-format=
# Regexp for a line that is allowed to be longer than the limit.
ignore-long-lines=^\s*(# )?<?https?://\S+>?$
# Number of spaces of indent required inside a hanging or continued line.
indent-after-paren=4
# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1
# tab).
indent-string=' '
# Maximum number of characters on a single line.
max-line-length=100
# Maximum number of lines in a module.
max-module-lines=4000
# List of optional constructs for which whitespace checking is disabled. `dict-
# separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}.
# `trailing-comma` allows a space between comma and closing bracket: (a, ).
# `empty-line` allows space-only lines.
no-space-check=trailing-comma,
dict-separator
# Allow the body of a class to be on the same line as the declaration if body
# contains single statement.
single-line-class-stmt=no
# Allow the body of an if to be on the same line as the test if there is no
# else.
single-line-if-stmt=no
[SIMILARITIES]
# Ignore comments when computing similarities.
ignore-comments=yes
# Ignore docstrings when computing similarities.
ignore-docstrings=yes
# Ignore imports when computing similarities.
ignore-imports=no
# Minimum lines number of a similarity.
min-similarity-lines=4
[TYPECHECK]
# List of decorators that produce context managers, such as
# contextlib.contextmanager. Add to this list to register other decorators that
# produce valid context managers.
contextmanager-decorators=contextlib.contextmanager
# List of members which are set dynamically and missed by pylint inference
# system, and so shouldn't trigger E1101 when accessed. Python regular
# expressions are accepted.
generated-members=
# Tells whether missing members accessed in mixin class should be ignored. A
# mixin class is detected if its name ends with "mixin" (case insensitive).
ignore-mixin-members=yes
# Tells whether to warn about missing members when the owner of the attribute
# is inferred to be None.
ignore-none=yes
# This flag controls whether pylint should warn about no-member and similar
# checks whenever an opaque object is returned when inferring. The inference
# can return multiple potential results while evaluating a Python object, but
# some branches might not be evaluated, which results in partial inference. In
# that case, it might be useful to still emit no-member and other checks for
# the rest of the inferred objects.
ignore-on-opaque-inference=yes
# List of class names for which member attributes should not be checked (useful
# for classes with dynamically set attributes). This supports the use of
# qualified names.
ignored-classes=optparse.Values,thread._local,_thread._local
# List of module names for which member attributes should not be checked
# (useful for modules/projects where namespaces are manipulated during runtime
# and thus existing member attributes cannot be deduced by static analysis. It
# supports qualified module names, as well as Unix pattern matching.
ignored-modules=dgl.backend,dgl._api_internal
# Show a hint with possible names when a member name was not found. The aspect
# of finding the hint is based on edit distance.
missing-member-hint=yes
# The minimum edit distance a name should have in order to be considered a
# similar match for a missing member name.
missing-member-hint-distance=1
# The total number of similar names that should be taken in consideration when
# showing a hint for a missing member.
missing-member-max-choices=1
[IMPORTS]
# Allow wildcard imports from modules that define __all__.
allow-wildcard-with-all=yes
# Analyse import fallback blocks. This can be used to support both Python 2 and
# 3 compatible code, which means that the block might have code that exists
# only in one or another interpreter, leading to false positives when analysed.
analyse-fallback-blocks=no
# Deprecated modules which should not be used, separated by a comma.
deprecated-modules=optparse,tkinter.tix
# Create a graph of external dependencies in the given file (report RP0402 must
# not be disabled).
ext-import-graph=
# Create a graph of every (i.e. internal and external) dependencies in the
# given file (report RP0402 must not be disabled).
import-graph=
# Create a graph of internal dependencies in the given file (report RP0402 must
# not be disabled).
int-import-graph=
# Force import order to recognize a module as part of the standard
# compatibility libraries.
known-standard-library=
# Force import order to recognize a module as part of a third party library.
known-third-party=enchant
[DESIGN]
# Maximum number of arguments for function / method.
max-args=5
# Maximum number of attributes for a class (see R0902).
max-attributes=7
# Maximum number of boolean expressions in an if statement.
max-bool-expr=5
# Maximum number of branch for function / method body.
max-branches=12
# Maximum number of locals for function / method body.
max-locals=15
# Maximum number of parents for a class (see R0901).
max-parents=7
# Maximum number of public methods for a class (see R0904).
max-public-methods=20
# Maximum number of return / yield for function / method body.
max-returns=6
# Maximum number of statements in function / method body.
max-statements=50
# Minimum number of public methods for a class (see R0903).
min-public-methods=2
[CLASSES]
# List of method names used to declare (i.e. assign) instance attributes.
defining-attr-methods=__init__,
__new__,
setUp
# List of member names, which should be excluded from the protected access
# warning.
exclude-protected=_asdict,
_fields,
_replace,
_source,
_make
# List of valid names for the first argument in a class method.
valid-classmethod-first-arg=cls
# List of valid names for the first argument in a metaclass class method.
valid-metaclass-classmethod-first-arg=cls
[EXCEPTIONS]
# Exceptions that will emit a warning when being caught. Defaults to
# "Exception".
overgeneral-exceptions=Exception
...@@ -4,8 +4,7 @@ import mxnet as mx ...@@ -4,8 +4,7 @@ import mxnet as mx
import numpy as np import numpy as np
import scipy as sp import scipy as sp
import dgl import dgl
from dgl.graph import GraphIndex, create_graph_index from dgl.graph_index import map_to_subgraph_nid, GraphIndex, create_graph_index
from dgl.graph_index import map_to_subgraph_nid
from dgl import utils from dgl import utils
def generate_rand_graph(n): def generate_rand_graph(n):
......
...@@ -596,11 +596,12 @@ def test_repr(): ...@@ -596,11 +596,12 @@ def test_repr():
G.add_nodes(10) G.add_nodes(10)
G.add_edge(0, 1) G.add_edge(0, 1)
repr_string = G.__repr__() repr_string = G.__repr__()
print(repr_string)
G.ndata['x'] = th.zeros((10, 5)) G.ndata['x'] = th.zeros((10, 5))
G.add_edges([0, 1], 2) G.add_edges([0, 1], 2)
G.edata['y'] = th.zeros((3, 4)) G.edata['y'] = th.zeros((3, 4))
repr_string = G.__repr__() repr_string = G.__repr__()
print(repr_string)
if __name__ == '__main__': if __name__ == '__main__':
test_nx_conversion() test_nx_conversion()
......
...@@ -61,7 +61,7 @@ def test_column1(): ...@@ -61,7 +61,7 @@ def test_column1():
def test_column2(): def test_column2():
# Test frameref column getter/setter # Test frameref column getter/setter
data = Frame(create_test_data()) data = Frame(create_test_data())
f = FrameRef(data, [3, 4, 5, 6, 7]) f = FrameRef(data, toindex([3, 4, 5, 6, 7]))
assert f.num_rows == 5 assert f.num_rows == 5
assert len(f) == 3 assert len(f) == 3
assert U.allclose(f['a1'], data['a1'].data[3:8]) assert U.allclose(f['a1'], data['a1'].data[3:8])
...@@ -111,7 +111,7 @@ def test_append2(): ...@@ -111,7 +111,7 @@ def test_append2():
assert not f.is_span_whole_column() assert not f.is_span_whole_column()
assert f.num_rows == 3 * N assert f.num_rows == 3 * N
new_idx = list(range(N)) + list(range(2*N, 4*N)) new_idx = list(range(N)) + list(range(2*N, 4*N))
assert th.all(f.index().tousertensor() == th.tensor(new_idx, dtype=th.int64)) assert th.all(f._index.tousertensor() == th.tensor(new_idx, dtype=th.int64))
assert data.num_rows == 4 * N assert data.num_rows == 4 * N
def test_append3(): def test_append3():
...@@ -233,8 +233,8 @@ def test_row4(): ...@@ -233,8 +233,8 @@ def test_row4():
def test_sharing(): def test_sharing():
data = Frame(create_test_data()) data = Frame(create_test_data())
f1 = FrameRef(data, index=[0, 1, 2, 3]) f1 = FrameRef(data, index=toindex([0, 1, 2, 3]))
f2 = FrameRef(data, index=[2, 3, 4, 5, 6]) f2 = FrameRef(data, index=toindex([2, 3, 4, 5, 6]))
# test read # test read
for k, v in f1.items(): for k, v in f1.items():
assert U.allclose(data[k].data[0:4], v) assert U.allclose(data[k].data[0:4], v)
...@@ -260,8 +260,8 @@ def test_sharing(): ...@@ -260,8 +260,8 @@ def test_sharing():
def test_slicing(): def test_slicing():
data = Frame(create_test_data(grad=True)) data = Frame(create_test_data(grad=True))
f1 = FrameRef(data, index=slice(1, 5)) f1 = FrameRef(data, index=toindex(slice(1, 5)))
f2 = FrameRef(data, index=slice(3, 8)) f2 = FrameRef(data, index=toindex(slice(3, 8)))
# test read # test read
for k, v in f1.items(): for k, v in f1.items():
assert U.allclose(data[k].data[1:5], v) assert U.allclose(data[k].data[1:5], v)
...@@ -279,15 +279,15 @@ def test_slicing(): ...@@ -279,15 +279,15 @@ def test_slicing():
'a2': th.ones([2, D]), 'a2': th.ones([2, D]),
'a3': th.ones([2, D]), 'a3': th.ones([2, D]),
} }
f2_a1[0:2] = 1 f2_a1[toindex(slice(0,2))] = 1
assert U.allclose(f2['a1'], f2_a1) assert U.allclose(f2['a1'], f2_a1)
f1[2:4] = { f1[toindex(slice(2,4))] = {
'a1': th.zeros([2, D]), 'a1': th.zeros([2, D]),
'a2': th.zeros([2, D]), 'a2': th.zeros([2, D]),
'a3': th.zeros([2, D]), 'a3': th.zeros([2, D]),
} }
f2_a1[0:2] = 0 f2_a1[toindex(slice(0,2))] = 0
assert U.allclose(f2['a1'], f2_a1) assert U.allclose(f2['a1'], f2_a1)
def test_add_rows(): def test_add_rows():
...@@ -299,12 +299,48 @@ def test_add_rows(): ...@@ -299,12 +299,48 @@ def test_add_rows():
ans = th.cat([x, th.zeros(3, 4)]) ans = th.cat([x, th.zeros(3, 4)])
assert U.allclose(f1['x'], ans) assert U.allclose(f1['x'], ans)
f1.add_rows(4) f1.add_rows(4)
f1[4:8] = {'x': th.ones(4, 4), 'y': th.ones(4, 5)} f1[toindex(slice(4,8))] = {'x': th.ones(4, 4), 'y': th.ones(4, 5)}
ans = th.cat([ans, th.ones(4, 4)]) ans = th.cat([ans, th.ones(4, 4)])
assert U.allclose(f1['x'], ans) assert U.allclose(f1['x'], ans)
ans = th.cat([th.zeros(4, 5), th.ones(4, 5)]) ans = th.cat([th.zeros(4, 5), th.ones(4, 5)])
assert U.allclose(f1['y'], ans) assert U.allclose(f1['y'], ans)
def test_inplace():
f = FrameRef(Frame(create_test_data()))
print(f.schemes)
a1addr = f['a1'].data.data_ptr()
a2addr = f['a2'].data.data_ptr()
a3addr = f['a3'].data.data_ptr()
# column updates are always out-of-place
f['a1'] = th.ones((N, D))
newa1addr = f['a1'].data.data_ptr()
assert a1addr != newa1addr
a1addr = newa1addr
# full row update that becomes column update
f[toindex(slice(0, N))] = {'a1' : th.ones((N, D))}
assert f['a1'].data.data_ptr() != a1addr
# row update (outplace) w/ slice
f[toindex(slice(1, 4))] = {'a2' : th.ones((3, D))}
newa2addr = f['a2'].data.data_ptr()
assert a2addr != newa2addr
a2addr = newa2addr
# row update (outplace) w/ list
f[toindex([1, 3, 5])] = {'a2' : th.ones((3, D))}
newa2addr = f['a2'].data.data_ptr()
assert a2addr != newa2addr
a2addr = newa2addr
# row update (inplace) w/ slice
f.update_data(toindex(slice(1, 4)), {'a2' : th.ones((3, D))}, True)
newa2addr = f['a2'].data.data_ptr()
assert a2addr == newa2addr
# row update (inplace) w/ list
f.update_data(toindex([1, 3, 5]), {'a2' : th.ones((3, D))}, True)
newa2addr = f['a2'].data.data_ptr()
assert a2addr == newa2addr
if __name__ == '__main__': if __name__ == '__main__':
test_create() test_create()
test_column1() test_column1()
...@@ -319,3 +355,4 @@ if __name__ == '__main__': ...@@ -319,3 +355,4 @@ if __name__ == '__main__':
test_sharing() test_sharing()
test_slicing() test_slicing()
test_add_rows() test_add_rows()
test_inplace()
...@@ -33,9 +33,10 @@ def test_create_from_elist(): ...@@ -33,9 +33,10 @@ def test_create_from_elist():
for i, (u, v) in enumerate(elist): for i, (u, v) in enumerate(elist):
assert g.edge_id(u, v) == i assert g.edge_id(u, v) == i
# immutable graph # immutable graph
g = dgl.DGLGraph(elist, readonly=True) # XXX: not enabled for pytorch
for i, (u, v) in enumerate(elist): #g = dgl.DGLGraph(elist, readonly=True)
assert g.edge_id(u, v) == i #for i, (u, v) in enumerate(elist):
# assert g.edge_id(u, v) == i
def test_adjmat_cache(): def test_adjmat_cache():
n = 1000 n = 1000
...@@ -109,7 +110,7 @@ def test_incmat_cache(): ...@@ -109,7 +110,7 @@ def test_incmat_cache():
assert dur2 < dur1 assert dur2 < dur1
assert id(inc1) == id(inc2) assert id(inc1) == id(inc2)
# different arg should result in different cache # different arg should result in different cache
inc3 = g.incidence_matrix(type="both") inc3 = g.incidence_matrix("both")
assert id(inc3) != id(inc2) assert id(inc3) != id(inc2)
# manually clear the cache # manually clear the cache
g.clear_cache() g.clear_cache()
......
...@@ -112,7 +112,7 @@ def test_pickling_graph(): ...@@ -112,7 +112,7 @@ def test_pickling_graph():
assert new_g._message_func == _global_message_func assert new_g._message_func == _global_message_func
assert isinstance(new_g._reduce_func, type(reduce_func)) assert isinstance(new_g._reduce_func, type(reduce_func))
assert new_g._reduce_func._name == 'sum' assert new_g._reduce_func._name == 'sum'
assert new_g._reduce_func.op == backend.sum assert new_g._reduce_func.reduce_op == backend.sum
assert new_g._reduce_func.msg_field == 'x' assert new_g._reduce_func.msg_field == 'x'
assert new_g._reduce_func.out_field == 'x' assert new_g._reduce_func.out_field == 'x'
......
...@@ -3,3 +3,7 @@ ...@@ -3,3 +3,7 @@
# cpplint # cpplint
echo 'Checking code style of C++ codes...' echo 'Checking code style of C++ codes...'
python3 third_party/dmlc-core/scripts/lint.py dgl cpp include src python3 third_party/dmlc-core/scripts/lint.py dgl cpp include src
# pylint
echo 'Checking code style of python codes...'
python3 -m pylint --reports=y -v --rcfile=tests/lint/pylintrc python/dgl
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment