Commit e2127d7a authored by coderfeli's avatar coderfeli
Browse files

impl fp16 in ckprofiler

parent 04f09f08
...@@ -16,7 +16,8 @@ namespace ck { ...@@ -16,7 +16,8 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8)) #ifdef CK_ENABLE_FP8
#ifdef CK_ENABLE_BF16
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instances_part1( void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instances_part1(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row, std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col, Col,
...@@ -174,6 +175,165 @@ void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_i ...@@ -174,6 +175,165 @@ void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_i
MultiplyMultiply>>>& instances); MultiplyMultiply>>>& instances);
#endif #endif
#ifdef CK_ENABLE_FP16
void add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_default_instances_part1(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
F16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_kpadding_instances_part1(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
F16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_default_instances_part2(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
F16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_kpadding_instances_part2(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
F16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_default_instances_part1(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
F16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_kpadding_instances_part1(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
F16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_default_instances_part2(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
F16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_kpadding_instances_part2(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
F16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_mem_v1_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
F16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_mem_v1_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
F16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_mem_v2_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
F16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_mem_v2_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
F16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
#endif
#endif
#if(defined(CK_ENABLE_FP16) || defined(CK_ENABLE_INT8)) #if(defined(CK_ENABLE_FP16) || defined(CK_ENABLE_INT8))
void add_device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_comp_default_instances( void add_device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_comp_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row, std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
...@@ -292,7 +452,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu ...@@ -292,7 +452,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
{ {
std::vector<std::unique_ptr<DeviceOp>> op_ptrs; std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8)) #ifdef CK_ENABLE_FP8
#ifdef CK_ENABLE_BF16
if constexpr(is_same_v<ADataType, f8_t> && is_same_v<BDataType, f8_t> && if constexpr(is_same_v<ADataType, f8_t> && is_same_v<BDataType, f8_t> &&
is_same_v<CDataType, bhalf_t>) is_same_v<CDataType, bhalf_t>)
{ {
...@@ -329,6 +490,44 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu ...@@ -329,6 +490,44 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
} }
} }
#endif #endif
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<ADataType, f8_t> && is_same_v<BDataType, f8_t> &&
is_same_v<CDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_default_instances_part1(
op_ptrs);
add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_kpadding_instances_part1(
op_ptrs);
add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_default_instances_part2(
op_ptrs);
add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_kpadding_instances_part2(
op_ptrs);
add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_default_instances_part1(
op_ptrs);
add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_kpadding_instances_part1(
op_ptrs);
add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_default_instances_part2(
op_ptrs);
add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_kpadding_instances_part2(
op_ptrs);
add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_mem_v1_default_instances(
op_ptrs);
add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_mem_v1_kpadding_instances(
op_ptrs);
add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_mem_v2_default_instances(
op_ptrs);
add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_mem_v2_kpadding_instances(
op_ptrs);
}
}
#endif
#endif
#if(defined(CK_ENABLE_FP16) || defined(CK_ENABLE_INT8)) #if(defined(CK_ENABLE_FP16) || defined(CK_ENABLE_INT8))
if constexpr(is_same_v<ADataType, int8_t> && is_same_v<BDataType, int8_t> && if constexpr(is_same_v<ADataType, int8_t> && is_same_v<BDataType, int8_t> &&
is_same_v<CDataType, half_t>) is_same_v<CDataType, half_t>)
......
...@@ -15,6 +15,19 @@ list(APPEND GEMM_MULTIPLY_MULTIPLY_INSTANCES ...@@ -15,6 +15,19 @@ list(APPEND GEMM_MULTIPLY_MULTIPLY_INSTANCES
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_instance.cpp device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_instance.cpp
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp
device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_default_instance_part1.cpp
device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_kpadding_instance_part1.cpp
device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_default_instance_part2.cpp
device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_kpadding_instance_part2.cpp
device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_default_instance_part1.cpp
device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_kpadding_instance_part1.cpp
device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_default_instance_part2.cpp
device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_kpadding_instance_part2.cpp
device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_mem_v1_default_instance.cpp
device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_mem_v1_kpadding_instance.cpp
device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_mem_v2_default_instance.cpp
device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp
device_gemm_multiply_multiply_xdl_i8_i8_f16/device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_comp_default_instance.cpp device_gemm_multiply_multiply_xdl_i8_i8_f16/device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_comp_default_instance.cpp
device_gemm_multiply_multiply_xdl_i8_i8_f16/device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_comp_kpadding_instance.cpp device_gemm_multiply_multiply_xdl_i8_i8_f16/device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_comp_kpadding_instance.cpp
device_gemm_multiply_multiply_xdl_i8_i8_f16/device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_mem_v1_default_instance.cpp device_gemm_multiply_multiply_xdl_i8_i8_f16/device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_mem_v1_default_instance.cpp
...@@ -32,6 +45,15 @@ set_source_files_properties(device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_ ...@@ -32,6 +45,15 @@ set_source_files_properties(device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_
set_source_files_properties(device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mfma16x16_default_instance_part2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mfma16x16_default_instance_part2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mfma16x16_kpadding_instance_part2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mfma16x16_kpadding_instance_part2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_default_instance_part1.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_kpadding_instance_part1.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_default_instance_part2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_kpadding_instance_part2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_default_instance_part1.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_kpadding_instance_part1.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_default_instance_part2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_kpadding_instance_part2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_multiply_multiply_xdl_i8_i8_f16/device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(device_gemm_multiply_multiply_xdl_i8_i8_f16/device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_multiply_multiply_xdl_i8_i8_f16/device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(device_gemm_multiply_multiply_xdl_i8_i8_f16/device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
......
...@@ -28,6 +28,7 @@ enum struct GemmDataType ...@@ -28,6 +28,7 @@ enum struct GemmDataType
F16_F16_F16_F8, // 6 F16_F16_F16_F8, // 6
F8_F8_BF16, // 7 F8_F8_BF16, // 7
INT8_INT8_F16, // 8 INT8_INT8_F16, // 8
F8_F8_F16, // 9
}; };
#define OP_NAME "gemm_multiply_multiply" #define OP_NAME "gemm_multiply_multiply"
...@@ -40,7 +41,7 @@ int profile_gemm_multiply_multiply(int argc, char* argv[]) ...@@ -40,7 +41,7 @@ int profile_gemm_multiply_multiply(int argc, char* argv[])
printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"); printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n");
printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8; 6: " printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8; 6: "
"f16->f8; 7: f8->bf16, " "f16->f8; 7: f8->bf16, "
"comp f8; 8: int8->f16)\n"); "comp f8; 8: int8->f16; 9: f8->f16, comp f8;)\n");
printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"); printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n");
printf(" 1: A[m, k] * B[n, k] = C[m, n];\n"); printf(" 1: A[m, k] * B[n, k] = C[m, n];\n");
printf(" 2: A[k, m] * B[k, n] = C[m, n];\n"); printf(" 2: A[k, m] * B[k, n] = C[m, n];\n");
...@@ -166,6 +167,11 @@ int profile_gemm_multiply_multiply(int argc, char* argv[]) ...@@ -166,6 +167,11 @@ int profile_gemm_multiply_multiply(int argc, char* argv[])
return profile( return profile(
F8{}, F8{}, F8{}, F32{}, F32{}, F32{}, BF16{}, Row{}, Col{}, Row{}, Col{}, Row{}); F8{}, F8{}, F8{}, F32{}, F32{}, F32{}, BF16{}, Row{}, Col{}, Row{}, Col{}, Row{});
} }
else if(data_type == GemmDataType::F8_F8_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{
return profile(
F8{}, F8{}, F8{}, F32{}, F32{}, F32{}, F16{}, Row{}, Col{}, Row{}, Col{}, Row{});
}
else if(data_type == GemmDataType::INT8_INT8_F16 && layout == GemmMatrixLayout::MK_NK_MN) else if(data_type == GemmDataType::INT8_INT8_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{ {
return profile( return profile(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment