Commit 363b6744 authored by mtgu0705's avatar mtgu0705
Browse files

add instance for gemm_ab_scale

parent 9dac9713
...@@ -17,7 +17,7 @@ namespace tensor_operation { ...@@ -17,7 +17,7 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8)) #if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instances( void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row, std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
Col, Col,
Tuple<>, Tuple<>,
...@@ -28,14 +28,14 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_i ...@@ -28,14 +28,14 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_i
F32, F32,
Tuple<>, Tuple<>,
BF16, BF16,
128, 1,
128, 128,
128, 128,
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instances( void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row, std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
Col, Col,
Tuple<>, Tuple<>,
...@@ -46,14 +46,14 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_ ...@@ -46,14 +46,14 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_
F32, F32,
Tuple<>, Tuple<>,
BF16, BF16,
128, 1,
128, 128,
128, 128,
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnpadding_instances( void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_mnpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row, std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
Col, Col,
Tuple<>, Tuple<>,
...@@ -64,14 +64,14 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnpadding ...@@ -64,14 +64,14 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnpadding
F32, F32,
Tuple<>, Tuple<>,
BF16, BF16,
128, 1,
128, 128,
128, 128,
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnkpadding_instances( void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row, std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
Col, Col,
Tuple<>, Tuple<>,
...@@ -82,14 +82,14 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnkpaddin ...@@ -82,14 +82,14 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnkpaddin
F32, F32,
Tuple<>, Tuple<>,
BF16, BF16,
128, 1,
128, 128,
128, 128,
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instances( void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row, std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
Col, Col,
Tuple<>, Tuple<>,
...@@ -100,14 +100,14 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default ...@@ -100,14 +100,14 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default
F32, F32,
Tuple<>, Tuple<>,
BF16, BF16,
128, 1,
128, 128,
128, 128,
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instances( void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row, std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
Col, Col,
Tuple<>, Tuple<>,
...@@ -118,14 +118,14 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpaddin ...@@ -118,14 +118,14 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpaddin
F32, F32,
Tuple<>, Tuple<>,
BF16, BF16,
128, 1,
128, 128,
128, 128,
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_mnkpadding_instances( void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row, std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
Col, Col,
Tuple<>, Tuple<>,
...@@ -136,7 +136,7 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_mnkpadd ...@@ -136,7 +136,7 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_mnkpadd
F32, F32,
Tuple<>, Tuple<>,
BF16, BF16,
128, 1,
128, 128,
128, 128,
PassThrough, PassThrough,
...@@ -163,7 +163,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu ...@@ -163,7 +163,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
B1DataType, B1DataType,
Tuple<>, Tuple<>,
CDataType, CDataType,
128, 1,
128, 128,
128, 128,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
...@@ -180,7 +180,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu ...@@ -180,7 +180,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
B1DataType, B1DataType,
Tuple<>, Tuple<>,
CDataType, CDataType,
128, 1,
128, 128,
128, 128,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
...@@ -198,20 +198,20 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu ...@@ -198,20 +198,20 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> && if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>) is_same_v<CLayout, Row>)
{ {
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instances( add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_default_instances(
op_ptrs); op_ptrs);
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instances( add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_kpadding_instances(
op_ptrs); op_ptrs);
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnpadding_instances( add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_mnpadding_instances(
op_ptrs); op_ptrs);
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnkpadding_instances( add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_mnkpadding_instances(
op_ptrs); op_ptrs);
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instances( add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_default_instances(
op_ptrs); op_ptrs);
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instances( add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_kpadding_instances(
op_ptrs); op_ptrs);
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_mnkpadding_instances( add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_mnkpadding_instances(
op_ptrs); op_ptrs);
} }
} }
......
...@@ -8,7 +8,7 @@ namespace tensor_operation { ...@@ -8,7 +8,7 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instances( void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row, std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
Col, Col,
Tuple<>, Tuple<>,
...@@ -19,7 +19,7 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_i ...@@ -19,7 +19,7 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_i
F32, F32,
Tuple<>, Tuple<>,
BF16, BF16,
128, 1,
128, 128,
128, 128,
PassThrough, PassThrough,
...@@ -28,7 +28,7 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_i ...@@ -28,7 +28,7 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_i
{ {
add_device_operation_instances( add_device_operation_instances(
instances, instances,
device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_instances<GemmDefault>{}); device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_instances<GemmDefault>{});
} }
} // namespace instance } // namespace instance
......
...@@ -8,7 +8,7 @@ namespace tensor_operation { ...@@ -8,7 +8,7 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instances( void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row, std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
Col, Col,
Tuple<>, Tuple<>,
...@@ -19,7 +19,7 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_ ...@@ -19,7 +19,7 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_
F32, F32,
Tuple<>, Tuple<>,
BF16, BF16,
128, 1,
128, 128,
128, 128,
PassThrough, PassThrough,
...@@ -28,7 +28,7 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_ ...@@ -28,7 +28,7 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_
{ {
add_device_operation_instances( add_device_operation_instances(
instances, instances,
device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_instances<GemmKPadding>{}); device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_instances<GemmKPadding>{});
} }
} // namespace instance } // namespace instance
......
...@@ -8,7 +8,7 @@ namespace tensor_operation { ...@@ -8,7 +8,7 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnkpadding_instances( void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row, std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
Col, Col,
Tuple<>, Tuple<>,
...@@ -19,7 +19,7 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnkpaddin ...@@ -19,7 +19,7 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnkpaddin
F32, F32,
Tuple<>, Tuple<>,
BF16, BF16,
128, 1,
128, 128,
128, 128,
PassThrough, PassThrough,
...@@ -28,7 +28,7 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnkpaddin ...@@ -28,7 +28,7 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnkpaddin
{ {
add_device_operation_instances( add_device_operation_instances(
instances, instances,
device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_instances<GemmMNKPadding>{}); device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_instances<GemmMNKPadding>{});
} }
} // namespace instance } // namespace instance
......
...@@ -8,7 +8,7 @@ namespace tensor_operation { ...@@ -8,7 +8,7 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnpadding_instances( void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_mnpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row, std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
Col, Col,
Tuple<>, Tuple<>,
...@@ -19,7 +19,7 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnpadding ...@@ -19,7 +19,7 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnpadding
F32, F32,
Tuple<>, Tuple<>,
BF16, BF16,
128, 1,
128, 128,
128, 128,
PassThrough, PassThrough,
...@@ -28,7 +28,7 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnpadding ...@@ -28,7 +28,7 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnpadding
{ {
add_device_operation_instances( add_device_operation_instances(
instances, instances,
device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_instances<GemmMNPadding>{}); device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_instances<GemmMNPadding>{});
} }
} // namespace instance } // namespace instance
......
...@@ -8,7 +8,7 @@ namespace tensor_operation { ...@@ -8,7 +8,7 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instances( void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row, std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
Col, Col,
Tuple<>, Tuple<>,
...@@ -19,7 +19,7 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default ...@@ -19,7 +19,7 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default
F32, F32,
Tuple<>, Tuple<>,
BF16, BF16,
128, 1,
128, 128,
128, 128,
PassThrough, PassThrough,
...@@ -28,7 +28,7 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default ...@@ -28,7 +28,7 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default
{ {
add_device_operation_instances( add_device_operation_instances(
instances, instances,
device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_instances<Intrawave, device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_instances<Intrawave,
GemmDefault>{}); GemmDefault>{});
} }
......
...@@ -8,7 +8,7 @@ namespace tensor_operation { ...@@ -8,7 +8,7 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instances( void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row, std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
Col, Col,
Tuple<>, Tuple<>,
...@@ -19,7 +19,7 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpaddin ...@@ -19,7 +19,7 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpaddin
F32, F32,
Tuple<>, Tuple<>,
BF16, BF16,
128, 1,
128, 128,
128, 128,
PassThrough, PassThrough,
...@@ -28,7 +28,7 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpaddin ...@@ -28,7 +28,7 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpaddin
{ {
add_device_operation_instances( add_device_operation_instances(
instances, instances,
device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_instances<Intrawave, device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_instances<Intrawave,
GemmKPadding>{}); GemmKPadding>{});
} }
......
...@@ -8,7 +8,7 @@ namespace tensor_operation { ...@@ -8,7 +8,7 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_mnkpadding_instances( void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row, std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
Col, Col,
Tuple<>, Tuple<>,
...@@ -19,7 +19,7 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_mnkpadd ...@@ -19,7 +19,7 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_mnkpadd
F32, F32,
Tuple<>, Tuple<>,
BF16, BF16,
128, 1,
128, 128,
128, 128,
PassThrough, PassThrough,
...@@ -28,7 +28,7 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_mnkpadd ...@@ -28,7 +28,7 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_mnkpadd
{ {
add_device_operation_instances( add_device_operation_instances(
instances, instances,
device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_instances<Intrawave, device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_instances<Intrawave,
GemmMNKPadding>{}); GemmMNKPadding>{});
} }
......
This diff is collapsed.
...@@ -32,8 +32,10 @@ enum struct GemmDataType ...@@ -32,8 +32,10 @@ enum struct GemmDataType
enum struct ScaleBlockTile enum struct ScaleBlockTile
{ {
Tile_128_128_128, // 0 Tile_128_128_128, // 0
Tile_1_128_128, // 1
}; };
#define OP_NAME "gemm_ab_scale" #define OP_NAME "gemm_ab_scale"
#define OP_DESC "GEMM_AB_Scale" #define OP_DESC "GEMM_AB_Scale"
...@@ -154,8 +156,25 @@ int profile_gemm_ab_scale(int argc, char* argv[]) ...@@ -154,8 +156,25 @@ int profile_gemm_ab_scale(int argc, char* argv[])
return pass ? 0 : 1; return pass ? 0 : 1;
}; };
// if(data_type == GemmDataType::F8_F8_BF16 && layout == GemmMatrixLayout::MK_NK_MN &&
// scale_block_tile == ScaleBlockTile::Tile_128_128_128)
// {
// return profile(F8{},
// F32{},
// F8{},
// F32{},
// F8{},
// F32{},
// BF16{},
// ck::Number<128>{},
// ck::Number<128>{},
// ck::Number<128>{},
// Row{},
// Col{},
// Row{});
// }
if(data_type == GemmDataType::F8_F8_BF16 && layout == GemmMatrixLayout::MK_NK_MN && if(data_type == GemmDataType::F8_F8_BF16 && layout == GemmMatrixLayout::MK_NK_MN &&
scale_block_tile == ScaleBlockTile::Tile_128_128_128) scale_block_tile == ScaleBlockTile::Tile_1_128_128)
{ {
return profile(F8{}, return profile(F8{},
F32{}, F32{},
...@@ -164,7 +183,7 @@ int profile_gemm_ab_scale(int argc, char* argv[]) ...@@ -164,7 +183,7 @@ int profile_gemm_ab_scale(int argc, char* argv[])
F8{}, F8{},
F32{}, F32{},
BF16{}, BF16{},
ck::Number<128>{}, ck::Number<1>{},
ck::Number<128>{}, ck::Number<128>{},
ck::Number<128>{}, ck::Number<128>{},
Row{}, Row{},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment