"...src/SimTKReference/ReferenceVelocityVerletDynamics.cpp" did not exist on "f14182c5dbe96961440f055fd15d3fc2811db9e0"
Commit 60dfab33 authored by zhanghj2's avatar zhanghj2
Browse files

float传bf16使用round_half_ulp_truncate

parent 68971b5c
...@@ -278,7 +278,6 @@ __forceinline__ __device__ auto convert_type(Tensor<Engine, Layout> const &tenso ...@@ -278,7 +278,6 @@ __forceinline__ __device__ auto convert_type(Tensor<Engine, Layout> const &tenso
*result_ptr = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data())); *result_ptr = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));
} }
else if constexpr (std::is_same_v<To_type, cutlass::float_e4m3_t>) { else if constexpr (std::is_same_v<To_type, cutlass::float_e4m3_t>) {
cutlass::NumericArrayConverter<To_type, From_type, numel,cutlass::FloatRoundStyle::round_to_nearest> convert_op; cutlass::NumericArrayConverter<To_type, From_type, numel,cutlass::FloatRoundStyle::round_to_nearest> convert_op;
*result_ptr = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data())); *result_ptr = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));
} }
...@@ -288,11 +287,10 @@ __forceinline__ __device__ auto convert_type(Tensor<Engine, Layout> const &tenso ...@@ -288,11 +287,10 @@ __forceinline__ __device__ auto convert_type(Tensor<Engine, Layout> const &tenso
} }
return tensor_To_type; return tensor_To_type;
} }
#else #else
{ {
if constexpr (std::is_same_v<To_type, cutlass::bfloat16_t>) { if constexpr (std::is_same_v<To_type, cutlass::bfloat16_t>) {
cutlass::NumericArrayConverter<To_type, From_type, numel, cutlass::FloatRoundStyle::round_toward_zero> convert_op; cutlass::NumericArrayConverter<To_type, From_type, numel, cutlass::FloatRoundStyle::round_half_ulp_truncate> convert_op;
*result_ptr = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data())); *result_ptr = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));
} else { } else {
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op; cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
...@@ -300,8 +298,6 @@ __forceinline__ __device__ auto convert_type(Tensor<Engine, Layout> const &tenso ...@@ -300,8 +298,6 @@ __forceinline__ __device__ auto convert_type(Tensor<Engine, Layout> const &tenso
} }
return tensor_To_type; return tensor_To_type;
} }
#endif #endif
// cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op; // cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
// // HACK: this requires tensor to be "contiguous" // // HACK: this requires tensor to be "contiguous"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment