Commit 6d9a07d7 authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'develop' into amd-develop

parents b30d416c a776978c
...@@ -8,6 +8,8 @@ ...@@ -8,6 +8,8 @@
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck { namespace ck {
// Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory // Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory
...@@ -1156,27 +1158,56 @@ struct ThreadwiseTensorSliceTransfer_v4 ...@@ -1156,27 +1158,56 @@ struct ThreadwiseTensorSliceTransfer_v4
src_ref_to_origin_disp_idx + data_to_origin_disp_idx + src_ref_to_origin_disp_idx + data_to_origin_disp_idx +
i * src_scalar_step_in_vector); i * src_scalar_step_in_vector);
// apply type convert
src_tmp_vector.template AsType<SrcData>()(i) = src_buf[Number<src_offset>{}]; src_tmp_vector.template AsType<SrcData>()(i) = src_buf[Number<src_offset>{}];
}); });
} }
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// DstData)
vector_type_maker_t<DstData, SrcScalarPerVector> dst_tmp_vector;
// TODO: if SrcData and DstData are vetor type, then static_cast may not compile if constexpr(is_same<remove_cvref_t<SrcData>, f8_t>::value &&
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { is_same<remove_cvref_t<DstData>, half_t>::value &&
dst_tmp_vector.template AsType<DstData>()(i) = SrcScalarPerVector % 2 == 0)
type_convert<DstData>(src_tmp_vector.template AsType<SrcData>()[i]); {
}); // copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// DstData)
vector_type_maker_t<DstData, SrcScalarPerVector> dst_tmp_vector;
constexpr index_t pack_size = 2;
using dst_v_t = typename vector_type_maker_t<DstData, pack_size>::type;
using src_v_t = typename vector_type_maker_t<SrcData, pack_size>::type;
static_for<0, SrcScalarPerVector / pack_size, 1>{}([&](auto i) {
ck::tensor_operation::element_wise::PassThroughPack2{}(
dst_tmp_vector.template AsType<dst_v_t>()(i),
src_tmp_vector.template AsType<src_v_t>()[i]);
});
// copy data from dst_tmp_vector into dst_buf // copy data from dst_tmp_vector into dst_buf
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
constexpr index_t dst_offset = dst_desc.CalculateOffset( constexpr index_t dst_offset = dst_desc.CalculateOffset(
dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector); dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector);
dst_buf(Number<dst_offset>{}) = dst_tmp_vector.template AsType<DstData>()[i]; dst_buf(Number<dst_offset>{}) = dst_tmp_vector.template AsType<DstData>()[i];
}); });
}
else
{
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// DstData)
vector_type_maker_t<DstData, SrcScalarPerVector> dst_tmp_vector;
// TODO: if SrcData and DstData are vetor type, then static_cast may not compile
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
dst_tmp_vector.template AsType<DstData>()(i) =
type_convert<DstData>(src_tmp_vector.template AsType<SrcData>()[i]);
});
// copy data from dst_tmp_vector into dst_buf
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
constexpr index_t dst_offset = dst_desc.CalculateOffset(
dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector);
dst_buf(Number<dst_offset>{}) = dst_tmp_vector.template AsType<DstData>()[i];
});
}
}); });
} }
......
...@@ -107,11 +107,12 @@ __host__ __device__ constexpr Y f8_convert_sr(X x); ...@@ -107,11 +107,12 @@ __host__ __device__ constexpr Y f8_convert_sr(X x);
template <> template <>
inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x) inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
{ {
constexpr int seed = 42; constexpr int seed = 1254739;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x); uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
float max_fp8 = 240.0f;
if(!std::isinf(x))
x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x);
#if defined(__gfx94__) #if defined(__gfx94__)
float max_fp8 = 240.0f;
x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x);
union union
{ {
float fval; float fval;
...@@ -144,7 +145,7 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x) ...@@ -144,7 +145,7 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
constexpr bool negative_zero_nan = true; constexpr bool negative_zero_nan = true;
constexpr bool clip = true; constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic; constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
constexpr int seed = 42; constexpr int seed = 1254739;
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x); uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x);
return utils:: return utils::
cast_to_f8<half_t, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>( cast_to_f8<half_t, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
...@@ -156,7 +157,7 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x) ...@@ -156,7 +157,7 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
template <> template <>
inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, float>(float x) inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, float>(float x)
{ {
constexpr int seed = 42; constexpr int seed = 1254739;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x); uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
#if defined(__gfx94__) #if defined(__gfx94__)
union union
...@@ -191,7 +192,7 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x) ...@@ -191,7 +192,7 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
constexpr bool negative_zero_nan = true; constexpr bool negative_zero_nan = true;
constexpr bool clip = true; constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic; constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
constexpr int seed = 42; constexpr int seed = 1254739;
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x); uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x);
return utils:: return utils::
cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>( cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
...@@ -207,9 +208,10 @@ __host__ __device__ constexpr Y f8_convert_rne(X x); ...@@ -207,9 +208,10 @@ __host__ __device__ constexpr Y f8_convert_rne(X x);
template <> template <>
inline __host__ __device__ f8_t f8_convert_rne<f8_t, float>(float x) inline __host__ __device__ f8_t f8_convert_rne<f8_t, float>(float x)
{ {
#if defined(__gfx94__)
float max_fp8 = 240.0f; float max_fp8 = 240.0f;
x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x); if(!std::isinf(x))
x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x);
#if defined(__gfx94__)
union union
{ {
float fval; float fval;
......
...@@ -61,12 +61,12 @@ __device__ void copy(const SrcTensorType& src_tensor, DstTensorType& dst_tensor) ...@@ -61,12 +61,12 @@ __device__ void copy(const SrcTensorType& src_tensor, DstTensorType& dst_tensor)
decltype(dim_access_order), decltype(dim_access_order),
VectorDim, VectorDim,
ScalarPerVector, ScalarPerVector,
Sequence<false>, Sequence<true>,
Sequence<false>>{in_grid_desc, Sequence<true>>{in_grid_desc,
make_tuple(src_tensor.GetMultiIdxOffsets()), make_tuple(src_tensor.GetMultiIdxOffsets()),
out_grid_desc, out_grid_desc,
make_tuple(dst_tensor.GetMultiIdxOffsets()), make_tuple(dst_tensor.GetMultiIdxOffsets()),
tensor_operation::element_wise::PassThrough{}}; tensor_operation::element_wise::PassThrough{}};
transfer.Run(tie(in_grid_desc), transfer.Run(tie(in_grid_desc),
tie(src_tensor.GetBuffer()), tie(src_tensor.GetBuffer()),
...@@ -104,37 +104,25 @@ __device__ void copy(const SrcTensorType& src_tensor, DstTensorType& dst_tensor) ...@@ -104,37 +104,25 @@ __device__ void copy(const SrcTensorType& src_tensor, DstTensorType& dst_tensor)
else if constexpr(SrcTensorType::IsDynamicBuffer && !DstTensorType::IsDynamicBuffer) else if constexpr(SrcTensorType::IsDynamicBuffer && !DstTensorType::IsDynamicBuffer)
{ {
// Perform copy from DynamicBuffer to StaticBuffer // Perform copy from DynamicBuffer to StaticBuffer
const auto src_dst_slice_origin = const auto dst_slice_origin_idxs =
generate_tuple([&](auto) { return I0; }, Number<num_dims>{}); generate_tuple([&](auto) { return I0; }, Number<num_dims>{});
constexpr auto src_vector_tensor_lengths = generate_sequence_v2( auto transfer = ThreadwiseTensorSliceTransfer_v2<
[&](auto I) { std::remove_const_t<typename SrcTensorType::TensorElementType>,
if constexpr(I == VectorDim) std::remove_const_t<typename DstTensorType::TensorElementType>,
{ remove_cvref_t<decltype(in_grid_desc)>,
return Number<ScalarPerVector>{}; remove_cvref_t<decltype(out_grid_desc)>,
} decltype(thread_slice_lengths),
else decltype(dim_access_order),
{ VectorDim,
return I1; ScalarPerVector,
} I1,
}, false,
Number<num_dims>{}); false>{in_grid_desc, src_tensor.GetMultiIdxOffsets()};
auto transfer =
ThreadwiseTensorSliceTransfer_v4r1<typename SrcTensorType::TensorElementType,
typename DstTensorType::TensorElementType,
remove_cvref_t<decltype(in_grid_desc)>,
remove_cvref_t<decltype(out_grid_desc)>,
decltype(thread_slice_lengths),
decltype(dim_access_order),
decltype(src_vector_tensor_lengths),
decltype(dim_access_order)>{
src_tensor.GetMultiIdxOffsets()};
transfer.Run(in_grid_desc, transfer.Run(in_grid_desc,
src_dst_slice_origin,
src_tensor.GetBuffer(), src_tensor.GetBuffer(),
out_grid_desc, out_grid_desc,
src_dst_slice_origin, dst_slice_origin_idxs,
dst_tensor.GetBuffer()); dst_tensor.GetBuffer());
} }
else else
...@@ -183,10 +171,12 @@ template <typename DimAccessOrderTuple, ...@@ -183,10 +171,12 @@ template <typename DimAccessOrderTuple,
index_t ScalarPerVector, index_t ScalarPerVector,
typename SrcTensorType, typename SrcTensorType,
typename DstTensorType, typename DstTensorType,
typename ThreadLayoutTuple> typename ThreadShape,
__device__ void blockwise_copy(const SrcTensorType& src_tensor, typename ThreadUnrolledDesc>
DstTensorType& dst_tensor, __device__ void
[[maybe_unused]] ThreadLayoutTuple& thread_layout) blockwise_copy(const SrcTensorType& src_tensor,
DstTensorType& dst_tensor,
[[maybe_unused]] const Layout<ThreadShape, ThreadUnrolledDesc>& thread_layout)
{ {
static_assert(SrcTensorType::IsDynamicBuffer && DstTensorType::IsDynamicBuffer); static_assert(SrcTensorType::IsDynamicBuffer && DstTensorType::IsDynamicBuffer);
static_assert(is_detected<is_tuple, DimAccessOrderTuple>::value); static_assert(is_detected<is_tuple, DimAccessOrderTuple>::value);
...@@ -199,12 +189,12 @@ __device__ void blockwise_copy(const SrcTensorType& src_tensor, ...@@ -199,12 +189,12 @@ __device__ void blockwise_copy(const SrcTensorType& src_tensor,
constexpr auto tile_lengths_seq = constexpr auto tile_lengths_seq =
generate_sequence_v2([](auto I) { return size(SrcShapeType{}.At(I)); }, Number<num_dims>{}); generate_sequence_v2([](auto I) { return size(SrcShapeType{}.At(I)); }, Number<num_dims>{});
constexpr auto thread_layout_seq = generate_sequence_v2( constexpr auto thread_layout_seq =
[](auto I) { return size(ThreadLayoutTuple{}.At(I)); }, Number<num_dims>{}); generate_sequence_v2([](auto I) { return size<I>(ThreadShape{}); }, Number<num_dims>{});
constexpr auto dim_access_order = generate_sequence_v2( constexpr auto dim_access_order = generate_sequence_v2(
[](auto I) { return DimAccessOrderTuple{}.At(I); }, Number<num_dims>{}); [](auto I) { return DimAccessOrderTuple{}.At(I); }, Number<num_dims>{});
using ThisThreadBlock = ThisThreadBlock<size(ThreadLayoutTuple{})>; using ThisThreadBlock = ThisThreadBlock<size(ThreadShape{})>;
// Perform copy between DynamicBuffers // Perform copy between DynamicBuffers
auto transfer = ThreadGroupTensorSliceTransfer_v7< auto transfer = ThreadGroupTensorSliceTransfer_v7<
......
...@@ -48,8 +48,9 @@ __device__ constexpr auto GetBlockDescriptor() ...@@ -48,8 +48,9 @@ __device__ constexpr auto GetBlockDescriptor()
/** /**
* \brief Perform blockwise gemm xdl on tensors stored in lds. Result will be * \brief Perform blockwise gemm xdl on tensors stored in lds. Result will be
* stored in Vgpr register. A data layout must be (MPerBlock, KPerBlock) and B * stored in Vgpr register. A data layout must be (MPerBlock, KPerBlock) or
* data layout must be (NPerBlock, KPerBlock). * (K0PerBlock, MPerBlock, K1) and B data layout must be (NPerBlock, KPerBlock)
* or (K0PerBlock, NPerBlock, K1).
* *
* \note C output Vgpr register layout (8D): * \note C output Vgpr register layout (8D):
* - MXdlPerWave - The number of MFMA instructions run by single wave in M * - MXdlPerWave - The number of MFMA instructions run by single wave in M
...@@ -71,9 +72,9 @@ __device__ constexpr auto GetBlockDescriptor() ...@@ -71,9 +72,9 @@ __device__ constexpr auto GetBlockDescriptor()
* \tparam BlockSize Tensor to pad. * \tparam BlockSize Tensor to pad.
* \tparam GemmTraits Traits of gemm xdl operation. * \tparam GemmTraits Traits of gemm xdl operation.
* \param a_local_tile_tensor A tensor in LDS memory for blockwise gemm * \param a_local_tile_tensor A tensor in LDS memory for blockwise gemm
* (MPerBlock, KPerBlock) layout. * (MPerBlock, KPerBlock) or (K0PerBlock, MPerBlock, K1) layout.
* \param b_local_tile_tensor B tensor in LDS memory for blockwise gemm * \param b_local_tile_tensor B tensor in LDS memory for blockwise gemm
* (NPerBlock, KPerBlock) layout. * (NPerBlock, KPerBlock) or (K0PerBlock, NPerBlock, K1) layout.
* \param c_reg_tensor C tensor VGPR memory for blockwise gemm. * \param c_reg_tensor C tensor VGPR memory for blockwise gemm.
*/ */
template <typename DataType, template <typename DataType,
...@@ -86,6 +87,8 @@ __device__ void blockwise_gemm_xdl(const ATensorType& a_local_tile_tensor, ...@@ -86,6 +87,8 @@ __device__ void blockwise_gemm_xdl(const ATensorType& a_local_tile_tensor,
const BTensorType& b_local_tile_tensor, const BTensorType& b_local_tile_tensor,
CTensorType& c_reg_tensor) CTensorType& c_reg_tensor)
{ {
constexpr auto I3 = Number<3>{};
static_assert(ATensorType::TensorBufferAddressSpace == MemoryTypeEnum::Lds); static_assert(ATensorType::TensorBufferAddressSpace == MemoryTypeEnum::Lds);
static_assert(BTensorType::TensorBufferAddressSpace == MemoryTypeEnum::Lds); static_assert(BTensorType::TensorBufferAddressSpace == MemoryTypeEnum::Lds);
static_assert(CTensorType::TensorBufferAddressSpace == MemoryTypeEnum::Vgpr); static_assert(CTensorType::TensorBufferAddressSpace == MemoryTypeEnum::Vgpr);
...@@ -99,10 +102,18 @@ __device__ void blockwise_gemm_xdl(const ATensorType& a_local_tile_tensor, ...@@ -99,10 +102,18 @@ __device__ void blockwise_gemm_xdl(const ATensorType& a_local_tile_tensor,
using ATileLayout = remove_cvref_t<decltype(layout(a_local_tile_tensor))>; using ATileLayout = remove_cvref_t<decltype(layout(a_local_tile_tensor))>;
using BTileLayout = remove_cvref_t<decltype(layout(b_local_tile_tensor))>; using BTileLayout = remove_cvref_t<decltype(layout(b_local_tile_tensor))>;
static_assert(typename ATileLayout::LayoutShape{}.Size() ==
typename BTileLayout::LayoutShape{}.Size());
constexpr bool is_3d_desc = typename ATileLayout::LayoutShape{}.Size() == I3;
using ABlockDesc_K0_M_K1_Type = using ABlockDesc_K0_M_K1_Type =
decltype(detail::GetBlockDescriptor<GemmTraits::K1, ATileLayout>()); conditional_t<is_3d_desc,
typename ATileLayout::LayoutUnrolledDescriptorType,
decltype(detail::GetBlockDescriptor<GemmTraits::K1, ATileLayout>())>;
using BBlockDesc_K0_N_K1_Type = using BBlockDesc_K0_N_K1_Type =
decltype(detail::GetBlockDescriptor<GemmTraits::K1, BTileLayout>()); conditional_t<is_3d_desc,
typename BTileLayout::LayoutUnrolledDescriptorType,
decltype(detail::GetBlockDescriptor<GemmTraits::K1, BTileLayout>())>;
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize, BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
DataType, DataType,
...@@ -168,14 +179,22 @@ make_blockwise_gemm_xdl_c_local_partition(CTensorType& c_local_tile_tensor) ...@@ -168,14 +179,22 @@ make_blockwise_gemm_xdl_c_local_partition(CTensorType& c_local_tile_tensor)
constexpr auto I6 = Number<6>{}; constexpr auto I6 = Number<6>{};
constexpr auto I7 = Number<7>{}; constexpr auto I7 = Number<7>{};
static_assert(typename ATileLayout::LayoutShape{}.Size() ==
typename BTileLayout::LayoutShape{}.Size());
constexpr bool is_integer = constexpr bool is_integer =
is_same_v<DataType, int8_t> || is_same_v<DataType, int16_t> || is_same_v<DataType, int32_t>; is_same_v<DataType, int8_t> || is_same_v<DataType, int16_t> || is_same_v<DataType, int32_t>;
using GemmAccDataType = std::conditional_t<is_integer, int32_t, float>; using GemmAccDataType = std::conditional_t<is_integer, int32_t, float>;
constexpr bool is_3d_desc = typename ATileLayout::LayoutShape{}.Size() == I3;
using ABlockDesc_K0_M_K1_Type = using ABlockDesc_K0_M_K1_Type =
decltype(detail::GetBlockDescriptor<GemmTraits::K1, ATileLayout>()); conditional_t<is_3d_desc,
typename ATileLayout::LayoutUnrolledDescriptorType,
decltype(detail::GetBlockDescriptor<GemmTraits::K1, ATileLayout>())>;
using BBlockDesc_K0_N_K1_Type = using BBlockDesc_K0_N_K1_Type =
decltype(detail::GetBlockDescriptor<GemmTraits::K1, BTileLayout>()); conditional_t<is_3d_desc,
typename BTileLayout::LayoutUnrolledDescriptorType,
decltype(detail::GetBlockDescriptor<GemmTraits::K1, BTileLayout>())>;
using BlockwiseGemmXdlops = using BlockwiseGemmXdlops =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize, BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
...@@ -233,19 +252,45 @@ make_blockwise_gemm_xdl_c_local_partition(CTensorType& c_local_tile_tensor) ...@@ -233,19 +252,45 @@ make_blockwise_gemm_xdl_c_local_partition(CTensorType& c_local_tile_tensor)
const auto partition_desc = BlockwiseGemmXdlops::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2( const auto partition_desc = BlockwiseGemmXdlops::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(
layout(c_local_tile_tensor).GetUnrolledDescriptor()); layout(c_local_tile_tensor).GetUnrolledDescriptor());
const auto lower_upper_dims =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<8>{});
auto sliced_desc = transform_tensor_descriptor(
partition_desc,
make_tuple(
make_slice_transform(partition_shape.At(Number<0>{}),
m_thread_data_on_grid_idx[I0],
partition_shape.At(Number<0>{}) + m_thread_data_on_grid_idx[I0]),
make_slice_transform(partition_shape.At(Number<1>{}),
n_thread_data_on_grid_idx[I0],
partition_shape.At(Number<1>{}) + n_thread_data_on_grid_idx[I0]),
make_slice_transform(partition_shape.At(Number<2>{}),
m_thread_data_on_grid_idx[I1],
partition_shape.At(Number<2>{}) + m_thread_data_on_grid_idx[I1]),
make_slice_transform(partition_shape.At(Number<3>{}),
n_thread_data_on_grid_idx[I1],
partition_shape.At(Number<3>{}) + n_thread_data_on_grid_idx[I1]),
make_slice_transform(partition_shape.At(Number<4>{}),
m_thread_data_on_grid_idx[I2],
partition_shape.At(Number<4>{}) + m_thread_data_on_grid_idx[I2]),
make_slice_transform(partition_shape.At(Number<5>{}),
m_thread_data_on_grid_idx[I3],
partition_shape.At(Number<5>{}) + m_thread_data_on_grid_idx[I3]),
make_slice_transform(partition_shape.At(Number<6>{}),
m_thread_data_on_grid_idx[I4],
partition_shape.At(Number<6>{}) + m_thread_data_on_grid_idx[I4]),
make_slice_transform(partition_shape.At(Number<7>{}),
n_thread_data_on_grid_idx[I2],
partition_shape.At(Number<7>{}) + n_thread_data_on_grid_idx[I2])),
lower_upper_dims,
lower_upper_dims);
const auto partition_layout = const auto partition_layout =
Layout<remove_reference_t<decltype(partition_shape)>, decltype(partition_desc)>( Layout<remove_reference_t<decltype(partition_shape)>, decltype(sliced_desc)>(
partition_shape, partition_desc); partition_shape, sliced_desc);
auto partition_tensor = make_tensor<CTensorType::TensorBufferAddressSpace>( auto partition_tensor = make_tensor<CTensorType::TensorBufferAddressSpace>(
c_local_tile_tensor.GetPointer(), partition_layout); c_local_tile_tensor.GetPointer(), partition_layout);
partition_tensor.SetMultiIdxOffset(make_multi_index(m_thread_data_on_grid_idx[I0],
n_thread_data_on_grid_idx[I0],
m_thread_data_on_grid_idx[I1],
n_thread_data_on_grid_idx[I1],
m_thread_data_on_grid_idx[I2],
m_thread_data_on_grid_idx[I3],
m_thread_data_on_grid_idx[I4],
n_thread_data_on_grid_idx[I2]));
return partition_tensor; return partition_tensor;
} }
...@@ -292,14 +337,22 @@ __host__ __device__ constexpr auto make_blockwise_gemm_xdl_c_vgpr() ...@@ -292,14 +337,22 @@ __host__ __device__ constexpr auto make_blockwise_gemm_xdl_c_vgpr()
constexpr auto I6 = Number<6>{}; constexpr auto I6 = Number<6>{};
constexpr auto I7 = Number<7>{}; constexpr auto I7 = Number<7>{};
static_assert(typename ATileLayout::LayoutShape{}.Size() ==
typename BTileLayout::LayoutShape{}.Size());
constexpr bool is_integer = constexpr bool is_integer =
is_same_v<DataType, int8_t> || is_same_v<DataType, int16_t> || is_same_v<DataType, int32_t>; is_same_v<DataType, int8_t> || is_same_v<DataType, int16_t> || is_same_v<DataType, int32_t>;
using GemmAccDataType = std::conditional_t<is_integer, int32_t, float>; using GemmAccDataType = std::conditional_t<is_integer, int32_t, float>;
constexpr bool is_3d_desc = typename ATileLayout::LayoutShape{}.Size() == I3;
using ABlockDesc_K0_M_K1_Type = using ABlockDesc_K0_M_K1_Type =
decltype(detail::GetBlockDescriptor<GemmTraits::K1, ATileLayout>()); conditional_t<is_3d_desc,
typename ATileLayout::LayoutUnrolledDescriptorType,
decltype(detail::GetBlockDescriptor<GemmTraits::K1, ATileLayout>())>;
using BBlockDesc_K0_N_K1_Type = using BBlockDesc_K0_N_K1_Type =
decltype(detail::GetBlockDescriptor<GemmTraits::K1, BTileLayout>()); conditional_t<is_3d_desc,
typename BTileLayout::LayoutUnrolledDescriptorType,
decltype(detail::GetBlockDescriptor<GemmTraits::K1, BTileLayout>())>;
using BlockwiseGemmXdlops = using BlockwiseGemmXdlops =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize, BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
...@@ -326,9 +379,8 @@ __host__ __device__ constexpr auto make_blockwise_gemm_xdl_c_vgpr() ...@@ -326,9 +379,8 @@ __host__ __device__ constexpr auto make_blockwise_gemm_xdl_c_vgpr()
const auto vgpr_layout = Layout<remove_reference_t<decltype(vgpr_shape)>, decltype(vgpr_desc)>( const auto vgpr_layout = Layout<remove_reference_t<decltype(vgpr_shape)>, decltype(vgpr_desc)>(
vgpr_shape, vgpr_desc); vgpr_shape, vgpr_desc);
// Get vector type for Vgpr // Get vector type for Vgpr
using BlockwiseGemmCThreadBufferType = constexpr index_t ScalarPerVector = BlockwiseGemmXdlops::xdlops_gemm.GetRegSizePerXdlops();
remove_reference_t<decltype(BlockwiseGemmXdlops{}.GetCThreadBuffer())>; using VgprVectorType = typename vector_type<GemmAccDataType, ScalarPerVector>::type;
using VgprVectorType = typename BlockwiseGemmCThreadBufferType::V;
return ck::wrapper::make_register_tensor<ck::wrapper::MemoryTypeEnum::Vgpr, VgprVectorType>( return ck::wrapper::make_register_tensor<ck::wrapper::MemoryTypeEnum::Vgpr, VgprVectorType>(
vgpr_layout); vgpr_layout);
} }
......
...@@ -172,10 +172,10 @@ __host__ __device__ constexpr auto GenerateUpperDims(const Tuple<Transforms...>& ...@@ -172,10 +172,10 @@ __host__ __device__ constexpr auto GenerateUpperDims(const Tuple<Transforms...>&
} }
} }
template <typename... Ts, typename Shape, typename FlattenDescriptor> template <typename... Ts, typename Shape, typename UnrolledDescriptor>
__host__ __device__ constexpr auto GenerateSlicedDescriptor(const Tuple<Ts...>& idx, __host__ __device__ constexpr auto GenerateSlicedDescriptor(const Tuple<Ts...>& idx,
const Shape& shape, const Shape& shape,
const FlattenDescriptor& flatten_desc) const UnrolledDescriptor& flatten_desc)
{ {
constexpr auto old_shape_dims = decltype(UnrollNestedTuple(shape))::Size(); constexpr auto old_shape_dims = decltype(UnrollNestedTuple(shape))::Size();
......
...@@ -20,48 +20,57 @@ namespace wrapper { ...@@ -20,48 +20,57 @@ namespace wrapper {
* \tparam K1Value The number of K-dim elements that are packed together as * \tparam K1Value The number of K-dim elements that are packed together as
* a separate logical dimension. Usually aligns with vector load size. * a separate logical dimension. Usually aligns with vector load size.
*/ */
template <index_t MPerXDLValue, template <typename MPerXDLValue,
index_t NPerXDLValue, typename NPerXDLValue,
index_t MXdlPerWaveValue, typename MXdlPerWaveValue,
index_t NXdlPerWaveValue, typename NXdlPerWaveValue,
index_t K1Value> typename K1Value>
struct BlockwisGemmXdlTraits struct BlockwisGemmXdlTraits
{ {
static constexpr index_t MPerXDL = MPerXDLValue; static constexpr auto MPerXDL = MPerXDLValue{};
static constexpr index_t NPerXDL = NPerXDLValue; static constexpr auto NPerXDL = NPerXDLValue{};
static constexpr index_t MXdlPerWave = MXdlPerWaveValue; static constexpr auto MXdlPerWave = MXdlPerWaveValue{};
static constexpr index_t NXdlPerWave = NXdlPerWaveValue; static constexpr auto NXdlPerWave = NXdlPerWaveValue{};
static constexpr index_t K1 = K1Value; static constexpr auto K1 = K1Value{};
}; };
// K1 = 4 // K1 = 4
struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_4K1 : BlockwisGemmXdlTraits<32, 32, 4, 2, 4> struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_4K1
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<4>, Number<2>, Number<4>>
{ {
}; };
struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_4K1 : BlockwisGemmXdlTraits<32, 32, 2, 4, 4> struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_4K1
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<2>, Number<4>, Number<4>>
{ {
}; };
struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_4K1 : BlockwisGemmXdlTraits<32, 32, 2, 2, 4> struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_4K1
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<2>, Number<2>, Number<4>>
{ {
}; };
// K1 = 8 // K1 = 8
struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_8K1 : BlockwisGemmXdlTraits<32, 32, 4, 2, 8> struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_8K1
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<4>, Number<2>, Number<8>>
{ {
}; };
struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_8K1 : BlockwisGemmXdlTraits<32, 32, 2, 4, 8> struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_8K1
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<2>, Number<4>, Number<8>>
{ {
}; };
struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_8K1 : BlockwisGemmXdlTraits<32, 32, 2, 2, 8> struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_8K1
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<2>, Number<2>, Number<8>>
{ {
}; };
// K1 = 16 // K1 = 16
struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_16K1 : BlockwisGemmXdlTraits<32, 32, 4, 2, 16> struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_16K1
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<4>, Number<2>, Number<16>>
{ {
}; };
struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_16K1 : BlockwisGemmXdlTraits<32, 32, 2, 4, 16> struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_16K1
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<2>, Number<4>, Number<16>>
{ {
}; };
struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_16K1 : BlockwisGemmXdlTraits<32, 32, 2, 2, 16> struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_16K1
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<2>, Number<2>, Number<16>>
{ {
}; };
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
namespace ck {
namespace wrapper {
#define __CK_WRAPPER_LAUNCH_BOUNDS__ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
} // namespace wrapper
} // namespace ck
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp" #include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
namespace ck { namespace ck {
namespace wrapper { namespace wrapper {
...@@ -29,6 +30,7 @@ template <typename T> ...@@ -29,6 +30,7 @@ template <typename T>
using is_tuple = decltype(std::declval<T&>().IsTuple()); using is_tuple = decltype(std::declval<T&>().IsTuple());
namespace { namespace {
namespace detail {
/** /**
* \brief Generate packed (column-major) strides if not passed * \brief Generate packed (column-major) strides if not passed
* *
...@@ -83,6 +85,7 @@ __host__ __device__ constexpr auto MakeUnrolledDescriptor(const LayoutShape& sha ...@@ -83,6 +85,7 @@ __host__ __device__ constexpr auto MakeUnrolledDescriptor(const LayoutShape& sha
return make_naive_tensor_descriptor(unrolled_shape, unrolled_strides); return make_naive_tensor_descriptor(unrolled_shape, unrolled_strides);
} }
} }
} // namespace detail
} // namespace } // namespace
/// @endcond /// @endcond
...@@ -98,8 +101,9 @@ __host__ __device__ constexpr auto MakeUnrolledDescriptor(const LayoutShape& sha ...@@ -98,8 +101,9 @@ __host__ __device__ constexpr auto MakeUnrolledDescriptor(const LayoutShape& sha
template <typename Shape, typename Strides> template <typename Shape, typename Strides>
__host__ __device__ constexpr auto make_layout(const Shape& shape, const Strides& strides) __host__ __device__ constexpr auto make_layout(const Shape& shape, const Strides& strides)
{ {
using UnrolledDescriptorType = decltype(MakeUnrolledDescriptor(Shape{}, Strides{})); using UnrolledDescriptorType = decltype(detail::MakeUnrolledDescriptor(Shape{}, Strides{}));
return Layout<Shape, UnrolledDescriptorType>(shape, MakeUnrolledDescriptor(shape, strides)); return Layout<Shape, UnrolledDescriptorType>(shape,
detail::MakeUnrolledDescriptor(shape, strides));
} }
/** /**
...@@ -112,13 +116,12 @@ __host__ __device__ constexpr auto make_layout(const Shape& shape, const Strides ...@@ -112,13 +116,12 @@ __host__ __device__ constexpr auto make_layout(const Shape& shape, const Strides
template <typename Shape> template <typename Shape>
__host__ __device__ constexpr auto make_layout(const Shape& shape) __host__ __device__ constexpr auto make_layout(const Shape& shape)
{ {
using UnrolledDescriptorType = decltype(MakeUnrolledDescriptor(Shape{}, Tuple<>{})); using UnrolledDescriptorType = decltype(detail::MakeUnrolledDescriptor(Shape{}, Tuple<>{}));
return Layout<Shape, UnrolledDescriptorType>(shape, MakeUnrolledDescriptor(shape, Tuple<>{})); return Layout<Shape, UnrolledDescriptorType>(shape,
detail::MakeUnrolledDescriptor(shape, Tuple<>{}));
} }
// Layout helpers // Layout helpers
// get // get
/** /**
* \private * \private
* \brief Get dim. * \brief Get dim.
...@@ -152,8 +155,8 @@ __host__ __device__ constexpr auto get(const Tuple<Dims...>& tuple) ...@@ -152,8 +155,8 @@ __host__ __device__ constexpr auto get(const Tuple<Dims...>& tuple)
* \param layout Layout to create sub layout. * \param layout Layout to create sub layout.
* \return Requsted sub layout. * \return Requsted sub layout.
*/ */
template <index_t idx, typename Shape, typename FlattenDesc> template <index_t idx, typename Shape, typename UnrolledDesc>
__host__ __device__ constexpr auto get(const Layout<Shape, FlattenDesc>& layout) __host__ __device__ constexpr auto get(const Layout<Shape, UnrolledDesc>& layout)
{ {
const auto& shape = layout.GetShape(); const auto& shape = layout.GetShape();
const auto new_shape = get<idx>(shape); const auto new_shape = get<idx>(shape);
...@@ -427,5 +430,91 @@ __host__ __device__ constexpr const auto& shape(const LayoutType& layout) ...@@ -427,5 +430,91 @@ __host__ __device__ constexpr const auto& shape(const LayoutType& layout)
return layout.GetShape(); return layout.GetShape();
} }
// pad
/**
* \brief Pad layout shapes to be adjusted to tile lengths.
*
*
* \param layout Layout to pad.
* \param tile_lengths Tile lengths to align layout shape.
* \return Padded layout.
*/
template <typename Shape, typename UnrolledDesc, typename TileLengths>
__host__ __device__ constexpr auto pad(const Layout<Shape, UnrolledDesc>& layout,
const TileLengths& tile_lengths)
{
auto& unrolled_desc = layout.GetUnrolledDescriptor();
// Generate sequence with ones to mark that all dims will be padded
constexpr auto do_pads_seq =
generate_sequence_v2([](auto) { return Number<1>{}; }, Number<Shape::Size()>{});
// Create descriptor with padding
auto padded_desc =
tensor_operation::device::PadTensorDescriptor(unrolled_desc, tile_lengths, do_pads_seq);
// Generate padded shape
const auto padded_shape = generate_tuple(
[&](auto i) { return padded_desc.GetLength(Number<i>{}); }, Number<TileLengths::Size()>{});
// Create layout
return Layout<decltype(padded_shape), decltype(padded_desc)>(padded_shape, padded_desc);
}
// unmerge
/**
* \brief Unmerge selected dim in layout.
*
* \tparam Idx Index to dimension being unmerged.
* \param layout Layout to pad.
* \param new_lengths Dimensions into which the indicated dimension will be divided.
* \param new_indexes Indexes to shuffle dims. Dims for unmerged dim should be nested.
* \return Unmerged layout.
*/
template <index_t Idx, typename Shape, typename UnrolledDesc, typename NewLengths, typename NewIdxs>
__host__ __device__ constexpr auto unmerge(const Layout<Shape, UnrolledDesc>& layout,
const NewLengths& new_lengths,
[[maybe_unused]] const NewIdxs& new_indexes)
{
const auto& layout_shape = shape(layout);
auto& unrolled_desc = layout.GetUnrolledDescriptor();
constexpr auto dims = Shape::Size();
// Generate transforms
const auto transforms = generate_tuple(
[&](auto i) {
if constexpr(i == Idx)
{
return make_unmerge_transform(new_lengths);
}
else
{
return make_pass_through_transform(layout_shape.At(i));
}
},
Number<dims>{});
constexpr auto lower_dims =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<dims>{});
constexpr auto upper_dims = generate_tuple(
[&](auto i) {
if constexpr(is_detected<is_tuple, tuple_element_t<i.value, NewIdxs>>::value)
{
constexpr auto idxs_tuple = tuple_element_t<i.value, NewIdxs>{};
return to_sequence(idxs_tuple);
}
else
{
constexpr index_t index = tuple_element_t<i.value, NewIdxs>{};
return Sequence<index>{};
}
},
Number<dims>{});
const auto unmerged_desc =
transform_tensor_descriptor(unrolled_desc, transforms, lower_dims, upper_dims);
const auto unmerged_shape =
generate_tuple([&](auto i) { return unmerged_desc.GetLength(Number<i>{}); },
Number<decltype(unmerged_desc)::GetNumOfVisibleDimension()>{});
// Create layout
return Layout<decltype(unmerged_shape), decltype(unmerged_desc)>(unmerged_shape, unmerged_desc);
}
} // namespace wrapper } // namespace wrapper
} // namespace ck } // namespace ck
...@@ -6,7 +6,6 @@ ...@@ -6,7 +6,6 @@
#include "tensor_utils.hpp" #include "tensor_utils.hpp"
#include "layout_utils.hpp" #include "layout_utils.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp" #include "ck/tensor_description/cluster_descriptor.hpp"
...@@ -44,8 +43,9 @@ __host__ __device__ constexpr auto CalculateLocalPartitionShape(const Tuple<Ts.. ...@@ -44,8 +43,9 @@ __host__ __device__ constexpr auto CalculateLocalPartitionShape(const Tuple<Ts..
* \brief Apply projection. * \brief Apply projection.
* *
* \param base_tuple Tuple to apply projection. * \param base_tuple Tuple to apply projection.
* \param projection Projection to remove selected dim from partitioning. * \param projection Projection is used to remove selected dim from
* slice(X) to remove, where X is dim size, Number<1>{} to keep. * partitioning. Use `slice(X)` to remove dimension, where X is dim
* size. Use `Number<1>{}` to keep it.
* \return Multi index after projection. * \return Multi index after projection.
*/ */
template <typename MultiIndex, typename ProjectionTuple> template <typename MultiIndex, typename ProjectionTuple>
...@@ -73,7 +73,7 @@ ApplyProjection([[maybe_unused]] const MultiIndex& base_tuple, ...@@ -73,7 +73,7 @@ ApplyProjection([[maybe_unused]] const MultiIndex& base_tuple,
} }
else else
{ {
return base_tuple.At(i_num); return make_tuple(base_tuple.At(i_num));
} }
}, },
Number<MultiIndex::Size()>{}); Number<MultiIndex::Size()>{});
...@@ -86,8 +86,9 @@ ApplyProjection([[maybe_unused]] const MultiIndex& base_tuple, ...@@ -86,8 +86,9 @@ ApplyProjection([[maybe_unused]] const MultiIndex& base_tuple,
* \brief Calculate shape with dims from projection. * \brief Calculate shape with dims from projection.
* *
* \param shape Base tensor shape. * \param shape Base tensor shape.
* \param projection Projection to remove selected dim from partitioning. * \param projection Projection is used to remove selected dim from
* slice(X) to remove, where X is dim size, Number<1>{} to keep. * partitioning. Use `slice(X)` to remove dimension, where X is dim
* size. Use `Number<1>{}` to keep it.
* \return Shape with dims from projection * \return Shape with dims from projection
*/ */
template <typename... Ts, typename... Ps> template <typename... Ts, typename... Ps>
...@@ -119,22 +120,14 @@ __host__ __device__ constexpr auto CalculateShapeWithProjection(const Tuple<Ts.. ...@@ -119,22 +120,14 @@ __host__ __device__ constexpr auto CalculateShapeWithProjection(const Tuple<Ts..
* *
* \param shape Base tensor shape. * \param shape Base tensor shape.
* \param tile_shape Tile shape. * \param tile_shape Tile shape.
* \param projection Projection is used to remove selected dim from
* partitioning. Use `slice(X)` to remove dimension, where X is dim
* size. Use `Number<1>{}` to keep it.
* \return Tuple with blocks number. * \return Tuple with blocks number.
*/ */
template <typename... Ts, typename... Ls, typename... Ps> template <typename... Ts, typename... Ls, typename... Ps>
__host__ __device__ constexpr auto CalculateGridSize(const Tuple<Ts...>& shape, __host__ __device__ constexpr auto CalculateGridSize(const Tuple<Ts...>& shape,
const Tuple<Ls...>& tile_shape, const Tuple<Ls...>& tile_shape)
const Tuple<Ps...>& projection)
{ {
auto shape_with_projection = CalculateShapeWithProjection(shape, projection);
return generate_tuple( return generate_tuple(
[&](auto i) { [&](auto i) { return ck::math::integer_divide_ceil(size<i>(shape), size<i>(tile_shape)); },
return ck::math::integer_divide_ceil(size<i>(shape_with_projection),
size<i>(tile_shape));
},
Number<Tuple<Ls...>::Size()>{}); Number<Tuple<Ls...>::Size()>{});
} }
...@@ -155,6 +148,54 @@ CalculateOffsetMultiIdxs(const ThreadIdxs& thread_idxs, ...@@ -155,6 +148,54 @@ CalculateOffsetMultiIdxs(const ThreadIdxs& thread_idxs,
return thread_idxs * partition_lengths_seq + old_offset_idxs; return thread_idxs * partition_lengths_seq + old_offset_idxs;
} }
/**
* \brief Select dims to partition (skip if slice).
*
* \param block_idxs Input block indexes.
* \return Partitioned dims.
*/
template <typename BlockIdxs>
__host__ __device__ constexpr auto GetDimsToPartition([[maybe_unused]] const BlockIdxs& block_idxs)
{
const auto dims_to_partition = generate_tuple(
[&](auto i) {
if constexpr(!is_detected<is_slice, tuple_element_t<i, BlockIdxs>>::value)
{
return Number<i>{};
}
else
{
return Tuple<>{};
}
},
Number<BlockIdxs::Size()>{});
// Remove empty tuples
return UnrollNestedTuple<0, 1>(dims_to_partition);
}
/**
* \brief Replace slices with zeros (Slice dims are not partitioned).
*
* \param block_idxs Input block indexes.
* \return Parsed dims.
*/
template <typename BlockIdxs>
__host__ __device__ constexpr auto ReplaceSlicesWithZeros(const BlockIdxs& block_idxs)
{
return generate_tuple(
[&](auto i) {
if constexpr(!is_detected<is_slice, tuple_element_t<i, BlockIdxs>>::value)
{
return block_idxs.At(i);
}
else
{
return Number<0>{};
}
},
Number<BlockIdxs::Size()>{});
}
/** /**
* \brief Calculate default projection. * \brief Calculate default projection.
* *
...@@ -168,6 +209,31 @@ GenerateDefaultProjection([[maybe_unused]] const TileShape tile_shape) ...@@ -168,6 +209,31 @@ GenerateDefaultProjection([[maybe_unused]] const TileShape tile_shape)
return generate_tuple([&](auto) { return Number<1>{}; }, Number<TileShape::Size()>{}); return generate_tuple([&](auto) { return Number<1>{}; }, Number<TileShape::Size()>{});
} }
/**
* \brief Calculate thread multi index from 1d thread index.
*
* \param thread_layout Layout of threads (could not be nested).
* \param thread_id Thread index represented as integer.
* \return Multi index.
*/
template <typename ThreadShape, typename ThreadUnrolledDesc>
__host__ __device__ constexpr auto CalculateThreadMultiIdx(
[[maybe_unused]] const Layout<ThreadShape, ThreadUnrolledDesc>& thread_layout,
const index_t thread_id)
{
static_assert(ThreadUnrolledDesc::GetNumOfTransform() == 1,
"Thread layout should not be transformed.");
constexpr auto embed_transform = ThreadUnrolledDesc{}.GetTransforms().At(Number<0>{});
constexpr auto shape = ThreadShape{};
constexpr auto strides = embed_transform.coefficients_;
return generate_tuple(
[&](auto i) {
constexpr auto num_i = Number<i>{};
return (thread_id / strides.At(num_i)) % shape.At(num_i);
},
Number<ThreadShape::Size()>{});
}
} // namespace detail } // namespace detail
} // namespace } // namespace
...@@ -176,51 +242,62 @@ GenerateDefaultProjection([[maybe_unused]] const TileShape tile_shape) ...@@ -176,51 +242,62 @@ GenerateDefaultProjection([[maybe_unused]] const TileShape tile_shape)
* is supported). * is supported).
* *
* \param tensor Tensor for partition. * \param tensor Tensor for partition.
* \param thread_lengths Layout of threads (could not be nested). * \param thread_layout Layout of threads (could not be transformed).
* \param thread_id Thread index represented as integer. * \param thread_id Thread index represented as integer.
* \param projection Projection is used to remove selected dim from * \param projection Projection is used to remove selected dim from
* partitioning. Use `slice(X)` to remove dimension, where X is dim * partitioning. Use `slice(X)` to remove dimension, where X is dim
* size. Use `Number<1>{}` to keep it. * size. Use `Number<1>{}` to keep it.
* \return Partition tensor. * \return Partition tensor.
*/ */
template <typename TensorType, typename ThreadLengthsTuple, typename ProjectionTuple> template <typename TensorType,
typename ThreadShape,
typename ThreadUnrolledDesc,
typename ProjectionTuple>
__host__ __device__ constexpr auto __host__ __device__ constexpr auto
make_local_partition(TensorType& tensor, make_local_partition(TensorType& tensor,
[[maybe_unused]] const ThreadLengthsTuple& thread_lengths, [[maybe_unused]] const Layout<ThreadShape, ThreadUnrolledDesc>& thread_layout,
const index_t thread_id, const index_t thread_id,
const ProjectionTuple& projection) const ProjectionTuple& projection)
{ {
static_assert(!IsNestedTuple(ThreadLengthsTuple{})); static_assert(!IsNestedTuple(ThreadShape{}));
// Calculate new partition shape // Calculate new partition shape
const auto& tensor_shape = shape(tensor); const auto& tensor_shape = shape(tensor);
// Calculate projected thread lengths // Calculate projected thread lengths
constexpr auto projected_thread_lengths = constexpr auto projected_thread_lengths =
detail::ApplyProjection(ThreadLengthsTuple{}, ProjectionTuple{}); detail::ApplyProjection(ThreadShape{}, ProjectionTuple{});
constexpr auto partition_shape = constexpr auto partition_shape =
detail::CalculateLocalPartitionShape(decltype(tensor_shape){}, projected_thread_lengths); detail::CalculateLocalPartitionShape(decltype(tensor_shape){}, projected_thread_lengths);
// Create Thread Cluster Descriptor
constexpr auto partition_shape_seq = constexpr auto partition_shape_seq =
generate_sequence_v2([&](auto I) { return size<I>(partition_shape); }, generate_sequence_v2([&](auto I) { return size<I>(partition_shape); },
Number<decltype(partition_shape)::Size()>{}); Number<decltype(partition_shape)::Size()>{});
constexpr auto thread_lengths_seq =
generate_sequence_v2([&](auto I) { return size<I>(ThreadLengthsTuple{}); },
Number<ThreadLengthsTuple::Size()>{});
constexpr auto thread_cluster_desc_ = make_cluster_descriptor(thread_lengths_seq);
// Calculate thread idxs and offsets // Calculate thread idxs and offsets
const auto thread_idxs = thread_cluster_desc_.CalculateBottomIndex(make_multi_index(thread_id)); const auto thread_idxs = detail::CalculateThreadMultiIdx(thread_layout, thread_id);
// Apply projection on thread idxs to remove not needed idxs // Apply projection on thread idxs to remove not needed idxs
const auto projected_thread_idxs = detail::ApplyProjection(thread_idxs, projection); const auto projected_thread_idxs = detail::ApplyProjection(thread_idxs, projection);
const auto offset_multi_idxs = detail::CalculateOffsetMultiIdxs( const auto offset_multi_idxs = detail::CalculateOffsetMultiIdxs(
projected_thread_idxs, partition_shape_seq, tensor.GetMultiIdxOffsets()); projected_thread_idxs, partition_shape_seq, tensor.GetMultiIdxOffsets());
// Create new layout and tensor // Create new layout and tensor
auto& unrolled_desc = layout(tensor).GetUnrolledDescriptor(); auto& unrolled_desc = layout(tensor).GetUnrolledDescriptor();
// Slice descriptor
const auto transforms = generate_tuple(
[&](auto i) {
return make_slice_transform(partition_shape.At(i),
offset_multi_idxs.At(i),
partition_shape.At(i) + offset_multi_idxs.At(i));
},
Number<remove_reference_t<decltype(tensor_shape)>::Size()>{});
const auto lower_upper_dims =
generate_tuple([&](auto i) { return Sequence<i.value>{}; },
Number<remove_reference_t<decltype(tensor_shape)>::Size()>{});
auto sliced_desc =
transform_tensor_descriptor(unrolled_desc, transforms, lower_upper_dims, lower_upper_dims);
// Create layout
const auto partition_layout = const auto partition_layout =
Layout<remove_reference_t<decltype(partition_shape)>, decltype(unrolled_desc)>( Layout<remove_reference_t<decltype(partition_shape)>, decltype(sliced_desc)>(
partition_shape, unrolled_desc); partition_shape, sliced_desc);
auto partition_tensor = auto partition_tensor =
make_tensor<TensorType::TensorBufferAddressSpace>(tensor.GetPointer(), partition_layout); make_tensor<TensorType::TensorBufferAddressSpace>(tensor.GetPointer(), partition_layout);
// Apply offsets // Apply offsets
partition_tensor.SetMultiIdxOffset(to_multi_index(offset_multi_idxs));
return partition_tensor; return partition_tensor;
} }
...@@ -233,12 +310,13 @@ make_local_partition(TensorType& tensor, ...@@ -233,12 +310,13 @@ make_local_partition(TensorType& tensor,
* \param thread_id Thread index represented as integer. * \param thread_id Thread index represented as integer.
* \return Partition tensor. * \return Partition tensor.
*/ */
template <typename TensorType, typename ThreadLengthsTuple> template <typename TensorType, typename ThreadShape, typename ThreadUnrolledDesc>
__host__ __device__ constexpr auto make_local_partition(TensorType& tensor, __host__ __device__ constexpr auto
const ThreadLengthsTuple& thread_lengths, make_local_partition(TensorType& tensor,
const index_t thread_id) const Layout<ThreadShape, ThreadUnrolledDesc>& thread_lengths,
const index_t thread_id)
{ {
const auto projection = detail::GenerateDefaultProjection(ThreadLengthsTuple{}); const auto projection = detail::GenerateDefaultProjection(ThreadShape{});
return make_local_partition(tensor, thread_lengths, thread_id, projection); return make_local_partition(tensor, thread_lengths, thread_id, projection);
} }
...@@ -252,21 +330,24 @@ __host__ __device__ constexpr auto make_local_partition(TensorType& tensor, ...@@ -252,21 +330,24 @@ __host__ __device__ constexpr auto make_local_partition(TensorType& tensor,
* *
* \param tensor Tensor for partition. * \param tensor Tensor for partition.
* \param tile_shape Shapes of requested tile. * \param tile_shape Shapes of requested tile.
* \param block_id Block index represented as integer. * \param block_idxs Tuple of block indexes represented as integer. If slice,
* \param projection Projection to remove selected dim from partitioning. * then get whole dim.
* slice(X) to remove, where X is dim size, Number<1>{} to keep. * \param projection Projection is used to remove selected dim from
* partitioning. Use `slice(X)` to remove dimension, where X is dim
* size. Use `Number<1>{}` to keep it.
* \return Tile tensor. * \return Tile tensor.
*/ */
template <typename TensorType, typename BlockShapeTuple, typename ProjectionTuple> template <typename TensorType,
typename BlockShapeTuple,
typename BlockIdxs,
typename ProjectionTuple>
__host__ __device__ constexpr auto make_local_tile(const TensorType& tensor, __host__ __device__ constexpr auto make_local_tile(const TensorType& tensor,
const BlockShapeTuple& tile_shape, const BlockShapeTuple& tile_shape,
const index_t block_id, const BlockIdxs& block_idxs,
const ProjectionTuple& projection) const ProjectionTuple& projection)
{ {
static_assert(!IsNestedTuple(BlockShapeTuple{})); static_assert(!IsNestedTuple(BlockShapeTuple{}));
static_assert(!IsNestedTuple(BlockIdxs{}));
constexpr bool is_default_projection =
is_same_v<ProjectionTuple, decltype(detail::GenerateDefaultProjection(BlockShapeTuple{}))>;
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
...@@ -274,49 +355,77 @@ __host__ __device__ constexpr auto make_local_tile(const TensorType& tensor, ...@@ -274,49 +355,77 @@ __host__ __device__ constexpr auto make_local_tile(const TensorType& tensor,
auto& aligned_desc = layout(tensor).GetMergedNestingDescriptor(); auto& aligned_desc = layout(tensor).GetMergedNestingDescriptor();
// TODO: Enable block_2_tile_map partitioning for non-default projection. constexpr auto projected_tile_shape =
if constexpr(BlockShapeTuple::Size() == I2 && is_default_projection) detail::ApplyProjection(BlockShapeTuple{}, ProjectionTuple{});
// Number of dims which are partitioned
constexpr auto dims_to_partition = detail::GetDimsToPartition(BlockIdxs{});
const auto parsed_block_idxs = detail::ReplaceSlicesWithZeros(block_idxs);
if constexpr(decltype(dims_to_partition)::Size() == I2)
{ {
// Optimized version for 2d tile shape [MxK] const auto shape_with_projection_dims =
detail::CalculateShapeWithProjection(shape(tensor), projection);
// Set Value for M, N partition
const auto M = shape_with_projection_dims.At(dims_to_partition.At(I0));
const auto N = shape_with_projection_dims.At(dims_to_partition.At(I1));
constexpr auto MPerBlock = BlockShapeTuple{}.At(dims_to_partition.At(I0));
constexpr auto NPerBlock = BlockShapeTuple{}.At(dims_to_partition.At(I1));
auto m_n_desc = make_naive_tensor_descriptor_packed(make_tuple(M, N));
// Get 1D block id
const auto grid_size = detail::CalculateGridSize(shape_with_projection_dims, tile_shape);
const auto block_lengths_desc = make_naive_tensor_descriptor_packed(grid_size);
const index_t block_id_1d = block_lengths_desc.CalculateOffset(parsed_block_idxs);
// Optimized version for 2d tile shape [MxN]
const auto block_2_tile_map = const auto block_2_tile_map =
BlockToCTileMap_M00_N0_M01Adapt<BlockShapeTuple{}.At(I0), BlockToCTileMap_M00_N0_M01Adapt<MPerBlock,
BlockShapeTuple{}.At(I1), NPerBlock,
remove_cvref_t<decltype(aligned_desc)>>(aligned_desc); remove_cvref_t<decltype(m_n_desc)>>(m_n_desc);
const auto block_work_idx = const auto block_work_idx =
block_2_tile_map.CalculateBottomIndex(make_multi_index(block_id)); block_2_tile_map.CalculateBottomIndex(make_multi_index(block_id_1d));
const index_t m_block_data_idx_on_grid = const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * size<0>(tile_shape)); __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
const index_t k_block_data_idx_on_grid = const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * size<1>(tile_shape)); __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
const auto offset_multi_idxs = // Apply 0 for non partitioned dims
make_tuple(m_block_data_idx_on_grid, k_block_data_idx_on_grid); const auto offset_multi_idxs = generate_tuple(
[&](auto i) {
if constexpr(i == dims_to_partition.At(I0))
{
return m_block_data_idx_on_grid;
}
else if constexpr(i == dims_to_partition.At(I1))
{
return n_block_data_idx_on_grid;
}
else
{
return Number<0>{};
}
},
Number<BlockShapeTuple::Size()>{});
const auto projected_offset_multi_idxs =
detail::ApplyProjection(offset_multi_idxs, projection);
// Create new layout and tensor // Create new layout and tensor
const auto tile_layout = const auto tile_layout =
Layout<remove_reference_t<decltype(tile_shape)>, decltype(aligned_desc)>(tile_shape, Layout<remove_reference_t<decltype(projected_tile_shape)>, decltype(aligned_desc)>(
aligned_desc); projected_tile_shape, aligned_desc);
auto tile_tensor = auto tile_tensor =
make_tensor<TensorType::TensorBufferAddressSpace>(tensor.GetPointer(), tile_layout); make_tensor<TensorType::TensorBufferAddressSpace>(tensor.GetPointer(), tile_layout);
// Apply offsets // Apply offsets
tile_tensor.SetMultiIdxOffset(to_multi_index(offset_multi_idxs)); tile_tensor.SetMultiIdxOffset(to_multi_index(projected_offset_multi_idxs));
return tile_tensor; return tile_tensor;
} }
else else
{ {
// Calculate offsets // Calculate offsets
// Sequence with data to process per block // Sequence with data to process per block
constexpr auto projected_tile_shape =
detail::ApplyProjection(BlockShapeTuple{}, ProjectionTuple{});
using ProjectedTileShapeTuple = decltype(projected_tile_shape); using ProjectedTileShapeTuple = decltype(projected_tile_shape);
constexpr auto projected_tile_shape_seq = constexpr auto projected_tile_shape_seq =
generate_sequence_v2([](auto I) { return ProjectedTileShapeTuple{}.At(I); }, generate_sequence_v2([](auto I) { return ProjectedTileShapeTuple{}.At(I); },
Number<ProjectedTileShapeTuple::Size()>{}); Number<ProjectedTileShapeTuple::Size()>{});
// Tuple with number of blocks // Tuple with number of blocks
const auto block_lengths = detail::CalculateGridSize(shape(tensor), tile_shape, projection); const auto projected_block_idxs =
const auto block_cluster_desc_ = make_cluster_descriptor(block_lengths); to_multi_index(detail::ApplyProjection(parsed_block_idxs, projection));
const auto block_idxs = const auto offset_multi_idxs = detail::CalculateOffsetMultiIdxs(
block_cluster_desc_.CalculateBottomIndex(make_multi_index(block_id));
const auto projected_block_idxs = detail::ApplyProjection(block_idxs, projection);
const auto offset_multi_idxs = detail::CalculateOffsetMultiIdxs(
projected_block_idxs, projected_tile_shape_seq, tensor.GetMultiIdxOffsets()); projected_block_idxs, projected_tile_shape_seq, tensor.GetMultiIdxOffsets());
// Create new layout and tensor // Create new layout and tensor
const auto tile_layout = const auto tile_layout =
...@@ -338,52 +447,17 @@ __host__ __device__ constexpr auto make_local_tile(const TensorType& tensor, ...@@ -338,52 +447,17 @@ __host__ __device__ constexpr auto make_local_tile(const TensorType& tensor,
* *
* \param tensor Tensor for partition. * \param tensor Tensor for partition.
* \param tile_shape Shapes of requested tile. * \param tile_shape Shapes of requested tile.
* \param block_id Block index represented as integer. * \param block_idxs Tuple of block indexes represented as integer. If slice,
* then get whole dim.
* \return Tile tensor. * \return Tile tensor.
*/ */
template <typename TensorType, typename BlockShapeTuple> template <typename TensorType, typename BlockShapeTuple, typename BlockIdxs>
__host__ __device__ constexpr auto __host__ __device__ constexpr auto make_local_tile(const TensorType& tensor,
make_local_tile(const TensorType& tensor, const BlockShapeTuple& tile_shape, const index_t block_id) const BlockShapeTuple& tile_shape,
const BlockIdxs& block_idxs)
{ {
const auto projection = detail::GenerateDefaultProjection(BlockShapeTuple{}); const auto projection = detail::GenerateDefaultProjection(BlockShapeTuple{});
return make_local_tile(tensor, tile_shape, block_id, projection); return make_local_tile(tensor, tile_shape, block_idxs, projection);
}
/**
* \brief Pad tensor shapes to be adjusted to tile lengths.
*
*
* \param tensor Tensor to pad.
* \param tile_lengths Tile lengths to align tensor shape.
* \return Padded tensor.
*/
template <typename TensorType, typename TileLengths>
__host__ __device__ constexpr auto pad(const TensorType& tensor, const TileLengths& tile_lengths)
{
const auto& tensor_shape = shape(tensor);
using TensorShapeType = remove_reference_t<decltype(tensor_shape)>;
auto& unrolled_desc = layout(tensor).GetUnrolledDescriptor();
// Generate sequence with ones to mark that all dims will be padded
constexpr auto do_pads_seq =
generate_sequence_v2([](auto) { return Number<1>{}; }, Number<TensorShapeType::Size()>{});
// Create descriptor with padding
auto padded_desc =
tensor_operation::device::PadTensorDescriptor(unrolled_desc, tile_lengths, do_pads_seq);
// Generate padded shape
const auto padded_shape = generate_tuple(
[&](auto i) {
const auto& dim = size<i>(tensor_shape);
const auto& tile_length = size<i>(tile_lengths);
return ck::math::integer_divide_ceil(dim, tile_length) * tile_length;
},
Number<TileLengths::Size()>{});
// Create layout and tensor
const auto padded_layout =
Layout<decltype(padded_shape), decltype(padded_desc)>(padded_shape, padded_desc);
auto partition_tensor =
make_tensor<TensorType::TensorBufferAddressSpace>(tensor.GetPointer(), padded_layout);
partition_tensor.SetMultiIdxOffset(tensor.GetMultiIdxOffsets());
return partition_tensor;
} }
} // namespace wrapper } // namespace wrapper
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -25,25 +25,35 @@ template <ck::index_t NDimSpatial, ...@@ -25,25 +25,35 @@ template <ck::index_t NDimSpatial,
typename InElementwiseOperation, typename InElementwiseOperation,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation, typename OutElementwiseOperation,
ck::index_t NumAElementwiseTensor = 0,
ck::index_t NumBElementwiseTensor = 0,
ck::index_t NumDElementwiseTensor = 0,
typename std::enable_if<NDimSpatial >= 1 && NDimSpatial <= 3, bool>::type = false> typename std::enable_if<NDimSpatial >= 1 && NDimSpatial <= 3, bool>::type = false>
struct ReferenceConvBwdData : public device::BaseOperator struct ReferenceConvBwdData : public device::BaseOperator
{ {
// Argument // Argument
struct Argument : public device::BaseArgument struct Argument : public device::BaseArgument
{ {
Argument(Tensor<InDataType>& input, Argument(
const Tensor<WeiDataType>& weight, Tensor<InDataType>& input,
const Tensor<OutDataType>& output, const Tensor<WeiDataType>& weight,
std::vector<ck::index_t> conv_filter_strides, const Tensor<OutDataType>& output,
std::vector<ck::index_t> conv_filter_dilations, std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> input_left_pads, std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_right_pads, std::vector<ck::index_t> input_left_pads,
InElementwiseOperation in_element_op, std::vector<ck::index_t> input_right_pads,
WeiElementwiseOperation wei_element_op, InElementwiseOperation in_element_op,
OutElementwiseOperation out_element_op) WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op,
const std::array<Tensor<InDataType>, NumAElementwiseTensor>& elementwise_a_tensors,
const std::array<Tensor<WeiDataType>, NumBElementwiseTensor>& elementwise_b_tensors,
const std::array<Tensor<OutDataType>, NumDElementwiseTensor>& elementwise_d_tensors)
: input_{input}, : input_{input},
weight_{weight}, weight_{weight},
output_{output}, output_{output},
elementwise_a_tensors_{elementwise_a_tensors},
elementwise_b_tensors_{elementwise_b_tensors},
elementwise_d_tensors_{elementwise_d_tensors},
conv_strides_{conv_filter_strides}, conv_strides_{conv_filter_strides},
conv_dilations_{conv_filter_dilations}, conv_dilations_{conv_filter_dilations},
in_left_pads_{input_left_pads}, in_left_pads_{input_left_pads},
...@@ -58,6 +68,10 @@ struct ReferenceConvBwdData : public device::BaseOperator ...@@ -58,6 +68,10 @@ struct ReferenceConvBwdData : public device::BaseOperator
const Tensor<WeiDataType>& weight_; const Tensor<WeiDataType>& weight_;
const Tensor<OutDataType>& output_; const Tensor<OutDataType>& output_;
const std::array<Tensor<InDataType>, NumAElementwiseTensor>& elementwise_a_tensors_;
const std::array<Tensor<WeiDataType>, NumBElementwiseTensor>& elementwise_b_tensors_;
const std::array<Tensor<OutDataType>, NumDElementwiseTensor>& elementwise_d_tensors_;
std::vector<index_t> conv_strides_; std::vector<index_t> conv_strides_;
std::vector<index_t> conv_dilations_; std::vector<index_t> conv_dilations_;
std::vector<index_t> in_left_pads_; std::vector<index_t> in_left_pads_;
...@@ -106,26 +120,46 @@ struct ReferenceConvBwdData : public device::BaseOperator ...@@ -106,26 +120,46 @@ struct ReferenceConvBwdData : public device::BaseOperator
{ {
for(std::size_t k = 0; k < K; ++k) for(std::size_t k = 0; k < K; ++k)
{ {
float v_out = 0; OutDataType v_out;
float v_wei = 0; WeiDataType v_wei;
arg.out_element_op_( ExecuteElementwiseOp(arg.out_element_op_,
v_out, ck::type_convert<float>(arg.output_(g, n, k, wo))); arg.elementwise_a_tensors_,
Number<NumAElementwiseTensor>{},
arg.wei_element_op_( v_out,
v_wei, ck::type_convert<float>(arg.weight_(g, k, c, x))); arg.output_(g, n, k, wo),
g,
v_acc += v_out * v_wei; n,
k,
wo);
ExecuteElementwiseOp(arg.wei_element_op_,
arg.elementwise_b_tensors_,
Number<NumBElementwiseTensor>{},
v_wei,
arg.weight_(g, k, c, x),
g,
k,
c,
x);
v_acc += ck::type_convert<float>(v_out) *
ck::type_convert<float>(v_wei);
} }
} }
} }
} }
float v_in; InDataType v_acc_converted = ck::type_convert<InDataType>(v_acc);
InDataType& v_in = arg.input_(g, n, c, wi);
arg.in_element_op_(v_in, v_acc); ExecuteElementwiseOp(arg.in_element_op_,
arg.elementwise_d_tensors_,
arg.input_(g, n, c, wi) = ck::type_convert<InDataType>(v_in); Number<NumDElementwiseTensor>{},
v_in,
v_acc_converted,
g,
n,
c,
wi);
}; };
make_ParallelTensorFunctor(f_ncw, make_ParallelTensorFunctor(f_ncw,
...@@ -175,20 +209,34 @@ struct ReferenceConvBwdData : public device::BaseOperator ...@@ -175,20 +209,34 @@ struct ReferenceConvBwdData : public device::BaseOperator
{ {
for(std::size_t k = 0; k < K; ++k) for(std::size_t k = 0; k < K; ++k)
{ {
float v_out = 0; OutDataType v_out;
float v_wei = 0; WeiDataType v_wei;
arg.out_element_op_( ExecuteElementwiseOp(
arg.out_element_op_,
arg.elementwise_a_tensors_,
Number<NumAElementwiseTensor>{},
v_out, v_out,
ck::type_convert<float>( arg.output_(g, n, k, ho, wo),
arg.output_(g, n, k, ho, wo))); g,
n,
arg.wei_element_op_( k,
ho,
wo);
ExecuteElementwiseOp(
arg.wei_element_op_,
arg.elementwise_b_tensors_,
Number<NumBElementwiseTensor>{},
v_wei, v_wei,
ck::type_convert<float>( arg.weight_(g, k, c, y, x),
arg.weight_(g, k, c, y, x))); g,
k,
v_acc += v_out * v_wei; c,
y,
x);
v_acc += ck::type_convert<float>(v_out) *
ck::type_convert<float>(v_wei);
} }
} }
} }
...@@ -197,11 +245,18 @@ struct ReferenceConvBwdData : public device::BaseOperator ...@@ -197,11 +245,18 @@ struct ReferenceConvBwdData : public device::BaseOperator
} }
} }
float v_in; InDataType v_acc_converted = ck::type_convert<InDataType>(v_acc);
InDataType& v_in = arg.input_(g, n, c, hi, wi);
arg.in_element_op_(v_in, v_acc); ExecuteElementwiseOp(arg.in_element_op_,
arg.elementwise_d_tensors_,
arg.input_(g, n, c, hi, wi) = ck::type_convert<InDataType>(v_in); Number<NumDElementwiseTensor>{},
v_in,
v_acc_converted,
g,
n,
c,
hi,
wi);
}; };
make_ParallelTensorFunctor(f_nchw, make_ParallelTensorFunctor(f_nchw,
...@@ -270,20 +325,37 @@ struct ReferenceConvBwdData : public device::BaseOperator ...@@ -270,20 +325,37 @@ struct ReferenceConvBwdData : public device::BaseOperator
{ {
for(std::size_t k = 0; k < K; ++k) for(std::size_t k = 0; k < K; ++k)
{ {
float v_out = 0; OutDataType v_out;
float v_wei = 0; WeiDataType v_wei;
arg.out_element_op_( ExecuteElementwiseOp(
arg.out_element_op_,
arg.elementwise_a_tensors_,
Number<NumAElementwiseTensor>{},
v_out, v_out,
ck::type_convert<float>(arg.output_( arg.output_(g, n, k, do_, ho, wo),
g, n, k, do_, ho, wo))); g,
n,
arg.wei_element_op_( k,
do_,
ho,
wo);
ExecuteElementwiseOp(
arg.wei_element_op_,
arg.elementwise_b_tensors_,
Number<NumBElementwiseTensor>{},
v_wei, v_wei,
ck::type_convert<float>( arg.weight_(g, k, c, z, y, x),
arg.weight_(g, k, c, z, y, x))); g,
k,
v_acc += v_out * v_wei; c,
z,
y,
x);
v_acc +=
ck::type_convert<float>(v_out) *
ck::type_convert<float>(v_wei);
} }
} }
} }
...@@ -295,11 +367,19 @@ struct ReferenceConvBwdData : public device::BaseOperator ...@@ -295,11 +367,19 @@ struct ReferenceConvBwdData : public device::BaseOperator
} }
} }
float v_in; InDataType v_acc_converted = ck::type_convert<InDataType>(v_acc);
InDataType& v_in = arg.input_(g, n, c, di, hi, wi);
arg.in_element_op_(v_in, v_acc); ExecuteElementwiseOp(arg.in_element_op_,
arg.elementwise_d_tensors_,
arg.input_(g, n, c, di, hi, wi) = ck::type_convert<InDataType>(v_in); Number<NumDElementwiseTensor>{},
v_in,
v_acc_converted,
g,
n,
c,
di,
hi,
wi);
}; };
make_ParallelTensorFunctor(f_ncdhw, make_ParallelTensorFunctor(f_ncdhw,
...@@ -325,6 +405,36 @@ struct ReferenceConvBwdData : public device::BaseOperator ...@@ -325,6 +405,36 @@ struct ReferenceConvBwdData : public device::BaseOperator
} }
}; };
template <typename... Args,
typename ElementwiseOp,
typename ElementwiseTensor,
typename NumTensor,
typename T>
static void ExecuteElementwiseOp(ElementwiseOp& elementwise_op,
ElementwiseTensor& elementwise_tensors,
NumTensor,
T& y,
const T& x,
Args... dims)
{
if constexpr(NumTensor::value == 0)
{
elementwise_op(y, x);
}
else if constexpr(NumTensor::value == 1)
{
elementwise_op(y, x, elementwise_tensors[0](dims...));
}
else if constexpr(NumTensor::value == 2)
{
elementwise_op(y, x, elementwise_tensors[0](dims...), elementwise_tensors[1](dims...));
}
else
{
throw std::runtime_error("ElementOp not supported in reference.");
}
}
static constexpr bool IsValidCompilationParameter() static constexpr bool IsValidCompilationParameter()
{ {
// TODO: properly implement this check // TODO: properly implement this check
...@@ -333,16 +443,20 @@ struct ReferenceConvBwdData : public device::BaseOperator ...@@ -333,16 +443,20 @@ struct ReferenceConvBwdData : public device::BaseOperator
bool IsSupportedArgument(const device::BaseArgument*) override { return true; } bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
static auto MakeArgument(Tensor<InDataType>& input, static auto MakeArgument(
const Tensor<WeiDataType>& weight, Tensor<InDataType>& input,
const Tensor<OutDataType>& output, const Tensor<WeiDataType>& weight,
std::vector<ck::index_t> conv_filter_strides, const Tensor<OutDataType>& output,
std::vector<ck::index_t> conv_filter_dilations, std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> input_left_pads, std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_right_pads, std::vector<ck::index_t> input_left_pads,
InElementwiseOperation in_element_op, std::vector<ck::index_t> input_right_pads,
WeiElementwiseOperation wei_element_op, InElementwiseOperation in_element_op,
OutElementwiseOperation out_element_op) WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op,
const std::array<Tensor<InDataType>, NumAElementwiseTensor>& elementwise_a_tensors = {},
const std::array<Tensor<WeiDataType>, NumBElementwiseTensor>& elementwise_b_tensors = {},
const std::array<Tensor<OutDataType>, NumDElementwiseTensor>& elementwise_d_tensors = {})
{ {
return Argument{input, return Argument{input,
weight, weight,
...@@ -353,7 +467,10 @@ struct ReferenceConvBwdData : public device::BaseOperator ...@@ -353,7 +467,10 @@ struct ReferenceConvBwdData : public device::BaseOperator
input_right_pads, input_right_pads,
in_element_op, in_element_op,
wei_element_op, wei_element_op,
out_element_op}; out_element_op,
elementwise_a_tensors,
elementwise_b_tensors,
elementwise_d_tensors};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
......
...@@ -189,6 +189,11 @@ void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_v2_instances( ...@@ -189,6 +189,11 @@ void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_v2_instances(
DeviceGemmSplitK<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>& DeviceGemmSplitK<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
instances); instances);
void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_kpb128_instances(
std::vector<std::unique_ptr<
DeviceGemmSplitK<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_splitk_f16_f16_f16_comp_f8_km_kn_mn_instances( void add_device_gemm_xdl_splitk_f16_f16_f16_comp_f8_km_kn_mn_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceGemmSplitK<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough, F8>>>& DeviceGemmSplitK<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough, F8>>>&
...@@ -352,6 +357,7 @@ struct DeviceOperationInstanceFactory< ...@@ -352,6 +357,7 @@ struct DeviceOperationInstanceFactory<
add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_v1_instances(op_ptrs); add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_v1_instances(op_ptrs);
add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_v1_interwave_instances(op_ptrs); add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_v1_interwave_instances(op_ptrs);
add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_v2_instances(op_ptrs); add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_v2_instances(op_ptrs);
add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_kpb128_instances(op_ptrs);
} }
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> && else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>) is_same_v<CLayout, Row>)
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using BF16 = ck::bhalf_t;
using F16 = ck::half_t;
using F32 = float;
using BF8 = ck::bf8_t;
using F8 = ck::f8_t;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using namespace ck::tensor_layout::convolution;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Bilinear = ck::tensor_operation::element_wise::Bilinear;
static constexpr auto ConvBwdDataDefault = ConvolutionBackwardDataSpecialization::Default;
static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 =
ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0;
// f16_f16_f32_f16
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
ConvolutionBackwardDataSpecialization ConvSpec>
using device_grouped_conv_bwd_data_xdl_bilinear_f16_instances = std::tuple<
// clang-format off
// ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer|
// ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector|
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, ck::Tuple<ELayout>, ELayout, F16, F16, F32, F16, Tuple<F16>, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>,
// instances for small conv.K and conv.C
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, ck::Tuple<ELayout>, ELayout, F16, F16, F32, F16, Tuple<F16>, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, ck::Tuple<ELayout>, ELayout, F16, F16, F32, F16, Tuple<F16>, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, ck::Tuple<ELayout>, ELayout, F16, F16, F32, F16, Tuple<F16>, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>
// clang-format on
>;
// bf16_bf16_f32_bf16
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
ConvolutionBackwardDataSpecialization ConvSpec>
using device_grouped_conv_bwd_data_xdl_bilinear_bf16_instances = std::tuple<
// clang-format off
// ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer|
// ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector|
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, ck::Tuple<ELayout>, ELayout, BF16, BF16, F32, BF16, Tuple<BF16>, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>,
// instances for small conv.K and conv.C
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, ck::Tuple<ELayout>, ELayout, BF16, BF16, F32, BF16, Tuple<BF16>, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, ck::Tuple<ELayout>, ELayout, BF16, BF16, F32, BF16, Tuple<BF16>, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, ck::Tuple<ELayout>, ELayout, BF16, BF16, F32, BF16, Tuple<BF16>, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>
// clang-format on
>;
// f32_f32_f32_f32
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
ConvolutionBackwardDataSpecialization ConvSpec>
using device_grouped_conv_bwd_data_xdl_bilinear_f32_instances = std::tuple<
// clang-format off
// ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer|
// ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector|
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, ck::Tuple<ELayout>, ELayout, F32, F32, F32, F32, Tuple<F32>, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1>,
// instances for small conv.K and conv.C
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, ck::Tuple<ELayout>, ELayout, F32, F32, F32, F32, Tuple<F32>, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, ck::Tuple<ELayout>, ELayout, F32, F32, F32, F32, Tuple<F32>, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 32, 1, 4>, 1>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, ck::Tuple<ELayout>, ELayout, F32, F32, F32, F32, Tuple<F32>, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4>
// clang-format on
>;
// f16_f16_f16_comp_f8
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
ConvolutionBackwardDataSpecialization ConvSpec>
using device_grouped_conv_bwd_data_xdl_bilinear_input_fp16_comp_bf8f8_instances = std::tuple<
// clang-format off
// ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer|
// ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector|
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, ck::Tuple<ELayout>, ELayout, F16, F16, F32, F32, Tuple<F16>, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, BF8, F8>,
// instances for small conv.K and conv.C
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, ck::Tuple<ELayout>, ELayout, F16, F16, F32, F32, Tuple<F16>, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4, LoopScheduler::Default, BF8, F8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, ck::Tuple<ELayout>, ELayout, F16, F16, F32, F32, Tuple<F16>, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 32, 1, 4>, 1, LoopScheduler::Default, BF8, F8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, ck::Tuple<ELayout>, ELayout, F16, F16, F32, F32, Tuple<F16>, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4, LoopScheduler::Default, BF8, F8>
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using BF16 = ck::bhalf_t;
using F16 = ck::half_t;
using F32 = float;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using namespace ck::tensor_layout::convolution;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Bilinear = ck::tensor_operation::element_wise::Bilinear;
static constexpr auto ConvFwdDefault =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
static constexpr auto ConvFwd1x1P0 = ConvolutionForwardSpecialization::Filter1x1Pad0;
static constexpr auto ConvFwd1x1S1P0 = ConvolutionForwardSpecialization::Filter1x1Stride1Pad0;
static constexpr auto ConvFwdOddC =
ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC;
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
ConvolutionForwardSpecialization ConvSpec>
using device_grouped_conv_fwd_xdl_bilinear_bf16_instances = std::tuple<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF16, BF16, F32, BF16, ck::Tuple<BF16>, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>,
// instances for small conv.K and conv.C
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF16, BF16, F32, BF16, ck::Tuple<BF16>, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF16, BF16, F32, BF16, ck::Tuple<BF16>, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF16, BF16, F32, BF16, ck::Tuple<BF16>, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>
// clang-format on
>;
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
ConvolutionForwardSpecialization ConvSpec>
using device_grouped_conv_fwd_xdl_bilinear_f16_instances = std::tuple<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, ck::Tuple<F16>, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>,
// instances for small conv.K and conv.C
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, ck::Tuple<F16>, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, ck::Tuple<F16>, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, ck::Tuple<F16>, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>
// clang-format on
>;
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
ConvolutionForwardSpecialization ConvSpec>
using device_grouped_conv_fwd_xdl_bilinear_f32_instances = std::tuple<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, ck::Tuple<F32>, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 8, 1, 8>, 1>,
// instances for small conv.K and conv.C
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, ck::Tuple<F32>, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 1>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, ck::Tuple<F32>, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, ck::Tuple<F32>, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>
// clang-format on
>;
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
ConvolutionForwardSpecialization ConvSpec>
using device_grouped_conv_fwd_xdl_bilinear_int8_instances = std::tuple<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, int8_t, int8_t, int32_t, int8_t, ck::Tuple<int8_t>, int8_t, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>,
// instances for small conv.K and conv.C
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, int8_t, int8_t, int32_t, int8_t, ck::Tuple<int8_t>, int8_t, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, int8_t, int8_t, int32_t, int8_t, ck::Tuple<int8_t>, int8_t, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, int8_t, int8_t, int32_t, int8_t, ck::Tuple<int8_t>, int8_t, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgk_gkzyxc_ndhwgc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
GKZYXC,
Tuple<NDHWGC>,
NDHWGC,
F16,
F16,
Tuple<F16>,
F16,
PassThrough,
PassThrough,
Bilinear>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgk_gkzyxc_ndhwgc_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
GKZYXC,
Tuple<NDHWGC>,
NDHWGC,
F32,
F32,
Tuple<F32>,
F32,
PassThrough,
PassThrough,
Bilinear>>>& instances);
#endif
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
GKZYXC,
Tuple<NDHWGC>,
NDHWGC,
BF16,
BF16,
Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
Bilinear>>>& instances);
#endif
template <ck::index_t NumDimSpatial,
typename OutLayout,
typename WeiLayout,
typename InLayout,
typename OutDataType,
typename WeiDataType,
typename InDataType,
typename ComputeTypeA,
typename ComputeTypeB>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD<
NumDimSpatial,
OutLayout,
WeiLayout,
Tuple<InLayout>,
InLayout,
OutDataType,
WeiDataType,
Tuple<InDataType>,
InDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::Bilinear,
ComputeTypeA,
ComputeTypeB>>
{
using DeviceOp =
DeviceGroupedConvBwdDataMultipleD<NumDimSpatial,
OutLayout,
WeiLayout,
Tuple<InLayout>,
InLayout,
OutDataType,
WeiDataType,
Tuple<InDataType>,
InDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::Bilinear,
ComputeTypeA,
ComputeTypeB>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(NumDimSpatial == 3)
{
if constexpr(is_same_v<InLayout, NDHWGC> && is_same_v<WeiLayout, GKZYXC> &&
is_same_v<OutLayout, NDHWGK>)
{
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
is_same_v<OutDataType, F16> && is_same_v<ComputeTypeA, F16> &&
is_same_v<ComputeTypeB, F16>)
{
add_device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgk_gkzyxc_ndhwgc_f16_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP32
else if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
is_same_v<OutDataType, F32> && is_same_v<ComputeTypeA, F32> &&
is_same_v<ComputeTypeB, F32>)
{
add_device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgk_gkzyxc_ndhwgc_f32_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_BF16
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
is_same_v<ComputeTypeB, BF16>)
{
add_device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
op_ptrs);
}
#endif
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Bilinear = ck::tensor_operation::element_wise::Bilinear;
#ifdef CK_ENABLE_BF16
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
void add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
ck::Tuple<NDHWGK>,
NDHWGK,
BF16,
BF16,
ck::Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
Bilinear>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
ck::Tuple<NDHWGK>,
NDHWGK,
F16,
F16,
ck::Tuple<F16>,
F16,
PassThrough,
PassThrough,
Bilinear>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
ck::Tuple<NDHWGK>,
NDHWGK,
F32,
F32,
ck::Tuple<F32>,
F32,
PassThrough,
PassThrough,
Bilinear>>>& instances);
#endif
#ifdef CK_ENABLE_INT8
void add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
ck::Tuple<NDHWGK>,
NDHWGK,
int8_t,
int8_t,
ck::Tuple<int8_t>,
int8_t,
PassThrough,
PassThrough,
Bilinear>>>& instances);
#endif
template <ck::index_t NumDimSpatial,
typename InLayout,
typename WeiLayout,
typename DLayouts,
typename OutLayout,
typename InDataType,
typename WeiDataType,
typename DDataTypes,
typename OutDataType,
typename ComputeType>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<
NumDimSpatial,
InLayout,
WeiLayout,
DLayouts,
OutLayout,
InDataType,
WeiDataType,
DDataTypes,
OutDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::Bilinear,
ComputeType>>
{
using DeviceOp =
DeviceGroupedConvFwdMultipleABD<NumDimSpatial,
InLayout,
WeiLayout,
DLayouts,
OutLayout,
InDataType,
WeiDataType,
DDataTypes,
OutDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::Bilinear,
ComputeType>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, NDHWGC> &&
is_same_v<WeiLayout, GKZYXC> && is_same_v<OutLayout, NDHWGK> &&
DLayouts::Size() == 1 && is_same_v<tuple_element_t<0, DLayouts>, NDHWGK>)
{
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<ComputeType, half_t>)
{
add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_BF16
if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> && is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_INT8
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>)
{
add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_int8_instances(
op_ptrs);
}
#endif
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -97,6 +97,35 @@ void add_device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_nk_mn_instances( ...@@ -97,6 +97,35 @@ void add_device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_nk_mn_instances(
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
// bf16_inputA i8_inputB
#if defined(CK_ENABLE_BF16) && defined(CK_ENABLE_INT8)
void add_device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
Row,
Empty_Tuple,
Row,
BF16,
I8,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
Col,
Empty_Tuple,
Row,
BF16,
I8,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
template <typename ALayout, template <typename ALayout,
typename BLayout, typename BLayout,
typename ELayout, typename ELayout,
...@@ -180,6 +209,24 @@ struct DeviceOperationInstanceFactory< ...@@ -180,6 +209,24 @@ struct DeviceOperationInstanceFactory<
} }
} }
// bf16_i8_input
#if defined(CK_ENABLE_BF16) && defined(CK_ENABLE_INT8)
if constexpr(is_same_v<ADataType, bhalf_t> && is_same_v<BDataType, int8_t> &&
is_same_v<EDataType, bhalf_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_instances(op_ptrs);
}
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_instances(op_ptrs);
}
}
#endif
return op_ptrs; return op_ptrs;
} }
}; };
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -17,7 +17,32 @@ namespace tensor_operation { ...@@ -17,7 +17,32 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
void add_device_permute_scale_f16_instances( #ifdef CK_ENABLE_FP16
void add_device_permute_scale_1d_f16_instances(
std::vector<std::unique_ptr<DeviceElementwise<ck::Tuple<F16>,
ck::Tuple<F16>,
PassThrough,
element_wise::UnarySquare,
Scale,
1>>>&);
void add_device_permute_scale_2d_f16_instances(
std::vector<std::unique_ptr<DeviceElementwise<ck::Tuple<F16>,
ck::Tuple<F16>,
PassThrough,
element_wise::UnarySquare,
Scale,
2>>>&);
void add_device_permute_scale_3d_f16_instances(
std::vector<std::unique_ptr<DeviceElementwise<ck::Tuple<F16>,
ck::Tuple<F16>,
PassThrough,
element_wise::UnarySquare,
Scale,
3>>>&);
void add_device_permute_scale_4d_f16_instances(
std::vector<std::unique_ptr<DeviceElementwise<ck::Tuple<F16>, std::vector<std::unique_ptr<DeviceElementwise<ck::Tuple<F16>,
ck::Tuple<F16>, ck::Tuple<F16>,
PassThrough, PassThrough,
...@@ -25,7 +50,50 @@ void add_device_permute_scale_f16_instances( ...@@ -25,7 +50,50 @@ void add_device_permute_scale_f16_instances(
Scale, Scale,
4>>>&); 4>>>&);
void add_device_permute_scale_f32_instances( void add_device_permute_scale_5d_f16_instances(
std::vector<std::unique_ptr<DeviceElementwise<ck::Tuple<F16>,
ck::Tuple<F16>,
PassThrough,
element_wise::UnarySquare,
Scale,
5>>>&);
void add_device_permute_scale_6d_f16_instances(
std::vector<std::unique_ptr<DeviceElementwise<ck::Tuple<F16>,
ck::Tuple<F16>,
PassThrough,
element_wise::UnarySquare,
Scale,
6>>>&);
#endif
#ifdef CK_ENABLE_FP32
void add_device_permute_scale_1d_f32_instances(
std::vector<std::unique_ptr<DeviceElementwise<ck::Tuple<F32>,
ck::Tuple<F32>,
PassThrough,
element_wise::UnarySquare,
Scale,
1>>>&);
void add_device_permute_scale_2d_f32_instances(
std::vector<std::unique_ptr<DeviceElementwise<ck::Tuple<F32>,
ck::Tuple<F32>,
PassThrough,
element_wise::UnarySquare,
Scale,
2>>>&);
void add_device_permute_scale_3d_f32_instances(
std::vector<std::unique_ptr<DeviceElementwise<ck::Tuple<F32>,
ck::Tuple<F32>,
PassThrough,
element_wise::UnarySquare,
Scale,
3>>>&);
void add_device_permute_scale_4d_f32_instances(
std::vector<std::unique_ptr<DeviceElementwise<ck::Tuple<F32>, std::vector<std::unique_ptr<DeviceElementwise<ck::Tuple<F32>,
ck::Tuple<F32>, ck::Tuple<F32>,
PassThrough, PassThrough,
...@@ -33,6 +101,23 @@ void add_device_permute_scale_f32_instances( ...@@ -33,6 +101,23 @@ void add_device_permute_scale_f32_instances(
Scale, Scale,
4>>>&); 4>>>&);
void add_device_permute_scale_5d_f32_instances(
std::vector<std::unique_ptr<DeviceElementwise<ck::Tuple<F32>,
ck::Tuple<F32>,
PassThrough,
element_wise::UnarySquare,
Scale,
5>>>&);
void add_device_permute_scale_6d_f32_instances(
std::vector<std::unique_ptr<DeviceElementwise<ck::Tuple<F32>,
ck::Tuple<F32>,
PassThrough,
element_wise::UnarySquare,
Scale,
6>>>&);
#endif
template <typename InDataTypeTuple, template <typename InDataTypeTuple,
typename OutDataTypeTuple, typename OutDataTypeTuple,
typename ElementwiseOperation, typename ElementwiseOperation,
...@@ -57,15 +142,107 @@ struct DeviceOperationInstanceFactory< ...@@ -57,15 +142,107 @@ struct DeviceOperationInstanceFactory<
static auto GetInstances() static auto GetInstances()
{ {
std::vector<std::unique_ptr<DeviceOp>> op_ptrs; std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<InDataTypeTuple, ck::Tuple<F32>> && if constexpr(NumDim == 1)
is_same_v<OutDataTypeTuple, ck::Tuple<F32>>) {
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataTypeTuple, ck::Tuple<F32>> &&
is_same_v<OutDataTypeTuple, ck::Tuple<F32>>)
{
add_device_permute_scale_1d_f32_instances(op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataTypeTuple, ck::Tuple<F16>> &&
is_same_v<OutDataTypeTuple, ck::Tuple<F16>>)
{
add_device_permute_scale_1d_f16_instances(op_ptrs);
}
#endif
}
else if constexpr(NumDim == 2)
{
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataTypeTuple, ck::Tuple<F32>> &&
is_same_v<OutDataTypeTuple, ck::Tuple<F32>>)
{
add_device_permute_scale_2d_f32_instances(op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataTypeTuple, ck::Tuple<F16>> &&
is_same_v<OutDataTypeTuple, ck::Tuple<F16>>)
{
add_device_permute_scale_2d_f16_instances(op_ptrs);
}
#endif
}
else if constexpr(NumDim == 3)
{
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataTypeTuple, ck::Tuple<F32>> &&
is_same_v<OutDataTypeTuple, ck::Tuple<F32>>)
{
add_device_permute_scale_3d_f32_instances(op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataTypeTuple, ck::Tuple<F16>> &&
is_same_v<OutDataTypeTuple, ck::Tuple<F16>>)
{
add_device_permute_scale_3d_f16_instances(op_ptrs);
}
#endif
}
else if constexpr(NumDim == 4)
{
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataTypeTuple, ck::Tuple<F32>> &&
is_same_v<OutDataTypeTuple, ck::Tuple<F32>>)
{
add_device_permute_scale_4d_f32_instances(op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataTypeTuple, ck::Tuple<F16>> &&
is_same_v<OutDataTypeTuple, ck::Tuple<F16>>)
{
add_device_permute_scale_4d_f16_instances(op_ptrs);
}
#endif
}
else if constexpr(NumDim == 5)
{ {
add_device_permute_scale_f32_instances(op_ptrs); #ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataTypeTuple, ck::Tuple<F32>> &&
is_same_v<OutDataTypeTuple, ck::Tuple<F32>>)
{
add_device_permute_scale_5d_f32_instances(op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataTypeTuple, ck::Tuple<F16>> &&
is_same_v<OutDataTypeTuple, ck::Tuple<F16>>)
{
add_device_permute_scale_5d_f16_instances(op_ptrs);
}
#endif
} }
else if constexpr(is_same_v<InDataTypeTuple, ck::Tuple<F16>> && else if constexpr(NumDim == 6)
is_same_v<OutDataTypeTuple, ck::Tuple<F16>>)
{ {
add_device_permute_scale_f16_instances(op_ptrs); #ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataTypeTuple, ck::Tuple<F32>> &&
is_same_v<OutDataTypeTuple, ck::Tuple<F32>>)
{
add_device_permute_scale_6d_f32_instances(op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataTypeTuple, ck::Tuple<F16>> &&
is_same_v<OutDataTypeTuple, ck::Tuple<F16>>)
{
add_device_permute_scale_6d_f16_instances(op_ptrs);
}
#endif
} }
return op_ptrs; return op_ptrs;
} }
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_scale_impl.hpp" #include "ck/tensor_operation/gpu/device/impl/device_elementwise_scale_impl.hpp"
#include "ck/utility/data_type.hpp" #include "ck/utility/data_type.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck {
namespace tensor_operation {
namespace ck { namespace device {
namespace tensor_operation { namespace instance {
namespace device {
namespace instance { using F16 = ck::half_t;
using F32 = float;
using F16 = ck::half_t;
using F32 = float; using Pass = ck::tensor_operation::element_wise::PassThrough;
using UnaryOp = ck::tensor_operation::element_wise::UnarySquare;
using Pass = ck::tensor_operation::element_wise::PassThrough; using Scale = ck::tensor_operation::element_wise::Scale;
using UnaryOp = ck::tensor_operation::element_wise::UnarySquare;
using Scale = ck::tensor_operation::element_wise::Scale; // clang-format off
template <index_t NDims>
// clang-format off using device_permute_scale_f16_instances =
using device_permute_scale_f16_instances = std::tuple <
std::tuple < DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, Pass, UnaryOp, Scale, NDims, 1, ck::Sequence<1>, ck::Sequence<1>>,
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, Pass, UnaryOp, Scale, 4, 1, ck::Sequence<1>, ck::Sequence<1>>, DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, Pass, UnaryOp, Scale, NDims, 8, ck::Sequence<8>, ck::Sequence<1>>,
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, Pass, UnaryOp, Scale, 4, 8, ck::Sequence<1>, ck::Sequence<1>>, DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, Pass, UnaryOp, Scale, NDims, 4, ck::Sequence<4>, ck::Sequence<1>>,
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, Pass, UnaryOp, Scale, 4, 4, ck::Sequence<1>, ck::Sequence<1>>, DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, Pass, UnaryOp, Scale, NDims, 2, ck::Sequence<2>, ck::Sequence<1>>
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, Pass, UnaryOp, Scale, 4, 2, ck::Sequence<1>, ck::Sequence<1>> >;
>;
template <index_t NDims>
using device_permute_scale_f32_instances = std::tuple< using device_permute_scale_f32_instances = std::tuple<
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, Pass, UnaryOp, Scale, 4, 1, ck::Sequence<1>, ck::Sequence<1>>, DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, Pass, UnaryOp, Scale, NDims, 1, ck::Sequence<1>, ck::Sequence<1>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, Pass, UnaryOp, Scale, 4, 8, ck::Sequence<1>, ck::Sequence<1>>, DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, Pass, UnaryOp, Scale, NDims, 8, ck::Sequence<8>, ck::Sequence<1>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, Pass, UnaryOp, Scale, 4, 4, ck::Sequence<1>, ck::Sequence<1>>, DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, Pass, UnaryOp, Scale, NDims, 4, ck::Sequence<4>, ck::Sequence<1>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, Pass, UnaryOp, Scale, 4, 2, ck::Sequence<1>, ck::Sequence<1>> DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, Pass, UnaryOp, Scale, NDims, 2, ck::Sequence<2>, ck::Sequence<1>>
>; >;
// clang-format on // clang-format on
void add_device_permute_scale_f16_instances( } // namespace instance
std::vector<std::unique_ptr< } // namespace device
DeviceElementwise<ck::Tuple<F16>, ck::Tuple<F16>, Pass, UnaryOp, Scale, 4>>>& instances) } // namespace tensor_operation
{ } // namespace ck
add_device_operation_instances(instances, device_permute_scale_f16_instances{});
}
void add_device_permute_scale_f32_instances(
std::vector<std::unique_ptr<
DeviceElementwise<ck::Tuple<F32>, ck::Tuple<F32>, Pass, UnaryOp, Scale, 4>>>& instances)
{
add_device_operation_instances(instances, device_permute_scale_f32_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
set(GEMM_SPLITK_INSTANCES) set(GEMM_SPLITK_INSTANCES)
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp list(APPEND GEMM_SPLITK_INSTANCES
device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp
device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp
device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_instance.cpp device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_irregular_instance.cpp device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_interwave_instance.cpp device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_irregular_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_interwave_irregular_instance.cpp device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_interwave_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v2_instance.cpp device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_interwave_irregular_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v2_irregular_instance.cpp device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v2_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_instance.cpp device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v2_irregular_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_irregular_instance.cpp device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_interwave_instance.cpp device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_irregular_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_interwave_irregular_instance.cpp device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_interwave_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v2_instance.cpp device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_interwave_irregular_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v2_irregular_instance.cpp device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v2_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v2_irregular_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp
device_gemm_xdl_splitk_fp8_f16_f16_mk_kn_mn_v1_instance.cpp device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp
device_gemm_xdl_splitk_fp8_f16_f16_mk_kn_mn_v1_interwave_instance.cpp device_gemm_xdl_splitk_fp8_f16_f16_mk_kn_mn_v1_instance.cpp
device_gemm_xdl_splitk_fp8_f16_f16_mk_kn_mn_v2_instance.cpp device_gemm_xdl_splitk_fp8_f16_f16_mk_kn_mn_v1_interwave_instance.cpp
device_gemm_xdl_splitk_lds_direct_load_f16_f16_f16_mk_nk_mn_instance.cpp device_gemm_xdl_splitk_fp8_f16_f16_mk_kn_mn_v2_instance.cpp
device_gemm_xdl_splitk_fp8_f16_f16_mk_nk_mn_instance.cpp device_gemm_xdl_splitk_lds_direct_load_f16_f16_f16_mk_nk_mn_instance.cpp
device_gemm_xdl_splitk_fp8_f16_f16_km_kn_mn_instance.cpp device_gemm_xdl_splitk_fp8_f16_f16_mk_nk_mn_instance.cpp
device_gemm_xdl_splitk_fp8_f16_f16_km_nk_mn_instance.cpp device_gemm_xdl_splitk_fp8_f16_f16_km_kn_mn_instance.cpp
device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_v1_instance.cpp device_gemm_xdl_splitk_fp8_f16_f16_km_nk_mn_instance.cpp
device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_v1_interwave_instance.cpp device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_v1_instance.cpp
device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_v2_instance.cpp device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_v1_interwave_instance.cpp
device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_irregular_instance.cpp device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_v2_instance.cpp
device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_v1_instance.cpp device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_irregular_instance.cpp
device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_v1_interwave_instance.cpp device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_kpb128_instance.cpp
device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_v2_instance.cpp device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_v1_instance.cpp
device_gemm_xdl_splitk_f16_fp8_f16_km_kn_mn_instance.cpp device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_v1_interwave_instance.cpp
device_gemm_xdl_splitk_f16_fp8_f16_km_nk_mn_instance.cpp device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_v2_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_mk_kn_mn_instance.cpp device_gemm_xdl_splitk_f16_fp8_f16_km_kn_mn_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_mk_nk_mn_instance.cpp device_gemm_xdl_splitk_f16_fp8_f16_km_nk_mn_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_km_kn_mn_instance.cpp device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_mk_kn_mn_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_km_nk_mn_instance.cpp) device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_mk_nk_mn_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_km_kn_mn_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_km_nk_mn_instance.cpp
)
add_instance_library(device_gemm_splitk_instance ${GEMM_SPLITK_INSTANCES}) add_instance_library(device_gemm_splitk_instance ${GEMM_SPLITK_INSTANCES})
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F8 = ck::f8_t;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;
static constexpr auto GemmKPadding = ck::tensor_operation::device::GemmSpecialization::KPadding;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
template <ck::tensor_operation::device::GemmSpecialization GemmSpec,
ck::PipelineVersion PipVer,
ck::LoopScheduler LoopSche>
using device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_irregular_kpb128_instances = std::tuple<
// clang-format off
//#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 8, 16, 16, 16, 1, 1, S<1, 8, 8, 2>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche, F16, F8>,
DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 8, 16, 16, 16, 1, 2, S<1, 8, 8, 2>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche, F16, F8>,
DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 8, 16, 16, 16, 1, 4, S<1, 8, 8, 2>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche, F16, F8>
// clang-format on
>;
void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_kpb128_instances(
std::vector<std::unique_ptr<
DeviceGemmSplitK<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
instances)
{
// default
add_device_operation_instances(
instances,
device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_irregular_kpb128_instances<
GemmDefault,
ck::PipelineVersion::v2,
ck::LoopScheduler::Default>{});
add_device_operation_instances(
instances,
device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_irregular_kpb128_instances<
GemmDefault,
ck::PipelineVersion::v1,
ck::LoopScheduler::Interwave>{});
add_device_operation_instances(
instances,
device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_irregular_kpb128_instances<
GemmDefault,
ck::PipelineVersion::v1,
ck::LoopScheduler::Default>{});
// MNKPadding
add_device_operation_instances(
instances,
device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_irregular_kpb128_instances<
GemmMNKPadding,
ck::PipelineVersion::v2,
ck::LoopScheduler::Default>{});
add_device_operation_instances(
instances,
device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_irregular_kpb128_instances<
GemmMNKPadding,
ck::PipelineVersion::v1,
ck::LoopScheduler::Interwave>{});
add_device_operation_instances(
instances,
device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_irregular_kpb128_instances<
GemmMNKPadding,
ck::PipelineVersion::v1,
ck::LoopScheduler::Default>{});
// KPadding
add_device_operation_instances(
instances,
device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_irregular_kpb128_instances<
GemmKPadding,
ck::PipelineVersion::v2,
ck::LoopScheduler::Default>{});
add_device_operation_instances(
instances,
device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_irregular_kpb128_instances<
GemmKPadding,
ck::PipelineVersion::v1,
ck::LoopScheduler::Interwave>{});
add_device_operation_instances(
instances,
device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_irregular_kpb128_instances<
GemmKPadding,
ck::PipelineVersion::v1,
ck::LoopScheduler::Default>{});
// MNPadding
add_device_operation_instances(
instances,
device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_irregular_kpb128_instances<
GemmMNPadding,
ck::PipelineVersion::v2,
ck::LoopScheduler::Default>{});
add_device_operation_instances(
instances,
device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_irregular_kpb128_instances<
GemmMNPadding,
ck::PipelineVersion::v1,
ck::LoopScheduler::Interwave>{});
add_device_operation_instances(
instances,
device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_irregular_kpb128_instances<
GemmMNPadding,
ck::PipelineVersion::v1,
ck::LoopScheduler::Default>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment