Unverified Commit a5d8460c authored by ndickson-nvidia's avatar ndickson-nvidia Committed by GitHub
Browse files

[Bug][Feature] Added more missing FP16 specializations (#4140)

* * Added missing specializations for `__half` of `DLDataTypeTraits`, `IndexSelect`, `Full`, `Scatter_`, `CSRGetData`, `CSRMM`, `CSRSum`, `IndexSelectCPUFromGPU`
* Fixed casting issue in `_LinearSearchKernel` that was preventing it from supporting `__half`
* Added `#if`'d out specializations of `CSRGEMM`, `CSRGEAM`, and `Xgeam`, which would require functions that aren't currently provided by cublas

* * Added more specific error messages for unimplemented FP16 specializations of Xgeam, CSRGEMM, and CSRGEAM

* * Added missing instantiation of DLDataTypeTraits<__half>::dtype

* * Fixed linter error
* Added clearer comment explaining why the cast to long long is necessary

* * Worked around a compile error in some particular setup, where __half can't be constructed on the host side

* * Fixed linter formatting errors

* * Changes to comments as recommended

* * Made recommended changes to logging errors in FP16 specializations
* Also changed the existing Xgeam function for unsupported data types from LOG(INFO) to LOG(FATAL)
parent b8f905f1
...@@ -18,6 +18,10 @@ ...@@ -18,6 +18,10 @@
#include "serializer.h" #include "serializer.h"
#include "shared_mem.h" #include "shared_mem.h"
#ifdef DGL_USE_CUDA
#include <cuda_fp16.h>
#endif
// forward declaration // forward declaration
inline std::ostream& operator << (std::ostream& os, DGLType t); inline std::ostream& operator << (std::ostream& os, DGLType t);
...@@ -46,6 +50,11 @@ GEN_DLDATATYPETRAITS_FOR(int64_t, kDLInt, 64); ...@@ -46,6 +50,11 @@ GEN_DLDATATYPETRAITS_FOR(int64_t, kDLInt, 64);
// converting uints to signed DTypes. // converting uints to signed DTypes.
GEN_DLDATATYPETRAITS_FOR(uint32_t, kDLInt, 32); GEN_DLDATATYPETRAITS_FOR(uint32_t, kDLInt, 32);
GEN_DLDATATYPETRAITS_FOR(uint64_t, kDLInt, 64); GEN_DLDATATYPETRAITS_FOR(uint64_t, kDLInt, 64);
#ifdef DGL_USE_CUDA
#ifdef USE_FP16
GEN_DLDATATYPETRAITS_FOR(__half, kDLFloat, 16);
#endif
#endif
GEN_DLDATATYPETRAITS_FOR(float, kDLFloat, 32); GEN_DLDATATYPETRAITS_FOR(float, kDLFloat, 32);
GEN_DLDATATYPETRAITS_FOR(double, kDLFloat, 64); GEN_DLDATATYPETRAITS_FOR(double, kDLFloat, 64);
#undef GEN_DLDATATYPETRAITS_FOR #undef GEN_DLDATATYPETRAITS_FOR
......
...@@ -55,6 +55,10 @@ template NDArray IndexSelect<kDLGPU, int32_t, int32_t>(NDArray, IdArray); ...@@ -55,6 +55,10 @@ template NDArray IndexSelect<kDLGPU, int32_t, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLGPU, int32_t, int64_t>(NDArray, IdArray); template NDArray IndexSelect<kDLGPU, int32_t, int64_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLGPU, int64_t, int32_t>(NDArray, IdArray); template NDArray IndexSelect<kDLGPU, int64_t, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLGPU, int64_t, int64_t>(NDArray, IdArray); template NDArray IndexSelect<kDLGPU, int64_t, int64_t>(NDArray, IdArray);
#ifdef USE_FP16
template NDArray IndexSelect<kDLGPU, __half, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLGPU, __half, int64_t>(NDArray, IdArray);
#endif
template NDArray IndexSelect<kDLGPU, float, int32_t>(NDArray, IdArray); template NDArray IndexSelect<kDLGPU, float, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLGPU, float, int64_t>(NDArray, IdArray); template NDArray IndexSelect<kDLGPU, float, int64_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLGPU, double, int32_t>(NDArray, IdArray); template NDArray IndexSelect<kDLGPU, double, int32_t>(NDArray, IdArray);
...@@ -63,18 +67,30 @@ template NDArray IndexSelect<kDLGPU, double, int64_t>(NDArray, IdArray); ...@@ -63,18 +67,30 @@ template NDArray IndexSelect<kDLGPU, double, int64_t>(NDArray, IdArray);
template <DLDeviceType XPU, typename DType> template <DLDeviceType XPU, typename DType>
DType IndexSelect(NDArray array, int64_t index) { DType IndexSelect(NDArray array, int64_t index) {
auto device = runtime::DeviceAPI::Get(array->ctx); auto device = runtime::DeviceAPI::Get(array->ctx);
#ifdef USE_FP16
// The initialization constructor for __half is apparently a device-
// only function in some setups, but the current function, IndexSelect,
// isn't run on the device, so it doesn't have access to that constructor.
using SafeDType = typename std::conditional<
std::is_same<DType, __half>::value, uint16_t, DType>::type;
SafeDType ret = 0;
#else
DType ret = 0; DType ret = 0;
#endif
device->CopyDataFromTo( device->CopyDataFromTo(
static_cast<DType*>(array->data) + index, 0, &ret, 0, static_cast<DType*>(array->data) + index, 0, reinterpret_cast<DType*>(&ret), 0,
sizeof(DType), array->ctx, DLContext{kDLCPU, 0}, sizeof(DType), array->ctx, DLContext{kDLCPU, 0},
array->dtype, nullptr); array->dtype, nullptr);
return ret; return reinterpret_cast<DType&>(ret);
} }
template int32_t IndexSelect<kDLGPU, int32_t>(NDArray array, int64_t index); template int32_t IndexSelect<kDLGPU, int32_t>(NDArray array, int64_t index);
template int64_t IndexSelect<kDLGPU, int64_t>(NDArray array, int64_t index); template int64_t IndexSelect<kDLGPU, int64_t>(NDArray array, int64_t index);
template uint32_t IndexSelect<kDLGPU, uint32_t>(NDArray array, int64_t index); template uint32_t IndexSelect<kDLGPU, uint32_t>(NDArray array, int64_t index);
template uint64_t IndexSelect<kDLGPU, uint64_t>(NDArray array, int64_t index); template uint64_t IndexSelect<kDLGPU, uint64_t>(NDArray array, int64_t index);
#ifdef USE_FP16
template __half IndexSelect<kDLGPU, __half>(NDArray array, int64_t index);
#endif
template float IndexSelect<kDLGPU, float>(NDArray array, int64_t index); template float IndexSelect<kDLGPU, float>(NDArray array, int64_t index);
template double IndexSelect<kDLGPU, double>(NDArray array, int64_t index); template double IndexSelect<kDLGPU, double>(NDArray array, int64_t index);
......
...@@ -224,6 +224,9 @@ NDArray Full(DType val, int64_t length, DLContext ctx) { ...@@ -224,6 +224,9 @@ NDArray Full(DType val, int64_t length, DLContext ctx) {
template IdArray Full<kDLGPU, int32_t>(int32_t val, int64_t length, DLContext ctx); template IdArray Full<kDLGPU, int32_t>(int32_t val, int64_t length, DLContext ctx);
template IdArray Full<kDLGPU, int64_t>(int64_t val, int64_t length, DLContext ctx); template IdArray Full<kDLGPU, int64_t>(int64_t val, int64_t length, DLContext ctx);
#ifdef USE_FP16
template IdArray Full<kDLGPU, __half>(__half val, int64_t length, DLContext ctx);
#endif
template IdArray Full<kDLGPU, float>(float val, int64_t length, DLContext ctx); template IdArray Full<kDLGPU, float>(float val, int64_t length, DLContext ctx);
template IdArray Full<kDLGPU, double>(double val, int64_t length, DLContext ctx); template IdArray Full<kDLGPU, double>(double val, int64_t length, DLContext ctx);
......
...@@ -39,10 +39,16 @@ void Scatter_(IdArray index, NDArray value, NDArray out) { ...@@ -39,10 +39,16 @@ void Scatter_(IdArray index, NDArray value, NDArray out) {
template void Scatter_<kDLGPU, int32_t, int32_t>(IdArray, NDArray, NDArray); template void Scatter_<kDLGPU, int32_t, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLGPU, int64_t, int32_t>(IdArray, NDArray, NDArray); template void Scatter_<kDLGPU, int64_t, int32_t>(IdArray, NDArray, NDArray);
#ifdef USE_FP16
template void Scatter_<kDLGPU, __half, int32_t>(IdArray, NDArray, NDArray);
#endif
template void Scatter_<kDLGPU, float, int32_t>(IdArray, NDArray, NDArray); template void Scatter_<kDLGPU, float, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLGPU, double, int32_t>(IdArray, NDArray, NDArray); template void Scatter_<kDLGPU, double, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLGPU, int32_t, int64_t>(IdArray, NDArray, NDArray); template void Scatter_<kDLGPU, int32_t, int64_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLGPU, int64_t, int64_t>(IdArray, NDArray, NDArray); template void Scatter_<kDLGPU, int64_t, int64_t>(IdArray, NDArray, NDArray);
#ifdef USE_FP16
template void Scatter_<kDLGPU, __half, int64_t>(IdArray, NDArray, NDArray);
#endif
template void Scatter_<kDLGPU, float, int64_t>(IdArray, NDArray, NDArray); template void Scatter_<kDLGPU, float, int64_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLGPU, double, int64_t>(IdArray, NDArray, NDArray); template void Scatter_<kDLGPU, double, int64_t>(IdArray, NDArray, NDArray);
......
...@@ -52,6 +52,12 @@ NDArray CSRGetData( ...@@ -52,6 +52,12 @@ NDArray CSRGetData(
return rst; return rst;
} }
#ifdef USE_FP16
template NDArray CSRGetData<kDLGPU, int32_t, __half>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, __half filler);
template NDArray CSRGetData<kDLGPU, int64_t, __half>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, __half filler);
#endif
template NDArray CSRGetData<kDLGPU, int32_t, float>( template NDArray CSRGetData<kDLGPU, int32_t, float>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, float filler); CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, float filler);
template NDArray CSRGetData<kDLGPU, int64_t, float>( template NDArray CSRGetData<kDLGPU, int64_t, float>(
......
...@@ -253,6 +253,12 @@ std::pair<CSRMatrix, NDArray> CSRMM( ...@@ -253,6 +253,12 @@ std::pair<CSRMatrix, NDArray> CSRMM(
} }
} }
#ifdef USE_FP16
template std::pair<CSRMatrix, NDArray> CSRMM<kDLGPU, int32_t, __half>(
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
template std::pair<CSRMatrix, NDArray> CSRMM<kDLGPU, int64_t, __half>(
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
#endif
template std::pair<CSRMatrix, NDArray> CSRMM<kDLGPU, int32_t, float>( template std::pair<CSRMatrix, NDArray> CSRMM<kDLGPU, int32_t, float>(
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray); const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
template std::pair<CSRMatrix, NDArray> CSRMM<kDLGPU, int64_t, float>( template std::pair<CSRMatrix, NDArray> CSRMM<kDLGPU, int64_t, float>(
......
...@@ -166,6 +166,12 @@ std::pair<CSRMatrix, NDArray> CSRSum( ...@@ -166,6 +166,12 @@ std::pair<CSRMatrix, NDArray> CSRSum(
} }
} }
#ifdef USE_FP16
template std::pair<CSRMatrix, NDArray> CSRSum<kDLGPU, int32_t, __half>(
const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
template std::pair<CSRMatrix, NDArray> CSRSum<kDLGPU, int64_t, __half>(
const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
#endif
template std::pair<CSRMatrix, NDArray> CSRSum<kDLGPU, int32_t, float>( template std::pair<CSRMatrix, NDArray> CSRSum<kDLGPU, int32_t, float>(
const std::vector<CSRMatrix>&, const std::vector<NDArray>&); const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
template std::pair<CSRMatrix, NDArray> CSRSum<kDLGPU, int64_t, float>( template std::pair<CSRMatrix, NDArray> CSRSum<kDLGPU, int64_t, float>(
......
...@@ -34,6 +34,32 @@ struct CSRGEMM { ...@@ -34,6 +34,32 @@ struct CSRGEMM {
} }
}; };
#ifdef USE_FP16
template <>
struct CSRGEMM<__half> {
template <typename... Args>
static inline cusparseStatus_t bufferSizeExt(Args... args) {
// TODO(ndickson): There is no cusparseHcsrgemm2_bufferSizeExt, so a different
// implementation would be required.
LOG(FATAL) << "CSRGEMM::bufferSizeExt does not support dtype half (FP16).";
return static_cast<cusparseStatus_t>(0);
}
template <typename... Args>
static inline cusparseStatus_t nnz(Args... args) {
return cusparseXcsrgemm2Nnz(args...);
}
template <typename... Args>
static inline cusparseStatus_t compute(Args... args) {
// TODO(ndickson): There is no cusparseHcsrgemm2, so a different
// implementation would be required.
LOG(FATAL) << "CSRGEMM::compute does not support dtype half (FP16).";
return static_cast<cusparseStatus_t>(0);
}
};
#endif
template <> template <>
struct CSRGEMM<float> { struct CSRGEMM<float> {
template <typename... Args> template <typename... Args>
...@@ -91,6 +117,32 @@ struct CSRGEAM { ...@@ -91,6 +117,32 @@ struct CSRGEAM {
} }
}; };
#ifdef USE_FP16
template <>
struct CSRGEAM<__half> {
template <typename... Args>
static inline cusparseStatus_t bufferSizeExt(Args... args) {
// TODO(ndickson): There is no cusparseHcsrgeam2_bufferSizeExt, so a different
// implementation would be required.
LOG(FATAL) << "CSRGEAM::bufferSizeExt does not support dtype half (FP16).";
return static_cast<cusparseStatus_t>(0);
}
template <typename... Args>
static inline cusparseStatus_t nnz(Args... args) {
return cusparseXcsrgeam2Nnz(args...);
}
template <typename... Args>
static inline cusparseStatus_t compute(Args... args) {
// TODO(ndickson): There is no cusparseHcsrgeam2, so a different
// implementation would be required.
LOG(FATAL) << "CSRGEAM::compute does not support dtype half (FP16).";
return static_cast<cusparseStatus_t>(0);
}
};
#endif
template <> template <>
struct CSRGEAM<float> { struct CSRGEAM<float> {
template <typename... Args> template <typename... Args>
......
...@@ -28,10 +28,24 @@ cublasStatus_t Xgeam(cublasHandle_t handle, cublasOperation_t transa, ...@@ -28,10 +28,24 @@ cublasStatus_t Xgeam(cublasHandle_t handle, cublasOperation_t transa,
const DType* alpha, const DType* A, int lda, const DType* alpha, const DType* A, int lda,
const DType* beta, const DType* B, int ldb, const DType* beta, const DType* B, int ldb,
DType* C, int ldc) { DType* C, int ldc) {
LOG(INFO) << "Not supported dtype"; LOG(FATAL) << "Not supported dtype";
return CUBLAS_STATUS_EXECUTION_FAILED; return CUBLAS_STATUS_EXECUTION_FAILED;
} }
#ifdef USE_FP16
template <>
cublasStatus_t Xgeam<__half>(cublasHandle_t handle, cublasOperation_t transa,
cublasOperation_t transb, int m, int n,
const __half* alpha, const __half* A, int lda,
const __half* beta, const __half* B, int ldb,
__half* C, int ldc) {
// TODO(ndickson): There is no cublasHgeam, so a different
// implementation would be required.
LOG(FATAL) << "Xgeam does not support dtype half (FP16)";
return CUBLAS_STATUS_EXECUTION_FAILED;
}
#endif
template <> template <>
cublasStatus_t Xgeam<float>(cublasHandle_t handle, cublasOperation_t transa, cublasStatus_t Xgeam<float>(cublasHandle_t handle, cublasOperation_t transa,
cublasOperation_t transb, int m, int n, cublasOperation_t transb, int m, int n,
......
...@@ -166,10 +166,19 @@ __global__ void _LinearSearchKernel( ...@@ -166,10 +166,19 @@ __global__ void _LinearSearchKernel(
break; break;
} }
} }
if (v == -1) if (v == -1) {
out[tx] = filler; out[tx] = filler;
else } else {
out[tx] = weights ? weights[v] : v; // The casts here are to be able to handle DType being __half.
// GCC treats int64_t as a distinct type from long long, so
// without the explcit cast to long long, it errors out saying
// that the implicit cast results in an ambiguous choice of
// constructor for __half.
// The using statement is to avoid a linter error about using
// long or long long.
using LongLong = long long; // NOLINT
out[tx] = weights ? weights[v] : DType(LongLong(v));
}
tx += stride_x; tx += stride_x;
} }
} }
......
...@@ -24,6 +24,9 @@ constexpr DLDataType DLDataTypeTraits<int32_t>::dtype; ...@@ -24,6 +24,9 @@ constexpr DLDataType DLDataTypeTraits<int32_t>::dtype;
constexpr DLDataType DLDataTypeTraits<int64_t>::dtype; constexpr DLDataType DLDataTypeTraits<int64_t>::dtype;
constexpr DLDataType DLDataTypeTraits<uint32_t>::dtype; constexpr DLDataType DLDataTypeTraits<uint32_t>::dtype;
constexpr DLDataType DLDataTypeTraits<uint64_t>::dtype; constexpr DLDataType DLDataTypeTraits<uint64_t>::dtype;
#ifdef USE_FP16
constexpr DLDataType DLDataTypeTraits<__half>::dtype;
#endif
constexpr DLDataType DLDataTypeTraits<float>::dtype; constexpr DLDataType DLDataTypeTraits<float>::dtype;
constexpr DLDataType DLDataTypeTraits<double>::dtype; constexpr DLDataType DLDataTypeTraits<double>::dtype;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment