Commit 1f45a250 authored by Rostyslav Geyyer's avatar Rostyslav Geyyer
Browse files

Use a PassThrough instead of ConvertBF16RTN to calcaulate reference

parent 82414279
...@@ -94,28 +94,19 @@ struct UnaryConvert ...@@ -94,28 +94,19 @@ struct UnaryConvert
struct ConvertBF16RTN struct ConvertBF16RTN
{ {
// convert to bf16 using round to nearest (rtn)
template <typename Y, typename X> template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const; __host__ __device__ void operator()(Y& y, const X& x) const
// convert fp16->bf16 using rounding to nearest (rtn) via fp32
template <>
__host__ __device__ void operator()<bhalf_t, half_t>(bhalf_t& y, const half_t& x) const
{ {
y = bf16_convert_rtn<bhalf_t>(x); // check Y datatype
} static_assert(is_same<Y, bhalf_t>::value,
"Data type is not supported by this operation!");
// convert fp32->bf16 using rounding to nearest (rtn) // check X datatype
template <> static_assert(is_same<X, float>::value || is_same<X, half_t>::value,
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const "Data type is not supported by this operation!");
{
y = bf16_convert_rtn<bhalf_t>(x);
}
// need to keep this specialization for fp16->fp16 ops y = bf16_convert_rtn<Y>(x);
template <>
__host__ __device__ void operator()<half_t, half_t>(half_t& y, const half_t& x) const
{
y = x;
} }
}; };
......
...@@ -66,8 +66,24 @@ struct ReferenceGemm : public device::BaseOperator ...@@ -66,8 +66,24 @@ struct ReferenceGemm : public device::BaseOperator
ADataType v_a; ADataType v_a;
BDataType v_b; BDataType v_b;
// use PassThrough instead of ConvertBF16RTN for reference calculation
if constexpr(is_same_v<AElementwiseOperation, ck::tensor_operation::element_wise::ConvertBF16RTN>)
{
ck::tensor_operation::element_wise::PassThrough{}(v_a, arg.a_m_k_(m, k));
}
else
{
arg.a_element_op_(v_a, arg.a_m_k_(m, k)); arg.a_element_op_(v_a, arg.a_m_k_(m, k));
}
// same for B matrix
if constexpr(is_same_v<BElementwiseOperation, ck::tensor_operation::element_wise::ConvertBF16RTN>)
{
ck::tensor_operation::element_wise::PassThrough{}(v_b, arg.b_k_n_(k, n));
}
else
{
arg.b_element_op_(v_b, arg.b_k_n_(k, n)); arg.b_element_op_(v_b, arg.b_k_n_(k, n));
}
v_acc += v_acc +=
ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b); ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment