Commit 32380a27 authored by Jing Zhang's avatar Jing Zhang
Browse files

add ckProfiler

parent 1675a341
......@@ -18,7 +18,7 @@ namespace device {
namespace instance {
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
......@@ -31,7 +31,7 @@ void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_inst
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
......@@ -44,7 +44,7 @@ void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_ins
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
......@@ -57,7 +57,7 @@ void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnpadding_in
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
......@@ -70,7 +70,7 @@ void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnkpadding_i
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
......@@ -83,7 +83,7 @@ void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_in
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
......@@ -96,7 +96,7 @@ void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_i
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
......@@ -109,7 +109,7 @@ void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_mnkpadding
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
......@@ -122,7 +122,7 @@ void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_in
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
......@@ -135,7 +135,7 @@ void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_i
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
......@@ -154,7 +154,7 @@ template <typename ADataType,
typename ALayout,
typename BLayout,
typename CLayout>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMultipleD<
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMultipleDSplitK<
ALayout,
BLayout,
Tuple<Row, Col>,
......@@ -167,7 +167,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::MultiplyMultiply>>
{
using DeviceOp = DeviceGemmMultipleD<ALayout,
using DeviceOp = DeviceGemmMultipleDSplitK<ALayout,
BLayout,
Tuple<Row, Col>,
CLayout,
......
......@@ -9,7 +9,7 @@ namespace device {
namespace instance {
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
......
......@@ -9,7 +9,7 @@ namespace device {
namespace instance {
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
......
......@@ -9,7 +9,7 @@ namespace device {
namespace instance {
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
......
......@@ -9,7 +9,7 @@ namespace device {
namespace instance {
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
......
......@@ -9,7 +9,7 @@ namespace device {
namespace instance {
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
......
......@@ -9,7 +9,7 @@ namespace device {
namespace instance {
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
......
......@@ -9,7 +9,7 @@ namespace device {
namespace instance {
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
......
......@@ -9,7 +9,7 @@ namespace device {
namespace instance {
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
......
......@@ -9,7 +9,7 @@ namespace device {
namespace instance {
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
......
......@@ -9,7 +9,7 @@ namespace device {
namespace instance {
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
......
......@@ -48,6 +48,7 @@ bool profile_gemm_multiply_multiply_impl(int do_verification,
int StrideD0,
int StrideD1,
int StrideE,
int KBatch,
int n_warmup,
int n_iter,
uint64_t rotating = 0)
......@@ -129,7 +130,7 @@ bool profile_gemm_multiply_multiply_impl(int do_verification,
d1_device_buf.ToDevice(d1_m_n.mData.data());
using DeviceOp =
ck::tensor_operation::device::DeviceGemmMultipleD<ALayout,
ck::tensor_operation::device::DeviceGemmMultipleDSplitK<ALayout,
BLayout,
ck::Tuple<D0Layout, D1Layout>,
ELayout,
......@@ -199,6 +200,7 @@ bool profile_gemm_multiply_multiply_impl(int do_verification,
StrideB,
std::array<ck::index_t, 2>{StrideD0, StrideD1},
StrideE,
KBatch,
a_element_op,
b_element_op,
c_element_op);
......
......@@ -34,7 +34,7 @@ enum struct GemmDataType
int profile_gemm_multiply_multiply(int argc, char* argv[])
{
if(argc != 16 && argc != 19)
if(argc != 16 && argc != 20)
{
printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n");
printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8; 6: "
......@@ -50,9 +50,10 @@ int profile_gemm_multiply_multiply(int argc, char* argv[])
printf("arg7: time kernel (0=no, 1=yes)\n");
printf("arg8 to 15: M, N, K, StrideA, StrideB, StrideD0, StrideD1, StrideE\n");
printf("optional:\n");
printf("arg16: number of warm-up cycles (default 1)\n");
printf("arg17: number of iterations (default 10)\n");
printf("arg18: memory for rotating buffer (default 0, size in MB)\n");
printf("arg16: number of kbatch (default 1)\n");
printf("arg17: number of warm-up cycles (default 1)\n");
printf("arg18: number of iterations (default 10)\n");
printf("arg19: memory for rotating buffer (default 0, size in MB)\n");
exit(1);
}
......@@ -76,11 +77,13 @@ int profile_gemm_multiply_multiply(int argc, char* argv[])
int n_warmup = 1;
int n_iter = 10;
uint64_t rotating = 0;
if(argc == 18)
int KBatch = 1;
if(argc == 20)
{
n_warmup = std::stoi(argv[16]);
n_iter = std::stoi(argv[17]);
rotating = std::stoull(argv[18]) * 1024 * 1024;
KBatch = std::stoi(argv[16]);
n_warmup = std::stoi(argv[17]);
n_iter = std::stoi(argv[18]);
rotating = std::stoull(argv[19]) * 1024 * 1024;
}
using F32 = float;
......@@ -146,6 +149,7 @@ int profile_gemm_multiply_multiply(int argc, char* argv[])
(StrideD0 < 0) ? DefaultStrideD0 : StrideD0,
(StrideD1 < 0) ? DefaultStrideD1 : StrideD1,
(StrideE < 0) ? DefaultStrideE : StrideE,
KBatch,
n_warmup,
n_iter,
rotating);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment