Commit 2b8a9941 authored by Artur Wojcik's avatar Artur Wojcik
Browse files

Merge branch 'develop' into uif2-initial

parents ce9d7c8d 707ad002
......@@ -132,9 +132,6 @@ struct ThreadwiseTensorSliceTransfer_v7r2
Number<num>{});
}
template <typename T>
using has_vec_len = decltype(std::declval<T&>().vec_len);
// SrcDescs: Tuple<const SrcDesc0&, const SrcDesc1&, ...>
// SrcBuffers: Tuple<const SrcBuffer0&, const SrcBuffer1&, ...>
template <typename SrcBuffers,
......@@ -159,19 +156,26 @@ struct ThreadwiseTensorSliceTransfer_v7r2
is_src_valid);
});
if constexpr(is_detected<has_vec_len, decltype(element_op_)>::value)
constexpr auto get_elem_op_vec_len = []() {
if constexpr(is_detected<is_pack8_invocable_t, decltype(element_op_)>::value)
{
constexpr auto elem_op_vec_len = decltype(element_op_)::vec_len;
static_assert(is_same<remove_cvref_t<decltype(elem_op_vec_len)>, index_t>::value,
"vec_len in element_op_ type is not index_t");
static_assert(elem_op_vec_len == 1 || elem_op_vec_len == 2 ||
elem_op_vec_len == 4 || elem_op_vec_len == 8,
"vec_len in element_op_ must be 1, 2, 4, 8");
if constexpr(decltype(element_op_)::is_pack8_invocable)
return math::min(8, SrcScalarPerVector);
}
if constexpr(is_detected<is_pack4_invocable_t, decltype(element_op_)>::value)
{
if constexpr(decltype(element_op_)::is_pack4_invocable)
return math::min(4, SrcScalarPerVector);
}
if constexpr(is_detected<is_pack2_invocable_t, decltype(element_op_)>::value)
{
if constexpr(decltype(element_op_)::is_pack2_invocable)
return math::min(2, SrcScalarPerVector);
}
return 1;
};
static_assert(SrcScalarPerVector % elem_op_vec_len == 0,
"vec_len in element_op_ cannot be divided by SrcScalarPerVector!");
constexpr index_t elem_op_vec_len = get_elem_op_vec_len();
// apply pointwise function
static_for<0, SrcScalarPerVector / elem_op_vec_len, 1>{}([&](auto i) {
......@@ -181,8 +185,7 @@ struct ThreadwiseTensorSliceTransfer_v7r2
[&](auto iSrc) -> const auto& {
using SrcData = remove_cvref_t<tuple_element_t<iSrc.value, SrcDatas>>;
using elem_op_vec_t =
typename vector_type<SrcData, elem_op_vec_len>::type;
using elem_op_vec_t = typename vector_type<SrcData, elem_op_vec_len>::type;
return src_vectors[iSrc].template AsType<elem_op_vec_t>()[i];
},
......@@ -194,8 +197,7 @@ struct ThreadwiseTensorSliceTransfer_v7r2
[&](auto iDst) -> auto& {
using DstData = remove_cvref_t<tuple_element_t<iDst.value, DstDatas>>;
using elem_op_vec_t =
typename vector_type<DstData, elem_op_vec_len>::type;
using elem_op_vec_t = typename vector_type<DstData, elem_op_vec_len>::type;
return dst_vectors(iDst).template AsType<elem_op_vec_t>()(i);
},
......@@ -211,42 +213,6 @@ struct ThreadwiseTensorSliceTransfer_v7r2
// ...)
unpack2(element_op_, dst_data_refs, src_data_refs);
});
}
else
{
// apply pointwise function
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
// get reference to src data
const auto src_data_refs = generate_tie(
// return type should be lvalue
[&](auto iSrc) -> const auto& {
using SrcData = remove_cvref_t<tuple_element_t<iSrc.value, SrcDatas>>;
return src_vectors[iSrc].template AsType<SrcData>()[i];
},
Number<nSrc>{});
// get reference to dst data
auto dst_data_refs = generate_tie(
// return type should be lvalue
[&](auto iDst) -> auto& {
using DstData = remove_cvref_t<tuple_element_t<iDst.value, DstDatas>>;
return dst_vectors(iDst).template AsType<DstData>()(i);
},
Number<nDst>{});
// apply pointwise function
// pointwise function signature:
// element_op_(dst_data_refs[I0],
// dst_data_refs[I1],
// ...,
// src_data_refs[I0],
// src_data_refs[I1],
// ...)
unpack2(element_op_, dst_data_refs, src_data_refs);
});
}
dst_vectors_tuple_(iAccess) = dst_vectors;
......
......@@ -31,4 +31,13 @@ struct nonesuch
template <template <class...> class Op, class... Args>
using is_detected = typename detail::detector<nonesuch, void, Op, Args...>::value_t;
template <typename T>
using is_pack2_invocable_t = decltype(std::declval<T&>().is_pack2_invocable);
template <typename T>
using is_pack4_invocable_t = decltype(std::declval<T&>().is_pack4_invocable);
template <typename T>
using is_pack8_invocable_t = decltype(std::declval<T&>().is_pack8_invocable);
} // namespace ck
......@@ -897,14 +897,3 @@ template <index_t NSize, index_t I>
using uniform_sequence_gen_t = typename uniform_sequence_gen<NSize, I>::type;
} // namespace ck
template <ck::index_t... Is>
std::ostream& operator<<(std::ostream& os, const ck::Sequence<Is...>)
{
using S = ck::Sequence<Is...>;
os << "{";
ck::static_for<0, S::Size() - ck::Number<1>{}, 1>{}(
[&](auto i) { os << S::At(i).value << ", "; });
os << S::At(S::Size() - ck::Number<1>{}).value << "}";
return os;
}
......@@ -14,27 +14,27 @@ namespace ck {
namespace tensor_operation {
namespace host {
///
/// @brief Reference implementation for forward convolution.
///
/// @paragraph
/// Tensor descriptor in GNCHW/GKCXY/GNKHW dimensional order
/// Supports both GNCHW/NGCHW as well as GNHWC/NHWGC physical layout
/// as long as dimensions in tensor descriptor is in GNCHW order
///
/// @tparam InDataType Input tensor data type.
/// @tparam WeiDataType Weights tensor data type.
/// @tparam OutDataType Output tensor data type.
/// @tparam InElementwiseOperation Functor for input tensor elementwise
/// operation.
/// @tparam WeiElementwiseOperation Functor for weights tensor elementwise
/// operation.
/// @tparam NDimSpatial Number of spatial dimensions.
///
/// input descriptor in [G, N, C, Do, Ho, Wo] order
/// weight descriptor in [G, K, C, Z, Y, X] order
/// output descriptor in [G, N, K, Di, Hi, Wi] order
/// phyiscal layout is irrelavent
//
// @brief Reference implementation for forward convolution.
//
// @paragraph
// Tensor descriptor in GNCHW/GKCXY/GNKHW dimensional order
// Supports both GNCHW/NGCHW as well as GNHWC/NHWGC physical layout
// as long as dimensions in tensor descriptor is in GNCHW order
//
// @tparam InDataType Input tensor data type.
// @tparam WeiDataType Weights tensor data type.
// @tparam OutDataType Output tensor data type.
// @tparam InElementwiseOperation Functor for input tensor elementwise
// operation.
// @tparam WeiElementwiseOperation Functor for weights tensor elementwise
// operation.
// @tparam NDimSpatial Number of spatial dimensions.
//
// input descriptor in [G, N, C, Do, Ho, Wo] order
// weight descriptor in [G, K, C, Z, Y, X] order
// output descriptor in [G, N, K, Di, Hi, Wi] order
// phyiscal layout is irrelavent
template <ck::index_t NDimSpatial,
typename InDataType,
typename WeiDataType,
......
......@@ -98,6 +98,26 @@ void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemmSplitK<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_splitk_f16_f16_f16_comp_f8_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemmSplitK<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough, F8>>>&
instances);
void add_device_gemm_xdl_splitk_f16_f16_f16_comp_f8_km_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemmSplitK<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough, F8>>>&
instances);
void add_device_gemm_xdl_splitk_f16_f16_f16_comp_f8_mk_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemmSplitK<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough, F8>>>&
instances);
void add_device_gemm_xdl_splitk_f16_f16_f16_comp_f8_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemmSplitK<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough, F8>>>&
instances);
#endif
template <typename ADataType,
......@@ -105,7 +125,8 @@ template <typename ADataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename CLayout>
typename CLayout,
typename ComputeType>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceGemmSplitK<ALayout,
BLayout,
......@@ -115,7 +136,8 @@ struct DeviceOperationInstanceFactory<
CDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>>
ck::tensor_operation::element_wise::PassThrough,
ComputeType>>
{
using DeviceOp = DeviceGemmSplitK<ALayout,
BLayout,
......@@ -125,7 +147,8 @@ struct DeviceOperationInstanceFactory<
CDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
ck::tensor_operation::element_wise::PassThrough,
ComputeType>;
static auto GetInstances()
{
......@@ -158,7 +181,7 @@ struct DeviceOperationInstanceFactory<
#endif
#ifdef CK_ENABLE_FP16
else if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<CDataType, half_t>)
is_same_v<CDataType, half_t> && is_same_v<ComputeType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
......@@ -231,6 +254,31 @@ struct DeviceOperationInstanceFactory<
add_device_gemm_xdl_splitk_f16_f8_f16_km_nk_mn_instances(op_ptrs);
}
}
else if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<CDataType, half_t> && is_same_v<ComputeType, f8_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_xdl_splitk_f16_f16_f16_comp_f8_mk_kn_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_xdl_splitk_f16_f16_f16_comp_f8_mk_nk_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_xdl_splitk_f16_f16_f16_comp_f8_km_kn_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_xdl_splitk_f16_f16_f16_comp_f8_km_nk_mn_instances(op_ptrs);
}
}
#endif
return op_ptrs;
}
......
......@@ -15,6 +15,10 @@ list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_in
device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_instance.cpp
device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_instance.cpp
device_gemm_xdl_splitk_f16_fp8_f16_km_kn_mn_instance.cpp
device_gemm_xdl_splitk_f16_fp8_f16_km_nk_mn_instance.cpp)
device_gemm_xdl_splitk_f16_fp8_f16_km_nk_mn_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_mk_kn_mn_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_mk_nk_mn_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_km_kn_mn_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_km_nk_mn_instance.cpp)
add_instance_library(device_gemm_splitk_instance ${GEMM_SPLITK_INSTANCES})
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F8 = ck::f8_t;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
using device_gemm_xdl_splitk_f16_f16_f16_comp_f8_km_kn_mn_instances = std::tuple<
// clang-format off
//#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute|
//#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Type|
//#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| |
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F8>
// clang-format on
>;
void add_device_gemm_xdl_splitk_f16_f16_f16_comp_f8_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemmSplitK<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough, F8>>>&
instances)
{
add_device_operation_instances(instances,
device_gemm_xdl_splitk_f16_f16_f16_comp_f8_km_kn_mn_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F8 = ck::f8_t;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
using device_gemm_xdl_splitk_f16_f16_f16_comp_f8_km_nk_mn_instances = std::tuple<
// clang-format off
//#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute|
//#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Type|
//#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| |
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F8>
// clang-format on
>;
void add_device_gemm_xdl_splitk_f16_f16_f16_comp_f8_km_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemmSplitK<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough, F8>>>&
instances)
{
add_device_operation_instances(instances,
device_gemm_xdl_splitk_f16_f16_f16_comp_f8_km_nk_mn_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F8 = ck::f8_t;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;
using device_gemm_xdl_splitk_f16_f16_f16_comp_f8_mk_kn_mn_generic_instances = std::tuple<
// clang-format off
//#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute|
//#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Type|
//#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| |
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 1, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 8>, 2, F8>
// clang-format on
>;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
using device_gemm_xdl_splitk_f16_f16_f16_comp_f8_mk_kn_mn_instances = std::tuple<
// clang-format off
//#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute|
//#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Type|
//#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| |
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 64, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 48, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 192, 64, 4, 8, 32, 32, 3, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 32, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 192, 32, 4, 8, 32, 32, 3, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 64, 32, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F8>
// clang-format on
>;
void add_device_gemm_xdl_splitk_f16_f16_f16_comp_f8_mk_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemmSplitK<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough, F8>>>&
instances)
{
add_device_operation_instances(
instances, device_gemm_xdl_splitk_f16_f16_f16_comp_f8_mk_kn_mn_generic_instances{});
add_device_operation_instances(instances,
device_gemm_xdl_splitk_f16_f16_f16_comp_f8_mk_kn_mn_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F8 = ck::f8_t;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto MNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
using device_gemm_xdl_splitk_f16_f16_f16_comp_f8_mk_nk_mn_generic_instances = std::tuple<
// clang-format off
//#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute|
//#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Type|
//#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| |
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, MNKPadding, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 1, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 2, F8>
// clang-format on
>;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
using device_gemm_xdl_splitk_f16_f16_f16_comp_f8_mk_nk_mn_instances = std::tuple<
// clang-format off
//#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute|
//#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Type|
//#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| |
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, MNKPadding, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, MNKPadding, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, MNKPadding, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, MNKPadding, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, MNKPadding, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, MNKPadding, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, MNKPadding, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, MNKPadding, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, MNKPadding, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, MNKPadding, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, MNKPadding, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, MNKPadding, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, MNKPadding, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F8>
// clang-format on
>;
void add_device_gemm_xdl_splitk_f16_f16_f16_comp_f8_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemmSplitK<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough, F8>>>&
instances)
{
add_device_operation_instances(
instances, device_gemm_xdl_splitk_f16_f16_f16_comp_f8_mk_nk_mn_generic_instances{});
add_device_operation_instances(instances,
device_gemm_xdl_splitk_f16_f16_f16_comp_f8_mk_nk_mn_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -30,7 +30,8 @@ template <typename ADataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename CLayout>
typename CLayout,
typename ComputeType = CDataType>
bool profile_gemm_splitk_impl(int do_verification,
int init_method,
bool do_log,
......@@ -103,7 +104,8 @@ bool profile_gemm_splitk_impl(int do_verification,
CDataType,
AElementOp,
BElementOp,
CElementOp>;
CElementOp,
ComputeType>;
// get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
......@@ -120,7 +122,8 @@ bool profile_gemm_splitk_impl(int do_verification,
AccDataType,
AElementOp,
BElementOp,
CElementOp>;
CElementOp,
ComputeType>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
......
......@@ -25,6 +25,7 @@ enum struct GemmDataType
INT8_INT8_INT8, // 3
F8_F16_F16, // 4
F16_F8_F16, // 5
F16_F16_F16_F8, // 6
};
#define OP_NAME "gemm_splitk"
......@@ -35,7 +36,8 @@ int profile_gemm_splitk(int argc, char* argv[])
if(argc != 15)
{
printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n");
printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8)\n");
printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8; 6: f16, "
"comp f8)\n");
printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n");
printf(" 1: A[m, k] * B[n, k] = C[m, n];\n");
printf(" 2: A[k, m] * B[k, n] = C[m, n];\n");
......@@ -80,7 +82,8 @@ int profile_gemm_splitk(int argc, char* argv[])
auto c_type,
auto a_layout,
auto b_layout,
auto c_layout) {
auto c_layout,
auto compute_type) {
using ADataType = decltype(a_type);
using BDataType = decltype(b_type);
using AccDataType = decltype(acc_type);
......@@ -90,6 +93,8 @@ int profile_gemm_splitk(int argc, char* argv[])
using BLayout = decltype(b_layout);
using CLayout = decltype(c_layout);
using ComputeType = decltype(compute_type);
const int DefaultStrideA = ck::is_same_v<ALayout, Row> ? K : M;
const int DefaultStrideB = ck::is_same_v<BLayout, Row> ? N : K;
const int DefaultStrideC = ck::is_same_v<CLayout, Row> ? N : M;
......@@ -100,7 +105,8 @@ int profile_gemm_splitk(int argc, char* argv[])
CDataType,
ALayout,
BLayout,
CLayout>(
CLayout,
ComputeType>(
do_verification,
init_method,
do_log,
......@@ -118,68 +124,84 @@ int profile_gemm_splitk(int argc, char* argv[])
if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN)
{
return profile(F32{}, F32{}, F32{}, F32{}, Row{}, Row{}, Row{});
return profile(F32{}, F32{}, F32{}, F32{}, Row{}, Row{}, Row{}, F32{});
}
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN)
{
return profile(F32{}, F32{}, F32{}, F32{}, Row{}, Col{}, Row{});
return profile(F32{}, F32{}, F32{}, F32{}, Row{}, Col{}, Row{}, F32{});
}
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN)
{
return profile(F32{}, F32{}, F32{}, F32{}, Col{}, Row{}, Row{});
return profile(F32{}, F32{}, F32{}, F32{}, Col{}, Row{}, Row{}, F32{});
}
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN)
{
return profile(F32{}, F32{}, F32{}, F32{}, Col{}, Col{}, Row{});
return profile(F32{}, F32{}, F32{}, F32{}, Col{}, Col{}, Row{}, F32{});
}
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{
return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Row{}, Row{});
return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Row{}, Row{}, F16{});
}
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{
return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{});
return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{}, F16{});
}
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN)
{
return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Row{}, Row{});
return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Row{}, Row{}, F16{});
}
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN)
{
return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Col{}, Row{});
return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Col{}, Row{}, F16{});
}
#if defined CK_ENABLE_FP8
else if(data_type == GemmDataType::F8_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{
return profile(F8{}, F16{}, F32{}, F16{}, Row{}, Row{}, Row{});
return profile(F8{}, F16{}, F32{}, F16{}, Row{}, Row{}, Row{}, F16{});
}
else if(data_type == GemmDataType::F8_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{
return profile(F8{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{});
return profile(F8{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{}, F16{});
}
else if(data_type == GemmDataType::F8_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN)
{
return profile(F8{}, F16{}, F32{}, F16{}, Col{}, Row{}, Row{});
return profile(F8{}, F16{}, F32{}, F16{}, Col{}, Row{}, Row{}, F16{});
}
else if(data_type == GemmDataType::F8_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN)
{
return profile(F8{}, F16{}, F32{}, F16{}, Col{}, Col{}, Row{});
return profile(F8{}, F16{}, F32{}, F16{}, Col{}, Col{}, Row{}, F16{});
}
else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{
return profile(F16{}, F8{}, F32{}, F16{}, Row{}, Row{}, Row{});
return profile(F16{}, F8{}, F32{}, F16{}, Row{}, Row{}, Row{}, F16{});
}
else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{
return profile(F16{}, F8{}, F32{}, F16{}, Row{}, Col{}, Row{});
return profile(F16{}, F8{}, F32{}, F16{}, Row{}, Col{}, Row{}, F16{});
}
else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::KM_KN_MN)
{
return profile(F16{}, F8{}, F32{}, F16{}, Col{}, Row{}, Row{});
return profile(F16{}, F8{}, F32{}, F16{}, Col{}, Row{}, Row{}, F16{});
}
else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::KM_NK_MN)
{
return profile(F16{}, F8{}, F32{}, F16{}, Col{}, Col{}, Row{});
return profile(F16{}, F8{}, F32{}, F16{}, Col{}, Col{}, Row{}, F16{});
}
else if(data_type == GemmDataType::F16_F16_F16_F8 && layout == GemmMatrixLayout::MK_KN_MN)
{
return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Row{}, Row{}, F8{});
}
else if(data_type == GemmDataType::F16_F16_F16_F8 && layout == GemmMatrixLayout::MK_NK_MN)
{
return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{}, F8{});
}
else if(data_type == GemmDataType::F16_F16_F16_F8 && layout == GemmMatrixLayout::KM_KN_MN)
{
return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Row{}, Row{}, F8{});
}
else if(data_type == GemmDataType::F16_F16_F16_F8 && layout == GemmMatrixLayout::KM_NK_MN)
{
return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Col{}, Row{}, F8{});
}
#endif
else
......
......@@ -12,9 +12,7 @@ cmake
-save-temps=$PWD" \
-D CMAKE_BUILD_TYPE=Release \
-D BUILD_DEV=ON \
-D GPU_TARGETS="gfx90a" \
-D GPU_TARGETS="gfx908;gfx90a;gfx940" \
-D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \
-D USE_BITINT_EXTENSION_INT4=OFF \
${MY_PROJECT_SOURCE}
#-D GPU_TARGETS="gfx908;gfx90a;gfx940" \
......@@ -4,7 +4,7 @@ TEST_P(RRR_F16_F16_F16, TinyCases)
{
const std::vector<int> Ms{0, 1};
constexpr int N = 768;
constexpr int K = 1088;
constexpr int K = 544;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
......@@ -17,9 +17,9 @@ TEST_P(RRR_F16_F16_F16, TinyCases)
TEST_P(RRR_F16_F16_F16, SmallCases)
{
const std::vector<int> Ms{2, 3, 4, 5};
const std::vector<int> Ms{2, 1, 3, 4, 5, 0};
constexpr int N = 768;
constexpr int K = 1088;
constexpr int K = 544;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
......@@ -34,7 +34,7 @@ TEST_P(RRR_F16_F16_F16, MidCases)
{
const std::vector<int> Ms{167, 183, 177, 153, 139, 204};
constexpr int N = 768;
constexpr int K = 1088;
constexpr int K = 544;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
......@@ -49,7 +49,7 @@ TEST_P(RRR_F16_F16_F16, Regular)
{
const std::vector<int> Ms{64, 128, 256};
constexpr int N = 768;
constexpr int K = 640;
constexpr int K = 320;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
......@@ -79,7 +79,7 @@ TEST_P(RCR_F16_F16_F16, TinyCases)
{
const std::vector<int> Ms{0, 1};
constexpr int N = 768;
constexpr int K = 1088;
constexpr int K = 544;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
......@@ -91,9 +91,9 @@ TEST_P(RCR_F16_F16_F16, TinyCases)
TEST_P(RCR_F16_F16_F16, SmallCases)
{
const std::vector<int> Ms{2, 3, 4, 5};
const std::vector<int> Ms{2, 1, 3, 4, 5, 0};
constexpr int N = 768;
constexpr int K = 1088;
constexpr int K = 544;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
......@@ -123,7 +123,7 @@ TEST_P(RCR_F16_F16_F16, Regular)
{
const std::vector<int> Ms{32, 64, 128, 256};
constexpr int N = 768;
constexpr int K = 640;
constexpr int K = 320;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
......@@ -151,9 +151,9 @@ TEST_P(RCR_F16_F16_F16, MNKPadded)
TEST_P(RRR_F16_F16_F16_LargeK, TestLargeKBatch)
{
const std::vector<int> Ms{127, 150, 188, 210};
const std::vector<int> Ms{188, 210};
constexpr int N = 768;
constexpr int K = 8192;
constexpr int K = 4096;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment