Commit 07a673c6 authored by carlushuang's avatar carlushuang
Browse files

Merge remote-tracking branch 'origin/develop' into cpu_avx2

parents c0f698d5 ac0d8066
...@@ -20,7 +20,7 @@ using S = ck::Sequence<Is...>; ...@@ -20,7 +20,7 @@ using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization_t::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n] // Compilation parameters for a[m, k] * b[k, n] = c[m, n]
using device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances = using device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances =
......
...@@ -20,8 +20,8 @@ using S = ck::Sequence<Is...>; ...@@ -20,8 +20,8 @@ using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization_t::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization_t::MNPadding; static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;
// Compilation parameters for a[m, k] * b[n, k] = c[m, n] // Compilation parameters for a[m, k] * b[n, k] = c[m, n]
using device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances = using device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances =
......
...@@ -20,7 +20,7 @@ using S = ck::Sequence<Is...>; ...@@ -20,7 +20,7 @@ using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization_t::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// Compilation parameters for a[k, m] * b[k, n] = c[m, n] // Compilation parameters for a[k, m] * b[k, n] = c[m, n]
using device_gemm_xdl_f32_f32_f32_km_kn_mn_instances = using device_gemm_xdl_f32_f32_f32_km_kn_mn_instances =
......
...@@ -20,7 +20,7 @@ using S = ck::Sequence<Is...>; ...@@ -20,7 +20,7 @@ using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization_t::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// Compilation parameters for a[k, m] * b[n, k] = c[m, n] // Compilation parameters for a[k, m] * b[n, k] = c[m, n]
using device_gemm_xdl_f32_f32_f32_km_nk_mn_instances = using device_gemm_xdl_f32_f32_f32_km_nk_mn_instances =
......
...@@ -20,7 +20,7 @@ using S = ck::Sequence<Is...>; ...@@ -20,7 +20,7 @@ using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization_t::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n] // Compilation parameters for a[m, k] * b[k, n] = c[m, n]
using device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances = using device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances =
......
...@@ -20,7 +20,7 @@ using S = ck::Sequence<Is...>; ...@@ -20,7 +20,7 @@ using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization_t::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// Compilation parameters for a[m, k] * b[n, k] = c[m, n] // Compilation parameters for a[m, k] * b[n, k] = c[m, n]
using device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances = using device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances =
......
...@@ -20,7 +20,7 @@ using S = ck::Sequence<Is...>; ...@@ -20,7 +20,7 @@ using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization_t::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n] // Compilation parameters for a[m, k] * b[k, n] = c[m, n]
using device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances = std::tuple< using device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances = std::tuple<
......
...@@ -20,7 +20,7 @@ using S = ck::Sequence<Is...>; ...@@ -20,7 +20,7 @@ using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization_t::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n] // Compilation parameters for a[m, k] * b[k, n] = c[m, n]
using device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances = std::tuple< using device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances = std::tuple<
......
...@@ -20,7 +20,7 @@ using S = ck::Sequence<Is...>; ...@@ -20,7 +20,7 @@ using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization_t::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n] // Compilation parameters for a[m, k] * b[k, n] = c[m, n]
using device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances = std::tuple< using device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances = std::tuple<
......
...@@ -20,7 +20,7 @@ using S = ck::Sequence<Is...>; ...@@ -20,7 +20,7 @@ using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization_t::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n] // Compilation parameters for a[m, k] * b[k, n] = c[m, n]
using device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances = std::tuple< using device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances = std::tuple<
......
...@@ -20,7 +20,7 @@ using S = ck::Sequence<Is...>; ...@@ -20,7 +20,7 @@ using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization_t::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// Compilation parameters for a[k, m] * b[k, n] = c[m, n] // Compilation parameters for a[k, m] * b[k, n] = c[m, n]
using device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances = std::tuple< using device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances = std::tuple<
......
...@@ -20,7 +20,7 @@ using S = ck::Sequence<Is...>; ...@@ -20,7 +20,7 @@ using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization_t::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// Compilation parameters for a[k, m] * b[n, k] = c[m, n] // Compilation parameters for a[k, m] * b[n, k] = c[m, n]
using device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances = std::tuple< using device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances = std::tuple<
......
...@@ -20,7 +20,7 @@ using S = ck::Sequence<Is...>; ...@@ -20,7 +20,7 @@ using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization_t::MNPadding; static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n] // Compilation parameters for a[m, k] * b[k, n] = c[m, n]
using device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances = std::tuple< using device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances = std::tuple<
......
...@@ -20,7 +20,7 @@ using S = ck::Sequence<Is...>; ...@@ -20,7 +20,7 @@ using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization_t::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// Compilation parameters for a[m, k] * b[n, k] = c[m, n] // Compilation parameters for a[m, k] * b[n, k] = c[m, n]
using device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances = std::tuple< using device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances = std::tuple<
......
...@@ -23,7 +23,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; ...@@ -23,7 +23,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ReduceSum = ck::tensor_operation::element_wise::ReduceSum; using ReduceSum = ck::tensor_operation::element_wise::ReduceSum;
using ReduceSquareSum = ck::tensor_operation::element_wise::ReduceSquareSum; using ReduceSquareSum = ck::tensor_operation::element_wise::ReduceSquareSum;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization_t::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// c[m, n] = a[k, m] * b[k, n] // c[m, n] = a[k, m] * b[k, n]
// d0[m] = reduce0(c[m, n]) // d0[m] = reduce0(c[m, n])
...@@ -36,7 +36,7 @@ using device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances = s ...@@ -36,7 +36,7 @@ using device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances = s
//###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 128, 256, 32, 4, 4, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 128, 256, 32, 2, 2, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 128, 128, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 128, 128, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
......
...@@ -23,7 +23,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; ...@@ -23,7 +23,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ReduceSum = ck::tensor_operation::element_wise::ReduceSum; using ReduceSum = ck::tensor_operation::element_wise::ReduceSum;
using ReduceSquareSum = ck::tensor_operation::element_wise::ReduceSquareSum; using ReduceSquareSum = ck::tensor_operation::element_wise::ReduceSquareSum;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization_t::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// c[m, n] = a[k, m] * b[n, k] // c[m, n] = a[k, m] * b[n, k]
// d0[m] = reduce0(c[m, n]) // d0[m] = reduce0(c[m, n])
......
...@@ -23,7 +23,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; ...@@ -23,7 +23,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ReduceSum = ck::tensor_operation::element_wise::ReduceSum; using ReduceSum = ck::tensor_operation::element_wise::ReduceSum;
using ReduceSquareSum = ck::tensor_operation::element_wise::ReduceSquareSum; using ReduceSquareSum = ck::tensor_operation::element_wise::ReduceSquareSum;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization_t::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// c[m, n] = a[m, k] * b[n, k] // c[m, n] = a[m, k] * b[n, k]
// d0[m] = reduce0(c[m, n]) // d0[m] = reduce0(c[m, n])
......
...@@ -23,7 +23,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; ...@@ -23,7 +23,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ReduceSum = ck::tensor_operation::element_wise::ReduceSum; using ReduceSum = ck::tensor_operation::element_wise::ReduceSum;
using ReduceSquareSum = ck::tensor_operation::element_wise::ReduceSquareSum; using ReduceSquareSum = ck::tensor_operation::element_wise::ReduceSquareSum;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization_t::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// c[m, n] = a[m, k] * b[n, k] // c[m, n] = a[m, k] * b[n, k]
// d0[m] = reduce0(c[m, n]) // d0[m] = reduce0(c[m, n])
......
...@@ -20,7 +20,7 @@ using S = ck::Sequence<Is...>; ...@@ -20,7 +20,7 @@ using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization_t::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// Compilation parameters for a[k, m] * b[k, n] = c[m, n] // Compilation parameters for a[k, m] * b[k, n] = c[m, n]
using device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances = std::tuple< using device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances = std::tuple<
......
...@@ -20,7 +20,7 @@ using S = ck::Sequence<Is...>; ...@@ -20,7 +20,7 @@ using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization_t::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// Compilation parameters for a[k, m] * b[n, k] = c[m, n] // Compilation parameters for a[k, m] * b[n, k] = c[m, n]
using device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances = std::tuple< using device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances = std::tuple<
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment