Unverified Commit 81831111 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] clang-format auto fix. (#4805)



* [Misc] clang-format auto fix.

* fix
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent 16e771c0
......@@ -5,11 +5,13 @@
*/
#include <dgl/array.h>
#include <dgl/sampling/negative.h>
#include <dgl/base_heterograph.h>
#include <dgl/packed_func_ext.h>
#include <dgl/runtime/container.h>
#include <dgl/sampling/negative.h>
#include <utility>
#include "../../../c_api_common.h"
using namespace dgl::runtime;
......@@ -19,13 +21,8 @@ namespace dgl {
namespace sampling {
std::pair<IdArray, IdArray> GlobalUniformNegativeSampling(
HeteroGraphPtr hg,
dgl_type_t etype,
int64_t num_samples,
int num_trials,
bool exclude_self_loops,
bool replace,
double redundancy) {
HeteroGraphPtr hg, dgl_type_t etype, int64_t num_samples, int num_trials,
bool exclude_self_loops, bool replace, double redundancy) {
auto format = hg->SelectFormat(etype, CSC_CODE | CSR_CODE);
if (format == SparseFormat::kCSC) {
CSRMatrix csc = hg->GetCSCMatrix(etype);
......@@ -40,13 +37,14 @@ std::pair<IdArray, IdArray> GlobalUniformNegativeSampling(
return CSRGlobalUniformNegativeSampling(
csr, num_samples, num_trials, exclude_self_loops, replace, redundancy);
} else {
LOG(FATAL) << "COO format is not supported in global uniform negative sampling";
LOG(FATAL)
<< "COO format is not supported in global uniform negative sampling";
return {IdArray(), IdArray()};
}
}
DGL_REGISTER_GLOBAL("sampling.negative._CAPI_DGLGlobalUniformNegativeSampling")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
CHECK_LE(etype, hg->NumEdgeTypes()) << "invalid edge type " << etype;
......@@ -57,7 +55,8 @@ DGL_REGISTER_GLOBAL("sampling.negative._CAPI_DGLGlobalUniformNegativeSampling")
double redundancy = args[6];
List<Value> result;
std::pair<IdArray, IdArray> ret = GlobalUniformNegativeSampling(
hg.sptr(), etype, num_samples, num_trials, exclude_self_loops, replace, redundancy);
hg.sptr(), etype, num_samples, num_trials, exclude_self_loops,
replace, redundancy);
result.push_back(Value(MakeValue(ret.first)));
result.push_back(Value(MakeValue(ret.second)));
*rv = result;
......
......@@ -9,13 +9,14 @@
#include <dgl/array.h>
#include <dgl/runtime/device_api.h>
#include <tuple>
namespace dgl {
namespace sampling {
namespace impl {
template<typename IdxType>
template <typename IdxType>
class DeviceEdgeHashmap {
public:
struct EdgeItem {
......@@ -23,14 +24,19 @@ class DeviceEdgeHashmap {
IdxType cnt;
};
DeviceEdgeHashmap() = delete;
DeviceEdgeHashmap(int64_t num_dst, int64_t num_items_each_dst,
IdxType* dst_unique_edges, EdgeItem *edge_hashmap):
_num_dst(num_dst), _num_items_each_dst(num_items_each_dst),
_dst_unique_edges(dst_unique_edges), _edge_hashmap(edge_hashmap) {}
DeviceEdgeHashmap(
int64_t num_dst, int64_t num_items_each_dst, IdxType *dst_unique_edges,
EdgeItem *edge_hashmap)
: _num_dst(num_dst),
_num_items_each_dst(num_items_each_dst),
_dst_unique_edges(dst_unique_edges),
_edge_hashmap(edge_hashmap) {}
// return the old cnt of this edge
inline __device__ IdxType InsertEdge(const IdxType &src, const IdxType &dst_idx);
inline __device__ IdxType
InsertEdge(const IdxType &src, const IdxType &dst_idx);
inline __device__ IdxType GetDstCount(const IdxType &dst_idx);
inline __device__ IdxType GetEdgeCount(const IdxType &src, const IdxType &dst_idx);
inline __device__ IdxType
GetEdgeCount(const IdxType &src, const IdxType &dst_idx);
private:
int64_t _num_dst;
......@@ -43,19 +49,21 @@ class DeviceEdgeHashmap {
}
};
template<typename IdxType>
template <typename IdxType>
class FrequencyHashmap {
public:
static constexpr int64_t kDefaultEdgeTableScale = 3;
FrequencyHashmap() = delete;
FrequencyHashmap(int64_t num_dst, int64_t num_items_each_dst, DGLContext ctx, cudaStream_t stream,
int64_t edge_table_scale = kDefaultEdgeTableScale);
FrequencyHashmap(
int64_t num_dst, int64_t num_items_each_dst, DGLContext ctx,
cudaStream_t stream, int64_t edge_table_scale = kDefaultEdgeTableScale);
~FrequencyHashmap();
using EdgeItem = typename DeviceEdgeHashmap<IdxType>::EdgeItem;
std::tuple<IdArray, IdArray, IdArray> Topk(
const IdxType *src_data, const IdxType *dst_data, DGLDataType dtype,
const int64_t num_edges, const int64_t num_edges_per_node,
const int64_t num_pick);
private:
DGLContext _ctx;
cudaStream_t _stream;
......
......@@ -6,7 +6,9 @@
#include <dgl/array.h>
#include <dgl/base_heterograph.h>
#include <utility>
#include "randomwalks_impl.h"
namespace dgl {
......@@ -18,10 +20,9 @@ namespace sampling {
namespace impl {
template<DGLDeviceType XPU, typename IdxType>
template <DGLDeviceType XPU, typename IdxType>
TypeArray GetNodeTypesFromMetapath(
const HeteroGraphPtr hg,
const TypeArray metapath) {
const HeteroGraphPtr hg, const TypeArray metapath) {
uint64_t num_etypes = metapath->shape[0];
TypeArray result = TypeArray::Empty(
{metapath->shape[0] + 1}, metapath->dtype, metapath->ctx);
......@@ -38,8 +39,8 @@ TypeArray GetNodeTypesFromMetapath(
dgl_type_t dsttype = src_dst_type.second;
if (srctype != curr_type) {
LOG(FATAL) << "source of edge type #" << i <<
" does not match destination of edge type #" << i - 1;
LOG(FATAL) << "source of edge type #" << i
<< " does not match destination of edge type #" << i - 1;
return result;
}
curr_type = dsttype;
......@@ -48,14 +49,10 @@ TypeArray GetNodeTypesFromMetapath(
return result;
}
template
TypeArray GetNodeTypesFromMetapath<kDGLCPU, int32_t>(
const HeteroGraphPtr hg,
const TypeArray metapath);
template
TypeArray GetNodeTypesFromMetapath<kDGLCPU, int64_t>(
const HeteroGraphPtr hg,
const TypeArray metapath);
template TypeArray GetNodeTypesFromMetapath<kDGLCPU, int32_t>(
const HeteroGraphPtr hg, const TypeArray metapath);
template TypeArray GetNodeTypesFromMetapath<kDGLCPU, int64_t>(
const HeteroGraphPtr hg, const TypeArray metapath);
}; // namespace impl
......
......@@ -4,11 +4,13 @@
* \brief DGL sampler
*/
#include <cuda_runtime.h>
#include <dgl/array.h>
#include <dgl/base_heterograph.h>
#include <dgl/runtime/device_api.h>
#include <cuda_runtime.h>
#include <utility>
#include "randomwalks_impl.h"
namespace dgl {
......@@ -20,19 +22,17 @@ namespace sampling {
namespace impl {
template<DGLDeviceType XPU, typename IdxType>
template <DGLDeviceType XPU, typename IdxType>
TypeArray GetNodeTypesFromMetapath(
const HeteroGraphPtr hg,
const TypeArray metapath) {
const HeteroGraphPtr hg, const TypeArray metapath) {
uint64_t num_etypes = metapath->shape[0];
auto cpu_ctx = DGLContext{kDGLCPU, 0};
auto metapath_ctx = metapath->ctx;
auto stream = DeviceAPI::Get(metapath_ctx)->GetStream();
TypeArray h_result = TypeArray::Empty(
{metapath->shape[0] + 1}, metapath->dtype, cpu_ctx);
TypeArray h_result =
TypeArray::Empty({metapath->shape[0] + 1}, metapath->dtype, cpu_ctx);
auto h_result_data = h_result.Ptr<IdxType>();
auto h_metapath = metapath.CopyTo(cpu_ctx);
......@@ -48,8 +48,8 @@ TypeArray GetNodeTypesFromMetapath(
dgl_type_t dsttype = src_dst_type.second;
if (srctype != curr_type) {
LOG(FATAL) << "source of edge type #" << i <<
" does not match destination of edge type #" << i - 1;
LOG(FATAL) << "source of edge type #" << i
<< " does not match destination of edge type #" << i - 1;
}
curr_type = dsttype;
h_result_data[i + 1] = dsttype;
......@@ -60,14 +60,10 @@ TypeArray GetNodeTypesFromMetapath(
return result;
}
template
TypeArray GetNodeTypesFromMetapath<kDGLCUDA, int32_t>(
const HeteroGraphPtr hg,
const TypeArray metapath);
template
TypeArray GetNodeTypesFromMetapath<kDGLCUDA, int64_t>(
const HeteroGraphPtr hg,
const TypeArray metapath);
template TypeArray GetNodeTypesFromMetapath<kDGLCUDA, int32_t>(
const HeteroGraphPtr hg, const TypeArray metapath);
template TypeArray GetNodeTypesFromMetapath<kDGLCUDA, int64_t>(
const HeteroGraphPtr hg, const TypeArray metapath);
}; // namespace impl
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment