Commit a09f0171 authored by Jing Zhang's avatar Jing Zhang
Browse files

add f16_f8_f16

parent c6033af3
......@@ -328,6 +328,16 @@ void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_instances(
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
void add_device_gemm_xdl_c_shuffle_f16_f8_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_c_shuffle_f16_f8_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
#endif
template <typename ALayout,
......@@ -549,6 +559,20 @@ struct DeviceOperationInstanceFactory<
add_device_gemm_xdl_c_shuffle_f8_f8_f8_km_nk_mn_instances(op_ptrs);
}
}
else if constexpr(is_same_v<ADataType, ck::half_t> && is_same_v<BDataType, ck::f8_t> &&
is_same_v<CDataType, ck::half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_xdl_c_shuffle_f16_f8_f16_mk_kn_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_xdl_c_shuffle_f16_f8_f16_mk_nk_mn_instances(op_ptrs);
}
}
#endif
return op_ptrs;
}
......
......@@ -28,8 +28,8 @@ using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto MNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
static constexpr auto MNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;
static constexpr auto MNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
template <ck::tensor_operation::device::GemmSpecialization GemmSpec>
// Compilation parameters for a[m, k] * b[n, k] = c[m, n]
......@@ -98,6 +98,9 @@ void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(
add_device_operation_instances(
instances, device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances<GemmDefault>{});
add_device_operation_instances(
instances, device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances<MNPadding>{});
add_device_operation_instances(
instances, device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances<MNKPadding>{});
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment