Unverified Commit 0166515c authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge branch 'main' into issue/300

parents f0300ff3 a23c4d13
...@@ -6,6 +6,8 @@ ...@@ -6,6 +6,8 @@
inline infiniDtype_t ggmlTypeToInfiniType(GGML_TYPE type) { inline infiniDtype_t ggmlTypeToInfiniType(GGML_TYPE type) {
switch (type) { switch (type) {
case GGML_TYPE_Q8_K:
return INFINI_DTYPE_BOOL;
case GGML_TYPE_I8: case GGML_TYPE_I8:
return INFINI_DTYPE_I8; return INFINI_DTYPE_I8;
case GGML_TYPE_I16: case GGML_TYPE_I16:
...@@ -14,10 +16,10 @@ inline infiniDtype_t ggmlTypeToInfiniType(GGML_TYPE type) { ...@@ -14,10 +16,10 @@ inline infiniDtype_t ggmlTypeToInfiniType(GGML_TYPE type) {
return INFINI_DTYPE_I32; return INFINI_DTYPE_I32;
case GGML_TYPE_I64: case GGML_TYPE_I64:
return INFINI_DTYPE_I64; return INFINI_DTYPE_I64;
case GGML_TYPE_F16:
return INFINI_DTYPE_F16;
case GGML_TYPE_BF16: case GGML_TYPE_BF16:
return INFINI_DTYPE_BF16; return INFINI_DTYPE_BF16;
case GGML_TYPE_F16:
return INFINI_DTYPE_F16;
case GGML_TYPE_F32: case GGML_TYPE_F32:
return INFINI_DTYPE_F32; return INFINI_DTYPE_F32;
case GGML_TYPE_F64: case GGML_TYPE_F64:
......
...@@ -9,12 +9,16 @@ ...@@ -9,12 +9,16 @@
inline double getVal(void *ptr, GGML_TYPE ggml_type) { inline double getVal(void *ptr, GGML_TYPE ggml_type) {
switch (ggml_type) { switch (ggml_type) {
case GGML_TYPE_BF16:
return utils::cast<float>(*(bf16_t *)ptr);
case GGML_TYPE_F16: case GGML_TYPE_F16:
return utils::cast<double>(*(fp16_t *)ptr); return utils::cast<float>(*(fp16_t *)ptr);
case GGML_TYPE_F32: case GGML_TYPE_F32:
return *(float *)ptr; return *(float *)ptr;
case GGML_TYPE_F64: case GGML_TYPE_F64:
return *(double *)ptr; return *(double *)ptr;
case GGML_TYPE_Q8_K:
return *(bool *)ptr;
case GGML_TYPE_I8: case GGML_TYPE_I8:
return *(int8_t *)ptr; return *(int8_t *)ptr;
case GGML_TYPE_I16: case GGML_TYPE_I16:
...@@ -30,12 +34,16 @@ inline double getVal(void *ptr, GGML_TYPE ggml_type) { ...@@ -30,12 +34,16 @@ inline double getVal(void *ptr, GGML_TYPE ggml_type) {
inline size_t ggmlSizeOf(GGML_TYPE ggml_type) { inline size_t ggmlSizeOf(GGML_TYPE ggml_type) {
switch (ggml_type) { switch (ggml_type) {
case GGML_TYPE_BF16:
return sizeof(bf16_t);
case GGML_TYPE_F16: case GGML_TYPE_F16:
return sizeof(fp16_t); return sizeof(fp16_t);
case GGML_TYPE_F32: case GGML_TYPE_F32:
return sizeof(float); return sizeof(float);
case GGML_TYPE_F64: case GGML_TYPE_F64:
return sizeof(double); return sizeof(double);
case GGML_TYPE_Q8_K:
return sizeof(bool);
case GGML_TYPE_I8: case GGML_TYPE_I8:
return sizeof(int8_t); return sizeof(int8_t);
case GGML_TYPE_I16: case GGML_TYPE_I16:
......
#include "tensor.hpp" #include "tensor.hpp"
#include "gguf.hpp"
#include "utils.hpp" #include "utils.hpp"
#include <cstring> #include <cstring>
#include <infinirt.h> #include <infinirt.h>
...@@ -19,6 +20,40 @@ void printData(const T *data, const std::vector<size_t> &shape, const std::vecto ...@@ -19,6 +20,40 @@ void printData(const T *data, const std::vector<size_t> &shape, const std::vecto
} }
} }
// The type int8_t is represented by signed char, with a range of –128 to 127.
// It may contain non-printable characters and thus cannot be printed directly.
template <>
void printData(const int8_t *data, const std::vector<size_t> &shape,
const std::vector<ptrdiff_t> &strides, size_t dim) {
if (dim == shape.size() - 1) {
for (size_t i = 0; i < shape[dim]; i++) {
std::cout << static_cast<int>(*(data + i * strides[dim])) << " ";
}
std::cout << std::endl;
} else if (dim < shape.size() - 1) {
for (size_t i = 0; i < shape[dim]; i++) {
printData(data + i * strides[dim], shape, strides, dim + 1);
std::cout << std::endl;
}
}
}
template <>
void printData(const bf16_t *data, const std::vector<size_t> &shape,
const std::vector<ptrdiff_t> &strides, size_t dim) {
if (dim == shape.size() - 1) {
for (size_t i = 0; i < shape[dim]; i++) {
std::cout << utils::cast<float>(*(data + i * strides[dim])) << " ";
}
std::cout << std::endl;
} else if (dim < shape.size() - 1) {
for (size_t i = 0; i < shape[dim]; i++) {
printData(data + i * strides[dim], shape, strides, dim + 1);
std::cout << std::endl;
}
}
}
template <> template <>
void printData(const fp16_t *data, const std::vector<size_t> &shape, void printData(const fp16_t *data, const std::vector<size_t> &shape,
const std::vector<ptrdiff_t> &strides, size_t dim) { const std::vector<ptrdiff_t> &strides, size_t dim) {
...@@ -26,6 +61,7 @@ void printData(const fp16_t *data, const std::vector<size_t> &shape, ...@@ -26,6 +61,7 @@ void printData(const fp16_t *data, const std::vector<size_t> &shape,
for (size_t i = 0; i < shape[dim]; i++) { for (size_t i = 0; i < shape[dim]; i++) {
std::cout << utils::cast<float>(*(data + i * strides[dim])) << " "; std::cout << utils::cast<float>(*(data + i * strides[dim])) << " ";
} }
std::cout << std::endl;
} else if (dim < shape.size() - 1) { } else if (dim < shape.size() - 1) {
for (size_t i = 0; i < shape[dim]; i++) { for (size_t i = 0; i < shape[dim]; i++) {
printData(data + i * strides[dim], shape, strides, dim + 1); printData(data + i * strides[dim], shape, strides, dim + 1);
...@@ -227,6 +263,8 @@ void Tensor::debug() const { ...@@ -227,6 +263,8 @@ void Tensor::debug() const {
auto tensor = to(INFINI_DEVICE_CPU, 0); auto tensor = to(INFINI_DEVICE_CPU, 0);
std::cout << "Tensor: " << tensor->info() << std::endl; std::cout << "Tensor: " << tensor->info() << std::endl;
switch (_ggml_type) { switch (_ggml_type) {
case GGML_TYPE_BF16:
printData((bf16_t *)(tensor->data()), _shape, _strides, 0);
case GGML_TYPE_F16: case GGML_TYPE_F16:
printData((fp16_t *)(tensor->data()), _shape, _strides, 0); printData((fp16_t *)(tensor->data()), _shape, _strides, 0);
break; break;
...@@ -236,6 +274,9 @@ void Tensor::debug() const { ...@@ -236,6 +274,9 @@ void Tensor::debug() const {
case GGML_TYPE_F64: case GGML_TYPE_F64:
printData((double *)(tensor->data()), _shape, _strides, 0); printData((double *)(tensor->data()), _shape, _strides, 0);
break; break;
case GGML_TYPE_Q8_K:
printData((bool *)(tensor->data()), _shape, _strides, 0);
break;
case GGML_TYPE_I8: case GGML_TYPE_I8:
printData((int8_t *)(tensor->data()), _shape, _strides, 0); printData((int8_t *)(tensor->data()), _shape, _strides, 0);
break; break;
...@@ -245,6 +286,9 @@ void Tensor::debug() const { ...@@ -245,6 +286,9 @@ void Tensor::debug() const {
case GGML_TYPE_I32: case GGML_TYPE_I32:
printData((int32_t *)(tensor->data()), _shape, _strides, 0); printData((int32_t *)(tensor->data()), _shape, _strides, 0);
break; break;
case GGML_TYPE_I64:
printData((int64_t *)(tensor->data()), _shape, _strides, 0);
break;
default: default:
std::cout << "Unsupported GGML type" << std::endl; std::cout << "Unsupported GGML type" << std::endl;
break; break;
......
...@@ -5,8 +5,8 @@ ...@@ -5,8 +5,8 @@
#ifdef ENABLE_CPU_API #ifdef ENABLE_CPU_API
#include "cpu/cpu_handle.h" #include "cpu/cpu_handle.h"
#endif #endif
#ifdef ENABLE_CUDA_API #if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API)
#include "cuda/cuda_handle.h" #include "nvidia/nvidia_handle.h"
#endif #endif
#ifdef ENABLE_CAMBRICON_API #ifdef ENABLE_CAMBRICON_API
#include "bang/bang_handle.h" #include "bang/bang_handle.h"
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
#include "kunlun/kunlun_handle.h" #include "kunlun/kunlun_handle.h"
#endif #endif
#ifdef ENABLE_METAX_API #ifdef ENABLE_METAX_API
#include "maca/maca_handle.h" #include "metax/metax_handle.h"
#endif #endif
__C infiniStatus_t infiniopCreateHandle(infiniopHandle_t *handle_ptr) { __C infiniStatus_t infiniopCreateHandle(infiniopHandle_t *handle_ptr) {
...@@ -41,8 +41,11 @@ __C infiniStatus_t infiniopCreateHandle(infiniopHandle_t *handle_ptr) { ...@@ -41,8 +41,11 @@ __C infiniStatus_t infiniopCreateHandle(infiniopHandle_t *handle_ptr) {
#ifdef ENABLE_CPU_API #ifdef ENABLE_CPU_API
CREATE(INFINI_DEVICE_CPU, cpu); CREATE(INFINI_DEVICE_CPU, cpu);
#endif #endif
#ifdef ENABLE_CUDA_API #ifdef ENABLE_NVIDIA_API
CREATE(INFINI_DEVICE_NVIDIA, cuda::nvidia); CREATE(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#ifdef ENABLE_ILUVATAR_API
CREATE(INFINI_DEVICE_ILUVATAR, iluvatar);
#endif #endif
#ifdef ENABLE_CAMBRICON_API #ifdef ENABLE_CAMBRICON_API
CREATE(INFINI_DEVICE_CAMBRICON, bang::cambricon); CREATE(INFINI_DEVICE_CAMBRICON, bang::cambricon);
...@@ -57,7 +60,7 @@ __C infiniStatus_t infiniopCreateHandle(infiniopHandle_t *handle_ptr) { ...@@ -57,7 +60,7 @@ __C infiniStatus_t infiniopCreateHandle(infiniopHandle_t *handle_ptr) {
CREATE(INFINI_DEVICE_KUNLUN, kunlun); CREATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif #endif
#ifdef ENABLE_METAX_API #ifdef ENABLE_METAX_API
CREATE(INFINI_DEVICE_METAX, maca); CREATE(INFINI_DEVICE_METAX, metax);
#endif #endif
default: default:
...@@ -78,8 +81,11 @@ __C infiniStatus_t infiniopDestroyHandle(infiniopHandle_t handle) { ...@@ -78,8 +81,11 @@ __C infiniStatus_t infiniopDestroyHandle(infiniopHandle_t handle) {
#ifdef ENABLE_CPU_API #ifdef ENABLE_CPU_API
DELETE(INFINI_DEVICE_CPU, cpu); DELETE(INFINI_DEVICE_CPU, cpu);
#endif #endif
#ifdef ENABLE_CUDA_API #ifdef ENABLE_NVIDIA_API
DELETE(INFINI_DEVICE_NVIDIA, cuda::nvidia); DELETE(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#ifdef ENABLE_ILUVATAR_API
DELETE(INFINI_DEVICE_ILUVATAR, iluvatar);
#endif #endif
#ifdef ENABLE_CAMBRICON_API #ifdef ENABLE_CAMBRICON_API
DELETE(INFINI_DEVICE_CAMBRICON, bang::cambricon); DELETE(INFINI_DEVICE_CAMBRICON, bang::cambricon);
...@@ -94,7 +100,7 @@ __C infiniStatus_t infiniopDestroyHandle(infiniopHandle_t handle) { ...@@ -94,7 +100,7 @@ __C infiniStatus_t infiniopDestroyHandle(infiniopHandle_t handle) {
DELETE(INFINI_DEVICE_KUNLUN, kunlun); DELETE(INFINI_DEVICE_KUNLUN, kunlun);
#endif #endif
#ifdef ENABLE_METAX_API #ifdef ENABLE_METAX_API
DELETE(INFINI_DEVICE_METAX, maca); DELETE(INFINI_DEVICE_METAX, metax);
#endif #endif
default: default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......
#include "../../../utils.h" #include "../../../utils.h"
#include "../pool.h" #include "../pool.h"
#include "maca_handle.h" #include "metax_handle.h"
#include <hcblas/hcblas.h> #include <hcblas/hcblas.h>
#include <hcdnn/hcdnn.h> #include <hcdnn/hcdnn.h>
#include <memory> #include <memory>
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#define CHECK_MCBLAS(API) CHECK_INTERNAL(API, HCBLAS_STATUS_SUCCESS) #define CHECK_MCBLAS(API) CHECK_INTERNAL(API, HCBLAS_STATUS_SUCCESS)
#define CHECK_MCDNN(API) CHECK_INTERNAL(API, HCDNN_STATUS_SUCCESS) #define CHECK_MCDNN(API) CHECK_INTERNAL(API, HCDNN_STATUS_SUCCESS)
namespace device::maca { namespace device::metax {
class Handle::Internal { class Handle::Internal {
Pool<hcblasHandle_t> mcblas_handles; Pool<hcblasHandle_t> mcblas_handles;
...@@ -39,4 +39,4 @@ public: ...@@ -39,4 +39,4 @@ public:
hcdnnDataType_t getHcdnnDtype(infiniDtype_t dt); hcdnnDataType_t getHcdnnDtype(infiniDtype_t dt);
} // namespace device::maca } // namespace device::metax
#include "common_maca.h" #include "metax_common.h"
namespace device::maca { namespace device::metax {
Handle::Handle(infiniDevice_t device, int device_id) Handle::Handle(infiniDevice_t device, int device_id)
: InfiniopHandle{device, device_id}, : InfiniopHandle{device, device_id},
_internal(std::make_shared<Handle::Internal>(device_id)) {} _internal(std::make_shared<Handle::Internal>(device_id)) {}
...@@ -83,4 +83,4 @@ infiniStatus_t Handle::create(InfiniopHandle **handle_ptr, int device_id) { ...@@ -83,4 +83,4 @@ infiniStatus_t Handle::create(InfiniopHandle **handle_ptr, int device_id) {
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
} // namespace device::maca } // namespace device::metax
#ifndef __INFINIOP_MACA_HANDLE_H__ #ifndef __INFINIOP_METAX_HANDLE_H__
#define __INFINIOP_MACA_HANDLE_H__ #define __INFINIOP_METAX_HANDLE_H__
#include "../../handle.h" #include "../../handle.h"
#include <memory> #include <memory>
namespace device::maca { namespace device::metax {
struct Handle : public InfiniopHandle { struct Handle : public InfiniopHandle {
Handle(int device_id); Handle(int device_id);
class Internal; class Internal;
...@@ -20,6 +20,6 @@ private: ...@@ -20,6 +20,6 @@ private:
std::shared_ptr<Internal> _internal; std::shared_ptr<Internal> _internal;
}; };
} // namespace device::maca } // namespace device::metax
#endif // __INFINIOP_MACA_HANDLE_H__ #endif // __INFINIOP_METAX_HANDLE_H__
#define INFINIOP_MACA_KERNEL __global__ void #define INFINIOP_METAX_KERNEL __global__ void
// Posible maximum number of threads per block for MACA architectures
// Posible maximum number of threads per block for METAX architectures
// Used for picking correct kernel launch configuration // Used for picking correct kernel launch configuration
#define MACA_BLOCK_SIZE_1024 1024 #define METAX_BLOCK_SIZE_1024 1024
#define MACA_BLOCK_SIZE_512 512 #define METAX_BLOCK_SIZE_512 512
#define CHECK_METAX(API) CHECK_INTERNAL(API, hcSuccess)
#define CHECK_MACA(API) CHECK_INTERNAL(API, hcSuccess) using cuda_bfloat16 = hpcc_bfloat16;
using cuda_bfloat162 = hpcc_bfloat162;
namespace device::maca { namespace device::metax {
// return the memory offset of original tensor, given the flattened index of broadcasted tensor // return the memory offset of original tensor, given the flattened index of broadcasted tensor
__forceinline__ __device__ __host__ size_t __forceinline__ __device__ __host__ size_t
...@@ -37,10 +41,8 @@ indexToOffset( ...@@ -37,10 +41,8 @@ indexToOffset(
} }
return res; return res;
} }
} // namespace device::maca } // namespace device::metax
#ifdef ENABLE_MACA_API
#include <maca_fp16.h>
__forceinline__ __device__ float __forceinline__ __device__ float
exp_(const float val) { exp_(const float val) {
return expf(val); return expf(val);
...@@ -48,7 +50,7 @@ exp_(const float val) { ...@@ -48,7 +50,7 @@ exp_(const float val) {
__forceinline__ __device__ long double __forceinline__ __device__ long double
exp_(const long double val) { exp_(const long double val) {
return expl(val); return exp(val);
} }
__forceinline__ __device__ double __forceinline__ __device__ double
...@@ -60,4 +62,8 @@ __forceinline__ __device__ __half ...@@ -60,4 +62,8 @@ __forceinline__ __device__ __half
exp_(const __half x) { exp_(const __half x) {
return hexp(x); return hexp(x);
} }
#endif
__forceinline__ __device__ __hpcc_bfloat16
exp_(const __hpcc_bfloat16 x) {
return hexp(x);
}
#include "cuda_handle.cuh" #include "nvidia_handle.cuh"
namespace device::cuda { namespace device {
namespace nvidia {
Handle::Handle(infiniDevice_t device, int device_id) Handle::Handle(infiniDevice_t device, int device_id)
: InfiniopHandle{device, device_id}, : InfiniopHandle{device, device_id},
...@@ -34,6 +36,7 @@ infiniStatus_t Handle::Internal::useCublas(cudaStream_t stream, const Fn<cublasH ...@@ -34,6 +36,7 @@ infiniStatus_t Handle::Internal::useCublas(cudaStream_t stream, const Fn<cublasH
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
#ifdef ENABLE_CUDNN_API
infiniStatus_t Handle::Internal::useCudnn(cudaStream_t stream, const Fn<cudnnHandle_t> &f) const { infiniStatus_t Handle::Internal::useCudnn(cudaStream_t stream, const Fn<cudnnHandle_t> &f) const {
auto handle = dnn_handles.pop(); auto handle = dnn_handles.pop();
if (!handle) { if (!handle) {
...@@ -44,6 +47,7 @@ infiniStatus_t Handle::Internal::useCudnn(cudaStream_t stream, const Fn<cudnnHan ...@@ -44,6 +47,7 @@ infiniStatus_t Handle::Internal::useCudnn(cudaStream_t stream, const Fn<cudnnHan
dnn_handles.push(std::move(*handle)); dnn_handles.push(std::move(*handle));
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
#endif
int Handle::Internal::warpSize() const { return _warp_size; } int Handle::Internal::warpSize() const { return _warp_size; }
int Handle::Internal::maxThreadsPerBlock() const { return _max_threads_per_block; } int Handle::Internal::maxThreadsPerBlock() const { return _max_threads_per_block; }
...@@ -54,6 +58,7 @@ int Handle::Internal::gridSizeX() const { return _grid_size[0]; } ...@@ -54,6 +58,7 @@ int Handle::Internal::gridSizeX() const { return _grid_size[0]; }
int Handle::Internal::gridSizeY() const { return _grid_size[1]; } int Handle::Internal::gridSizeY() const { return _grid_size[1]; }
int Handle::Internal::gridSizeZ() const { return _grid_size[2]; } int Handle::Internal::gridSizeZ() const { return _grid_size[2]; }
#ifdef ENABLE_CUDNN_API
cudnnDataType_t getCudnnDtype(infiniDtype_t dt) { cudnnDataType_t getCudnnDtype(infiniDtype_t dt) {
switch (dt) { switch (dt) {
case INFINI_DTYPE_F16: case INFINI_DTYPE_F16:
...@@ -68,7 +73,7 @@ cudnnDataType_t getCudnnDtype(infiniDtype_t dt) { ...@@ -68,7 +73,7 @@ cudnnDataType_t getCudnnDtype(infiniDtype_t dt) {
return CUDNN_DATA_INT8; return CUDNN_DATA_INT8;
case INFINI_DTYPE_I32: case INFINI_DTYPE_I32:
return CUDNN_DATA_INT32; return CUDNN_DATA_INT32;
#ifndef ENABLE_ILUVATAR_CUDA_API #ifndef ENABLE_ILUVATAR_API
case INFINI_DTYPE_I64: case INFINI_DTYPE_I64:
return CUDNN_DATA_INT64; return CUDNN_DATA_INT64;
#endif #endif
...@@ -78,17 +83,25 @@ cudnnDataType_t getCudnnDtype(infiniDtype_t dt) { ...@@ -78,17 +83,25 @@ cudnnDataType_t getCudnnDtype(infiniDtype_t dt) {
return CUDNN_DATA_FLOAT; return CUDNN_DATA_FLOAT;
} }
} }
#endif
namespace nvidia { infiniStatus_t Handle::create(InfiniopHandle **handle_ptr, int device_id) {
*handle_ptr = new Handle(INFINI_DEVICE_NVIDIA, device_id);
return INFINI_STATUS_SUCCESS;
}
} // namespace nvidia
namespace iluvatar {
Handle::Handle(int device_id) Handle::Handle(int device_id)
: cuda::Handle(INFINI_DEVICE_NVIDIA, device_id) {} : nvidia::Handle(INFINI_DEVICE_ILUVATAR, device_id) {}
infiniStatus_t Handle::create(InfiniopHandle **handle_ptr, int device_id) { infiniStatus_t Handle::create(InfiniopHandle **handle_ptr, int device_id) {
*handle_ptr = new Handle(device_id); *handle_ptr = new Handle(device_id);
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
} // namespace nvidia } // namespace iluvatar
} // namespace device::cuda } // namespace device
#ifndef __INFINIOP_CUDA_COMMON_CUH__ #ifndef __INFINIOP_CUDA_COMMON_CUH__
#define __INFINIOP_CUDA_COMMON_CUH__ #define __INFINIOP_CUDA_COMMON_CUH__
#include "cuda_handle.cuh"
#include "infinicore.h" #include "infinicore.h"
#include "nvidia_handle.cuh"
namespace device::cuda { namespace device::nvidia {
#ifdef ENABLE_CUDNN_API
cudnnDataType_t getCudnnDtype(infiniDtype_t dt); cudnnDataType_t getCudnnDtype(infiniDtype_t dt);
#endif
} // namespace device::cuda } // namespace device::nvidia
#endif // __INFINIOP_CUDA_COMMON_CUH__ #endif // __INFINIOP_CUDA_COMMON_CUH__
...@@ -3,19 +3,24 @@ ...@@ -3,19 +3,24 @@
#include "../../../utils.h" #include "../../../utils.h"
#include "../pool.h" #include "../pool.h"
#include "cuda_handle.h" #include "nvidia_handle.h"
#include <cublas_v2.h> #include <cublas_v2.h>
#include <cudnn.h>
#include <functional> #include <functional>
#ifdef ENABLE_CUDNN_API
#include <cudnn.h>
#endif
#define CHECK_CUBLAS(API) CHECK_INTERNAL(API, CUBLAS_STATUS_SUCCESS) #define CHECK_CUBLAS(API) CHECK_INTERNAL(API, CUBLAS_STATUS_SUCCESS)
#define CHECK_CUDNN(API) CHECK_INTERNAL(API, CUDNN_STATUS_SUCCESS) #define CHECK_CUDNN(API) CHECK_INTERNAL(API, CUDNN_STATUS_SUCCESS)
namespace device::cuda { namespace device::nvidia {
class Handle::Internal { class Handle::Internal {
Pool<cublasHandle_t> blas_handles; Pool<cublasHandle_t> blas_handles;
#ifdef ENABLE_CUDNN_API
Pool<cudnnHandle_t> dnn_handles; Pool<cudnnHandle_t> dnn_handles;
#endif
int _warp_size, int _warp_size,
_max_threads_per_block, _max_threads_per_block,
...@@ -29,7 +34,9 @@ public: ...@@ -29,7 +34,9 @@ public:
Internal(int); Internal(int);
infiniStatus_t useCublas(cudaStream_t stream, const Fn<cublasHandle_t> &f) const; infiniStatus_t useCublas(cudaStream_t stream, const Fn<cublasHandle_t> &f) const;
#ifdef ENABLE_CUDNN_API
infiniStatus_t useCudnn(cudaStream_t stream, const Fn<cudnnHandle_t> &f) const; infiniStatus_t useCudnn(cudaStream_t stream, const Fn<cudnnHandle_t> &f) const;
#endif
int warpSize() const; int warpSize() const;
int maxThreadsPerBlock() const; int maxThreadsPerBlock() const;
...@@ -41,6 +48,6 @@ public: ...@@ -41,6 +48,6 @@ public:
int gridSizeZ() const; int gridSizeZ() const;
}; };
} // namespace device::cuda } // namespace device::nvidia
#endif // __INFINIOP_CUDA_HANDLE_CUH__ #endif // __INFINIOP_CUDA_HANDLE_CUH__
...@@ -4,7 +4,9 @@ ...@@ -4,7 +4,9 @@
#include "../../handle.h" #include "../../handle.h"
#include <memory> #include <memory>
namespace device::cuda { namespace device {
namespace nvidia {
struct Handle : public InfiniopHandle { struct Handle : public InfiniopHandle {
class Internal; class Internal;
...@@ -13,21 +15,26 @@ struct Handle : public InfiniopHandle { ...@@ -13,21 +15,26 @@ struct Handle : public InfiniopHandle {
protected: protected:
Handle(infiniDevice_t device, int device_id); Handle(infiniDevice_t device, int device_id);
public:
static infiniStatus_t create(InfiniopHandle **handle_ptr, int device_id);
private: private:
std::shared_ptr<Internal> _internal; std::shared_ptr<Internal> _internal;
}; };
namespace nvidia { } // namespace nvidia
class Handle : public cuda::Handle { namespace iluvatar {
struct Handle : public nvidia::Handle {
Handle(int device_id); Handle(int device_id);
public: public:
static infiniStatus_t create(InfiniopHandle **handle_ptr, int device_id); static infiniStatus_t create(InfiniopHandle **handle_ptr, int device_id);
}; };
} // namespace nvidia } // namespace iluvatar
} // namespace device::cuda } // namespace device
#endif // __INFINIOP_CUDA_HANDLE_H__ #endif // __INFINIOP_CUDA_HANDLE_H__
...@@ -4,6 +4,9 @@ ...@@ -4,6 +4,9 @@
#define INFINIOP_CUDA_KERNEL __global__ void #define INFINIOP_CUDA_KERNEL __global__ void
#endif #endif
#include <cuda_bf16.h>
#include <cuda_fp16.h>
// Posible maximum number of threads per block for CUDA architectures // Posible maximum number of threads per block for CUDA architectures
// Used for picking correct kernel launch configuration // Used for picking correct kernel launch configuration
#define CUDA_BLOCK_SIZE_4096 4096 #define CUDA_BLOCK_SIZE_4096 4096
...@@ -12,8 +15,10 @@ ...@@ -12,8 +15,10 @@
#define CHECK_CUDA(API) CHECK_INTERNAL(API, cudaSuccess) #define CHECK_CUDA(API) CHECK_INTERNAL(API, cudaSuccess)
namespace device::cuda { using cuda_bfloat16 = nv_bfloat16;
using cuda_bfloat162 = nv_bfloat162;
namespace device::nvidia {
// return the memory offset of original tensor, given the flattened index of broadcasted tensor // return the memory offset of original tensor, given the flattened index of broadcasted tensor
__forceinline__ __device__ __host__ size_t __forceinline__ __device__ __host__ size_t
indexToReducedOffset( indexToReducedOffset(
...@@ -43,16 +48,14 @@ indexToOffset( ...@@ -43,16 +48,14 @@ indexToOffset(
} }
return res; return res;
} }
} // namespace device::cuda } // namespace device::nvidia
#ifdef ENABLE_CUDA_API
#include <cuda_fp16.h>
__forceinline__ __device__ float __forceinline__ __device__ float
exp_(const float val) { exp_(const float val) {
return expf(val); return expf(val);
} }
#ifndef ENABLE_ILUVATAR_CUDA_API #ifndef ENABLE_ILUVATAR_API
__forceinline__ __device__ long double __forceinline__ __device__ long double
exp_(const long double val) { exp_(const long double val) {
return expl(val); return expl(val);
...@@ -73,4 +76,3 @@ __forceinline__ __device__ __nv_bfloat16 ...@@ -73,4 +76,3 @@ __forceinline__ __device__ __nv_bfloat16
exp_(const __nv_bfloat16 x) { exp_(const __nv_bfloat16 x) {
return hexp(x); return hexp(x);
} }
#endif
#ifndef __INFINIOP_ELEMENTWISE_MACA_H__ #ifndef __INFINIOP_ELEMENTWISE_METAX_H__
#define __INFINIOP_ELEMENTWISE_MACA_H__ #define __INFINIOP_ELEMENTWISE_METAX_H__
#include "../../../utils.h" #include "../../../utils.h"
#include "../../devices/maca/common_maca.h" #include "../../devices/metax/metax_common.h"
#include "../../devices/maca/maca_kernel_common.h" #include "../../devices/metax/metax_kernel_common.h"
#include "elementwise_maca_api.h" #include "elementwise_metax_api.h"
namespace op::elementwise::maca { namespace op::elementwise::metax {
template <typename T> template <typename T>
__device__ __forceinline__ const T *typedInputPtr(const void *ptr) { __device__ __forceinline__ const T *typedInputPtr(const void *ptr) {
return reinterpret_cast<const T *>(ptr); return reinterpret_cast<const T *>(ptr);
...@@ -14,7 +14,7 @@ __device__ __forceinline__ const T *typedInputPtr(const void *ptr) { ...@@ -14,7 +14,7 @@ __device__ __forceinline__ const T *typedInputPtr(const void *ptr) {
__device__ __forceinline__ size_t getOutputIndex(size_t idx, bool is_contiguous, size_t ndim, __device__ __forceinline__ size_t getOutputIndex(size_t idx, bool is_contiguous, size_t ndim,
const size_t *shape, const ptrdiff_t *strides) { const size_t *shape, const ptrdiff_t *strides) {
return is_contiguous ? idx : device::maca::indexToOffset(idx, ndim, shape, strides); return is_contiguous ? idx : device::metax::indexToOffset(idx, ndim, shape, strides);
} }
struct InputIndexer { struct InputIndexer {
...@@ -30,8 +30,8 @@ struct InputIndexer { ...@@ -30,8 +30,8 @@ struct InputIndexer {
return input_contiguous[input_id] return input_contiguous[input_id]
? idx ? idx
: (input_broadcasted[input_id] : (input_broadcasted[input_id]
? device::maca::indexToReducedOffset(idx, ndim, output_strides, input_strides + input_id * ndim) ? device::metax::indexToReducedOffset(idx, ndim, output_strides, input_strides + input_id * ndim)
: device::maca::indexToOffset(idx, ndim, input_shapes + input_id * ndim, input_strides + input_id * ndim)); : device::metax::indexToOffset(idx, ndim, input_shapes + input_id * ndim, input_strides + input_id * ndim));
} }
}; };
...@@ -41,7 +41,7 @@ __device__ __forceinline__ void unpackInputsAndApply(F &&f, std::index_sequence< ...@@ -41,7 +41,7 @@ __device__ __forceinline__ void unpackInputsAndApply(F &&f, std::index_sequence<
} }
template <size_t N, typename Op, typename Tdata, typename... Args> template <size_t N, typename Op, typename Tdata, typename... Args>
INFINIOP_MACA_KERNEL elementwiseKernel( INFINIOP_METAX_KERNEL elementwiseKernel(
size_t output_size, size_t output_size,
size_t ndim, size_t ndim,
bool output_contiguous, bool output_contiguous,
...@@ -72,7 +72,7 @@ INFINIOP_MACA_KERNEL elementwiseKernel( ...@@ -72,7 +72,7 @@ INFINIOP_MACA_KERNEL elementwiseKernel(
} }
template <typename Op, typename Tout, typename... Tin> template <typename Op, typename Tout, typename... Tin>
INFINIOP_MACA_KERNEL elementwiseKernel( INFINIOP_METAX_KERNEL elementwiseKernel(
size_t output_size, size_t output_size,
size_t ndim, size_t ndim,
bool output_contiguous, bool output_contiguous,
...@@ -102,9 +102,9 @@ INFINIOP_MACA_KERNEL elementwiseKernel( ...@@ -102,9 +102,9 @@ INFINIOP_MACA_KERNEL elementwiseKernel(
} }
struct DeviceImpl::Opaque { struct DeviceImpl::Opaque {
std::shared_ptr<device::maca::Handle::Internal> internal; std::shared_ptr<device::metax::Handle::Internal> internal;
Opaque(const std::shared_ptr<device::maca::Handle::Internal> &internal) Opaque(const std::shared_ptr<device::metax::Handle::Internal> &internal)
: internal(internal) {} : internal(internal) {}
template <uint32_t BLOCK_SIZE, size_t N, typename Op, typename Tdata, typename... Args> template <uint32_t BLOCK_SIZE, size_t N, typename Op, typename Tdata, typename... Args>
...@@ -159,8 +159,8 @@ private: ...@@ -159,8 +159,8 @@ private:
const int8_t *d_meta_start = reinterpret_cast<int8_t *>(workspace) + input_arr_size; const int8_t *d_meta_start = reinterpret_cast<int8_t *>(workspace) + input_arr_size;
// copy the input pointer array and meta to device // copy the input pointer array and meta to device
CHECK_MACA(hcMemcpyAsync(workspace, h_inputs_arr, input_arr_size, hcMemcpyHostToDevice, stream)); CHECK_METAX(hcMemcpyAsync(workspace, h_inputs_arr, input_arr_size, hcMemcpyHostToDevice, stream));
CHECK_MACA(hcMemcpyAsync((void *)d_meta_start, info_meta_start, info.getMetaMemSize(), hcMemcpyHostToDevice, stream)); CHECK_METAX(hcMemcpyAsync((void *)d_meta_start, info_meta_start, info.getMetaMemSize(), hcMemcpyHostToDevice, stream));
// offset/assign the pointers // offset/assign the pointers
d_inputs_arr = reinterpret_cast<const void **>(workspace); d_inputs_arr = reinterpret_cast<const void **>(workspace);
...@@ -259,6 +259,6 @@ infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &inf ...@@ -259,6 +259,6 @@ infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &inf
std::forward<Args>(args)...); std::forward<Args>(args)...);
} }
} // namespace op::elementwise::maca } // namespace op::elementwise::metax
#endif #endif
#ifndef __INFINIOP_ELEMENTWISE_MACA_API_H__ #ifndef __INFINIOP_ELEMENTWISE_METAX_API_H__
#define __INFINIOP_ELEMENTWISE_MACA_API_H__ #define __INFINIOP_ELEMENTWISE_METAX_API_H__
#include "../elementwise.h" #include "../elementwise.h"
namespace op::elementwise::maca { namespace op::elementwise::metax {
class DeviceImpl final { class DeviceImpl final {
struct Opaque; struct Opaque;
...@@ -37,23 +37,23 @@ public: ...@@ -37,23 +37,23 @@ public:
void *stream, void *stream,
Args &&...args); Args &&...args);
}; };
} // namespace op::elementwise::maca } // namespace op::elementwise::metax
#define CREATE_ELEMENTWISE_MACA_DESCRIPTOR(HANDLE, DTYPE, OUT_DESC, INPUT_DESC_VEC) \ #define CREATE_ELEMENTWISE_METAX_DESCRIPTOR(HANDLE, DTYPE, OUT_DESC, INPUT_DESC_VEC) \
\ \
auto info_result = op::elementwise::ElementwiseInfo::create(OUT_DESC, INPUT_DESC_VEC); \ auto info_result = op::elementwise::ElementwiseInfo::create(OUT_DESC, INPUT_DESC_VEC); \
CHECK_RESULT(info_result); \ CHECK_RESULT(info_result); \
auto info = info_result.take(); \ auto info = info_result.take(); \
auto workspace_size = info.getMetaMemSize() + info.getInputSize() * sizeof(void *); \ auto workspace_size = info.getMetaMemSize() + info.getInputSize() * sizeof(void *); \
\ \
auto device_impl_result = op::elementwise::maca::DeviceImpl::create(HANDLE->internal()); \ auto device_impl_result = op::elementwise::metax::DeviceImpl::create(HANDLE->internal()); \
CHECK_RESULT(device_impl_result); \ CHECK_RESULT(device_impl_result); \
\ \
*desc_ptr = new Descriptor( \ *desc_ptr = new Descriptor( \
DTYPE, \ DTYPE, \
std::move(info), \ std::move(info), \
std::move(device_impl_result.take()), \ std::move(device_impl_result.take()), \
workspace_size, \ workspace_size, \
HANDLE->device, \ HANDLE->device, \
HANDLE->device_id); HANDLE->device_id);
#endif // __INFINIOP_ELEMENTWISE_MACA_API_H__ #endif // __INFINIOP_ELEMENTWISE_METAX_API_H__
...@@ -2,11 +2,11 @@ ...@@ -2,11 +2,11 @@
#define __INFINIOP_ELEMENTWISE_CUDA_H__ #define __INFINIOP_ELEMENTWISE_CUDA_H__
#include "../../../utils.h" #include "../../../utils.h"
#include "../../devices/cuda/cuda_common.cuh" #include "../../devices/nvidia/nvidia_common.cuh"
#include "../../devices/cuda/cuda_kernel_common.cuh" #include "../../devices/nvidia/nvidia_kernel_common.cuh"
#include "elementwise_cuda_api.cuh" #include "elementwise_nvidia_api.cuh"
namespace op::elementwise::cuda { namespace op::elementwise::nvidia {
/** /**
* @brief Casts an untyped device pointer to a typed pointer of type T. * @brief Casts an untyped device pointer to a typed pointer of type T.
...@@ -33,7 +33,7 @@ __device__ __forceinline__ const T *typedInputPtr(const void *ptr) { ...@@ -33,7 +33,7 @@ __device__ __forceinline__ const T *typedInputPtr(const void *ptr) {
*/ */
__device__ __forceinline__ size_t getOutputIndex(size_t idx, bool is_contiguous, size_t ndim, __device__ __forceinline__ size_t getOutputIndex(size_t idx, bool is_contiguous, size_t ndim,
const size_t *shape, const ptrdiff_t *strides) { const size_t *shape, const ptrdiff_t *strides) {
return is_contiguous ? idx : device::cuda::indexToOffset(idx, ndim, shape, strides); return is_contiguous ? idx : device::nvidia::indexToOffset(idx, ndim, shape, strides);
} }
/** /**
...@@ -61,8 +61,8 @@ struct InputIndexer { ...@@ -61,8 +61,8 @@ struct InputIndexer {
return input_contiguous[input_id] return input_contiguous[input_id]
? idx ? idx
: (input_broadcasted[input_id] : (input_broadcasted[input_id]
? device::cuda::indexToReducedOffset(idx, ndim, output_strides, input_strides + input_id * ndim) ? device::nvidia::indexToReducedOffset(idx, ndim, output_strides, input_strides + input_id * ndim)
: device::cuda::indexToOffset(idx, ndim, input_shapes + input_id * ndim, input_strides + input_id * ndim)); : device::nvidia::indexToOffset(idx, ndim, input_shapes + input_id * ndim, input_strides + input_id * ndim));
} }
}; };
...@@ -186,9 +186,9 @@ INFINIOP_CUDA_KERNEL elementwiseKernel( ...@@ -186,9 +186,9 @@ INFINIOP_CUDA_KERNEL elementwiseKernel(
} }
struct DeviceImpl::Opaque { struct DeviceImpl::Opaque {
std::shared_ptr<device::cuda::Handle::Internal> internal; std::shared_ptr<device::nvidia::Handle::Internal> internal;
Opaque(const std::shared_ptr<device::cuda::Handle::Internal> &internal) Opaque(const std::shared_ptr<device::nvidia::Handle::Internal> &internal)
: internal(internal) {} : internal(internal) {}
/** /**
...@@ -414,6 +414,6 @@ infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &inf ...@@ -414,6 +414,6 @@ infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &inf
std::forward<Args>(args)...); std::forward<Args>(args)...);
} }
} // namespace op::elementwise::cuda } // namespace op::elementwise::nvidia
#endif // __INFINIOP_ELEMENTWISE_CUDA_H__ #endif // __INFINIOP_ELEMENTWISE_CUDA_H__
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#include "../elementwise.h" #include "../elementwise.h"
namespace op::elementwise::cuda { namespace op::elementwise::nvidia {
/** /**
* @brief Define the methods and info needed by CUDA to perform elementwise operation * @brief Define the methods and info needed by CUDA to perform elementwise operation
...@@ -77,7 +77,7 @@ public: ...@@ -77,7 +77,7 @@ public:
void *stream, void *stream,
Args &&...args); Args &&...args);
}; };
} // namespace op::elementwise::cuda } // namespace op::elementwise::nvidia
/** /**
* @brief Define the process for initializing a Descriptor of an elementwise operation * @brief Define the process for initializing a Descriptor of an elementwise operation
...@@ -88,22 +88,22 @@ public: ...@@ -88,22 +88,22 @@ public:
* @param OUT_DESC The output tensor descriptor. * @param OUT_DESC The output tensor descriptor.
* @param INPUT_DESC_VEC A vector containing input tensor descriptors. * @param INPUT_DESC_VEC A vector containing input tensor descriptors.
*/ */
#define CREATE_ELEMENTWISE_CUDA_DESCRIPTOR(HANDLE, DTYPE, OUT_DESC, INPUT_DESC_VEC) \ #define CREATE_ELEMENTWISE_CUDA_DESCRIPTOR(HANDLE, DTYPE, OUT_DESC, INPUT_DESC_VEC) \
\ \
auto info_result = op::elementwise::ElementwiseInfo::create(OUT_DESC, INPUT_DESC_VEC); \ auto info_result = op::elementwise::ElementwiseInfo::create(OUT_DESC, INPUT_DESC_VEC); \
CHECK_RESULT(info_result); \ CHECK_RESULT(info_result); \
auto info = info_result.take(); \ auto info = info_result.take(); \
auto workspace_size = info.getMetaMemSize() + info.getInputSize() * sizeof(void *); \ auto workspace_size = info.getMetaMemSize() + info.getInputSize() * sizeof(void *); \
\ \
auto device_impl_result = op::elementwise::cuda::DeviceImpl::create(HANDLE->internal()); \ auto device_impl_result = op::elementwise::nvidia::DeviceImpl::create(HANDLE->internal()); \
CHECK_RESULT(device_impl_result); \ CHECK_RESULT(device_impl_result); \
\ \
*desc_ptr = new Descriptor( \ *desc_ptr = new Descriptor( \
DTYPE, \ DTYPE, \
std::move(info), \ std::move(info), \
std::move(device_impl_result.take()), \ std::move(device_impl_result.take()), \
workspace_size, \ workspace_size, \
HANDLE->device, \ HANDLE->device, \
HANDLE->device_id); HANDLE->device_id);
#endif // __INFINIOP_ELEMENTWISE_CUDA_API_H__ #endif // __INFINIOP_ELEMENTWISE_CUDA_API_H__
import functools
import inspect
import itertools
import pathlib
import ninetoothed
from ninetoothed.aot import _HEADER_PATH
CURRENT_FILE_PATH = pathlib.Path(__file__)
BUILD_DIRECTORY_PATH = (
CURRENT_FILE_PATH.parent.parent.parent.parent / "build" / "ninetoothed"
)
def build(premake, constexpr_param_grid, caller, op_name, output_dir):
headers = []
all_param_names = []
launches = []
for combination in _generate_param_value_combinations(constexpr_param_grid):
arrangement, application, tensors = premake(**combination)
for param_name, param_value in combination.items():
if isinstance(param_value, str):
combination[param_name] = (
f"INFINI_DTYPE_{combination[param_name].replace('fp', 'F').upper()}"
)
combination = {f"{name}_": value for name, value in combination.items()}
kernel_name = f"{op_name}_{_generate_suffix(combination.values())}"
ninetoothed.make(
arrangement,
application,
tensors,
caller=caller,
kernel_name=kernel_name,
output_dir=output_dir,
)
header = output_dir / f"{kernel_name}.h"
param_names = ("stream",) + tuple(
inspect.signature(application).parameters.keys()
)
launch = f""" if ({_generate_condition(combination)})
return launch_{kernel_name}({", ".join(param_names)});"""
headers.append(header)
all_param_names.append(param_names)
launches.append(launch)
includes = "\n".join(f'#include "{header}"' for header in headers)
param_names = list(
functools.reduce(
lambda x, y: dict.fromkeys(x) | dict.fromkeys(y),
sorted(all_param_names, key=len, reverse=True),
{},
)
)
param_types = [
"NineToothedStream",
] + ["NineToothedTensor" for _ in range(len(param_names) - 1)]
for param_name in combination:
param_names.append(param_name)
param_types.append("int")
param_decls = ", ".join(
f"{type} {param}" for param, type in zip(param_names, param_types)
)
source_file_name = f"{op_name}.c"
header_file_name = f"{op_name}.h"
func_sig = f"NineToothedResult launch_{op_name}({param_decls})"
joined_launches = "\n".join(launches)
op_decl = f'#ifdef __cplusplus\nextern "C" {func_sig};\n#else\n{func_sig};\n#endif'
op_def = f"""{func_sig} {{
{joined_launches}
return INFINI_STATUS_NOT_IMPLEMENTED;
}}"""
source_content = f"""#include "{header_file_name}"
#include "infinicore.h"
{includes}\n\n{op_def}\n"""
header_content = f"""#include "{_HEADER_PATH}"
\n{op_decl}\n"""
(BUILD_DIRECTORY_PATH / source_file_name).write_text(source_content)
(BUILD_DIRECTORY_PATH / header_file_name).write_text(header_content)
def _generate_condition(combination):
return " && ".join(f"{param} == {value}" for param, value in combination.items())
def _generate_suffix(values):
return "_".join(f"{value}" for value in values)
def _generate_param_value_combinations(param_grid):
keys = list(param_grid.keys())
value_combinations = itertools.product(*param_grid.values())
return tuple(dict(zip(keys, combination)) for combination in value_combinations)
...@@ -19,7 +19,7 @@ infiniStatus_t Descriptor::create( ...@@ -19,7 +19,7 @@ infiniStatus_t Descriptor::create(
const auto &a_shape = a_desc->shape(); const auto &a_shape = a_desc->shape();
const auto &b_shape = b_desc->shape(); const auto &b_shape = b_desc->shape();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64); CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64, INFINI_DTYPE_BF16);
CHECK_SAME_SHAPE(c_shape, a_shape, b_shape); CHECK_SAME_SHAPE(c_shape, a_shape, b_shape);
...@@ -43,6 +43,8 @@ infiniStatus_t Descriptor::calculate( ...@@ -43,6 +43,8 @@ infiniStatus_t Descriptor::calculate(
return _device_info->calculate<AddOp, float>(_info, output, inputs, stream); return _device_info->calculate<AddOp, float>(_info, output, inputs, stream);
case INFINI_DTYPE_F64: case INFINI_DTYPE_F64:
return _device_info->calculate<AddOp, double>(_info, output, inputs, stream); return _device_info->calculate<AddOp, double>(_info, output, inputs, stream);
case INFINI_DTYPE_BF16:
return _device_info->calculate<AddOp, bf16_t>(_info, output, inputs, stream);
default: default:
return INFINI_STATUS_BAD_TENSOR_DTYPE; return INFINI_STATUS_BAD_TENSOR_DTYPE;
} }
......
#ifndef __ADD_CUDA_H__ #ifndef __ADD_CUDA_H__
#define __ADD_CUDA_H__ #define __ADD_CUDA_H__
#include "../../../elementwise/cuda/elementwise_cuda.cuh"
#include <cuda_fp16.h>
namespace op::add::cuda { namespace op::add::cuda {
typedef struct AddOp { typedef struct AddOp {
public: public:
...@@ -12,7 +9,7 @@ public: ...@@ -12,7 +9,7 @@ public:
__device__ __forceinline__ T operator()(const T &a, const T &b) const { __device__ __forceinline__ T operator()(const T &a, const T &b) const {
if constexpr (std::is_same_v<T, half2>) { if constexpr (std::is_same_v<T, half2>) {
return __hadd2(a, b); return __hadd2(a, b);
} else if constexpr (std::is_same_v<T, half>) { } else if constexpr (std::is_same_v<T, half> || std::is_same_v<T, cuda_bfloat16>) {
return __hadd(a, b); return __hadd(a, b);
} else if constexpr (std::is_same_v<T, float>) { } else if constexpr (std::is_same_v<T, float>) {
return __fadd_rd(a, b); return __fadd_rd(a, b);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment