Unverified Commit f88d4ad8 authored by 蒋帅宏(Shuaihong_Jiang)'s avatar 蒋帅宏(Shuaihong_Jiang) Committed by GitHub
Browse files

issue/254: 添加算子在CPU和CUDA上对BF16的支持,并增加相应的测试代码 (#255)



* issue/254: 添加算子在CPU和CUDA上对BF16的支持,并增加相应的测试代码

* issue/254: 将修改后的算子格式化后重新提交

* 修改与最新main的冲突

* 解决冲突后rms_norm原本的精度过不了了,现在由
{"atol": 5e-3, "rtol": 5e-3}更改为
{"atol": 8e-3, "rtol": 8e-3}

* rms_norm在debug模式下FP16的测试用例失败了(本地测试能通过,github上过不了),
所以将容差增大了两倍进行测试

* 将rms_normd的测试输入缩放0.5,将容差改回原始值来进行ci测试

* issue/254: 1.使用CHECK_DTYPE宏来进行数据类型检验
2.在test的utils.py中添加了设备对BF16支持的检验

* issue/254: rms_norm测试fp16容差由
torch.float16: {"atol": 1e-3, "rtol": 1e-3},
改为torch.float16: {"atol": 2e-3, "rtol": 2e-3},
并删除对输入0.5的放缩

* issue/254: 在utils.py中debug方法和debug_all方法中
添加了对BF16的特判

* 修改支持BF16测试的设备类型检查方法

* 修改支持BF16测试的设备检查

* issue/254: reduce redundancy in rms_norm.py

* issue/254: add back the missing comment in rms_norm.py

* issue/254: add fp32 tolerance condition in causal_softmax.py

---------
Co-authored-by: default avatarZimin Li <coollizimin@gmail.com>
parent 105065e2
...@@ -20,7 +20,7 @@ infiniStatus_t Descriptor::create( ...@@ -20,7 +20,7 @@ infiniStatus_t Descriptor::create(
const auto &up_shape = up_desc->shape(); const auto &up_shape = up_desc->shape();
const auto &gate_shape = gate_desc->shape(); const auto &gate_shape = gate_desc->shape();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64); CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32, INFINI_DTYPE_F64);
CHECK_SAME_SHAPE(out_shape, up_shape, gate_shape); CHECK_SAME_SHAPE(out_shape, up_shape, gate_shape);
// create CUDA elementwise descriptor // create CUDA elementwise descriptor
...@@ -43,6 +43,8 @@ infiniStatus_t Descriptor::calculate( ...@@ -43,6 +43,8 @@ infiniStatus_t Descriptor::calculate(
switch (_dtype) { switch (_dtype) {
case INFINI_DTYPE_F16: case INFINI_DTYPE_F16:
return _device_info->calculate<256, SwiGLUOp, half>(_info, workspace, output, inputs, stream); return _device_info->calculate<256, SwiGLUOp, half>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_BF16:
return _device_info->calculate<256, SwiGLUOp, __nv_bfloat16>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F32: case INFINI_DTYPE_F32:
return _device_info->calculate<256, SwiGLUOp, float>(_info, workspace, output, inputs, stream); return _device_info->calculate<256, SwiGLUOp, float>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F64: case INFINI_DTYPE_F64:
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#define __SWIGLU_CUDA_H__ #define __SWIGLU_CUDA_H__
#include "../../../elementwise/cuda/elementwise_cuda.cuh" #include "../../../elementwise/cuda/elementwise_cuda.cuh"
#include <cuda_bf16.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
namespace op::swiglu::cuda { namespace op::swiglu::cuda {
...@@ -13,6 +14,15 @@ private: ...@@ -13,6 +14,15 @@ private:
return h2rcp(__hadd2(make_half2(1, 1), h2exp(__hneg2(x)))); return h2rcp(__hadd2(make_half2(1, 1), h2exp(__hneg2(x))));
} else if constexpr (std::is_same_v<T, half>) { } else if constexpr (std::is_same_v<T, half>) {
return hrcp(__hadd(half(1.f), __float2half(__expf(__half2float(__hneg(x)))))); return hrcp(__hadd(half(1.f), __float2half(__expf(__half2float(__hneg(x))))));
} else if constexpr (std::is_same_v<T, __nv_bfloat162>) {
float x0 = __bfloat162float(__low2bfloat16(x));
float x1 = __bfloat162float(__high2bfloat16(x));
float sig0 = __frcp_rn(__fadd_rn(1.0f, __expf(-x0)));
float sig1 = __frcp_rn(__fadd_rn(1.0f, __expf(-x1)));
return __floats2bfloat162_rn(sig0, sig1);
} else if constexpr (std::is_same_v<T, __nv_bfloat16>) {
float xf = __bfloat162float(x);
return __float2bfloat16_rn(__frcp_rn(__fadd_rn(1.0f, __expf(-xf))));
} else if constexpr (std::is_same_v<T, float>) { } else if constexpr (std::is_same_v<T, float>) {
return __frcp_rn(__fadd_rn(1, __expf(-x))); return __frcp_rn(__fadd_rn(1, __expf(-x)));
} else { } else {
...@@ -28,6 +38,23 @@ public: ...@@ -28,6 +38,23 @@ public:
return __hmul2(__hmul2(gate, sigmoid(gate)), up); return __hmul2(__hmul2(gate, sigmoid(gate)), up);
} else if constexpr (std::is_same_v<T, half>) { } else if constexpr (std::is_same_v<T, half>) {
return __hmul(__hmul(gate, sigmoid(gate)), up); return __hmul(__hmul(gate, sigmoid(gate)), up);
} else if constexpr (std::is_same_v<T, __nv_bfloat162>) {
__nv_bfloat162 sig = sigmoid(gate);
float gate0 = __bfloat162float(__low2bfloat16(gate));
float gate1 = __bfloat162float(__high2bfloat16(gate));
float sig0 = __bfloat162float(__low2bfloat16(sig));
float sig1 = __bfloat162float(__high2bfloat16(sig));
float up0 = __bfloat162float(__low2bfloat16(up));
float up1 = __bfloat162float(__high2bfloat16(up));
float res0 = __fmul_rn(__fmul_rn(gate0, sig0), up0);
float res1 = __fmul_rn(__fmul_rn(gate1, sig1), up1);
return __floats2bfloat162_rn(res0, res1);
} else if constexpr (std::is_same_v<T, __nv_bfloat16>) {
__nv_bfloat16 sig = sigmoid(gate);
float gatef = __bfloat162float(gate);
float sigf = __bfloat162float(sig);
float upf = __bfloat162float(up);
return __float2bfloat16_rn(__fmul_rn(__fmul_rn(gatef, sigf), upf));
} else if constexpr (std::is_same_v<T, float>) { } else if constexpr (std::is_same_v<T, float>) {
return __fmul_rn(__fmul_rn(gate, sigmoid(gate)), up); return __fmul_rn(__fmul_rn(gate, sigmoid(gate)), up);
} else { } else {
......
...@@ -2,32 +2,58 @@ ...@@ -2,32 +2,58 @@
namespace op::common_cpu::reduce_op { namespace op::common_cpu::reduce_op {
float sum(const fp16_t *data, size_t len, ptrdiff_t stride) { template <typename HalfType>
float sum_half_impl(const HalfType *data, size_t len, ptrdiff_t stride) {
float result = 0; float result = 0;
for (size_t i = 0; i < len; i++) { for (size_t i = 0; i < len; i++) {
result += utils::cast<float>(data[i * stride]); result += utils::cast<float>(data[i * stride]);
} }
return result; return result;
} }
float max(const fp16_t *data, size_t len, ptrdiff_t stride) { template <typename HalfType>
float max_half_impl(const HalfType *data, size_t len, ptrdiff_t stride) {
float result = utils::cast<float>(data[0]); float result = utils::cast<float>(data[0]);
for (size_t i = 1; i < len; i++) { for (size_t i = 1; i < len; i++) {
result = std::max(result, utils::cast<float>(data[i * stride])); result = std::max(result, utils::cast<float>(data[i * stride]));
} }
return result; return result;
} }
float sumSquared(const fp16_t *data, size_t len, ptrdiff_t stride) { template <typename HalfType>
float sumSquared_half_impl(const HalfType *data, size_t len, ptrdiff_t stride) {
float result = 0; float result = 0;
for (size_t i = 0; i < len; i++) { for (size_t i = 0; i < len; i++) {
float val = utils::cast<float>(data[i * stride]); float val = utils::cast<float>(data[i * stride]);
result += val * val; result += val * val;
} }
return result; return result;
} }
// fp16
float sum(const fp16_t *data, size_t len, ptrdiff_t stride) {
return sum_half_impl(data, len, stride);
}
float max(const fp16_t *data, size_t len, ptrdiff_t stride) {
return max_half_impl(data, len, stride);
}
float sumSquared(const fp16_t *data, size_t len, ptrdiff_t stride) {
return sumSquared_half_impl(data, len, stride);
}
// bf16
float sum(const bf16_t *data, size_t len, ptrdiff_t stride) {
return sum_half_impl(data, len, stride);
}
float max(const bf16_t *data, size_t len, ptrdiff_t stride) {
return max_half_impl(data, len, stride);
}
float sumSquared(const bf16_t *data, size_t len, ptrdiff_t stride) {
return sumSquared_half_impl(data, len, stride);
}
} // namespace op::common_cpu::reduce_op } // namespace op::common_cpu::reduce_op
...@@ -37,6 +37,7 @@ T sum(const T *data, size_t len, ptrdiff_t stride = 1) { ...@@ -37,6 +37,7 @@ T sum(const T *data, size_t len, ptrdiff_t stride = 1) {
} }
float sum(const fp16_t *data, size_t len, ptrdiff_t stride = 1); float sum(const fp16_t *data, size_t len, ptrdiff_t stride = 1);
float sum(const bf16_t *data, size_t len, ptrdiff_t stride = 1);
template <typename T, typename = std::enable_if_t<ReduceToSame<T>::value>> template <typename T, typename = std::enable_if_t<ReduceToSame<T>::value>>
T max(const T *data, size_t len, ptrdiff_t stride = 1) { T max(const T *data, size_t len, ptrdiff_t stride = 1) {
...@@ -49,6 +50,7 @@ T max(const T *data, size_t len, ptrdiff_t stride = 1) { ...@@ -49,6 +50,7 @@ T max(const T *data, size_t len, ptrdiff_t stride = 1) {
} }
float max(const fp16_t *data, size_t len, ptrdiff_t stride = 1); float max(const fp16_t *data, size_t len, ptrdiff_t stride = 1);
float max(const bf16_t *data, size_t len, ptrdiff_t stride = 1);
template <typename T, typename = std::enable_if_t<ReduceToSame<T>::value>> template <typename T, typename = std::enable_if_t<ReduceToSame<T>::value>>
T sumSquared(const T *data, size_t len, ptrdiff_t stride = 1) { T sumSquared(const T *data, size_t len, ptrdiff_t stride = 1) {
...@@ -62,6 +64,7 @@ T sumSquared(const T *data, size_t len, ptrdiff_t stride = 1) { ...@@ -62,6 +64,7 @@ T sumSquared(const T *data, size_t len, ptrdiff_t stride = 1) {
} }
float sumSquared(const fp16_t *data, size_t len, ptrdiff_t stride = 1); float sumSquared(const fp16_t *data, size_t len, ptrdiff_t stride = 1);
float sumSquared(const bf16_t *data, size_t len, ptrdiff_t stride = 1);
} // namespace reduce_op } // namespace reduce_op
......
...@@ -61,3 +61,25 @@ fp16_t _f32_to_f16(float val) { ...@@ -61,3 +61,25 @@ fp16_t _f32_to_f16(float val) {
return fp16_t{(uint16_t)sign}; return fp16_t{(uint16_t)sign};
} }
} }
float _bf16_to_f32(bf16_t val) {
// 只需把 bf16 放到 float32 高 16 bit,其余 16 位置 0。
uint32_t bits32 = static_cast<uint32_t>(val._v) << 16;
float out;
std::memcpy(&out, &bits32, sizeof(out));
return out;
}
bf16_t _f32_to_bf16(float val) {
uint32_t bits32;
std::memcpy(&bits32, &val, sizeof(bits32));
// 截断前先加 0x7FFF,再根据第 16 位(有效位的最低位)的奇偶做 round-to-nearest-even
const uint32_t rounding_bias = 0x00007FFF + // 0111 1111 1111 1111
((bits32 >> 16) & 1); // 尾数的有效位的最低位奇数时 +1,即实现舍入偶数
uint16_t bf16_bits = static_cast<uint16_t>((bits32 + rounding_bias) >> 16);
return bf16_t{bf16_bits};
}
...@@ -16,6 +16,9 @@ typedef struct CustomBFloat16 bf16_t; ...@@ -16,6 +16,9 @@ typedef struct CustomBFloat16 bf16_t;
float _f16_to_f32(fp16_t val); float _f16_to_f32(fp16_t val);
fp16_t _f32_to_f16(float val); fp16_t _f32_to_f16(float val);
float _bf16_to_f32(bf16_t val);
bf16_t _f32_to_bf16(float val);
namespace utils { namespace utils {
// General template for non-fp16_t conversions // General template for non-fp16_t conversions
template <typename TypeTo, typename TypeFrom> template <typename TypeTo, typename TypeFrom>
...@@ -25,11 +28,19 @@ TypeTo cast(TypeFrom val) { ...@@ -25,11 +28,19 @@ TypeTo cast(TypeFrom val) {
} else if constexpr (std::is_same<TypeTo, fp16_t>::value && std::is_same<TypeFrom, float>::value) { } else if constexpr (std::is_same<TypeTo, fp16_t>::value && std::is_same<TypeFrom, float>::value) {
return _f32_to_f16(val); return _f32_to_f16(val);
} else if constexpr (std::is_same<TypeTo, fp16_t>::value && !std::is_same<TypeFrom, float>::value) { } else if constexpr (std::is_same<TypeTo, fp16_t>::value && !std::is_same<TypeFrom, float>::value) {
return _f32_to_f16(static_cast<TypeTo>(val)); return _f32_to_f16(static_cast<float>(val));
} else if constexpr (std::is_same<TypeFrom, fp16_t>::value && std::is_same<TypeTo, float>::value) { } else if constexpr (std::is_same<TypeFrom, fp16_t>::value && std::is_same<TypeTo, float>::value) {
return _f16_to_f32(val); return _f16_to_f32(val);
} else if constexpr (std::is_same<TypeFrom, fp16_t>::value && !std::is_same<TypeTo, float>::value) { } else if constexpr (std::is_same<TypeFrom, fp16_t>::value && !std::is_same<TypeTo, float>::value) {
return static_cast<TypeTo>(_f16_to_f32(val)); return static_cast<TypeTo>(_f16_to_f32(val));
} else if constexpr (std::is_same<TypeTo, bf16_t>::value && std::is_same<TypeFrom, float>::value) {
return _f32_to_bf16(val);
} else if constexpr (std::is_same<TypeTo, bf16_t>::value && !std::is_same<TypeFrom, float>::value) {
return _f32_to_bf16(static_cast<float>(val));
} else if constexpr (std::is_same<TypeFrom, bf16_t>::value && std::is_same<TypeTo, float>::value) {
return _bf16_to_f32(val);
} else if constexpr (std::is_same<TypeFrom, bf16_t>::value && !std::is_same<TypeTo, float>::value) {
return static_cast<TypeTo>(_bf16_to_f32(val));
} else { } else {
return static_cast<TypeTo>(val); return static_cast<TypeTo>(val);
} }
......
...@@ -34,11 +34,13 @@ _TEST_CASES_ = [ ...@@ -34,11 +34,13 @@ _TEST_CASES_ = [
] ]
# Data types used for testing # Data types used for testing
_TENSOR_DTYPES = [torch.float16, torch.float32] _TENSOR_DTYPES = [torch.float16, torch.bfloat16, torch.float32]
# Tolerance map for different data types # Tolerance map for different data types
_TOLERANCE_MAP = { _TOLERANCE_MAP = {
torch.float16: {"atol": 1e-3, "rtol": 1e-2}, torch.float16: {"atol": 1e-3, "rtol": 1e-2},
torch.bfloat16: {"atol": 5e-3, "rtol": 5e-2},
torch.float32: {"atol": 1e-5, "rtol": 1e-5},
} }
...@@ -87,7 +89,7 @@ def test( ...@@ -87,7 +89,7 @@ def test(
y_stride=None, y_stride=None,
inplace=Inplace.OUT_OF_PLACE, inplace=Inplace.OUT_OF_PLACE,
dtype=torch.float16, dtype=torch.float16,
sync=None sync=None,
): ):
print( print(
f"Testing CausalSoftmax on {torch_device} with shape:{shape} x_stride:{x_stride} y_stride:{y_stride} dtype:{dtype} inplace:{inplace}" f"Testing CausalSoftmax on {torch_device} with shape:{shape} x_stride:{x_stride} y_stride:{y_stride} dtype:{dtype} inplace:{inplace}"
...@@ -109,7 +111,7 @@ def test( ...@@ -109,7 +111,7 @@ def test(
y = torch.zeros(shape, dtype=dtype).to(torch_device) y = torch.zeros(shape, dtype=dtype).to(torch_device)
y = rearrange_if_needed(y, y_stride) y = rearrange_if_needed(y, y_stride)
y_tensor = to_tensor(y, lib) y_tensor = to_tensor(y, lib)
if sync is not None: if sync is not None:
sync() sync()
...@@ -144,9 +146,9 @@ def test( ...@@ -144,9 +146,9 @@ def test(
) )
lib_causal_softmax() lib_causal_softmax()
if sync is not None: if sync is not None:
sync() sync()
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
if DEBUG: if DEBUG:
......
...@@ -31,12 +31,13 @@ _TEST_CASES = [ ...@@ -31,12 +31,13 @@ _TEST_CASES = [
] ]
# Data types used for testing # Data types used for testing
_TENSOR_DTYPES = [torch.float16, torch.float32] _TENSOR_DTYPES = [torch.float16, torch.float32, torch.bfloat16]
# Tolerance map for different data types # Tolerance map for different data types
_TOLERANCE_MAP = { _TOLERANCE_MAP = {
torch.float16: {"atol": 0, "rtol": 1e-2}, torch.float16: {"atol": 0, "rtol": 1e-2},
torch.float32: {"atol": 0, "rtol": 1e-3}, torch.float32: {"atol": 0, "rtol": 1e-3},
torch.bfloat16: {"atol": 0, "rtol": 5e-2},
} }
DEBUG = False DEBUG = False
...@@ -84,7 +85,7 @@ def test( ...@@ -84,7 +85,7 @@ def test(
b_stride=None, b_stride=None,
c_stride=None, c_stride=None,
dtype=torch.float16, dtype=torch.float16,
sync=None sync=None,
): ):
print( print(
f"Testing Gemm on {torch_device} with alpha:{alpha}, beta:{beta}," f"Testing Gemm on {torch_device} with alpha:{alpha}, beta:{beta},"
...@@ -152,8 +153,10 @@ def test( ...@@ -152,8 +153,10 @@ def test(
# Validate results # Validate results
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
if DEBUG: if DEBUG:
debug(c, ans, atol=atol, rtol=rtol) debug(c, ans, atol=atol, rtol=rtol)
assert torch.allclose(c, ans, atol=atol, rtol=rtol) assert torch.allclose(c, ans, atol=atol, rtol=rtol)
# Profiling workflow # Profiling workflow
......
import torch
import ctypes import ctypes
from .datatypes import * from .datatypes import *
from .devices import * from .devices import *
...@@ -223,6 +224,10 @@ def debug(actual, desired, atol=0, rtol=1e-2, equal_nan=False, verbose=True): ...@@ -223,6 +224,10 @@ def debug(actual, desired, atol=0, rtol=1e-2, equal_nan=False, verbose=True):
If True, the function will print detailed information about any discrepancies between the tensors. If True, the function will print detailed information about any discrepancies between the tensors.
""" """
import numpy as np import numpy as np
# 如果是BF16,全部转成FP32再比对
if actual.dtype == torch.bfloat16 or desired.dtype == torch.bfloat16:
actual = actual.to(torch.float32)
desired = desired.to(torch.float32)
print_discrepancy(actual, desired, atol, rtol, equal_nan, verbose) print_discrepancy(actual, desired, atol, rtol, equal_nan, verbose)
np.testing.assert_allclose( np.testing.assert_allclose(
...@@ -230,6 +235,14 @@ def debug(actual, desired, atol=0, rtol=1e-2, equal_nan=False, verbose=True): ...@@ -230,6 +235,14 @@ def debug(actual, desired, atol=0, rtol=1e-2, equal_nan=False, verbose=True):
) )
def filter_tensor_dtypes_by_device(device, tensor_dtypes):
if device in (InfiniDeviceEnum.CPU, InfiniDeviceEnum.NVIDIA):
return tensor_dtypes
else:
# 过滤掉 torch.bfloat16
return [dt for dt in tensor_dtypes if dt != torch.bfloat16]
def debug_all( def debug_all(
actual_vals: Sequence, actual_vals: Sequence,
desired_vals: Sequence, desired_vals: Sequence,
...@@ -269,6 +282,9 @@ def debug_all( ...@@ -269,6 +282,9 @@ def debug_all(
passed = False if condition == "or" else True passed = False if condition == "or" else True
for index, (actual, desired) in enumerate(zip(actual_vals, desired_vals)): for index, (actual, desired) in enumerate(zip(actual_vals, desired_vals)):
if actual.dtype == torch.bfloat16 or desired.dtype == torch.bfloat16:
actual = actual.to(torch.float32)
desired = desired.to(torch.float32)
print(f" \033[36mCondition #{index + 1}:\033[0m {actual} == {desired}") print(f" \033[36mCondition #{index + 1}:\033[0m {actual} == {desired}")
indices = print_discrepancy(actual, desired, atol, rtol, equal_nan, verbose) indices = print_discrepancy(actual, desired, atol, rtol, equal_nan, verbose)
if condition == "or": if condition == "or":
...@@ -418,6 +434,7 @@ def test_operator(lib, device, test_func, test_cases, tensor_dtypes): ...@@ -418,6 +434,7 @@ def test_operator(lib, device, test_func, test_cases, tensor_dtypes):
""" """
lib.infinirtSetDevice(device, ctypes.c_int(0)) lib.infinirtSetDevice(device, ctypes.c_int(0))
handle = create_handle(lib) handle = create_handle(lib)
tensor_dtypes = filter_tensor_dtypes_by_device(device, tensor_dtypes)
try: try:
for test_case in test_cases: for test_case in test_cases:
for tensor_dtype in tensor_dtypes: for tensor_dtype in tensor_dtypes:
...@@ -427,7 +444,7 @@ def test_operator(lib, device, test_func, test_cases, tensor_dtypes): ...@@ -427,7 +444,7 @@ def test_operator(lib, device, test_func, test_cases, tensor_dtypes):
infiniDeviceEnum_str_map[device], infiniDeviceEnum_str_map[device],
*test_case, *test_case,
tensor_dtype, tensor_dtype,
get_sync_func(device) get_sync_func(device),
) )
finally: finally:
destroy_handle(lib, handle) destroy_handle(lib, handle)
...@@ -480,11 +497,12 @@ def get_test_devices(args): ...@@ -480,11 +497,12 @@ def get_test_devices(args):
def get_sync_func(device): def get_sync_func(device):
import torch import torch
device_str = infiniDeviceEnum_str_map[device] device_str = infiniDeviceEnum_str_map[device]
if device == InfiniDeviceEnum.CPU: if device == InfiniDeviceEnum.CPU:
sync = None sync = None
else: else:
sync = getattr(torch, device_str).synchronize sync = getattr(torch, device_str).synchronize
return sync return sync
...@@ -37,10 +37,11 @@ _TEST_CASES = [ ...@@ -37,10 +37,11 @@ _TEST_CASES = [
] ]
# Data types used for testing # Data types used for testing
_TENSOR_DTYPES = [torch.float16] _TENSOR_DTYPES = [torch.float16, torch.bfloat16]
_TOLERANCE_MAP = { _TOLERANCE_MAP = {
torch.float16: {"atol": 0, "rtol": 0}, torch.float16: {"atol": 0, "rtol": 0},
torch.bfloat16: {"atol": 0, "rtol": 0},
} }
......
...@@ -23,21 +23,33 @@ from libinfiniop import ( ...@@ -23,21 +23,33 @@ from libinfiniop import (
# Configuration (Internal Use Only) # Configuration (Internal Use Only)
# ============================================================================== # ==============================================================================
# These are not meant to be imported from other modules # These are not meant to be imported from other modules
_TEST_CASES = [ _TEST_CASES_ = [
# y_shape, x_shape, w_shape, y_stride, x_stride, w_dtype # y_shape, x_shape, w_shape, y_stride, x_stride
((1, 4), (1, 4), (4,), None, None, torch.float32), ((1, 4), (1, 4), (4,), None, None),
((16, 2048), (16, 2048), (2048,), None, None, torch.float32), ((1, 4), (1, 4), (4,), None, None),
((16, 2048), (16, 2048), (2048,), None, None, torch.float16), ((16, 2048), (16, 2048), (2048,), None, None),
((16, 2048), (16, 2048), (2048,), (4096, 1), (4096, 1), torch.float32), ((16, 2048), (16, 2048), (2048,), None, None),
((16, 2048), (16, 2048), (2048,), (4096, 1), (4096, 1), torch.float16), ((16, 2048), (16, 2048), (2048,), (4096, 1), (4096, 1)),
((16, 2048), (16, 2048), (2048,), (4096, 1), (4096, 1)),
] ]
# w (weight) types
# Note: 'None' means the same as input dtype
_WEIGHT_DTYPES = [None, torch.float32]
# x types used for testing # x types used for testing
_TENSOR_DTYPES = [torch.float16] _TENSOR_DTYPES = [torch.float16, torch.bfloat16]
# Form the test cases by appending each element of _WEIGHT_DTYPES to each tuple in _TEST_CASES_
_TEST_CASES = [
test_case + (w_dtype,)
for test_case in _TEST_CASES_
for w_dtype in _WEIGHT_DTYPES
]
# Tolerance map for different data types # Tolerance map for different data types
_TOLERANCE_MAP = { _TOLERANCE_MAP = {
torch.float16: {"atol": 1e-3, "rtol": 1e-3}, torch.float16: {"atol": 2e-3, "rtol": 2e-3},
torch.bfloat16: {"atol": 8e-3, "rtol": 8e-3},
} }
DEBUG = False DEBUG = False
...@@ -73,13 +85,14 @@ def test( ...@@ -73,13 +85,14 @@ def test(
x_stride, x_stride,
w_dtype=torch.float16, w_dtype=torch.float16,
dtype=torch.float16, dtype=torch.float16,
sync=None sync=None,
): ):
print( print(
f"Testing RMS_Norm on {torch_device} with y_shape:{y_shape} x_shape:{x_shape} w_shape:{w_shape}" f"Testing RMS_Norm on {torch_device} with y_shape:{y_shape} x_shape:{x_shape} w_shape:{w_shape}"
f" y_stride:{y_stride} x_stride:{x_stride} w_dtype:{w_dtype} dtype:{dtype}" f" y_stride:{y_stride} x_stride:{x_stride} w_dtype:{w_dtype} dtype:{dtype}"
) )
w_dtype = w_dtype if w_dtype else dtype
y = torch.zeros(y_shape, dtype=dtype).to(torch_device) y = torch.zeros(y_shape, dtype=dtype).to(torch_device)
x = torch.rand(x_shape, dtype=dtype).to(torch_device) x = torch.rand(x_shape, dtype=dtype).to(torch_device)
w = torch.rand(w_shape, dtype=w_dtype).to(torch_device) w = torch.rand(w_shape, dtype=w_dtype).to(torch_device)
...@@ -93,10 +106,10 @@ def test( ...@@ -93,10 +106,10 @@ def test(
for tensor, stride in zip([x, y], [x_stride, y_stride]) for tensor, stride in zip([x, y], [x_stride, y_stride])
] ]
x_tensor, y_tensor, w_tensor = [to_tensor(tensor, lib) for tensor in [x, y, w]] x_tensor, y_tensor, w_tensor = [to_tensor(tensor, lib) for tensor in [x, y, w]]
if sync is not None: if sync is not None:
sync() sync()
descriptor = infiniopRMSNormDescriptor_t() descriptor = infiniopRMSNormDescriptor_t()
check_error( check_error(
...@@ -169,7 +182,7 @@ if __name__ == "__main__": ...@@ -169,7 +182,7 @@ if __name__ == "__main__":
POINTER(c_uint64), POINTER(c_uint64),
] ]
lib.infiniopRMSNorm.restypes = c_int32 lib.infiniopRMSNorm.restype = c_int32
lib.infiniopRMSNorm.argtypes = [ lib.infiniopRMSNorm.argtypes = [
infiniopRMSNormDescriptor_t, infiniopRMSNormDescriptor_t,
c_void_p, c_void_p,
......
...@@ -35,11 +35,12 @@ _TEST_CASES_ = [ ...@@ -35,11 +35,12 @@ _TEST_CASES_ = [
] ]
# Data types used for testing # Data types used for testing
_TENSOR_DTYPES = [torch.float16, torch.float32] _TENSOR_DTYPES = [torch.float16, torch.bfloat16, torch.float32]
# Tolerance map for different data types # Tolerance map for different data types
_TOLERANCE_MAP = { _TOLERANCE_MAP = {
torch.float16: {"atol": 1e-3, "rtol": 1e-2}, torch.float16: {"atol": 1e-3, "rtol": 1e-2},
torch.bfloat16: {"atol": 5e-3, "rtol": 5e-2},
torch.float32: {"atol": 1e-4, "rtol": 1e-3}, torch.float32: {"atol": 1e-4, "rtol": 1e-3},
} }
...@@ -117,7 +118,7 @@ def test( ...@@ -117,7 +118,7 @@ def test(
y_strides=None, y_strides=None,
inplace=Inplace.OUT_OF_PLACE, inplace=Inplace.OUT_OF_PLACE,
dtype=torch.float32, dtype=torch.float32,
sync=None sync=None,
): ):
if inplace == Inplace.INPLACE_X: if inplace == Inplace.INPLACE_X:
y_strides = x_strides y_strides = x_strides
...@@ -189,7 +190,7 @@ def test( ...@@ -189,7 +190,7 @@ def test(
) )
lib_rope() lib_rope()
if sync is not None: if sync is not None:
sync() sync()
......
...@@ -14,7 +14,7 @@ from libinfiniop import ( ...@@ -14,7 +14,7 @@ from libinfiniop import (
debug, debug,
get_tolerance, get_tolerance,
profile_operation, profile_operation,
create_workspace create_workspace,
) )
from enum import Enum, auto from enum import Enum, auto
...@@ -36,6 +36,7 @@ _TEST_CASES_ = [ ...@@ -36,6 +36,7 @@ _TEST_CASES_ = [
((4, 4, 5632), (45056, 5632, 1), (45056, 5632, 1), (45056, 5632, 1)), ((4, 4, 5632), (45056, 5632, 1), (45056, 5632, 1), (45056, 5632, 1)),
] ]
class Inplace(Enum): class Inplace(Enum):
OUT_OF_PLACE = auto() OUT_OF_PLACE = auto()
INPLACE_A = auto() INPLACE_A = auto()
...@@ -57,11 +58,12 @@ _TEST_CASES = [ ...@@ -57,11 +58,12 @@ _TEST_CASES = [
] ]
# Data types used for testing # Data types used for testing
_TENSOR_DTYPES = [torch.float16, torch.float32] _TENSOR_DTYPES = [torch.float16, torch.bfloat16, torch.float32]
# Tolerance map for different data types # Tolerance map for different data types
_TOLERANCE_MAP = { _TOLERANCE_MAP = {
torch.float16: {"atol": 1e-3, "rtol": 1e-3}, torch.float16: {"atol": 1e-3, "rtol": 1e-3},
torch.bfloat16: {"atol": 5e-3, "rtol": 5e-3},
torch.float32: {"atol": 2e-7, "rtol": 1e-7}, torch.float32: {"atol": 2e-7, "rtol": 1e-7},
} }
...@@ -80,7 +82,6 @@ infiniopSwiGLUDescriptor_t = POINTER(SwiGLUDescriptor) ...@@ -80,7 +82,6 @@ infiniopSwiGLUDescriptor_t = POINTER(SwiGLUDescriptor)
def swiglu(a, b): def swiglu(a, b):
return a * b / (1 + torch.exp(-b.float()).to(b.dtype)) return a * b / (1 + torch.exp(-b.float()).to(b.dtype))
def process_tensors(c, c_strides, a, a_stride, b, b_stride, inplace): def process_tensors(c, c_strides, a, a_stride, b, b_stride, inplace):
...@@ -171,10 +172,13 @@ def test( ...@@ -171,10 +172,13 @@ def test(
def lib_swiglu(): def lib_swiglu():
check_error( check_error(
lib.infiniopSwiGLU( lib.infiniopSwiGLU(
descriptor, descriptor,
workspace.data_ptr() if workspace is not None else None, workspace.data_ptr() if workspace is not None else None,
workspace_size.value, workspace_size.value,
c_tensor.data, a_tensor.data, b_tensor.data, None c_tensor.data,
a_tensor.data,
b_tensor.data,
None,
) )
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment