Unverified Commit b2d38ca8 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] clang-format auto fix. (#4803)



* [Misc] clang-format auto fix.

* manual
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent 07dc8fb6
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include <dgl/graph_serializer.h> #include <dgl/graph_serializer.h>
#include <dmlc/io.h> #include <dmlc/io.h>
#include <dmlc/serializer.h> #include <dmlc/serializer.h>
#include <memory> #include <memory>
namespace dmlc { namespace dmlc {
......
/*! /*!
* Copyright (c) 2020-2022 by Contributors * Copyright (c) 2020-2022 by Contributors
* \file array/tensordispatch.h * \file array/tensordispatch.h
* \brief This file defines the dispatcher of tensor operators to framework-specific * \brief This file defines the dispatcher of tensor operators to
* implementations. * framework-specific implementations.
* *
* The dispatcher consists of a TensorDispatcher singleton in DGL C library and * The dispatcher consists of a TensorDispatcher singleton in DGL C library and
* one separately-built shared library per supported backend. * one separately-built shared library per supported backend.
...@@ -15,14 +15,14 @@ ...@@ -15,14 +15,14 @@
* The TensorDispatcher singleton maintains a mapping from an array operator to * The TensorDispatcher singleton maintains a mapping from an array operator to
* the address of the corresponding symbol in the shared library. During * the address of the corresponding symbol in the shared library. During
* initialization, the TensorDispatcher checks which backend DGL is using. * initialization, the TensorDispatcher checks which backend DGL is using.
* It then locates and opens the corresponding shared library using dlopen(3) (or * It then locates and opens the corresponding shared library using dlopen(3)
* LoadLibrary in Windows), and populates the said mapping above with dlsym(3) * (or LoadLibrary in Windows), and populates the said mapping above with
* (or GetProcAddress in Windows). * dlsym(3) (or GetProcAddress in Windows).
* *
* A tensor operator in TensorDispatcher first checks whether the corresponding symbol * A tensor operator in TensorDispatcher first checks whether the corresponding
* address is found in the mapping. If so, it calls the function located at the * symbol address is found in the mapping. If so, it calls the function located
* symbol address instead, allocate/free pieces of memory on CPU/GPU. * at the symbol address instead, allocate/free pieces of memory on CPU/GPU. If
* If not, it falls back to DeviceAPI::AllocWorkspace/FreeWorkspace. * not, it falls back to DeviceAPI::AllocWorkspace/FreeWorkspace.
*/ */
#ifndef DGL_RUNTIME_TENSORDISPATCH_H_ #ifndef DGL_RUNTIME_TENSORDISPATCH_H_
...@@ -38,14 +38,18 @@ ...@@ -38,14 +38,18 @@
#endif // DGL_USE_CUDA #endif // DGL_USE_CUDA
#include "ndarray.h" #include "ndarray.h"
/*! \brief Casts a pointer \c entry to a function pointer with signature of \c func */ /*!
* \brief Casts a pointer \c entry to a function pointer with signature of \c
* func.
*/
#define FUNCCAST(func, entry) (*reinterpret_cast<decltype(&(func))>(entry)) #define FUNCCAST(func, entry) (*reinterpret_cast<decltype(&(func))>(entry))
namespace dgl { namespace dgl {
namespace runtime { namespace runtime {
/*! /*!
* \brief Dispatcher that delegates the function calls to framework-specific C++ APIs. * \brief Dispatcher that delegates the function calls to framework-specific C++
* APIs.
* *
* This class is not thread-safe. * This class is not thread-safe.
*/ */
...@@ -57,17 +61,14 @@ class TensorDispatcher { ...@@ -57,17 +61,14 @@ class TensorDispatcher {
return &inst; return &inst;
} }
/*! \brief Whether an adapter library is available */ /*! \brief Whether an adapter library is available. */
inline bool IsAvailable() { inline bool IsAvailable() { return available_; }
return available_;
}
/*! \brief Load symbols from the given tensor adapter library path */ /*! \brief Load symbols from the given tensor adapter library path. */
bool Load(const char *path_cstr); bool Load(const char* path_cstr);
/*! /*!
* \brief Allocate a piece of CPU memory via * \brief Allocate a piece of CPU memory via PyTorch's CPUAllocator.
* PyTorch's CPUAllocator.
* Used in CPUDeviceAPI::AllocWorkspace(). * Used in CPUDeviceAPI::AllocWorkspace().
* *
* \param nbytes The size to be allocated. * \param nbytes The size to be allocated.
...@@ -146,8 +147,8 @@ class TensorDispatcher { ...@@ -146,8 +147,8 @@ class TensorDispatcher {
inline void RecordStream(void* ptr, DGLStreamHandle stream, int device_id) { inline void RecordStream(void* ptr, DGLStreamHandle stream, int device_id) {
#ifdef DGL_USE_CUDA #ifdef DGL_USE_CUDA
auto entry = entrypoints_[Op::kRecordStream]; auto entry = entrypoints_[Op::kRecordStream];
FUNCCAST(tensoradapter::RecordStream, entry)( FUNCCAST(tensoradapter::RecordStream, entry)
ptr, static_cast<cudaStream_t>(stream), device_id); (ptr, static_cast<cudaStream_t>(stream), device_id);
#endif // DGL_USE_CUDA #endif // DGL_USE_CUDA
} }
...@@ -162,14 +163,10 @@ class TensorDispatcher { ...@@ -162,14 +163,10 @@ class TensorDispatcher {
* *
* Must match the functions in tensoradapter/include/tensoradapter.h. * Must match the functions in tensoradapter/include/tensoradapter.h.
*/ */
static constexpr const char *names_[] = { static constexpr const char* names_[] = {
"CPURawAlloc", "CPURawAlloc", "CPURawDelete",
"CPURawDelete",
#ifdef DGL_USE_CUDA #ifdef DGL_USE_CUDA
"CUDARawAlloc", "CUDARawAlloc", "CUDARawDelete", "CUDACurrentStream", "RecordStream",
"CUDARawDelete",
"CUDACurrentStream",
"RecordStream",
#endif // DGL_USE_CUDA #endif // DGL_USE_CUDA
}; };
...@@ -191,13 +188,9 @@ class TensorDispatcher { ...@@ -191,13 +188,9 @@ class TensorDispatcher {
/*! \brief Entrypoints of each function */ /*! \brief Entrypoints of each function */
void* entrypoints_[num_entries_] = { void* entrypoints_[num_entries_] = {
nullptr, nullptr, nullptr,
nullptr,
#ifdef DGL_USE_CUDA #ifdef DGL_USE_CUDA
nullptr, nullptr, nullptr, nullptr, nullptr,
nullptr,
nullptr,
nullptr,
#endif // DGL_USE_CUDA #endif // DGL_USE_CUDA
}; };
......
...@@ -34,8 +34,8 @@ class ThreadGroup { ...@@ -34,8 +34,8 @@ class ThreadGroup {
* `worker_callback` will only be called for values >= 1. This * `worker_callback` will only be called for values >= 1. This
* allows use of the main thread as a worker. * allows use of the main thread as a worker.
*/ */
ThreadGroup(int num_workers, ThreadGroup(
std::function<void(int)> worker_callback, int num_workers, std::function<void(int)> worker_callback,
bool exclude_worker0 = false); bool exclude_worker0 = false);
~ThreadGroup(); ~ThreadGroup();
...@@ -70,8 +70,8 @@ class ThreadGroup { ...@@ -70,8 +70,8 @@ class ThreadGroup {
/*! /*!
* \brief Platform-agnostic no-op. * \brief Platform-agnostic no-op.
*/ */
// This used to be Yield(), renaming to YieldThread() because windows.h defined it as a // This used to be Yield(), renaming to YieldThread() because windows.h defined
// macro in later SDKs. // it as a macro in later SDKs.
void YieldThread(); void YieldThread();
/*! /*!
...@@ -79,7 +79,6 @@ void YieldThread(); ...@@ -79,7 +79,6 @@ void YieldThread();
*/ */
int MaxConcurrency(); int MaxConcurrency();
} // namespace threading } // namespace threading
} // namespace runtime } // namespace runtime
} // namespace dgl } // namespace dgl
......
...@@ -6,10 +6,11 @@ ...@@ -6,10 +6,11 @@
#ifndef DGL_SAMPLER_H_ #ifndef DGL_SAMPLER_H_
#define DGL_SAMPLER_H_ #define DGL_SAMPLER_H_
#include <vector>
#include <string>
#include <cstdlib> #include <cstdlib>
#include <ctime> #include <ctime>
#include <string>
#include <vector>
#include "graph_interface.h" #include "graph_interface.h"
#include "nodeflow.h" #include "nodeflow.h"
...@@ -32,13 +33,11 @@ class SamplerOp { ...@@ -32,13 +33,11 @@ class SamplerOp {
* \param probability the transition probability (float/double). * \param probability the transition probability (float/double).
* \return a NodeFlow graph. * \return a NodeFlow graph.
*/ */
template<typename ValueType> template <typename ValueType>
static NodeFlow NeighborSample(const ImmutableGraph *graph, static NodeFlow NeighborSample(
const std::vector<dgl_id_t>& seeds, const ImmutableGraph *graph, const std::vector<dgl_id_t> &seeds,
const std::string &edge_type, const std::string &edge_type, int num_hops, int expand_factor,
int num_hops, int expand_factor, const bool add_self_loop, const ValueType *probability);
const bool add_self_loop,
const ValueType *probability);
/*! /*!
* \brief Sample a graph from the seed vertices with layer sampling. * \brief Sample a graph from the seed vertices with layer sampling.
...@@ -50,10 +49,9 @@ class SamplerOp { ...@@ -50,10 +49,9 @@ class SamplerOp {
* \param layer_sizes The size of layers. * \param layer_sizes The size of layers.
* \return a NodeFlow graph. * \return a NodeFlow graph.
*/ */
static NodeFlow LayerUniformSample(const ImmutableGraph *graph, static NodeFlow LayerUniformSample(
const std::vector<dgl_id_t>& seeds, const ImmutableGraph *graph, const std::vector<dgl_id_t> &seeds,
const std::string &neigh_type, const std::string &neigh_type, IdArray layer_sizes);
IdArray layer_sizes);
}; };
} // namespace dgl } // namespace dgl
......
...@@ -6,40 +6,37 @@ ...@@ -6,40 +6,37 @@
#ifndef DGL_SAMPLING_NEGATIVE_H_ #ifndef DGL_SAMPLING_NEGATIVE_H_
#define DGL_SAMPLING_NEGATIVE_H_ #define DGL_SAMPLING_NEGATIVE_H_
#include <dgl/base_heterograph.h>
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/base_heterograph.h>
#include <utility> #include <utility>
namespace dgl { namespace dgl {
namespace sampling { namespace sampling {
/*! /*!
* \brief Given an edge type, uniformly sample source-destination pairs that do not have * \brief Given an edge type, uniformly sample source-destination pairs that do
* an edge in between using rejection sampling. * not have an edge in between using rejection sampling.
* *
* \note This function may not return the same number of elements as the given number * \note This function may not return the same number of elements as the given
* of samples. * number of samples.
* \note This function requires sorting the CSR or CSC matrix of the graph in-place. It * \note This function requires sorting the CSR or CSC matrix of the graph
* prefers CSC over CSR. * in-place. It prefers CSC over CSR.
* *
* \param hg The graph. * \param hg The graph.
* \param etype The edge type. * \param etype The edge type.
* \param num_samples The number of negative examples to sample. * \param num_samples The number of negative examples to sample.
* \param num_trials The number of rejection sampling trials. * \param num_trials The number of rejection sampling trials.
* \param exclude_self_loops Do not include the examples where the source equals the * \param exclude_self_loops Do not include the examples where the source equals
* destination. * the destination.
* \param replace Whether to sample with replacement. * \param replace Whether to sample with replacement.
* \param redundancy How much redundant negative examples to take in case of duplicate examples. * \param redundancy How much redundant negative examples to take in case of
* duplicate examples.
* \return The pair of source and destination tensors. * \return The pair of source and destination tensors.
*/ */
std::pair<IdArray, IdArray> GlobalUniformNegativeSampling( std::pair<IdArray, IdArray> GlobalUniformNegativeSampling(
HeteroGraphPtr hg, HeteroGraphPtr hg, dgl_type_t etype, int64_t num_samples, int num_trials,
dgl_type_t etype, bool exclude_self_loops, bool replace, double redundancy);
int64_t num_samples,
int num_trials,
bool exclude_self_loops,
bool replace,
double redundancy);
}; // namespace sampling }; // namespace sampling
}; // namespace dgl }; // namespace dgl
......
...@@ -6,81 +6,75 @@ ...@@ -6,81 +6,75 @@
#ifndef DGL_SAMPLING_NEIGHBOR_H_ #ifndef DGL_SAMPLING_NEIGHBOR_H_
#define DGL_SAMPLING_NEIGHBOR_H_ #define DGL_SAMPLING_NEIGHBOR_H_
#include <dgl/base_heterograph.h>
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/base_heterograph.h>
#include <vector> #include <vector>
namespace dgl { namespace dgl {
namespace sampling { namespace sampling {
/*! /*!
* \brief Sample from the neighbors of the given nodes and return the sampled edges as a graph. * \brief Sample from the neighbors of the given nodes and return the sampled
* edges as a graph.
* *
* When sampling with replacement, the sampled subgraph could have parallel edges. * When sampling with replacement, the sampled subgraph could have parallel
* edges.
* *
* For sampling without replace, if fanout > the number of neighbors, all the * For sampling without replace, if fanout > the number of neighbors, all the
* neighbors will be sampled. * neighbors will be sampled.
* *
* \param hg The input graph. * \param hg The input graph.
* \param nodes Node IDs of each type. The vector length must be equal to the number * \param nodes Node IDs of each type. The vector length must be equal to the
* of node types. Empty array is allowed. * number of node types. Empty array is allowed.
* \param fanouts Number of sampled neighbors for each edge type. The vector length * \param fanouts Number of sampled neighbors for each edge type. The vector
* should be equal to the number of edge types, or one if they all * length should be equal to the number of edge types, or one if they all have
* have the same fanout. * the same fanout.
* \param dir Edge direction. * \param dir Edge direction.
* \param probability A vector of 1D float arrays, indicating the transition probability of * \param probability A vector of 1D float arrays, indicating the transition
* each edge by edge type. An empty float array assumes uniform transition. * probability of each edge by edge type. An empty float array assumes uniform
* \param exclude_edges Edges IDs of each type which will be excluded during sampling. * transition.
* The vector length must be equal to the number of edges types. Empty array is allowed. * \param exclude_edges Edges IDs of each type which will be excluded during
* sampling. The vector length must be equal to the number of edges types. Empty
* array is allowed.
* \param replace If true, sample with replacement. * \param replace If true, sample with replacement.
* \return Sampled neighborhoods as a graph. The return graph has the same schema as the * \return Sampled neighborhoods as a graph. The return graph has the same
* original one. * schema as the original one.
*/ */
HeteroSubgraph SampleNeighbors( HeteroSubgraph SampleNeighbors(
const HeteroGraphPtr hg, const HeteroGraphPtr hg, const std::vector<IdArray>& nodes,
const std::vector<IdArray>& nodes, const std::vector<int64_t>& fanouts, EdgeDir dir,
const std::vector<int64_t>& fanouts,
EdgeDir dir,
const std::vector<FloatArray>& probability, const std::vector<FloatArray>& probability,
const std::vector<IdArray>& exclude_edges, const std::vector<IdArray>& exclude_edges, bool replace = true);
bool replace = true);
/*! /*!
* Select the neighbors with k-largest weights on the connecting edges for each given node. * Select the neighbors with k-largest weights on the connecting edges for each
* given node.
* *
* If k > the number of neighbors, all the neighbors are sampled. * If k > the number of neighbors, all the neighbors are sampled.
* *
* \param hg The input graph. * \param hg The input graph.
* \param nodes Node IDs of each type. The vector length must be equal to the number * \param nodes Node IDs of each type. The vector length must be equal to the
* of node types. Empty array is allowed. * number of node types. Empty array is allowed.
* \param k The k value for each edge type. The vector length * \param k The k value for each edge type. The vector length should be equal to
* should be equal to the number of edge types, or one if they all * the number of edge types, or one if they all have the same fanout.
* have the same fanout.
* \param dir Edge direction. * \param dir Edge direction.
* \param weight A vector of 1D float arrays, indicating the weights associated with * \param weight A vector of 1D float arrays, indicating the weights associated
* each edge. * witheach edge.
* \param ascending If true, elements are sorted by ascending order, equivalent to find * \param ascending If true, elements are sorted by ascending order, equivalent
* the K smallest values. Otherwise, find K largest values. * to find the K smallest values. Otherwise, find K largest values.
* \return Sampled neighborhoods as a graph. The return graph has the same schema as the * \return Sampled neighborhoods as a graph. The return graph has the same
* original one. * schema as the original one.
*/ */
HeteroSubgraph SampleNeighborsTopk( HeteroSubgraph SampleNeighborsTopk(
const HeteroGraphPtr hg, const HeteroGraphPtr hg, const std::vector<IdArray>& nodes,
const std::vector<IdArray>& nodes, const std::vector<int64_t>& k, EdgeDir dir,
const std::vector<int64_t>& k, const std::vector<FloatArray>& weight, bool ascending = false);
EdgeDir dir,
const std::vector<FloatArray>& weight,
bool ascending = false);
HeteroSubgraph SampleNeighborsBiased( HeteroSubgraph SampleNeighborsBiased(
const HeteroGraphPtr hg, const HeteroGraphPtr hg, const IdArray& nodes, const int64_t fanouts,
const IdArray& nodes, const NDArray& bias, const NDArray& tag_offset, const EdgeDir dir,
const int64_t fanouts, const bool replace);
const NDArray& bias,
const NDArray& tag_offset,
const EdgeDir dir,
const bool replace
);
} // namespace sampling } // namespace sampling
} // namespace dgl } // namespace dgl
......
...@@ -6,11 +6,12 @@ ...@@ -6,11 +6,12 @@
#ifndef DGL_SAMPLING_RANDOMWALKS_H_ #ifndef DGL_SAMPLING_RANDOMWALKS_H_
#define DGL_SAMPLING_RANDOMWALKS_H_ #define DGL_SAMPLING_RANDOMWALKS_H_
#include <dgl/base_heterograph.h>
#include <dgl/array.h> #include <dgl/array.h>
#include <vector> #include <dgl/base_heterograph.h>
#include <utility>
#include <tuple> #include <tuple>
#include <utility>
#include <vector>
namespace dgl { namespace dgl {
...@@ -19,71 +20,67 @@ namespace sampling { ...@@ -19,71 +20,67 @@ namespace sampling {
/*! /*!
* \brief Metapath-based random walk. * \brief Metapath-based random walk.
* \param hg The heterograph. * \param hg The heterograph.
* \param seeds A 1D array of seed nodes, with the type the source type of the first * \param seeds A 1D array of seed nodes, with the type the source type of the
* edge type in the metapath. * first edge type in the metapath.
* \param metapath A 1D array of edge types representing the metapath. * \param metapath A 1D array of edge types representing the metapath.
* \param prob A vector of 1D float arrays, indicating the transition probability of * \param prob A vector of 1D float arrays, indicating the transition
* each edge by edge type. An empty float array assumes uniform transition. * probability of each edge by edge type. An empty float array assumes uniform
* transition.
* \return A pair of * \return A pair of
* 1. One 2D array of shape (len(seeds), len(metapath) + 1) with node IDs. The * 1. One 2D array of shape (len(seeds), len(metapath) + 1) with node
* paths that terminated early are padded with -1. * IDs. The paths that terminated early are padded with -1.
* 2. One 2D array of shape (len(seeds), len(metapath)) with edge IDs. The * 2. One 2D array of shape (len(seeds), len(metapath)) with edge IDs.
* paths that terminated early are padded with -1. * The paths that terminated early are padded with -1.
* 3. One 1D array of shape (len(metapath) + 1) with node type IDs. * 3. One 1D array of shape (len(metapath) + 1) with node type IDs.
*/ */
std::tuple<IdArray, IdArray, TypeArray> RandomWalk( std::tuple<IdArray, IdArray, TypeArray> RandomWalk(
const HeteroGraphPtr hg, const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
const IdArray seeds,
const TypeArray metapath,
const std::vector<FloatArray> &prob); const std::vector<FloatArray> &prob);
/*! /*!
* \brief Metapath-based random walk with restart probability. * \brief Metapath-based random walk with restart probability.
* \param hg The heterograph. * \param hg The heterograph.
* \param seeds A 1D array of seed nodes, with the type the source type of the first * \param seeds A 1D array of seed nodes, with the type the source type of the
* edge type in the metapath. * first edge type in the metapath.
* \param metapath A 1D array of edge types representing the metapath. * \param metapath A 1D array of edge types representing the metapath.
* \param prob A vector of 1D float arrays, indicating the transition probability of * \param prob A vector of 1D float arrays, indicating the transition
* each edge by edge type. An empty float array assumes uniform transition. * probability of each edge by edge type. An empty float array assumes uniform
* \param restart_prob Restart probability * transition.
* \param restart_prob Restart probability.
* \return A pair of * \return A pair of
* 1. One 2D array of shape (len(seeds), len(metapath) + 1) with node IDs. The * 1. One 2D array of shape (len(seeds), len(metapath) + 1) with node
* paths that terminated early are padded with -1. * IDs. The paths that terminated early are padded with -1.
* 2. One 2D array of shape (len(seeds), len(metapath)) with edge IDs. The * 2. One 2D array of shape (len(seeds), len(metapath)) with edge IDs.
* paths that terminated early are padded with -1. * The paths that terminated early are padded with -1.
* 3. One 1D array of shape (len(metapath) + 1) with node type IDs. * 3. One 1D array of shape (len(metapath) + 1) with node type IDs.
*/ */
std::tuple<IdArray, IdArray, TypeArray> RandomWalkWithRestart( std::tuple<IdArray, IdArray, TypeArray> RandomWalkWithRestart(
const HeteroGraphPtr hg, const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
const IdArray seeds, const std::vector<FloatArray> &prob, double restart_prob);
const TypeArray metapath,
const std::vector<FloatArray> &prob,
double restart_prob);
/*! /*!
* \brief Metapath-based random walk with stepwise restart probability. Useful * \brief Metapath-based random walk with stepwise restart probability. Useful
* for PinSAGE-like models. * for PinSAGE-like models.
* \param hg The heterograph. * \param hg The heterograph.
* \param seeds A 1D array of seed nodes, with the type the source type of the first * \param seeds A 1D array of seed nodes, with the type the source type of the
* edge type in the metapath. * first edge type in the metapath.
* \param metapath A 1D array of edge types representing the metapath. * \param metapath A 1D array of edge types representing the metapath.
* \param prob A vector of 1D float arrays, indicating the transition probability of * \param prob A vector of 1D float arrays, indicating the transition
* each edge by edge type. An empty float array assumes uniform transition. * probability of each edge by edge type. An empty float array assumes uniform
* \param restart_prob Restart probability array which has the same number of elements * transition.
* as \c metapath, indicating the probability to terminate after transition. * \param restart_prob Restart probability array which has the same number of
* elements as \c metapath, indicating the probability to terminate after
* transition.
* \return A pair of * \return A pair of
* 1. One 2D array of shape (len(seeds), len(metapath) + 1) with node IDs. The * 1. One 2D array of shape (len(seeds), len(metapath) + 1) with node
* paths that terminated early are padded with -1. * IDs. The paths that terminated early are padded with -1.
* 2. One 2D array of shape (len(seeds), len(metapath)) with edge IDs. The * 2. One 2D array of shape (len(seeds), len(metapath)) with edge IDs.
* paths that terminated early are padded with -1. * The paths that terminated early are padded with -1.
* 3. One 1D array of shape (len(metapath) + 1) with node type IDs. * 3. One 1D array of shape (len(metapath) + 1) with node type IDs.
*/ */
std::tuple<IdArray, IdArray, TypeArray> RandomWalkWithStepwiseRestart( std::tuple<IdArray, IdArray, TypeArray> RandomWalkWithStepwiseRestart(
const HeteroGraphPtr hg, const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
const IdArray seeds, const std::vector<FloatArray> &prob, FloatArray restart_prob);
const TypeArray metapath,
const std::vector<FloatArray> &prob,
FloatArray restart_prob);
}; // namespace sampling }; // namespace sampling
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#define DGL_SCHEDULER_H_ #define DGL_SCHEDULER_H_
#include <vector> #include <vector>
#include "runtime/ndarray.h" #include "runtime/ndarray.h"
namespace dgl { namespace dgl {
...@@ -21,8 +22,8 @@ namespace sched { ...@@ -21,8 +22,8 @@ namespace sched {
* \param msg_ids The edge id for each message * \param msg_ids The edge id for each message
* \param vids The destination vertex for each message * \param vids The destination vertex for each message
* \param recv_ids The recv nodes (for checking zero degree nodes) * \param recv_ids The recv nodes (for checking zero degree nodes)
* \note If there are multiple messages going into the same destination vertex, then * \note If there are multiple messages going into the same destination vertex,
* there will be multiple copies of the destination vertex in vids * then there will be multiple copies of the destination vertex in vids.
* \return a vector of 5 IdArrays for degree bucketing. The 5 arrays are: * \return a vector of 5 IdArrays for degree bucketing. The 5 arrays are:
* degrees: degrees for each bucket * degrees: degrees for each bucket
* nids: destination node ids * nids: destination node ids
...@@ -31,8 +32,8 @@ namespace sched { ...@@ -31,8 +32,8 @@ namespace sched {
* mid_section: number of messages in each bucket (used to split mids) * mid_section: number of messages in each bucket (used to split mids)
*/ */
template <class IdType> template <class IdType>
std::vector<IdArray> DegreeBucketing(const IdArray& msg_ids, const IdArray& vids, std::vector<IdArray> DegreeBucketing(
const IdArray& recv_ids); const IdArray& msg_ids, const IdArray& vids, const IdArray& recv_ids);
/*! /*!
* \brief Generate degree bucketing schedule for group_apply edge * \brief Generate degree bucketing schedule for group_apply edge
...@@ -53,8 +54,8 @@ std::vector<IdArray> DegreeBucketing(const IdArray& msg_ids, const IdArray& vids ...@@ -53,8 +54,8 @@ std::vector<IdArray> DegreeBucketing(const IdArray& msg_ids, const IdArray& vids
* new_uids, new_vids, and new_eids) * new_uids, new_vids, and new_eids)
*/ */
template <class IdType> template <class IdType>
std::vector<IdArray> GroupEdgeByNodeDegree(const IdArray& uids, std::vector<IdArray> GroupEdgeByNodeDegree(
const IdArray& vids, const IdArray& eids); const IdArray& uids, const IdArray& vids, const IdArray& eids);
} // namespace sched } // namespace sched
......
...@@ -7,50 +7,51 @@ ...@@ -7,50 +7,51 @@
#ifndef DGL_TRANSFORM_H_ #ifndef DGL_TRANSFORM_H_
#define DGL_TRANSFORM_H_ #define DGL_TRANSFORM_H_
#include <vector>
#include <tuple> #include <tuple>
#include <utility> #include <utility>
#include "base_heterograph.h" #include <vector>
#include "array.h" #include "array.h"
#include "base_heterograph.h"
namespace dgl { namespace dgl {
namespace transform { namespace transform {
/*! /*!
* \brief Given a list of graphs, remove the common nodes that do not have inbound and * \brief Given a list of graphs, remove the common nodes that do not have
* outbound edges. * inbound and outbound edges.
* *
* The graphs should have identical node ID space (i.e. should have the same set of nodes, * The graphs should have identical node ID space (i.e. should have the same set
* including types and IDs). * of nodes, including types and IDs).
* *
* \param graphs The list of graphs. * \param graphs The list of graphs.
* \param always_preserve The list of nodes to preserve regardless of whether the inbound * \param always_preserve The list of nodes to preserve regardless of whether
* or outbound edges exist. * the inbound or outbound edges exist.
* *
* \return A pair. The first element is the list of compacted graphs, and the second * \return A pair. The first element is the list of compacted graphs, and the
* element is the mapping from the compacted graphs and the original graph. * second element is the mapping from the compacted graphs and the original
* graph.
*/ */
std::pair<std::vector<HeteroGraphPtr>, std::vector<IdArray>> std::pair<std::vector<HeteroGraphPtr>, std::vector<IdArray>> CompactGraphs(
CompactGraphs(
const std::vector<HeteroGraphPtr> &graphs, const std::vector<HeteroGraphPtr> &graphs,
const std::vector<IdArray> &always_preserve); const std::vector<IdArray> &always_preserve);
/*! /*!
* \brief Convert a graph into a bipartite-structured graph for message passing. * \brief Convert a graph into a bipartite-structured graph for message passing.
* *
* Specifically, we create one node type \c ntype_l on the "left" side and another * Specifically, we create one node type \c ntype_l on the "left" side and
* node type \c ntype_r on the "right" side for each node type \c ntype. The nodes of * another node type \c ntype_r on the "right" side for each node type \c ntype.
* type \c ntype_r would contain the nodes designated by the caller, and node type * The nodes of type \c ntype_r would contain the nodes designated by the
* \c ntype_l would contain the nodes that has an edge connecting to one of the * caller, and node type \c ntype_l would contain the nodes that has an edge
* designated nodes. * connecting to one of the designated nodes.
* *
* The nodes of \c ntype_l would also contain the nodes in node type \c ntype_r. * The nodes of \c ntype_l would also contain the nodes in node type \c ntype_r.
* *
* This function is often used for constructing a series of dependency graphs for * This function is often used for constructing a series of dependency graphs
* multi-layer message passing, where we first construct a series of frontier graphs * for multi-layer message passing, where we first construct a series of
* on the original node space, and run the following to get the bipartite graph needed * frontier graphs on the original node space, and run the following to get the
* for message passing with each GNN layer: * bipartite graph needed for message passing with each GNN layer:
* *
* <code> * <code>
* bipartites = [None] * len(num_layers) * bipartites = [None] * len(num_layers)
...@@ -66,20 +67,21 @@ CompactGraphs( ...@@ -66,20 +67,21 @@ CompactGraphs(
* *
* \param graph The graph. * \param graph The graph.
* \param rhs_nodes Designated nodes that would appear on the right side. * \param rhs_nodes Designated nodes that would appear on the right side.
* \param include_rhs_in_lhs If false, do not include the nodes of node type \c ntype_r * \param include_rhs_in_lhs If false, do not include the nodes of node type \c
* in \c ntype_l. * ntype_r in \c ntype_l.
* *
* \return A triplet containing * \return A triplet containing
* * The bipartite-structured graph, * * The bipartite-structured graph,
* * The induced node from the left side for each graph, * * The induced node from the left side for each graph,
* * The induced edges. * * The induced edges.
* *
* \note If include_rhs_in_lhs is true, then for each node type \c ntype, the nodes * \note If include_rhs_in_lhs is true, then for each node type \c ntype, the
* in rhs_nodes[ntype] would always appear first in the nodes of type \c ntype_l * nodes in rhs_nodes[ntype] would always appear first in the nodes of type \c
* in the new graph. * ntype_l in the new graph.
*/ */
std::tuple<HeteroGraphPtr, std::vector<IdArray>, std::vector<IdArray>> std::tuple<HeteroGraphPtr, std::vector<IdArray>, std::vector<IdArray>> ToBlock(
ToBlock(HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes, bool include_rhs_in_lhs); HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes,
bool include_rhs_in_lhs);
/*! /*!
* \brief Convert a multigraph to a simple graph. * \brief Convert a multigraph to a simple graph.
...@@ -87,7 +89,8 @@ ToBlock(HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes, bool includ ...@@ -87,7 +89,8 @@ ToBlock(HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes, bool includ
* \return A triplet of * \return A triplet of
* * \c hg : The said simple graph. * * \c hg : The said simple graph.
* * \c count : The array of edge occurrences per edge type. * * \c count : The array of edge occurrences per edge type.
* * \c edge_map : The mapping from original edge IDs to new edge IDs per edge type. * * \c edge_map : The mapping from original edge IDs to new edge IDs per edge
* type.
* *
* \note Example: consider a graph with the following edges * \note Example: consider a graph with the following edges
* *
...@@ -99,13 +102,14 @@ ToBlock(HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes, bool includ ...@@ -99,13 +102,14 @@ ToBlock(HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes, bool includ
* *
* [(0, 1), (1, 3), (1, 4), (2, 2)] * [(0, 1), (1, 3), (1, 4), (2, 2)]
* *
* * The second element is an array \c count. \c count[i] stands for the number of edges * * The second element is an array \c count. \c count[i] stands for the number
* connecting simple_g.src[i] and simple_g.dst[i] in the original graph. * of edges connecting simple_g.src[i] and simple_g.dst[i] in the original
* graph.
* *
* count[0] = [1, 2, 2, 1] * count[0] = [1, 2, 2, 1]
* *
* * One can find the mapping between edges from the original graph to the new simple * * One can find the mapping between edges from the original graph to the new
* graph. * simple graph.
* *
* edge_map[0] = [0, 1, 3, 1, 2, 2] * edge_map[0] = [0, 1, 3, 1, 2, 2]
*/ */
...@@ -118,11 +122,11 @@ ToSimpleGraph(const HeteroGraphPtr graph); ...@@ -118,11 +122,11 @@ ToSimpleGraph(const HeteroGraphPtr graph);
* \param graph The graph. * \param graph The graph.
* \param eids The edge IDs to remove per edge type. * \param eids The edge IDs to remove per edge type.
* *
* \return A pair of the graph with edges removed, as well as the edge ID mapping from * \return A pair of the graph with edges removed, as well as the edge ID
* the original graph to the new graph per edge type. * mapping from the original graph to the new graph per edge type.
*/ */
std::pair<HeteroGraphPtr, std::vector<IdArray>> std::pair<HeteroGraphPtr, std::vector<IdArray>> RemoveEdges(
RemoveEdges(const HeteroGraphPtr graph, const std::vector<IdArray> &eids); const HeteroGraphPtr graph, const std::vector<IdArray> &eids);
}; // namespace transform }; // namespace transform
......
...@@ -12,12 +12,12 @@ ...@@ -12,12 +12,12 @@
#include <dmlc/serializer.h> #include <dmlc/serializer.h>
#include <deque> #include <deque>
#include <memory>
#include <queue> #include <queue>
#include <string> #include <string>
#include <tuple> #include <tuple>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include <memory>
#include "dmlc/logging.h" #include "dmlc/logging.h"
...@@ -85,7 +85,8 @@ class StreamWithBuffer : public dmlc::SeekStream { ...@@ -85,7 +85,8 @@ class StreamWithBuffer : public dmlc::SeekStream {
* // Read from remote sended pointer list * // Read from remote sended pointer list
* StreamWithBuffer buf_strm(&blob, data_ptr_list) * StreamWithBuffer buf_strm(&blob, data_ptr_list)
*/ */
StreamWithBuffer(std::unique_ptr<dmlc::SeekStream> strm, StreamWithBuffer(
std::unique_ptr<dmlc::SeekStream> strm,
const std::vector<void*>& data_ptr_list) const std::vector<void*>& data_ptr_list)
: strm_(std::move(strm)), send_to_remote_(true) { : strm_(std::move(strm)), send_to_remote_(true) {
for (void* data : data_ptr_list) { for (void* data : data_ptr_list) {
...@@ -136,8 +137,8 @@ class StreamWithBuffer : public dmlc::SeekStream { ...@@ -136,8 +137,8 @@ class StreamWithBuffer : public dmlc::SeekStream {
* \param size buffer size * \param size buffer size
* \param data_ptr_list pointer list for NDArrays to deconstruct from * \param data_ptr_list pointer list for NDArrays to deconstruct from
*/ */
StreamWithBuffer(char* p_buffer, size_t size, StreamWithBuffer(
const std::vector<void*>& data_ptr_list) char* p_buffer, size_t size, const std::vector<void*>& data_ptr_list)
: strm_(new dmlc::MemoryFixedSizeStream(p_buffer, size)), : strm_(new dmlc::MemoryFixedSizeStream(p_buffer, size)),
send_to_remote_(true) { send_to_remote_(true) {
for (void* data : data_ptr_list) { for (void* data : data_ptr_list) {
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <memory> #include <memory>
#include <tuple> #include <tuple>
#include <type_traits> #include <type_traits>
#include "dmlc/logging.h" #include "dmlc/logging.h"
#include "meta_utils.h" #include "meta_utils.h"
#include "xbyak/xbyak.h" #include "xbyak/xbyak.h"
...@@ -61,8 +62,8 @@ class ElemWiseAddUpdate : public Xbyak::CodeGenerator { ...@@ -61,8 +62,8 @@ class ElemWiseAddUpdate : public Xbyak::CodeGenerator {
public: public:
typedef typename Op::type DType; typedef typename Op::type DType;
static_assert( static_assert(
std::is_base_of<std::true_type, std::is_base_of<
utils::has_type<DType, supported_types>>::value, std::true_type, utils::has_type<DType, supported_types>>::value,
"Use case fail dgl::ElemWiseAddUpdate< Operator<DType> > DType is not " "Use case fail dgl::ElemWiseAddUpdate< Operator<DType> > DType is not "
"supported !"); "supported !");
...@@ -82,75 +83,84 @@ class ElemWiseAddUpdate : public Xbyak::CodeGenerator { ...@@ -82,75 +83,84 @@ class ElemWiseAddUpdate : public Xbyak::CodeGenerator {
static constexpr int UNIT_PER_REG = static constexpr int UNIT_PER_REG =
REG_BIT_SIZE / (UNIT_SIZE_BYTES * BITS_IN_BYTES); REG_BIT_SIZE / (UNIT_SIZE_BYTES * BITS_IN_BYTES);
template <class TType, class R1, class R2, template <
utils::CheckCmp<TType, float> = true> class TType, class R1, class R2, utils::CheckCmp<TType, float> = true>
void alias_load(R1 r1, R2 r2) { void alias_load(R1 r1, R2 r2) {
vmovups(r1, r2); vmovups(r1, r2);
} }
template <class TType, class R1, class R2, template <
utils::CheckCmp<TType, double> = true> class TType, class R1, class R2, utils::CheckCmp<TType, double> = true>
void alias_load(R1 r1, R2 r2) { void alias_load(R1 r1, R2 r2) {
vmovupd(r1, r2); vmovupd(r1, r2);
} }
template <class TType, class R1, class R2, template <
utils::CheckCmp<TType, float> = true> class TType, class R1, class R2, utils::CheckCmp<TType, float> = true>
void alias_save(R1 r1, R2 r2) { void alias_save(R1 r1, R2 r2) {
alias_load<TType>(r1, r2); alias_load<TType>(r1, r2);
} }
template <class TType, class R1, class R2, template <
utils::CheckCmp<TType, double> = true> class TType, class R1, class R2, utils::CheckCmp<TType, double> = true>
void alias_save(R1 r1, R2 r2) { void alias_save(R1 r1, R2 r2) {
alias_load<TType>(r1, r2); alias_load<TType>(r1, r2);
} }
template <class TType, class R1, class R2, class R3, template <
class TType, class R1, class R2, class R3,
utils::CheckCmp<TType, float> = true> utils::CheckCmp<TType, float> = true>
void alias_ADD(R1 r1, R2 r2, R3 r3) { void alias_ADD(R1 r1, R2 r2, R3 r3) {
vaddps(r1, r2, r3); vaddps(r1, r2, r3);
} }
template <class TType, class R1, class R2, class R3, template <
class TType, class R1, class R2, class R3,
utils::CheckCmp<TType, double> = true> utils::CheckCmp<TType, double> = true>
void alias_ADD(R1 r1, R2 r2, R3 r3) { void alias_ADD(R1 r1, R2 r2, R3 r3) {
vaddpd(r1, r2, r3); vaddpd(r1, r2, r3);
} }
template <class TType, class R1, class R2, class R3, template <
class TType, class R1, class R2, class R3,
utils::CheckCmp<TType, float> = true> utils::CheckCmp<TType, float> = true>
void alias_SUB(R1 r1, R2 r2, R3 r3) { void alias_SUB(R1 r1, R2 r2, R3 r3) {
vsubps(r1, r2, r3); vsubps(r1, r2, r3);
} }
template <class TType, class R1, class R2, class R3, template <
class TType, class R1, class R2, class R3,
utils::CheckCmp<TType, double> = true> utils::CheckCmp<TType, double> = true>
void alias_SUB(R1 r1, R2 r2, R3 r3) { void alias_SUB(R1 r1, R2 r2, R3 r3) {
vsubpd(r1, r2, r3); vsubpd(r1, r2, r3);
} }
template <class TType, class R1, class R2, class R3, template <
class TType, class R1, class R2, class R3,
utils::CheckCmp<TType, float> = true> utils::CheckCmp<TType, float> = true>
void alias_DIV(R1 r1, R2 r2, R3 r3) { void alias_DIV(R1 r1, R2 r2, R3 r3) {
vdivps(r1, r2, r3); vdivps(r1, r2, r3);
} }
template <class TType, class R1, class R2, class R3, template <
class TType, class R1, class R2, class R3,
utils::CheckCmp<TType, double> = true> utils::CheckCmp<TType, double> = true>
void alias_DIV(R1 r1, R2 r2, R3 r3) { void alias_DIV(R1 r1, R2 r2, R3 r3) {
vdivpd(r1, r2, r3); vdivpd(r1, r2, r3);
} }
template <class TType, class R1, class R2, class R3, template <
class TType, class R1, class R2, class R3,
utils::CheckCmp<TType, float> = true> utils::CheckCmp<TType, float> = true>
void alias_MUL(R1 r1, R2 r2, R3 r3) { void alias_MUL(R1 r1, R2 r2, R3 r3) {
vmulps(r1, r2, r3); vmulps(r1, r2, r3);
} }
template <class TType, class R1, class R2, class R3, template <
class TType, class R1, class R2, class R3,
utils::CheckCmp<TType, double> = true> utils::CheckCmp<TType, double> = true>
void alias_MUL(R1 r1, R2 r2, R3 r3) { void alias_MUL(R1 r1, R2 r2, R3 r3) {
vmulpd(r1, r2, r3); vmulpd(r1, r2, r3);
} }
template <class Operator, template <
utils::Verify<Operator, ::dgl::aten::cpu::op::CopyLhs, class Operator,
supported_types> = true> utils::Verify<Operator, ::dgl::aten::cpu::op::CopyLhs, supported_types> =
true>
void full_chunk_loop_operations() { void full_chunk_loop_operations() {
typedef typename Operator::type IType; typedef typename Operator::type IType;
alias_load<IType>(zmm0, ptr[r_out_ + r9 * sizeof(IType)]); alias_load<IType>(zmm0, ptr[r_out_ + r9 * sizeof(IType)]);
...@@ -158,9 +168,10 @@ class ElemWiseAddUpdate : public Xbyak::CodeGenerator { ...@@ -158,9 +168,10 @@ class ElemWiseAddUpdate : public Xbyak::CodeGenerator {
alias_ADD<IType>(zmm2, zmm0, zmm1); alias_ADD<IType>(zmm2, zmm0, zmm1);
alias_save<IType>(ptr[r_out_ + r9 * sizeof(IType)], zmm2); alias_save<IType>(ptr[r_out_ + r9 * sizeof(IType)], zmm2);
} }
template <class Operator, template <
utils::Verify<Operator, ::dgl::aten::cpu::op::CopyRhs, class Operator,
supported_types> = true> utils::Verify<Operator, ::dgl::aten::cpu::op::CopyRhs, supported_types> =
true>
void full_chunk_loop_operations() { void full_chunk_loop_operations() {
typedef typename Operator::type IType; typedef typename Operator::type IType;
alias_load<IType>(zmm0, ptr[r_out_ + r9 * sizeof(IType)]); alias_load<IType>(zmm0, ptr[r_out_ + r9 * sizeof(IType)]);
...@@ -179,16 +190,20 @@ class ElemWiseAddUpdate : public Xbyak::CodeGenerator { ...@@ -179,16 +190,20 @@ class ElemWiseAddUpdate : public Xbyak::CodeGenerator {
alias_ADD<T>(zmm2, zmm0, zmm2); alias_ADD<T>(zmm2, zmm0, zmm2);
alias_save<T>(ptr[r_out_ + r9 * sizeof(T)], zmm2); alias_save<T>(ptr[r_out_ + r9 * sizeof(T)], zmm2);
} }
template <class Operator, utils::Verify<Operator, ::dgl::aten::cpu::op::Add, template <
supported_types> = true> class Operator,
utils::Verify<Operator, ::dgl::aten::cpu::op::Add, supported_types> =
true>
void full_chunk_loop_operations() { void full_chunk_loop_operations() {
typedef typename Operator::type IType; typedef typename Operator::type IType;
loop_pre<IType>(); loop_pre<IType>();
alias_ADD<IType>(zmm2, zmm1, zmm2); alias_ADD<IType>(zmm2, zmm1, zmm2);
loop_post<IType>(); loop_post<IType>();
} }
template <class Operator, utils::Verify<Operator, ::dgl::aten::cpu::op::Sub, template <
supported_types> = true> class Operator,
utils::Verify<Operator, ::dgl::aten::cpu::op::Sub, supported_types> =
true>
void full_chunk_loop_operations() { void full_chunk_loop_operations() {
typedef typename Operator::type IType; typedef typename Operator::type IType;
loop_pre<IType>(); loop_pre<IType>();
...@@ -196,8 +211,10 @@ class ElemWiseAddUpdate : public Xbyak::CodeGenerator { ...@@ -196,8 +211,10 @@ class ElemWiseAddUpdate : public Xbyak::CodeGenerator {
loop_post<IType>(); loop_post<IType>();
} }
template <class Operator, utils::Verify<Operator, ::dgl::aten::cpu::op::Div, template <
supported_types> = true> class Operator,
utils::Verify<Operator, ::dgl::aten::cpu::op::Div, supported_types> =
true>
void full_chunk_loop_operations() { void full_chunk_loop_operations() {
typedef typename Operator::type IType; typedef typename Operator::type IType;
loop_pre<IType>(); loop_pre<IType>();
...@@ -205,8 +222,10 @@ class ElemWiseAddUpdate : public Xbyak::CodeGenerator { ...@@ -205,8 +222,10 @@ class ElemWiseAddUpdate : public Xbyak::CodeGenerator {
loop_post<IType>(); loop_post<IType>();
} }
template <class Operator, utils::Verify<Operator, ::dgl::aten::cpu::op::Mul, template <
supported_types> = true> class Operator,
utils::Verify<Operator, ::dgl::aten::cpu::op::Mul, supported_types> =
true>
void full_chunk_loop_operations() { void full_chunk_loop_operations() {
typedef typename Operator::type IType; typedef typename Operator::type IType;
loop_pre<IType>(); loop_pre<IType>();
...@@ -214,17 +233,19 @@ class ElemWiseAddUpdate : public Xbyak::CodeGenerator { ...@@ -214,17 +233,19 @@ class ElemWiseAddUpdate : public Xbyak::CodeGenerator {
loop_post<IType>(); loop_post<IType>();
} }
template <class Operator, template <
utils::Verify<Operator, ::dgl::aten::cpu::op::CopyLhs, class Operator,
supported_types> = true> utils::Verify<Operator, ::dgl::aten::cpu::op::CopyLhs, supported_types> =
true>
void remainder_operations(const Xbyak::Opmask mask) { void remainder_operations(const Xbyak::Opmask mask) {
typedef typename Operator::type IType; typedef typename Operator::type IType;
alias_load<IType>(make_zmm(zmm2) | mask, ptr[r_left_ + r9 * sizeof(IType)]); alias_load<IType>(make_zmm(zmm2) | mask, ptr[r_left_ + r9 * sizeof(IType)]);
} }
template <class Operator, template <
utils::Verify<Operator, ::dgl::aten::cpu::op::CopyRhs, class Operator,
supported_types> = true> utils::Verify<Operator, ::dgl::aten::cpu::op::CopyRhs, supported_types> =
true>
void remainder_operations(const Xbyak::Opmask mask) { void remainder_operations(const Xbyak::Opmask mask) {
typedef typename Operator::type IType; typedef typename Operator::type IType;
alias_load<IType>(make_zmm(zmm2) | mask, ptr[r_right + r9 * sizeof(IType)]); alias_load<IType>(make_zmm(zmm2) | mask, ptr[r_right + r9 * sizeof(IType)]);
...@@ -236,32 +257,40 @@ class ElemWiseAddUpdate : public Xbyak::CodeGenerator { ...@@ -236,32 +257,40 @@ class ElemWiseAddUpdate : public Xbyak::CodeGenerator {
alias_load<T>(make_zmm(zmm1) | mask, ptr[r_right + r9 * sizeof(T)]); alias_load<T>(make_zmm(zmm1) | mask, ptr[r_right + r9 * sizeof(T)]);
} }
template <class Operator, utils::Verify<Operator, ::dgl::aten::cpu::op::Mul, template <
supported_types> = true> class Operator,
utils::Verify<Operator, ::dgl::aten::cpu::op::Mul, supported_types> =
true>
void remainder_operations(const Xbyak::Opmask mask) { void remainder_operations(const Xbyak::Opmask mask) {
typedef typename Operator::type IType; typedef typename Operator::type IType;
remainder_fetch_LR<IType>(mask); remainder_fetch_LR<IType>(mask);
alias_MUL<IType>(zmm2, zmm2, zmm1); alias_MUL<IType>(zmm2, zmm2, zmm1);
} }
template <class Operator, utils::Verify<Operator, ::dgl::aten::cpu::op::Add, template <
supported_types> = true> class Operator,
utils::Verify<Operator, ::dgl::aten::cpu::op::Add, supported_types> =
true>
void remainder_operations(const Xbyak::Opmask mask) { void remainder_operations(const Xbyak::Opmask mask) {
typedef typename Operator::type IType; typedef typename Operator::type IType;
remainder_fetch_LR<IType>(mask); remainder_fetch_LR<IType>(mask);
alias_ADD<DType>(zmm2, zmm2, zmm1); alias_ADD<DType>(zmm2, zmm2, zmm1);
} }
template <class Operator, utils::Verify<Operator, ::dgl::aten::cpu::op::Div, template <
supported_types> = true> class Operator,
utils::Verify<Operator, ::dgl::aten::cpu::op::Div, supported_types> =
true>
void remainder_operations(const Xbyak::Opmask mask) { void remainder_operations(const Xbyak::Opmask mask) {
typedef typename Operator::type IType; typedef typename Operator::type IType;
remainder_fetch_LR<IType>(mask); remainder_fetch_LR<IType>(mask);
alias_DIV<DType>(zmm2, zmm2, zmm1); alias_DIV<DType>(zmm2, zmm2, zmm1);
} }
template <class Operator, utils::Verify<Operator, ::dgl::aten::cpu::op::Sub, template <
supported_types> = true> class Operator,
utils::Verify<Operator, ::dgl::aten::cpu::op::Sub, supported_types> =
true>
void remainder_operations(const Xbyak::Opmask mask) { void remainder_operations(const Xbyak::Opmask mask) {
typedef typename Operator::type IType; typedef typename Operator::type IType;
remainder_fetch_LR<IType>(mask); remainder_fetch_LR<IType>(mask);
...@@ -280,7 +309,8 @@ class ElemWiseAddUpdate : public Xbyak::CodeGenerator { ...@@ -280,7 +309,8 @@ class ElemWiseAddUpdate : public Xbyak::CodeGenerator {
if (current_cpu.has(Xbyak::util::Cpu::tAVX512F)) { if (current_cpu.has(Xbyak::util::Cpu::tAVX512F)) {
/* prepare REMAINDER */ /* prepare REMAINDER */
mov(r8, r_size_); mov(r8, r_size_);
and_(r8, and_(
r8,
UNIT_PER_REG - 1); // r8_modulo = size/(sizeof(zmm)/sizeof(float)) UNIT_PER_REG - 1); // r8_modulo = size/(sizeof(zmm)/sizeof(float))
xor_(r9, r9); // reset r9 xor_(r9, r9); // reset r9
cmp(r_size_, UNIT_PER_REG); // if ( size < 16 ) { } cmp(r_size_, UNIT_PER_REG); // if ( size < 16 ) { }
...@@ -306,12 +336,12 @@ class ElemWiseAddUpdate : public Xbyak::CodeGenerator { ...@@ -306,12 +336,12 @@ class ElemWiseAddUpdate : public Xbyak::CodeGenerator {
sal(rax, cl); sal(rax, cl);
dec(rax); // k1= (1 << r8 )-1 dec(rax); // k1= (1 << r8 )-1
kmovw(k1, eax); // set bitmask kmovw(k1, eax); // set bitmask
alias_load<DType>(make_zmm(zmm0) | k1, alias_load<DType>(
ptr[r_out_ + r9 * UNIT_SIZE_BYTES]); make_zmm(zmm0) | k1, ptr[r_out_ + r9 * UNIT_SIZE_BYTES]);
remainder_operations<Op>(k1); remainder_operations<Op>(k1);
alias_ADD<DType>(zmm3, zmm2, zmm0); alias_ADD<DType>(zmm3, zmm2, zmm0);
alias_save<DType>(ptr[r_out_ + r9 * UNIT_SIZE_BYTES], alias_save<DType>(
make_zmm(zmm3) | k1); ptr[r_out_ + r9 * UNIT_SIZE_BYTES], make_zmm(zmm3) | k1);
L("done"); L("done");
applicable_ = true; applicable_ = true;
log_intel("AVX512F cpu kernel is ready"); log_intel("AVX512F cpu kernel is ready");
......
...@@ -23,7 +23,8 @@ struct has_type<T, std::tuple<U, Ts...>> : has_type<T, std::tuple<Ts...>> {}; ...@@ -23,7 +23,8 @@ struct has_type<T, std::tuple<U, Ts...>> : has_type<T, std::tuple<Ts...>> {};
template <typename T, typename... Ts> template <typename T, typename... Ts>
struct has_type<T, std::tuple<T, Ts...>> : std::true_type {}; struct has_type<T, std::tuple<T, Ts...>> : std::true_type {};
template <class OCmp, template <class> class ToP, class Tup, template <
class OCmp, template <class> class ToP, class Tup,
int ok = std::tuple_size<Tup>::value> int ok = std::tuple_size<Tup>::value>
struct DeepType; struct DeepType;
...@@ -38,7 +39,8 @@ struct DeepType<OCmp, ToP, Tup, 2> { ...@@ -38,7 +39,8 @@ struct DeepType<OCmp, ToP, Tup, 2> {
typedef typename std::tuple_element<0, Tup>::type EL1; typedef typename std::tuple_element<0, Tup>::type EL1;
typedef typename std::tuple_element<1, Tup>::type EL2; typedef typename std::tuple_element<1, Tup>::type EL2;
enum { enum {
value = (std::is_same<OCmp, ToP<EL1>>::value || value =
(std::is_same<OCmp, ToP<EL1>>::value ||
std::is_same<OCmp, ToP<EL2>>::value) std::is_same<OCmp, ToP<EL2>>::value)
}; };
}; };
...@@ -49,7 +51,8 @@ struct DeepType<OCmp, ToP, Tup, 3> { ...@@ -49,7 +51,8 @@ struct DeepType<OCmp, ToP, Tup, 3> {
typedef typename std::tuple_element<1, Tup>::type EL2; typedef typename std::tuple_element<1, Tup>::type EL2;
typedef typename std::tuple_element<2, Tup>::type EL3; typedef typename std::tuple_element<2, Tup>::type EL3;
enum { enum {
value = (std::is_same<OCmp, ToP<EL1>>::value || value =
(std::is_same<OCmp, ToP<EL1>>::value ||
std::is_same<OCmp, ToP<EL2>>::value || std::is_same<OCmp, ToP<EL2>>::value ||
std::is_same<OCmp, ToP<EL3>>::value) std::is_same<OCmp, ToP<EL3>>::value)
}; };
......
...@@ -3,54 +3,49 @@ ...@@ -3,54 +3,49 @@
* \file api/api_container.cc * \file api/api_container.cc
* \brief Runtime container APIs. (reference: tvm/src/api/api_lang.cc) * \brief Runtime container APIs. (reference: tvm/src/api/api_lang.cc)
*/ */
#include <dgl/runtime/ndarray.h> #include <dgl/packed_func_ext.h>
#include <dgl/runtime/container.h> #include <dgl/runtime/container.h>
#include <dgl/runtime/ndarray.h>
#include <dgl/runtime/registry.h> #include <dgl/runtime/registry.h>
#include <dgl/packed_func_ext.h>
namespace dgl { namespace dgl {
namespace runtime { namespace runtime {
DGL_REGISTER_GLOBAL("_List") DGL_REGISTER_GLOBAL("_List").set_body([](DGLArgs args, DGLRetValue* rv) {
.set_body([] (DGLArgs args, DGLRetValue* rv) {
auto ret_obj = std::make_shared<runtime::ListObject>(); auto ret_obj = std::make_shared<runtime::ListObject>();
for (int i = 0; i < args.size(); ++i) { for (int i = 0; i < args.size(); ++i) {
ret_obj->data.push_back(args[i].obj_sptr()); ret_obj->data.push_back(args[i].obj_sptr());
} }
*rv = ret_obj; *rv = ret_obj;
}); });
DGL_REGISTER_GLOBAL("_ListGetItem") DGL_REGISTER_GLOBAL("_ListGetItem").set_body([](DGLArgs args, DGLRetValue* rv) {
.set_body([] (DGLArgs args, DGLRetValue* rv) {
auto& sptr = args[0].obj_sptr(); auto& sptr = args[0].obj_sptr();
CHECK(sptr->is_type<ListObject>()); CHECK(sptr->is_type<ListObject>());
auto* o = static_cast<const ListObject*>(sptr.get()); auto* o = static_cast<const ListObject*>(sptr.get());
int64_t i = args[1]; int64_t i = args[1];
CHECK_LT(i, o->data.size()) << "list out of bound"; CHECK_LT(i, o->data.size()) << "list out of bound";
*rv = o->data[i]; *rv = o->data[i];
}); });
DGL_REGISTER_GLOBAL("_ListSize") DGL_REGISTER_GLOBAL("_ListSize").set_body([](DGLArgs args, DGLRetValue* rv) {
.set_body([] (DGLArgs args, DGLRetValue* rv) {
auto& sptr = args[0].obj_sptr(); auto& sptr = args[0].obj_sptr();
CHECK(sptr->is_type<ListObject>()); CHECK(sptr->is_type<ListObject>());
auto* o = static_cast<const ListObject*>(sptr.get()); auto* o = static_cast<const ListObject*>(sptr.get());
*rv = static_cast<int64_t>(o->data.size()); *rv = static_cast<int64_t>(o->data.size());
}); });
DGL_REGISTER_GLOBAL("_Map") DGL_REGISTER_GLOBAL("_Map").set_body([](DGLArgs args, DGLRetValue* rv) {
.set_body([] (DGLArgs args, DGLRetValue* rv) {
CHECK_EQ(args.size() % 2, 0); CHECK_EQ(args.size() % 2, 0);
if (args.size() != 0 && args[0].type_code() == kStr) { if (args.size() != 0 && args[0].type_code() == kStr) {
// StrMap // StrMap
StrMapObject::ContainerType data; StrMapObject::ContainerType data;
for (int i = 0; i < args.size(); i += 2) { for (int i = 0; i < args.size(); i += 2) {
CHECK(args[i].type_code() == kStr) CHECK(args[i].type_code() == kStr) << "The key of the map must be string";
<< "The key of the map must be string";
CHECK(args[i + 1].type_code() == kObjectHandle) CHECK(args[i + 1].type_code() == kObjectHandle)
<< "The value of the map must be an object type"; << "The value of the map must be an object type";
data.emplace(std::make_pair(args[i].operator std::string(), data.emplace(std::make_pair(
args[i + 1].obj_sptr())); args[i].operator std::string(), args[i + 1].obj_sptr()));
} }
auto obj = std::make_shared<StrMapObject>(); auto obj = std::make_shared<StrMapObject>();
obj->data = std::move(data); obj->data = std::move(data);
...@@ -69,7 +64,7 @@ DGL_REGISTER_GLOBAL("_Map") ...@@ -69,7 +64,7 @@ DGL_REGISTER_GLOBAL("_Map")
obj->data = std::move(data); obj->data = std::move(data);
*rv = obj; *rv = obj;
} }
}); });
DGL_REGISTER_GLOBAL("_EmptyStrMap").set_body([](DGLArgs args, DGLRetValue* rv) { DGL_REGISTER_GLOBAL("_EmptyStrMap").set_body([](DGLArgs args, DGLRetValue* rv) {
StrMapObject::ContainerType data; StrMapObject::ContainerType data;
...@@ -78,8 +73,7 @@ DGL_REGISTER_GLOBAL("_EmptyStrMap").set_body([](DGLArgs args, DGLRetValue* rv) { ...@@ -78,8 +73,7 @@ DGL_REGISTER_GLOBAL("_EmptyStrMap").set_body([](DGLArgs args, DGLRetValue* rv) {
*rv = obj; *rv = obj;
}); });
DGL_REGISTER_GLOBAL("_MapSize") DGL_REGISTER_GLOBAL("_MapSize").set_body([](DGLArgs args, DGLRetValue* rv) {
.set_body([] (DGLArgs args, DGLRetValue* rv) {
auto& sptr = args[0].obj_sptr(); auto& sptr = args[0].obj_sptr();
if (sptr->is_type<MapObject>()) { if (sptr->is_type<MapObject>()) {
auto* o = static_cast<const MapObject*>(sptr.get()); auto* o = static_cast<const MapObject*>(sptr.get());
...@@ -89,10 +83,9 @@ DGL_REGISTER_GLOBAL("_MapSize") ...@@ -89,10 +83,9 @@ DGL_REGISTER_GLOBAL("_MapSize")
auto* o = static_cast<const StrMapObject*>(sptr.get()); auto* o = static_cast<const StrMapObject*>(sptr.get());
*rv = static_cast<int64_t>(o->data.size()); *rv = static_cast<int64_t>(o->data.size());
} }
}); });
DGL_REGISTER_GLOBAL("_MapGetItem") DGL_REGISTER_GLOBAL("_MapGetItem").set_body([](DGLArgs args, DGLRetValue* rv) {
.set_body([] (DGLArgs args, DGLRetValue* rv) {
auto& sptr = args[0].obj_sptr(); auto& sptr = args[0].obj_sptr();
if (sptr->is_type<MapObject>()) { if (sptr->is_type<MapObject>()) {
auto* o = static_cast<const MapObject*>(sptr.get()); auto* o = static_cast<const MapObject*>(sptr.get());
...@@ -106,10 +99,9 @@ DGL_REGISTER_GLOBAL("_MapGetItem") ...@@ -106,10 +99,9 @@ DGL_REGISTER_GLOBAL("_MapGetItem")
CHECK(it != o->data.end()) << "cannot find the key in the map"; CHECK(it != o->data.end()) << "cannot find the key in the map";
*rv = (*it).second; *rv = (*it).second;
} }
}); });
DGL_REGISTER_GLOBAL("_MapItems") DGL_REGISTER_GLOBAL("_MapItems").set_body([](DGLArgs args, DGLRetValue* rv) {
.set_body([] (DGLArgs args, DGLRetValue* rv) {
auto& sptr = args[0].obj_sptr(); auto& sptr = args[0].obj_sptr();
if (sptr->is_type<MapObject>()) { if (sptr->is_type<MapObject>()) {
auto* o = static_cast<const MapObject*>(sptr.get()); auto* o = static_cast<const MapObject*>(sptr.get());
...@@ -129,10 +121,9 @@ DGL_REGISTER_GLOBAL("_MapItems") ...@@ -129,10 +121,9 @@ DGL_REGISTER_GLOBAL("_MapItems")
} }
*rv = rkvs; *rv = rkvs;
} }
}); });
DGL_REGISTER_GLOBAL("_MapCount") DGL_REGISTER_GLOBAL("_MapCount").set_body([](DGLArgs args, DGLRetValue* rv) {
.set_body([] (DGLArgs args, DGLRetValue* rv) {
auto& sptr = args[0].obj_sptr(); auto& sptr = args[0].obj_sptr();
if (sptr->is_type<MapObject>()) { if (sptr->is_type<MapObject>()) {
auto* o = static_cast<const MapObject*>(sptr.get()); auto* o = static_cast<const MapObject*>(sptr.get());
...@@ -142,20 +133,18 @@ DGL_REGISTER_GLOBAL("_MapCount") ...@@ -142,20 +133,18 @@ DGL_REGISTER_GLOBAL("_MapCount")
auto* o = static_cast<const StrMapObject*>(sptr.get()); auto* o = static_cast<const StrMapObject*>(sptr.get());
*rv = static_cast<int64_t>(o->data.count(args[1].operator std::string())); *rv = static_cast<int64_t>(o->data.count(args[1].operator std::string()));
} }
}); });
DGL_REGISTER_GLOBAL("_Value") DGL_REGISTER_GLOBAL("_Value").set_body([](DGLArgs args, DGLRetValue* rv) {
.set_body([] (DGLArgs args, DGLRetValue* rv) {
*rv = MakeValue(args[0]); *rv = MakeValue(args[0]);
}); });
DGL_REGISTER_GLOBAL("_ValueGet") DGL_REGISTER_GLOBAL("_ValueGet").set_body([](DGLArgs args, DGLRetValue* rv) {
.set_body([] (DGLArgs args, DGLRetValue* rv) {
auto& sptr = args[0].obj_sptr(); auto& sptr = args[0].obj_sptr();
CHECK(sptr->is_type<ValueObject>()); CHECK(sptr->is_type<ValueObject>());
auto* o = static_cast<const ValueObject*>(sptr.get()); auto* o = static_cast<const ValueObject*>(sptr.get());
*rv = o->data; *rv = o->data;
}); });
} // namespace runtime } // namespace runtime
} // namespace dgl } // namespace dgl
...@@ -3,10 +3,11 @@ ...@@ -3,10 +3,11 @@
* \file api/api_test.cc * \file api/api_test.cc
* \brief C APIs for testing FFI * \brief C APIs for testing FFI
*/ */
#include <dgl/runtime/ndarray.h> #include <dgl/packed_func_ext.h>
#include <dgl/runtime/container.h> #include <dgl/runtime/container.h>
#include <dgl/runtime/ndarray.h>
#include <dgl/runtime/registry.h> #include <dgl/runtime/registry.h>
#include <dgl/packed_func_ext.h>
#include <thread> #include <thread>
namespace dgl { namespace dgl {
...@@ -18,12 +19,12 @@ namespace runtime { ...@@ -18,12 +19,12 @@ namespace runtime {
// - The argument to pass to the python callback // - The argument to pass to the python callback
// It returns what python callback returns // It returns what python callback returns
DGL_REGISTER_GLOBAL("_TestPythonCallback") DGL_REGISTER_GLOBAL("_TestPythonCallback")
.set_body([](DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
LOG(INFO) << "Inside C API"; LOG(INFO) << "Inside C API";
PackedFunc fn = args[0]; PackedFunc fn = args[0];
DGLArgs cb_args(args.values + 1, args.type_codes + 1, 1); DGLArgs cb_args(args.values + 1, args.type_codes + 1, 1);
fn.CallPacked(cb_args, rv); fn.CallPacked(cb_args, rv);
}); });
// Register an internal API for testing python callback. // Register an internal API for testing python callback.
// It receives two arguments: // It receives two arguments:
...@@ -34,17 +35,16 @@ DGL_REGISTER_GLOBAL("_TestPythonCallback") ...@@ -34,17 +35,16 @@ DGL_REGISTER_GLOBAL("_TestPythonCallback")
// The API runs the python callback in a separate thread to test // The API runs the python callback in a separate thread to test
// python GIL is properly released. // python GIL is properly released.
DGL_REGISTER_GLOBAL("_TestPythonCallbackThread") DGL_REGISTER_GLOBAL("_TestPythonCallbackThread")
.set_body([](DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
LOG(INFO) << "Inside C API"; LOG(INFO) << "Inside C API";
PackedFunc fn = args[0]; PackedFunc fn = args[0];
auto thr = std::make_shared<std::thread>( auto thr = std::make_shared<std::thread>([fn, args, rv]() {
[fn, args, rv]() {
LOG(INFO) << "Callback thread " << std::this_thread::get_id(); LOG(INFO) << "Callback thread " << std::this_thread::get_id();
DGLArgs cb_args(args.values + 1, args.type_codes + 1, 1); DGLArgs cb_args(args.values + 1, args.type_codes + 1, 1);
fn.CallPacked(cb_args, rv); fn.CallPacked(cb_args, rv);
}); });
thr->join(); thr->join();
}); });
} // namespace runtime } // namespace runtime
} // namespace dgl } // namespace dgl
...@@ -4,11 +4,12 @@ ...@@ -4,11 +4,12 @@
* \brief DGL array arithmetic operations * \brief DGL array arithmetic operations
*/ */
#include <dgl/packed_func_ext.h> #include <dgl/packed_func_ext.h>
#include <dgl/runtime/ndarray.h>
#include <dgl/runtime/container.h> #include <dgl/runtime/container.h>
#include <dgl/runtime/ndarray.h>
#include "../c_api_common.h" #include "../c_api_common.h"
#include "./array_op.h"
#include "./arith.h" #include "./arith.h"
#include "./array_op.h"
using namespace dgl::runtime; using namespace dgl::runtime;
...@@ -65,7 +66,6 @@ namespace aten { ...@@ -65,7 +66,6 @@ namespace aten {
return ret; \ return ret; \
} }
BINARY_ELEMENT_OP(Add, Add) BINARY_ELEMENT_OP(Add, Add)
BINARY_ELEMENT_OP(Sub, Sub) BINARY_ELEMENT_OP(Sub, Sub)
BINARY_ELEMENT_OP(Mul, Mul) BINARY_ELEMENT_OP(Mul, Mul)
...@@ -108,106 +108,104 @@ UNARY_ELEMENT_OP(Neg, Neg) ...@@ -108,106 +108,104 @@ UNARY_ELEMENT_OP(Neg, Neg)
} // namespace dgl } // namespace dgl
///////////////// Operator overloading for NDArray ///////////////// ///////////////// Operator overloading for NDArray /////////////////
NDArray operator + (const NDArray& lhs, const NDArray& rhs) { NDArray operator+(const NDArray& lhs, const NDArray& rhs) {
return dgl::aten::Add(lhs, rhs); return dgl::aten::Add(lhs, rhs);
} }
NDArray operator - (const NDArray& lhs, const NDArray& rhs) { NDArray operator-(const NDArray& lhs, const NDArray& rhs) {
return dgl::aten::Sub(lhs, rhs); return dgl::aten::Sub(lhs, rhs);
} }
NDArray operator * (const NDArray& lhs, const NDArray& rhs) { NDArray operator*(const NDArray& lhs, const NDArray& rhs) {
return dgl::aten::Mul(lhs, rhs); return dgl::aten::Mul(lhs, rhs);
} }
NDArray operator / (const NDArray& lhs, const NDArray& rhs) { NDArray operator/(const NDArray& lhs, const NDArray& rhs) {
return dgl::aten::Div(lhs, rhs); return dgl::aten::Div(lhs, rhs);
} }
NDArray operator % (const NDArray& lhs, const NDArray& rhs) { NDArray operator%(const NDArray& lhs, const NDArray& rhs) {
return dgl::aten::Mod(lhs, rhs); return dgl::aten::Mod(lhs, rhs);
} }
NDArray operator + (const NDArray& lhs, int64_t rhs) { NDArray operator+(const NDArray& lhs, int64_t rhs) {
return dgl::aten::Add(lhs, rhs); return dgl::aten::Add(lhs, rhs);
} }
NDArray operator - (const NDArray& lhs, int64_t rhs) { NDArray operator-(const NDArray& lhs, int64_t rhs) {
return dgl::aten::Sub(lhs, rhs); return dgl::aten::Sub(lhs, rhs);
} }
NDArray operator * (const NDArray& lhs, int64_t rhs) { NDArray operator*(const NDArray& lhs, int64_t rhs) {
return dgl::aten::Mul(lhs, rhs); return dgl::aten::Mul(lhs, rhs);
} }
NDArray operator / (const NDArray& lhs, int64_t rhs) { NDArray operator/(const NDArray& lhs, int64_t rhs) {
return dgl::aten::Div(lhs, rhs); return dgl::aten::Div(lhs, rhs);
} }
NDArray operator % (const NDArray& lhs, int64_t rhs) { NDArray operator%(const NDArray& lhs, int64_t rhs) {
return dgl::aten::Mod(lhs, rhs); return dgl::aten::Mod(lhs, rhs);
} }
NDArray operator + (int64_t lhs, const NDArray& rhs) { NDArray operator+(int64_t lhs, const NDArray& rhs) {
return dgl::aten::Add(lhs, rhs); return dgl::aten::Add(lhs, rhs);
} }
NDArray operator - (int64_t lhs, const NDArray& rhs) { NDArray operator-(int64_t lhs, const NDArray& rhs) {
return dgl::aten::Sub(lhs, rhs); return dgl::aten::Sub(lhs, rhs);
} }
NDArray operator * (int64_t lhs, const NDArray& rhs) { NDArray operator*(int64_t lhs, const NDArray& rhs) {
return dgl::aten::Mul(lhs, rhs); return dgl::aten::Mul(lhs, rhs);
} }
NDArray operator / (int64_t lhs, const NDArray& rhs) { NDArray operator/(int64_t lhs, const NDArray& rhs) {
return dgl::aten::Div(lhs, rhs); return dgl::aten::Div(lhs, rhs);
} }
NDArray operator % (int64_t lhs, const NDArray& rhs) { NDArray operator%(int64_t lhs, const NDArray& rhs) {
return dgl::aten::Mod(lhs, rhs); return dgl::aten::Mod(lhs, rhs);
} }
NDArray operator - (const NDArray& array) { NDArray operator-(const NDArray& array) { return dgl::aten::Neg(array); }
return dgl::aten::Neg(array);
}
NDArray operator > (const NDArray& lhs, const NDArray& rhs) { NDArray operator>(const NDArray& lhs, const NDArray& rhs) {
return dgl::aten::GT(lhs, rhs); return dgl::aten::GT(lhs, rhs);
} }
NDArray operator < (const NDArray& lhs, const NDArray& rhs) { NDArray operator<(const NDArray& lhs, const NDArray& rhs) {
return dgl::aten::LT(lhs, rhs); return dgl::aten::LT(lhs, rhs);
} }
NDArray operator >= (const NDArray& lhs, const NDArray& rhs) { NDArray operator>=(const NDArray& lhs, const NDArray& rhs) {
return dgl::aten::GE(lhs, rhs); return dgl::aten::GE(lhs, rhs);
} }
NDArray operator <= (const NDArray& lhs, const NDArray& rhs) { NDArray operator<=(const NDArray& lhs, const NDArray& rhs) {
return dgl::aten::LE(lhs, rhs); return dgl::aten::LE(lhs, rhs);
} }
NDArray operator == (const NDArray& lhs, const NDArray& rhs) { NDArray operator==(const NDArray& lhs, const NDArray& rhs) {
return dgl::aten::EQ(lhs, rhs); return dgl::aten::EQ(lhs, rhs);
} }
NDArray operator != (const NDArray& lhs, const NDArray& rhs) { NDArray operator!=(const NDArray& lhs, const NDArray& rhs) {
return dgl::aten::NE(lhs, rhs); return dgl::aten::NE(lhs, rhs);
} }
NDArray operator > (const NDArray& lhs, int64_t rhs) { NDArray operator>(const NDArray& lhs, int64_t rhs) {
return dgl::aten::GT(lhs, rhs); return dgl::aten::GT(lhs, rhs);
} }
NDArray operator < (const NDArray& lhs, int64_t rhs) { NDArray operator<(const NDArray& lhs, int64_t rhs) {
return dgl::aten::LT(lhs, rhs); return dgl::aten::LT(lhs, rhs);
} }
NDArray operator >= (const NDArray& lhs, int64_t rhs) { NDArray operator>=(const NDArray& lhs, int64_t rhs) {
return dgl::aten::GE(lhs, rhs); return dgl::aten::GE(lhs, rhs);
} }
NDArray operator <= (const NDArray& lhs, int64_t rhs) { NDArray operator<=(const NDArray& lhs, int64_t rhs) {
return dgl::aten::LE(lhs, rhs); return dgl::aten::LE(lhs, rhs);
} }
NDArray operator == (const NDArray& lhs, int64_t rhs) { NDArray operator==(const NDArray& lhs, int64_t rhs) {
return dgl::aten::EQ(lhs, rhs); return dgl::aten::EQ(lhs, rhs);
} }
NDArray operator != (const NDArray& lhs, int64_t rhs) { NDArray operator!=(const NDArray& lhs, int64_t rhs) {
return dgl::aten::NE(lhs, rhs); return dgl::aten::NE(lhs, rhs);
} }
NDArray operator > (int64_t lhs, const NDArray& rhs) { NDArray operator>(int64_t lhs, const NDArray& rhs) {
return dgl::aten::GT(lhs, rhs); return dgl::aten::GT(lhs, rhs);
} }
NDArray operator < (int64_t lhs, const NDArray& rhs) { NDArray operator<(int64_t lhs, const NDArray& rhs) {
return dgl::aten::LT(lhs, rhs); return dgl::aten::LT(lhs, rhs);
} }
NDArray operator >= (int64_t lhs, const NDArray& rhs) { NDArray operator>=(int64_t lhs, const NDArray& rhs) {
return dgl::aten::GE(lhs, rhs); return dgl::aten::GE(lhs, rhs);
} }
NDArray operator <= (int64_t lhs, const NDArray& rhs) { NDArray operator<=(int64_t lhs, const NDArray& rhs) {
return dgl::aten::LE(lhs, rhs); return dgl::aten::LE(lhs, rhs);
} }
NDArray operator == (int64_t lhs, const NDArray& rhs) { NDArray operator==(int64_t lhs, const NDArray& rhs) {
return dgl::aten::EQ(lhs, rhs); return dgl::aten::EQ(lhs, rhs);
} }
NDArray operator != (int64_t lhs, const NDArray& rhs) { NDArray operator!=(int64_t lhs, const NDArray& rhs) {
return dgl::aten::NE(lhs, rhs); return dgl::aten::NE(lhs, rhs);
} }
...@@ -6,35 +6,32 @@ ...@@ -6,35 +6,32 @@
#ifndef DGL_ARRAY_CHECK_H_ #ifndef DGL_ARRAY_CHECK_H_
#define DGL_ARRAY_CHECK_H_ #define DGL_ARRAY_CHECK_H_
#include <dgl/runtime/ndarray.h>
#include <dgl/array.h> #include <dgl/array.h>
#include <vector> #include <dgl/runtime/ndarray.h>
#include <string> #include <string>
#include <vector>
namespace dgl { namespace dgl {
namespace aten { namespace aten {
// Check whether the given arguments have the same context. // Check whether the given arguments have the same context.
inline void CheckCtx( inline void CheckCtx(
const DGLContext& ctx, const DGLContext& ctx, const std::vector<NDArray>& arrays,
const std::vector<NDArray>& arrays,
const std::vector<std::string>& names) { const std::vector<std::string>& names) {
for (size_t i = 0; i < arrays.size(); ++i) { for (size_t i = 0; i < arrays.size(); ++i) {
if (IsNullArray(arrays[i])) if (IsNullArray(arrays[i])) continue;
continue;
CHECK_EQ(ctx, arrays[i]->ctx) CHECK_EQ(ctx, arrays[i]->ctx)
<< "Expected device context " << ctx << ". But got " << "Expected device context " << ctx << ". But got " << arrays[i]->ctx
<< arrays[i]->ctx << " for " << names[i] << "."; << " for " << names[i] << ".";
} }
} }
// Check whether input tensors are contiguous. // Check whether input tensors are contiguous.
inline void CheckContiguous( inline void CheckContiguous(
const std::vector<NDArray>& arrays, const std::vector<NDArray>& arrays, const std::vector<std::string>& names) {
const std::vector<std::string>& names) {
for (size_t i = 0; i < arrays.size(); ++i) { for (size_t i = 0; i < arrays.size(); ++i) {
if (IsNullArray(arrays[i])) if (IsNullArray(arrays[i])) continue;
continue;
CHECK(arrays[i].IsContiguous()) CHECK(arrays[i].IsContiguous())
<< "Expect " << names[i] << " to be a contiguous tensor"; << "Expect " << names[i] << " to be a contiguous tensor";
} }
...@@ -42,21 +39,18 @@ inline void CheckContiguous( ...@@ -42,21 +39,18 @@ inline void CheckContiguous(
// Check whether input tensors have valid shape. // Check whether input tensors have valid shape.
inline void CheckShape( inline void CheckShape(
const std::vector<uint64_t>& gdim, const std::vector<uint64_t>& gdim, const std::vector<int>& uev_idx,
const std::vector<int>& uev_idx, const std::vector<NDArray>& arrays, const std::vector<std::string>& names) {
const std::vector<NDArray>& arrays,
const std::vector<std::string>& names) {
for (size_t i = 0; i < arrays.size(); ++i) { for (size_t i = 0; i < arrays.size(); ++i) {
if (IsNullArray(arrays[i])) if (IsNullArray(arrays[i])) continue;
continue;
CHECK_GE(arrays[i]->ndim, 2) CHECK_GE(arrays[i]->ndim, 2)
<< "Expect " << names[i] << " to have ndim >= 2, " << "Expect " << names[i] << " to have ndim >= 2, "
<< "Note that for scalar feature we expand its " << "Note that for scalar feature we expand its "
<< "dimension with an additional dimension of " << "dimension with an additional dimension of "
<< "length one."; << "length one.";
CHECK_EQ(gdim[uev_idx[i]], arrays[i]->shape[0]) CHECK_EQ(gdim[uev_idx[i]], arrays[i]->shape[0])
<< "Expect " << names[i] << " to have size " << "Expect " << names[i] << " to have size " << gdim[uev_idx[i]]
<< gdim[uev_idx[i]] << " on the first dimension, " << " on the first dimension, "
<< "but got " << arrays[i]->shape[0]; << "but got " << arrays[i]->shape[0];
} }
} }
......
...@@ -14,22 +14,21 @@ template <DGLDeviceType XPU, typename IdType> ...@@ -14,22 +14,21 @@ template <DGLDeviceType XPU, typename IdType>
IdArray CumSum(IdArray array, bool prepend_zero) { IdArray CumSum(IdArray array, bool prepend_zero) {
const int64_t len = array.NumElements(); const int64_t len = array.NumElements();
if (len == 0) if (len == 0)
return !prepend_zero ? array : aten::Full(0, 1, array->dtype.bits, array->ctx); return !prepend_zero ? array
: aten::Full(0, 1, array->dtype.bits, array->ctx);
if (prepend_zero) { if (prepend_zero) {
IdArray ret = aten::NewIdArray(len + 1, array->ctx, array->dtype.bits); IdArray ret = aten::NewIdArray(len + 1, array->ctx, array->dtype.bits);
const IdType* in_d = array.Ptr<IdType>(); const IdType* in_d = array.Ptr<IdType>();
IdType* out_d = ret.Ptr<IdType>(); IdType* out_d = ret.Ptr<IdType>();
out_d[0] = 0; out_d[0] = 0;
for (int64_t i = 0; i < len; ++i) for (int64_t i = 0; i < len; ++i) out_d[i + 1] = out_d[i] + in_d[i];
out_d[i + 1] = out_d[i] + in_d[i];
return ret; return ret;
} else { } else {
IdArray ret = aten::NewIdArray(len, array->ctx, array->dtype.bits); IdArray ret = aten::NewIdArray(len, array->ctx, array->dtype.bits);
const IdType* in_d = array.Ptr<IdType>(); const IdType* in_d = array.Ptr<IdType>();
IdType* out_d = ret.Ptr<IdType>(); IdType* out_d = ret.Ptr<IdType>();
out_d[0] = in_d[0]; out_d[0] = in_d[0];
for (int64_t i = 1; i < len; ++i) for (int64_t i = 1; i < len; ++i) out_d[i] = out_d[i - 1] + in_d[i];
out_d[i] = out_d[i - 1] + in_d[i];
return ret; return ret;
} }
} }
......
...@@ -10,9 +10,10 @@ using runtime::NDArray; ...@@ -10,9 +10,10 @@ using runtime::NDArray;
namespace aten { namespace aten {
namespace impl { namespace impl {
template<DGLDeviceType XPU, typename DType, typename IdType> template <DGLDeviceType XPU, typename DType, typename IdType>
NDArray IndexSelect(NDArray array, IdArray index) { NDArray IndexSelect(NDArray array, IdArray index) {
CHECK_EQ(array->shape[0], array.NumElements()) << "Only support tensor" CHECK_EQ(array->shape[0], array.NumElements())
<< "Only support tensor"
<< " whose first dimension equals number of elements, e.g. (5,), (5, 1)"; << " whose first dimension equals number of elements, e.g. (5,), (5, 1)";
const DType* array_data = static_cast<DType*>(array->data); const DType* array_data = static_cast<DType*>(array->data);
......
...@@ -15,8 +15,7 @@ IdArray NonZero(IdArray array) { ...@@ -15,8 +15,7 @@ IdArray NonZero(IdArray array) {
std::vector<int64_t> ret; std::vector<int64_t> ret;
const IdType* data = array.Ptr<IdType>(); const IdType* data = array.Ptr<IdType>();
for (int64_t i = 0; i < array->shape[0]; ++i) for (int64_t i = 0; i < array->shape[0]; ++i)
if (data[i] != 0) if (data[i] != 0) ret.push_back(i);
ret.push_back(i);
return NDArray::FromVector(ret, array->ctx); return NDArray::FromVector(ret, array->ctx);
} }
......
...@@ -6,7 +6,9 @@ ...@@ -6,7 +6,9 @@
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/runtime/ndarray.h> #include <dgl/runtime/ndarray.h>
#include <dgl/runtime/parallel_for.h> #include <dgl/runtime/parallel_for.h>
#include <numeric> #include <numeric>
#include "../arith.h" #include "../arith.h"
namespace dgl { namespace dgl {
...@@ -51,116 +53,186 @@ IdArray BinaryElewise(IdArray lhs, IdArray rhs) { ...@@ -51,116 +53,186 @@ IdArray BinaryElewise(IdArray lhs, IdArray rhs) {
const IdType* lhs_data = static_cast<IdType*>(lhs->data); const IdType* lhs_data = static_cast<IdType*>(lhs->data);
const IdType* rhs_data = static_cast<IdType*>(rhs->data); const IdType* rhs_data = static_cast<IdType*>(rhs->data);
IdType* ret_data = static_cast<IdType*>(ret->data); IdType* ret_data = static_cast<IdType*>(ret->data);
// TODO(BarclayII): this usually incurs lots of overhead in thread spawning, scheduling, // TODO(BarclayII): this usually incurs lots of overhead in thread spawning,
// etc., especially since the workload is very light. Need to replace with parallel_for. // scheduling, etc., especially since the workload is very light. Need to
// replace with parallel_for.
for (int64_t i = 0; i < lhs->shape[0]; i++) { for (int64_t i = 0; i < lhs->shape[0]; i++) {
ret_data[i] = Op::Call(lhs_data[i], rhs_data[i]); ret_data[i] = Op::Call(lhs_data[i], rhs_data[i]);
} }
return ret; return ret;
} }
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Add>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Add>(
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Sub>(IdArray lhs, IdArray rhs); IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Mul>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Sub>(
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Div>(IdArray lhs, IdArray rhs); IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Mod>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Mul>(
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::GT>(IdArray lhs, IdArray rhs); IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::LT>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Div>(
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::GE>(IdArray lhs, IdArray rhs); IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::LE>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Mod>(
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::EQ>(IdArray lhs, IdArray rhs); IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::NE>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCPU, int32_t, arith::GT>(
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Add>(IdArray lhs, IdArray rhs); IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Sub>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCPU, int32_t, arith::LT>(
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Mul>(IdArray lhs, IdArray rhs); IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Div>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCPU, int32_t, arith::GE>(
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Mod>(IdArray lhs, IdArray rhs); IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::GT>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCPU, int32_t, arith::LE>(
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::LT>(IdArray lhs, IdArray rhs); IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::GE>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCPU, int32_t, arith::EQ>(
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::LE>(IdArray lhs, IdArray rhs); IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::EQ>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCPU, int32_t, arith::NE>(
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::NE>(IdArray lhs, IdArray rhs); IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Add>(
IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Sub>(
IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Mul>(
IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Div>(
IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Mod>(
IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::GT>(
IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::LT>(
IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::GE>(
IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::LE>(
IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::EQ>(
IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::NE>(
IdArray lhs, IdArray rhs);
template <DGLDeviceType XPU, typename IdType, typename Op> template <DGLDeviceType XPU, typename IdType, typename Op>
IdArray BinaryElewise(IdArray lhs, IdType rhs) { IdArray BinaryElewise(IdArray lhs, IdType rhs) {
IdArray ret = NewIdArray(lhs->shape[0], lhs->ctx, lhs->dtype.bits); IdArray ret = NewIdArray(lhs->shape[0], lhs->ctx, lhs->dtype.bits);
const IdType* lhs_data = static_cast<IdType*>(lhs->data); const IdType* lhs_data = static_cast<IdType*>(lhs->data);
IdType* ret_data = static_cast<IdType*>(ret->data); IdType* ret_data = static_cast<IdType*>(ret->data);
// TODO(BarclayII): this usually incurs lots of overhead in thread spawning, scheduling, // TODO(BarclayII): this usually incurs lots of overhead in thread spawning,
// etc., especially since the workload is very light. Need to replace with parallel_for. // scheduling, etc., especially since the workload is very light. Need to
// replace with parallel_for.
for (int64_t i = 0; i < lhs->shape[0]; i++) { for (int64_t i = 0; i < lhs->shape[0]; i++) {
ret_data[i] = Op::Call(lhs_data[i], rhs); ret_data[i] = Op::Call(lhs_data[i], rhs);
} }
return ret; return ret;
} }
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Add>(IdArray lhs, int32_t rhs); template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Add>(
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Sub>(IdArray lhs, int32_t rhs); IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Mul>(IdArray lhs, int32_t rhs); template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Sub>(
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Div>(IdArray lhs, int32_t rhs); IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Mod>(IdArray lhs, int32_t rhs); template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Mul>(
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::GT>(IdArray lhs, int32_t rhs); IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::LT>(IdArray lhs, int32_t rhs); template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Div>(
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::GE>(IdArray lhs, int32_t rhs); IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::LE>(IdArray lhs, int32_t rhs); template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Mod>(
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::EQ>(IdArray lhs, int32_t rhs); IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::NE>(IdArray lhs, int32_t rhs); template IdArray BinaryElewise<kDGLCPU, int32_t, arith::GT>(
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Add>(IdArray lhs, int64_t rhs); IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Sub>(IdArray lhs, int64_t rhs); template IdArray BinaryElewise<kDGLCPU, int32_t, arith::LT>(
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Mul>(IdArray lhs, int64_t rhs); IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Div>(IdArray lhs, int64_t rhs); template IdArray BinaryElewise<kDGLCPU, int32_t, arith::GE>(
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Mod>(IdArray lhs, int64_t rhs); IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::GT>(IdArray lhs, int64_t rhs); template IdArray BinaryElewise<kDGLCPU, int32_t, arith::LE>(
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::LT>(IdArray lhs, int64_t rhs); IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::GE>(IdArray lhs, int64_t rhs); template IdArray BinaryElewise<kDGLCPU, int32_t, arith::EQ>(
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::LE>(IdArray lhs, int64_t rhs); IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::EQ>(IdArray lhs, int64_t rhs); template IdArray BinaryElewise<kDGLCPU, int32_t, arith::NE>(
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::NE>(IdArray lhs, int64_t rhs); IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Add>(
IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Sub>(
IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Mul>(
IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Div>(
IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Mod>(
IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::GT>(
IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::LT>(
IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::GE>(
IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::LE>(
IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::EQ>(
IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::NE>(
IdArray lhs, int64_t rhs);
template <DGLDeviceType XPU, typename IdType, typename Op> template <DGLDeviceType XPU, typename IdType, typename Op>
IdArray BinaryElewise(IdType lhs, IdArray rhs) { IdArray BinaryElewise(IdType lhs, IdArray rhs) {
IdArray ret = NewIdArray(rhs->shape[0], rhs->ctx, rhs->dtype.bits); IdArray ret = NewIdArray(rhs->shape[0], rhs->ctx, rhs->dtype.bits);
const IdType* rhs_data = static_cast<IdType*>(rhs->data); const IdType* rhs_data = static_cast<IdType*>(rhs->data);
IdType* ret_data = static_cast<IdType*>(ret->data); IdType* ret_data = static_cast<IdType*>(ret->data);
// TODO(BarclayII): this usually incurs lots of overhead in thread spawning, scheduling, // TODO(BarclayII): this usually incurs lots of overhead in thread spawning,
// etc., especially since the workload is very light. Need to replace with parallel_for. // scheduling, etc., especially since the workload is very light. Need to
// replace with parallel_for.
for (int64_t i = 0; i < rhs->shape[0]; i++) { for (int64_t i = 0; i < rhs->shape[0]; i++) {
ret_data[i] = Op::Call(lhs, rhs_data[i]); ret_data[i] = Op::Call(lhs, rhs_data[i]);
} }
return ret; return ret;
} }
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Add>(int32_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Add>(
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Sub>(int32_t lhs, IdArray rhs); int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Mul>(int32_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Sub>(
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Div>(int32_t lhs, IdArray rhs); int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Mod>(int32_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Mul>(
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::GT>(int32_t lhs, IdArray rhs); int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::LT>(int32_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Div>(
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::GE>(int32_t lhs, IdArray rhs); int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::LE>(int32_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Mod>(
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::EQ>(int32_t lhs, IdArray rhs); int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::NE>(int32_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCPU, int32_t, arith::GT>(
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Add>(int64_t lhs, IdArray rhs); int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Sub>(int64_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCPU, int32_t, arith::LT>(
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Mul>(int64_t lhs, IdArray rhs); int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Div>(int64_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCPU, int32_t, arith::GE>(
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Mod>(int64_t lhs, IdArray rhs); int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::GT>(int64_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCPU, int32_t, arith::LE>(
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::LT>(int64_t lhs, IdArray rhs); int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::GE>(int64_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCPU, int32_t, arith::EQ>(
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::LE>(int64_t lhs, IdArray rhs); int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::EQ>(int64_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCPU, int32_t, arith::NE>(
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::NE>(int64_t lhs, IdArray rhs); int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Add>(
int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Sub>(
int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Mul>(
int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Div>(
int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Mod>(
int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::GT>(
int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::LT>(
int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::GE>(
int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::LE>(
int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::EQ>(
int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::NE>(
int64_t lhs, IdArray rhs);
template <DGLDeviceType XPU, typename IdType, typename Op> template <DGLDeviceType XPU, typename IdType, typename Op>
IdArray UnaryElewise(IdArray lhs) { IdArray UnaryElewise(IdArray lhs) {
IdArray ret = NewIdArray(lhs->shape[0], lhs->ctx, lhs->dtype.bits); IdArray ret = NewIdArray(lhs->shape[0], lhs->ctx, lhs->dtype.bits);
const IdType* lhs_data = static_cast<IdType*>(lhs->data); const IdType* lhs_data = static_cast<IdType*>(lhs->data);
IdType* ret_data = static_cast<IdType*>(ret->data); IdType* ret_data = static_cast<IdType*>(ret->data);
// TODO(BarclayII): this usually incurs lots of overhead in thread spawning, scheduling, // TODO(BarclayII): this usually incurs lots of overhead in thread spawning,
// etc., especially since the workload is very light. Need to replace with parallel_for. // scheduling, etc., especially since the workload is very light. Need to
// replace with parallel_for.
for (int64_t i = 0; i < lhs->shape[0]; i++) { for (int64_t i = 0; i < lhs->shape[0]; i++) {
ret_data[i] = Op::Call(lhs_data[i]); ret_data[i] = Op::Call(lhs_data[i]);
} }
...@@ -180,10 +252,14 @@ NDArray Full(DType val, int64_t length, DGLContext ctx) { ...@@ -180,10 +252,14 @@ NDArray Full(DType val, int64_t length, DGLContext ctx) {
return ret; return ret;
} }
template NDArray Full<kDGLCPU, int32_t>(int32_t val, int64_t length, DGLContext ctx); template NDArray Full<kDGLCPU, int32_t>(
template NDArray Full<kDGLCPU, int64_t>(int64_t val, int64_t length, DGLContext ctx); int32_t val, int64_t length, DGLContext ctx);
template NDArray Full<kDGLCPU, float>(float val, int64_t length, DGLContext ctx); template NDArray Full<kDGLCPU, int64_t>(
template NDArray Full<kDGLCPU, double>(double val, int64_t length, DGLContext ctx); int64_t val, int64_t length, DGLContext ctx);
template NDArray Full<kDGLCPU, float>(
float val, int64_t length, DGLContext ctx);
template NDArray Full<kDGLCPU, double>(
double val, int64_t length, DGLContext ctx);
///////////////////////////// Range ///////////////////////////// ///////////////////////////// Range /////////////////////////////
...@@ -216,7 +292,8 @@ IdArray Relabel_(const std::vector<IdArray>& arrays) { ...@@ -216,7 +292,8 @@ IdArray Relabel_(const std::vector<IdArray>& arrays) {
} }
} }
// map array // map array
IdArray maparr = NewIdArray(newid, DGLContext{kDGLCPU, 0}, sizeof(IdType) * 8); IdArray maparr =
NewIdArray(newid, DGLContext{kDGLCPU, 0}, sizeof(IdType) * 8);
IdType* maparr_data = static_cast<IdType*>(maparr->data); IdType* maparr_data = static_cast<IdType*>(maparr->data);
for (const auto& kv : oldv2newv) { for (const auto& kv : oldv2newv) {
maparr_data[kv.second] = kv.first; maparr_data[kv.second] = kv.first;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment