Commit b95deabb authored by Rostyslav Geyyer's avatar Rostyslav Geyyer
Browse files

Format

parent 1f45a250
...@@ -99,12 +99,11 @@ struct ConvertBF16RTN ...@@ -99,12 +99,11 @@ struct ConvertBF16RTN
__host__ __device__ void operator()(Y& y, const X& x) const __host__ __device__ void operator()(Y& y, const X& x) const
{ {
// check Y datatype // check Y datatype
static_assert(is_same<Y, bhalf_t>::value, static_assert(is_same<Y, bhalf_t>::value, "Data type is not supported by this operation!");
"Data type is not supported by this operation!");
// check X datatype // check X datatype
static_assert(is_same<X, float>::value || is_same<X, half_t>::value, static_assert(is_same<X, float>::value || is_same<X, half_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = bf16_convert_rtn<Y>(x); y = bf16_convert_rtn<Y>(x);
} }
......
...@@ -67,7 +67,8 @@ struct ReferenceGemm : public device::BaseOperator ...@@ -67,7 +67,8 @@ struct ReferenceGemm : public device::BaseOperator
BDataType v_b; BDataType v_b;
// use PassThrough instead of ConvertBF16RTN for reference calculation // use PassThrough instead of ConvertBF16RTN for reference calculation
if constexpr(is_same_v<AElementwiseOperation, ck::tensor_operation::element_wise::ConvertBF16RTN>) if constexpr(is_same_v<AElementwiseOperation,
ck::tensor_operation::element_wise::ConvertBF16RTN>)
{ {
ck::tensor_operation::element_wise::PassThrough{}(v_a, arg.a_m_k_(m, k)); ck::tensor_operation::element_wise::PassThrough{}(v_a, arg.a_m_k_(m, k));
} }
...@@ -76,7 +77,8 @@ struct ReferenceGemm : public device::BaseOperator ...@@ -76,7 +77,8 @@ struct ReferenceGemm : public device::BaseOperator
arg.a_element_op_(v_a, arg.a_m_k_(m, k)); arg.a_element_op_(v_a, arg.a_m_k_(m, k));
} }
// same for B matrix // same for B matrix
if constexpr(is_same_v<BElementwiseOperation, ck::tensor_operation::element_wise::ConvertBF16RTN>) if constexpr(is_same_v<BElementwiseOperation,
ck::tensor_operation::element_wise::ConvertBF16RTN>)
{ {
ck::tensor_operation::element_wise::PassThrough{}(v_b, arg.b_k_n_(k, n)); ck::tensor_operation::element_wise::PassThrough{}(v_b, arg.b_k_n_(k, n));
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment