Unverified Commit 23077c42 authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #309 from InfiniTensor/issue/254/fix

issue/254/fix 为elementwise算子添加bf16支持
parents 2790a7b2 ceda7c1c
......@@ -19,7 +19,7 @@ infiniStatus_t Descriptor::create(
const auto &a_shape = a_desc->shape();
const auto &b_shape = b_desc->shape();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64);
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64, INFINI_DTYPE_BF16);
CHECK_SAME_SHAPE(c_shape, a_shape, b_shape);
......@@ -43,6 +43,8 @@ infiniStatus_t Descriptor::calculate(
return _device_info->calculate<AddOp, float>(_info, output, inputs, stream);
case INFINI_DTYPE_F64:
return _device_info->calculate<AddOp, double>(_info, output, inputs, stream);
case INFINI_DTYPE_BF16:
return _device_info->calculate<AddOp, bf16_t>(_info, output, inputs, stream);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
......
......@@ -20,7 +20,7 @@ infiniStatus_t Descriptor::create(
const auto &a_shape = a_desc->shape();
const auto &b_shape = b_desc->shape();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64);
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64, INFINI_DTYPE_BF16);
CHECK_SAME_SHAPE(c_shape, a_shape, b_shape);
......@@ -44,6 +44,8 @@ infiniStatus_t Descriptor::calculate(
switch (_dtype) {
case INFINI_DTYPE_F16:
return _device_info->calculate<256, AddOp, half>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_BF16:
return _device_info->calculate<256, AddOp, __nv_bfloat16>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F32:
return _device_info->calculate<256, AddOp, float>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F64:
......
......@@ -2,6 +2,7 @@
#define __ADD_CUDA_H__
#include "../../../elementwise/cuda/elementwise_cuda.cuh"
#include <cuda_bf16.h>
#include <cuda_fp16.h>
namespace op::add::cuda {
......@@ -12,7 +13,7 @@ public:
__device__ __forceinline__ T operator()(const T &a, const T &b) const {
if constexpr (std::is_same_v<T, half2>) {
return __hadd2(a, b);
} else if constexpr (std::is_same_v<T, half>) {
} else if constexpr (std::is_same_v<T, half> || std::is_same_v<T, __nv_bfloat16>) {
return __hadd(a, b);
} else if constexpr (std::is_same_v<T, float>) {
return __fadd_rd(a, b);
......
......@@ -21,7 +21,7 @@ infiniStatus_t Descriptor::create(
const auto &min_shape = min_desc->shape();
const auto &max_shape = max_desc->shape();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64);
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64, INFINI_DTYPE_BF16);
CHECK_SAME_SHAPE(out_shape, in_shape);
CHECK_SAME_SHAPE(out_shape, min_shape);
CHECK_SAME_SHAPE(out_shape, max_shape);
......@@ -45,6 +45,8 @@ infiniStatus_t Descriptor::calculate(
return _device_info->calculate<ClipOp, float>(_info, output, inputs, stream);
case INFINI_DTYPE_F64:
return _device_info->calculate<ClipOp, double>(_info, output, inputs, stream);
case INFINI_DTYPE_BF16:
return _device_info->calculate<ClipOp, bf16_t>(_info, output, inputs, stream);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
......
......@@ -22,7 +22,7 @@ infiniStatus_t Descriptor::create(
const auto &min_shape = min_desc->shape();
const auto &max_shape = max_desc->shape();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64);
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64, INFINI_DTYPE_BF16);
CHECK_SAME_SHAPE(out_shape, in_shape);
CHECK_SAME_SHAPE(out_shape, min_shape);
CHECK_SAME_SHAPE(out_shape, max_shape);
......@@ -50,6 +50,8 @@ infiniStatus_t Descriptor::calculate(
return _device_info->calculate<256, ClipOp, float>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F64:
return _device_info->calculate<256, ClipOp, double>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_BF16:
return _device_info->calculate<256, ClipOp, __nv_bfloat16>(_info, workspace, output, inputs, stream);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
......
......@@ -19,7 +19,7 @@ infiniStatus_t Descriptor::create(
const auto &a_shape = a_desc->shape();
const auto &b_shape = b_desc->shape();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64);
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64, INFINI_DTYPE_BF16);
CHECK_SAME_SHAPE(out_shape, a_shape, b_shape);
......@@ -43,6 +43,8 @@ infiniStatus_t Descriptor::calculate(
return _device_info->calculate<MulOp, float>(_info, output, inputs, stream);
case INFINI_DTYPE_F64:
return _device_info->calculate<MulOp, double>(_info, output, inputs, stream);
case INFINI_DTYPE_BF16:
return _device_info->calculate<MulOp, bf16_t>(_info, output, inputs, stream);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
......
......@@ -20,7 +20,7 @@ infiniStatus_t Descriptor::create(
const auto &a_shape = a_desc->shape();
const auto &b_shape = b_desc->shape();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64);
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64, INFINI_DTYPE_BF16);
CHECK_SAME_SHAPE(c_shape, a_shape, b_shape);
......@@ -48,6 +48,8 @@ infiniStatus_t Descriptor::calculate(
return _device_info->calculate<256, MulOp, float>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F64:
return _device_info->calculate<256, MulOp, double>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_BF16:
return _device_info->calculate<256, MulOp, __nv_bfloat16>(_info, workspace, output, inputs, stream);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
......
......@@ -2,6 +2,7 @@
#define __MUL_CUDA_H__
#include "../../../elementwise/cuda/elementwise_cuda.cuh"
#include <cuda_bf16.h>
#include <cuda_fp16.h>
namespace op::mul::cuda {
......@@ -11,7 +12,7 @@ typedef struct MulOp {
__device__ __forceinline__ T operator()(const T &a, const T &b) const {
if constexpr (std::is_same_v<T, half2>) {
return __hmul2(a, b);
} else if constexpr (std::is_same_v<T, half>) {
} else if constexpr (std::is_same_v<T, half> || std::is_same_v<T, __nv_bfloat16>) {
return __hmul(a, b);
} else if constexpr (std::is_same_v<T, float>) {
return __fmul_rn(a, b);
......
......@@ -19,7 +19,7 @@ infiniStatus_t Descriptor::create(
const auto &a_shape = a_desc->shape();
const auto &b_shape = b_desc->shape();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64);
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64, INFINI_DTYPE_BF16);
CHECK_SAME_SHAPE(c_shape, a_shape, b_shape);
......@@ -43,6 +43,8 @@ infiniStatus_t Descriptor::calculate(
return _device_info->calculate<SubOp, float>(_info, output, inputs, stream);
case INFINI_DTYPE_F64:
return _device_info->calculate<SubOp, double>(_info, output, inputs, stream);
case INFINI_DTYPE_BF16:
return _device_info->calculate<SubOp, bf16_t>(_info, output, inputs, stream);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
......
......@@ -20,7 +20,7 @@ infiniStatus_t Descriptor::create(
const auto &a_shape = a_desc->shape();
const auto &b_shape = b_desc->shape();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64);
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64, INFINI_DTYPE_BF16);
CHECK_SAME_SHAPE(c_shape, a_shape, b_shape);
......@@ -48,6 +48,8 @@ infiniStatus_t Descriptor::calculate(
return _device_info->calculate<256, SubOp, float>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F64:
return _device_info->calculate<256, SubOp, double>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_BF16:
return _device_info->calculate<256, SubOp, __nv_bfloat16>(_info, workspace, output, inputs, stream);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
......
......@@ -2,6 +2,7 @@
#define __SUB_CUDA_H__
#include "../../../elementwise/cuda/elementwise_cuda.cuh"
#include <cuda_bf16.h>
#include <cuda_fp16.h>
namespace op::sub::cuda {
......@@ -12,7 +13,7 @@ public:
__device__ __forceinline__ T operator()(const T &a, const T &b) const {
if constexpr (std::is_same_v<T, half2>) {
return __hsub2(a, b);
} else if constexpr (std::is_same_v<T, half>) {
} else if constexpr (std::is_same_v<T, half> || std::is_same_v<T, __nv_bfloat16>) {
return __hsub(a, b);
} else if constexpr (std::is_same_v<T, float>) {
return __fsub_rd(a, b);
......
......@@ -59,12 +59,13 @@ _TEST_CASES = [
]
# Data types used for testing
_TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.F32]
_TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.F32, InfiniDtype.BF16]
# Tolerance map for different data types
_TOLERANCE_MAP = {
InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-3},
InfiniDtype.F32: {"atol": 1e-7, "rtol": 1e-7},
InfiniDtype.BF16: {"atol": 1e-3, "rtol": 1e-3},
}
DEBUG = False
......
......@@ -52,12 +52,13 @@ _TEST_CASES_ = [
]
_TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.F32]
_TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.F32, InfiniDtype.BF16]
_TOLERANCE_MAP = {
InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-3},
InfiniDtype.F32: {"atol": 1e-7, "rtol": 1e-6},
InfiniDtype.BF16: {"atol": 1e-3, "rtol": 1e-3},
}
......
......@@ -59,12 +59,13 @@ _TEST_CASES = [
]
# Data types used for testing
_TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.F32]
_TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.F32, InfiniDtype.BF16]
# Tolerance map for different data types
_TOLERANCE_MAP = {
InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-3},
InfiniDtype.F32: {"atol": 1e-7, "rtol": 1e-7},
InfiniDtype.BF16: {"atol": 1e-3, "rtol": 1e-3},
}
DEBUG = False
......
......@@ -59,12 +59,13 @@ _TEST_CASES = [
]
# Data types used for testing
_TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.F32]
_TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.F32, InfiniDtype.BF16]
# Tolerance map for different data types
_TOLERANCE_MAP = {
InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-3},
InfiniDtype.F32: {"atol": 1e-7, "rtol": 1e-7},
InfiniDtype.BF16: {"atol": 1e-3, "rtol": 1e-3},
}
DEBUG = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment