Commit 8038fad9 authored by Rostyslav Geyyer's avatar Rostyslav Geyyer
Browse files

Update reference calculation

parent 7cf5d8f7
...@@ -21,7 +21,8 @@ template <typename ADataType, ...@@ -21,7 +21,8 @@ template <typename ADataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
typename ComputType = ADataType> typename ComputeTypeA = ADataType,
typename ComputeTypeB = ComputeTypeA>
struct ReferenceGemm : public device::BaseOperator struct ReferenceGemm : public device::BaseOperator
{ {
// Argument // Argument
...@@ -65,8 +66,8 @@ struct ReferenceGemm : public device::BaseOperator ...@@ -65,8 +66,8 @@ struct ReferenceGemm : public device::BaseOperator
for(int k = 0; k < K; ++k) for(int k = 0; k < K; ++k)
{ {
ComputType v_a; ComputeTypeA v_a;
ComputType v_b; ComputeTypeB v_b;
// use PassThrough instead of ConvertBF16RTN for reference calculation // use PassThrough instead of ConvertBF16RTN for reference calculation
if constexpr(is_same_v<AElementwiseOperation, if constexpr(is_same_v<AElementwiseOperation,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment