You need to sign in or sign up before continuing.
Unverified Commit 9c135fd5 authored by VoVAllen's avatar VoVAllen Committed by GitHub
Browse files

Merge pull request #4 from jermainewang/master

Sync with latest commit
parents 9d3f299d 00add9f2
// DGL Graph interface
#ifndef DGL_DGLGRAPH_H_
#define DGL_DGLGRAPH_H_
/*!
* Copyright (c) 2018 by Contributors
* \file dgl/graph.h
* \brief DGL graph index class.
*/
#ifndef DGL_GRAPH_H_
#define DGL_GRAPH_H_
#include <stdint.h>
#include <vector>
#include <cstdint>
#include "runtime/ndarray.h"
namespace dgl {
......@@ -17,7 +22,7 @@ class GraphOp;
struct Subgraph;
/*!
* \brief Base dgl graph class.
* \brief Base dgl graph index class.
*
* DGL's graph is directed. Vertices are integers enumerated from zero. Edges
* are uniquely identified by the two endpoints. Multi-edge is currently not
......@@ -41,7 +46,7 @@ class Graph {
} EdgeArray;
/*! \brief default constructor */
Graph(bool multigraph = false) : is_multigraph_(multigraph) {}
explicit Graph(bool multigraph = false) : is_multigraph_(multigraph) {}
/*! \brief default copy constructor */
Graph(const Graph& other) = default;
......@@ -192,7 +197,7 @@ class Graph {
* \return the id arrays of the two endpoints of the edges.
*/
EdgeArray InEdges(IdArray vids) const;
/*!
* \brief Get the out edges of the vertex.
* \note The returned src id array is filled with vid.
......@@ -347,4 +352,4 @@ struct Subgraph {
} // namespace dgl
#endif // DGL_DGLGRAPH_H_
#endif // DGL_GRAPH_H_
// Graph operations
/*!
* Copyright (c) 2018 by Contributors
* \file dgl/graph_op.h
* \brief Operations on graph index.
*/
#ifndef DGL_GRAPH_OP_H_
#define DGL_GRAPH_OP_H_
#include <vector>
#include "graph.h"
namespace dgl {
......
# C API and runtime
Borrowed and adapted from TVM project.
......@@ -7,8 +7,8 @@
* used by compiled tvm operators, usually user do not need to use these
* function directly.
*/
#ifndef TVM_RUNTIME_C_BACKEND_API_H_
#define TVM_RUNTIME_C_BACKEND_API_H_
#ifndef DGL_RUNTIME_C_BACKEND_API_H_
#define DGL_RUNTIME_C_BACKEND_API_H_
#include "c_runtime_api.h"
......@@ -136,4 +136,4 @@ TVM_DLL int TVMBackendRunOnce(void** handle,
#ifdef __cplusplus
} // TVM_EXTERN_C
#endif
#endif // TVM_RUNTIME_C_BACKEND_API_H_
#endif // DGL_RUNTIME_C_BACKEND_API_H_
/*!
* Copyright (c) 2016 by Contributors
* \file tvm/runtime/c_runtime_api.h
* \file dgl/runtime/c_runtime_api.h
* \brief TVM runtime library.
*
* The philosophy of TVM project is to customize the compilation
......@@ -15,8 +15,8 @@
* - Use TVMFuncListGlobalNames to get global function name
* - Use TVMFuncCall to call these functions.
*/
#ifndef TVM_RUNTIME_C_RUNTIME_API_H_
#define TVM_RUNTIME_C_RUNTIME_API_H_
#ifndef DGL_RUNTIME_C_RUNTIME_API_H_
#define DGL_RUNTIME_C_RUNTIME_API_H_
// Macros to do weak linking
#ifdef _MSC_VER
......@@ -530,4 +530,4 @@ TVM_DLL int TVMStreamStreamSynchronize(int device_type,
#ifdef __cplusplus
} // TVM_EXTERN_C
#endif
#endif // TVM_RUNTIME_C_RUNTIME_API_H_
#endif // DGL_RUNTIME_C_RUNTIME_API_H_
/*!
* Copyright (c) 2016 by Contributors
* \file tvm/runtime/device_api.h
* \file dgl/runtime/device_api.h
* \brief Abstract device memory management API
*/
#ifndef TVM_RUNTIME_DEVICE_API_H_
#define TVM_RUNTIME_DEVICE_API_H_
#ifndef DGL_RUNTIME_DEVICE_API_H_
#define DGL_RUNTIME_DEVICE_API_H_
#include <string>
#include "packed_func.h"
......@@ -180,4 +180,4 @@ class DeviceAPI {
constexpr int kRPCSessMask = 128;
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_DEVICE_API_H_
#endif // DGL_RUNTIME_DEVICE_API_H_
/*!
* Copyright (c) 2017 by Contributors
* \file tvm/runtime/module.h
* \file dgl/runtime/module.h
* \brief Runtime container of the functions generated by TVM,
* This is used to support dynamically link, load and save
* functions from different convention under unified API.
*/
#ifndef TVM_RUNTIME_MODULE_H_
#define TVM_RUNTIME_MODULE_H_
#ifndef DGL_RUNTIME_MODULE_H_
#define DGL_RUNTIME_MODULE_H_
#include <dmlc/io.h>
#include <memory>
......@@ -174,4 +174,4 @@ inline const ModuleNode* Module::operator->() const {
} // namespace tvm
#include "packed_func.h"
#endif // TVM_RUNTIME_MODULE_H_
#endif // DGL_RUNTIME_MODULE_H_
/*!
* Copyright (c) 2017 by Contributors
* \file tvm/runtime/ndarray.h
* \file dgl/runtime/ndarray.h
* \brief Abstract device memory management API
*/
#ifndef TVM_RUNTIME_NDARRAY_H_
#define TVM_RUNTIME_NDARRAY_H_
#ifndef DGL_RUNTIME_NDARRAY_H_
#define DGL_RUNTIME_NDARRAY_H_
#include <atomic>
#include <vector>
......@@ -422,4 +422,4 @@ inline bool NDArray::Load(dmlc::Stream* strm) {
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_NDARRAY_H_
#endif // DGL_RUNTIME_NDARRAY_H_
/*!
* Copyright (c) 2017 by Contributors
* \file tvm/runtime/packed_func.h
* \file dgl/runtime/packed_func.h
* \brief Type-erased function used across TVM API.
*/
#ifndef TVM_RUNTIME_PACKED_FUNC_H_
#define TVM_RUNTIME_PACKED_FUNC_H_
#ifndef DGL_RUNTIME_PACKED_FUNC_H_
#define DGL_RUNTIME_PACKED_FUNC_H_
#include <dmlc/logging.h>
#include <functional>
......@@ -1212,4 +1212,4 @@ inline PackedFunc Module::GetFunction(const std::string& name, bool query_import
}
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_PACKED_FUNC_H_
#endif // DGL_RUNTIME_PACKED_FUNC_H_
/*!
* Copyright (c) 2017 by Contributors
* \file tvm/runtime/registry.h
* \file dgl/runtime/registry.h
* \brief This file defines the TVM global function registry.
*
* The registered functions will be made available to front-end
......@@ -22,8 +22,8 @@
* });
* \endcode
*/
#ifndef TVM_RUNTIME_REGISTRY_H_
#define TVM_RUNTIME_REGISTRY_H_
#ifndef DGL_RUNTIME_REGISTRY_H_
#define DGL_RUNTIME_REGISTRY_H_
#include <string>
#include <vector>
......@@ -141,4 +141,4 @@ class Registry {
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_REGISTRY_H_
#endif // DGL_RUNTIME_REGISTRY_H_
/*!
* Copyright (c) 2017 by Contributors
* \file tvm/runtime/serializer.h
* \file dgl/runtime/serializer.h
* \brief Serializer extension to support TVM data types
* Include this file to enable serialization of DLDataType, DLContext
*/
#ifndef TVM_RUNTIME_SERIALIZER_H_
#define TVM_RUNTIME_SERIALIZER_H_
#ifndef DGL_RUNTIME_SERIALIZER_H_
#define DGL_RUNTIME_SERIALIZER_H_
#include <dmlc/io.h>
#include <dmlc/serializer.h>
......@@ -48,4 +48,4 @@ struct Handler<DLContext> {
} // namespace serializer
} // namespace dmlc
#endif // TVM_RUNTIME_SERIALIZER_H_
#endif // DGL_RUNTIME_SERIALIZER_H_
/*!
* Copyright (c) 2018 by Contributors
* \file tvm/runtime/threading_backend.h
* \file dgl/runtime/threading_backend.h
* \brief Utilities for manipulating thread pool threads.
*/
#ifndef TVM_RUNTIME_THREADING_BACKEND_H_
#define TVM_RUNTIME_THREADING_BACKEND_H_
#ifndef DGL_RUNTIME_THREADING_BACKEND_H_
#define DGL_RUNTIME_THREADING_BACKEND_H_
#include <functional>
#include <memory>
......@@ -82,4 +82,4 @@ int MaxConcurrency();
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_THREADING_BACKEND_H_
#endif // DGL_RUNTIME_THREADING_BACKEND_H_
/*!
* Copyright (c) 2017 by Contributors
* \file tvm/runtime/util.h
* \file dgl/runtime/util.h
* \brief Useful runtime util.
*/
#ifndef TVM_RUNTIME_UTIL_H_
#define TVM_RUNTIME_UTIL_H_
#ifndef DGL_RUNTIME_UTIL_H_
#define DGL_RUNTIME_UTIL_H_
#include "c_runtime_api.h"
......@@ -50,4 +50,4 @@ enum TVMStructFieldKind : int {
} // namespace intrinsic
} // namespace ir
} // namespace tvm
#endif // TVM_RUNTIME_UTIL_H_
#endif // DGL_RUNTIME_UTIL_H_
// DGL Scheduler interface
/*!
* Copyright (c) 2018 by Contributors
* \file dgl/scheduler.h
* \brief Operations on graph index.
*/
#ifndef DGL_SCHEDULER_H_
#define DGL_SCHEDULER_H_
#include "runtime/ndarray.h"
#include <vector>
#include "runtime/ndarray.h"
namespace dgl {
......@@ -25,8 +29,8 @@ namespace sched {
*/
std::vector<IdArray> DegreeBucketing(const IdArray& vids);
} // namespace sched
} // namespace sched
} // namespace dgl
} // namespace dgl
#endif // DGL_SCHEDULER_H_
#endif // DGL_SCHEDULER_H_
from . import backend
from . import data
from . import function
from . import generator
from . import nn
from ._ffi.runtime_ctypes import TypeCode
......@@ -10,6 +9,5 @@ from ._ffi.base import DGLError, __version__
from .base import ALL
from .batched_graph import *
from .generator import *
from .graph import DGLGraph, __MSG__, __REPR__
from .graph import DGLGraph
from .subgraph import DGLSubGraph
......@@ -46,8 +46,14 @@ def from_numpy(np_data):
def pack(tensors):
return F.concat(*tensors, dim=0)
def unpack(x, indices_or_sections=1):
return th.split(x, indices_or_sections)
def unpack(x, split_sizes_or_sections=1):
if isinstance(split_sizes_or_sections, list):
np_arr = x.asnumpy()
indices = np.cumsum(split_sizes_or_sections)
res = np.split(np_arr, indices[:-1])
return [tensor(arr, dtype=x.dtype) for arr in res]
else:
return F.split(x, split_sizes_or_sections)
# TODO this doesn't exist for symbol.
def shape(x):
......@@ -66,7 +72,10 @@ def unique(x):
return mx.nd.array(tmp, ctx=x.context, dtype=x.dtype)
def gather_row(data, row_index):
return data[row_index,]
if isinstance(row_index, F.NDArray):
return F.take(data, row_index)
else:
return data[row_index,]
scatter_row = mx.nd.contrib.index_copy
......@@ -114,6 +123,27 @@ def get_context(x):
def _typestr(arr_dtype):
return arr_dtype
def get_tvmtype(arr):
arr_dtype = arr.dtype
if arr_dtype == np.float16:
return TVMType('float16')
elif arr_dtype == np.float32:
return TVMType('float32')
elif arr_dtype == np.float64:
return TVMType('float64')
elif arr_dtype == np.int16:
return TVMType('int16')
elif arr_dtype == np.int32:
return TVMType('int32')
elif arr_dtype == np.int64:
return TVMType('int64')
elif arr_dtype == np.int8:
return TVMType('int8')
elif arr_dtype == np.uint8:
return TVMType('uint8')
else:
raise RuntimeError('Unsupported data type:', arr_dtype)
def zerocopy_to_dlpack(arr):
"""Return a dlpack compatible array using zero copy."""
return arr.to_dlpack_for_read()
......
......@@ -93,23 +93,24 @@ def get_context(arr):
return TVMContext(
TVMContext.STR2MASK[arr.device.type], arr.device.index)
def _typestr(arr_dtype):
def get_tvmtype(arr):
arr_dtype = arr.dtype
if arr_dtype in (th.float16, th.half):
return 'float16'
return TVMType('float16')
elif arr_dtype in (th.float32, th.float):
return 'float32'
return TVMType('float32')
elif arr_dtype in (th.float64, th.double):
return 'float64'
return TVMType('float64')
elif arr_dtype in (th.int16, th.short):
return 'int16'
return TVMType('int16')
elif arr_dtype in (th.int32, th.int):
return 'int32'
return TVMType('int32')
elif arr_dtype in (th.int64, th.long):
return 'int64'
return TVMType('int64')
elif arr_dtype == th.int8:
return 'int8'
return TVMType('int8')
elif arr_dtype == th.uint8:
return 'uint8'
return TVMType('uint8')
else:
raise RuntimeError('Unsupported data type:', arr_dtype)
......@@ -130,20 +131,6 @@ def zerocopy_from_numpy(np_data):
"""Return a tensor that shares the numpy data."""
return th.from_numpy(np_data)
'''
data = arr_data
assert data.is_contiguous()
arr = TVMArray()
shape = c_array(tvm_shape_index_t, tuple(data.shape))
arr.data = ctypes.cast(data.data_ptr(), ctypes.c_void_p)
arr.shape = shape
arr.strides = None
arr.dtype = TVMType(_typestr(data.dtype))
arr.ndim = len(shape)
arr.ctx = get_context(data)
return arr
'''
def nonzero_1d(arr):
"""Return a 1D tensor with nonzero element indices in a 1D vector"""
assert arr.dim() == 1
......
"""Module for base types and utilities."""
from __future__ import absolute_import
import warnings
from ._ffi.base import DGLError
# A special argument for selecting all nodes/edges.
ALL = "__ALL__"
......@@ -6,5 +11,4 @@ ALL = "__ALL__"
def is_all(arg):
return isinstance(arg, str) and arg == ALL
__MSG__ = "__MSG__"
__REPR__ = "__REPR__"
dgl_warning = warnings.warn
......@@ -206,10 +206,33 @@ def get_gnp_generator(args):
return nx.fast_gnp_random_graph(n, p, seed, True)
return _gen
class ScipyGraph(object):
"""A simple graph object that uses scipy matrix."""
def __init__(self, mat):
self._mat = mat
def get_graph(self):
return self._mat
def number_of_nodes(self):
return self._mat.shape[0]
def number_of_edges(self):
return self._mat.getnnz()
def get_scipy_generator(args):
n = args.syn_gnp_n
p = (2 * np.log(n) / n) if args.syn_gnp_p == 0. else args.syn_gnp_p
def _gen(seed):
return ScipyGraph(sp.random(n, n, p, format='coo'))
return _gen
def load_synthetic(args):
ty = args.syn_type
if ty == 'gnp':
gen = get_gnp_generator(args)
elif ty == 'scipy':
gen = get_scipy_generator(args)
else:
raise ValueError('Unknown graph generator type: {}'.format(ty))
return GCNSyntheticDataset(
......
"""Dataset utilities."""
from __future__ import absolute_import
import os
import os, sys
import hashlib
import warnings
import zipfile
......@@ -125,17 +126,22 @@ def extract_archive(file, target_dir):
target_dir : str
Target directory of the archive to be uncompressed
"""
if os.path.exists(target_dir):
return
if file.endswith('.gz') or file.endswith('.tar') or file.endswith('.tgz'):
archive = tarfile.open(file, 'r')
elif file.endswith('.zip'):
archive = zipfile.ZipFile(file, 'r')
else:
raise Exception('Unrecognized file type: ' + file)
print('Extracting file to {}'.format(target_dir))
archive.extractall(path=target_dir)
archive.close()
def get_download_dir():
dirname = '_download'
"""Get the absolute path to the download directory."""
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
dirname = os.path.join(curr_path, '../../../_download')
if not os.path.exists(dirname):
os.makedirs(dirname)
return dirname
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment