Commit 60dfab33 authored by zhanghj2's avatar zhanghj2
Browse files

float传bf16使用round_half_ulp_truncate

parent 68971b5c
......@@ -276,23 +276,21 @@ __forceinline__ __device__ auto convert_type(Tensor<Engine, Layout> const &tenso
if constexpr (std::is_same_v<To_type, cutlass::bfloat16_t>) {
cutlass::NumericArrayConverter<To_type, From_type, numel, cutlass::FloatRoundStyle::round_to_nearest> convert_op;
*result_ptr = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));
}
else if constexpr (std::is_same_v<To_type, cutlass::float_e4m3_t>) {
cutlass::NumericArrayConverter<To_type, From_type, numel,cutlass::FloatRoundStyle::round_to_nearest> convert_op;
*result_ptr = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));
}
else {
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
*result_ptr = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));
}
return tensor_To_type;
}
else if constexpr (std::is_same_v<To_type, cutlass::float_e4m3_t>) {
cutlass::NumericArrayConverter<To_type, From_type, numel,cutlass::FloatRoundStyle::round_to_nearest> convert_op;
*result_ptr = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));
}
else {
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
*result_ptr = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));
}
return tensor_To_type;
}
#else
{
if constexpr (std::is_same_v<To_type, cutlass::bfloat16_t>) {
cutlass::NumericArrayConverter<To_type, From_type, numel, cutlass::FloatRoundStyle::round_toward_zero> convert_op;
cutlass::NumericArrayConverter<To_type, From_type, numel, cutlass::FloatRoundStyle::round_half_ulp_truncate> convert_op;
*result_ptr = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));
} else {
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
......@@ -300,8 +298,6 @@ __forceinline__ __device__ auto convert_type(Tensor<Engine, Layout> const &tenso
}
return tensor_To_type;
}
#endif
// cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
// // HACK: this requires tensor to be "contiguous"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment