Unverified Commit a9dabcc7 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Feature] Random Walk for 0.5 (#1209)

* trying to refactor IndexSelect

* partial implementation

* add index select and assign for floats as well

* move to random choice source

* more updates

* fixes

* fixes

* more fixes

* adding python impl

* fixes

* unit test

* lint

* lint x2

* lint x3

* update metapath2vec

* debugging performance

* still debugging for performance

* tuning

* switching to succvec

* redo

* revert non-uniform sampler to use vector

* still not fast

* why does this crash with OpenMP???

* because there was a data race!!!

* add documentations and remove assign op

* lint

* lint x2

* lol what have i done

* lint x3

* fix and disable gpu testing

* bugfix

* generic random walk

* reorg the random walk source code

* Update randomwalks.h

* Update randomwalks_cpu.cc

* rename file

* move internal function to anonymous ns

* reorg & docstrings

* constant restart probability

* docstring fix

* more commit

* random walk with restart, tested

* some fixes

* switch to using NDArray for choice

* massive fix & docstring

* lint x?

* lint x??

* fix

* export symbols

* skip gpu test

* addresses comments

* replaces another VecToIdArray

* add randomwalks.h to include

* replace void * with template
parent 5967d817
......@@ -92,6 +92,8 @@ file(GLOB DGL_SRC
src/*.cc
src/array/*.cc
src/array/cpu/*.cc
src/random/*.cc
src/random/cpu/*.cc
src/kernel/*.cc
src/kernel/cpu/*.cc
src/runtime/*.cc
......
......@@ -89,11 +89,9 @@ def generate_metapath():
hg, author_names, conf_names, paper_names = construct_graph()
for conf_idx in tqdm.trange(hg.number_of_nodes('conf')):
traces = dgl.contrib.sampling.metapath_random_walk(
hg, ['cp', 'pa', 'ap', 'pc'] * walk_length, [conf_idx], num_walks_per_node)
traces = traces[0]
for trace in traces:
tr = np.insert(trace.numpy(), 0, conf_idx)
traces, _ = dgl.sampling.random_walk(
hg, [conf_idx] * num_walks_per_node, metapath=['cp', 'pa', 'ap', 'pc'] * walk_length)
for tr in traces:
outline = ' '.join(
(conf_names if i % 4 == 0 else author_names)[tr[i]]
for i in range(0, len(tr), 2)) # skip paper
......
......@@ -12,18 +12,22 @@
#include <dgl/runtime/ndarray.h>
#include <algorithm>
#include <vector>
#include <tuple>
#include <utility>
namespace dgl {
typedef uint64_t dgl_id_t;
typedef uint64_t dgl_type_t;
typedef dgl::runtime::NDArray IdArray;
typedef dgl::runtime::NDArray DegreeArray;
typedef dgl::runtime::NDArray BoolArray;
typedef dgl::runtime::NDArray IntArray;
typedef dgl::runtime::NDArray FloatArray;
typedef dgl::runtime::NDArray TypeArray;
using dgl::runtime::NDArray;
typedef NDArray IdArray;
typedef NDArray DegreeArray;
typedef NDArray BoolArray;
typedef NDArray IntArray;
typedef NDArray FloatArray;
typedef NDArray TypeArray;
namespace aten {
......@@ -101,9 +105,13 @@ BoolArray LT(IdArray lhs, dgl_id_t rhs);
/*! \brief Stack two arrays (of len L) into a 2*L length array */
IdArray HStack(IdArray arr1, IdArray arr2);
/*! \brief Return the data under the index. In numpy notation, A[I] */
int64_t IndexSelect(IdArray array, int64_t index);
IdArray IndexSelect(IdArray array, IdArray index);
/*!
* \brief Return the data under the index. In numpy notation, A[I]
* \tparam ValueType The type of return value.
*/
template<typename ValueType>
ValueType IndexSelect(NDArray array, uint64_t index);
NDArray IndexSelect(NDArray array, IdArray index);
/*!
* \brief Relabel the given ids to consecutive ids.
......@@ -121,6 +129,68 @@ inline bool IsValidIdArray(const dgl::runtime::NDArray& arr) {
return arr->ndim == 1 && arr->dtype.code == kDLInt;
}
/*!
* \brief Packs a tensor containing padded sequences of variable length.
*
* Similar to \c pack_padded_sequence in PyTorch, except that
*
* 1. The length for each sequence (before padding) is inferred as the number
* of elements before the first occurrence of \c pad_value.
* 2. It does not sort the sequences by length.
* 3. Along with the tensor containing the packed sequence, it returns both the
* length, as well as the offsets to the packed tensor, of each sequence.
*
* \param array The tensor containing sequences padded to the same length
* \param pad_value The padding value
* \return A triplet of packed tensor, the length tensor, and the offset tensor
*
* \note Example: consider the following array with padding value -1:
*
* <code>
* [[1, 2, -1, -1],
* [3, 4, 5, -1]]
* </code>
*
* The packed tensor would be [1, 2, 3, 4, 5].
*
* The length tensor would be [2, 3], i.e. the length of each sequence before padding.
*
* The offset tensor would be [0, 2], i.e. the offset to the packed tensor for each
* sequence (before padding)
*/
template<typename ValueType>
std::tuple<NDArray, IdArray, IdArray> Pack(NDArray array, ValueType pad_value);
/*!
* \brief Batch-slice a 1D or 2D array, and then pack the list of sliced arrays
* by concatenation.
*
* If a 2D array is given, then the function is equivalent to:
*
* <code>
* def ConcatSlices(array, lengths):
* slices = [array[i, :l] for i, l in enumerate(lengths)]
* packed = np.concatenate(slices)
* offsets = np.cumsum([0] + lengths[:-1])
* return packed, offsets
* </code>
*
* If a 1D array is given, then the function is equivalent to
*
* <code>
* def ConcatSlices(array, lengths):
* slices = [array[:l] for l in lengths]
* packed = np.concatenate(slices)
* offsets = np.cumsum([0] + lengths[:-1])
* return packed, offsets
* </code>
*
* \param array A 1D or 2D tensor for slicing
* \param lengths A 1D tensor indicating the number of elements to slice
* \return The tensor with packed slices along with the offsets.
*/
std::pair<NDArray, IdArray> ConcatSlices(NDArray array, IdArray lengths);
//////////////////////////////////////////////////////////////////////
// Sparse matrix
//////////////////////////////////////////////////////////////////////
......@@ -285,6 +355,181 @@ IdArray VecToIdArray(const std::vector<T>& vec,
return ret.CopyTo(ctx);
}
///////////////////////// Dispatchers //////////////////////////
/*
* Dispatch according to device:
*
* ATEN_XPU_SWITCH(array->ctx.device_type, XPU, {
* // Now XPU is a placeholder for array->ctx.device_type
* DeviceSpecificImplementation<XPU>(...);
* });
*/
#define ATEN_XPU_SWITCH(val, XPU, ...) do { \
if ((val) == kDLCPU) { \
constexpr auto XPU = kDLCPU; \
{__VA_ARGS__} \
} else { \
LOG(FATAL) << "Device type: " << (val) << " is not supported."; \
} \
} while (0)
/*
* Dispatch according to integral type (either int32 or int64):
*
* ATEN_ID_TYPE_SWITCH(array->dtype, IdType, {
* // Now IdType is the type corresponding to data type in array.
* // For instance, one can do this for a CPU array:
* DType *data = static_cast<DType *>(array->data);
* });
*/
#define ATEN_ID_TYPE_SWITCH(val, IdType, ...) do { \
CHECK_EQ((val).code, kDLInt) << "ID must be integer type"; \
if ((val).bits == 32) { \
typedef int32_t IdType; \
{__VA_ARGS__} \
} else if ((val).bits == 64) { \
typedef int64_t IdType; \
{__VA_ARGS__} \
} else { \
LOG(FATAL) << "ID can only be int32 or int64"; \
} \
} while (0)
/*
* Dispatch according to float type (either float32 or float64):
*
* ATEN_ID_TYPE_SWITCH(array->dtype, FloatType, {
* // Now FloatType is the type corresponding to data type in array.
* // For instance, one can do this for a CPU array:
* FloatType *data = static_cast<FloatType *>(array->data);
* });
*/
#define ATEN_FLOAT_TYPE_SWITCH(val, FloatType, val_name, ...) do { \
CHECK_EQ((val).code, kDLFloat) \
<< (val_name) << " must be float type"; \
if ((val).bits == 32) { \
typedef float FloatType; \
{__VA_ARGS__} \
} else if ((val).bits == 64) { \
typedef double FloatType; \
{__VA_ARGS__} \
} else { \
LOG(FATAL) << (val_name) << " can only be float32 or float64"; \
} \
} while (0)
/*
* Dispatch according to data type (int32, int64, float32 or float64):
*
* ATEN_ID_TYPE_SWITCH(array->dtype, DType, {
* // Now DType is the type corresponding to data type in array.
* // For instance, one can do this for a CPU array:
* DType *data = static_cast<DType *>(array->data);
* });
*/
#define ATEN_DTYPE_SWITCH(val, DType, val_name, ...) do { \
if ((val).code == kDLInt && (val).bits == 32) { \
typedef int32_t DType; \
{__VA_ARGS__} \
} else if ((val).code == kDLInt && (val).bits == 64) { \
typedef int64_t DType; \
{__VA_ARGS__} \
} else if ((val).code == kDLFloat && (val).bits == 32) { \
typedef float DType; \
{__VA_ARGS__} \
} else if ((val).code == kDLFloat && (val).bits == 64) { \
typedef double DType; \
{__VA_ARGS__} \
} else { \
LOG(FATAL) << (val_name) << " can only be int32, int64, float32 or float64"; \
} \
} while (0)
/*
* Dispatch according to integral type of CSR graphs.
* Identical to ATEN_ID_TYPE_SWITCH except for a different error message.
*/
#define ATEN_CSR_DTYPE_SWITCH(val, DType, ...) do { \
if ((val).code == kDLInt && (val).bits == 32) { \
typedef int32_t DType; \
{__VA_ARGS__} \
} else if ((val).code == kDLInt && (val).bits == 64) { \
typedef int64_t DType; \
{__VA_ARGS__} \
} else { \
LOG(FATAL) << "CSR matrix data can only be int32 or int64"; \
} \
} while (0)
// Macro to dispatch according to device context, index type and data type
// TODO(minjie): In our current use cases, data type and id type are the
// same. For example, data array is used to store edge ids.
#define ATEN_CSR_SWITCH(csr, XPU, IdType, DType, ...) \
ATEN_XPU_SWITCH(csr.indptr->ctx.device_type, XPU, { \
ATEN_ID_TYPE_SWITCH(csr.indptr->dtype, IdType, { \
typedef IdType DType; \
{__VA_ARGS__} \
}); \
});
// Macro to dispatch according to device context and index type
#define ATEN_CSR_IDX_SWITCH(csr, XPU, IdType, ...) \
ATEN_XPU_SWITCH(csr.indptr->ctx.device_type, XPU, { \
ATEN_ID_TYPE_SWITCH(csr.indptr->dtype, IdType, { \
{__VA_ARGS__} \
}); \
});
// Macro to dispatch according to device context, index type and data type
// TODO(minjie): In our current use cases, data type and id type are the
// same. For example, data array is used to store edge ids.
#define ATEN_COO_SWITCH(coo, XPU, IdType, DType, ...) \
ATEN_XPU_SWITCH(coo.row->ctx.device_type, XPU, { \
ATEN_ID_TYPE_SWITCH(coo.row->dtype, IdType, { \
typedef IdType DType; \
{__VA_ARGS__} \
}); \
});
// Macro to dispatch according to device context and index type
#define ATEN_COO_IDX_SWITCH(coo, XPU, IdType, ...) \
ATEN_XPU_SWITCH(coo.row->ctx.device_type, XPU, { \
ATEN_ID_TYPE_SWITCH(coo.row->dtype, IdType, { \
{__VA_ARGS__} \
}); \
});
///////////////////////// Array checks //////////////////////////
#define IS_INT32(a) \
((a)->dtype.code == kDLInt && (a)->dtype.bits == 32)
#define IS_INT64(a) \
((a)->dtype.code == kDLInt && (a)->dtype.bits == 64)
#define IS_FLOAT32(a) \
((a)->dtype.code == kDLFloat && (a)->dtype.bits == 32)
#define IS_FLOAT64(a) \
((a)->dtype.code == kDLFloat && (a)->dtype.bits == 64)
#define CHECK_IF(cond, prop, value_name, dtype_name) \
CHECK(cond) << "Expecting " << (prop) << " of " << (value_name) << " to be " << (dtype_name)
#define CHECK_INT32(value, value_name) \
CHECK_IF(IS_INT32(value), "dtype", value_name, "int32")
#define CHECK_INT64(value, value_name) \
CHECK_IF(IS_INT64(value), "dtype", value_name, "int64")
#define CHECK_INT(value, value_name) \
CHECK_IF(IS_INT32(value) || IS_INT64(value), "dtype", value_name, "int32 or int64")
#define CHECK_FLOAT32(value, value_name) \
CHECK_IF(IS_FLOAT32(value), "dtype", value_name, "float32")
#define CHECK_FLOAT64(value, value_name) \
CHECK_IF(IS_FLOAT64(value), "dtype", value_name, "float64")
#define CHECK_FLOAT(value, value_name) \
CHECK_IF(IS_FLOAT32(value) || IS_FLOAT64(value), "dtype", value_name, "float32 or float64")
#define CHECK_NDIM(value, _ndim, value_name) \
CHECK_IF((value)->ndim == (_ndim), "ndim", value_name, _ndim)
} // namespace aten
} // namespace dgl
......
......@@ -57,6 +57,11 @@ class BaseHeteroGraph : public runtime::Object {
return meta_graph_->NumEdges();
}
/*! \return given the edge type, find the source type */
virtual std::pair<dgl_type_t, dgl_type_t> GetEndpointTypes(dgl_type_t etype) const {
return meta_graph_->FindEdge(etype);
}
/*! \return the meta graph */
virtual GraphPtr meta_graph() const {
return meta_graph_;
......
......@@ -7,10 +7,12 @@
#ifndef DGL_RANDOM_H_
#define DGL_RANDOM_H_
#include <dgl/array.h>
#include <dmlc/thread_local.h>
#include <dmlc/logging.h>
#include <random>
#include <thread>
#include <vector>
namespace dgl {
......@@ -95,8 +97,15 @@ class RandomEngine {
return dist(rng_);
}
/*!
* \brief Pick a random integer between 0 to N-1 according to given probabilities
* \param prob Array of unnormalized probability of each element. Must be non-negative.
*/
template<typename IdxType>
IdxType Choice(FloatArray prob);
private:
std::mt19937 rng_;
std::default_random_engine rng_;
};
}; // namespace dgl
......
......@@ -168,6 +168,23 @@ class IterAdapter {
* operator[] only provide const access, use Set to mutate the content.
*
* \tparam T The content ObjectRef type.
*
* \note The element type must subclass \c ObjectRef. Otherwise, the
* compiler would throw an error:
*
* <code>
* error: no type named 'type' in 'struct std::enable_if<false, void>'
* </code>
*
* Example:
*
* <code>
* // List<int> list; // fails
* // List<NDArray> list2; // fails
* List<Value> list; // works
* list.push_back(Value(MakeValue(1))); // works
* list.push_back(Value(MakeValue(NDArray::Empty(shape, dtype, ctx)))); // works
* </code>
*/
template<typename T,
typename = typename std::enable_if<std::is_base_of<ObjectRef, T>::value>::type >
......@@ -366,6 +383,23 @@ class List : public ObjectRef {
*
* \tparam K The key ObjectRef type.
* \tparam V The value ObjectRef type.
*
* \note The element type must subclass \c ObjectRef. Otherwise, the
* compiler would throw an error:
*
* <code>
* error: no type named 'type' in 'struct std::enable_if<false, void>'
* </code>
*
* Example:
*
* <code>
* // Map<std::string, int> map; // fails
* // Map<std::string, NDArray> map2; // fails
* Map<std::string, Value> map; // works
* map.Set("key1", Value(MakeValue(1))); // works
* map.Set("key2", Value(MakeValue(NDArray::Empty(shape, dtype, ctx)))); // works
* </code>
*/
template<typename K,
typename V,
......
......@@ -182,6 +182,18 @@ class NDArray {
* \return The created NDArray view.
*/
DGL_DLL static NDArray FromDLPack(DLManagedTensor* tensor);
/*!
* \brief Create a NDArray by copying from std::vector.
* \tparam T Type of vector data. Determines the dtype of returned array.
*/
template<typename T>
DGL_DLL static NDArray FromVector(
const std::vector<T>& vec, DLContext ctx = DLContext{kDLCPU, 0});
template<typename T>
static NDArray FromVector(const std::vector<T>& vec, DLDataType dtype, DLContext ctx);
/*!
* \brief Function to copy data from one array to another.
* \param from The source array.
......
/*!
* Copyright (c) 2019 by Contributors
* \file dgl/samplinig/randomwalks.h
* \brief Random walk functions.
*/
#ifndef DGL_SAMPLING_RANDOMWALKS_H_
#define DGL_SAMPLING_RANDOMWALKS_H_
#include <dgl/base_heterograph.h>
#include <dgl/array.h>
#include <vector>
#include <utility>
namespace dgl {
namespace sampling {
/*!
* \brief Metapath-based random walk.
* \param hg The heterograph.
* \param seeds A 1D array of seed nodes, with the type the source type of the first
* edge type in the metapath.
* \param metapath A 1D array of edge types representing the metapath.
* \param prob A vector of 1D float arrays, indicating the transition probability of
* each edge by edge type. An empty float array assumes uniform transition.
* \return A pair of
* 1. One 2D array of shape (len(seeds), len(metapath) + 1) with node IDs. The
* paths that terminated early are padded with -1.
* 2. One 1D array of shape (len(metapath) + 1) with node type IDs.
*/
std::pair<IdArray, TypeArray> RandomWalk(
const HeteroGraphPtr hg,
const IdArray seeds,
const TypeArray metapath,
const std::vector<FloatArray> &prob);
/*!
* \brief Metapath-based random walk with restart probability.
* \param hg The heterograph.
* \param seeds A 1D array of seed nodes, with the type the source type of the first
* edge type in the metapath.
* \param metapath A 1D array of edge types representing the metapath.
* \param prob A vector of 1D float arrays, indicating the transition probability of
* each edge by edge type. An empty float array assumes uniform transition.
* \param restart_prob Restart probability
* \return A pair of
* 1. One 2D array of shape (len(seeds), len(metapath) + 1) with node IDs. The
* paths that terminated early are padded with -1.
* 2. One 1D array of shape (len(metapath) + 1) with node type IDs.
*/
std::pair<IdArray, TypeArray> RandomWalkWithRestart(
const HeteroGraphPtr hg,
const IdArray seeds,
const TypeArray metapath,
const std::vector<FloatArray> &prob,
double restart_prob);
/*!
* \brief Metapath-based random walk with stepwise restart probability. Useful
* for PinSAGE-like models.
* \param hg The heterograph.
* \param seeds A 1D array of seed nodes, with the type the source type of the first
* edge type in the metapath.
* \param metapath A 1D array of edge types representing the metapath.
* \param prob A vector of 1D float arrays, indicating the transition probability of
* each edge by edge type. An empty float array assumes uniform transition.
* \param restart_prob Restart probability array which has the same number of elements
* as \c metapath, indicating the probability to terminate after transition.
* \return A pair of
* 1. One 2D array of shape (len(seeds), len(metapath) + 1) with node IDs. The
* paths that terminated early are padded with -1.
* 2. One 1D array of shape (len(metapath) + 1) with node type IDs.
*/
std::pair<IdArray, TypeArray> RandomWalkWithStepwiseRestart(
const HeteroGraphPtr hg,
const IdArray seeds,
const TypeArray metapath,
const std::vector<FloatArray> &prob,
FloatArray restart_prob);
}; // namespace sampling
}; // namespace dgl
#endif // DGL_SAMPLING_RANDOMWALKS_H_
......@@ -12,6 +12,7 @@ from . import nn
from . import contrib
from . import container
from . import random
from . import sampling
from ._ffi.runtime_ctypes import TypeCode
from ._ffi.function import register_func, get_global_func, list_global_func_names, extract_ext_funcs
......
......@@ -286,14 +286,16 @@ def _init_api_prefix(module_name, prefix):
for name in list_global_func_names():
if name.startswith("_"):
continue
if not name.startswith(prefix):
name_split = name.rsplit('.', 1)
if name_split[0] != prefix:
continue
fname = name[len(prefix)+1:]
target_module = module
if fname.find(".") != -1:
print('Warning: invalid API name "%s".' % fname)
if len(name_split) == 1:
print('Warning: invalid API name "%s".' % name)
continue
fname = name_split[1]
target_module = module
f = get_global_func(name)
ff = _get_api(f)
ff.__name__ = fname
......
"""Sampler modules."""
from .randomwalks import *
"""Random walk routines
"""
from .._ffi.function import _init_api
from .. import backend as F
from ..base import DGLError
from .. import ndarray as nd
from .. import utils
__all__ = [
'random_walk',
'pack_traces']
def random_walk(g, nodes, *, metapath=None, length=None, prob=None, restart_prob=None):
"""Generate random walk traces from an array of seed nodes (or starting nodes),
based on the given metapath.
For a single seed node, ``num_traces`` traces would be generated. A trace would
1. Start from the given seed and set ``t`` to 0.
2. Pick and traverse along edge type ``metapath[t]`` from the current node.
3. If no edge can be found, halt. Otherwise, increment ``t`` and go to step 2.
The returned traces all have length ``len(metapath) + 1``, where the first node
is the seed node itself.
If a random walk stops in advance, the trace is padded with -1 to have the same
length.
Parameters
----------
g : DGLGraph
The graph.
nodes : Tensor
Node ID tensor from which the random walk traces starts.
metapath : list[str or tuple of str], optional
Metapath, specified as a list of edge types.
If omitted, we assume that ``g`` only has one node & edge type. In this
case, the argument ``length`` specifies the length of random walk traces.
length : int, optional
Length of random walks.
Affects only when ``metapath`` is omitted.
prob : str, optional
The name of the edge feature tensor on the graph storing the (unnormalized)
probabilities associated with each edge for choosing the next node.
The feature tensor must be non-negative.
If omitted, we assume the neighbors are picked uniformly.
restart_prob : float or Tensor, optional
Probability to stop after each step.
If a tensor is given, ``restart_prob`` should have the same length as ``metapath``.
Returns
-------
traces : Tensor
A 2-dimensional node ID tensor with shape (num_seeds, len(metapath) + 1).
types : Tensor
A 1-dimensional node type ID tensor with shape (len(metapath) + 1).
The type IDs match the ones in the original graph ``g``.
Examples
--------
The following creates a homogeneous graph:
>>> g1 = dgl.graph([(0, 1), (1, 2), (1, 3), (2, 0), (3, 0)], 'user', 'follow')
Normal random walk:
>>> dgl.sampling.random_walk(g1, [0, 1, 2, 0], length=4)
(tensor([[0, 1, 2, 0, 1],
[1, 3, 0, 1, 3],
[2, 0, 1, 3, 0],
[0, 1, 2, 0, 1]]), tensor([0, 0, 0, 0, 0]))
The first tensor indicates the random walk path for each seed node.
The j-th element in the second tensor indicates the node type ID of the j-th node
in every path. In this case, it is returning all 0 (``user``).
Random walk with restart:
>>> dgl.sampling.random_walk_with_restart(g1, [0, 1, 2, 0], length=4, restart_prob=0.5)
(tensor([[ 0, -1, -1, -1, -1],
[ 1, 3, 0, -1, -1],
[ 2, -1, -1, -1, -1],
[ 0, -1, -1, -1, -1]]), tensor([0, 0, 0, 0, 0]))
Non-uniform random walk:
>>> g1.edata['p'] = torch.FloatTensor([1, 0, 1, 1, 1]) # disallow going from 1 to 2
>>> dgl.sampling.random_walk(g1, [0, 1, 2, 0], length=4, prob='p')
(tensor([[0, 1, 3, 0, 1],
[1, 3, 0, 1, 3],
[2, 0, 1, 3, 0],
[0, 1, 3, 0, 1]]), tensor([0, 0, 0, 0, 0]))
Metapath-based random walk:
>>> g2 = dgl.heterograph({
... ('user', 'follow', 'user'): [(0, 1), (1, 2), (1, 3), (2, 0), (3, 0)],
... ('user', 'view', 'item'): [(0, 0), (0, 1), (1, 1), (2, 2), (3, 2), (3, 1)],
... ('item', 'viewed-by', 'user'): [(0, 0), (1, 0), (1, 1), (2, 2), (2, 3), (1, 3)]})
>>> dgl.sampling.random_walk(
... g2, [0, 1, 2, 0], metapath=['follow', 'view', 'viewed-by'] * 2)
(tensor([[0, 1, 1, 1, 2, 2, 3],
[1, 3, 1, 1, 2, 2, 2],
[2, 0, 1, 1, 3, 1, 1],
[0, 1, 1, 0, 1, 1, 3]]), tensor([0, 0, 1, 0, 0, 1, 0]))
Metapath-based random walk, with restarts only on items (i.e. after traversing a "view"
relationship):
>>> dgl.sampling.random_walk(
... g2, [0, 1, 2, 0], metapath=['follow', 'view', 'viewed-by'] * 2,
... restart_prob=torch.FloatTensor([0, 0.5, 0, 0, 0.5, 0]))
(tensor([[ 0, 1, -1, -1, -1, -1, -1],
[ 1, 3, 1, 0, 1, 1, 0],
[ 2, 0, 1, 1, 3, 2, 2],
[ 0, 1, 1, 3, 0, 0, 0]]), tensor([0, 0, 1, 0, 0, 1, 0]))
"""
n_etypes = len(g.canonical_etypes)
n_ntypes = len(g.ntypes)
if metapath is None:
if n_etypes > 1 or n_ntypes > 1:
raise DGLError("metapath not specified and the graph is not homogeneous.")
if length is None:
raise ValueError("Please specify either the metapath or the random walk length.")
metapath = [0] * length
else:
metapath = [g.get_etype_id(etype) for etype in metapath]
gidx = g._graph
nodes = utils.toindex(nodes).todgltensor()
metapath = utils.toindex(metapath).todgltensor().copyto(nodes.ctx)
# Load the probability tensor from the edge frames
if prob is None:
p_nd = [nd.array([], ctx=nodes.ctx) for _ in g.canonical_etypes]
else:
p_nd = []
for etype in g.canonical_etypes:
if prob in g.edges[etype].data:
prob_nd = F.zerocopy_to_dgl_ndarray(g.edges[etype].data[prob])
if prob_nd.ctx != nodes.ctx:
raise ValueError(
'context of seed node array and edges[%s].data[%s] are different' %
(etype, prob))
else:
prob_nd = nd.array([], ctx=nodes.ctx)
p_nd.append(prob_nd)
# Actual random walk
if restart_prob is None:
traces, types = _CAPI_DGLSamplingRandomWalk(gidx, nodes, metapath, p_nd)
elif F.is_tensor(restart_prob):
restart_prob = F.zerocopy_to_dgl_ndarray(restart_prob)
traces, types = _CAPI_DGLSamplingRandomWalkWithStepwiseRestart(
gidx, nodes, metapath, p_nd, restart_prob)
else:
traces, types = _CAPI_DGLSamplingRandomWalkWithRestart(
gidx, nodes, metapath, p_nd, restart_prob)
traces = F.zerocopy_from_dgl_ndarray(traces.data)
types = F.zerocopy_from_dgl_ndarray(types.data)
return traces, types
def pack_traces(traces, types):
"""Pack the padded traces returned by ``random_walk()`` into a concatenated array.
The padding values (-1) are removed, and the length and offset of each trace is
returned along with the concatenated node ID and node type arrays.
Parameters
----------
traces : Tensor
A 2-dimensional node ID tensor.
types : Tensor
A 1-dimensional node type ID tensor.
Returns
-------
concat_vids : Tensor
An array of all node IDs concatenated and padding values removed.
concat_types : Tensor
An array of node types corresponding for each node in ``concat_vids``.
Has the same length as ``concat_vids``.
lengths : Tensor
Length of each trace in the original traces tensor.
offsets : Tensor
Offset of each trace in the originial traces tensor in the new concatenated tensor.
Examples
--------
>>> g2 = dgl.heterograph({
... ('user', 'follow', 'user'): [(0, 1), (1, 2), (1, 3), (2, 0), (3, 0)],
... ('user', 'view', 'item'): [(0, 0), (0, 1), (1, 1), (2, 2), (3, 2), (3, 1)],
... ('item', 'viewed-by', 'user'): [(0, 0), (1, 0), (1, 1), (2, 2), (2, 3), (1, 3)]})
>>> traces, types = dgl.sampling.random_walk(
... g2, [0, 0], metapath=['follow', 'view', 'viewed-by'] * 2,
... restart_prob=torch.FloatTensor([0, 0.5, 0, 0, 0.5, 0]))
>>> traces, types
(tensor([[ 0, 1, -1, -1, -1, -1, -1],
[ 0, 1, 1, 3, 0, 0, 0]]), tensor([0, 0, 1, 0, 0, 1, 0]))
>>> concat_vids, concat_types, lengths, offsets = dgl.sampling.pack_traces(traces, types)
>>> concat_vids
tensor([0, 1, 0, 1, 1, 3, 0, 0, 0])
>>> concat_types
tensor([0, 0, 0, 0, 1, 0, 0, 1, 0])
>>> lengths
tensor([2, 7])
>>> offsets
tensor([0, 2]))
The first tensor ``concat_vids`` is the concatenation of all paths, i.e. flattened array
of ``traces``, excluding all padding values (-1).
The second tensor ``concat_types`` stands for the node type IDs of all corresponding nodes
in the first tensor.
The third and fourth tensor indicates the length and the offset of each path. With these
tensors it is easy to obtain the i-th random walk path with:
>>> vids = concat_vids.split(lengths.tolist())
>>> vtypes = concat_vtypes.split(lengths.tolist())
>>> vids[1], vtypes[1]
(tensor([0, 1, 1, 3, 0, 0, 0]), tensor([0, 0, 1, 0, 0, 1, 0]))
"""
traces = F.zerocopy_to_dgl_ndarray(traces)
types = F.zerocopy_to_dgl_ndarray(types)
concat_vids, concat_types, lengths, offsets = _CAPI_DGLSamplingPackTraces(traces, types)
concat_vids = F.zerocopy_from_dgl_ndarray(concat_vids.data)
concat_types = F.zerocopy_from_dgl_ndarray(concat_types.data)
lengths = F.zerocopy_from_dgl_ndarray(lengths.data)
offsets = F.zerocopy_from_dgl_ndarray(offsets.data)
return concat_vids, concat_types, lengths, offsets
_init_api('dgl.sampling.randomwalks', __name__)
......@@ -7,7 +7,6 @@
#include "../c_api_common.h"
#include "./array_op.h"
#include "./arith.h"
#include "./common.h"
namespace dgl {
......@@ -201,25 +200,35 @@ IdArray HStack(IdArray lhs, IdArray rhs) {
return ret;
}
IdArray IndexSelect(IdArray array, IdArray index) {
IdArray ret;
NDArray IndexSelect(NDArray array, IdArray index) {
NDArray ret;
// TODO(BarclayII): check if array and index match in context
ATEN_XPU_SWITCH(array->ctx.device_type, XPU, {
ATEN_ID_TYPE_SWITCH(array->dtype, IdType, {
ret = impl::IndexSelect<XPU, IdType>(array, index);
ATEN_DTYPE_SWITCH(array->dtype, DType, "values", {
ATEN_ID_TYPE_SWITCH(index->dtype, IdType, {
ret = impl::IndexSelect<XPU, DType, IdType>(array, index);
});
});
});
return ret;
}
int64_t IndexSelect(IdArray array, int64_t index) {
int64_t ret = 0;
template<typename ValueType>
ValueType IndexSelect(NDArray array, uint64_t index) {
ValueType ret = 0;
ATEN_XPU_SWITCH(array->ctx.device_type, XPU, {
ATEN_ID_TYPE_SWITCH(array->dtype, IdType, {
ret = impl::IndexSelect<XPU, IdType>(array, index);
ATEN_DTYPE_SWITCH(array->dtype, DType, "values", {
ret = impl::IndexSelect<XPU, DType>(array, index);
});
});
return ret;
}
template int32_t IndexSelect<int32_t>(NDArray array, uint64_t index);
template int64_t IndexSelect<int64_t>(NDArray array, uint64_t index);
template uint32_t IndexSelect<uint32_t>(NDArray array, uint64_t index);
template uint64_t IndexSelect<uint64_t>(NDArray array, uint64_t index);
template float IndexSelect<float>(NDArray array, uint64_t index);
template double IndexSelect<double>(NDArray array, uint64_t index);
IdArray Relabel_(const std::vector<IdArray>& arrays) {
IdArray ret;
......@@ -231,6 +240,36 @@ IdArray Relabel_(const std::vector<IdArray>& arrays) {
return ret;
}
template<typename ValueType>
std::tuple<NDArray, IdArray, IdArray> Pack(NDArray array, ValueType pad_value) {
std::tuple<NDArray, IdArray, IdArray> ret;
ATEN_XPU_SWITCH(array->ctx.device_type, XPU, {
ATEN_DTYPE_SWITCH(array->dtype, DType, "array", {
ret = impl::Pack<XPU, DType>(array, static_cast<DType>(pad_value));
});
});
return ret;
}
template std::tuple<NDArray, IdArray, IdArray> Pack<int32_t>(NDArray, int32_t);
template std::tuple<NDArray, IdArray, IdArray> Pack<int64_t>(NDArray, int64_t);
template std::tuple<NDArray, IdArray, IdArray> Pack<uint32_t>(NDArray, uint32_t);
template std::tuple<NDArray, IdArray, IdArray> Pack<uint64_t>(NDArray, uint64_t);
template std::tuple<NDArray, IdArray, IdArray> Pack<float>(NDArray, float);
template std::tuple<NDArray, IdArray, IdArray> Pack<double>(NDArray, double);
std::pair<NDArray, IdArray> ConcatSlices(NDArray array, IdArray lengths) {
std::pair<NDArray, IdArray> ret;
ATEN_XPU_SWITCH(array->ctx.device_type, XPU, {
ATEN_DTYPE_SWITCH(array->dtype, DType, "array", {
ATEN_ID_TYPE_SWITCH(lengths->dtype, IdType, {
ret = impl::ConcatSlices<XPU, DType, IdType>(array, lengths);
});
});
});
return ret;
}
///////////////////////// CSR routines //////////////////////////
bool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col) {
......
......@@ -8,6 +8,8 @@
#include <dgl/array.h>
#include <vector>
#include <tuple>
#include <utility>
namespace dgl {
namespace aten {
......@@ -34,15 +36,21 @@ IdArray BinaryElewise(IdType lhs, IdArray rhs);
template <DLDeviceType XPU, typename IdType>
IdArray HStack(IdArray arr1, IdArray arr2);
template <DLDeviceType XPU, typename IdType>
IdArray IndexSelect(IdArray array, IdArray index);
template <DLDeviceType XPU, typename DType, typename IdType>
NDArray IndexSelect(NDArray array, IdArray index);
template <DLDeviceType XPU, typename IdType>
int64_t IndexSelect(IdArray array, int64_t index);
template <DLDeviceType XPU, typename DType>
DType IndexSelect(NDArray array, uint64_t index);
template <DLDeviceType XPU, typename IdType>
IdArray Relabel_(const std::vector<IdArray>& arrays);
template <DLDeviceType XPU, typename DType>
std::tuple<NDArray, IdArray, IdArray> Pack(NDArray array, DType pad_value);
template <DLDeviceType XPU, typename DType, typename IdType>
std::pair<NDArray, IdArray> ConcatSlices(NDArray array, IdArray lengths);
// sparse arrays
template <DLDeviceType XPU, typename IdType>
......
/*!
* Copyright (c) 2019 by Contributors
* \file array/common.h
* \brief Array operator common utilities
*/
#ifndef DGL_ARRAY_COMMON_H_
#define DGL_ARRAY_COMMON_H_
namespace dgl {
namespace aten {
#define ATEN_XPU_SWITCH(val, XPU, ...) do { \
if ((val) == kDLCPU) { \
constexpr auto XPU = kDLCPU; \
{__VA_ARGS__} \
} else { \
LOG(FATAL) << "Device type: " << (val) << " is not supported."; \
} \
} while (0)
#define ATEN_ID_TYPE_SWITCH(val, IdType, ...) do { \
CHECK_EQ((val).code, kDLInt) << "ID must be integer type"; \
if ((val).bits == 32) { \
typedef int32_t IdType; \
{__VA_ARGS__} \
} else if ((val).bits == 64) { \
typedef int64_t IdType; \
{__VA_ARGS__} \
} else { \
LOG(FATAL) << "ID can only be int32 or int64"; \
} \
} while (0)
#define ATEN_FLOAT_TYPE_SWITCH(val, FloatType, val_name, ...) do { \
CHECK_EQ((val).code, kDLFloat) \
<< (val_name) << " must be float type"; \
if ((val).bits == 32) { \
typedef float FloatType; \
{__VA_ARGS__} \
} else if ((val).bits == 64) { \
typedef double FloatType; \
{__VA_ARGS__} \
} else { \
LOG(FATAL) << (val_name) << " can only be float32 or float64"; \
} \
} while (0)
#define ATEN_CSR_DTYPE_SWITCH(val, DType, ...) do { \
if ((val).code == kDLInt && (val).bits == 32) { \
typedef int32_t DType; \
{__VA_ARGS__} \
} else if ((val).code == kDLInt && (val).bits == 64) { \
typedef int64_t DType; \
{__VA_ARGS__} \
} else { \
LOG(FATAL) << "CSR matrix data can only be int32 or int64"; \
} \
} while (0)
// Macro to dispatch according to device context, index type and data type
// TODO(minjie): In our current use cases, data type and id type are the
// same. For example, data array is used to store edge ids.
#define ATEN_CSR_SWITCH(csr, XPU, IdType, DType, ...) \
ATEN_XPU_SWITCH(csr.indptr->ctx.device_type, XPU, { \
ATEN_ID_TYPE_SWITCH(csr.indptr->dtype, IdType, { \
typedef IdType DType; \
{__VA_ARGS__} \
}); \
});
// Macro to dispatch according to device context and index type
#define ATEN_CSR_IDX_SWITCH(csr, XPU, IdType, ...) \
ATEN_XPU_SWITCH(csr.indptr->ctx.device_type, XPU, { \
ATEN_ID_TYPE_SWITCH(csr.indptr->dtype, IdType, { \
{__VA_ARGS__} \
}); \
});
// Macro to dispatch according to device context, index type and data type
// TODO(minjie): In our current use cases, data type and id type are the
// same. For example, data array is used to store edge ids.
#define ATEN_COO_SWITCH(coo, XPU, IdType, DType, ...) \
ATEN_XPU_SWITCH(coo.row->ctx.device_type, XPU, { \
ATEN_ID_TYPE_SWITCH(coo.row->dtype, IdType, { \
typedef IdType DType; \
{__VA_ARGS__} \
}); \
});
// Macro to dispatch according to device context and index type
#define ATEN_COO_IDX_SWITCH(coo, XPU, IdType, ...) \
ATEN_XPU_SWITCH(coo.row->ctx.device_type, XPU, { \
ATEN_ID_TYPE_SWITCH(coo.row->dtype, IdType, { \
{__VA_ARGS__} \
}); \
});
} // namespace aten
} // namespace dgl
#endif // DGL_ARRAY_COMMON_H_
/*!
* Copyright (c) 2019 by Contributors
* \file array/cpu/array_index_select.cc
* \brief Array index select CPU implementation
*/
#include <dgl/array.h>
namespace dgl {
using runtime::NDArray;
namespace aten {
namespace impl {
template<DLDeviceType XPU, typename DType, typename IdType>
NDArray IndexSelect(NDArray array, IdArray index) {
const DType* array_data = static_cast<DType*>(array->data);
const IdType* idx_data = static_cast<IdType*>(index->data);
const int64_t arr_len = array->shape[0];
const int64_t len = index->shape[0];
NDArray ret = NDArray::Empty({len}, array->dtype, array->ctx);
DType* ret_data = static_cast<DType*>(ret->data);
for (int64_t i = 0; i < len; ++i) {
CHECK_LT(idx_data[i], arr_len) << "Index out of range.";
ret_data[i] = array_data[idx_data[i]];
}
return ret;
}
template NDArray IndexSelect<kDLCPU, int32_t, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLCPU, int32_t, int64_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLCPU, int64_t, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLCPU, int64_t, int64_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLCPU, float, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLCPU, float, int64_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLCPU, double, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLCPU, double, int64_t>(NDArray, IdArray);
template <DLDeviceType XPU, typename DType>
DType IndexSelect(NDArray array, uint64_t index) {
const DType* data = static_cast<DType*>(array->data);
return data[index];
}
template int32_t IndexSelect<kDLCPU, int32_t>(NDArray array, uint64_t index);
template int64_t IndexSelect<kDLCPU, int64_t>(NDArray array, uint64_t index);
template uint32_t IndexSelect<kDLCPU, uint32_t>(NDArray array, uint64_t index);
template uint64_t IndexSelect<kDLCPU, uint64_t>(NDArray array, uint64_t index);
template float IndexSelect<kDLCPU, float>(NDArray array, uint64_t index);
template double IndexSelect<kDLCPU, double>(NDArray array, uint64_t index);
}; // namespace impl
}; // namespace aten
}; // namespace dgl
......@@ -156,35 +156,6 @@ IdArray Range(IdType low, IdType high, DLContext ctx) {
template IdArray Range<kDLCPU, int32_t>(int32_t, int32_t, DLContext);
template IdArray Range<kDLCPU, int64_t>(int64_t, int64_t, DLContext);
///////////////////////////// IndexSelect /////////////////////////////
template <DLDeviceType XPU, typename IdType>
IdArray IndexSelect(IdArray array, IdArray index) {
const IdType* array_data = static_cast<IdType*>(array->data);
const IdType* idx_data = static_cast<IdType*>(index->data);
const int64_t arr_len = array->shape[0];
const int64_t len = index->shape[0];
IdArray ret = NDArray::Empty({len}, array->dtype, array->ctx);
IdType* ret_data = static_cast<IdType*>(ret->data);
for (int64_t i = 0; i < len; ++i) {
CHECK_LT(idx_data[i], arr_len) << "Index out of range.";
ret_data[i] = array_data[idx_data[i]];
}
return ret;
}
template IdArray IndexSelect<kDLCPU, int32_t>(IdArray, IdArray);
template IdArray IndexSelect<kDLCPU, int64_t>(IdArray, IdArray);
template <DLDeviceType XPU, typename IdType>
int64_t IndexSelect(IdArray array, int64_t index) {
const IdType* data = static_cast<IdType*>(array->data);
return data[index];
}
template int64_t IndexSelect<kDLCPU, int32_t>(IdArray array, int64_t index);
template int64_t IndexSelect<kDLCPU, int64_t>(IdArray array, int64_t index);
///////////////////////////// Relabel_ /////////////////////////////
template <DLDeviceType XPU, typename IdType>
......
/*!
* Copyright (c) 2019 by Contributors
* \file array/cpu/array_index_select.cc
* \brief Array index select CPU implementation
*/
#include <dgl/array.h>
#include <tuple>
#include <utility>
namespace dgl {
using runtime::NDArray;
namespace aten {
namespace impl {
template<DLDeviceType XPU, typename DType, typename IdType>
std::pair<NDArray, IdArray> ConcatSlices(NDArray array, IdArray lengths) {
const int64_t rows = lengths->shape[0];
const int64_t cols = (array->ndim == 1 ? array->shape[0] : array->shape[1]);
const int64_t stride = (array->ndim == 1 ? 0 : cols);
const DType *array_data = static_cast<DType *>(array->data);
const IdType *length_data = static_cast<IdType *>(lengths->data);
IdArray offsets = NewIdArray(rows, array->ctx, sizeof(IdType) * 8);
IdType *offsets_data = static_cast<IdType *>(offsets->data);
for (int64_t i = 0; i < rows; ++i)
offsets_data[i] = (i == 0 ? 0 : length_data[i - 1] + offsets_data[i - 1]);
const int64_t total_length = offsets_data[rows - 1] + length_data[rows - 1];
NDArray concat = NDArray::Empty({total_length}, array->dtype, array->ctx);
DType *concat_data = static_cast<DType *>(concat->data);
#pragma omp parallel for
for (int64_t i = 0; i < rows; ++i) {
for (int64_t j = 0; j < length_data[i]; ++j)
concat_data[offsets_data[i] + j] = array_data[i * stride + j];
}
return std::make_pair(concat, offsets);
}
template std::pair<NDArray, IdArray> ConcatSlices<kDLCPU, int32_t, int32_t>(NDArray, IdArray);
template std::pair<NDArray, IdArray> ConcatSlices<kDLCPU, int64_t, int32_t>(NDArray, IdArray);
template std::pair<NDArray, IdArray> ConcatSlices<kDLCPU, float, int32_t>(NDArray, IdArray);
template std::pair<NDArray, IdArray> ConcatSlices<kDLCPU, double, int32_t>(NDArray, IdArray);
template std::pair<NDArray, IdArray> ConcatSlices<kDLCPU, int32_t, int64_t>(NDArray, IdArray);
template std::pair<NDArray, IdArray> ConcatSlices<kDLCPU, int64_t, int64_t>(NDArray, IdArray);
template std::pair<NDArray, IdArray> ConcatSlices<kDLCPU, float, int64_t>(NDArray, IdArray);
template std::pair<NDArray, IdArray> ConcatSlices<kDLCPU, double, int64_t>(NDArray, IdArray);
template<DLDeviceType XPU, typename DType>
std::tuple<NDArray, IdArray, IdArray> Pack(NDArray array, DType pad_value) {
CHECK_NDIM(array, 2, "array");
const DType *array_data = static_cast<DType *>(array->data);
const int64_t rows = array->shape[0];
const int64_t cols = array->shape[1];
IdArray length = NewIdArray(rows, array->ctx);
int64_t *length_data = static_cast<int64_t *>(length->data);
#pragma omp parallel for
for (int64_t i = 0; i < rows; ++i) {
int64_t j;
for (j = 0; j < cols; ++j) {
const DType val = array_data[i * cols + j];
if (val == pad_value)
break;
}
length_data[i] = j;
}
auto ret = ConcatSlices<XPU, DType, int64_t>(array, length);
return std::make_tuple(ret.first, length, ret.second);
}
template std::tuple<NDArray, IdArray, IdArray> Pack<kDLCPU, int32_t>(NDArray, int32_t);
template std::tuple<NDArray, IdArray, IdArray> Pack<kDLCPU, int64_t>(NDArray, int64_t);
template std::tuple<NDArray, IdArray, IdArray> Pack<kDLCPU, float>(NDArray, float);
template std::tuple<NDArray, IdArray, IdArray> Pack<kDLCPU, double>(NDArray, double);
}; // namespace impl
}; // namespace aten
}; // namespace dgl
......@@ -68,16 +68,6 @@ struct PairHash {
}
};
template <typename DType>
inline runtime::NDArray VecToNDArray(const std::vector<DType>& vec,
DLDataType dtype, DLContext ctx) {
const int64_t len = vec.size();
NDArray ret_arr = NDArray::Empty({len}, dtype, ctx);
DType* ptr = static_cast<DType*>(ret_arr->data);
std::copy(vec.begin(), vec.end(), ptr);
return ret_arr;
}
inline bool CSRHasData(CSRMatrix csr) {
return csr.data.defined();
}
......@@ -257,7 +247,7 @@ NDArray CSRGetData(CSRMatrix csr, int64_t row, int64_t col) {
}
}
}
return VecToNDArray(ret_vec, csr.data->dtype, csr.data->ctx);
return NDArray::FromVector(ret_vec, csr.data->dtype, csr.data->ctx);
}
template NDArray CSRGetData<kDLCPU, int32_t, int32_t>(CSRMatrix, int64_t, int64_t);
......@@ -301,7 +291,7 @@ NDArray CSRGetData(CSRMatrix csr, NDArray rows, NDArray cols) {
}
}
return VecToNDArray(ret_vec, csr.data->dtype, csr.data->ctx);
return NDArray::FromVector(ret_vec, csr.data->dtype, csr.data->ctx);
}
template NDArray CSRGetData<kDLCPU, int32_t, int32_t>(CSRMatrix csr, NDArray rows, NDArray cols);
......@@ -381,9 +371,9 @@ std::vector<NDArray> CSRGetDataAndIndices(CSRMatrix csr, NDArray rows, NDArray c
}
}
return {VecToIdArray(ret_rows, csr.indptr->dtype.bits, csr.indptr->ctx),
VecToIdArray(ret_cols, csr.indptr->dtype.bits, csr.indptr->ctx),
VecToNDArray(ret_data, csr.data->dtype, csr.data->ctx)};
return {NDArray::FromVector(ret_rows, csr.indptr->dtype, csr.indptr->ctx),
NDArray::FromVector(ret_cols, csr.indptr->dtype, csr.indptr->ctx),
NDArray::FromVector(ret_data, csr.data->dtype, csr.data->ctx)};
}
template std::vector<NDArray> CSRGetDataAndIndices<kDLCPU, int32_t, int32_t>(
......@@ -613,8 +603,8 @@ CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray
DType* ptr = static_cast<DType*>(sub_data_arr->data);
std::copy(sub_data.begin(), sub_data.end(), ptr);
return CSRMatrix{new_nrows, new_ncols,
VecToIdArray(sub_indptr, csr.indptr->dtype.bits, csr.indptr->ctx),
VecToIdArray(sub_indices, csr.indptr->dtype.bits, csr.indptr->ctx),
NDArray::FromVector(sub_indptr, csr.indptr->dtype, csr.indptr->ctx),
NDArray::FromVector(sub_indices, csr.indptr->dtype, csr.indptr->ctx),
sub_data_arr};
}
......
......@@ -303,8 +303,8 @@ bool COO::IsMultigraph() const {
std::pair<dgl_id_t, dgl_id_t> COO::FindEdge(dgl_id_t eid) const {
CHECK(eid < NumEdges()) << "Invalid edge id: " << eid;
const auto src = aten::IndexSelect(adj_.row, eid);
const auto dst = aten::IndexSelect(adj_.col, eid);
const dgl_id_t src = aten::IndexSelect<dgl_id_t>(adj_.row, eid);
const dgl_id_t dst = aten::IndexSelect<dgl_id_t>(adj_.col, eid);
return std::pair<dgl_id_t, dgl_id_t>(src, dst);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment