"...composable_kernel_rocm.git" did not exist on "85fc91c3218c1d85169ed1fe95eef7b07942e648"
Commit 1abe377b authored by Bartlomiej Kocot's avatar Bartlomiej Kocot
Browse files

Stylistic minor fixes

parent 8354aad7
...@@ -262,11 +262,11 @@ int main(int argc, char* argv[]) ...@@ -262,11 +262,11 @@ int main(int argc, char* argv[])
AElementOp, AElementOp,
BElementOp>; BElementOp>;
auto ref_gemm = ReferenceOpInstance{}; auto ref_op = ReferenceOpInstance{};
auto ref_invoker = ref_gemm.MakeInvoker(); auto ref_invoker = ref_op.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument( auto ref_argument =
a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op); ref_op.MakeArgument(a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op);
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
......
...@@ -262,11 +262,11 @@ int main(int argc, char* argv[]) ...@@ -262,11 +262,11 @@ int main(int argc, char* argv[])
AElementOp, AElementOp,
BElementOp>; BElementOp>;
auto ref_gemm = ReferenceOpInstance{}; auto ref_op = ReferenceOpInstance{};
auto ref_invoker = ref_gemm.MakeInvoker(); auto ref_invoker = ref_op.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument( auto ref_argument =
a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op); ref_op.MakeArgument(a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op);
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
......
...@@ -245,12 +245,12 @@ int main(int argc, char* argv[]) ...@@ -245,12 +245,12 @@ int main(int argc, char* argv[])
AElementOp, AElementOp,
BElementOp>; BElementOp>;
auto ref_gemm = ReferenceOpInstance{}; auto ref_op = ReferenceOpInstance{};
auto ref_invoker = ref_gemm.MakeInvoker(); auto ref_invoker = ref_op.MakeInvoker();
Tensor<float> empty_tensor(std::vector<ck::index_t>{}, std::vector<ck::index_t>{}); Tensor<float> empty_tensor(std::vector<ck::index_t>{}, std::vector<ck::index_t>{});
auto ref_argument = ref_gemm.MakeArgument( auto ref_argument =
a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op); ref_op.MakeArgument(a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op);
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
......
...@@ -245,12 +245,12 @@ int main(int argc, char* argv[]) ...@@ -245,12 +245,12 @@ int main(int argc, char* argv[])
AElementOp, AElementOp,
BElementOp>; BElementOp>;
auto ref_gemm = ReferenceOpInstance{}; auto ref_op = ReferenceOpInstance{};
auto ref_invoker = ref_gemm.MakeInvoker(); auto ref_invoker = ref_op.MakeInvoker();
Tensor<float> empty_tensor(std::vector<ck::index_t>{}, std::vector<ck::index_t>{}); Tensor<float> empty_tensor(std::vector<ck::index_t>{}, std::vector<ck::index_t>{});
auto ref_argument = ref_gemm.MakeArgument( auto ref_argument =
a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op); ref_op.MakeArgument(a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op);
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
......
...@@ -11,8 +11,6 @@ ...@@ -11,8 +11,6 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
using Bilinear = ck::tensor_operation::element_wise::Bilinear;
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace host { namespace host {
......
...@@ -49,22 +49,22 @@ Best Perf: 1.42509 ms, 102.988 TFlops, 234.086 GB/s ...@@ -49,22 +49,22 @@ Best Perf: 1.42509 ms, 102.988 TFlops, 234.086 GB/s
## Profile contraction kernels ## Profile contraction kernels
```bash ```bash
#arg1: tensor operation (contraction=CONTRACTION) #arg1: tensor operation (contraction_bilinear=CONTRACTION+Bilinear)
#arg2: data type (0: fp32; 1: f64)\n" #arg2: data type (0: fp32; 1: f64)\n"
#arg3: matrix layout (0: A[m0, m1, k0, k1] * B[k0, k1, n0, n1] + D[m0, m1, n0, n1] = C[m0, m1, n0, n1]; #arg3: matrix layout (0: A[m0, m1, k0, k1] * B[k0, k1, n0, n1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1];
# 1: A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = C[m0, m1, n0, n1]; # 1: A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1];
# 2: A[k0, k1, m0, m1] * B[k0, k1, n0, n1] + D[m0, m1, n0, n1] = C[m0, m1, n0, n1]; # 2: A[k0, k1, m0, m1] * B[k0, k1, n0, n1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1];
# 3: A[k0, k1, m0, m1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = C[m0, m1, n0, n1]) # 3: A[k0, k1, m0, m1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1])
#arg4: verification (0: no; 1: yes) #arg4: verification (0: no; 1: yes)
#arg5: initialization (0: no init; 1: integer value; 2: decimal value) #arg5: initialization (0: no init; 1: integer value; 2: decimal value)
#arg6: print tensor value (0: no; 1: yes) #arg6: print tensor value (0: no; 1: yes)
#arg7: time kernel (0: no, 1: yes) #arg7: time kernel (0: no, 1: yes)
#arg8 and arg9(optional): alpha and beta for bilinear (pass only alpha for scale) #arg8 and arg9: alpha and beta
#arg9/10 to 14/15: M0, M1, N0, N1, K0, K1 #arg10 to 15: M0, M1, N0, N1, K0, K1
#arg15/16 to 30/31: Strides for A, B, C and D (skip for default) #arg16 to 31: Strides for A, B, D and E (skip for default)
################ op datatype layout verify init log time alpha beta M0 M1 N0 N1 K0 K1 ################ op datatype layout verify init log time alpha beta M0 M1 N0 N1 K0 K1
./bin/ckProfiler contraction 0 1 0 0 0 1 1.0 1.0 128 128 128 128 128 128 ./bin/ckProfiler contraction_bilinear 0 1 0 0 0 1 1.0 1.0 128 128 128 128 128 128
``` ```
Result (MI100) Result (MI100)
...@@ -72,7 +72,7 @@ Result (MI100) ...@@ -72,7 +72,7 @@ Result (MI100)
a_m_k: dim 4, lengths {128, 128, 128, 128}, strides {2097152, 16384, 128, 1} a_m_k: dim 4, lengths {128, 128, 128, 128}, strides {2097152, 16384, 128, 1}
b_k_n: dim 4, lengths {128, 128, 128, 128}, strides {128, 1, 2097152, 16384} b_k_n: dim 4, lengths {128, 128, 128, 128}, strides {128, 1, 2097152, 16384}
d_m_n: dim 4, lengths {128, 128, 128, 128}, strides {2097152, 16384, 128, 1} d_m_n: dim 4, lengths {128, 128, 128, 128}, strides {2097152, 16384, 128, 1}
c_m_n: dim 4, lengths {128, 128, 128, 128}, strides {2097152, 16384, 128, 1} e_m_n: dim 4, lengths {128, 128, 128, 128}, strides {2097152, 16384, 128, 1}
.... ....
Best Perf: 211.405 ms, 41.6077 TFlops, 15.2372 GB/s Best Perf: 211.405 ms, 41.6077 TFlops, 15.2372 GB/s
``` ```
...@@ -33,7 +33,7 @@ using Scale = ck::tensor_operation::element_wise::Scale; ...@@ -33,7 +33,7 @@ using Scale = ck::tensor_operation::element_wise::Scale;
template <typename ALayout, template <typename ALayout,
typename BLayout, typename BLayout,
typename CDLayout, typename CDELayout,
typename DataType, typename DataType,
typename DTupleDataType, typename DTupleDataType,
typename CDElementOp> typename CDElementOp>
...@@ -47,7 +47,7 @@ int profile_contraction_impl(ck::index_t do_verification, ...@@ -47,7 +47,7 @@ int profile_contraction_impl(ck::index_t do_verification,
const std::vector<ck::index_t>& K, const std::vector<ck::index_t>& K,
const std::vector<ck::index_t>& StridesA, const std::vector<ck::index_t>& StridesA,
const std::vector<ck::index_t>& StridesB, const std::vector<ck::index_t>& StridesB,
const std::vector<ck::index_t>& StridesC, const std::vector<ck::index_t>& StridesE,
const std::vector<ck::index_t>& StridesD) const std::vector<ck::index_t>& StridesD)
{ {
bool pass = true; bool pass = true;
...@@ -64,8 +64,8 @@ int profile_contraction_impl(ck::index_t do_verification, ...@@ -64,8 +64,8 @@ int profile_contraction_impl(ck::index_t do_verification,
Tensor<DataType> a_m_k(f_host_tensor_descriptor(M, K, StridesA)); Tensor<DataType> a_m_k(f_host_tensor_descriptor(M, K, StridesA));
Tensor<DataType> b_k_n(f_host_tensor_descriptor(K, N, StridesB)); Tensor<DataType> b_k_n(f_host_tensor_descriptor(K, N, StridesB));
Tensor<DataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StridesC)); Tensor<DataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StridesE));
Tensor<DataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StridesC)); Tensor<DataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StridesE));
Tensor<DataType> d_m_n(f_host_tensor_descriptor(M, N, StridesD)); Tensor<DataType> d_m_n(f_host_tensor_descriptor(M, N, StridesD));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
...@@ -100,10 +100,10 @@ int profile_contraction_impl(ck::index_t do_verification, ...@@ -100,10 +100,10 @@ int profile_contraction_impl(ck::index_t do_verification,
e_device_buf.SetZero(); e_device_buf.SetZero();
d_device_buf.ToDevice(d_m_n.mData.data()); d_device_buf.ToDevice(d_m_n.mData.data());
const std::vector<index_t> a_m_k_lengths = {M[0], M[1], K[0], K[1]}; const std::vector<index_t> a_ms_ks_lengths = {M[0], M[1], K[0], K[1]};
const std::vector<index_t> b_n_k_lengths = {N[0], N[1], K[0], K[1]}; const std::vector<index_t> b_ns_ks_lengths = {N[0], N[1], K[0], K[1]};
const std::vector<index_t> c_m_n_lengths = {M[0], M[1], N[0], N[1]}; const std::vector<index_t> e_ms_ns_lengths = {M[0], M[1], N[0], N[1]};
const std::vector<index_t> d_m_n_lengths = {M[0], M[1], N[0], N[1]}; const std::vector<index_t> d_m_n_lengths = {M[0], M[1], N[0], N[1]};
const auto a_element_op = AElementOp{}; const auto a_element_op = AElementOp{};
const auto b_element_op = BElementOp{}; const auto b_element_op = BElementOp{};
...@@ -143,7 +143,7 @@ int profile_contraction_impl(ck::index_t do_verification, ...@@ -143,7 +143,7 @@ int profile_contraction_impl(ck::index_t do_verification,
auto ref_op = ReferenceGemmInstance{}; auto ref_op = ReferenceGemmInstance{};
auto ref_invoker = ref_op.MakeInvoker(); auto ref_invoker = ref_op.MakeInvoker();
Tensor<DataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StridesC)); Tensor<DataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StridesE));
auto ref_argument = auto ref_argument =
ref_op.MakeArgument(a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op); ref_op.MakeArgument(a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op);
...@@ -169,6 +169,10 @@ int profile_contraction_impl(ck::index_t do_verification, ...@@ -169,6 +169,10 @@ int profile_contraction_impl(ck::index_t do_verification,
cde_element_op(e_m_n_host_result(m0, m1, n0, n1), cde_element_op(e_m_n_host_result(m0, m1, n0, n1),
c_m_n_host_result(m0, m1, n0, n1)); c_m_n_host_result(m0, m1, n0, n1));
} }
else
{
static_assert("Unsupported CDElementOp in contraction profiler.");
}
} }
} }
} }
...@@ -191,37 +195,41 @@ int profile_contraction_impl(ck::index_t do_verification, ...@@ -191,37 +195,41 @@ int profile_contraction_impl(ck::index_t do_verification,
static_cast<DataType*>(b_device_buf.GetDeviceBuffer()), static_cast<DataType*>(b_device_buf.GetDeviceBuffer()),
std::array<const void*, 1>{d_device_buf.GetDeviceBuffer()}, std::array<const void*, 1>{d_device_buf.GetDeviceBuffer()},
static_cast<DataType*>(e_device_buf.GetDeviceBuffer()), static_cast<DataType*>(e_device_buf.GetDeviceBuffer()),
a_m_k_lengths, a_ms_ks_lengths,
StridesA, StridesA,
b_n_k_lengths, b_ns_ks_lengths,
StridesB, StridesB,
std::array<std::vector<ck::index_t>, 1>{d_m_n_lengths}, std::array<std::vector<ck::index_t>, 1>{d_m_n_lengths},
std::array<std::vector<ck::index_t>, 1>{StridesD}, std::array<std::vector<ck::index_t>, 1>{StridesD},
c_m_n_lengths, e_ms_ns_lengths,
StridesC, StridesE,
a_element_op, a_element_op,
b_element_op, b_element_op,
cde_element_op); cde_element_op);
} }
else else if constexpr(is_same<CDElementOp, Scale>::value)
{ {
argument_ptr = argument_ptr =
op_ptr->MakeArgumentPointer(static_cast<DataType*>(a_device_buf.GetDeviceBuffer()), op_ptr->MakeArgumentPointer(static_cast<DataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(b_device_buf.GetDeviceBuffer()), static_cast<DataType*>(b_device_buf.GetDeviceBuffer()),
std::array<const void*, 0>{}, std::array<const void*, 0>{},
static_cast<DataType*>(e_device_buf.GetDeviceBuffer()), static_cast<DataType*>(e_device_buf.GetDeviceBuffer()),
a_m_k_lengths, a_ms_ks_lengths,
StridesA, StridesA,
b_n_k_lengths, b_ns_ks_lengths,
StridesB, StridesB,
std::array<std::vector<ck::index_t>, 0>{}, std::array<std::vector<ck::index_t>, 0>{},
std::array<std::vector<ck::index_t>, 0>{}, std::array<std::vector<ck::index_t>, 0>{},
c_m_n_lengths, e_ms_ns_lengths,
StridesC, StridesE,
a_element_op, a_element_op,
b_element_op, b_element_op,
cde_element_op); cde_element_op);
} }
else
{
static_assert("Unsupported CDElementOp in contraction profiler.");
}
auto invoker_ptr = op_ptr->MakeInvokerPointer(); auto invoker_ptr = op_ptr->MakeInvokerPointer();
...@@ -316,8 +324,17 @@ int profile_contraction_impl(ck::index_t do_verification, ...@@ -316,8 +324,17 @@ int profile_contraction_impl(ck::index_t do_verification,
std::cout << " BLayout = ColumnMajor"; std::cout << " BLayout = ColumnMajor";
} }
if constexpr(is_same<CDELayout, tensor_layout::gemm::RowMajor>::value)
{
std::cout << " CDELayout = RowMajor";
}
else if constexpr(is_same<CDELayout, tensor_layout::gemm::ColumnMajor>::value)
{
std::cout << " CDELayout = ColumnMajor";
}
std::cout << " M = " << M << " N = " << N << " K = " << K << " StridesA = " << StridesA std::cout << " M = " << M << " N = " << N << " K = " << K << " StridesA = " << StridesA
<< " StridesB = " << StridesB << " StridesC = " << StridesC << " : " << best_avg_time << " StridesB = " << StridesB << " StridesE = " << StridesE << " : " << best_avg_time
<< " ms, " << best_tflops << " TFlops, " << best_gb_per_sec << " GB/s, " << " ms, " << best_tflops << " TFlops, " << best_gb_per_sec << " GB/s, "
<< best_op_name << std::endl; << best_op_name << std::endl;
......
...@@ -13,6 +13,20 @@ using Col = ck::tensor_layout::gemm::ColumnMajor; ...@@ -13,6 +13,20 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
using Bilinear = ck::tensor_operation::element_wise::Bilinear; using Bilinear = ck::tensor_operation::element_wise::Bilinear;
using Scale = ck::tensor_operation::element_wise::Scale; using Scale = ck::tensor_operation::element_wise::Scale;
enum struct ContractionMatrixLayout
{
MK_KN_MN_MN, // 0
MK_NK_MN_MN, // 1
KM_KN_MN_MN, // 2
KM_NK_MN_MN, // 3
};
enum struct ContractionDataType
{
F32_F32_F32_F32, // 0
F64_F64_F64_F64, // 1
};
inline void collect_index_params(char* argv[], inline void collect_index_params(char* argv[],
std::vector<ck::index_t>& params, std::vector<ck::index_t>& params,
const ck::index_t from, const ck::index_t from,
......
...@@ -11,20 +11,6 @@ ...@@ -11,20 +11,6 @@
#include "profiler/profile_contraction_utils.hpp" #include "profiler/profile_contraction_utils.hpp"
#include "profiler_operation_registry.hpp" #include "profiler_operation_registry.hpp"
enum struct ContractionMatrixLayout
{
MK_KN_MN_MN, // 0
MK_NK_MN_MN, // 1
KM_KN_MN_MN, // 2
KM_NK_MN_MN, // 3
};
enum struct ContractionDataType
{
F32_F32_F32_F32, // 0
F64_F64_F64_F64, // 1
};
#define OP_NAME "contraction_bilinear" #define OP_NAME "contraction_bilinear"
#define OP_DESC "CONTRACTION+Bilinear" #define OP_DESC "CONTRACTION+Bilinear"
...@@ -33,13 +19,13 @@ static void print_helper_msg() ...@@ -33,13 +19,13 @@ static void print_helper_msg()
std::cout << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n" std::cout << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"
<< "arg2: data type (0: fp32; 1: f64)\n" << "arg2: data type (0: fp32; 1: f64)\n"
<< "arg3: matrix layout (0: A[m0, m1, k0, k1] * B[k0, k1, n0, n1] + " << "arg3: matrix layout (0: A[m0, m1, k0, k1] * B[k0, k1, n0, n1] + "
"D[m0, m1, n0, n1] = C[m0, m1, n0, n1];\n" "D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n"
<< " 1: A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + " << " 1: A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + "
"D[m0, m1, n0, n1] = C[m0, m1, n0, n1];\n" "D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n"
<< " 2: A[k0, k1, m0, m1] * B[k0, k1, n0, n1] + " << " 2: A[k0, k1, m0, m1] * B[k0, k1, n0, n1] + "
"D[m0, m1, n0, n1] = C[m0, m1, n0, n1];\n" "D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n"
<< " 3: A[k0, k1, m0, m1] * B[n0, n1, k0, k1] + " << " 3: A[k0, k1, m0, m1] * B[n0, n1, k0, k1] + "
"D[m0, m1, n0, n1] = C[m0, m1, n0, n1])\n" "D[m0, m1, n0, n1] = E[m0, m1, n0, n1])\n"
<< "arg4: verification (0: no; 1: yes)\n" << "arg4: verification (0: no; 1: yes)\n"
<< "arg5: initialization (0: no init; 1: integer value; 2: decimal " << "arg5: initialization (0: no init; 1: integer value; 2: decimal "
<< "value)\n" << "value)\n"
...@@ -47,7 +33,7 @@ static void print_helper_msg() ...@@ -47,7 +33,7 @@ static void print_helper_msg()
<< "arg7: time kernel (0: no, 1: yes)\n" << "arg7: time kernel (0: no, 1: yes)\n"
<< "arg8 and arg9: alpha and beta\n" << "arg8 and arg9: alpha and beta\n"
<< "arg10 to 15: M0, M1, N0, N1, K0, K1\n" << "arg10 to 15: M0, M1, N0, N1, K0, K1\n"
<< "arg16 to 31: Strides for A, B, C and D (skip for default)\n" << "arg16 to 31: Strides for A, B, D and E (skip for default)\n"
<< std::endl; << std::endl;
} }
...@@ -80,23 +66,23 @@ int profile_contraction_bilinear(int argc, char* argv[]) ...@@ -80,23 +66,23 @@ int profile_contraction_bilinear(int argc, char* argv[])
std::vector<ck::index_t> StridesA; std::vector<ck::index_t> StridesA;
std::vector<ck::index_t> StridesB; std::vector<ck::index_t> StridesB;
std::vector<ck::index_t> StridesC; std::vector<ck::index_t> StridesE;
std::vector<ck::index_t> StridesD; std::vector<ck::index_t> StridesD;
if(!default_strides) if(!default_strides)
{ {
collect_index_params(argv, StridesA, dims_arg_num + 6, 4); collect_index_params(argv, StridesA, dims_arg_num + 6, 4);
collect_index_params(argv, StridesB, dims_arg_num + 10, 4); collect_index_params(argv, StridesB, dims_arg_num + 10, 4);
collect_index_params(argv, StridesC, dims_arg_num + 14, 4); collect_index_params(argv, StridesE, dims_arg_num + 14, 4);
collect_index_params(argv, StridesD, dims_arg_num + 18, 4); collect_index_params(argv, StridesD, dims_arg_num + 18, 4);
} }
using F32 = float; using F32 = float;
using F64 = double; using F64 = double;
auto profile = [&](auto a_layout, auto b_layout, auto cd_layout, auto type) { auto profile = [&](auto a_layout, auto b_layout, auto cde_layout, auto type) {
using ALayout = decltype(a_layout); using ALayout = decltype(a_layout);
using BLayout = decltype(b_layout); using BLayout = decltype(b_layout);
using CDLayout = decltype(cd_layout); using CDELayout = decltype(cde_layout);
using DataType = decltype(type); using DataType = decltype(type);
...@@ -104,12 +90,12 @@ int profile_contraction_bilinear(int argc, char* argv[]) ...@@ -104,12 +90,12 @@ int profile_contraction_bilinear(int argc, char* argv[])
{ {
assign_default_strides(a_layout, StridesA, {M[0], M[1], K[0], K[1]}); assign_default_strides(a_layout, StridesA, {M[0], M[1], K[0], K[1]});
assign_default_strides(b_layout, StridesB, {K[0], K[1], N[0], N[1]}); assign_default_strides(b_layout, StridesB, {K[0], K[1], N[0], N[1]});
assign_default_strides(cd_layout, StridesC, {M[0], M[1], N[0], N[1]}); assign_default_strides(cde_layout, StridesE, {M[0], M[1], N[0], N[1]});
assign_default_strides(cd_layout, StridesD, {M[0], M[1], N[0], N[1]}); assign_default_strides(cde_layout, StridesD, {M[0], M[1], N[0], N[1]});
} }
bool pass = ck::profiler::profile_contraction_impl<ALayout, bool pass = ck::profiler::profile_contraction_impl<ALayout,
BLayout, BLayout,
CDLayout, CDELayout,
DataType, DataType,
ck::Tuple<DataType>, ck::Tuple<DataType>,
Bilinear>(do_verification, Bilinear>(do_verification,
...@@ -122,7 +108,7 @@ int profile_contraction_bilinear(int argc, char* argv[]) ...@@ -122,7 +108,7 @@ int profile_contraction_bilinear(int argc, char* argv[])
K, K,
StridesA, StridesA,
StridesB, StridesB,
StridesC, StridesE,
StridesD); StridesD);
return pass; return pass;
......
...@@ -11,20 +11,6 @@ ...@@ -11,20 +11,6 @@
#include "profiler/profile_contraction_utils.hpp" #include "profiler/profile_contraction_utils.hpp"
#include "profiler_operation_registry.hpp" #include "profiler_operation_registry.hpp"
enum struct ContractionMatrixLayout
{
MK_KN_MN_MN, // 0
MK_NK_MN_MN, // 1
KM_KN_MN_MN, // 2
KM_NK_MN_MN, // 3
};
enum struct ContractionDataType
{
F32_F32_F32_F32, // 0
F64_F64_F64_F64, // 1
};
#define OP_NAME "contraction_scale" #define OP_NAME "contraction_scale"
#define OP_DESC "CONTRACTION+Scale" #define OP_DESC "CONTRACTION+Scale"
...@@ -33,13 +19,13 @@ static void print_helper_msg() ...@@ -33,13 +19,13 @@ static void print_helper_msg()
std::cout << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n" std::cout << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"
<< "arg2: data type (0: fp32; 1: f64)\n" << "arg2: data type (0: fp32; 1: f64)\n"
<< "arg3: matrix layout (0: A[m0, m1, k0, k1] * B[k0, k1, n0, n1] + " << "arg3: matrix layout (0: A[m0, m1, k0, k1] * B[k0, k1, n0, n1] + "
"D[m0, m1, n0, n1] = C[m0, m1, n0, n1];\n" "D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n"
<< " 1: A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + " << " 1: A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + "
"D[m0, m1, n0, n1] = C[m0, m1, n0, n1];\n" "D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n"
<< " 2: A[k0, k1, m0, m1] * B[k0, k1, n0, n1] + " << " 2: A[k0, k1, m0, m1] * B[k0, k1, n0, n1] + "
"D[m0, m1, n0, n1] = C[m0, m1, n0, n1];\n" "D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n"
<< " 3: A[k0, k1, m0, m1] * B[n0, n1, k0, k1] + " << " 3: A[k0, k1, m0, m1] * B[n0, n1, k0, k1] + "
"D[m0, m1, n0, n1] = C[m0, m1, n0, n1])\n" "D[m0, m1, n0, n1] = E[m0, m1, n0, n1])\n"
<< "arg4: verification (0: no; 1: yes)\n" << "arg4: verification (0: no; 1: yes)\n"
<< "arg5: initialization (0: no init; 1: integer value; 2: decimal " << "arg5: initialization (0: no init; 1: integer value; 2: decimal "
<< "value)\n" << "value)\n"
...@@ -47,7 +33,7 @@ static void print_helper_msg() ...@@ -47,7 +33,7 @@ static void print_helper_msg()
<< "arg7: time kernel (0: no, 1: yes)\n" << "arg7: time kernel (0: no, 1: yes)\n"
<< "arg8: alpha\n" << "arg8: alpha\n"
<< "arg9 to 14: M0, M1, N0, N1, K0, K1\n" << "arg9 to 14: M0, M1, N0, N1, K0, K1\n"
<< "arg15 to 30: Strides for A, B, C and D (skip for default)\n" << "arg15 to 30: Strides for A, B, D and E (skip for default)\n"
<< std::endl; << std::endl;
} }
...@@ -79,23 +65,23 @@ int profile_contraction_scale(int argc, char* argv[]) ...@@ -79,23 +65,23 @@ int profile_contraction_scale(int argc, char* argv[])
std::vector<ck::index_t> StridesA; std::vector<ck::index_t> StridesA;
std::vector<ck::index_t> StridesB; std::vector<ck::index_t> StridesB;
std::vector<ck::index_t> StridesC; std::vector<ck::index_t> StridesE;
std::vector<ck::index_t> StridesD; std::vector<ck::index_t> StridesD;
if(!default_strides) if(!default_strides)
{ {
collect_index_params(argv, StridesA, dims_arg_num + 6, 4); collect_index_params(argv, StridesA, dims_arg_num + 6, 4);
collect_index_params(argv, StridesB, dims_arg_num + 10, 4); collect_index_params(argv, StridesB, dims_arg_num + 10, 4);
collect_index_params(argv, StridesC, dims_arg_num + 14, 4); collect_index_params(argv, StridesE, dims_arg_num + 14, 4);
collect_index_params(argv, StridesD, dims_arg_num + 18, 4); collect_index_params(argv, StridesD, dims_arg_num + 18, 4);
} }
using F32 = float; using F32 = float;
using F64 = double; using F64 = double;
auto profile = [&](auto a_layout, auto b_layout, auto cd_layout, auto type) { auto profile = [&](auto a_layout, auto b_layout, auto cde_layout, auto type) {
using ALayout = decltype(a_layout); using ALayout = decltype(a_layout);
using BLayout = decltype(b_layout); using BLayout = decltype(b_layout);
using CDLayout = decltype(cd_layout); using CDELayout = decltype(cde_layout);
using DataType = decltype(type); using DataType = decltype(type);
...@@ -103,12 +89,12 @@ int profile_contraction_scale(int argc, char* argv[]) ...@@ -103,12 +89,12 @@ int profile_contraction_scale(int argc, char* argv[])
{ {
assign_default_strides(a_layout, StridesA, {M[0], M[1], K[0], K[1]}); assign_default_strides(a_layout, StridesA, {M[0], M[1], K[0], K[1]});
assign_default_strides(b_layout, StridesB, {K[0], K[1], N[0], N[1]}); assign_default_strides(b_layout, StridesB, {K[0], K[1], N[0], N[1]});
assign_default_strides(cd_layout, StridesC, {M[0], M[1], N[0], N[1]}); assign_default_strides(cde_layout, StridesE, {M[0], M[1], N[0], N[1]});
assign_default_strides(cd_layout, StridesD, {M[0], M[1], N[0], N[1]}); assign_default_strides(cde_layout, StridesD, {M[0], M[1], N[0], N[1]});
} }
bool pass = ck::profiler:: bool pass = ck::profiler::
profile_contraction_impl<ALayout, BLayout, CDLayout, DataType, ck::Tuple<>, Scale>( profile_contraction_impl<ALayout, BLayout, CDELayout, DataType, ck::Tuple<>, Scale>(
do_verification, do_verification,
init_method, init_method,
do_log, do_log,
...@@ -119,7 +105,7 @@ int profile_contraction_scale(int argc, char* argv[]) ...@@ -119,7 +105,7 @@ int profile_contraction_scale(int argc, char* argv[])
K, K,
StridesA, StridesA,
StridesB, StridesB,
StridesC, StridesE,
StridesD); StridesD);
return pass; return pass;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment