Unverified Commit 70eebf22 authored by zjing14's avatar zjing14 Committed by GitHub
Browse files

Merge branch 'develop' into grouped_gemm_multi_abd_fixed_nk_example

parents 5608328c 98fd41f5
...@@ -17,8 +17,9 @@ ...@@ -17,8 +17,9 @@
static void print_helper_msg() static void print_helper_msg()
{ {
std::cout << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n" std::cout << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"
<< "arg2: data type (0: fp32; 1: f64)\n" << "arg2: data type (0: fp32; 1: f64; 2: f16; 3: bf16)\n"
<< "arg3: matrix layout (0: A[m0, m1, k0, k1] * B[k0, k1, n0, n1] + " << "arg3: compute data type (0: fp32; 1: f64; 2: f16; 3: bf16)\n"
<< "arg4: matrix layout (0: A[m0, m1, k0, k1] * B[k0, k1, n0, n1] + "
"D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n" "D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n"
<< " 1: A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + " << " 1: A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + "
"D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n" "D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n"
...@@ -26,40 +27,42 @@ static void print_helper_msg() ...@@ -26,40 +27,42 @@ static void print_helper_msg()
"D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n" "D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n"
<< " 3: A[k0, k1, m0, m1] * B[n0, n1, k0, k1] + " << " 3: A[k0, k1, m0, m1] * B[n0, n1, k0, k1] + "
"D[m0, m1, n0, n1] = E[m0, m1, n0, n1])\n" "D[m0, m1, n0, n1] = E[m0, m1, n0, n1])\n"
<< "arg4: verification (0: no; 1: yes)\n" << "arg5: verification (0: no; 1: yes)\n"
<< "arg5: initialization (0: no init; 1: integer value; 2: decimal " << "arg6: initialization (0: no init; 1: integer value; 2: decimal "
<< "value)\n" << "value)\n"
<< "arg6: print tensor value (0: no; 1: yes)\n" << "arg7: print tensor value (0: no; 1: yes)\n"
<< "arg7: time kernel (0: no, 1: yes)\n" << "arg8: time kernel (0: no, 1: yes)\n"
<< "arg8 and arg9: alpha and beta\n" << "arg9: alpha\n"
<< "arg10 to 15: M0, M1, N0, N1, K0, K1\n" << "arg10: beta\n"
<< "arg16 to 31: Strides for A, B, D and E (skip for default)\n" << "arg11 to 16: M0, M1, N0, N1, K0, K1\n"
<< "arg17 to 32: Strides for A, B, D and E (skip for default)\n"
<< std::endl; << std::endl;
} }
int profile_contraction_bilinear(int argc, char* argv[]) int profile_contraction_bilinear(int argc, char* argv[])
{ {
const bool default_strides = argc == 16; const bool default_strides = argc == 17;
if(argc != 32 && argc != 16) if(argc != 33 && argc != 17)
{ {
print_helper_msg(); print_helper_msg();
exit(1); exit(1);
} }
const auto data_type = static_cast<ContractionDataType>(std::stoi(argv[2])); const auto data_type = static_cast<ContractionDataType>(std::stoi(argv[2]));
const auto layout = static_cast<ContractionMatrixLayout>(std::stoi(argv[3])); const auto compute_data_type = static_cast<ContractionComputeDataType>(std::stoi(argv[3]));
const bool do_verification = std::stoi(argv[4]); const auto layout = static_cast<ContractionMatrixLayout>(std::stoi(argv[4]));
const ck::index_t init_method = std::stoi(argv[5]); const bool do_verification = std::stoi(argv[5]);
const bool do_log = std::stoi(argv[6]); const ck::index_t init_method = std::stoi(argv[6]);
const bool time_kernel = std::stoi(argv[7]); const bool do_log = std::stoi(argv[7]);
const float alpha = std::stof(argv[8]); const bool time_kernel = std::stoi(argv[8]);
const float beta = std::stof(argv[9]); const float alpha = std::stof(argv[9]);
const float beta = std::stof(argv[10]);
std::vector<ck::index_t> M; std::vector<ck::index_t> M;
std::vector<ck::index_t> N; std::vector<ck::index_t> N;
std::vector<ck::index_t> K; std::vector<ck::index_t> K;
const ck::index_t dims_arg_num = 10; const ck::index_t dims_arg_num = 11;
collect_index_params(argv, M, dims_arg_num, 2); collect_index_params(argv, M, dims_arg_num, 2);
collect_index_params(argv, N, dims_arg_num + 2, 2); collect_index_params(argv, N, dims_arg_num + 2, 2);
collect_index_params(argv, K, dims_arg_num + 4, 2); collect_index_params(argv, K, dims_arg_num + 4, 2);
...@@ -76,90 +79,130 @@ int profile_contraction_bilinear(int argc, char* argv[]) ...@@ -76,90 +79,130 @@ int profile_contraction_bilinear(int argc, char* argv[])
collect_index_params(argv, StridesD, dims_arg_num + 18, 4); collect_index_params(argv, StridesD, dims_arg_num + 18, 4);
} }
using F32 = float; using F16 = ck::half_t;
using F64 = double; using BF16 = ck::bhalf_t;
using F32 = float;
auto profile = [&](auto a_layout, auto b_layout, auto cde_layout, auto type) { using F64 = double;
using ALayout = decltype(a_layout);
using BLayout = decltype(b_layout); auto profile =
using CDELayout = decltype(cde_layout); [&](auto a_layout, auto b_layout, auto cde_layout, auto type, auto compute_type) {
using ALayout = decltype(a_layout);
using DataType = decltype(type); using BLayout = decltype(b_layout);
using CDELayout = decltype(cde_layout);
if(default_strides)
using DataType = decltype(type);
using ComputeDataType = decltype(compute_type);
if(default_strides)
{
assign_default_strides(a_layout, StridesA, {M[0], M[1], K[0], K[1]});
assign_default_strides(b_layout, StridesB, {N[0], N[1], K[0], K[1]});
assign_default_strides(cde_layout, StridesE, {M[0], M[1], N[0], N[1]});
assign_default_strides(cde_layout, StridesD, {M[0], M[1], N[0], N[1]});
}
bool pass = ck::profiler::profile_contraction_impl<ALayout,
BLayout,
CDELayout,
DataType,
ComputeDataType,
ck::Tuple<DataType>,
Bilinear>(do_verification,
init_method,
do_log,
time_kernel,
Bilinear{alpha, beta},
M,
N,
K,
StridesA,
StridesB,
StridesE,
StridesD);
return pass;
};
auto run_profile_for_datatype = [&](auto type, auto compute_type) {
if(layout == ContractionMatrixLayout::MK_KN_MN_MN)
{ {
assign_default_strides(a_layout, StridesA, {M[0], M[1], K[0], K[1]}); return profile(Row{}, Row{}, Row{}, type, compute_type);
assign_default_strides(b_layout, StridesB, {K[0], K[1], N[0], N[1]});
assign_default_strides(cde_layout, StridesE, {M[0], M[1], N[0], N[1]});
assign_default_strides(cde_layout, StridesD, {M[0], M[1], N[0], N[1]});
} }
bool pass = ck::profiler::profile_contraction_impl<ALayout, else if(layout == ContractionMatrixLayout::MK_NK_MN_MN)
BLayout, {
CDELayout, return profile(Row{}, Col{}, Row{}, type, compute_type);
DataType, }
ck::Tuple<DataType>, else if(layout == ContractionMatrixLayout::KM_KN_MN_MN)
Bilinear>(do_verification, {
init_method, return profile(Col{}, Row{}, Row{}, type, compute_type);
do_log, }
time_kernel, else if(layout == ContractionMatrixLayout::KM_NK_MN_MN)
Bilinear{alpha, beta}, {
M, return profile(Col{}, Col{}, Row{}, type, compute_type);
N, }
K, return false;
StridesA,
StridesB,
StridesE,
StridesD);
return pass;
}; };
if(data_type == ContractionDataType::F32_F32_F32_F32 && if(data_type == ContractionDataType::F32_F32_F32_F32)
layout == ContractionMatrixLayout::MK_KN_MN_MN)
{
return profile(Row{}, Row{}, Row{}, F32{});
}
else if(data_type == ContractionDataType::F32_F32_F32_F32 &&
layout == ContractionMatrixLayout::MK_NK_MN_MN)
{ {
return profile(Row{}, Col{}, Row{}, F32{}); if(compute_data_type == ContractionComputeDataType::F32)
} {
else if(data_type == ContractionDataType::F32_F32_F32_F32 && return run_profile_for_datatype(F32{}, F32{});
layout == ContractionMatrixLayout::KM_KN_MN_MN) }
{ else if(compute_data_type == ContractionComputeDataType::F16)
return profile(Col{}, Row{}, Row{}, F32{}); {
} return run_profile_for_datatype(F32{}, F16{});
else if(data_type == ContractionDataType::F32_F32_F32_F32 && }
layout == ContractionMatrixLayout::KM_NK_MN_MN) else if(compute_data_type == ContractionComputeDataType::BF16)
{ {
return profile(Col{}, Col{}, Row{}, F32{}); return run_profile_for_datatype(F32{}, BF16{});
} }
else if(data_type == ContractionDataType::F64_F64_F64_F64 && else
layout == ContractionMatrixLayout::MK_KN_MN_MN) {
{ std::cout << "Incorrect combination of data type and compute data type." << std::endl;
return profile(Row{}, Row{}, Row{}, F64{}); return 1;
} }
else if(data_type == ContractionDataType::F64_F64_F64_F64 &&
layout == ContractionMatrixLayout::MK_NK_MN_MN)
{
return profile(Row{}, Col{}, Row{}, F64{});
} }
else if(data_type == ContractionDataType::F64_F64_F64_F64 && else if(data_type == ContractionDataType::F64_F64_F64_F64)
layout == ContractionMatrixLayout::KM_KN_MN_MN)
{ {
return profile(Col{}, Row{}, Row{}, F64{}); if(compute_data_type == ContractionComputeDataType::F64)
{
return run_profile_for_datatype(F64{}, F64{});
}
else if(compute_data_type == ContractionComputeDataType::F32)
{
return run_profile_for_datatype(F64{}, F32{});
}
else
{
std::cout << "Incorrect combination of data type and compute data type." << std::endl;
return 1;
}
} }
else if(data_type == ContractionDataType::F64_F64_F64_F64 && else if(data_type == ContractionDataType::F16_F16_F16_F16)
layout == ContractionMatrixLayout::KM_NK_MN_MN)
{ {
return profile(Col{}, Col{}, Row{}, F64{}); if(compute_data_type == ContractionComputeDataType::F32)
{
return run_profile_for_datatype(F16{}, F32{});
}
else
{
std::cout << "Incorrect combination of data type and compute data type." << std::endl;
return 1;
}
} }
else else if(data_type == ContractionDataType::BF16_BF16_BF16_BF16)
{ {
std::cout << "this data_type & layout is not implemented" << std::endl; if(compute_data_type == ContractionComputeDataType::F32)
{
return 1; return run_profile_for_datatype(BF16{}, F32{});
}
else
{
std::cout << "Incorrect combination of data type and compute data type." << std::endl;
return 1;
}
} }
return 1;
} }
REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_contraction_bilinear); REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_contraction_bilinear);
...@@ -17,8 +17,9 @@ ...@@ -17,8 +17,9 @@
static void print_helper_msg() static void print_helper_msg()
{ {
std::cout << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n" std::cout << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"
<< "arg2: data type (0: fp32; 1: f64)\n" << "arg2: data type (0: fp32; 1: f64; 2: f16; 3: bf16)\n"
<< "arg3: matrix layout (0: A[m0, m1, k0, k1] * B[k0, k1, n0, n1] + " << "arg3: compute data type (0: fp32; 1: f64; 2: f16; 3: bf16)\n"
<< "arg4: matrix layout (0: A[m0, m1, k0, k1] * B[k0, k1, n0, n1] + "
"D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n" "D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n"
<< " 1: A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + " << " 1: A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + "
"D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n" "D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n"
...@@ -26,39 +27,40 @@ static void print_helper_msg() ...@@ -26,39 +27,40 @@ static void print_helper_msg()
"D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n" "D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n"
<< " 3: A[k0, k1, m0, m1] * B[n0, n1, k0, k1] + " << " 3: A[k0, k1, m0, m1] * B[n0, n1, k0, k1] + "
"D[m0, m1, n0, n1] = E[m0, m1, n0, n1])\n" "D[m0, m1, n0, n1] = E[m0, m1, n0, n1])\n"
<< "arg4: verification (0: no; 1: yes)\n" << "arg5: verification (0: no; 1: yes)\n"
<< "arg5: initialization (0: no init; 1: integer value; 2: decimal " << "arg6: initialization (0: no init; 1: integer value; 2: decimal "
<< "value)\n" << "value)\n"
<< "arg6: print tensor value (0: no; 1: yes)\n" << "arg7: print tensor value (0: no; 1: yes)\n"
<< "arg7: time kernel (0: no, 1: yes)\n" << "arg8: time kernel (0: no, 1: yes)\n"
<< "arg8: alpha\n" << "arg9: alpha\n"
<< "arg9 to 14: M0, M1, N0, N1, K0, K1\n" << "arg10 to 15: M0, M1, N0, N1, K0, K1\n"
<< "arg15 to 30: Strides for A, B, D and E (skip for default)\n" << "arg16 to 31: Strides for A, B, D and E (skip for default)\n"
<< std::endl; << std::endl;
} }
int profile_contraction_scale(int argc, char* argv[]) int profile_contraction_scale(int argc, char* argv[])
{ {
const bool default_strides = argc == 15; const bool default_strides = argc == 16;
if(argc != 31 && argc != 15) if(argc != 32 && argc != 16)
{ {
print_helper_msg(); print_helper_msg();
exit(1); exit(1);
} }
const auto data_type = static_cast<ContractionDataType>(std::stoi(argv[2])); const auto data_type = static_cast<ContractionDataType>(std::stoi(argv[2]));
const auto layout = static_cast<ContractionMatrixLayout>(std::stoi(argv[3])); const auto compute_data_type = static_cast<ContractionComputeDataType>(std::stoi(argv[3]));
const bool do_verification = std::stoi(argv[4]); const auto layout = static_cast<ContractionMatrixLayout>(std::stoi(argv[4]));
const ck::index_t init_method = std::stoi(argv[5]); const bool do_verification = std::stoi(argv[5]);
const bool do_log = std::stoi(argv[6]); const ck::index_t init_method = std::stoi(argv[6]);
const bool time_kernel = std::stoi(argv[7]); const bool do_log = std::stoi(argv[7]);
const float alpha = std::stof(argv[8]); const bool time_kernel = std::stoi(argv[8]);
const float alpha = std::stof(argv[9]);
std::vector<ck::index_t> M; std::vector<ck::index_t> M;
std::vector<ck::index_t> N; std::vector<ck::index_t> N;
std::vector<ck::index_t> K; std::vector<ck::index_t> K;
const ck::index_t dims_arg_num = 9; const ck::index_t dims_arg_num = 10;
collect_index_params(argv, M, dims_arg_num, 2); collect_index_params(argv, M, dims_arg_num, 2);
collect_index_params(argv, N, dims_arg_num + 2, 2); collect_index_params(argv, N, dims_arg_num + 2, 2);
collect_index_params(argv, K, dims_arg_num + 4, 2); collect_index_params(argv, K, dims_arg_num + 4, 2);
...@@ -75,88 +77,131 @@ int profile_contraction_scale(int argc, char* argv[]) ...@@ -75,88 +77,131 @@ int profile_contraction_scale(int argc, char* argv[])
collect_index_params(argv, StridesD, dims_arg_num + 18, 4); collect_index_params(argv, StridesD, dims_arg_num + 18, 4);
} }
using F32 = float; using F16 = ck::half_t;
using F64 = double; using BF16 = ck::bhalf_t;
using F32 = float;
auto profile = [&](auto a_layout, auto b_layout, auto cde_layout, auto type) { using F64 = double;
using ALayout = decltype(a_layout);
using BLayout = decltype(b_layout); auto profile =
using CDELayout = decltype(cde_layout); [&](auto a_layout, auto b_layout, auto cde_layout, auto type, auto compute_type) {
using ALayout = decltype(a_layout);
using DataType = decltype(type); using BLayout = decltype(b_layout);
using CDELayout = decltype(cde_layout);
if(default_strides)
using DataType = decltype(type);
using ComputeDataType = decltype(compute_type);
if(default_strides)
{
assign_default_strides(a_layout, StridesA, {M[0], M[1], K[0], K[1]});
assign_default_strides(b_layout, StridesB, {N[0], N[1], K[0], K[1]});
assign_default_strides(cde_layout, StridesE, {M[0], M[1], N[0], N[1]});
assign_default_strides(cde_layout, StridesD, {M[0], M[1], N[0], N[1]});
}
bool pass = ck::profiler::profile_contraction_impl<ALayout,
BLayout,
CDELayout,
DataType,
ComputeDataType,
ck::Tuple<>,
Scale>(do_verification,
init_method,
do_log,
time_kernel,
Scale{alpha},
M,
N,
K,
StridesA,
StridesB,
StridesE,
StridesD);
return pass;
};
auto run_profile_for_datatype = [&](auto type, auto compute_type) {
if(layout == ContractionMatrixLayout::MK_KN_MN_MN)
{ {
assign_default_strides(a_layout, StridesA, {M[0], M[1], K[0], K[1]}); return profile(Row{}, Row{}, Row{}, type, compute_type);
assign_default_strides(b_layout, StridesB, {K[0], K[1], N[0], N[1]});
assign_default_strides(cde_layout, StridesE, {M[0], M[1], N[0], N[1]});
assign_default_strides(cde_layout, StridesD, {M[0], M[1], N[0], N[1]});
} }
else if(layout == ContractionMatrixLayout::MK_NK_MN_MN)
bool pass = ck::profiler:: {
profile_contraction_impl<ALayout, BLayout, CDELayout, DataType, ck::Tuple<>, Scale>( return profile(Row{}, Col{}, Row{}, type, compute_type);
do_verification, }
init_method, else if(layout == ContractionMatrixLayout::KM_KN_MN_MN)
do_log, {
time_kernel, return profile(Col{}, Row{}, Row{}, type, compute_type);
Scale{alpha}, }
M, else if(layout == ContractionMatrixLayout::KM_NK_MN_MN)
N, {
K, return profile(Col{}, Col{}, Row{}, type, compute_type);
StridesA, }
StridesB, return false;
StridesE,
StridesD);
return pass;
}; };
if(data_type == ContractionDataType::F32_F32_F32_F32 && if(data_type == ContractionDataType::F32_F32_F32_F32)
layout == ContractionMatrixLayout::MK_KN_MN_MN)
{
return profile(Row{}, Row{}, Row{}, F32{});
}
else if(data_type == ContractionDataType::F32_F32_F32_F32 &&
layout == ContractionMatrixLayout::MK_NK_MN_MN)
{
return profile(Row{}, Col{}, Row{}, F32{});
}
else if(data_type == ContractionDataType::F32_F32_F32_F32 &&
layout == ContractionMatrixLayout::KM_KN_MN_MN)
{ {
return profile(Col{}, Row{}, Row{}, F32{}); if(compute_data_type == ContractionComputeDataType::F32)
} {
else if(data_type == ContractionDataType::F32_F32_F32_F32 && return run_profile_for_datatype(F32{}, F32{});
layout == ContractionMatrixLayout::KM_NK_MN_MN) }
{ else if(compute_data_type == ContractionComputeDataType::F16)
return profile(Col{}, Col{}, Row{}, F32{}); {
} return run_profile_for_datatype(F32{}, F16{});
else if(data_type == ContractionDataType::F64_F64_F64_F64 && }
layout == ContractionMatrixLayout::MK_KN_MN_MN) else if(compute_data_type == ContractionComputeDataType::BF16)
{ {
return profile(Row{}, Row{}, Row{}, F64{}); return run_profile_for_datatype(F32{}, BF16{});
} }
else if(data_type == ContractionDataType::F64_F64_F64_F64 && else
layout == ContractionMatrixLayout::MK_NK_MN_MN) {
{ std::cout << "Incorrect combination of data type and compute data type." << std::endl;
return profile(Row{}, Col{}, Row{}, F64{}); return 1;
}
} }
else if(data_type == ContractionDataType::F64_F64_F64_F64 && else if(data_type == ContractionDataType::F64_F64_F64_F64)
layout == ContractionMatrixLayout::KM_KN_MN_MN)
{ {
return profile(Col{}, Row{}, Row{}, F64{}); if(compute_data_type == ContractionComputeDataType::F64)
{
return run_profile_for_datatype(F64{}, F64{});
}
else if(compute_data_type == ContractionComputeDataType::F32)
{
return run_profile_for_datatype(F64{}, F32{});
}
else
{
std::cout << "Incorrect combination of data type and compute data type." << std::endl;
return 1;
}
} }
else if(data_type == ContractionDataType::F64_F64_F64_F64 && else if(data_type == ContractionDataType::F16_F16_F16_F16)
layout == ContractionMatrixLayout::KM_NK_MN_MN)
{ {
return profile(Col{}, Col{}, Row{}, F64{}); if(compute_data_type == ContractionComputeDataType::F32)
{
return run_profile_for_datatype(F16{}, F32{});
}
else
{
std::cout << "Incorrect combination of data type and compute data type." << std::endl;
return 1;
}
} }
else else if(data_type == ContractionDataType::BF16_BF16_BF16_BF16)
{ {
std::cout << "this data_type & layout is not implemented" << std::endl; if(compute_data_type == ContractionComputeDataType::F32)
{
return 1; return run_profile_for_datatype(BF16{}, F32{});
}
else
{
std::cout << "Incorrect combination of data type and compute data type." << std::endl;
return 1;
}
} }
return 1;
} }
REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_contraction_scale); REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_contraction_scale);
...@@ -19,7 +19,8 @@ enum struct RearrangeOp ...@@ -19,7 +19,8 @@ enum struct RearrangeOp
enum struct ConvLayout enum struct ConvLayout
{ {
NHWC, // 0 GNHWC, // 0
NHWGC, // 1
}; };
enum struct DataType enum struct DataType
...@@ -42,7 +43,8 @@ static void print_helper_msg() ...@@ -42,7 +43,8 @@ static void print_helper_msg()
<< " 1: Input fp16, Weight fp16, Output fp16\n" << " 1: Input fp16, Weight fp16, Output fp16\n"
<< " 2: Input bf16, Weight bf16, Output bf16\n" << " 2: Input bf16, Weight bf16, Output bf16\n"
<< " 3: Input int8, Weight int8, Output int8)\n" << " 3: Input int8, Weight int8, Output int8)\n"
<< "arg3: tensor layout (0: Input[N, Hi, Wi, C], Output[N * Ho * Wo, Y * X * C])\n" << "arg3: tensor layout (0: Input[G, N, Hi, Wi, C], Output[G * N * Ho * Wo, Y * X * C],\n"
<< " 1: Input[N, Hi, Wi, G, C], Output[N * Ho * Wo * G, Y * X * C])\n"
<< "arg4: verification (0: no, 1: yes)\n" << "arg4: verification (0: no, 1: yes)\n"
<< "arg5: initialization (0: no init, 1: integer value, 2: decimal value)\n" << "arg5: initialization (0: no init, 1: integer value, 2: decimal value)\n"
<< "arg6: print tensor value (0: no; 1: yes)\n" << "arg6: print tensor value (0: no; 1: yes)\n"
...@@ -114,11 +116,9 @@ int profile_conv_tensor_rearrange(int argc, char* argv[]) ...@@ -114,11 +116,9 @@ int profile_conv_tensor_rearrange(int argc, char* argv[])
return pass ? 0 : 1; return pass ? 0 : 1;
}; };
// Image To Column
if(rearrange_op == RearrangeOp::ImageToColumn) if(rearrange_op == RearrangeOp::ImageToColumn)
{ {
// NHWC if(layout == ConvLayout::GNHWC)
if(layout == ConvLayout::NHWC)
{ {
if(num_dim_spatial == 1) if(num_dim_spatial == 1)
{ {
...@@ -178,11 +178,70 @@ int profile_conv_tensor_rearrange(int argc, char* argv[]) ...@@ -178,11 +178,70 @@ int profile_conv_tensor_rearrange(int argc, char* argv[])
} }
} }
} }
else if(layout == ConvLayout::NHWGC)
{
if(num_dim_spatial == 1)
{
if(data_type == DataType::F32_F32)
{
return profile(I1, NWGC{}, F32{}, F32{}, ImageToColumn{});
}
else if(data_type == DataType::F16_F16)
{
return profile(I1, NWGC{}, F16{}, F16{}, ImageToColumn{});
}
else if(data_type == DataType::BF16_BF16)
{
return profile(I1, NWGC{}, BF16{}, BF16{}, ImageToColumn{});
}
else if(data_type == DataType::INT8_INT8)
{
return profile(I1, NWGC{}, INT8{}, INT8{}, ImageToColumn{});
}
}
else if(num_dim_spatial == 2)
{
if(data_type == DataType::F32_F32)
{
return profile(I2, NHWGC{}, F32{}, F32{}, ImageToColumn{});
}
else if(data_type == DataType::F16_F16)
{
return profile(I2, NHWGC{}, F16{}, F16{}, ImageToColumn{});
}
else if(data_type == DataType::BF16_BF16)
{
return profile(I2, NHWGC{}, BF16{}, BF16{}, ImageToColumn{});
}
else if(data_type == DataType::INT8_INT8)
{
return profile(I2, NHWGC{}, INT8{}, INT8{}, ImageToColumn{});
}
}
else if(num_dim_spatial == 3)
{
if(data_type == DataType::F32_F32)
{
return profile(I3, NDHWGC{}, F32{}, F32{}, ImageToColumn{});
}
else if(data_type == DataType::F16_F16)
{
return profile(I3, NDHWGC{}, F16{}, F16{}, ImageToColumn{});
}
else if(data_type == DataType::BF16_BF16)
{
return profile(I3, NDHWGC{}, BF16{}, BF16{}, ImageToColumn{});
}
else if(data_type == DataType::INT8_INT8)
{
return profile(I3, NDHWGC{}, INT8{}, INT8{}, ImageToColumn{});
}
}
}
} }
else if(rearrange_op == RearrangeOp::ColumnToImage) else if(rearrange_op == RearrangeOp::ColumnToImage)
{ {
// NHWC if(layout == ConvLayout::GNHWC)
if(layout == ConvLayout::NHWC)
{ {
if(num_dim_spatial == 1) if(num_dim_spatial == 1)
{ {
...@@ -242,6 +301,66 @@ int profile_conv_tensor_rearrange(int argc, char* argv[]) ...@@ -242,6 +301,66 @@ int profile_conv_tensor_rearrange(int argc, char* argv[])
} }
} }
} }
else if(layout == ConvLayout::NHWGC)
{
if(num_dim_spatial == 1)
{
if(data_type == DataType::F32_F32)
{
return profile(I1, NWGC{}, F32{}, F32{}, ColumnToImage{});
}
else if(data_type == DataType::F16_F16)
{
return profile(I1, NWGC{}, F16{}, F16{}, ColumnToImage{});
}
else if(data_type == DataType::BF16_BF16)
{
return profile(I1, NWGC{}, BF16{}, BF16{}, ColumnToImage{});
}
else if(data_type == DataType::INT8_INT8)
{
return profile(I1, NWGC{}, INT8{}, INT8{}, ColumnToImage{});
}
}
else if(num_dim_spatial == 2)
{
if(data_type == DataType::F32_F32)
{
return profile(I2, NHWGC{}, F32{}, F32{}, ColumnToImage{});
}
else if(data_type == DataType::F16_F16)
{
return profile(I2, NHWGC{}, F16{}, F16{}, ColumnToImage{});
}
else if(data_type == DataType::BF16_BF16)
{
return profile(I2, NHWGC{}, BF16{}, BF16{}, ColumnToImage{});
}
else if(data_type == DataType::INT8_INT8)
{
return profile(I2, NHWGC{}, INT8{}, INT8{}, ColumnToImage{});
}
}
else if(num_dim_spatial == 3)
{
if(data_type == DataType::F32_F32)
{
return profile(I3, NDHWGC{}, F32{}, F32{}, ColumnToImage{});
}
else if(data_type == DataType::F16_F16)
{
return profile(I3, NDHWGC{}, F16{}, F16{}, ColumnToImage{});
}
else if(data_type == DataType::BF16_BF16)
{
return profile(I3, NDHWGC{}, BF16{}, BF16{}, ColumnToImage{});
}
else if(data_type == DataType::INT8_INT8)
{
return profile(I3, NDHWGC{}, INT8{}, INT8{}, ColumnToImage{});
}
}
}
} }
std::cout << "this data_type & layout is not implemented" << std::endl; std::cout << "this data_type & layout is not implemented" << std::endl;
......
...@@ -27,6 +27,8 @@ enum struct GemmDataType ...@@ -27,6 +27,8 @@ enum struct GemmDataType
F16_F16_F16, // 1 F16_F16_F16, // 1
BF16_BF16_BF16, // 2 BF16_BF16_BF16, // 2
INT8_INT8_INT8, // 3 INT8_INT8_INT8, // 3
F8_F16_F16, // 4
F16_F8_F16, // 5
}; };
#define OP_NAME "grouped_gemm" #define OP_NAME "grouped_gemm"
...@@ -56,7 +58,7 @@ int profile_grouped_gemm(int argc, char* argv[]) ...@@ -56,7 +58,7 @@ int profile_grouped_gemm(int argc, char* argv[])
{ {
std::cout std::cout
<< "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n" << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"
<< "arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n" << "arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: fp8@fp6; 5: f16@f8)\n"
<< "arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n" << "arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"
<< " 1: A[m, k] * B[n, k] = C[m, n];\n" << " 1: A[m, k] * B[n, k] = C[m, n];\n"
<< " 2: A[k, m] * B[k, n] = C[m, n];\n" << " 2: A[k, m] * B[k, n] = C[m, n];\n"
...@@ -169,6 +171,46 @@ int profile_grouped_gemm(int argc, char* argv[]) ...@@ -169,6 +171,46 @@ int profile_grouped_gemm(int argc, char* argv[])
StrideCs, StrideCs,
kbatch); kbatch);
} }
else if(data_type == GemmDataType::F8_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{
ck::profiler::profile_grouped_gemm_impl<ck::f8_t,
ck::half_t,
ck::half_t,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(do_verification,
init_method,
do_log,
time_kernel,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
kbatch);
}
else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{
ck::profiler::profile_grouped_gemm_impl<ck::half_t,
ck::f8_t,
ck::half_t,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(do_verification,
init_method,
do_log,
time_kernel,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
kbatch);
}
else else
{ {
throw std::runtime_error("wrong! this GEMM data_type & layout is not implemented"); throw std::runtime_error("wrong! this GEMM data_type & layout is not implemented");
......
...@@ -8,8 +8,7 @@ MY_PROJECT_SOURCE=$1 ...@@ -8,8 +8,7 @@ MY_PROJECT_SOURCE=$1
cmake \ cmake \
-D CMAKE_PREFIX_PATH=/opt/rocm \ -D CMAKE_PREFIX_PATH=/opt/rocm \
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-D CMAKE_CXX_FLAGS="-std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker \ -D CMAKE_CXX_FLAGS="-std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker" \
-save-temps=$PWD" \
-D CMAKE_BUILD_TYPE=Release \ -D CMAKE_BUILD_TYPE=Release \
-D BUILD_DEV=ON \ -D BUILD_DEV=ON \
-D GPU_TARGETS="gfx908;gfx90a;gfx940" \ -D GPU_TARGETS="gfx908;gfx90a;gfx940" \
......
fips = no
setuid = root
setgid = root
pid = /var/run/stunnel.pid
debug = 7
options = NO_SSLv2
options = NO_SSLv3
[redis-cli]
client = yes
accept = 127.0.0.1:6379
#!/bin/bash
set -e
COMPILERS_HASH_DIR=${COMPILERS_HASH_DIR:-"/tmp/.sccache"}
SCCACHE_EXTRAFILES=${SCCACHE_EXTRAFILES:-"${COMPILERS_HASH_DIR}/rocm_compilers_hash_file"}
SCCACHE_BIN=${SCCACHE_BIN:-"${SCCACHE_INSTALL_LOCATION}/sccache"}
ENFORCE_REDIS="false"
while [ "$1" != "" ];
do
case $1 in
--enforce_redis )
shift; ENFORCE_REDIS="true" ;;
--no-hipcc )
shift ;;
*)
break ;;
esac
done
setup_rocm_compilers_hash_file() {
mkdir -p "$COMPILERS_HASH_DIR"
HIPCC_MD5="$(md5sum "${ROCM_PATH}/bin/hipcc")"
pushd "${ROCM_PATH}/amdgcn/bitcode"
DEVICELIBS_BITCODES_MD5="$(find . -type f -exec md5sum {} \; | sort | md5sum)"
popd
HIPCC_HASH_VALUE="${HIPCC_MD5%% *}"
DEVICELIBS_BITCODES_HASH_VALUE="${DEVICELIBS_BITCODES_MD5%% *}"
# MD5 checksums of clang and clang-offload-bundler cannot be used since they will keep changing
# if the ROCM_PATH changes, ie; for every mainline build.
# This is because ROCM_PATH gets encoded into the clang/clang-offload-bundler binaries as part
# of RPATH.
# The versions themselves contain the commit hash of the compiler repo at the time of building.
# Hence, this should be a viable alternative to using the binary checksum itself.
CLANG_VERSION="$("${ROCM_PATH}/llvm/bin/clang" --version | head -n 1)"
CLANG_OFFLOAD_BUNDLER_VERSION="$("${ROCM_PATH}/llvm/bin/clang-offload-bundler" --version | head -n 1)"
printf '%s: %s\n' 'clang version' "${CLANG_VERSION}" | tee -a "$SCCACHE_EXTRAFILES"
printf '%s: %s\n' 'clang-offload-bundler version' "${CLANG_OFFLOAD_BUNDLER_VERSION}" | tee -a "$SCCACHE_EXTRAFILES"
printf '%s: %s\n' 'hipcc md5sum' "${HIPCC_HASH_VALUE}" | tee -a "$SCCACHE_EXTRAFILES"
printf '%s: %s\n' 'devicelibs bitcode md5sum' "${DEVICELIBS_BITCODES_HASH_VALUE}" | tee -a "$SCCACHE_EXTRAFILES"
echo "sccache-wrapper: compilers hash file set up at ${SCCACHE_EXTRAFILES}"
cat "$SCCACHE_EXTRAFILES"
}
if [ "${ENFORCE_REDIS}" == "true" ]; then
if [ -z "${SCCACHE_REDIS}" ]; then
echo "SCCACHE_REDIS not set. Not wrapping compilers with sccache."
exit 10
else
response=$(redis-cli -u ${SCCACHE_REDIS} ping) || true
if [ "${response}" != "PONG" ]; then
echo "Redis server unreachable. Not wrapping compilers with sccache."
exit 20
fi
fi
fi
setup_rocm_compilers_hash_file
$SCCACHE_BIN --version
$SCCACHE_BIN --start-server
...@@ -13,31 +13,27 @@ function(add_test_executable TEST_NAME) ...@@ -13,31 +13,27 @@ function(add_test_executable TEST_NAME)
if(DEFINED DTYPES) if(DEFINED DTYPES)
foreach(source IN LISTS ARGN) foreach(source IN LISTS ARGN)
set(test 0) set(test 0)
foreach(type IN LISTS DTYPES) if((source MATCHES "_fp16" OR source MATCHES "_f16") AND NOT "fp16" IN_LIST DTYPES)
if(type MATCHES "fp16") set(test 1)
set(type1 "_f16") endif()
elseif(type MATCHES "fp32") if((source MATCHES "_fp32" OR source MATCHES "_f32") AND NOT "fp32" IN_LIST DTYPES)
set(type1 "_f32") set(test 1)
elseif(type MATCHES "fp8") endif()
set(type1 "_f8") if((source MATCHES "_fp64" OR source MATCHES "_f64") AND NOT "fp64" IN_LIST DTYPES)
elseif(type MATCHES "bf16") set(test 1)
set(type1 "_b16") endif()
elseif(type MATCHES "fp64") if((source MATCHES "_fp8" OR source MATCHES "_f8") AND NOT "fp8" IN_LIST DTYPES)
set(type1 "_f64") set(test 1)
elseif(type MATCHES "int8") endif()
set(type1 "_i8") if((source MATCHES "_bf8" OR source MATCHES "_bf8") AND NOT "bf8" IN_LIST DTYPES)
endif() set(test 1)
if("${source}" MATCHES "${type}" OR "${source}" MATCHES "${type1}") endif()
#if filename matches any selected type, exit type loop and do no exclude the file from the list if((source MATCHES "_bf16" OR source MATCHES "_b16") AND NOT "bf16" IN_LIST DTYPES)
set(test 0) set(test 1)
break() endif()
elseif((source MATCHES "fp8" OR source MATCHES "fp32" OR source MATCHES "fp64" OR source MATCHES "bf16" OR source MATCHES "int8" OR source MATCHES "fp16" OR if((source MATCHES "_int8" OR source MATCHES "_i8") AND NOT "int8" IN_LIST DTYPES)
source MATCHES "_f8" OR source MATCHES "_f32" OR source MATCHES "_f64" OR source MATCHES "_i8" OR source MATCHES "_f16" OR source MATCHES "_b16") AND set(test 1)
NOT(source MATCHES type OR source MATCHES type1)) endif()
#if filename contains a type which doesn't match any selected type, mark it for removal
set(test 1)
endif()
endforeach()
if(test EQUAL 1) if(test EQUAL 1)
message("removing test ${source} ") message("removing test ${source} ")
list(REMOVE_ITEM ARGN "${source}") list(REMOVE_ITEM ARGN "${source}")
...@@ -72,31 +68,27 @@ function(add_gtest_executable TEST_NAME) ...@@ -72,31 +68,27 @@ function(add_gtest_executable TEST_NAME)
if(DEFINED DTYPES) if(DEFINED DTYPES)
foreach(source IN LISTS ARGN) foreach(source IN LISTS ARGN)
set(test 0) set(test 0)
foreach(type IN LISTS DTYPES) if((source MATCHES "_fp16" OR source MATCHES "_f16") AND NOT "fp16" IN_LIST DTYPES)
if(type MATCHES "fp16") set(test 1)
set(type1 "_f16") endif()
elseif(type MATCHES "fp32") if((source MATCHES "_fp32" OR source MATCHES "_f32") AND NOT "fp32" IN_LIST DTYPES)
set(type1 "_f32") set(test 1)
elseif(type MATCHES "fp8") endif()
set(type1 "_f8") if((source MATCHES "_fp64" OR source MATCHES "_f64") AND NOT "fp64" IN_LIST DTYPES)
elseif(type MATCHES "bf16") set(test 1)
set(type1 "_b16") endif()
elseif(type MATCHES "fp64") if((source MATCHES "_fp8" OR source MATCHES "_f8") AND NOT "fp8" IN_LIST DTYPES)
set(type1 "_f64") set(test 1)
elseif(type MATCHES "int8") endif()
set(type1 "_i8") if((source MATCHES "_bf8" OR source MATCHES "_bf8") AND NOT "bf8" IN_LIST DTYPES)
endif() set(test 1)
if("${source}" MATCHES "${type}" OR "${source}" MATCHES "${type1}") endif()
#if filename matches any selected type, exit type loop and do no exclude the file from the list if((source MATCHES "_bf16" OR source MATCHES "_b16") AND NOT "bf16" IN_LIST DTYPES)
set(test 0) set(test 1)
break() endif()
elseif((source MATCHES "fp8" OR source MATCHES "fp32" OR source MATCHES "fp64" OR source MATCHES "bf16" OR source MATCHES "int8" OR source MATCHES "fp16" OR if((source MATCHES "_int8" OR source MATCHES "_i8") AND NOT "int8" IN_LIST DTYPES)
source MATCHES "_f8" OR source MATCHES "_f32" OR source MATCHES "_f64" OR source MATCHES "_i8" OR source MATCHES "_f16" OR source MATCHES "_b16") AND set(test 1)
NOT(source MATCHES type OR source MATCHES type1)) endif()
#if filename contains a type which doesn't match any selected type, mark it for removal
set(test 1)
endif()
endforeach()
if(test EQUAL 1) if(test EQUAL 1)
message("removing gtest ${source} ") message("removing gtest ${source} ")
list(REMOVE_ITEM ARGN "${source}") list(REMOVE_ITEM ARGN "${source}")
......
...@@ -10,9 +10,12 @@ ...@@ -10,9 +10,12 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "profiler/profile_contraction_impl.hpp" #include "profiler/profile_contraction_impl.hpp"
#include "profiler/profile_contraction_utils.hpp"
using F32 = float; using F16 = ck::half_t;
using F64 = double; using BF16 = ck::bhalf_t;
using F32 = float;
using F64 = double;
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
...@@ -20,49 +23,49 @@ using Col = ck::tensor_layout::gemm::ColumnMajor; ...@@ -20,49 +23,49 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
using Bilinear = ck::tensor_operation::element_wise::Bilinear; using Bilinear = ck::tensor_operation::element_wise::Bilinear;
using Scale = ck::tensor_operation::element_wise::Scale; using Scale = ck::tensor_operation::element_wise::Scale;
struct MemoryParams struct Dimensions
{ {
std::vector<ck::index_t> M; std::vector<ck::index_t> M;
std::vector<ck::index_t> N; std::vector<ck::index_t> N;
std::vector<ck::index_t> K; std::vector<ck::index_t> K;
std::vector<ck::index_t> StridesA;
std::vector<ck::index_t> StridesB;
std::vector<ck::index_t> StridesC;
std::vector<ck::index_t> StridesD;
}; };
template <typename Tuple> template <typename Tuple>
class TestContraction : public ::testing::Test class TestContraction : public ::testing::Test
{ {
protected: protected:
using ALayout = std::tuple_element_t<0, Tuple>; using ALayout = std::tuple_element_t<0, Tuple>;
using BLayout = std::tuple_element_t<1, Tuple>; using BLayout = std::tuple_element_t<1, Tuple>;
using CDLayout = std::tuple_element_t<2, Tuple>; using CDLayout = std::tuple_element_t<2, Tuple>;
using DataType = std::tuple_element_t<3, Tuple>; using DataType = std::tuple_element_t<3, Tuple>;
using DTupleDataType = std::tuple_element_t<4, Tuple>; using DTupleDataType = std::tuple_element_t<4, Tuple>;
using CDElementOp = std::tuple_element_t<5, Tuple>; using ComputeDataType = std::tuple_element_t<5, Tuple>;
using CDElementOp = std::tuple_element_t<6, Tuple>;
std::vector<MemoryParams> list_of_memory_params = {{{32, 32},
{32, 32}, std::vector<Dimensions> dimension_list = {{{32, 32}, {32, 32}, {32, 32}},
{32, 32}, {{16, 16}, {32, 32}, {16, 16}}};
{32768, 1024, 32, 1},
{32768, 1024, 32, 1}, std::vector<ck::index_t> init_methods = {1, 2};
{32768, 1024, 32, 1},
{32768, 1024, 32, 1}},
{{16, 16},
{32, 32},
{16, 16},
{4096, 256, 16, 1},
{16, 1, 8192, 256},
{16384, 1024, 32, 1},
{16384, 1024, 32, 1}}};
std::vector<ck::index_t> init_methods = {0, 1, 2};
std::unique_ptr<CDElementOp> p_cd_element_op; std::unique_ptr<CDElementOp> p_cd_element_op;
void Run() void Run()
{ {
for(auto& memory_params : list_of_memory_params) for(auto& dimension_params : dimension_list)
{ {
std::vector<ck::index_t> StridesA;
std::vector<ck::index_t> StridesB;
std::vector<ck::index_t> StridesC;
std::vector<ck::index_t> StridesD;
const auto& M = dimension_params.M;
const auto& N = dimension_params.N;
const auto& K = dimension_params.K;
assign_default_strides(ALayout{}, StridesA, {M[0], M[1], K[0], K[1]});
assign_default_strides(BLayout{}, StridesB, {N[0], N[1], K[0], K[1]});
assign_default_strides(CDLayout{}, StridesC, {M[0], M[1], N[0], N[1]});
assign_default_strides(CDLayout{}, StridesD, {M[0], M[1], N[0], N[1]});
for(const ck::index_t init_method : init_methods) for(const ck::index_t init_method : init_methods)
{ {
bool pass = bool pass =
...@@ -70,19 +73,20 @@ class TestContraction : public ::testing::Test ...@@ -70,19 +73,20 @@ class TestContraction : public ::testing::Test
BLayout, BLayout,
CDLayout, CDLayout,
DataType, DataType,
ComputeDataType,
DTupleDataType, DTupleDataType,
CDElementOp>(true /*do_verification*/, CDElementOp>(true /*do_verification*/,
init_method, init_method,
false /*do_logs*/, false /*do_logs*/,
false /*time_kernel*/, false /*time_kernel*/,
*p_cd_element_op, *p_cd_element_op,
memory_params.M, dimension_params.M,
memory_params.N, dimension_params.N,
memory_params.K, dimension_params.K,
memory_params.StridesA, StridesA,
memory_params.StridesB, StridesB,
memory_params.StridesC, StridesC,
memory_params.StridesD); StridesD);
EXPECT_TRUE(pass); EXPECT_TRUE(pass);
} }
} }
...@@ -99,24 +103,18 @@ class TestContractionBilinear : public TestContraction<Tuple> ...@@ -99,24 +103,18 @@ class TestContractionBilinear : public TestContraction<Tuple>
{ {
}; };
#define ALL_LAYOUT_COMBINATIONS(dt, tuple_dt, compute_dt, op) \
std::tuple<Row, Row, Row, dt, tuple_dt, compute_dt, op>, \
std::tuple<Row, Col, Row, dt, tuple_dt, compute_dt, op>, \
std::tuple<Col, Row, Row, dt, tuple_dt, compute_dt, op>, \
std::tuple<Col, Col, Row, dt, tuple_dt, compute_dt, op>
using BilinearKernelTypes = using BilinearKernelTypes =
::testing::Types<std::tuple<Row, Row, Row, F32, ck::Tuple<F32>, Bilinear>, ::testing::Types<ALL_LAYOUT_COMBINATIONS(F32, ck::Tuple<F32>, F32, Bilinear),
std::tuple<Row, Col, Row, F32, ck::Tuple<F32>, Bilinear>, ALL_LAYOUT_COMBINATIONS(F64, ck::Tuple<F64>, F64, Bilinear)>;
std::tuple<Col, Row, Row, F32, ck::Tuple<F32>, Bilinear>,
std::tuple<Col, Col, Row, F32, ck::Tuple<F32>, Bilinear>, using ScaleKernelTypes = ::testing::Types<ALL_LAYOUT_COMBINATIONS(F32, ck::Tuple<>, F32, Scale),
std::tuple<Row, Row, Row, F64, ck::Tuple<F32>, Bilinear>, ALL_LAYOUT_COMBINATIONS(F64, ck::Tuple<>, F64, Scale)>;
std::tuple<Row, Col, Row, F64, ck::Tuple<F32>, Bilinear>,
std::tuple<Col, Row, Row, F64, ck::Tuple<F32>, Bilinear>,
std::tuple<Col, Col, Row, F64, ck::Tuple<F32>, Bilinear>>;
using ScaleKernelTypes = ::testing::Types<std::tuple<Row, Row, Row, F32, ck::Tuple<>, Scale>,
std::tuple<Row, Col, Row, F32, ck::Tuple<>, Scale>,
std::tuple<Col, Row, Row, F32, ck::Tuple<>, Scale>,
std::tuple<Col, Col, Row, F32, ck::Tuple<>, Scale>,
std::tuple<Row, Row, Row, F64, ck::Tuple<>, Scale>,
std::tuple<Row, Col, Row, F64, ck::Tuple<>, Scale>,
std::tuple<Col, Row, Row, F64, ck::Tuple<>, Scale>,
std::tuple<Col, Col, Row, F64, ck::Tuple<>, Scale>>;
TYPED_TEST_SUITE(TestContractionBilinear, BilinearKernelTypes); TYPED_TEST_SUITE(TestContractionBilinear, BilinearKernelTypes);
TYPED_TEST_SUITE(TestContractionScale, ScaleKernelTypes); TYPED_TEST_SUITE(TestContractionScale, ScaleKernelTypes);
...@@ -136,3 +134,46 @@ TYPED_TEST(TestContractionScale, scale) ...@@ -136,3 +134,46 @@ TYPED_TEST(TestContractionScale, scale)
this->p_cd_element_op = std::make_unique<Scale>(0.5f); this->p_cd_element_op = std::make_unique<Scale>(0.5f);
this->Run(); this->Run();
} }
template <typename Tuple>
class TestContractionScaleMixedPrecision : public TestContraction<Tuple>
{
};
template <typename Tuple>
class TestContractionBilinearMixedPrecision : public TestContraction<Tuple>
{
};
using BilinearKernelTypesMixedPrecision =
::testing::Types<ALL_LAYOUT_COMBINATIONS(F32, ck::Tuple<F32>, F16, Bilinear),
ALL_LAYOUT_COMBINATIONS(F32, ck::Tuple<F32>, BF16, Bilinear),
ALL_LAYOUT_COMBINATIONS(F64, ck::Tuple<F64>, F32, Bilinear),
ALL_LAYOUT_COMBINATIONS(F16, ck::Tuple<F16>, F32, Bilinear),
ALL_LAYOUT_COMBINATIONS(BF16, ck::Tuple<BF16>, F32, Bilinear)>;
using ScaleKernelTypesMixedPrecision =
::testing::Types<ALL_LAYOUT_COMBINATIONS(F32, ck::Tuple<>, F16, Scale),
ALL_LAYOUT_COMBINATIONS(F32, ck::Tuple<>, BF16, Scale),
ALL_LAYOUT_COMBINATIONS(F64, ck::Tuple<>, F32, Scale),
ALL_LAYOUT_COMBINATIONS(F16, ck::Tuple<>, F32, Scale),
ALL_LAYOUT_COMBINATIONS(BF16, ck::Tuple<>, F32, Scale)>;
TYPED_TEST_SUITE(TestContractionBilinearMixedPrecision, BilinearKernelTypesMixedPrecision);
TYPED_TEST_SUITE(TestContractionScaleMixedPrecision, ScaleKernelTypesMixedPrecision);
TYPED_TEST(TestContractionBilinearMixedPrecision, bilinear)
{
this->p_cd_element_op = std::make_unique<Bilinear>(1.f, 1.f);
this->Run();
this->p_cd_element_op = std::make_unique<Bilinear>(-0.5f, 0.5f);
this->Run();
}
TYPED_TEST(TestContractionScaleMixedPrecision, scale)
{
this->p_cd_element_op = std::make_unique<Scale>(1.f);
this->Run();
this->p_cd_element_op = std::make_unique<Scale>(0.5f);
this->Run();
}
...@@ -34,11 +34,11 @@ class ContractionInstanceWrapper ...@@ -34,11 +34,11 @@ class ContractionInstanceWrapper
static constexpr ck::index_t NumDim = 2; static constexpr ck::index_t NumDim = 2;
// clang-format off // clang-format off
using ContractionDeviceInstance = ck::tensor_operation::device:: using ContractionDeviceInstance = ck::tensor_operation::device::
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute|
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data|
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, F32, F32, F32, F32, ck::Tuple<F32>, F32, Pass, Pass, Bilinear, GemmSpec, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, ABlockTransferSrcVectorDim, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, BBlockTransferSrcVectorDim, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, CDEBlockTransferScalarPerVector>; DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, F32, F32, F32, F32, ck::Tuple<F32>, F32, Pass, Pass, Bilinear, GemmSpec, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, ABlockTransferSrcVectorDim, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, BBlockTransferSrcVectorDim, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, CDEBlockTransferScalarPerVector, F32>;
// clang-format on // clang-format on
bool isSupported(std::vector<ck::index_t>& ADims, bool isSupported(std::vector<ck::index_t>& ADims,
......
...@@ -45,14 +45,20 @@ class TestConvTensorRearrange : public ::testing::Test ...@@ -45,14 +45,20 @@ class TestConvTensorRearrange : public ::testing::Test
using namespace ck::tensor_layout::convolution; using namespace ck::tensor_layout::convolution;
using namespace ck::conv_tensor_rearrange_op; using namespace ck::conv_tensor_rearrange_op;
using KernelTypes1d = using KernelTypes1d = ::testing::Types<std::tuple<GNWC, ImageToColumn>,
::testing::Types<std::tuple<GNWC, ImageToColumn>, std::tuple<GNWC, ColumnToImage>>; std::tuple<GNWC, ColumnToImage>,
std::tuple<NWGC, ImageToColumn>,
std::tuple<NWGC, ColumnToImage>>;
using KernelTypes2d = using KernelTypes2d = ::testing::Types<std::tuple<GNHWC, ImageToColumn>,
::testing::Types<std::tuple<GNHWC, ImageToColumn>, std::tuple<GNHWC, ColumnToImage>>; std::tuple<GNHWC, ColumnToImage>,
std::tuple<NHWGC, ImageToColumn>,
std::tuple<NHWGC, ColumnToImage>>;
using KernelTypes3d = using KernelTypes3d = ::testing::Types<std::tuple<GNDHWC, ImageToColumn>,
::testing::Types<std::tuple<GNDHWC, ImageToColumn>, std::tuple<GNDHWC, ColumnToImage>>; std::tuple<GNDHWC, ColumnToImage>,
std::tuple<NDHWGC, ImageToColumn>,
std::tuple<NDHWGC, ColumnToImage>>;
template <typename Tuple> template <typename Tuple>
class TestConvTensorRearrange1d : public TestConvTensorRearrange<Tuple> class TestConvTensorRearrange1d : public TestConvTensorRearrange<Tuple>
...@@ -77,16 +83,16 @@ TYPED_TEST(TestConvTensorRearrange1d, Test1D) ...@@ -77,16 +83,16 @@ TYPED_TEST(TestConvTensorRearrange1d, Test1D)
{ {
this->conv_params.clear(); this->conv_params.clear();
this->conv_params.push_back({1, 1, 4, 1, 192, {3}, {28}, {1}, {1}, {1}, {1}}); this->conv_params.push_back({1, 2, 4, 1, 192, {3}, {28}, {1}, {1}, {1}, {1}});
this->conv_params.push_back({1, 1, 64, 1, 64, {3}, {14}, {1}, {1}, {1}, {1}}); this->conv_params.push_back({1, 2, 64, 1, 64, {3}, {14}, {1}, {1}, {1}, {1}});
this->conv_params.push_back({1, 1, 64, 1, 64, {1}, {7}, {3}, {1}, {0}, {0}}); this->conv_params.push_back({1, 2, 64, 1, 64, {1}, {7}, {3}, {1}, {0}, {0}});
this->conv_params.push_back({1, 1, 64, 1, 64, {1}, {3}, {1}, {1}, {0}, {0}}); this->conv_params.push_back({1, 2, 64, 1, 64, {1}, {3}, {1}, {1}, {0}, {0}});
// ScalarPerVector should be 1 // ScalarPerVector should be 1
this->conv_params.push_back({1, 1, 4, 1, 1, {3}, {28}, {1}, {1}, {1}, {1}}); this->conv_params.push_back({1, 2, 4, 1, 1, {3}, {28}, {1}, {1}, {1}, {1}});
// stride != 1 // stride != 1
this->conv_params.push_back({1, 1, 1, 1, 4, {3}, {28}, {2}, {1}, {1}, {1}}); this->conv_params.push_back({1, 2, 1, 1, 4, {3}, {28}, {2}, {1}, {1}, {1}});
// dilation != 1 // dilation != 1
this->conv_params.push_back({1, 1, 1, 1, 4, {3}, {28}, {1}, {2}, {1}, {1}}); this->conv_params.push_back({1, 2, 1, 1, 4, {3}, {28}, {1}, {2}, {1}, {1}});
#ifdef CK_ENABLE_FP32 #ifdef CK_ENABLE_FP32
this->template Run<1, float, float>(); this->template Run<1, float, float>();
#endif #endif
...@@ -106,13 +112,13 @@ TYPED_TEST(TestConvTensorRearrange2d, Test2D) ...@@ -106,13 +112,13 @@ TYPED_TEST(TestConvTensorRearrange2d, Test2D)
this->conv_params.clear(); this->conv_params.clear();
this->conv_params.push_back( this->conv_params.push_back(
{2, 1, 4, 1, 192, {3, 3}, {28, 28}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); {2, 2, 4, 1, 192, {3, 3}, {28, 28}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back( this->conv_params.push_back(
{2, 1, 64, 1, 64, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); {2, 2, 64, 1, 64, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back({2, 1, 64, 1, 64, {1, 1}, {7, 7}, {3, 3}, {1, 1}, {0, 0}, {0, 0}}); this->conv_params.push_back({2, 1, 64, 1, 64, {1, 1}, {7, 7}, {3, 3}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back({2, 1, 64, 1, 64, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); this->conv_params.push_back({2, 1, 64, 1, 64, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back( this->conv_params.push_back(
{2, 1, 64, 1, 64, {3, 3}, {28, 28}, {2, 2}, {2, 2}, {1, 1}, {1, 1}}); {2, 2, 64, 1, 64, {3, 3}, {28, 28}, {2, 2}, {2, 2}, {1, 1}, {1, 1}});
#ifdef CK_ENABLE_FP32 #ifdef CK_ENABLE_FP32
this->template Run<2, float, float>(); this->template Run<2, float, float>();
#endif #endif
...@@ -131,13 +137,13 @@ TYPED_TEST(TestConvTensorRearrange3d, Test3D) ...@@ -131,13 +137,13 @@ TYPED_TEST(TestConvTensorRearrange3d, Test3D)
{ {
this->conv_params.clear(); this->conv_params.clear();
this->conv_params.push_back( this->conv_params.push_back(
{3, 1, 16, 1, 64, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {3, 3, 3}, {0, 0, 0}, {0, 0, 0}}); {3, 2, 16, 1, 64, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {3, 3, 3}, {0, 0, 0}, {0, 0, 0}});
this->conv_params.push_back( this->conv_params.push_back(
{3, 1, 2, 1, 64, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); {3, 2, 2, 1, 64, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->conv_params.push_back( this->conv_params.push_back(
{3, 1, 32, 1, 64, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); {3, 2, 32, 1, 64, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->conv_params.push_back( this->conv_params.push_back(
{3, 1, 64, 1, 64, {3, 3, 3}, {14, 14, 14}, {2, 2, 2}, {2, 2, 2}, {1, 1, 1}, {1, 1, 1}}); {3, 2, 64, 1, 64, {3, 3, 3}, {14, 14, 14}, {2, 2, 2}, {2, 2, 2}, {1, 1, 1}, {1, 1, 1}});
#ifdef CK_ENABLE_FP32 #ifdef CK_ENABLE_FP32
this->template Run<3, float, float>(); this->template Run<3, float, float>();
#endif #endif
......
...@@ -53,7 +53,7 @@ class TestConvTensorRearrangeInterface : public ::testing::Test ...@@ -53,7 +53,7 @@ class TestConvTensorRearrangeInterface : public ::testing::Test
template <typename ConvTensorRearrangeOp> template <typename ConvTensorRearrangeOp>
bool Run() bool Run()
{ {
const auto G = conv_param.G_;
const auto N = conv_param.N_; const auto N = conv_param.N_;
const auto C = conv_param.C_; const auto C = conv_param.C_;
const auto FakeC = const auto FakeC =
...@@ -71,13 +71,13 @@ class TestConvTensorRearrangeInterface : public ::testing::Test ...@@ -71,13 +71,13 @@ class TestConvTensorRearrangeInterface : public ::testing::Test
const auto image_desc = const auto image_desc =
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<ImLayout>( ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<ImLayout>(
conv_param); conv_param);
const auto gemm_desc = HostTensorDescriptor({NDoHoWo, CZYX}); const auto gemm_desc = HostTensorDescriptor({G, NDoHoWo, CZYX});
std::array<ck::index_t, NDimSpatial> input_spatial_lengths{}; std::array<ck::index_t, NDimSpatial> input_spatial_lengths{};
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths{}; std::array<ck::index_t, NDimSpatial> filter_spatial_lengths{};
std::array<ck::index_t, NDimSpatial> output_spatial_lengths{}; std::array<ck::index_t, NDimSpatial> output_spatial_lengths{};
std::array<ck::index_t, NDimSpatial + 3> input_g_n_c_wis_strides{}; std::array<ck::index_t, NDimSpatial + 3> input_g_n_c_wis_strides{};
std::array<ck::index_t, 2> output_m_k_strides{}; std::array<ck::index_t, 3> output_g_m_k_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_strides{}; std::array<ck::index_t, NDimSpatial> conv_filter_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_dilations{}; std::array<ck::index_t, NDimSpatial> conv_filter_dilations{};
std::array<ck::index_t, NDimSpatial> input_left_pads{}; std::array<ck::index_t, NDimSpatial> input_left_pads{};
...@@ -89,7 +89,7 @@ class TestConvTensorRearrangeInterface : public ::testing::Test ...@@ -89,7 +89,7 @@ class TestConvTensorRearrangeInterface : public ::testing::Test
copy(conv_param.filter_spatial_lengths_, filter_spatial_lengths); copy(conv_param.filter_spatial_lengths_, filter_spatial_lengths);
copy(conv_param.output_spatial_lengths_, output_spatial_lengths); copy(conv_param.output_spatial_lengths_, output_spatial_lengths);
copy(image_desc.GetStrides(), input_g_n_c_wis_strides); copy(image_desc.GetStrides(), input_g_n_c_wis_strides);
copy(gemm_desc.GetStrides(), output_m_k_strides); copy(gemm_desc.GetStrides(), output_g_m_k_strides);
copy(conv_param.conv_filter_strides_, conv_filter_strides); copy(conv_param.conv_filter_strides_, conv_filter_strides);
copy(conv_param.conv_filter_dilations_, conv_filter_dilations); copy(conv_param.conv_filter_dilations_, conv_filter_dilations);
copy(conv_param.input_left_pads_, input_left_pads); copy(conv_param.input_left_pads_, input_left_pads);
...@@ -100,13 +100,14 @@ class TestConvTensorRearrangeInterface : public ::testing::Test ...@@ -100,13 +100,14 @@ class TestConvTensorRearrangeInterface : public ::testing::Test
auto img2col = DeviceImgToColInstance{}; auto img2col = DeviceImgToColInstance{};
auto argument = img2col.MakeArgument(nullptr, auto argument = img2col.MakeArgument(nullptr,
nullptr, nullptr,
G,
N, N,
IsCPacked ? C : FakeC, IsCPacked ? C : FakeC,
input_spatial_lengths, input_spatial_lengths,
filter_spatial_lengths, filter_spatial_lengths,
output_spatial_lengths, output_spatial_lengths,
input_g_n_c_wis_strides, input_g_n_c_wis_strides,
output_m_k_strides, output_g_m_k_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
...@@ -119,13 +120,14 @@ class TestConvTensorRearrangeInterface : public ::testing::Test ...@@ -119,13 +120,14 @@ class TestConvTensorRearrangeInterface : public ::testing::Test
auto col2img = DeviceColToimgInstance{}; auto col2img = DeviceColToimgInstance{};
auto argument = col2img.MakeArgument(nullptr, auto argument = col2img.MakeArgument(nullptr,
nullptr, nullptr,
G,
N, N,
IsCPacked ? C : FakeC, IsCPacked ? C : FakeC,
input_spatial_lengths, input_spatial_lengths,
filter_spatial_lengths, filter_spatial_lengths,
output_spatial_lengths, output_spatial_lengths,
input_g_n_c_wis_strides, input_g_n_c_wis_strides,
output_m_k_strides, output_g_m_k_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
......
if (USE_BITINT_EXTENSION_INT4) if (USE_BITINT_EXTENSION_INT4)
add_gtest_executable(test_int4 int4.cpp) add_gtest_executable(test_int4 test_int4.cpp)
if(result EQUAL 0) if(result EQUAL 0)
target_link_libraries(test_int4 PRIVATE utility) target_link_libraries(test_int4 PRIVATE utility)
endif() endif()
endif() endif()
add_gtest_executable(test_fp8 fp8.cpp) add_gtest_executable(test_fp8 test_fp8.cpp)
if(result EQUAL 0) if(result EQUAL 0)
target_link_libraries(test_fp8 PRIVATE utility) target_link_libraries(test_fp8 PRIVATE utility)
endif() endif()
add_gtest_executable(test_bf8 bf8.cpp) add_gtest_executable(test_bf8 test_bf8.cpp)
if(result EQUAL 0) if(result EQUAL 0)
target_link_libraries(test_bf8 PRIVATE utility) target_link_libraries(test_bf8 PRIVATE utility)
endif() endif()
......
...@@ -108,6 +108,10 @@ TEST_F(TestGGemmSplitKInterface_MKNKMN, KLoops) ...@@ -108,6 +108,10 @@ TEST_F(TestGGemmSplitKInterface_MKNKMN, KLoops)
// kloops % 2 // kloops % 2
Ks = std::vector<int>{256, 512, 320, 768}; Ks = std::vector<int>{256, 512, 320, 768};
EXPECT_FALSE(
DefaultGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, kbatch));
Ks = std::vector<int>{256, 512, 384, 768};
EXPECT_TRUE( EXPECT_TRUE(
DefaultGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, kbatch)); DefaultGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, kbatch));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment