Commit 3af2a90c authored by Adam Osewski's avatar Adam Osewski
Browse files

Add instance mk-nk-mn

parent f1814e23
...@@ -31,7 +31,7 @@ struct DeviceGroupedGemm : public BaseOperator ...@@ -31,7 +31,7 @@ struct DeviceGroupedGemm : public BaseOperator
{ {
static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr index_t NumDTensor = DsDataType::Size();
static_assert(DsLayout::Size() == DsDataType::Size(), "wrong! inconsisiten NumDTensor"); static_assert(DsLayout::Size() == DsDataType::Size(), "wrong! inconsistent NumDTensor");
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(std::vector<const void*>& p_a, MakeArgumentPointer(std::vector<const void*>& p_a,
......
...@@ -68,6 +68,19 @@ void add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances( ...@@ -68,6 +68,19 @@ void add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Col,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
template <typename ALayout, template <typename ALayout,
typename BLayout, typename BLayout,
typename ELayout, typename ELayout,
...@@ -114,6 +127,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -114,6 +127,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<ELayout, Row>) is_same_v<ELayout, Row>)
{ {
add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(op_ptrs); add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
} }
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> && else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>) is_same_v<ELayout, Row>)
......
...@@ -3,4 +3,5 @@ add_instance_library(device_grouped_gemm_instance ...@@ -3,4 +3,5 @@ add_instance_library(device_grouped_gemm_instance
device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp
device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp
device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment