Unverified Commit 9c135fd5 authored by VoVAllen's avatar VoVAllen Committed by GitHub
Browse files

Merge pull request #4 from jermainewang/master

Sync with latest commit
parents 9d3f299d 00add9f2
// DGL Graph interface /*!
#ifndef DGL_DGLGRAPH_H_ * Copyright (c) 2018 by Contributors
#define DGL_DGLGRAPH_H_ * \file dgl/graph.h
* \brief DGL graph index class.
*/
#ifndef DGL_GRAPH_H_
#define DGL_GRAPH_H_
#include <stdint.h> #include <vector>
#include <cstdint>
#include "runtime/ndarray.h" #include "runtime/ndarray.h"
namespace dgl { namespace dgl {
...@@ -17,7 +22,7 @@ class GraphOp; ...@@ -17,7 +22,7 @@ class GraphOp;
struct Subgraph; struct Subgraph;
/*! /*!
* \brief Base dgl graph class. * \brief Base dgl graph index class.
* *
* DGL's graph is directed. Vertices are integers enumerated from zero. Edges * DGL's graph is directed. Vertices are integers enumerated from zero. Edges
* are uniquely identified by the two endpoints. Multi-edge is currently not * are uniquely identified by the two endpoints. Multi-edge is currently not
...@@ -41,7 +46,7 @@ class Graph { ...@@ -41,7 +46,7 @@ class Graph {
} EdgeArray; } EdgeArray;
/*! \brief default constructor */ /*! \brief default constructor */
Graph(bool multigraph = false) : is_multigraph_(multigraph) {} explicit Graph(bool multigraph = false) : is_multigraph_(multigraph) {}
/*! \brief default copy constructor */ /*! \brief default copy constructor */
Graph(const Graph& other) = default; Graph(const Graph& other) = default;
...@@ -347,4 +352,4 @@ struct Subgraph { ...@@ -347,4 +352,4 @@ struct Subgraph {
} // namespace dgl } // namespace dgl
#endif // DGL_DGLGRAPH_H_ #endif // DGL_GRAPH_H_
// Graph operations /*!
* Copyright (c) 2018 by Contributors
* \file dgl/graph_op.h
* \brief Operations on graph index.
*/
#ifndef DGL_GRAPH_OP_H_ #ifndef DGL_GRAPH_OP_H_
#define DGL_GRAPH_OP_H_ #define DGL_GRAPH_OP_H_
#include <vector>
#include "graph.h" #include "graph.h"
namespace dgl { namespace dgl {
......
# C API and runtime
Borrowed and adapted from TVM project.
...@@ -7,8 +7,8 @@ ...@@ -7,8 +7,8 @@
* used by compiled tvm operators, usually user do not need to use these * used by compiled tvm operators, usually user do not need to use these
* function directly. * function directly.
*/ */
#ifndef TVM_RUNTIME_C_BACKEND_API_H_ #ifndef DGL_RUNTIME_C_BACKEND_API_H_
#define TVM_RUNTIME_C_BACKEND_API_H_ #define DGL_RUNTIME_C_BACKEND_API_H_
#include "c_runtime_api.h" #include "c_runtime_api.h"
...@@ -136,4 +136,4 @@ TVM_DLL int TVMBackendRunOnce(void** handle, ...@@ -136,4 +136,4 @@ TVM_DLL int TVMBackendRunOnce(void** handle,
#ifdef __cplusplus #ifdef __cplusplus
} // TVM_EXTERN_C } // TVM_EXTERN_C
#endif #endif
#endif // TVM_RUNTIME_C_BACKEND_API_H_ #endif // DGL_RUNTIME_C_BACKEND_API_H_
/*! /*!
* Copyright (c) 2016 by Contributors * Copyright (c) 2016 by Contributors
* \file tvm/runtime/c_runtime_api.h * \file dgl/runtime/c_runtime_api.h
* \brief TVM runtime library. * \brief TVM runtime library.
* *
* The philosophy of TVM project is to customize the compilation * The philosophy of TVM project is to customize the compilation
...@@ -15,8 +15,8 @@ ...@@ -15,8 +15,8 @@
* - Use TVMFuncListGlobalNames to get global function name * - Use TVMFuncListGlobalNames to get global function name
* - Use TVMFuncCall to call these functions. * - Use TVMFuncCall to call these functions.
*/ */
#ifndef TVM_RUNTIME_C_RUNTIME_API_H_ #ifndef DGL_RUNTIME_C_RUNTIME_API_H_
#define TVM_RUNTIME_C_RUNTIME_API_H_ #define DGL_RUNTIME_C_RUNTIME_API_H_
// Macros to do weak linking // Macros to do weak linking
#ifdef _MSC_VER #ifdef _MSC_VER
...@@ -530,4 +530,4 @@ TVM_DLL int TVMStreamStreamSynchronize(int device_type, ...@@ -530,4 +530,4 @@ TVM_DLL int TVMStreamStreamSynchronize(int device_type,
#ifdef __cplusplus #ifdef __cplusplus
} // TVM_EXTERN_C } // TVM_EXTERN_C
#endif #endif
#endif // TVM_RUNTIME_C_RUNTIME_API_H_ #endif // DGL_RUNTIME_C_RUNTIME_API_H_
/*! /*!
* Copyright (c) 2016 by Contributors * Copyright (c) 2016 by Contributors
* \file tvm/runtime/device_api.h * \file dgl/runtime/device_api.h
* \brief Abstract device memory management API * \brief Abstract device memory management API
*/ */
#ifndef TVM_RUNTIME_DEVICE_API_H_ #ifndef DGL_RUNTIME_DEVICE_API_H_
#define TVM_RUNTIME_DEVICE_API_H_ #define DGL_RUNTIME_DEVICE_API_H_
#include <string> #include <string>
#include "packed_func.h" #include "packed_func.h"
...@@ -180,4 +180,4 @@ class DeviceAPI { ...@@ -180,4 +180,4 @@ class DeviceAPI {
constexpr int kRPCSessMask = 128; constexpr int kRPCSessMask = 128;
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
#endif // TVM_RUNTIME_DEVICE_API_H_ #endif // DGL_RUNTIME_DEVICE_API_H_
/*! /*!
* Copyright (c) 2017 by Contributors * Copyright (c) 2017 by Contributors
* \file tvm/runtime/module.h * \file dgl/runtime/module.h
* \brief Runtime container of the functions generated by TVM, * \brief Runtime container of the functions generated by TVM,
* This is used to support dynamically link, load and save * This is used to support dynamically link, load and save
* functions from different convention under unified API. * functions from different convention under unified API.
*/ */
#ifndef TVM_RUNTIME_MODULE_H_ #ifndef DGL_RUNTIME_MODULE_H_
#define TVM_RUNTIME_MODULE_H_ #define DGL_RUNTIME_MODULE_H_
#include <dmlc/io.h> #include <dmlc/io.h>
#include <memory> #include <memory>
...@@ -174,4 +174,4 @@ inline const ModuleNode* Module::operator->() const { ...@@ -174,4 +174,4 @@ inline const ModuleNode* Module::operator->() const {
} // namespace tvm } // namespace tvm
#include "packed_func.h" #include "packed_func.h"
#endif // TVM_RUNTIME_MODULE_H_ #endif // DGL_RUNTIME_MODULE_H_
/*! /*!
* Copyright (c) 2017 by Contributors * Copyright (c) 2017 by Contributors
* \file tvm/runtime/ndarray.h * \file dgl/runtime/ndarray.h
* \brief Abstract device memory management API * \brief Abstract device memory management API
*/ */
#ifndef TVM_RUNTIME_NDARRAY_H_ #ifndef DGL_RUNTIME_NDARRAY_H_
#define TVM_RUNTIME_NDARRAY_H_ #define DGL_RUNTIME_NDARRAY_H_
#include <atomic> #include <atomic>
#include <vector> #include <vector>
...@@ -422,4 +422,4 @@ inline bool NDArray::Load(dmlc::Stream* strm) { ...@@ -422,4 +422,4 @@ inline bool NDArray::Load(dmlc::Stream* strm) {
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
#endif // TVM_RUNTIME_NDARRAY_H_ #endif // DGL_RUNTIME_NDARRAY_H_
/*! /*!
* Copyright (c) 2017 by Contributors * Copyright (c) 2017 by Contributors
* \file tvm/runtime/packed_func.h * \file dgl/runtime/packed_func.h
* \brief Type-erased function used across TVM API. * \brief Type-erased function used across TVM API.
*/ */
#ifndef TVM_RUNTIME_PACKED_FUNC_H_ #ifndef DGL_RUNTIME_PACKED_FUNC_H_
#define TVM_RUNTIME_PACKED_FUNC_H_ #define DGL_RUNTIME_PACKED_FUNC_H_
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <functional> #include <functional>
...@@ -1212,4 +1212,4 @@ inline PackedFunc Module::GetFunction(const std::string& name, bool query_import ...@@ -1212,4 +1212,4 @@ inline PackedFunc Module::GetFunction(const std::string& name, bool query_import
} }
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
#endif // TVM_RUNTIME_PACKED_FUNC_H_ #endif // DGL_RUNTIME_PACKED_FUNC_H_
/*! /*!
* Copyright (c) 2017 by Contributors * Copyright (c) 2017 by Contributors
* \file tvm/runtime/registry.h * \file dgl/runtime/registry.h
* \brief This file defines the TVM global function registry. * \brief This file defines the TVM global function registry.
* *
* The registered functions will be made available to front-end * The registered functions will be made available to front-end
...@@ -22,8 +22,8 @@ ...@@ -22,8 +22,8 @@
* }); * });
* \endcode * \endcode
*/ */
#ifndef TVM_RUNTIME_REGISTRY_H_ #ifndef DGL_RUNTIME_REGISTRY_H_
#define TVM_RUNTIME_REGISTRY_H_ #define DGL_RUNTIME_REGISTRY_H_
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -141,4 +141,4 @@ class Registry { ...@@ -141,4 +141,4 @@ class Registry {
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
#endif // TVM_RUNTIME_REGISTRY_H_ #endif // DGL_RUNTIME_REGISTRY_H_
/*! /*!
* Copyright (c) 2017 by Contributors * Copyright (c) 2017 by Contributors
* \file tvm/runtime/serializer.h * \file dgl/runtime/serializer.h
* \brief Serializer extension to support TVM data types * \brief Serializer extension to support TVM data types
* Include this file to enable serialization of DLDataType, DLContext * Include this file to enable serialization of DLDataType, DLContext
*/ */
#ifndef TVM_RUNTIME_SERIALIZER_H_ #ifndef DGL_RUNTIME_SERIALIZER_H_
#define TVM_RUNTIME_SERIALIZER_H_ #define DGL_RUNTIME_SERIALIZER_H_
#include <dmlc/io.h> #include <dmlc/io.h>
#include <dmlc/serializer.h> #include <dmlc/serializer.h>
...@@ -48,4 +48,4 @@ struct Handler<DLContext> { ...@@ -48,4 +48,4 @@ struct Handler<DLContext> {
} // namespace serializer } // namespace serializer
} // namespace dmlc } // namespace dmlc
#endif // TVM_RUNTIME_SERIALIZER_H_ #endif // DGL_RUNTIME_SERIALIZER_H_
/*! /*!
* Copyright (c) 2018 by Contributors * Copyright (c) 2018 by Contributors
* \file tvm/runtime/threading_backend.h * \file dgl/runtime/threading_backend.h
* \brief Utilities for manipulating thread pool threads. * \brief Utilities for manipulating thread pool threads.
*/ */
#ifndef TVM_RUNTIME_THREADING_BACKEND_H_ #ifndef DGL_RUNTIME_THREADING_BACKEND_H_
#define TVM_RUNTIME_THREADING_BACKEND_H_ #define DGL_RUNTIME_THREADING_BACKEND_H_
#include <functional> #include <functional>
#include <memory> #include <memory>
...@@ -82,4 +82,4 @@ int MaxConcurrency(); ...@@ -82,4 +82,4 @@ int MaxConcurrency();
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
#endif // TVM_RUNTIME_THREADING_BACKEND_H_ #endif // DGL_RUNTIME_THREADING_BACKEND_H_
/*! /*!
* Copyright (c) 2017 by Contributors * Copyright (c) 2017 by Contributors
* \file tvm/runtime/util.h * \file dgl/runtime/util.h
* \brief Useful runtime util. * \brief Useful runtime util.
*/ */
#ifndef TVM_RUNTIME_UTIL_H_ #ifndef DGL_RUNTIME_UTIL_H_
#define TVM_RUNTIME_UTIL_H_ #define DGL_RUNTIME_UTIL_H_
#include "c_runtime_api.h" #include "c_runtime_api.h"
...@@ -50,4 +50,4 @@ enum TVMStructFieldKind : int { ...@@ -50,4 +50,4 @@ enum TVMStructFieldKind : int {
} // namespace intrinsic } // namespace intrinsic
} // namespace ir } // namespace ir
} // namespace tvm } // namespace tvm
#endif // TVM_RUNTIME_UTIL_H_ #endif // DGL_RUNTIME_UTIL_H_
// DGL Scheduler interface /*!
* Copyright (c) 2018 by Contributors
* \file dgl/scheduler.h
* \brief Operations on graph index.
*/
#ifndef DGL_SCHEDULER_H_ #ifndef DGL_SCHEDULER_H_
#define DGL_SCHEDULER_H_ #define DGL_SCHEDULER_H_
#include "runtime/ndarray.h"
#include <vector> #include <vector>
#include "runtime/ndarray.h"
namespace dgl { namespace dgl {
......
from . import backend from . import backend
from . import data from . import data
from . import function from . import function
from . import generator
from . import nn from . import nn
from ._ffi.runtime_ctypes import TypeCode from ._ffi.runtime_ctypes import TypeCode
...@@ -10,6 +9,5 @@ from ._ffi.base import DGLError, __version__ ...@@ -10,6 +9,5 @@ from ._ffi.base import DGLError, __version__
from .base import ALL from .base import ALL
from .batched_graph import * from .batched_graph import *
from .generator import * from .graph import DGLGraph
from .graph import DGLGraph, __MSG__, __REPR__
from .subgraph import DGLSubGraph from .subgraph import DGLSubGraph
...@@ -46,8 +46,14 @@ def from_numpy(np_data): ...@@ -46,8 +46,14 @@ def from_numpy(np_data):
def pack(tensors): def pack(tensors):
return F.concat(*tensors, dim=0) return F.concat(*tensors, dim=0)
def unpack(x, indices_or_sections=1): def unpack(x, split_sizes_or_sections=1):
return th.split(x, indices_or_sections) if isinstance(split_sizes_or_sections, list):
np_arr = x.asnumpy()
indices = np.cumsum(split_sizes_or_sections)
res = np.split(np_arr, indices[:-1])
return [tensor(arr, dtype=x.dtype) for arr in res]
else:
return F.split(x, split_sizes_or_sections)
# TODO this doesn't exist for symbol. # TODO this doesn't exist for symbol.
def shape(x): def shape(x):
...@@ -66,6 +72,9 @@ def unique(x): ...@@ -66,6 +72,9 @@ def unique(x):
return mx.nd.array(tmp, ctx=x.context, dtype=x.dtype) return mx.nd.array(tmp, ctx=x.context, dtype=x.dtype)
def gather_row(data, row_index): def gather_row(data, row_index):
if isinstance(row_index, F.NDArray):
return F.take(data, row_index)
else:
return data[row_index,] return data[row_index,]
scatter_row = mx.nd.contrib.index_copy scatter_row = mx.nd.contrib.index_copy
...@@ -114,6 +123,27 @@ def get_context(x): ...@@ -114,6 +123,27 @@ def get_context(x):
def _typestr(arr_dtype): def _typestr(arr_dtype):
return arr_dtype return arr_dtype
def get_tvmtype(arr):
arr_dtype = arr.dtype
if arr_dtype == np.float16:
return TVMType('float16')
elif arr_dtype == np.float32:
return TVMType('float32')
elif arr_dtype == np.float64:
return TVMType('float64')
elif arr_dtype == np.int16:
return TVMType('int16')
elif arr_dtype == np.int32:
return TVMType('int32')
elif arr_dtype == np.int64:
return TVMType('int64')
elif arr_dtype == np.int8:
return TVMType('int8')
elif arr_dtype == np.uint8:
return TVMType('uint8')
else:
raise RuntimeError('Unsupported data type:', arr_dtype)
def zerocopy_to_dlpack(arr): def zerocopy_to_dlpack(arr):
"""Return a dlpack compatible array using zero copy.""" """Return a dlpack compatible array using zero copy."""
return arr.to_dlpack_for_read() return arr.to_dlpack_for_read()
......
...@@ -93,23 +93,24 @@ def get_context(arr): ...@@ -93,23 +93,24 @@ def get_context(arr):
return TVMContext( return TVMContext(
TVMContext.STR2MASK[arr.device.type], arr.device.index) TVMContext.STR2MASK[arr.device.type], arr.device.index)
def _typestr(arr_dtype): def get_tvmtype(arr):
arr_dtype = arr.dtype
if arr_dtype in (th.float16, th.half): if arr_dtype in (th.float16, th.half):
return 'float16' return TVMType('float16')
elif arr_dtype in (th.float32, th.float): elif arr_dtype in (th.float32, th.float):
return 'float32' return TVMType('float32')
elif arr_dtype in (th.float64, th.double): elif arr_dtype in (th.float64, th.double):
return 'float64' return TVMType('float64')
elif arr_dtype in (th.int16, th.short): elif arr_dtype in (th.int16, th.short):
return 'int16' return TVMType('int16')
elif arr_dtype in (th.int32, th.int): elif arr_dtype in (th.int32, th.int):
return 'int32' return TVMType('int32')
elif arr_dtype in (th.int64, th.long): elif arr_dtype in (th.int64, th.long):
return 'int64' return TVMType('int64')
elif arr_dtype == th.int8: elif arr_dtype == th.int8:
return 'int8' return TVMType('int8')
elif arr_dtype == th.uint8: elif arr_dtype == th.uint8:
return 'uint8' return TVMType('uint8')
else: else:
raise RuntimeError('Unsupported data type:', arr_dtype) raise RuntimeError('Unsupported data type:', arr_dtype)
...@@ -130,20 +131,6 @@ def zerocopy_from_numpy(np_data): ...@@ -130,20 +131,6 @@ def zerocopy_from_numpy(np_data):
"""Return a tensor that shares the numpy data.""" """Return a tensor that shares the numpy data."""
return th.from_numpy(np_data) return th.from_numpy(np_data)
'''
data = arr_data
assert data.is_contiguous()
arr = TVMArray()
shape = c_array(tvm_shape_index_t, tuple(data.shape))
arr.data = ctypes.cast(data.data_ptr(), ctypes.c_void_p)
arr.shape = shape
arr.strides = None
arr.dtype = TVMType(_typestr(data.dtype))
arr.ndim = len(shape)
arr.ctx = get_context(data)
return arr
'''
def nonzero_1d(arr): def nonzero_1d(arr):
"""Return a 1D tensor with nonzero element indices in a 1D vector""" """Return a 1D tensor with nonzero element indices in a 1D vector"""
assert arr.dim() == 1 assert arr.dim() == 1
......
"""Module for base types and utilities.""" """Module for base types and utilities."""
from __future__ import absolute_import
import warnings
from ._ffi.base import DGLError
# A special argument for selecting all nodes/edges. # A special argument for selecting all nodes/edges.
ALL = "__ALL__" ALL = "__ALL__"
...@@ -6,5 +11,4 @@ ALL = "__ALL__" ...@@ -6,5 +11,4 @@ ALL = "__ALL__"
def is_all(arg): def is_all(arg):
return isinstance(arg, str) and arg == ALL return isinstance(arg, str) and arg == ALL
__MSG__ = "__MSG__" dgl_warning = warnings.warn
__REPR__ = "__REPR__"
...@@ -206,10 +206,33 @@ def get_gnp_generator(args): ...@@ -206,10 +206,33 @@ def get_gnp_generator(args):
return nx.fast_gnp_random_graph(n, p, seed, True) return nx.fast_gnp_random_graph(n, p, seed, True)
return _gen return _gen
class ScipyGraph(object):
"""A simple graph object that uses scipy matrix."""
def __init__(self, mat):
self._mat = mat
def get_graph(self):
return self._mat
def number_of_nodes(self):
return self._mat.shape[0]
def number_of_edges(self):
return self._mat.getnnz()
def get_scipy_generator(args):
n = args.syn_gnp_n
p = (2 * np.log(n) / n) if args.syn_gnp_p == 0. else args.syn_gnp_p
def _gen(seed):
return ScipyGraph(sp.random(n, n, p, format='coo'))
return _gen
def load_synthetic(args): def load_synthetic(args):
ty = args.syn_type ty = args.syn_type
if ty == 'gnp': if ty == 'gnp':
gen = get_gnp_generator(args) gen = get_gnp_generator(args)
elif ty == 'scipy':
gen = get_scipy_generator(args)
else: else:
raise ValueError('Unknown graph generator type: {}'.format(ty)) raise ValueError('Unknown graph generator type: {}'.format(ty))
return GCNSyntheticDataset( return GCNSyntheticDataset(
......
"""Dataset utilities.""" """Dataset utilities."""
from __future__ import absolute_import
import os import os, sys
import hashlib import hashlib
import warnings import warnings
import zipfile import zipfile
...@@ -125,17 +126,22 @@ def extract_archive(file, target_dir): ...@@ -125,17 +126,22 @@ def extract_archive(file, target_dir):
target_dir : str target_dir : str
Target directory of the archive to be uncompressed Target directory of the archive to be uncompressed
""" """
if os.path.exists(target_dir):
return
if file.endswith('.gz') or file.endswith('.tar') or file.endswith('.tgz'): if file.endswith('.gz') or file.endswith('.tar') or file.endswith('.tgz'):
archive = tarfile.open(file, 'r') archive = tarfile.open(file, 'r')
elif file.endswith('.zip'): elif file.endswith('.zip'):
archive = zipfile.ZipFile(file, 'r') archive = zipfile.ZipFile(file, 'r')
else: else:
raise Exception('Unrecognized file type: ' + file) raise Exception('Unrecognized file type: ' + file)
print('Extracting file to {}'.format(target_dir))
archive.extractall(path=target_dir) archive.extractall(path=target_dir)
archive.close() archive.close()
def get_download_dir(): def get_download_dir():
dirname = '_download' """Get the absolute path to the download directory."""
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
dirname = os.path.join(curr_path, '../../../_download')
if not os.path.exists(dirname): if not os.path.exists(dirname):
os.makedirs(dirname) os.makedirs(dirname)
return dirname return dirname
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment