Commit a8747955 authored by Bartlomiej Wroblewski's avatar Bartlomiej Wroblewski
Browse files

Reference contraction: Fix incorrect order of B matrix dimensions

parent a1d9285f
......@@ -32,12 +32,12 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
struct Argument : public ck::tensor_operation::device::BaseArgument
{
Argument(const Tensor<ADataType>& a_ms_ks,
const Tensor<BDataType>& b_ns_ks,
const Tensor<BDataType>& b_ks_ns,
Tensor<CDataType>& c_ms_ns,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op)
: a_ms_ks_{a_ms_ks},
b_ns_ks_{b_ns_ks},
b_ks_ns_{b_ks_ns},
c_ms_ns_{c_ms_ns},
a_element_op_{a_element_op},
b_element_op_{b_element_op}
......@@ -45,7 +45,7 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
}
const Tensor<ADataType>& a_ms_ks_;
const Tensor<BDataType>& b_ns_ks_;
const Tensor<BDataType>& b_ks_ns_;
Tensor<CDataType>& c_ms_ns_;
AElementwiseOperation a_element_op_;
......@@ -75,7 +75,7 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
arg.a_element_op_(
v_a, ck::type_convert<AccDataType>(arg.a_ms_ks_(m0, m1, k0, k1)));
arg.b_element_op_(
v_b, ck::type_convert<AccDataType>(arg.b_ns_ks_(n0, n1, k0, k1)));
v_b, ck::type_convert<AccDataType>(arg.b_ks_ns_(k0, k1, n0, n1)));
v_acc += v_a * v_b;
}
......@@ -113,12 +113,12 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
}
static auto MakeArgument(const Tensor<ADataType>& a_ms_ks,
const Tensor<BDataType>& b_ns_ks,
const Tensor<BDataType>& b_ks_ns,
Tensor<CDataType>& c_ms_ns,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op)
{
return Argument{a_ms_ks, b_ns_ks, c_ms_ns, a_element_op, b_element_op};
return Argument{a_ms_ks, b_ks_ns, c_ms_ns, a_element_op, b_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment