Commit e00a943e authored by myamlak's avatar myamlak
Browse files

Merge remote-tracking branch 'origin/develop' into myamlak/cgemm

parents ffe12e2e 9f71ff48
add_example_executable(example_conv2d_bwd_data_xdl conv2d_bwd_data_xdl.cpp) add_example_executable(example_conv2d_bwd_data_xdl conv2d_bwd_data_xdl.cpp)
target_link_libraries(example_conv2d_bwd_data_xdl PRIVATE conv_fwd_util) target_link_libraries(example_conv2d_bwd_data_xdl PRIVATE conv_util)
...@@ -77,9 +77,9 @@ using ReferenceConvBwdInstance = ck::tensor_operation::host::ReferenceConvBwdDat ...@@ -77,9 +77,9 @@ using ReferenceConvBwdInstance = ck::tensor_operation::host::ReferenceConvBwdDat
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
bool do_verification = 0; bool do_verification = true;
int init_method = 0; int init_method = 1;
int nrepeat = 5; bool time_kernel = false;
// Conv shape // Conv shape
ck::index_t N = 128; ck::index_t N = 128;
...@@ -102,13 +102,13 @@ int main(int argc, char* argv[]) ...@@ -102,13 +102,13 @@ int main(int argc, char* argv[])
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
else if(argc == 19) else if(argc == 19)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
N = std::stoi(argv[4]); N = std::stoi(argv[4]);
K = std::stoi(argv[5]); K = std::stoi(argv[5]);
...@@ -130,7 +130,7 @@ int main(int argc, char* argv[]) ...@@ -130,7 +130,7 @@ int main(int argc, char* argv[])
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: run kernel # of times (>1)\n"); printf("arg3: time kernel (0=n0, 1=yes)\n");
printf("arg4 to 18: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, " printf("arg4 to 18: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, "
"RightPx\n"); "RightPx\n");
exit(0); exit(0);
...@@ -214,7 +214,7 @@ int main(int argc, char* argv[]) ...@@ -214,7 +214,7 @@ int main(int argc, char* argv[])
"not support this Conv problem"); "not support this Conv problem");
} }
float ave_time = invoker.Run(argument, nrepeat); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * N * K * Ho * Wo * C * Y * X; std::size_t flop = std::size_t(2) * N * K * Ho * Wo * C * Y * X;
...@@ -249,6 +249,10 @@ int main(int argc, char* argv[]) ...@@ -249,6 +249,10 @@ int main(int argc, char* argv[])
in_device_buf.FromDevice(in_n_c_hi_wi_device_result.mData.data()); in_device_buf.FromDevice(in_n_c_hi_wi_device_result.mData.data());
ck::utils::check_err(in_n_c_hi_wi_device_result.mData, in_n_c_hi_wi_host_result.mData); return ck::utils::check_err(in_n_c_hi_wi_device_result.mData,
in_n_c_hi_wi_host_result.mData)
? 0
: 1;
} }
return 0;
} }
add_example_executable(example_conv2d_bwd_weight_xdl conv2d_bwd_weight_xdl.cpp) add_example_executable(example_conv2d_bwd_weight_xdl conv2d_bwd_weight_xdl.cpp)
target_link_libraries(example_conv2d_bwd_weight_xdl PRIVATE conv_fwd_util) target_link_libraries(example_conv2d_bwd_weight_xdl PRIVATE conv_util)
...@@ -82,9 +82,9 @@ using ReferenceConvBwdWeightInstance = ...@@ -82,9 +82,9 @@ using ReferenceConvBwdWeightInstance =
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
bool do_verification = 0; bool do_verification = true;
int init_method = 0; int init_method = 1;
int nrepeat = 5; bool time_kernel = false;
int do_log = 0; int do_log = 0;
int split_k = 4; int split_k = 4;
...@@ -109,7 +109,7 @@ int main(int argc, char* argv[]) ...@@ -109,7 +109,7 @@ int main(int argc, char* argv[])
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
do_log = std::stoi(argv[4]); do_log = std::stoi(argv[4]);
split_k = std::stoi(argv[5]); split_k = std::stoi(argv[5]);
} }
...@@ -117,7 +117,7 @@ int main(int argc, char* argv[]) ...@@ -117,7 +117,7 @@ int main(int argc, char* argv[])
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
do_log = std::stoi(argv[4]); do_log = std::stoi(argv[4]);
split_k = std::stoi(argv[5]); split_k = std::stoi(argv[5]);
...@@ -141,7 +141,7 @@ int main(int argc, char* argv[]) ...@@ -141,7 +141,7 @@ int main(int argc, char* argv[])
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: run kernel # of times (>1)\n"); printf("arg3: time kernel (0=n0, 1=yes)\n");
printf("arg4: is show log (0=no, 1=yes)\n"); printf("arg4: is show log (0=no, 1=yes)\n");
printf("arg5: split-k \n"); printf("arg5: split-k \n");
printf("arg6 to 19: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, " printf("arg6 to 19: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, "
...@@ -246,7 +246,7 @@ int main(int argc, char* argv[]) ...@@ -246,7 +246,7 @@ int main(int argc, char* argv[])
return 1; return 1;
} }
float ave_time = invoker.Run(argument, nrepeat); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * N * K * Ho * Wo * C * Y * X; std::size_t flop = std::size_t(2) * N * K * Ho * Wo * C * Y * X;
...@@ -291,6 +291,9 @@ int main(int argc, char* argv[]) ...@@ -291,6 +291,9 @@ int main(int argc, char* argv[])
LogRangeAsType<float>(std::cout << "wei_host : ", wei_k_c_y_x_host_result.mData, ",") LogRangeAsType<float>(std::cout << "wei_host : ", wei_k_c_y_x_host_result.mData, ",")
<< std::endl; << std::endl;
} }
ck::utils::check_err(wei_k_c_y_x_device_result.mData, wei_k_c_y_x_host_result.mData); return ck::utils::check_err(wei_k_c_y_x_device_result.mData, wei_k_c_y_x_host_result.mData)
? 0
: 1;
} }
return 0;
} }
add_example_executable(example_reduce_blockwise reduce_blockwise.cpp) add_example_executable(example_reduce_blockwise reduce_blockwise.cpp -D 16,64,32,960 -v 1 1 10)
...@@ -116,10 +116,9 @@ class SimpleAppArgs ...@@ -116,10 +116,9 @@ class SimpleAppArgs
std::vector<size_t> inLengths; std::vector<size_t> inLengths;
std::vector<float> scales; std::vector<float> scales;
bool do_verification = false; bool do_verification = true;
int init_method = 1;
int init_method = 1; bool time_kernel = false;
int nrepeat = 5;
public: public:
void show_usage(const char* cmd) void show_usage(const char* cmd)
...@@ -135,7 +134,7 @@ class SimpleAppArgs ...@@ -135,7 +134,7 @@ class SimpleAppArgs
std::cout << "Arg1 -- init method (0=no init, 1=single integer value, 2=scope integer " std::cout << "Arg1 -- init method (0=no init, 1=single integer value, 2=scope integer "
"value, 3=decimal value)" "value, 3=decimal value)"
<< std::endl; << std::endl;
std::cout << "Arg2 -- number of repeats to run the kernel" << std::endl; std::cout << "Arg2 -- time kernel (0=n0, 1=yes)" << std::endl;
}; };
int processArgs(int argc, char* argv[]) int processArgs(int argc, char* argv[])
...@@ -182,7 +181,7 @@ class SimpleAppArgs ...@@ -182,7 +181,7 @@ class SimpleAppArgs
throw std::runtime_error("Invalid cmd-line arguments, more argumetns are needed!"); throw std::runtime_error("Invalid cmd-line arguments, more argumetns are needed!");
init_method = std::atoi(argv[optind++]); init_method = std::atoi(argv[optind++]);
nrepeat = std::atoi(argv[optind]); time_kernel = std::atoi(argv[optind]);
if(scales.empty()) if(scales.empty())
{ {
...@@ -352,7 +351,7 @@ int main(int argc, char* argv[]) ...@@ -352,7 +351,7 @@ int main(int argc, char* argv[])
auto invoker_ptr = reduce.MakeInvokerPointer(); auto invoker_ptr = reduce.MakeInvokerPointer();
float avg_time = invoker_ptr->Run(argument_ptr.get(), args.nrepeat); float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, args.time_kernel});
std::size_t num_bytes = invariant_total_length * reduce_total_length * sizeof(InDataType) + std::size_t num_bytes = invariant_total_length * reduce_total_length * sizeof(InDataType) +
invariant_total_length * sizeof(OutDataType); invariant_total_length * sizeof(OutDataType);
...@@ -362,16 +361,17 @@ int main(int argc, char* argv[]) ...@@ -362,16 +361,17 @@ int main(int argc, char* argv[])
std::cout << "Perf: " << avg_time << " ms, " << gb_per_sec << " GB/s, " << reduce_name std::cout << "Perf: " << avg_time << " ms, " << gb_per_sec << " GB/s, " << reduce_name
<< std::endl; << std::endl;
bool pass = true;
if(args.do_verification) if(args.do_verification)
{ {
out_dev.FromDevice(out.mData.data()); out_dev.FromDevice(out.mData.data());
ck::utils::check_err(out.mData, out_ref.mData); pass &= ck::utils::check_err(out.mData, out_ref.mData);
if(NeedIndices) if(NeedIndices)
{ {
out_indices_dev.FromDevice(out_indices.mData.data()); out_indices_dev.FromDevice(out_indices.mData.data());
ck::utils::check_err(out_indices.mData, out_indices_ref.mData); pass &= ck::utils::check_err(out_indices.mData, out_indices_ref.mData);
;
}; };
}; };
return pass ? 0 : 1;
} }
...@@ -149,9 +149,9 @@ int main(int argc, char* argv[]) ...@@ -149,9 +149,9 @@ int main(int argc, char* argv[])
{ {
using namespace ck::host_reduce; using namespace ck::host_reduce;
bool do_verification = 0; bool do_verification = true;
int init_method = 0; int init_method = 1;
int nrepeat = 5; bool time_kernel = false;
// Pool shape // Pool shape
ck::index_t N = 128; ck::index_t N = 128;
...@@ -171,13 +171,13 @@ int main(int argc, char* argv[]) ...@@ -171,13 +171,13 @@ int main(int argc, char* argv[])
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
else if(argc == 16) else if(argc == 16)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
N = std::stoi(argv[4]); N = std::stoi(argv[4]);
C = std::stoi(argv[5]); C = std::stoi(argv[5]);
...@@ -196,7 +196,7 @@ int main(int argc, char* argv[]) ...@@ -196,7 +196,7 @@ int main(int argc, char* argv[])
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: run kernel # of times (>1)\n"); printf("arg3: time kernel (0=n0, 1=yes)\n");
printf("arg4 to 15: N, C, Y, X, Hi, Wi, Sy, Sx, LeftPy, LeftPx, RightPy, " printf("arg4 to 15: N, C, Y, X, Hi, Wi, Sy, Sx, LeftPy, LeftPx, RightPy, "
"RightPx\n"); "RightPx\n");
exit(0); exit(0);
...@@ -271,7 +271,7 @@ int main(int argc, char* argv[]) ...@@ -271,7 +271,7 @@ int main(int argc, char* argv[])
"not support this problem"); "not support this problem");
} }
float ave_time = invoker_ptr->Run(argument_ptr.get(), nrepeat); float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * N * C * Ho * Wo * Y * X; std::size_t flop = std::size_t(2) * N * C * Ho * Wo * Y * X;
...@@ -285,6 +285,7 @@ int main(int argc, char* argv[]) ...@@ -285,6 +285,7 @@ int main(int argc, char* argv[])
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl; << std::endl;
bool pass = true;
if(do_verification) if(do_verification)
{ {
pool_host_verify<InDataType, pool_host_verify<InDataType,
...@@ -302,14 +303,15 @@ int main(int argc, char* argv[]) ...@@ -302,14 +303,15 @@ int main(int argc, char* argv[])
out_device_buf.FromDevice(out_n_c_ho_wo_device.mData.data()); out_device_buf.FromDevice(out_n_c_ho_wo_device.mData.data());
ck::utils::check_err(out_n_c_ho_wo_device.mData, out_n_c_ho_wo_host.mData); pass &= ck::utils::check_err(out_n_c_ho_wo_device.mData, out_n_c_ho_wo_host.mData);
if constexpr(NeedIndices) if constexpr(NeedIndices)
{ {
out_indices_device_buf.FromDevice(out_indices_n_c_ho_wo_device.mData.data()); out_indices_device_buf.FromDevice(out_indices_n_c_ho_wo_device.mData.data());
// ck::utils::check_err(out_indices_n_c_ho_wo_device.mData, pass &= ck::utils::check_err(out_indices_n_c_ho_wo_device.mData,
// out_indices_n_c_ho_wo_host.mData);; out_indices_n_c_ho_wo_host.mData);
}; };
} }
return pass ? 0 : 1;
} }
...@@ -105,9 +105,9 @@ using ReferenceGemmInstance = ck::tensor_operation::host:: ...@@ -105,9 +105,9 @@ using ReferenceGemmInstance = ck::tensor_operation::host::
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
bool do_verification = 0; bool do_verification = true;
int init_method = 0; int init_method = 1;
int nrepeat = 5; bool time_kernel = false;
// GEMM shape // GEMM shape
ck::index_t M = 3840; ck::index_t M = 3840;
...@@ -125,13 +125,13 @@ int main(int argc, char* argv[]) ...@@ -125,13 +125,13 @@ int main(int argc, char* argv[])
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
else if(argc == 10) else if(argc == 10)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]); M = std::stoi(argv[4]);
N = std::stoi(argv[5]); N = std::stoi(argv[5]);
...@@ -145,7 +145,7 @@ int main(int argc, char* argv[]) ...@@ -145,7 +145,7 @@ int main(int argc, char* argv[])
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: run kernel # of times (>1)\n"); printf("arg3: time kernel (0=n0, 1=yes)\n");
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"); printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n");
exit(0); exit(0);
} }
...@@ -219,7 +219,7 @@ int main(int argc, char* argv[]) ...@@ -219,7 +219,7 @@ int main(int argc, char* argv[])
"not support this GEMM problem"); "not support this GEMM problem");
} }
float ave_time = invoker.Run(argument, nrepeat); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * M * N * K; std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype = std::size_t num_btype =
...@@ -244,7 +244,7 @@ int main(int argc, char* argv[]) ...@@ -244,7 +244,7 @@ int main(int argc, char* argv[])
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); return ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData) ? 0 : 1;
} }
return 0; return 0;
......
...@@ -60,21 +60,21 @@ using ReferenceGemmInstance = ck::tensor_operation::host:: ...@@ -60,21 +60,21 @@ using ReferenceGemmInstance = ck::tensor_operation::host::
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
bool do_verification = 0; bool do_verification = true;
int init_method = 0; int init_method = 1;
int nrepeat = 5; bool time_kernel = false;
if(argc == 4) if(argc == 4)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
else else
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: run kernel # of times (>1)\n"); printf("arg3: time kernel (0=n0, 1=yes)\n");
exit(0); exit(0);
} }
...@@ -202,7 +202,7 @@ int main(int argc, char* argv[]) ...@@ -202,7 +202,7 @@ int main(int argc, char* argv[])
"not support this GEMM problem"); "not support this GEMM problem");
} }
float ave_time = invoker.Run(argument, nrepeat); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
...@@ -211,6 +211,7 @@ int main(int argc, char* argv[]) ...@@ -211,6 +211,7 @@ int main(int argc, char* argv[])
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl; << gemm.GetTypeString() << std::endl;
bool pass = true;
if(do_verification) if(do_verification)
{ {
for(std::size_t i = 0; i < gemm_shapes.size(); i++) for(std::size_t i = 0; i < gemm_shapes.size(); i++)
...@@ -227,9 +228,9 @@ int main(int argc, char* argv[]) ...@@ -227,9 +228,9 @@ int main(int argc, char* argv[])
c_element_op); c_element_op);
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
ck::utils::check_err(c_device_tensors[i].mData, c_host_tensors[i].mData); pass &= ck::utils::check_err(c_device_tensors[i].mData, c_host_tensors[i].mData);
} }
} }
return 0; return pass ? 0 : 1;
} }
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include <cstdlib> #include <cstdlib>
#include <stdlib.h> #include <stdlib.h>
#include <half.hpp> #include <half.hpp>
#include "check_err.hpp"
#include "config.hpp" #include "config.hpp"
#include "device.hpp" #include "device.hpp"
#include "host_tensor.hpp" #include "host_tensor.hpp"
...@@ -58,9 +59,9 @@ using ReferenceGemmInstance = ck::tensor_operation::host:: ...@@ -58,9 +59,9 @@ using ReferenceGemmInstance = ck::tensor_operation::host::
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
bool do_verification = 1; bool do_verification = true;
int init_method = 1; int init_method = 1;
int nrepeat = 5; bool time_kernel = false;
// GEMM shape // GEMM shape
ck::index_t M = 3840; ck::index_t M = 3840;
...@@ -79,13 +80,13 @@ int main(int argc, char* argv[]) ...@@ -79,13 +80,13 @@ int main(int argc, char* argv[])
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
else if(argc == 10) else if(argc == 10)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]); M = std::stoi(argv[4]);
N = std::stoi(argv[5]); N = std::stoi(argv[5]);
...@@ -99,7 +100,7 @@ int main(int argc, char* argv[]) ...@@ -99,7 +100,7 @@ int main(int argc, char* argv[])
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: run kernel # of times (>1)\n"); printf("arg3: time kernel (0=n0, 1=yes)\n");
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"); printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n");
exit(0); exit(0);
} }
...@@ -192,30 +193,13 @@ int main(int argc, char* argv[]) ...@@ -192,30 +193,13 @@ int main(int argc, char* argv[])
"not support this GEMM problem"); "not support this GEMM problem");
} }
// warm up // init DO, D1 to 0
invoker.Run(argument); d0_device_buf.SetZero();
d1_device_buf.SetZero();
// timing // if time_kernel == true, kernel will run multiple times. This kernel use atomic-add so result
float total_time = 0; // will not be correct. need to set time_kernel = false for correctness test
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
for(int i = 0; i < nrepeat; ++i)
{
// init DO, D1 to 0
d0_device_buf.SetZero();
d1_device_buf.SetZero();
KernelTimer timer;
timer.Start();
invoker.Run(argument);
timer.End();
total_time += timer.GetElapsedTime();
}
float ave_time = total_time / nrepeat;
std::size_t flop = std::size_t(2) * M * N * K; std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype = std::size_t num_btype =
...@@ -228,6 +212,7 @@ int main(int argc, char* argv[]) ...@@ -228,6 +212,7 @@ int main(int argc, char* argv[])
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl; << gemm.GetTypeString() << std::endl;
bool pass = true;
if(do_verification) if(do_verification)
{ {
c_device_buf.FromDevice(c_m_n_device_result.mData.data()); c_device_buf.FromDevice(c_m_n_device_result.mData.data());
...@@ -264,10 +249,19 @@ int main(int argc, char* argv[]) ...@@ -264,10 +249,19 @@ int main(int argc, char* argv[])
d1_m_host_result(m) = ck::type_convert<DDataType>(d1_acc); d1_m_host_result(m) = ck::type_convert<DDataType>(d1_acc);
} }
check_error(c_m_n_host_result, c_m_n_device_result); pass &= ck::utils::check_err(
check_error(d0_m_host_result, d0_m_device_result); c_m_n_device_result.mData, c_m_n_host_result.mData, "Error: Incorrect results c");
check_error(d1_m_host_result, d1_m_device_result); pass &= ck::utils::check_err(d0_m_device_result.mData,
d0_m_host_result.mData,
"Error: Incorrect results d0",
1e-3,
1e-3);
pass &= ck::utils::check_err(d1_m_device_result.mData,
d1_m_host_result.mData,
"Error: Incorrect results d1",
1e-3,
1e-3);
} }
return 0; return pass ? 0 : 1;
} }
add_example_executable(example_convnd_bwd_data_xdl convnd_bwd_data_xdl.cpp) add_example_executable(example_convnd_bwd_data_xdl convnd_bwd_data_xdl.cpp)
target_link_libraries(example_convnd_bwd_data_xdl PRIVATE conv_fwd_util) target_link_libraries(example_convnd_bwd_data_xdl PRIVATE conv_util)
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include <half.hpp> #include <half.hpp>
#include "config.hpp" #include "config.hpp"
#include "conv_fwd_util.hpp" #include "conv_util.hpp"
#include "print.hpp" #include "print.hpp"
#include "device.hpp" #include "device.hpp"
#include "host_tensor.hpp" #include "host_tensor.hpp"
...@@ -87,7 +87,7 @@ void print_use_msg() ...@@ -87,7 +87,7 @@ void print_use_msg()
{ {
std::cout << "arg1: verification (0=no, 1=yes)\n" std::cout << "arg1: verification (0=no, 1=yes)\n"
<< "arg2: initialization (0=no init, 1=random value, 2= init to 1 )\n" << "arg2: initialization (0=no init, 1=random value, 2= init to 1 )\n"
<< "arg3: run kernel # of times (>1)\n" << "arg3: time kernel (0=n0, 1=yes)\n"
<< "arg4: N spatial dimensions (default 2)\n" << "arg4: N spatial dimensions (default 2)\n"
<< "Following arguments (depending on number of spatial dims):\n" << "Following arguments (depending on number of spatial dims):\n"
<< " N, K, C, \n" << " N, K, C, \n"
...@@ -105,40 +105,40 @@ ck::utils::conv::ConvParams parse_conv_params(int num_dim_spatial, char* argv[]) ...@@ -105,40 +105,40 @@ ck::utils::conv::ConvParams parse_conv_params(int num_dim_spatial, char* argv[])
ck::utils::conv::ConvParams params; ck::utils::conv::ConvParams params;
int arg_idx = 5; int arg_idx = 5;
params.num_dim_spatial = num_dim_spatial; params.num_dim_spatial_ = num_dim_spatial;
params.N = std::stoi(argv[arg_idx++]); params.N_ = std::stoi(argv[arg_idx++]);
params.K = std::stoi(argv[arg_idx++]); params.K_ = std::stoi(argv[arg_idx++]);
params.C = std::stoi(argv[arg_idx++]); params.C_ = std::stoi(argv[arg_idx++]);
params.filter_spatial_lengths.resize(num_dim_spatial); params.filter_spatial_lengths_.resize(num_dim_spatial);
for(int i = 0; i < num_dim_spatial; ++i) for(int i = 0; i < num_dim_spatial; ++i)
{ {
params.filter_spatial_lengths[i] = std::stoi(argv[arg_idx++]); params.filter_spatial_lengths_[i] = std::stoi(argv[arg_idx++]);
} }
params.input_spatial_lengths.resize(num_dim_spatial); params.input_spatial_lengths_.resize(num_dim_spatial);
for(int i = 0; i < num_dim_spatial; ++i) for(int i = 0; i < num_dim_spatial; ++i)
{ {
params.input_spatial_lengths[i] = std::stoi(argv[arg_idx++]); params.input_spatial_lengths_[i] = std::stoi(argv[arg_idx++]);
} }
params.conv_filter_strides.resize(num_dim_spatial); params.conv_filter_strides_.resize(num_dim_spatial);
for(int i = 0; i < num_dim_spatial; ++i) for(int i = 0; i < num_dim_spatial; ++i)
{ {
params.conv_filter_strides[i] = std::stoi(argv[arg_idx++]); params.conv_filter_strides_[i] = std::stoi(argv[arg_idx++]);
} }
params.conv_filter_dilations.resize(num_dim_spatial); params.conv_filter_dilations_.resize(num_dim_spatial);
for(int i = 0; i < num_dim_spatial; ++i) for(int i = 0; i < num_dim_spatial; ++i)
{ {
params.conv_filter_dilations[i] = std::stoi(argv[arg_idx++]); params.conv_filter_dilations_[i] = std::stoi(argv[arg_idx++]);
} }
params.input_left_pads.resize(num_dim_spatial); params.input_left_pads_.resize(num_dim_spatial);
for(int i = 0; i < num_dim_spatial; ++i) for(int i = 0; i < num_dim_spatial; ++i)
{ {
params.input_left_pads[i] = std::stoi(argv[arg_idx++]); params.input_left_pads_[i] = std::stoi(argv[arg_idx++]);
} }
params.input_right_pads.resize(num_dim_spatial); params.input_right_pads_.resize(num_dim_spatial);
for(int i = 0; i < num_dim_spatial; ++i) for(int i = 0; i < num_dim_spatial; ++i)
{ {
params.input_right_pads[i] = std::stoi(argv[arg_idx++]); params.input_right_pads_[i] = std::stoi(argv[arg_idx++]);
} }
return params; return params;
...@@ -165,25 +165,25 @@ DeviceConvBwdDataBasePtr get_conv_instance(int num_dim_spatial) ...@@ -165,25 +165,25 @@ DeviceConvBwdDataBasePtr get_conv_instance(int num_dim_spatial)
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
bool do_verification = 0; bool do_verification = true;
int init_method = 0; int init_method = 1;
int nrepeat = 5; bool time_kernel = false;
int num_dim_spatial = 2; int num_dim_spatial = 2;
ck::utils::conv::ConvParams params; ck::utils::conv::ConvParams params;
params.C = 128; params.C_ = 128;
if(argc == 4) if(argc == 4)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
else if(argc > 4) else if(argc > 4)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
num_dim_spatial = std::stoi(argv[4]); num_dim_spatial = std::stoi(argv[4]);
// check args number // check args number
int conv_args = 3 + num_dim_spatial * 6; int conv_args = 3 + num_dim_spatial * 6;
...@@ -202,21 +202,21 @@ int main(int argc, char* argv[]) ...@@ -202,21 +202,21 @@ int main(int argc, char* argv[])
exit(1); exit(1);
} }
std::vector<std::size_t> input_dims{static_cast<std::size_t>(params.N), std::vector<std::size_t> input_dims{static_cast<std::size_t>(params.N_),
static_cast<std::size_t>(params.C)}; static_cast<std::size_t>(params.C_)};
input_dims.insert(std::end(input_dims), input_dims.insert(std::end(input_dims),
std::begin(params.input_spatial_lengths), std::begin(params.input_spatial_lengths_),
std::end(params.input_spatial_lengths)); std::end(params.input_spatial_lengths_));
std::vector<std::size_t> filter_dims{static_cast<std::size_t>(params.K), std::vector<std::size_t> filter_dims{static_cast<std::size_t>(params.K_),
static_cast<std::size_t>(params.C)}; static_cast<std::size_t>(params.C_)};
filter_dims.insert(std::end(filter_dims), filter_dims.insert(std::end(filter_dims),
std::begin(params.filter_spatial_lengths), std::begin(params.filter_spatial_lengths_),
std::end(params.filter_spatial_lengths)); std::end(params.filter_spatial_lengths_));
const std::vector<ck::index_t>& output_spatial_lengths = params.GetOutputSpatialLengths(); const std::vector<ck::index_t>& output_spatial_lengths = params.GetOutputSpatialLengths();
std::vector<std::size_t> output_dims{static_cast<std::size_t>(params.N), std::vector<std::size_t> output_dims{static_cast<std::size_t>(params.N_),
static_cast<std::size_t>(params.K)}; static_cast<std::size_t>(params.K_)};
output_dims.insert(std::end(output_dims), output_dims.insert(std::end(output_dims),
std::begin(output_spatial_lengths), std::begin(output_spatial_lengths),
std::end(output_spatial_lengths)); std::end(output_spatial_lengths));
...@@ -263,16 +263,16 @@ int main(int argc, char* argv[]) ...@@ -263,16 +263,16 @@ int main(int argc, char* argv[])
conv->MakeArgumentPointer(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()), conv->MakeArgumentPointer(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()), static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()), static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
params.N, params.N_,
params.K, params.K_,
params.C, params.C_,
params.input_spatial_lengths, params.input_spatial_lengths_,
params.filter_spatial_lengths, params.filter_spatial_lengths_,
output_spatial_lengths, output_spatial_lengths,
params.conv_filter_strides, params.conv_filter_strides_,
params.conv_filter_dilations, params.conv_filter_dilations_,
params.input_left_pads, params.input_left_pads_,
params.input_right_pads, params.input_right_pads_,
InElementOp{}, InElementOp{},
WeiElementOp{}, WeiElementOp{},
OutElementOp{}); OutElementOp{});
...@@ -284,16 +284,16 @@ int main(int argc, char* argv[]) ...@@ -284,16 +284,16 @@ int main(int argc, char* argv[])
"not support this Conv problem"); "not support this Conv problem");
} }
float ave_time = invoker->Run(argument.get(), nrepeat); float ave_time = invoker->Run(argument.get(), StreamConfig{nullptr, time_kernel});
std::size_t flop = ck::utils::conv::get_flops( std::size_t flop = ck::utils::conv::get_flops(
params.N, params.C, params.K, params.filter_spatial_lengths, output_spatial_lengths); params.N_, params.C_, params.K_, params.filter_spatial_lengths_, output_spatial_lengths);
std::size_t num_btype = ck::utils::conv::get_btype<InDataType, WeiDataType, OutDataType>( std::size_t num_btype = ck::utils::conv::get_btype<InDataType, WeiDataType, OutDataType>(
params.N, params.N_,
params.C, params.C_,
params.K, params.K_,
params.input_spatial_lengths, params.input_spatial_lengths_,
params.filter_spatial_lengths, params.filter_spatial_lengths_,
output_spatial_lengths); output_spatial_lengths);
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
...@@ -310,10 +310,10 @@ int main(int argc, char* argv[]) ...@@ -310,10 +310,10 @@ int main(int argc, char* argv[])
auto ref_argument = ref_conv.MakeArgument(in_n_c_hi_wi_host_result, auto ref_argument = ref_conv.MakeArgument(in_n_c_hi_wi_host_result,
wei_k_c_y_x, wei_k_c_y_x,
out_n_k_ho_wo, out_n_k_ho_wo,
params.conv_filter_strides, params.conv_filter_strides_,
params.conv_filter_dilations, params.conv_filter_dilations_,
params.input_left_pads, params.input_left_pads_,
params.input_right_pads, params.input_right_pads_,
InElementOp{}, InElementOp{},
WeiElementOp{}, WeiElementOp{},
OutElementOp{}); OutElementOp{});
...@@ -322,7 +322,10 @@ int main(int argc, char* argv[]) ...@@ -322,7 +322,10 @@ int main(int argc, char* argv[])
in_device_buf.FromDevice(in_n_c_hi_wi_device_result.mData.data()); in_device_buf.FromDevice(in_n_c_hi_wi_device_result.mData.data());
check_error(in_n_c_hi_wi_host_result, in_n_c_hi_wi_device_result); return ck::utils::check_err(in_n_c_hi_wi_device_result.mData,
in_n_c_hi_wi_host_result.mData)
? 0
: 1;
}; };
switch(num_dim_spatial) switch(num_dim_spatial)
...@@ -347,4 +350,5 @@ int main(int argc, char* argv[]) ...@@ -347,4 +350,5 @@ int main(int argc, char* argv[])
} }
} }
} }
return 0;
} }
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include <cstdlib> #include <cstdlib>
#include <stdlib.h> #include <stdlib.h>
#include <half.hpp> #include <half.hpp>
#include "check_err.hpp"
#include "config.hpp" #include "config.hpp"
#include "device.hpp" #include "device.hpp"
#include "host_tensor.hpp" #include "host_tensor.hpp"
...@@ -57,18 +58,18 @@ using ReferenceBatchedGemmInstance = ck::tensor_operation::host:: ...@@ -57,18 +58,18 @@ using ReferenceBatchedGemmInstance = ck::tensor_operation::host::
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
bool do_verification = 1; bool do_verification = true;
int init_method = 1; int init_method = 1;
int nrepeat = 5; bool time_kernel = false;
// GEMM shape // GEMM shape
ck::index_t M = 3840; ck::index_t M = 2048;
ck::index_t N = 4096; ck::index_t N = 1920;
ck::index_t K = 4096; ck::index_t K = 2048;
ck::index_t StrideA = 4096; ck::index_t StrideA = 2048;
ck::index_t StrideB = 4096; ck::index_t StrideB = 2048;
ck::index_t StrideC = 4096; ck::index_t StrideC = 1920;
ck::index_t BatchCount = 4; ck::index_t BatchCount = 4;
...@@ -80,13 +81,13 @@ int main(int argc, char* argv[]) ...@@ -80,13 +81,13 @@ int main(int argc, char* argv[])
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
else if(argc == 11) else if(argc == 11)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]); M = std::stoi(argv[4]);
N = std::stoi(argv[5]); N = std::stoi(argv[5]);
...@@ -96,13 +97,13 @@ int main(int argc, char* argv[]) ...@@ -96,13 +97,13 @@ int main(int argc, char* argv[])
StrideB = std::stoi(argv[8]); StrideB = std::stoi(argv[8]);
StrideC = std::stoi(argv[9]); StrideC = std::stoi(argv[9]);
BatchCount = std::stoi(argv[9]); BatchCount = std::stoi(argv[10]);
} }
else else
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: run kernel # of times (>1)\n"); printf("arg3: time kernel (0=n0, 1=yes)\n");
printf("arg4 to 10: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC, BatchCount\n"); printf("arg4 to 10: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC, BatchCount\n");
exit(0); exit(0);
} }
...@@ -204,30 +205,13 @@ int main(int argc, char* argv[]) ...@@ -204,30 +205,13 @@ int main(int argc, char* argv[])
"not support this GEMM problem"); "not support this GEMM problem");
} }
// warm up // init DO, D1 to 0
invoker.Run(argument); d0_device_buf.SetZero();
d1_device_buf.SetZero();
// timing // if time_kernel == true, kernel will run multiple times. This kernel use atomic-add so result
float total_time = 0; // will not be correct. need to set time_kernel = false for correctness test
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
for(int i = 0; i < nrepeat; ++i)
{
// init DO, D1 to 0
d0_device_buf.SetZero();
d1_device_buf.SetZero();
KernelTimer timer;
timer.Start();
invoker.Run(argument);
timer.End();
total_time += timer.GetElapsedTime();
}
float ave_time = total_time / nrepeat;
std::size_t flop = std::size_t(2) * BatchCount * M * N * K; std::size_t flop = std::size_t(2) * BatchCount * M * N * K;
std::size_t num_btype = sizeof(ADataType) * BatchCount * M * K + std::size_t num_btype = sizeof(ADataType) * BatchCount * M * K +
...@@ -241,6 +225,7 @@ int main(int argc, char* argv[]) ...@@ -241,6 +225,7 @@ int main(int argc, char* argv[])
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< batched_gemm.GetTypeString() << std::endl; << batched_gemm.GetTypeString() << std::endl;
bool pass = true;
if(do_verification) if(do_verification)
{ {
c_device_buf.FromDevice(c_g_m_n_device_result.mData.data()); c_device_buf.FromDevice(c_g_m_n_device_result.mData.data());
...@@ -264,7 +249,7 @@ int main(int argc, char* argv[]) ...@@ -264,7 +249,7 @@ int main(int argc, char* argv[])
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
float d0_val = ck::type_convert<float>(c_g_m_n_host_result(m, n)); float d0_val = ck::type_convert<float>(c_g_m_n_host_result(batch, m, n));
float d1_val; float d1_val;
d1_element_op(d1_val, d0_val); d1_element_op(d1_val, d0_val);
...@@ -277,10 +262,18 @@ int main(int argc, char* argv[]) ...@@ -277,10 +262,18 @@ int main(int argc, char* argv[])
} }
} }
check_error(c_g_m_n_host_result, c_g_m_n_device_result); pass &= ck::utils::check_err(c_g_m_n_host_result.mData, c_g_m_n_device_result.mData);
check_error(d0_g_m_host_result, d0_g_m_device_result); pass &= ck::utils::check_err(d0_g_m_device_result.mData,
check_error(d1_g_m_host_result, d1_g_m_device_result); d0_g_m_host_result.mData,
"Error: Incorrect results! D0",
1e-3,
1e-3);
pass &= ck::utils::check_err(d1_g_m_device_result.mData,
d1_g_m_host_result.mData,
"Error: Incorrect results! D1",
1e-3,
1e-3);
} }
return 0; return pass ? 0 : 1;
} }
...@@ -88,9 +88,9 @@ using ReferenceCGemmInstance = ck::tensor_operation::host:: ...@@ -88,9 +88,9 @@ using ReferenceCGemmInstance = ck::tensor_operation::host::
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
bool do_verification = 0; bool do_verification = true;
int init_method = 0; int init_method = 1;
int nrepeat = 5; bool time_kernel = false;
// CGEMM shape // CGEMM shape
ck::index_t M = 3840; ck::index_t M = 3840;
...@@ -105,13 +105,13 @@ int main(int argc, char* argv[]) ...@@ -105,13 +105,13 @@ int main(int argc, char* argv[])
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
else if(argc == 10) else if(argc == 10)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]); M = std::stoi(argv[4]);
N = std::stoi(argv[5]); N = std::stoi(argv[5]);
...@@ -223,7 +223,7 @@ int main(int argc, char* argv[]) ...@@ -223,7 +223,7 @@ int main(int argc, char* argv[])
"not support this CGEMM problem"); "not support this CGEMM problem");
} }
float ave_time = invoker.Run(argument, nrepeat); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(8) * M * N * K; std::size_t flop = std::size_t(8) * M * N * K;
std::size_t num_btype = std::size_t(2) * sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + std::size_t num_btype = std::size_t(2) * sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
......
...@@ -19,9 +19,18 @@ include_directories(BEFORE ...@@ -19,9 +19,18 @@ include_directories(BEFORE
add_custom_target(examples) add_custom_target(examples)
function(add_example_executable EXAMPLE_NAME) function(add_example_executable EXAMPLE_NAME FILE_NAME)
message("adding example ${EXAMPLE_NAME}") message("adding example ${EXAMPLE_NAME}")
add_executable(${EXAMPLE_NAME} ${ARGN}) add_executable(${EXAMPLE_NAME} ${FILE_NAME})
target_link_libraries(${EXAMPLE_NAME} PRIVATE host_tensor)
add_test(NAME ${EXAMPLE_NAME} COMMAND $<TARGET_FILE:${EXAMPLE_NAME}> ${ARGN})
add_dependencies(examples ${EXAMPLE_NAME})
add_dependencies(check ${EXAMPLE_NAME})
endfunction(add_example_executable EXAMPLE_NAME)
function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
message("adding example ${EXAMPLE_NAME}")
add_executable(${EXAMPLE_NAME} ${FILE_NAME})
target_link_libraries(${EXAMPLE_NAME} PRIVATE host_tensor) target_link_libraries(${EXAMPLE_NAME} PRIVATE host_tensor)
add_dependencies(examples ${EXAMPLE_NAME}) add_dependencies(examples ${EXAMPLE_NAME})
endfunction(add_example_executable EXAMPLE_NAME) endfunction(add_example_executable EXAMPLE_NAME)
......
...@@ -109,6 +109,10 @@ ...@@ -109,6 +109,10 @@
// experimental feature: use __builtin_memcpy instead of union to do bit_cast // experimental feature: use __builtin_memcpy instead of union to do bit_cast
#define CK_EXPERIMENTAL_USE_MEMCPY_FOR_BIT_CAST 1 #define CK_EXPERIMENTAL_USE_MEMCPY_FOR_BIT_CAST 1
// experimental feature: optimize for inter-wave scheduling policy
#define CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING 0
#define CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS 1
// hack: have underlying assumption that need to be satsified, otherwise it's a bug // hack: have underlying assumption that need to be satsified, otherwise it's a bug
// hack for forcing register to keep idx_diff_low_const in SGPR. idx_diff_low_const must be // hack for forcing register to keep idx_diff_low_const in SGPR. idx_diff_low_const must be
// thread-invariant, otherwise it's a bug // thread-invariant, otherwise it's a bug
......
#pragma once
// "_PACKAGE_" to avoid name contentions: the macros like
// HIP_VERSION_MAJOR are defined in HIP_VERSION.h.
// clang-format off
#define CK_HIP_PACKAGE_VERSION_MAJOR @CK_HIP_VERSION_MAJOR@
#define CK_HIP_PACKAGE_VERSION_MINOR @CK_HIP_VERSION_MINOR@
#define CK_HIP_PACKAGE_VERSION_PATCH @CK_HIP_VERSION_PATCH@
// clang-format on
#ifndef CK_HIP_PACKAGE_VERSION_MAJOR
#define CK_HIP_PACKAGE_VERSION_MAJOR 0
#endif
#ifndef CK_HIP_PACKAGE_VERSION_MINOR
#define CK_HIP_PACKAGE_VERSION_MINOR 0
#endif
#ifndef CK_HIP_PACKAGE_VERSION_PATCH
#define CK_HIP_PACKAGE_VERSION_PATCH 0
#endif
// 3 decimal digits for major and minor, 6 digits for patch number.
// Max number is 999,999,999999 == 0xE8,D4A5,0FFF that fits into 64-bit math.
#if CK_HIP_PACKAGE_VERSION_MAJOR > 999 || CK_HIP_PACKAGE_VERSION_MAJOR > 999 || \
CK_HIP_PACKAGE_VERSION_PATCH > 999999
#error "Too big HIP version number(s)"
#endif
#define CK_HIP_PACKAGE_VERSION_FLAT \
((CK_HIP_PACKAGE_VERSION_MAJOR * 1000ULL + CK_HIP_PACKAGE_VERSION_MINOR) * 1000000 + \
CK_HIP_PACKAGE_VERSION_PATCH)
#pragma once
#cmakedefine01 CK_TIME_KERNEL
#pragma once
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
struct StreamConfig
{
hipStream_t stream_id_ = nullptr;
bool time_kernel_ = false;
};
...@@ -7,6 +7,21 @@ ...@@ -7,6 +7,21 @@
namespace ck { namespace ck {
enum struct LoopScheduler
{
Default,
Interwave,
};
constexpr LoopScheduler make_default_loop_scheduler()
{
#if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING
return LoopScheduler::Interwave;
#else
return LoopScheduler::Default;
#endif // if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING
}
template <index_t BlockSize, template <index_t BlockSize,
typename FloatAB, typename FloatAB,
typename FloatAcc, typename FloatAcc,
...@@ -302,7 +317,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -302,7 +317,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
}); });
} }
private: protected:
// A[M0, M1, M2, KPerThread] // A[M0, M1, M2, KPerThread]
static constexpr auto a_thread_desc_ = static constexpr auto a_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number<KPerThread>{})); make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number<KPerThread>{}));
...@@ -339,4 +354,232 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -339,4 +354,232 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()}; BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()};
}; };
// Note: To facilitate the inter-wave loop scheduler, we need to explicitly set the macro
// CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING=1 as a few intrinsics are not yet available in
// the latest ROCm release. For unsupported compilers, inter-wave loop scheduler falls back to the
// default loop scheduler which is given by the macro CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING=0
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename AK0MK1BlockDesc,
typename BK0NK1BlockDesc,
index_t MPerXDL,
index_t NPerXDL,
index_t MRepeat,
index_t NRepeat,
index_t KPack,
index_t NumMacClusters = CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS>
struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
: public BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB,
FloatAcc,
AK0MK1BlockDesc,
BK0NK1BlockDesc,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>
{
using Base = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB,
FloatAcc,
AK0MK1BlockDesc,
BK0NK1BlockDesc,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>;
#if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING
using Base::a_block_desc_m0_m1_m2_k;
using Base::A_K1;
using Base::b_block_desc_n0_n1_n2_k;
using Base::B_K1;
using Base::c_thread_buf_;
using Base::c_thread_desc_;
using Base::CalculateAThreadOriginDataIndex;
using Base::CalculateBThreadOriginDataIndex;
using Base::I0;
using Base::I1;
using Base::KPerThread;
using Base::xdlops_gemm;
static constexpr index_t KPerInnerLoop = math::max(KPerThread / NumMacClusters, KPack);
// 2-wave optimized blockwise gemm
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
__device__ void Run(const ABlockBuffer& a_block_buf,
const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const
{
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
b_thread_desc_.GetElementSpaceSize());
static_for<0, KPerThread, KPerInnerLoop>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
// read A
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, k),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, I0),
a_thread_buf);
});
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read B
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, k),
b_block_buf,
b_thread_desc_,
make_tuple(n0, I0, I0, I0),
b_thread_buf);
});
__builtin_amdgcn_sched_barrier();
// NOTE: Synchronize threads in a workgroup at the start of each MAC cluster, but except
// the first, as we can shorten non-MAC cluster a bit and there's no observable negative
// impact. The desired effect is waves in a workgroup executing MAC in sync. This avoids
// some out-of-sync waves hijacking MAC resource from other workgroups and reducing the
// chance of latency hiding by waiting for the rest of the workgroup at the eventual
// sync point.
if constexpr(k.value != 0 || KPerInnerLoop == KPerThread)
{
asm volatile("s_barrier" ::);
__builtin_amdgcn_sched_barrier();
}
static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto i) {
a_thread_vec.template AsType<FloatAB>()(i) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, 0, 0, k_ + i))>{}];
b_thread_vec.template AsType<FloatAB>()(i) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, 0, 0, k_ + i))>{}];
});
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
// The block_sync_lds() here performs double duty:
// A) safeguard against data hazard because barrier from blockwise_gemm is
// moved here B) reduce VMEM FIFO congestion by applying small delays to
// different wavefronts It is performed near the end of MAC cluster to
// minimize lgkmcnt penalty
if constexpr(k.value == KPerThread - KPerInnerLoop &&
k_.value == KPerInnerLoop - KPack && m0.value == MRepeat - 1 &&
n0.value == NRepeat - 1)
{
__builtin_amdgcn_sched_barrier();
block_sync_lds();
__builtin_amdgcn_sched_barrier();
}
// TODO: insert setprio in more precise manner since we
// could have more than >1 MFMA instructions in single call
xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
{
__builtin_amdgcn_sched_barrier();
__builtin_amdgcn_s_setprio(1);
__builtin_amdgcn_sched_barrier();
}
});
});
});
__builtin_amdgcn_sched_barrier();
__builtin_amdgcn_s_setprio(0);
__builtin_amdgcn_sched_barrier();
});
}
protected:
// A[M0, M1, M2, KPerInnerLoop]
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, I1, I1, Number<KPerInnerLoop>{}));
// B[N0, N1, N2, KPerInnerLoop]
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<NRepeat>{}, I1, I1, Number<KPerInnerLoop>{}));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB,
decltype(a_block_desc_m0_m1_m2_k),
decltype(a_thread_desc_),
Sequence<1, 1, 1, KPerInnerLoop>,
Sequence<0, 1, 2, 3>,
3,
A_K1,
A_K1>;
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB,
decltype(b_block_desc_n0_n1_n2_k),
decltype(b_thread_desc_),
Sequence<1, 1, 1, KPerInnerLoop>,
Sequence<0, 1, 2, 3>,
3,
B_K1,
B_K1>;
AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()};
BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()};
#endif // #if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING
};
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename AK0MK1BlockDesc,
typename BK0NK1BlockDesc,
index_t MPerXDL,
index_t NPerXDL,
index_t MRepeat,
index_t NRepeat,
index_t KPack,
LoopScheduler LoopSched>
constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
{
if constexpr(LoopSched == LoopScheduler::Default)
{
return BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB,
FloatAcc,
AK0MK1BlockDesc,
BK0NK1BlockDesc,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
}
else if constexpr(LoopSched == LoopScheduler::Interwave)
{
return BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB,
FloatAcc,
AK0MK1BlockDesc,
BK0NK1BlockDesc,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
}
};
} // namespace ck } // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment