Commit fd650950 authored by M.Emin Ozturk's avatar M.Emin Ozturk
Browse files

clang

parent f1055b34
...@@ -317,7 +317,6 @@ bool run_mx_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) ...@@ -317,7 +317,6 @@ bool run_mx_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
std::cout << "Computing GEMM on host..." << std::endl; std::cout << "Computing GEMM on host..." << std::endl;
} }
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMXGemm<ADataType, using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMXGemm<ADataType,
BDataType, BDataType,
CDataType, CDataType,
...@@ -340,7 +339,6 @@ bool run_mx_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) ...@@ -340,7 +339,6 @@ bool run_mx_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
PassThrough{}, PassThrough{},
PassThrough{}); PassThrough{});
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
if(config.verbosity > 0) if(config.verbosity > 0)
...@@ -357,12 +355,10 @@ bool run_mx_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) ...@@ -357,12 +355,10 @@ bool run_mx_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
<< ((res_verified) ? " (PASSED!)" : " (FAILED!)") << std::endl; << ((res_verified) ? " (PASSED!)" : " (FAILED!)") << std::endl;
} }
res_verified = res_verified && ck::utils::check_err(c_m_n_device_result, res_verified = res_verified && ck::utils::check_err(c_m_n_device_result,
c_m_n_host_result, c_m_n_host_result,
"Error: Incorrect results!"); "Error: Incorrect results!");
if(config.verbosity > 0 && res_verified) if(config.verbosity > 0 && res_verified)
std::cout << "Done." << std::endl; std::cout << "Done." << std::endl;
} }
......
...@@ -30,8 +30,13 @@ auto calculate_rtol_atol(const ck_tile::index_t K, ...@@ -30,8 +30,13 @@ auto calculate_rtol_atol(const ck_tile::index_t K,
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
} }
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType, template <typename ADataType,
typename ALayout, typename BLayout, typename CLayout> typename BDataType,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename CLayout>
float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile::DeviceMem& b_k_n_dev_buf, ck_tile::DeviceMem& b_k_n_dev_buf,
ck_tile::DeviceMem& c_m_n_dev_buf, ck_tile::DeviceMem& c_m_n_dev_buf,
...@@ -57,9 +62,9 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, ...@@ -57,9 +62,9 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
args.stride_B = stride_B; args.stride_B = stride_B;
args.stride_C = stride_C; args.stride_C = stride_C;
float ave_time = gemm_calc<ADataType, BDataType, AccDataType, CDataType, float ave_time =
ALayout, BLayout, CLayout>( gemm_calc<ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
std::size_t flop = std::size_t(2) * M * N * K; std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_byte = std::size_t num_byte =
...@@ -69,14 +74,11 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, ...@@ -69,14 +74,11 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
std::cout << "Run Gemm kernel with M =" << M << " N =" << N << " K =" << K std::cout << "Run Gemm kernel with M =" << M << " N =" << N << " K =" << K
<< " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C << " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C
<< " A_Layout =" << ALayout::name << " A_Layout =" << ALayout::name << " B_Layout =" << BLayout::name
<< " B_Layout =" << BLayout::name << " C_Layout =" << CLayout::name << " A Type = " << DataTypeTraits<ADataType>::name
<< " C_Layout =" << CLayout::name << " B Type = " << DataTypeTraits<BDataType>::name
<< " A Type = " << DataTypeTraits<ADataType>::name << " C Type = " << DataTypeTraits<CDataType>::name << " : " << ave_time << " ms, "
<< " B Type = " << DataTypeTraits<BDataType>::name << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
<< " C Type = " << DataTypeTraits<CDataType>::name
<< " : " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< std::endl;
return ave_time; return ave_time;
} }
...@@ -92,10 +94,10 @@ int run_gemm_example_with_layouts(int argc, ...@@ -92,10 +94,10 @@ int run_gemm_example_with_layouts(int argc,
if(!result) if(!result)
return -1; return -1;
using ADataType = typename GemmBasicTypeConfig<PrecType>::ADataType; using ADataType = typename GemmBasicTypeConfig<PrecType>::ADataType;
using BDataType = typename GemmBasicTypeConfig<PrecType>::BDataType; using BDataType = typename GemmBasicTypeConfig<PrecType>::BDataType;
using CDataType = typename GemmBasicTypeConfig<PrecType>::CDataType; using CDataType = typename GemmBasicTypeConfig<PrecType>::CDataType;
using AccDataType = typename GemmBasicTypeConfig<PrecType>::AccDataType; using AccDataType = typename GemmBasicTypeConfig<PrecType>::AccDataType;
ck_tile::index_t M = arg_parser.get_int("m"); ck_tile::index_t M = arg_parser.get_int("m");
ck_tile::index_t N = arg_parser.get_int("n"); ck_tile::index_t N = arg_parser.get_int("n");
...@@ -133,19 +135,19 @@ int run_gemm_example_with_layouts(int argc, ...@@ -133,19 +135,19 @@ int run_gemm_example_with_layouts(int argc,
c_m_n_dev_buf.SetZero(); c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero(); c_m_n_dev_result.SetZero();
invoke_gemm<ADataType, BDataType, AccDataType, CDataType, invoke_gemm<ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout>(
ALayout, BLayout, CLayout>(a_m_k_dev_buf, a_m_k_dev_buf,
b_k_n_dev_buf, b_k_n_dev_buf,
c_m_n_dev_buf, c_m_n_dev_buf,
M, M,
N, N,
K, K,
stride_A, stride_A,
stride_B, stride_B,
stride_C, stride_C,
kbatch, kbatch,
n_warmup, n_warmup,
n_repeat); n_repeat);
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
bool pass = true; bool pass = true;
...@@ -160,9 +162,9 @@ int run_gemm_example_with_layouts(int argc, ...@@ -160,9 +162,9 @@ int run_gemm_example_with_layouts(int argc,
a_m_k, b_k_n, c_m_n_host_ref); a_m_k, b_k_n, c_m_n_host_ref);
const float max_accumulated_value = const float max_accumulated_value =
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType> const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
(K, kbatch, max_accumulated_value); K, kbatch, max_accumulated_value);
pass = ck_tile::check_err(c_m_n_dev_result, pass = ck_tile::check_err(c_m_n_dev_result,
c_m_n_host_ref, c_m_n_host_ref,
"Error: Incorrect results!", "Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}), rtol_atol.at(ck_tile::number<0>{}),
...@@ -218,9 +220,9 @@ int run_gemm_example_with_layouts(int argc, ...@@ -218,9 +220,9 @@ int run_gemm_example_with_layouts(int argc,
c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data()); c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data());
const float max_accumulated_value = const float max_accumulated_value =
*std::max_element(c_m_n_gpu_ref.mData.begin(), c_m_n_gpu_ref.mData.end()); *std::max_element(c_m_n_gpu_ref.mData.begin(), c_m_n_gpu_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType> const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
(K, kbatch, max_accumulated_value); K, kbatch, max_accumulated_value);
pass = ck_tile::check_err(c_m_n_dev_result, pass = ck_tile::check_err(c_m_n_dev_result,
c_m_n_gpu_ref, c_m_n_gpu_ref,
"Error: Incorrect results!", "Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}), rtol_atol.at(ck_tile::number<0>{}),
......
...@@ -224,12 +224,12 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 ...@@ -224,12 +224,12 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
}(); }();
// Pad both M and K to be multiples of the block sizes // Pad both M and K to be multiples of the block sizes
const auto a_grid_desc_m_k = transform_tensor_descriptor( const auto a_grid_desc_m_k =
a_grid_desc_mraw_kraw, transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_right_pad_transform(M, MPad - M), make_tuple(make_right_pad_transform(M, MPad - M),
make_right_pad_transform(K, KPad - K)), make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_grid_desc_m_k, a_grid_desc_m_k,
...@@ -322,14 +322,14 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 ...@@ -322,14 +322,14 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1)); return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1));
} }
}(); }();
// Pad both N and K to be multiples of the block sizes // Pad both N and K to be multiples of the block sizes
const auto b_grid_desc_n_k = transform_tensor_descriptor( const auto b_grid_desc_n_k =
b_grid_desc_nraw_kraw, transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_right_pad_transform(N, NPad - N), make_tuple(make_right_pad_transform(N, NPad - N),
make_right_pad_transform(K, KPad - K)), make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_n_k, b_grid_desc_n_k,
...@@ -990,7 +990,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 ...@@ -990,7 +990,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) &&
!(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)) !(is_same<tensor_layout::gemm::RowMajor, ALayout>::value))
{ {
if(!(karg.M % MPerBlock == 0)) if(!(karg.M % MPerBlock == 0))
...@@ -1008,7 +1008,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 ...@@ -1008,7 +1008,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) &&
(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)) (is_same<tensor_layout::gemm::RowMajor, BLayout>::value))
{ {
if(!(karg.N % NPerBlock == 0)) if(!(karg.N % NPerBlock == 0))
...@@ -1075,7 +1075,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 ...@@ -1075,7 +1075,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
<< ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl; << __LINE__ << ", in function: " << __func__ << std::endl;
} }
return false; return false;
} }
...@@ -1093,9 +1092,9 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 ...@@ -1093,9 +1092,9 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
<< __LINE__ << ", in function: " << __func__ << std::endl; << __LINE__ << ", in function: " << __func__ << std::endl;
} }
std::cout << "Arg N (" << karg.N std::cout << "Arg N (" << karg.N
<< ") value is not a multiple of BBlockTransferSrcScalarPerVector (" << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
<< BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl; << __LINE__ << ", in function: " << __func__ << std::endl;
return false; return false;
} }
} }
...@@ -1110,7 +1109,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 ...@@ -1110,7 +1109,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
<< BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl; << __LINE__ << ", in function: " << __func__ << std::endl;
} }
return false; return false;
} }
} }
...@@ -1128,7 +1127,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 ...@@ -1128,7 +1127,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__ << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl; << std::endl;
} }
return false; return false;
} }
} }
...@@ -1145,7 +1144,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 ...@@ -1145,7 +1144,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__ << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl; << std::endl;
} }
return false; return false;
} }
} }
......
...@@ -55,8 +55,7 @@ int profile_gemm_universal_streamk(int argc, char* argv[]) ...@@ -55,8 +55,7 @@ int profile_gemm_universal_streamk(int argc, char* argv[])
printf("arg18: memory for rotating buffer (default 0, size in MB)\n"); printf("arg18: memory for rotating buffer (default 0, size in MB)\n");
exit(1); exit(1);
} }
int M; int M;
int N; int N;
int StrideA; int StrideA;
...@@ -76,7 +75,7 @@ int profile_gemm_universal_streamk(int argc, char* argv[]) ...@@ -76,7 +75,7 @@ int profile_gemm_universal_streamk(int argc, char* argv[])
N = std::stoi(argv[9]); N = std::stoi(argv[9]);
StrideB = std::stoi(argv[12]); StrideB = std::stoi(argv[12]);
} }
const auto data_type = static_cast<GemmDataType>(std::stoi(argv[2])); const auto data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
const auto layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3])); const auto layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3]));
const bool do_verification = std::stoi(argv[4]); const bool do_verification = std::stoi(argv[4]);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment