"git@developer.sourcefind.cn:gaoqiong/composable_kernel.git" did not exist on "73480fee3635310aedbbec68b6084c94cfd2457d"
Commit 762d30bf authored by Bartlomiej Wroblewski's avatar Bartlomiej Wroblewski
Browse files

Fix the order of B matrix dimensions across examples and profiler

parent 367a6f4b
...@@ -33,12 +33,12 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base ...@@ -33,12 +33,12 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
struct Argument : public ck::tensor_operation::device::BaseArgument struct Argument : public ck::tensor_operation::device::BaseArgument
{ {
Argument(const Tensor<ADataType>& a_ms_ks, Argument(const Tensor<ADataType>& a_ms_ks,
const Tensor<BDataType>& b_ks_ns, const Tensor<BDataType>& b_ns_ks,
Tensor<CDataType>& c_ms_ns, Tensor<CDataType>& c_ms_ns,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op) BElementwiseOperation b_element_op)
: a_ms_ks_{a_ms_ks}, : a_ms_ks_{a_ms_ks},
b_ks_ns_{b_ks_ns}, b_ns_ks_{b_ns_ks},
c_ms_ns_{c_ms_ns}, c_ms_ns_{c_ms_ns},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op} b_element_op_{b_element_op}
...@@ -46,7 +46,7 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base ...@@ -46,7 +46,7 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
} }
const Tensor<ADataType>& a_ms_ks_; const Tensor<ADataType>& a_ms_ks_;
const Tensor<BDataType>& b_ks_ns_; const Tensor<BDataType>& b_ns_ks_;
Tensor<CDataType>& c_ms_ns_; Tensor<CDataType>& c_ms_ns_;
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
...@@ -75,7 +75,7 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base ...@@ -75,7 +75,7 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
ComputeDataType v_a_compute_input = ComputeDataType v_a_compute_input =
ck::type_convert<ComputeDataType>(arg.a_ms_ks_(m0, m1, k0, k1)); ck::type_convert<ComputeDataType>(arg.a_ms_ks_(m0, m1, k0, k1));
ComputeDataType v_b_compute_input = ComputeDataType v_b_compute_input =
ck::type_convert<ComputeDataType>(arg.b_ks_ns_(k0, k1, n0, n1)); ck::type_convert<ComputeDataType>(arg.b_ns_ks_(n0, n1, k0, k1));
AccDataType v_a; AccDataType v_a;
AccDataType v_b; AccDataType v_b;
...@@ -119,12 +119,12 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base ...@@ -119,12 +119,12 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
} }
static auto MakeArgument(const Tensor<ADataType>& a_ms_ks, static auto MakeArgument(const Tensor<ADataType>& a_ms_ks,
const Tensor<BDataType>& b_ks_ns, const Tensor<BDataType>& b_ns_ks,
Tensor<CDataType>& c_ms_ns, Tensor<CDataType>& c_ms_ns,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op) BElementwiseOperation b_element_op)
{ {
return Argument{a_ms_ks, b_ks_ns, c_ms_ns, a_element_op, b_element_op}; return Argument{a_ms_ks, b_ns_ks, c_ms_ns, a_element_op, b_element_op};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
......
...@@ -50,7 +50,7 @@ int profile_contraction_impl(ck::index_t do_verification, ...@@ -50,7 +50,7 @@ int profile_contraction_impl(ck::index_t do_verification,
const std::vector<ck::index_t>& N, const std::vector<ck::index_t>& N,
const std::vector<ck::index_t>& K, const std::vector<ck::index_t>& K,
const std::vector<ck::index_t>& StridesA, // [M0, M1, K0, K1] const std::vector<ck::index_t>& StridesA, // [M0, M1, K0, K1]
const std::vector<ck::index_t>& StridesB, // [K0, K1, N0, N1] const std::vector<ck::index_t>& StridesB, // [N0, N1, K0, K1]
const std::vector<ck::index_t>& StridesE, // [M0, M1, N0, N1] const std::vector<ck::index_t>& StridesE, // [M0, M1, N0, N1]
const std::vector<ck::index_t>& StridesD) // [M0, M1, N0, N1] const std::vector<ck::index_t>& StridesD) // [M0, M1, N0, N1]
{ {
...@@ -67,13 +67,13 @@ int profile_contraction_impl(ck::index_t do_verification, ...@@ -67,13 +67,13 @@ int profile_contraction_impl(ck::index_t do_verification,
}; };
Tensor<DataType> a_m_k(f_host_tensor_descriptor(M, K, StridesA)); Tensor<DataType> a_m_k(f_host_tensor_descriptor(M, K, StridesA));
Tensor<DataType> b_k_n(f_host_tensor_descriptor(K, N, StridesB)); Tensor<DataType> b_n_k(f_host_tensor_descriptor(N, K, StridesB));
Tensor<DataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StridesE)); Tensor<DataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StridesE));
Tensor<DataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StridesE)); Tensor<DataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StridesE));
Tensor<DataType> d_m_n(f_host_tensor_descriptor(M, N, StridesD)); Tensor<DataType> d_m_n(f_host_tensor_descriptor(M, N, StridesD));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; std::cout << "b_n_k: " << b_n_k.mDesc << std::endl;
std::cout << "d_m_n: " << d_m_n.mDesc << std::endl; std::cout << "d_m_n: " << d_m_n.mDesc << std::endl;
std::cout << "e_m_n: " << e_m_n_device_result.mDesc << std::endl; std::cout << "e_m_n: " << e_m_n_device_result.mDesc << std::endl;
...@@ -82,12 +82,12 @@ int profile_contraction_impl(ck::index_t do_verification, ...@@ -82,12 +82,12 @@ int profile_contraction_impl(ck::index_t do_verification,
case 0: break; case 0: break;
case 1: case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2<DataType>{-5, 5}); a_m_k.GenerateTensorValue(GeneratorTensor_2<DataType>{-5, 5});
b_k_n.GenerateTensorValue(GeneratorTensor_2<DataType>{-5, 5}); b_n_k.GenerateTensorValue(GeneratorTensor_2<DataType>{-5, 5});
d_m_n.GenerateTensorValue(GeneratorTensor_2<DataType>{-5, 5}); d_m_n.GenerateTensorValue(GeneratorTensor_2<DataType>{-5, 5});
break; break;
default: default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<DataType>{0.0, 1.0}); a_m_k.GenerateTensorValue(GeneratorTensor_3<DataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_3<DataType>{-0.5, 0.5}); b_n_k.GenerateTensorValue(GeneratorTensor_3<DataType>{-0.5, 0.5});
d_m_n.GenerateTensorValue(GeneratorTensor_3<DataType>{-0.5, 0.5}); d_m_n.GenerateTensorValue(GeneratorTensor_3<DataType>{-0.5, 0.5});
} }
...@@ -95,12 +95,12 @@ int profile_contraction_impl(ck::index_t do_verification, ...@@ -95,12 +95,12 @@ int profile_contraction_impl(ck::index_t do_verification,
using BElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough;
DeviceMem a_device_buf(sizeof(DataType) * a_m_k.mDesc.GetElementSpaceSize()); DeviceMem a_device_buf(sizeof(DataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(DataType) * b_k_n.mDesc.GetElementSpaceSize()); DeviceMem b_device_buf(sizeof(DataType) * b_n_k.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(DataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf(sizeof(DataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
DeviceMem d_device_buf(sizeof(DataType) * d_m_n.mDesc.GetElementSpaceSize()); DeviceMem d_device_buf(sizeof(DataType) * d_m_n.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_m_k.mData.data()); a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data()); b_device_buf.ToDevice(b_n_k.mData.data());
e_device_buf.SetZero(); e_device_buf.SetZero();
d_device_buf.ToDevice(d_m_n.mData.data()); d_device_buf.ToDevice(d_m_n.mData.data());
...@@ -109,9 +109,9 @@ int profile_contraction_impl(ck::index_t do_verification, ...@@ -109,9 +109,9 @@ int profile_contraction_impl(ck::index_t do_verification,
const std::vector<index_t> e_ms_ns_lengths = {M[0], M[1], N[0], N[1]}; const std::vector<index_t> e_ms_ns_lengths = {M[0], M[1], N[0], N[1]};
const std::vector<index_t> d_m_n_lengths = {M[0], M[1], N[0], N[1]}; const std::vector<index_t> d_m_n_lengths = {M[0], M[1], N[0], N[1]};
// The order of dims in StridesB is [K0, K1, N0, N1] so need to change it to [N0, N1, K0, K1] // // The order of dims in StridesB is [K0, K1, N0, N1] so need to change it to [N0, N1, K0, K1]
const std::vector<index_t> b_ns_ks_strides = { // const std::vector<index_t> b_ns_ks_strides = {
StridesB[2], StridesB[3], StridesB[0], StridesB[1]}; // StridesB[2], StridesB[3], StridesB[0], StridesB[1]};
const auto a_element_op = AElementOp{}; const auto a_element_op = AElementOp{};
const auto b_element_op = BElementOp{}; const auto b_element_op = BElementOp{};
...@@ -159,7 +159,7 @@ int profile_contraction_impl(ck::index_t do_verification, ...@@ -159,7 +159,7 @@ int profile_contraction_impl(ck::index_t do_verification,
Tensor<DataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StridesE)); Tensor<DataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StridesE));
auto ref_argument = auto ref_argument =
ref_op.MakeArgument(a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op); ref_op.MakeArgument(a_m_k, b_n_k, c_m_n_host_result, a_element_op, b_element_op);
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
...@@ -211,7 +211,7 @@ int profile_contraction_impl(ck::index_t do_verification, ...@@ -211,7 +211,7 @@ int profile_contraction_impl(ck::index_t do_verification,
a_ms_ks_lengths, a_ms_ks_lengths,
StridesA, StridesA,
b_ns_ks_lengths, b_ns_ks_lengths,
b_ns_ks_strides, StridesB,
std::array<std::vector<ck::index_t>, 1>{d_m_n_lengths}, std::array<std::vector<ck::index_t>, 1>{d_m_n_lengths},
std::array<std::vector<ck::index_t>, 1>{StridesD}, std::array<std::vector<ck::index_t>, 1>{StridesD},
e_ms_ns_lengths, e_ms_ns_lengths,
...@@ -230,7 +230,7 @@ int profile_contraction_impl(ck::index_t do_verification, ...@@ -230,7 +230,7 @@ int profile_contraction_impl(ck::index_t do_verification,
a_ms_ks_lengths, a_ms_ks_lengths,
StridesA, StridesA,
b_ns_ks_lengths, b_ns_ks_lengths,
b_ns_ks_strides, StridesB,
std::array<std::vector<ck::index_t>, 0>{}, std::array<std::vector<ck::index_t>, 0>{},
std::array<std::vector<ck::index_t>, 0>{}, std::array<std::vector<ck::index_t>, 0>{},
e_ms_ns_lengths, e_ms_ns_lengths,
...@@ -309,7 +309,7 @@ int profile_contraction_impl(ck::index_t do_verification, ...@@ -309,7 +309,7 @@ int profile_contraction_impl(ck::index_t do_verification,
if(do_log) if(do_log)
{ {
LogRangeAsType<float>(std::cout << "a : ", a_m_k.mData, ",") << std::endl; LogRangeAsType<float>(std::cout << "a : ", a_m_k.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "b: ", b_k_n.mData, ",") << std::endl; LogRangeAsType<float>(std::cout << "b: ", b_n_k.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "c_host : ", e_m_n_host_result.mData, ",") LogRangeAsType<float>(std::cout << "c_host : ", e_m_n_host_result.mData, ",")
<< std::endl; << std::endl;
LogRangeAsType<float>(std::cout << "c_device: ", e_m_n_device_result.mData, ",") LogRangeAsType<float>(std::cout << "c_device: ", e_m_n_device_result.mData, ",")
......
...@@ -96,7 +96,7 @@ int profile_contraction_bilinear(int argc, char* argv[]) ...@@ -96,7 +96,7 @@ int profile_contraction_bilinear(int argc, char* argv[])
if(default_strides) if(default_strides)
{ {
assign_default_strides(a_layout, StridesA, {M[0], M[1], K[0], K[1]}); assign_default_strides(a_layout, StridesA, {M[0], M[1], K[0], K[1]});
assign_default_strides(b_layout, StridesB, {K[0], K[1], N[0], N[1]}); assign_default_strides(b_layout, StridesB, {N[0], N[1], K[0], K[1]});
assign_default_strides(cde_layout, StridesE, {M[0], M[1], N[0], N[1]}); assign_default_strides(cde_layout, StridesE, {M[0], M[1], N[0], N[1]});
assign_default_strides(cde_layout, StridesD, {M[0], M[1], N[0], N[1]}); assign_default_strides(cde_layout, StridesD, {M[0], M[1], N[0], N[1]});
} }
......
...@@ -94,7 +94,7 @@ int profile_contraction_scale(int argc, char* argv[]) ...@@ -94,7 +94,7 @@ int profile_contraction_scale(int argc, char* argv[])
if(default_strides) if(default_strides)
{ {
assign_default_strides(a_layout, StridesA, {M[0], M[1], K[0], K[1]}); assign_default_strides(a_layout, StridesA, {M[0], M[1], K[0], K[1]});
assign_default_strides(b_layout, StridesB, {K[0], K[1], N[0], N[1]}); assign_default_strides(b_layout, StridesB, {N[0], N[1], K[0], K[1]});
assign_default_strides(cde_layout, StridesE, {M[0], M[1], N[0], N[1]}); assign_default_strides(cde_layout, StridesE, {M[0], M[1], N[0], N[1]});
assign_default_strides(cde_layout, StridesD, {M[0], M[1], N[0], N[1]}); assign_default_strides(cde_layout, StridesD, {M[0], M[1], N[0], N[1]});
} }
......
...@@ -62,7 +62,7 @@ class TestContraction : public ::testing::Test ...@@ -62,7 +62,7 @@ class TestContraction : public ::testing::Test
const auto& K = dimension_params.K; const auto& K = dimension_params.K;
assign_default_strides(ALayout{}, StridesA, {M[0], M[1], K[0], K[1]}); assign_default_strides(ALayout{}, StridesA, {M[0], M[1], K[0], K[1]});
assign_default_strides(BLayout{}, StridesB, {K[0], K[1], N[0], N[1]}); assign_default_strides(BLayout{}, StridesB, {N[0], N[1], K[0], K[1]});
assign_default_strides(CDLayout{}, StridesC, {M[0], M[1], N[0], N[1]}); assign_default_strides(CDLayout{}, StridesC, {M[0], M[1], N[0], N[1]});
assign_default_strides(CDLayout{}, StridesD, {M[0], M[1], N[0], N[1]}); assign_default_strides(CDLayout{}, StridesD, {M[0], M[1], N[0], N[1]});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment