Unverified Commit 8ac27dad authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] clang-format auto fix. (#4824)



* [Misc] clang-format auto fix.

* blabla

* ablabla

* blabla
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent bcd37684
This diff is collapsed.
...@@ -4,9 +4,11 @@ ...@@ -4,9 +4,11 @@
* @brief Retrieve entries of a CSR matrix * @brief Retrieve entries of a CSR matrix
*/ */
#include <dgl/array.h> #include <dgl/array.h>
#include <vector>
#include <unordered_set>
#include <numeric> #include <numeric>
#include <unordered_set>
#include <vector>
#include "../../runtime/cuda/cuda_common.h" #include "../../runtime/cuda/cuda_common.h"
#include "./utils.h" #include "./utils.h"
...@@ -19,7 +21,8 @@ namespace impl { ...@@ -19,7 +21,8 @@ namespace impl {
template <DGLDeviceType XPU, typename IdType, typename DType> template <DGLDeviceType XPU, typename IdType, typename DType>
NDArray CSRGetData( NDArray CSRGetData(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, DType filler) { CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,
NDArray weights, DType filler) {
const int64_t rowlen = rows->shape[0]; const int64_t rowlen = rows->shape[0];
const int64_t collen = cols->shape[0]; const int64_t collen = cols->shape[0];
...@@ -31,44 +34,44 @@ NDArray CSRGetData( ...@@ -31,44 +34,44 @@ NDArray CSRGetData(
const int64_t rstlen = std::max(rowlen, collen); const int64_t rstlen = std::max(rowlen, collen);
IdArray rst = NDArray::Empty({rstlen}, weights->dtype, rows->ctx); IdArray rst = NDArray::Empty({rstlen}, weights->dtype, rows->ctx);
if (rstlen == 0) if (rstlen == 0) return rst;
return rst;
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
const int nt = cuda::FindNumThreads(rstlen); const int nt = cuda::FindNumThreads(rstlen);
const int nb = (rstlen + nt - 1) / nt; const int nb = (rstlen + nt - 1) / nt;
if (return_eids) if (return_eids)
BUG_IF_FAIL(DGLDataTypeTraits<DType>::dtype == rows->dtype) << BUG_IF_FAIL(DGLDataTypeTraits<DType>::dtype == rows->dtype)
"DType does not match row's dtype."; << "DType does not match row's dtype.";
const IdType* indptr_data = csr.indptr.Ptr<IdType>(); const IdType* indptr_data = csr.indptr.Ptr<IdType>();
const IdType* indices_data = csr.indices.Ptr<IdType>(); const IdType* indices_data = csr.indices.Ptr<IdType>();
const IdType* data_data = CSRHasData(csr) ? csr.data.Ptr<IdType>() : nullptr; const IdType* data_data = CSRHasData(csr) ? csr.data.Ptr<IdType>() : nullptr;
if (csr.is_pinned) { if (csr.is_pinned) {
CUDA_CALL(cudaHostGetDevicePointer( CUDA_CALL(
&indptr_data, csr.indptr.Ptr<IdType>(), 0)); cudaHostGetDevicePointer(&indptr_data, csr.indptr.Ptr<IdType>(), 0));
CUDA_CALL(cudaHostGetDevicePointer( CUDA_CALL(
&indices_data, csr.indices.Ptr<IdType>(), 0)); cudaHostGetDevicePointer(&indices_data, csr.indices.Ptr<IdType>(), 0));
if (CSRHasData(csr)) { if (CSRHasData(csr)) {
CUDA_CALL(cudaHostGetDevicePointer( CUDA_CALL(
&data_data, csr.data.Ptr<IdType>(), 0)); cudaHostGetDevicePointer(&data_data, csr.data.Ptr<IdType>(), 0));
} }
} }
// TODO(minjie): use binary search for sorted csr // TODO(minjie): use binary search for sorted csr
CUDA_KERNEL_CALL(cuda::_LinearSearchKernel, CUDA_KERNEL_CALL(
nb, nt, 0, stream, cuda::_LinearSearchKernel, nb, nt, 0, stream, indptr_data, indices_data,
indptr_data, indices_data, data_data, data_data, rows.Ptr<IdType>(), cols.Ptr<IdType>(), row_stride, col_stride,
rows.Ptr<IdType>(), cols.Ptr<IdType>(), rstlen, return_eids ? nullptr : weights.Ptr<DType>(), filler,
row_stride, col_stride, rstlen, rst.Ptr<DType>());
return_eids ? nullptr : weights.Ptr<DType>(), filler, rst.Ptr<DType>());
return rst; return rst;
} }
template NDArray CSRGetData<kDGLCUDA, int32_t, __half>( template NDArray CSRGetData<kDGLCUDA, int32_t, __half>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, __half filler); CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,
NDArray weights, __half filler);
template NDArray CSRGetData<kDGLCUDA, int64_t, __half>( template NDArray CSRGetData<kDGLCUDA, int64_t, __half>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, __half filler); CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,
NDArray weights, __half filler);
#if BF16_ENABLED #if BF16_ENABLED
template NDArray CSRGetData<kDGLCUDA, int32_t, __nv_bfloat16>( template NDArray CSRGetData<kDGLCUDA, int32_t, __nv_bfloat16>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,
...@@ -78,19 +81,25 @@ template NDArray CSRGetData<kDGLCUDA, int64_t, __nv_bfloat16>( ...@@ -78,19 +81,25 @@ template NDArray CSRGetData<kDGLCUDA, int64_t, __nv_bfloat16>(
NDArray weights, __nv_bfloat16 filler); NDArray weights, __nv_bfloat16 filler);
#endif // BF16_ENABLED #endif // BF16_ENABLED
template NDArray CSRGetData<kDGLCUDA, int32_t, float>( template NDArray CSRGetData<kDGLCUDA, int32_t, float>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, float filler); CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,
NDArray weights, float filler);
template NDArray CSRGetData<kDGLCUDA, int64_t, float>( template NDArray CSRGetData<kDGLCUDA, int64_t, float>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, float filler); CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,
NDArray weights, float filler);
template NDArray CSRGetData<kDGLCUDA, int32_t, double>( template NDArray CSRGetData<kDGLCUDA, int32_t, double>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, double filler); CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,
NDArray weights, double filler);
template NDArray CSRGetData<kDGLCUDA, int64_t, double>( template NDArray CSRGetData<kDGLCUDA, int64_t, double>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, double filler); CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,
NDArray weights, double filler);
// For CSRGetData<XPU, IdType>(CSRMatrix, NDArray, NDArray) // For CSRGetData<XPU, IdType>(CSRMatrix, NDArray, NDArray)
template NDArray CSRGetData<kDGLCUDA, int32_t, int32_t>( template NDArray CSRGetData<kDGLCUDA, int32_t, int32_t>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, int32_t filler); CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,
NDArray weights, int32_t filler);
template NDArray CSRGetData<kDGLCUDA, int64_t, int64_t>( template NDArray CSRGetData<kDGLCUDA, int64_t, int64_t>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, int64_t filler); CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,
NDArray weights, int64_t filler);
} // namespace impl } // namespace impl
} // namespace aten } // namespace aten
......
This diff is collapsed.
This diff is collapsed.
/** /**
* Copyright (c) 2020 by Contributors * Copyright (c) 2020 by Contributors
* @file array/cuda/dispatcher.cuh * @file array/cuda/dispatcher.cuh
* @brief Templates to dispatch into different cuSPARSE routines based on the type * @brief Templates to dispatch into different cuSPARSE routines based on the
* argument. * type argument.
*/ */
#ifndef DGL_ARRAY_CUDA_CUSPARSE_DISPATCHER_CUH_ #ifndef DGL_ARRAY_CUDA_CUSPARSE_DISPATCHER_CUH_
#define DGL_ARRAY_CUDA_CUSPARSE_DISPATCHER_CUH_ #define DGL_ARRAY_CUDA_CUSPARSE_DISPATCHER_CUH_
#include <cusparse.h> #include <cusparse.h>
#include <dgl/runtime/c_runtime_api.h> #include <dgl/runtime/c_runtime_api.h>
#include "fp16.cuh"
#include "bf16.cuh" #include "bf16.cuh"
#include "fp16.cuh"
namespace dgl { namespace dgl {
namespace aten { namespace aten {
...@@ -40,8 +41,8 @@ template <> ...@@ -40,8 +41,8 @@ template <>
struct CSRGEMM<__half> { struct CSRGEMM<__half> {
template <typename... Args> template <typename... Args>
static inline cusparseStatus_t bufferSizeExt(Args... args) { static inline cusparseStatus_t bufferSizeExt(Args... args) {
// TODO(ndickson): There is no cusparseHcsrgemm2_bufferSizeExt, so a different // TODO(ndickson): There is no cusparseHcsrgemm2_bufferSizeExt, so a
// implementation would be required. // different implementation would be required.
LOG(FATAL) << "CSRGEMM::bufferSizeExt does not support dtype half (FP16)."; LOG(FATAL) << "CSRGEMM::bufferSizeExt does not support dtype half (FP16).";
return static_cast<cusparseStatus_t>(0); return static_cast<cusparseStatus_t>(0);
} }
...@@ -65,9 +66,10 @@ template <> ...@@ -65,9 +66,10 @@ template <>
struct CSRGEMM<__nv_bfloat16> { struct CSRGEMM<__nv_bfloat16> {
template <typename... Args> template <typename... Args>
static inline cusparseStatus_t bufferSizeExt(Args... args) { static inline cusparseStatus_t bufferSizeExt(Args... args) {
// TODO(ndickson): There is no cusparseHcsrgemm2_bufferSizeExt, so a different // TODO(ndickson): There is no cusparseHcsrgemm2_bufferSizeExt, so a
// implementation would be required. // different implementation would be required.
LOG(FATAL) << "CSRGEMM::bufferSizeExt does not support dtype bfloat16 (BF16)."; LOG(FATAL)
<< "CSRGEMM::bufferSizeExt does not support dtype bfloat16 (BF16).";
return static_cast<cusparseStatus_t>(0); return static_cast<cusparseStatus_t>(0);
} }
...@@ -147,8 +149,8 @@ template <> ...@@ -147,8 +149,8 @@ template <>
struct CSRGEAM<__half> { struct CSRGEAM<__half> {
template <typename... Args> template <typename... Args>
static inline cusparseStatus_t bufferSizeExt(Args... args) { static inline cusparseStatus_t bufferSizeExt(Args... args) {
// TODO(ndickson): There is no cusparseHcsrgeam2_bufferSizeExt, so a different // TODO(ndickson): There is no cusparseHcsrgeam2_bufferSizeExt, so a
// implementation would be required. // different implementation would be required.
LOG(FATAL) << "CSRGEAM::bufferSizeExt does not support dtype half (FP16)."; LOG(FATAL) << "CSRGEAM::bufferSizeExt does not support dtype half (FP16).";
return static_cast<cusparseStatus_t>(0); return static_cast<cusparseStatus_t>(0);
} }
...@@ -172,9 +174,10 @@ template <> ...@@ -172,9 +174,10 @@ template <>
struct CSRGEAM<__nv_bfloat16> { struct CSRGEAM<__nv_bfloat16> {
template <typename... Args> template <typename... Args>
static inline cusparseStatus_t bufferSizeExt(Args... args) { static inline cusparseStatus_t bufferSizeExt(Args... args) {
// TODO(ndickson): There is no cusparseHcsrgeam2_bufferSizeExt, so a different // TODO(ndickson): There is no cusparseHcsrgeam2_bufferSizeExt, so a
// implementation would be required. // different implementation would be required.
LOG(FATAL) << "CSRGEAM::bufferSizeExt does not support dtype bfloat16 (BF16)."; LOG(FATAL)
<< "CSRGEAM::bufferSizeExt does not support dtype bfloat16 (BF16).";
return static_cast<cusparseStatus_t>(0); return static_cast<cusparseStatus_t>(0);
} }
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment