Commit 363b6744 authored by mtgu0705's avatar mtgu0705
Browse files

add instance for gemm_ab_scale

parent 9dac9713
......@@ -17,7 +17,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instances(
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
Col,
Tuple<>,
......@@ -28,14 +28,14 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_i
F32,
Tuple<>,
BF16,
128,
1,
128,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instances(
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
Col,
Tuple<>,
......@@ -46,14 +46,14 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_
F32,
Tuple<>,
BF16,
128,
1,
128,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnpadding_instances(
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_mnpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
Col,
Tuple<>,
......@@ -64,14 +64,14 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnpadding
F32,
Tuple<>,
BF16,
128,
1,
128,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnkpadding_instances(
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
Col,
Tuple<>,
......@@ -82,14 +82,14 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnkpaddin
F32,
Tuple<>,
BF16,
128,
1,
128,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instances(
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
Col,
Tuple<>,
......@@ -100,14 +100,14 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default
F32,
Tuple<>,
BF16,
128,
1,
128,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instances(
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
Col,
Tuple<>,
......@@ -118,14 +118,14 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpaddin
F32,
Tuple<>,
BF16,
128,
1,
128,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_mnkpadding_instances(
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
Col,
Tuple<>,
......@@ -136,7 +136,7 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_mnkpadd
F32,
Tuple<>,
BF16,
128,
1,
128,
128,
PassThrough,
......@@ -163,7 +163,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
B1DataType,
Tuple<>,
CDataType,
128,
1,
128,
128,
ck::tensor_operation::element_wise::PassThrough,
......@@ -180,7 +180,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
B1DataType,
Tuple<>,
CDataType,
128,
1,
128,
128,
ck::tensor_operation::element_wise::PassThrough,
......@@ -198,20 +198,20 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instances(
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_default_instances(
op_ptrs);
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instances(
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_kpadding_instances(
op_ptrs);
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnpadding_instances(
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_mnpadding_instances(
op_ptrs);
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnkpadding_instances(
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_mnkpadding_instances(
op_ptrs);
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instances(
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_default_instances(
op_ptrs);
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instances(
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_kpadding_instances(
op_ptrs);
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_mnkpadding_instances(
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_mnkpadding_instances(
op_ptrs);
}
}
......
......@@ -8,7 +8,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instances(
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
Col,
Tuple<>,
......@@ -19,7 +19,7 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_i
F32,
Tuple<>,
BF16,
128,
1,
128,
128,
PassThrough,
......@@ -28,7 +28,7 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_i
{
add_device_operation_instances(
instances,
device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_instances<GemmDefault>{});
device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_instances<GemmDefault>{});
}
} // namespace instance
......
......@@ -8,7 +8,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instances(
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
Col,
Tuple<>,
......@@ -19,7 +19,7 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_
F32,
Tuple<>,
BF16,
128,
1,
128,
128,
PassThrough,
......@@ -28,7 +28,7 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_
{
add_device_operation_instances(
instances,
device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_instances<GemmKPadding>{});
device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_instances<GemmKPadding>{});
}
} // namespace instance
......
......@@ -8,7 +8,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnkpadding_instances(
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
Col,
Tuple<>,
......@@ -19,7 +19,7 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnkpaddin
F32,
Tuple<>,
BF16,
128,
1,
128,
128,
PassThrough,
......@@ -28,7 +28,7 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnkpaddin
{
add_device_operation_instances(
instances,
device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_instances<GemmMNKPadding>{});
device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_instances<GemmMNKPadding>{});
}
} // namespace instance
......
......@@ -8,7 +8,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnpadding_instances(
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_mnpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
Col,
Tuple<>,
......@@ -19,7 +19,7 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnpadding
F32,
Tuple<>,
BF16,
128,
1,
128,
128,
PassThrough,
......@@ -28,7 +28,7 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnpadding
{
add_device_operation_instances(
instances,
device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_instances<GemmMNPadding>{});
device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_instances<GemmMNPadding>{});
}
} // namespace instance
......
......@@ -8,7 +8,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instances(
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
Col,
Tuple<>,
......@@ -19,7 +19,7 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default
F32,
Tuple<>,
BF16,
128,
1,
128,
128,
PassThrough,
......@@ -28,7 +28,7 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default
{
add_device_operation_instances(
instances,
device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_instances<Intrawave,
device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_instances<Intrawave,
GemmDefault>{});
}
......
......@@ -8,7 +8,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instances(
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
Col,
Tuple<>,
......@@ -19,7 +19,7 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpaddin
F32,
Tuple<>,
BF16,
128,
1,
128,
128,
PassThrough,
......@@ -28,7 +28,7 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpaddin
{
add_device_operation_instances(
instances,
device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_instances<Intrawave,
device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_instances<Intrawave,
GemmKPadding>{});
}
......
......@@ -8,7 +8,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_mnkpadding_instances(
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
Col,
Tuple<>,
......@@ -19,7 +19,7 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_mnkpadd
F32,
Tuple<>,
BF16,
128,
1,
128,
128,
PassThrough,
......@@ -28,7 +28,7 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_mnkpadd
{
add_device_operation_instances(
instances,
device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_instances<Intrawave,
device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_instances<Intrawave,
GemmMNKPadding>{});
}
......
This diff is collapsed.
......@@ -32,8 +32,10 @@ enum struct GemmDataType
enum struct ScaleBlockTile
{
Tile_128_128_128, // 0
Tile_1_128_128, // 1
};
#define OP_NAME "gemm_ab_scale"
#define OP_DESC "GEMM_AB_Scale"
......@@ -154,8 +156,25 @@ int profile_gemm_ab_scale(int argc, char* argv[])
return pass ? 0 : 1;
};
// if(data_type == GemmDataType::F8_F8_BF16 && layout == GemmMatrixLayout::MK_NK_MN &&
// scale_block_tile == ScaleBlockTile::Tile_128_128_128)
// {
// return profile(F8{},
// F32{},
// F8{},
// F32{},
// F8{},
// F32{},
// BF16{},
// ck::Number<128>{},
// ck::Number<128>{},
// ck::Number<128>{},
// Row{},
// Col{},
// Row{});
// }
if(data_type == GemmDataType::F8_F8_BF16 && layout == GemmMatrixLayout::MK_NK_MN &&
scale_block_tile == ScaleBlockTile::Tile_128_128_128)
scale_block_tile == ScaleBlockTile::Tile_1_128_128)
{
return profile(F8{},
F32{},
......@@ -164,7 +183,7 @@ int profile_gemm_ab_scale(int argc, char* argv[])
F8{},
F32{},
BF16{},
ck::Number<128>{},
ck::Number<1>{},
ck::Number<128>{},
ck::Number<128>{},
Row{},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment