Commit 8c9be1df authored by root's avatar root
Browse files

clang

parent cce106f6
...@@ -240,183 +240,399 @@ void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_mnkpaddin ...@@ -240,183 +240,399 @@ void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_mnkpaddin
#ifdef CK_ENABLE_BF16 #ifdef CK_ENABLE_BF16
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_default_instances( void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_default_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Row,
DeviceGemm_Streamk_V2<Row, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>& Row,
instances); Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_kpadding_instances( void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_kpadding_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Row,
DeviceGemm_Streamk_V2<Row, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>& Row,
instances); Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_mnpadding_instances( void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_mnpadding_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Row,
DeviceGemm_Streamk_V2<Row, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>& Row,
instances); Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_mnkpadding_instances( void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_mnkpadding_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Row,
DeviceGemm_Streamk_V2<Row, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>& Row,
instances); Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_mem_v1_default_instances( void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_mem_v1_default_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Row,
DeviceGemm_Streamk_V2<Row, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>& Row,
instances); Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_mem_v1_kpadding_instances( void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_mem_v1_kpadding_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Row,
DeviceGemm_Streamk_V2<Row, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>& Row,
instances); Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_mem_v1_mnkpadding_instances( void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_mem_v1_mnkpadding_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Row,
DeviceGemm_Streamk_V2<Row, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>& Row,
instances); Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_mem_v2_default_instances( void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_mem_v2_default_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Row,
DeviceGemm_Streamk_V2<Row, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>& Row,
instances); Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_mem_v2_kpadding_instances( void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_mem_v2_kpadding_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Row,
DeviceGemm_Streamk_V2<Row, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>& Row,
instances); Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_mem_v2_mnkpadding_instances( void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_mem_v2_mnkpadding_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Row,
DeviceGemm_Streamk_V2<Row, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>& Row,
instances); Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_comp_default_instances( void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_comp_default_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Row,
DeviceGemm_Streamk_V2<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>& Col,
instances); Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_comp_kpadding_instances( void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_comp_kpadding_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Row,
DeviceGemm_Streamk_V2<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>& Col,
instances); Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_mem_v1_default_instances( void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_mem_v1_default_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Row,
DeviceGemm_Streamk_V2<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>& Col,
instances); Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_mem_v1_kpadding_instances( void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_mem_v1_kpadding_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Row,
DeviceGemm_Streamk_V2<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>& Col,
instances); Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_mem_v2_default_instances( void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_mem_v2_default_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Row,
DeviceGemm_Streamk_V2<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>& Col,
instances); Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_mem_v2_kpadding_instances( void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_mem_v2_kpadding_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Row,
DeviceGemm_Streamk_V2<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>& Col,
instances); Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_default_instances( void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_default_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Col,
DeviceGemm_Streamk_V2<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>& Row,
instances); Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_kpadding_instances( void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_kpadding_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Col,
DeviceGemm_Streamk_V2<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>& Row,
instances); Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_mnpadding_instances( void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_mnpadding_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Col,
DeviceGemm_Streamk_V2<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>& Row,
instances); Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_mnkpadding_instances( void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_mnkpadding_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Col,
DeviceGemm_Streamk_V2<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>& Row,
instances); Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_mem_v1_default_instances( void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_mem_v1_default_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Col,
DeviceGemm_Streamk_V2<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>& Row,
instances); Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_mem_v1_kpadding_instances( void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_mem_v1_kpadding_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Col,
DeviceGemm_Streamk_V2<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>& Row,
instances); Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_mem_v1_mnkpadding_instances( void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_mem_v1_mnkpadding_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Col,
DeviceGemm_Streamk_V2<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>& Row,
instances); Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_mem_v2_default_instances( void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_mem_v2_default_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Col,
DeviceGemm_Streamk_V2<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>& Row,
instances); Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_mem_v2_kpadding_instances( void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_mem_v2_kpadding_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Col,
DeviceGemm_Streamk_V2<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>& Row,
instances); Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_mem_v2_mnkpadding_instances( void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_mem_v2_mnkpadding_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Col,
DeviceGemm_Streamk_V2<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>& Row,
instances); Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_comp_default_instances( void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_comp_default_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Col,
DeviceGemm_Streamk_V2<Col, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>& Col,
instances); Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_comp_kpadding_instances( void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_comp_kpadding_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Col,
DeviceGemm_Streamk_V2<Col, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>& Col,
instances); Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_comp_mpadding_instances( void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_comp_mpadding_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Col,
DeviceGemm_Streamk_V2<Col, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>& Col,
instances); Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_comp_mkpadding_instances( void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_comp_mkpadding_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Col,
DeviceGemm_Streamk_V2<Col, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>& Col,
instances); Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_mem_v1_default_instances( void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_mem_v1_default_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Col,
DeviceGemm_Streamk_V2<Col, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>& Col,
instances); Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_mem_v1_kpadding_instances( void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_mem_v1_kpadding_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Col,
DeviceGemm_Streamk_V2<Col, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>& Col,
instances); Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_mem_v1_mkpadding_instances( void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_mem_v1_mkpadding_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Col,
DeviceGemm_Streamk_V2<Col, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>& Col,
instances); Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_mem_v2_default_instances( void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_mem_v2_default_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Col,
DeviceGemm_Streamk_V2<Col, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>& Col,
instances); Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_mem_v2_kpadding_instances( void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_mem_v2_kpadding_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Col,
DeviceGemm_Streamk_V2<Col, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>& Col,
instances); Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_mem_v2_mkpadding_instances( void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_mem_v2_mkpadding_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Col,
DeviceGemm_Streamk_V2<Col, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>& Col,
instances); Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif #endif
#if(defined(CK_ENABLE_FP8)) #if(defined(CK_ENABLE_FP8))
...@@ -708,7 +924,6 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemm_S ...@@ -708,7 +924,6 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemm_S
} }
#endif #endif
#ifdef CK_ENABLE_BF16 #ifdef CK_ENABLE_BF16
if constexpr(is_same_v<ADataType, bhalf_t> && is_same_v<BDataType, bhalf_t> && if constexpr(is_same_v<ADataType, bhalf_t> && is_same_v<BDataType, bhalf_t> &&
is_same_v<CDataType, bhalf_t>) is_same_v<CDataType, bhalf_t>)
...@@ -812,7 +1027,6 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemm_S ...@@ -812,7 +1027,6 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemm_S
} }
#endif #endif
#if(defined(CK_ENABLE_FP8)) #if(defined(CK_ENABLE_FP8))
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, f8_t> && if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, f8_t> &&
is_same_v<CDataType, half_t>) is_same_v<CDataType, half_t>)
......
...@@ -9,9 +9,15 @@ namespace device { ...@@ -9,9 +9,15 @@ namespace device {
namespace instance { namespace instance {
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_default_instances( void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_default_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Col,
DeviceGemm_Streamk_V2<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>& Row,
instances) Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, instances,
......
...@@ -9,12 +9,19 @@ namespace device { ...@@ -9,12 +9,19 @@ namespace device {
namespace instance { namespace instance {
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_kpadding_instances( void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_kpadding_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Col,
DeviceGemm_Streamk_V2<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough,PassThrough>>>& Row,
instances) Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_instances<GemmKPadding>{}); instances,
device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_instances<GemmKPadding>{});
} }
} // namespace instance } // namespace instance
......
...@@ -9,12 +9,19 @@ namespace device { ...@@ -9,12 +9,19 @@ namespace device {
namespace instance { namespace instance {
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_mnkpadding_instances( void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_mnkpadding_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Col,
DeviceGemm_Streamk_V2<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>& Row,
instances) Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_instances<GemmMNKPadding>{}); instances,
device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_instances<GemmMNKPadding>{});
} }
} // namespace instance } // namespace instance
......
...@@ -9,9 +9,15 @@ namespace device { ...@@ -9,9 +9,15 @@ namespace device {
namespace instance { namespace instance {
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_mnpadding_instances( void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_mnpadding_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Col,
DeviceGemm_Streamk_V2<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough,PassThrough>>>& Row,
instances) Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, instances,
......
...@@ -9,9 +9,15 @@ namespace device { ...@@ -9,9 +9,15 @@ namespace device {
namespace instance { namespace instance {
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_mem_v2_kpadding_instances( void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_mem_v2_kpadding_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Row,
DeviceGemm_Streamk_V2<Row, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>& Row,
instances) Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, instances,
......
...@@ -9,9 +9,15 @@ namespace device { ...@@ -9,9 +9,15 @@ namespace device {
namespace instance { namespace instance {
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_mem_v2_mnkpadding_instances( void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_mem_v2_mnkpadding_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Row,
DeviceGemm_Streamk_V2<Row, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>& Row,
instances) Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, instances,
......
...@@ -91,8 +91,6 @@ int profile_gemm_universal_streamk(int argc, char* argv[]) ...@@ -91,8 +91,6 @@ int profile_gemm_universal_streamk(int argc, char* argv[])
using F8 = ck::f8_t; using F8 = ck::f8_t;
#endif #endif
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment