Commit 60dfab33 authored by zhanghj2's avatar zhanghj2
Browse files

float传bf16使用round_half_ulp_truncate

parent 68971b5c
...@@ -276,23 +276,21 @@ __forceinline__ __device__ auto convert_type(Tensor<Engine, Layout> const &tenso ...@@ -276,23 +276,21 @@ __forceinline__ __device__ auto convert_type(Tensor<Engine, Layout> const &tenso
if constexpr (std::is_same_v<To_type, cutlass::bfloat16_t>) { if constexpr (std::is_same_v<To_type, cutlass::bfloat16_t>) {
cutlass::NumericArrayConverter<To_type, From_type, numel, cutlass::FloatRoundStyle::round_to_nearest> convert_op; cutlass::NumericArrayConverter<To_type, From_type, numel, cutlass::FloatRoundStyle::round_to_nearest> convert_op;
*result_ptr = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data())); *result_ptr = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));
} }
else if constexpr (std::is_same_v<To_type, cutlass::float_e4m3_t>) { else if constexpr (std::is_same_v<To_type, cutlass::float_e4m3_t>) {
cutlass::NumericArrayConverter<To_type, From_type, numel,cutlass::FloatRoundStyle::round_to_nearest> convert_op;
cutlass::NumericArrayConverter<To_type, From_type, numel,cutlass::FloatRoundStyle::round_to_nearest> convert_op; *result_ptr = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));
*result_ptr = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data())); }
} else {
else { cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op; *result_ptr = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));
*result_ptr = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data())); }
} return tensor_To_type;
return tensor_To_type;
} }
#else #else
{ {
if constexpr (std::is_same_v<To_type, cutlass::bfloat16_t>) { if constexpr (std::is_same_v<To_type, cutlass::bfloat16_t>) {
cutlass::NumericArrayConverter<To_type, From_type, numel, cutlass::FloatRoundStyle::round_toward_zero> convert_op; cutlass::NumericArrayConverter<To_type, From_type, numel, cutlass::FloatRoundStyle::round_half_ulp_truncate> convert_op;
*result_ptr = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data())); *result_ptr = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));
} else { } else {
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op; cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
...@@ -300,8 +298,6 @@ __forceinline__ __device__ auto convert_type(Tensor<Engine, Layout> const &tenso ...@@ -300,8 +298,6 @@ __forceinline__ __device__ auto convert_type(Tensor<Engine, Layout> const &tenso
} }
return tensor_To_type; return tensor_To_type;
} }
#endif #endif
// cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op; // cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
// // HACK: this requires tensor to be "contiguous" // // HACK: this requires tensor to be "contiguous"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment