Commit 3782ed3b authored by Bartlomiej Kocot's avatar Bartlomiej Kocot
Browse files

reproducer

parent 0f48e38a
...@@ -19,44 +19,44 @@ namespace tensor_operation { ...@@ -19,44 +19,44 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_instances( // void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_instances(
std::vector<std::unique_ptr<DeviceBatchedGemmMultiD<Col, // std::vector<std::unique_ptr<DeviceBatchedGemmMultiD<Col,
Row, // Row,
Empty_Tuple, // Empty_Tuple,
Row, // Row,
F16, // F16,
F16, // F16,
Empty_Tuple, // Empty_Tuple,
F16, // F16,
PassThrough, // PassThrough,
PassThrough, // PassThrough,
PassThrough>>>& instances); // PassThrough>>>& instances);
void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_instances( // void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_instances(
std::vector<std::unique_ptr<DeviceBatchedGemmMultiD<Col, // std::vector<std::unique_ptr<DeviceBatchedGemmMultiD<Col,
Col, // Col,
Empty_Tuple, // Empty_Tuple,
Row, // Row,
F16, // F16,
F16, // F16,
Empty_Tuple, // Empty_Tuple,
F16, // F16,
PassThrough, // PassThrough,
PassThrough, // PassThrough,
PassThrough>>>& instances); // PassThrough>>>& instances);
void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_instances( // void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_instances(
std::vector<std::unique_ptr<DeviceBatchedGemmMultiD<Row, // std::vector<std::unique_ptr<DeviceBatchedGemmMultiD<Row,
Row, // Row,
Empty_Tuple, // Empty_Tuple,
Row, // Row,
F16, // F16,
F16, // F16,
Empty_Tuple, // Empty_Tuple,
F16, // F16,
PassThrough, // PassThrough,
PassThrough, // PassThrough,
PassThrough>>>& instances); // PassThrough>>>& instances);
void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_instances( void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_instances(
std::vector<std::unique_ptr<DeviceBatchedGemmMultiD<Row, std::vector<std::unique_ptr<DeviceBatchedGemmMultiD<Row,
...@@ -71,161 +71,161 @@ void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_instances( ...@@ -71,161 +71,161 @@ void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_instances(
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_irregular_instances( // void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_irregular_instances(
std::vector<std::unique_ptr<DeviceBatchedGemmMultiD<Col, // std::vector<std::unique_ptr<DeviceBatchedGemmMultiD<Col,
Row, // Row,
Empty_Tuple, // Empty_Tuple,
Row, // Row,
F16, // F16,
F16, // F16,
Empty_Tuple, // Empty_Tuple,
F16, // F16,
PassThrough, // PassThrough,
PassThrough, // PassThrough,
PassThrough>>>& instances); // PassThrough>>>& instances);
void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_irregular_instances( // void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_irregular_instances(
std::vector<std::unique_ptr<DeviceBatchedGemmMultiD<Col, // std::vector<std::unique_ptr<DeviceBatchedGemmMultiD<Col,
Col, // Col,
Empty_Tuple, // Empty_Tuple,
Row, // Row,
F16, // F16,
F16, // F16,
Empty_Tuple, // Empty_Tuple,
F16, // F16,
PassThrough, // PassThrough,
PassThrough, // PassThrough,
PassThrough>>>& instances); // PassThrough>>>& instances);
void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_irregular_instances( // void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_irregular_instances(
std::vector<std::unique_ptr<DeviceBatchedGemmMultiD<Row, // std::vector<std::unique_ptr<DeviceBatchedGemmMultiD<Row,
Row, // Row,
Empty_Tuple, // Empty_Tuple,
Row, // Row,
F16, // F16,
F16, // F16,
Empty_Tuple, // Empty_Tuple,
F16, // F16,
PassThrough, // PassThrough,
PassThrough, // PassThrough,
PassThrough>>>& instances); // PassThrough>>>& instances);
void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_irregular_instances( // void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_irregular_instances(
std::vector<std::unique_ptr<DeviceBatchedGemmMultiD<Row, // std::vector<std::unique_ptr<DeviceBatchedGemmMultiD<Row,
Col, // Col,
Empty_Tuple, // Empty_Tuple,
Row, // Row,
F16, // F16,
F16, // F16,
Empty_Tuple, // Empty_Tuple,
F16, // F16,
PassThrough, // PassThrough,
PassThrough, // PassThrough,
PassThrough>>>& instances); // PassThrough>>>& instances);
void add_device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_instances( // void add_device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_instances(
std::vector<std::unique_ptr<DeviceBatchedGemmMultiD<Col, // std::vector<std::unique_ptr<DeviceBatchedGemmMultiD<Col,
Row, // Row,
Empty_Tuple, // Empty_Tuple,
Row, // Row,
int8_t, // int8_t,
int8_t, // int8_t,
Empty_Tuple, // Empty_Tuple,
int8_t, // int8_t,
PassThrough, // PassThrough,
PassThrough, // PassThrough,
PassThrough>>>& instances); // PassThrough>>>& instances);
void add_device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_instances( // void add_device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_instances(
std::vector<std::unique_ptr<DeviceBatchedGemmMultiD<Col, // std::vector<std::unique_ptr<DeviceBatchedGemmMultiD<Col,
Col, // Col,
Empty_Tuple, // Empty_Tuple,
Row, // Row,
int8_t, // int8_t,
int8_t, // int8_t,
Empty_Tuple, // Empty_Tuple,
int8_t, // int8_t,
PassThrough, // PassThrough,
PassThrough, // PassThrough,
PassThrough>>>& instances); // PassThrough>>>& instances);
void add_device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_instances( // void add_device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_instances(
std::vector<std::unique_ptr<DeviceBatchedGemmMultiD<Row, // std::vector<std::unique_ptr<DeviceBatchedGemmMultiD<Row,
Row, // Row,
Empty_Tuple, // Empty_Tuple,
Row, // Row,
int8_t, // int8_t,
int8_t, // int8_t,
Empty_Tuple, // Empty_Tuple,
int8_t, // int8_t,
PassThrough, // PassThrough,
PassThrough, // PassThrough,
PassThrough>>>& instances); // PassThrough>>>& instances);
void add_device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_instances( // void add_device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_instances(
std::vector<std::unique_ptr<DeviceBatchedGemmMultiD<Row, // std::vector<std::unique_ptr<DeviceBatchedGemmMultiD<Row,
Col, // Col,
Empty_Tuple, // Empty_Tuple,
Row, // Row,
int8_t, // int8_t,
int8_t, // int8_t,
Empty_Tuple, // Empty_Tuple,
int8_t, // int8_t,
PassThrough, // PassThrough,
PassThrough, // PassThrough,
PassThrough>>>& instances); // PassThrough>>>& instances);
void add_device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_irregular_instances( // void add_device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_irregular_instances(
std::vector<std::unique_ptr<DeviceBatchedGemmMultiD<Col, // std::vector<std::unique_ptr<DeviceBatchedGemmMultiD<Col,
Row, // Row,
Empty_Tuple, // Empty_Tuple,
Row, // Row,
int8_t, // int8_t,
int8_t, // int8_t,
Empty_Tuple, // Empty_Tuple,
int8_t, // int8_t,
PassThrough, // PassThrough,
PassThrough, // PassThrough,
PassThrough>>>& instances); // PassThrough>>>& instances);
void add_device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_irregular_instances( // void add_device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_irregular_instances(
std::vector<std::unique_ptr<DeviceBatchedGemmMultiD<Col, // std::vector<std::unique_ptr<DeviceBatchedGemmMultiD<Col,
Col, // Col,
Empty_Tuple, // Empty_Tuple,
Row, // Row,
int8_t, // int8_t,
int8_t, // int8_t,
Empty_Tuple, // Empty_Tuple,
int8_t, // int8_t,
PassThrough, // PassThrough,
PassThrough, // PassThrough,
PassThrough>>>& instances); // PassThrough>>>& instances);
void add_device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_irregular_instances( // void add_device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_irregular_instances(
std::vector<std::unique_ptr<DeviceBatchedGemmMultiD<Row, // std::vector<std::unique_ptr<DeviceBatchedGemmMultiD<Row,
Row, // Row,
Empty_Tuple, // Empty_Tuple,
Row, // Row,
int8_t, // int8_t,
int8_t, // int8_t,
Empty_Tuple, // Empty_Tuple,
int8_t, // int8_t,
PassThrough, // PassThrough,
PassThrough, // PassThrough,
PassThrough>>>& instances); // PassThrough>>>& instances);
void add_device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_irregular_instances( // void add_device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_irregular_instances(
std::vector<std::unique_ptr<DeviceBatchedGemmMultiD<Row, // std::vector<std::unique_ptr<DeviceBatchedGemmMultiD<Row,
Col, // Col,
Empty_Tuple, // Empty_Tuple,
Row, // Row,
int8_t, // int8_t,
int8_t, // int8_t,
Empty_Tuple, // Empty_Tuple,
int8_t, // int8_t,
PassThrough, // PassThrough,
PassThrough, // PassThrough,
PassThrough>>>& instances); // PassThrough>>>& instances);
template <typename ALayout, template <typename ALayout,
typename BLayout, typename BLayout,
...@@ -268,30 +268,30 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatche ...@@ -268,30 +268,30 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatche
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> && if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>) is_same_v<ELayout, Row>)
{ {
add_device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_instances(op_ptrs); // add_device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_instances(op_ptrs);
add_device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_irregular_instances( // add_device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_irregular_instances(
op_ptrs); // op_ptrs);
} }
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> && else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<ELayout, Row>) is_same_v<ELayout, Row>)
{ {
add_device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_instances(op_ptrs); add_device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_instances(op_ptrs);
add_device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_irregular_instances( // add_device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_irregular_instances(
op_ptrs); // op_ptrs);
} }
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> && else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>) is_same_v<ELayout, Row>)
{ {
add_device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_instances(op_ptrs); // add_device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_instances(op_ptrs);
add_device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_irregular_instances( // add_device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_irregular_instances(
op_ptrs); // op_ptrs);
} }
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> && else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
is_same_v<ELayout, Row>) is_same_v<ELayout, Row>)
{ {
add_device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_instances(op_ptrs); // add_device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_instances(op_ptrs);
add_device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_irregular_instances( // add_device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_irregular_instances(
op_ptrs); // op_ptrs);
} }
} }
else if constexpr(is_same_v<ADataType, int8_t> && is_same_v<BDataType, int8_t> && else if constexpr(is_same_v<ADataType, int8_t> && is_same_v<BDataType, int8_t> &&
...@@ -300,30 +300,30 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatche ...@@ -300,30 +300,30 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatche
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> && if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>) is_same_v<ELayout, Row>)
{ {
add_device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_instances(op_ptrs); // add_device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_instances(op_ptrs);
add_device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_irregular_instances( // add_device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_irregular_instances(
op_ptrs); // op_ptrs);
} }
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> && else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<ELayout, Row>) is_same_v<ELayout, Row>)
{ {
add_device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_instances(op_ptrs); // add_device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_instances(op_ptrs);
add_device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_irregular_instances( // add_device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_irregular_instances(
op_ptrs); // op_ptrs);
} }
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> && else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>) is_same_v<ELayout, Row>)
{ {
add_device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_instances(op_ptrs); // add_device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_instances(op_ptrs);
add_device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_irregular_instances( // add_device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_irregular_instances(
op_ptrs); // op_ptrs);
} }
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> && else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
is_same_v<ELayout, Row>) is_same_v<ELayout, Row>)
{ {
add_device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_instances(op_ptrs); // add_device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_instances(op_ptrs);
add_device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_irregular_instances( // add_device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_irregular_instances(
op_ptrs); // op_ptrs);
} }
} }
......
add_instance_library(device_batched_gemm_multi_d_instance add_instance_library(device_batched_gemm_multi_d_instance
device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_instance.cpp device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_instance
device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_instance.cpp
device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_instance.cpp
device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_instance.cpp
device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_irregular_instance.cpp
device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_irregular_instance.cpp
device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_irregular_instance.cpp
device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_irregular_instance.cpp
device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_instance.cpp
device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_instance.cpp
device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_instance.cpp
device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_instance.cpp
device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_irregular_instance.cpp
device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_irregular_instance.cpp
device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_irregular_instance.cpp
device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_irregular_instance.cpp
) )
...@@ -37,38 +37,38 @@ using device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_instances = std::tu ...@@ -37,38 +37,38 @@ using device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_instances = std::tu
// ##########################| | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | // ##########################| | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | |
// ##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // ##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// MPerBlock=128, NPerBlock=128 // MPerBlock=128, NPerBlock=128
DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, // DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<4, 4>, S<4, 4>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, // DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<4, 4>, S<4, 4>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<2, 8>, S<2, 8>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, // DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<2, 8>, S<2, 8>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// // MPerBlock=128, NPerBlock=64 // // // MPerBlock=128, NPerBlock=64
DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 16, 2, 4, 4, 1, S<8, 2>, S<4, 2>, S<8, 1, 2, 2>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, // DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 16, 2, 4, 4, 1, S<8, 2>, S<4, 2>, S<8, 1, 2, 2>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 16, 2, 4, 4, 1, S<2, 8>, S<2, 4>, S<8, 1, 2, 2>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, // DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 16, 2, 4, 4, 1, S<2, 8>, S<2, 4>, S<8, 1, 2, 2>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// // MPerBlock=64, NPerBlock=128 // // // MPerBlock=64, NPerBlock=128
DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 16, 2, 4, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 2, 2>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, // DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 16, 2, 4, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 2, 2>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 16, 2, 4, 4, 1, S<2, 4>, S<2, 8>, S<8, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 2, 2>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, // DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 16, 2, 4, 4, 1, S<2, 4>, S<2, 8>, S<8, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 2, 2>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// MPerBlock=64, NPerBlock=64 // // MPerBlock=64, NPerBlock=64
DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, // DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 4, 1, S<2, 4>, S<2, 4>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, // DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 4, 1, S<2, 4>, S<2, 4>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 4, 1, S<8, 1>, S<4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, // DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 4, 1, S<8, 1>, S<4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<8, 1>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, // DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<8, 1>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// MPerBlock=16, NPerBlock=64 // // MPerBlock=16, NPerBlock=64
DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 64, 16, 2, 1, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 1, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, // DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 64, 16, 2, 1, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 1, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// MPerBlock=64, NPerBlock=16 // // MPerBlock=64, NPerBlock=16
DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 16, 16, 2, 4, 1, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, // DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 16, 16, 2, 4, 1, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
// MPerBlock=16, NPerBlock=16 // // MPerBlock=16, NPerBlock=16
DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 16, 16, 16, 16, 2, 2, 2, 1, S<2, 2>, S<2, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, // DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 16, 16, 16, 16, 2, 2, 2, 1, S<2, 2>, S<2, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>,
DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 16, 16, 16, 16, 2, 2, 2, 1, S<1, 4>, S<1, 4>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, // DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 16, 16, 16, 16, 2, 2, 2, 1, S<1, 4>, S<1, 4>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>,
// MPerBlock=8, NPerBlock=64 // MPerBlock=8, NPerBlock=64
DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 32, 2, 1, 2, 1, S<4, 1>, S<8, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, // DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 32, 2, 1, 2, 1, S<4, 1>, S<8, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>,
DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 32, 2, 1, 2, 1, S<2, 2>, S<8, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, // DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 32, 2, 1, 2, 1, S<2, 2>, S<8, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>,
// MPerBlock=64, NPerBlock=8 // MPerBlock=64, NPerBlock=8
DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<4, 1>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<4, 1>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<2, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<2, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>
// MPerBlock=8, NPerBlock=8 // MPerBlock=8, NPerBlock=8
DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 1, 2, 1, S<4, 1>, S<2, 1>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, // DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 1, 2, 1, S<4, 1>, S<2, 1>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>,
DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 1, 2, 1, S<1, 4>, S<1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, // DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 1, 2, 1, S<1, 4>, S<1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>,
DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 2, 1, 1, S<2, 1>, S<4, 1>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, // DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 2, 1, 1, S<2, 1>, S<4, 1>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 2, 1, 1, S<1, 2>, S<1, 4>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1> // DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 2, 1, 1, S<1, 2>, S<1, 4>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>
// clang-format on // clang-format on
>; >;
......
...@@ -86,8 +86,8 @@ bool profile_batched_gemm_impl(int do_verification, ...@@ -86,8 +86,8 @@ bool profile_batched_gemm_impl(int do_verification,
{ {
case 0: break; case 0: break;
case 1: case 1:
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5}); a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{1, 2});
b_g_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}); b_g_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{1, 2});
break; break;
default: default:
a_g_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}); a_g_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
......
...@@ -25,10 +25,10 @@ class TestBatchedGemmMultiD : public ::testing::Test ...@@ -25,10 +25,10 @@ class TestBatchedGemmMultiD : public ::testing::Test
using BLayout = std::tuple_element_t<1, Tuple>; using BLayout = std::tuple_element_t<1, Tuple>;
using CLayout = std::tuple_element_t<2, Tuple>; using CLayout = std::tuple_element_t<2, Tuple>;
static constexpr int M = 512; static constexpr int M = 64;
static constexpr int N = 256; static constexpr int N = 8;
static constexpr int K = 128; static constexpr int K = 64;
static constexpr int BatchCount = 3; static constexpr int BatchCount = 1;
template <typename DataType> template <typename DataType>
void Run() void Run()
...@@ -61,14 +61,15 @@ class TestBatchedGemmMultiD : public ::testing::Test ...@@ -61,14 +61,15 @@ class TestBatchedGemmMultiD : public ::testing::Test
} }
}; };
using KernelTypes = ::testing::Types<std::tuple<Row, Row, Row>, using KernelTypes = ::testing::Types<//std::tuple<Row, Row, Row>,
std::tuple<Row, Col, Row>, std::tuple<Row, Col, Row>
std::tuple<Col, Row, Row>, // std::tuple<Col, Row, Row>,
std::tuple<Col, Col, Row>>; // std::tuple<Col, Col, Row>
>;
} // namespace } // namespace
TYPED_TEST_SUITE(TestBatchedGemmMultiD, KernelTypes); TYPED_TEST_SUITE(TestBatchedGemmMultiD, KernelTypes);
TYPED_TEST(TestBatchedGemmMultiD, f16) { this->template Run<F16>(); } TYPED_TEST(TestBatchedGemmMultiD, f16) { this->template Run<F16>(); }
TYPED_TEST(TestBatchedGemmMultiD, int8) { this->template Run<int8_t>(); } // TYPED_TEST(TestBatchedGemmMultiD, int8) { this->template Run<int8_t>(); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment