Commit a8747955 authored by Bartlomiej Wroblewski's avatar Bartlomiej Wroblewski
Browse files

Reference contraction: Fix incorrect order of B matrix dimensions

parent a1d9285f
...@@ -32,12 +32,12 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base ...@@ -32,12 +32,12 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
struct Argument : public ck::tensor_operation::device::BaseArgument struct Argument : public ck::tensor_operation::device::BaseArgument
{ {
Argument(const Tensor<ADataType>& a_ms_ks, Argument(const Tensor<ADataType>& a_ms_ks,
const Tensor<BDataType>& b_ns_ks, const Tensor<BDataType>& b_ks_ns,
Tensor<CDataType>& c_ms_ns, Tensor<CDataType>& c_ms_ns,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op) BElementwiseOperation b_element_op)
: a_ms_ks_{a_ms_ks}, : a_ms_ks_{a_ms_ks},
b_ns_ks_{b_ns_ks}, b_ks_ns_{b_ks_ns},
c_ms_ns_{c_ms_ns}, c_ms_ns_{c_ms_ns},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op} b_element_op_{b_element_op}
...@@ -45,7 +45,7 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base ...@@ -45,7 +45,7 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
} }
const Tensor<ADataType>& a_ms_ks_; const Tensor<ADataType>& a_ms_ks_;
const Tensor<BDataType>& b_ns_ks_; const Tensor<BDataType>& b_ks_ns_;
Tensor<CDataType>& c_ms_ns_; Tensor<CDataType>& c_ms_ns_;
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
...@@ -75,7 +75,7 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base ...@@ -75,7 +75,7 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
arg.a_element_op_( arg.a_element_op_(
v_a, ck::type_convert<AccDataType>(arg.a_ms_ks_(m0, m1, k0, k1))); v_a, ck::type_convert<AccDataType>(arg.a_ms_ks_(m0, m1, k0, k1)));
arg.b_element_op_( arg.b_element_op_(
v_b, ck::type_convert<AccDataType>(arg.b_ns_ks_(n0, n1, k0, k1))); v_b, ck::type_convert<AccDataType>(arg.b_ks_ns_(k0, k1, n0, n1)));
v_acc += v_a * v_b; v_acc += v_a * v_b;
} }
...@@ -113,12 +113,12 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base ...@@ -113,12 +113,12 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
} }
static auto MakeArgument(const Tensor<ADataType>& a_ms_ks, static auto MakeArgument(const Tensor<ADataType>& a_ms_ks,
const Tensor<BDataType>& b_ns_ks, const Tensor<BDataType>& b_ks_ns,
Tensor<CDataType>& c_ms_ns, Tensor<CDataType>& c_ms_ns,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op) BElementwiseOperation b_element_op)
{ {
return Argument{a_ms_ks, b_ns_ks, c_ms_ns, a_element_op, b_element_op}; return Argument{a_ms_ks, b_ks_ns, c_ms_ns, a_element_op, b_element_op};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment