Unverified Commit f3b6c23a authored by Bartłomiej Kocot's avatar Bartłomiej Kocot Committed by GitHub
Browse files

Add blockwise gemm to ck wrapper (#1139)

* Add blockwise gemm to ck wrapper

* Add blockwise gemm traits

* Disable test_gemm for non xdl devices

* Fixes

* Add c layout descritpions
parent 6651a124
...@@ -11,7 +11,7 @@ None ...@@ -11,7 +11,7 @@ None
None None
### Additions ### Additions
* Introduced wrapper sublibrary (limited functionality). (#1071, #1098, #1108, #1126) * Introduced wrapper sublibrary (limited functionality). (#1071, #1098, #1108, #1126, #1139)
### Changes ### Changes
None None
......
...@@ -89,3 +89,4 @@ Operations ...@@ -89,3 +89,4 @@ Operations
------------------------------------- -------------------------------------
.. doxygenfile:: copy.hpp .. doxygenfile:: copy.hpp
.. doxygenfile:: gemm.hpp
...@@ -248,6 +248,9 @@ struct Layout ...@@ -248,6 +248,9 @@ struct Layout
using DefaultIdxsTupleType = remove_cvref_t<decltype(GenerateDefaultIdxsTuple(Shape{}))>; using DefaultIdxsTupleType = remove_cvref_t<decltype(GenerateDefaultIdxsTuple(Shape{}))>;
public: public:
using LayoutShape = Shape;
using LayoutUnrolledDescriptorType = UnrolledDescriptorType;
/** /**
* \brief Transform descriptor to align to passed indexes. * \brief Transform descriptor to align to passed indexes.
* *
......
...@@ -3,45 +3,18 @@ ...@@ -3,45 +3,18 @@
#pragma once #pragma once
#include "../utils/tensor_utils.hpp" #include "ck/wrapper/utils/tensor_utils.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_description/tensor_space_filling_curve.hpp"
namespace ck { namespace ck {
namespace wrapper { namespace wrapper {
/**
* \brief Perform generic copy between two tensors partitions (threadwise copy).
* Tensors must have the same size.
*
* \param src_tensor Source tensor.
* \param dst_tensor Destination tensor.
*/
template <typename SrcTensorType, typename DstTensorType>
__host__ __device__ void copy(const SrcTensorType& src_tensor, DstTensorType& dst_tensor)
{
if constexpr(!SrcTensorType::IsDynamicBuffer)
{
using SizeType = decltype(size(src_tensor));
static_for<0, SizeType{}, 1>{}([&](auto i) { dst_tensor(i) = src_tensor(i); });
}
else if constexpr(!DstTensorType::IsDynamicBuffer)
{
using SizeType = decltype(size(dst_tensor));
static_for<0, SizeType{}, 1>{}([&](auto i) { dst_tensor(i) = src_tensor(i); });
}
else
{
for(int i = 0; i < size(src_tensor); i++)
{
dst_tensor(i) = src_tensor(i);
}
}
}
/** /**
* \brief Perform optimized copy between two tensors partitions (threadwise copy). * \brief Perform optimized copy between two tensors partitions (threadwise copy).
* Tensors must have the same size. * Tensors must have the same size.
...@@ -167,9 +140,99 @@ __device__ void copy(const SrcTensorType& src_tensor, DstTensorType& dst_tensor) ...@@ -167,9 +140,99 @@ __device__ void copy(const SrcTensorType& src_tensor, DstTensorType& dst_tensor)
else else
{ {
// Perform copy between StaticBuffers // Perform copy between StaticBuffers
copy(src_tensor, dst_tensor); static_for<0, SrcShapeType::Size(), 1>{}([&](auto i) { dst_tensor(i) = src_tensor(i); });
} }
} }
/**
* \brief Perform generic copy between two tensors partitions (threadwise copy).
* Tensors must have the same size.
*
* \param src_tensor Source tensor.
* \param dst_tensor Destination tensor.
*/
template <typename SrcTensorType, typename DstTensorType>
__host__ __device__ void copy(const SrcTensorType& src_tensor, DstTensorType& dst_tensor)
{
// Generate default params
using SrcShapeType = remove_cvref_t<decltype(shape(src_tensor))>;
constexpr index_t num_dims = SrcShapeType::Size();
// Incrementing dims 0, 1, 2 ... num_dims - 1
constexpr auto dim_access_order_tuple =
generate_tuple([](auto i) { return Number<i>{}; }, Number<num_dims>{});
constexpr index_t vector_dim = num_dims - 1;
constexpr index_t scalar_per_vector = 1;
copy<decltype(dim_access_order_tuple), vector_dim, scalar_per_vector>(src_tensor, dst_tensor);
}
/**
* \brief Perform optimized blockwise copy between two tensors. Tensors must have the
* same size.
*
* \note At now Vgpr and Sgpr are not supported.
*
* \tparam DimAccessOrderTuple Tuple with dimension access order.
* \tparam VectorDim Dimension for vectorize read and write.
* \tparam ScalarPerVector Number of scalar per vectorize read and write.
* \param src_tensor Source tensor.
* \param dst_tensor Destination tensor.
* \param thread_layout Thread layout per each dimension for copy.
*/
template <typename DimAccessOrderTuple,
index_t VectorDim,
index_t ScalarPerVector,
typename SrcTensorType,
typename DstTensorType,
typename ThreadLayoutTuple>
__device__ void blockwise_copy(const SrcTensorType& src_tensor,
DstTensorType& dst_tensor,
[[maybe_unused]] ThreadLayoutTuple& thread_layout)
{
static_assert(SrcTensorType::IsDynamicBuffer && DstTensorType::IsDynamicBuffer);
static_assert(is_detected<is_tuple, DimAccessOrderTuple>::value);
const auto& in_grid_desc = layout(src_tensor).GetUnrolledDescriptor();
const auto& out_grid_desc = layout(dst_tensor).GetUnrolledDescriptor();
using SrcShapeType = remove_cvref_t<decltype(shape(src_tensor))>;
constexpr index_t num_dims = SrcShapeType::Size();
constexpr auto tile_lengths_seq =
generate_sequence_v2([](auto I) { return size(SrcShapeType{}.At(I)); }, Number<num_dims>{});
constexpr auto thread_layout_seq = generate_sequence_v2(
[](auto I) { return size(ThreadLayoutTuple{}.At(I)); }, Number<num_dims>{});
constexpr auto dim_access_order = generate_sequence_v2(
[](auto I) { return DimAccessOrderTuple{}.At(I); }, Number<num_dims>{});
using ThisThreadBlock = ThisThreadBlock<size(ThreadLayoutTuple{})>;
// Perform copy between DynamicBuffers
auto transfer = ThreadGroupTensorSliceTransfer_v7<
ThisThreadBlock,
Tuple<typename SrcTensorType::TensorElementType>,
Tuple<typename DstTensorType::TensorElementType>,
decltype(tie(in_grid_desc)),
decltype(tie(out_grid_desc)),
tensor_operation::element_wise::PassThrough,
Sequence<static_cast<index_t>(InMemoryDataOperationEnum::Set)>,
std::remove_const_t<decltype(tile_lengths_seq)>,
std::remove_const_t<decltype(thread_layout_seq)>,
std::remove_const_t<decltype(dim_access_order)>,
std::remove_const_t<decltype(dim_access_order)>,
VectorDim,
ScalarPerVector,
Sequence<true>,
Sequence<true>>{in_grid_desc,
make_tuple(src_tensor.GetMultiIdxOffsets()),
out_grid_desc,
make_tuple(dst_tensor.GetMultiIdxOffsets()),
tensor_operation::element_wise::PassThrough{}};
transfer.Run(tie(in_grid_desc),
tie(src_tensor.GetBuffer()),
tie(out_grid_desc),
tie(dst_tensor.GetBuffer()));
}
} // namespace wrapper } // namespace wrapper
} // namespace ck } // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/wrapper/utils/tensor_utils.hpp"
#include "ck/wrapper/traits/blockwise_gemm_xdl_traits.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
namespace ck {
namespace wrapper {
namespace {
namespace detail {
/**
* \brief Create block descriptor (K0, MPerBlock or NPerBlock, K1).
*
*
* \tparam K1 The number of K-dim elements that are packed together as a separate logical dimension.
* \tparam TileLayout Tensor data tile layout (M,K) or (N,K).
*
* \return Block descriptor (K0, MPerBlock or NPerBlock, K1)
*/
template <index_t K1, typename TileLayout>
__device__ constexpr auto GetBlockDescriptor()
{
using TileLayoutShape = typename TileLayout::LayoutShape;
using TileLayoutDescriptor = typename TileLayout::LayoutUnrolledDescriptorType;
constexpr auto K0PerBlock = Number<size<1>(TileLayoutShape{})>{} / Number<K1>{};
// MPerBlock or NPerBlock
constexpr auto Dim0 = Number<size<0>(TileLayoutShape{})>{};
constexpr auto a_block_desc_k0_m_k1 = transform_tensor_descriptor(
TileLayoutDescriptor{},
make_tuple(make_unmerge_transform(make_tuple(K0PerBlock, Number<K1>{})),
make_pass_through_transform(Dim0)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_block_desc_k0_m_k1;
}
} // namespace detail
} // namespace
/**
* \brief Perform blockwise gemm xdl on tensors stored in lds. Result will be
* stored in Vgpr register. A data layout must be (MPerBlock, KPerBlock) and B
* data layout must be (NPerBlock, KPerBlock).
*
* \note C output Vgpr register layout (8D):
* - MXdlPerWave - The number of MFMA instructions run by single wave in M
* dimension per tile.
* - NXdlPerWave - The number of MFMA instructions run by single wave in N
* dimension per tile.
* - MWave - Equals to 1 since this is for single wave.
* - NWave - Equals to 1 since this is for single wave.
* - NumGroupsPerBlock - Mfma instruction internal layout (depeneds on the
* instruction size).
* - NumInputsBlock - Mfma instruction internal layout (depeneds on the
* instruction size).
* - GroupSize - Mfma instruction internal layout (depeneds on the
* instruction size).
* - NumThreadsPerBlock - Mfma instruction internal layout (depeneds on the
* instruction size).
*
* \tparam DataType Input data types.
* \tparam BlockSize Tensor to pad.
* \tparam GemmTraits Traits of gemm xdl operation.
* \param a_local_tile_tensor A tensor in LDS memory for blockwise gemm
* (MPerBlock, KPerBlock) layout.
* \param b_local_tile_tensor B tensor in LDS memory for blockwise gemm
* (NPerBlock, KPerBlock) layout.
* \param c_reg_tensor C tensor VGPR memory for blockwise gemm.
*/
template <typename DataType,
index_t BlockSize,
typename GemmTraits,
typename ATensorType,
typename BTensorType,
typename CTensorType>
__device__ void blockwise_gemm_xdl(const ATensorType& a_local_tile_tensor,
const BTensorType& b_local_tile_tensor,
CTensorType& c_reg_tensor)
{
static_assert(ATensorType::TensorBufferAddressSpace == MemoryTypeEnum::Lds);
static_assert(BTensorType::TensorBufferAddressSpace == MemoryTypeEnum::Lds);
static_assert(CTensorType::TensorBufferAddressSpace == MemoryTypeEnum::Vgpr);
static_assert(is_same_v<DataType, typename ATensorType::TensorElementType>);
static_assert(is_same_v<DataType, typename BTensorType::TensorElementType>);
constexpr bool is_integer =
is_same_v<DataType, int8_t> || is_same_v<DataType, int16_t> || is_same_v<DataType, int32_t>;
using GemmAccDataType = std::conditional_t<is_integer, int32_t, float>;
using ATileLayout = remove_cvref_t<decltype(layout(a_local_tile_tensor))>;
using BTileLayout = remove_cvref_t<decltype(layout(b_local_tile_tensor))>;
using ABlockDesc_K0_M_K1_Type =
decltype(detail::GetBlockDescriptor<GemmTraits::K1, ATileLayout>());
using BBlockDesc_K0_N_K1_Type =
decltype(detail::GetBlockDescriptor<GemmTraits::K1, BTileLayout>());
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
DataType,
DataType,
GemmAccDataType,
ABlockDesc_K0_M_K1_Type,
BBlockDesc_K0_N_K1_Type,
GemmTraits::MPerXDL,
GemmTraits::NPerXDL,
GemmTraits::MXdlPerWave,
GemmTraits::NXdlPerWave,
GemmTraits::K1>
blockwise_gemm_xdl_op{};
blockwise_gemm_xdl_op.Run(
a_local_tile_tensor.GetBuffer(), b_local_tile_tensor.GetBuffer(), c_reg_tensor.GetBuffer());
}
/**
* \brief Create local partition per thread for C tensor.
*
* \note C output global memory layout (8D):
* - MXdlPerWave - The number of MFMA instructions run by single wave in M
* dimension.
* - NXdlPerWave - The number of MFMA instructions run by single wave in N
* dimension.
* - MWave - The number of waves in single tile M dimension per tile.
* - NWave - The number of waves in single tile N dimension per tile.
* - NumGroupsPerBlock - Mfma instruction internal layout (depeneds on the
* instruction size).
* - NumInputsBlock - Mfma instruction internal layout (depeneds on the
* instruction size).
* - GroupSize - Mfma instruction internal layout (depeneds on the
* instruction size).
* - NumThreadsPerBlock - Mfma instruction internal layout (depeneds on the
* instruction size).
*
* \tparam DataType Input data types.
* \tparam ATileLayout A tensor layout.
* \tparam BTileLayout B tensor layout.
* \tparam BlockSize Number of threads in block.
* \tparam GemmTraits Traits of gemm xdl operation.
* \param c_local_tile_tensor C tensor in LDS memory for blockwise gemm
* (MPerBlock, NPerBlock) layout.
*
* \return Partition c tensor for blockwise gemm.
*/
template <typename DataType,
typename ATileLayout,
typename BTileLayout,
index_t BlockSize,
typename GemmTraits,
typename CTensorType>
__host__ __device__ constexpr auto
make_blockwise_gemm_xdl_c_local_partition(CTensorType& c_local_tile_tensor)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
constexpr auto I5 = Number<5>{};
constexpr auto I6 = Number<6>{};
constexpr auto I7 = Number<7>{};
constexpr bool is_integer =
is_same_v<DataType, int8_t> || is_same_v<DataType, int16_t> || is_same_v<DataType, int32_t>;
using GemmAccDataType = std::conditional_t<is_integer, int32_t, float>;
using ABlockDesc_K0_M_K1_Type =
decltype(detail::GetBlockDescriptor<GemmTraits::K1, ATileLayout>());
using BBlockDesc_K0_N_K1_Type =
decltype(detail::GetBlockDescriptor<GemmTraits::K1, BTileLayout>());
using BlockwiseGemmXdlops =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
DataType,
DataType,
GemmAccDataType,
ABlockDesc_K0_M_K1_Type,
BBlockDesc_K0_N_K1_Type,
GemmTraits::MPerXDL,
GemmTraits::NPerXDL,
GemmTraits::MXdlPerWave,
GemmTraits::NXdlPerWave,
GemmTraits::K1>;
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
BlockwiseGemmXdlops::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0);
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1);
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I2);
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I3);
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I4);
constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I5);
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6);
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I7);
// Calculate offset on grid
const auto c_thread_mtx_on_block =
BlockwiseGemmXdlops::CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
const index_t m_thread_data_on_grid =
c_local_tile_tensor.GetMultiIdxOffsets()[I0] + c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_grid =
c_local_tile_tensor.GetMultiIdxOffsets()[I1] + c_thread_mtx_on_block[I1];
const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_grid_idx =
m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_grid));
const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor =
make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_grid_idx =
n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_grid));
// Create partition shape based on descriptor dims.
const auto partition_shape = make_tuple(M0, N0, I1, I1, M2, I1, M4, I1);
const auto partition_desc = BlockwiseGemmXdlops::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(
layout(c_local_tile_tensor).GetUnrolledDescriptor());
const auto partition_layout =
Layout<remove_reference_t<decltype(partition_shape)>, decltype(partition_desc)>(
partition_shape, partition_desc);
auto partition_tensor = make_tensor<CTensorType::TensorBufferAddressSpace>(
c_local_tile_tensor.GetPointer(), partition_layout);
partition_tensor.SetMultiIdxOffset(make_multi_index(m_thread_data_on_grid_idx[I0],
n_thread_data_on_grid_idx[I0],
m_thread_data_on_grid_idx[I1],
n_thread_data_on_grid_idx[I1],
m_thread_data_on_grid_idx[I2],
m_thread_data_on_grid_idx[I3],
m_thread_data_on_grid_idx[I4],
n_thread_data_on_grid_idx[I2]));
return partition_tensor;
}
/**
* \brief Create local partition per thread for C tensor.
*
* \note C output Vgpr register layout (8D):
* - MXdlPerWave - The number of MFMA instructions run by single wave in M
* dimension per tile.
* - NXdlPerWave - The number of MFMA instructions run by single wave in N
* dimension per tile.
* - MWave - Equals to 1 since this is for single wave.
* - NWave - Equals to 1 since this is for single wave.
* - NumGroupsPerBlock - Mfma instruction internal layout (depeneds on the
* instruction size).
* - NumInputsBlock - Mfma instruction internal layout (depeneds on the
* instruction size).
* - GroupSize - Mfma instruction internal layout (depeneds on the
* instruction size).
* - NumThreadsPerBlock - Mfma instruction internal layout (depeneds on the
* instruction size).
*
* \tparam DataType Input data types.
* \tparam ATileLayout A tensor layout.
* \tparam BTileLayout B tensor layout.
* \tparam BlockSize Number of threads in block.
* \tparam GemmTraits Traits of gemm xdl operation.
*
* \return Vgpr c tensor for blockwise gemm.
*/
template <typename DataType,
typename ATileLayout,
typename BTileLayout,
index_t BlockSize,
typename GemmTraits>
__host__ __device__ constexpr auto make_blockwise_gemm_xdl_c_vgpr()
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
constexpr auto I5 = Number<5>{};
constexpr auto I6 = Number<6>{};
constexpr auto I7 = Number<7>{};
constexpr bool is_integer =
is_same_v<DataType, int8_t> || is_same_v<DataType, int16_t> || is_same_v<DataType, int32_t>;
using GemmAccDataType = std::conditional_t<is_integer, int32_t, float>;
using ABlockDesc_K0_M_K1_Type =
decltype(detail::GetBlockDescriptor<GemmTraits::K1, ATileLayout>());
using BBlockDesc_K0_N_K1_Type =
decltype(detail::GetBlockDescriptor<GemmTraits::K1, BTileLayout>());
using BlockwiseGemmXdlops =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
DataType,
DataType,
GemmAccDataType,
ABlockDesc_K0_M_K1_Type,
BBlockDesc_K0_N_K1_Type,
GemmTraits::MPerXDL,
GemmTraits::NPerXDL,
GemmTraits::MXdlPerWave,
GemmTraits::NXdlPerWave,
GemmTraits::K1>;
// Calcualte descriptor, shape and layout
constexpr auto vgpr_desc = BlockwiseGemmXdlops::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
const auto vgpr_shape = make_tuple(vgpr_desc.GetLengths()[I0],
vgpr_desc.GetLengths()[I1],
vgpr_desc.GetLengths()[I2],
vgpr_desc.GetLengths()[I3],
vgpr_desc.GetLengths()[I4],
vgpr_desc.GetLengths()[I5],
vgpr_desc.GetLengths()[I6],
vgpr_desc.GetLengths()[I7]);
const auto vgpr_layout = Layout<remove_reference_t<decltype(vgpr_shape)>, decltype(vgpr_desc)>(
vgpr_shape, vgpr_desc);
// Get vector type for Vgpr
using BlockwiseGemmCThreadBufferType =
remove_reference_t<decltype(BlockwiseGemmXdlops{}.GetCThreadBuffer())>;
using VgprVectorType = typename BlockwiseGemmCThreadBufferType::V;
return ck::wrapper::make_register_tensor<ck::wrapper::MemoryTypeEnum::Vgpr, VgprVectorType>(
vgpr_layout);
}
} // namespace wrapper
} // namespace ck
...@@ -10,8 +10,8 @@ ...@@ -10,8 +10,8 @@
namespace ck { namespace ck {
namespace wrapper { namespace wrapper {
namespace detail {
namespace { namespace {
namespace detail {
/** /**
* \brief Check if Tuple contains Slice object * \brief Check if Tuple contains Slice object
* *
...@@ -187,8 +187,8 @@ __host__ __device__ constexpr auto GenerateSlicedDescriptor(const Tuple<Ts...>& ...@@ -187,8 +187,8 @@ __host__ __device__ constexpr auto GenerateSlicedDescriptor(const Tuple<Ts...>&
const auto upper_dims = decltype(GenerateUpperDims<0>(TransformsTupleType{})){}; const auto upper_dims = decltype(GenerateUpperDims<0>(TransformsTupleType{})){};
return transform_tensor_descriptor(flatten_desc, transforms, lower_dims, upper_dims); return transform_tensor_descriptor(flatten_desc, transforms, lower_dims, upper_dims);
} }
} // namespace
} // namespace detail } // namespace detail
} // namespace
/** /**
* \brief Tensor wrapper that performs static and dynamic buffer logic. * \brief Tensor wrapper that performs static and dynamic buffer logic.
...@@ -209,7 +209,10 @@ struct Tensor ...@@ -209,7 +209,10 @@ struct Tensor
public: public:
using ElementSpaceSize = decltype(Layout<Shape, UnrolledDescriptorType>{ using ElementSpaceSize = decltype(Layout<Shape, UnrolledDescriptorType>{
Shape{}, UnrolledDescriptorType{}}.GetElementSpaceSize()); // SpaceSize type for buffer Shape{}, UnrolledDescriptorType{}}.GetElementSpaceSize()); // SpaceSize type for buffer
using TensorElementType = ElementType; // DataType using TensorElementType = std::conditional_t<
is_scalar_type<ElementType>::value,
ElementType,
typename scalar_type<std::remove_const_t<ElementType>>::type>; // DataType
static constexpr MemoryTypeEnum TensorBufferAddressSpace = BufferAddressSpace; static constexpr MemoryTypeEnum TensorBufferAddressSpace = BufferAddressSpace;
static constexpr bool IsDynamicBuffer = !(BufferAddressSpace == MemoryTypeEnum ::Sgpr || static constexpr bool IsDynamicBuffer = !(BufferAddressSpace == MemoryTypeEnum ::Sgpr ||
...@@ -280,7 +283,7 @@ struct Tensor ...@@ -280,7 +283,7 @@ struct Tensor
* \return Requested value. * \return Requested value.
*/ */
template <typename... Ts, enable_if_t<!detail::HasSlice(Tuple<Ts...>{}), bool> = false> template <typename... Ts, enable_if_t<!detail::HasSlice(Tuple<Ts...>{}), bool> = false>
__host__ __device__ const ElementType& operator[](const Tuple<Ts...>& idx) const __host__ __device__ const TensorElementType& operator[](const Tuple<Ts...>& idx) const
{ {
if constexpr(IsDynamicBuffer) if constexpr(IsDynamicBuffer)
{ {
...@@ -301,13 +304,13 @@ struct Tensor ...@@ -301,13 +304,13 @@ struct Tensor
} }
template <typename... Ts, enable_if_t<!detail::HasSlice(Tuple<Ts...>{}), bool> = false> template <typename... Ts, enable_if_t<!detail::HasSlice(Tuple<Ts...>{}), bool> = false>
__host__ __device__ const ElementType& operator()(const Tuple<Ts...>& idx) const __host__ __device__ const TensorElementType& operator()(const Tuple<Ts...>& idx) const
{ {
return this->operator[](idx); return this->operator[](idx);
} }
template <typename... Idxs, enable_if_t<!detail::HasSlice(Tuple<Idxs...>{}), bool> = false> template <typename... Idxs, enable_if_t<!detail::HasSlice(Tuple<Idxs...>{}), bool> = false>
__host__ __device__ const ElementType& operator()(Idxs... idxs) const __host__ __device__ const TensorElementType& operator()(Idxs... idxs) const
{ {
return this->operator[](make_tuple(idxs...)); return this->operator[](make_tuple(idxs...));
} }
...@@ -319,7 +322,7 @@ struct Tensor ...@@ -319,7 +322,7 @@ struct Tensor
* \return Requested value. * \return Requested value.
*/ */
template <typename... Ts, enable_if_t<!detail::HasSlice(Tuple<Ts...>{}), bool> = false> template <typename... Ts, enable_if_t<!detail::HasSlice(Tuple<Ts...>{}), bool> = false>
__host__ __device__ ElementType& operator[](const Tuple<Ts...>& idx) __host__ __device__ TensorElementType& operator[](const Tuple<Ts...>& idx)
{ {
if constexpr(IsDynamicBuffer) if constexpr(IsDynamicBuffer)
{ {
...@@ -340,13 +343,13 @@ struct Tensor ...@@ -340,13 +343,13 @@ struct Tensor
} }
template <typename... Ts, enable_if_t<!detail::HasSlice(Tuple<Ts...>{}), bool> = false> template <typename... Ts, enable_if_t<!detail::HasSlice(Tuple<Ts...>{}), bool> = false>
__host__ __device__ ElementType& operator()(const Tuple<Ts...>& idx) __host__ __device__ TensorElementType& operator()(const Tuple<Ts...>& idx)
{ {
return this->operator[](idx); return this->operator[](idx);
} }
template <typename... Idxs, enable_if_t<!detail::HasSlice(Tuple<Idxs...>{}), bool> = false> template <typename... Idxs, enable_if_t<!detail::HasSlice(Tuple<Idxs...>{}), bool> = false>
__host__ __device__ ElementType& operator()(Idxs... idxs) __host__ __device__ TensorElementType& operator()(Idxs... idxs)
{ {
return this->operator[](make_tuple(idxs...)); return this->operator[](make_tuple(idxs...));
} }
...@@ -366,7 +369,7 @@ struct Tensor ...@@ -366,7 +369,7 @@ struct Tensor
* *
* \return Pointer. * \return Pointer.
*/ */
__host__ __device__ ElementType* GetPointer() const { return buffer_.p_data_; } __host__ __device__ TensorElementType* GetPointer() const { return buffer_.p_data_; }
__host__ __device__ constexpr auto& GetBuffer() { return buffer_; } __host__ __device__ constexpr auto& GetBuffer() { return buffer_; }
__host__ __device__ constexpr auto& GetBuffer() const { return buffer_; } __host__ __device__ constexpr auto& GetBuffer() const { return buffer_; }
...@@ -395,10 +398,18 @@ struct Tensor ...@@ -395,10 +398,18 @@ struct Tensor
ElementType, ElementType,
ElementSpaceSize, ElementSpaceSize,
true /*InvalidElementUseNumericalZeroValue*/>; true /*InvalidElementUseNumericalZeroValue*/>;
using StaticBufferType = StaticBuffer<BufferAddressSpace, using StaticBufferType = std::conditional_t<
ElementType, is_scalar_type<ElementType>::value,
size(Shape{}), StaticBuffer<BufferAddressSpace,
true /*InvalidElementUseNumericalZeroValue*/>; ElementType,
size(Shape{}),
true /*InvalidElementUseNumericalZeroValue*/>,
StaticBufferTupleOfVector<BufferAddressSpace,
TensorElementType,
size(Shape{}) /
scalar_type<std::remove_const_t<ElementType>>::vector_size,
scalar_type<std::remove_const_t<ElementType>>::vector_size,
true /*InvalidElementUseNumericalZeroValue*/>>;
// If register use static buffer, else use dynamic buffer // If register use static buffer, else use dynamic buffer
using Buffer = std::conditional_t<IsDynamicBuffer, DynamicBufferType, StaticBufferType>; using Buffer = std::conditional_t<IsDynamicBuffer, DynamicBufferType, StaticBufferType>;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
namespace ck {
namespace wrapper {
/**
* \brief Traits for blockwise gemm xdl.
*
* \tparam MPerXDLValue The MFMA instruction size in M dimension.
* \tparam NPerXDLValue The MFMA instruction size in N dimension.
* \tparam MXdlPerWaveValue The number of MFMA instructions run by single
* wave in M dimension.
* \tparam NXdlPerWaveValue The number of MFMA instructions run by single
* wave in N dimension.
* \tparam K1Value The number of K-dim elements that are packed together as
* a separate logical dimension. Usually aligns with vector load size.
*/
template <index_t MPerXDLValue,
index_t NPerXDLValue,
index_t MXdlPerWaveValue,
index_t NXdlPerWaveValue,
index_t K1Value>
struct BlockwisGemmXdlTraits
{
static constexpr index_t MPerXDL = MPerXDLValue;
static constexpr index_t NPerXDL = NPerXDLValue;
static constexpr index_t MXdlPerWave = MXdlPerWaveValue;
static constexpr index_t NXdlPerWave = NXdlPerWaveValue;
static constexpr index_t K1 = K1Value;
};
struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_4K1 : BlockwisGemmXdlTraits<32, 32, 4, 2, 4>
{
};
struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_4K1 : BlockwisGemmXdlTraits<32, 32, 2, 4, 4>
{
};
struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_4K1 : BlockwisGemmXdlTraits<32, 32, 2, 2, 4>
{
};
} // namespace wrapper
} // namespace ck
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include "tensor_utils.hpp" #include "tensor_utils.hpp"
#include "layout_utils.hpp" #include "layout_utils.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp" #include "ck/tensor_description/cluster_descriptor.hpp"
...@@ -14,6 +15,8 @@ namespace wrapper { ...@@ -14,6 +15,8 @@ namespace wrapper {
namespace { namespace {
namespace detail {
/** /**
* \brief Calculate shape for partition based on number of threads per each dim and * \brief Calculate shape for partition based on number of threads per each dim and
* previous shape * previous shape
...@@ -30,26 +33,109 @@ __host__ __device__ constexpr auto CalculateLocalPartitionShape(const Tuple<Ts.. ...@@ -30,26 +33,109 @@ __host__ __device__ constexpr auto CalculateLocalPartitionShape(const Tuple<Ts..
return generate_tuple( return generate_tuple(
[&](auto i) { [&](auto i) {
constexpr auto num_i = Number<i>{}; constexpr auto num_i = Number<i>{};
const auto slice_len = size<num_i>(shape) / thread_lengths.At(num_i); const auto slice_len =
ck::math::integer_divide_ceil(size<num_i>(shape), thread_lengths.At(num_i));
return slice_len; return slice_len;
}, },
Number<Tuple<Ls...>::Size()>{}); Number<Tuple<Ls...>::Size()>{});
} }
/**
* \brief Apply projection.
*
* \param base_tuple Tuple to apply projection.
* \param projection Projection to remove selected dim from partitioning.
* slice(X) to remove, where X is dim size, Number<1>{} to keep.
* \return Multi index after projection.
*/
template <typename MultiIndex, typename ProjectionTuple>
__host__ __device__ constexpr auto
ApplyProjection([[maybe_unused]] const MultiIndex& base_tuple,
[[maybe_unused]] const ProjectionTuple& projection)
{
if constexpr(is_same_v<ProjectionTuple, Tuple<>>)
{
return Tuple<>{};
}
else
{
auto base_tuple_after_projection = generate_tuple(
[&](auto i) {
const auto i_num = Number<i.value>{};
static_assert(
is_detected<is_slice, tuple_element_t<i_num, ProjectionTuple>>::value ||
is_same_v<tuple_element_t<i_num, ProjectionTuple>, Number<1>>);
if constexpr(is_detected<is_slice, tuple_element_t<i_num, ProjectionTuple>>::value)
{
// When slice (to remove), then insert empty tuple (will be removed in next
// step).
return Tuple<>{};
}
else
{
return base_tuple.At(i_num);
}
},
Number<MultiIndex::Size()>{});
// Remove empty tuples
return UnrollNestedTuple<0, 1>(base_tuple_after_projection);
}
}
/**
* \brief Calculate shape with dims from projection.
*
* \param shape Base tensor shape.
* \param projection Projection to remove selected dim from partitioning.
* slice(X) to remove, where X is dim size, Number<1>{} to keep.
* \return Shape with dims from projection
*/
template <typename... Ts, typename... Ps>
__host__ __device__ constexpr auto CalculateShapeWithProjection(const Tuple<Ts...>& shape,
const Tuple<Ps...>& projection)
{
return generate_tuple(
[&](auto i) {
if constexpr(is_detected<is_slice, tuple_element_t<i, Tuple<Ps...>>>::value)
{
return size<i>(projection).to_;
}
else
{
// number of shape element in actual fragment of shape and projection (method to
// calculate shape idx)
constexpr index_t shape_i =
detail::ApplyProjection(TupleSlice<0, i>(Tuple<Ts...>{}),
TupleSlice<0, i>(Tuple<Ps...>{}))
.Size();
return size<shape_i>(shape);
}
},
Number<Tuple<Ps...>::Size()>{});
}
/** /**
* \brief Calculate total number of blocks. * \brief Calculate total number of blocks.
* *
* \param shape Base tensor shape. * \param shape Base tensor shape.
* \param tile_shape Tile shape. * \param tile_shape Tile shape.
* \param projection Projection is used to remove selected dim from
* partitioning. Use `slice(X)` to remove dimension, where X is dim
* size. Use `Number<1>{}` to keep it.
* \return Tuple with blocks number. * \return Tuple with blocks number.
*/ */
template <typename... Ts, typename... Ls> template <typename... Ts, typename... Ls, typename... Ps>
__host__ __device__ constexpr auto CalculateGridSize(const Tuple<Ts...>& shape, __host__ __device__ constexpr auto CalculateGridSize(const Tuple<Ts...>& shape,
const Tuple<Ls...>& tile_shape) const Tuple<Ls...>& tile_shape,
const Tuple<Ps...>& projection)
{ {
static_assert(Tuple<Ts...>::Size() == Tuple<Ls...>::Size(), "Wrong thread_lengths shape."); auto shape_with_projection = CalculateShapeWithProjection(shape, projection);
return generate_tuple([&](auto i) { return size<i>(shape) / size<i>(tile_shape); }, return generate_tuple(
Number<Tuple<Ls...>::Size()>{}); [&](auto i) {
return ck::math::integer_divide_ceil(size<i>(shape_with_projection),
size<i>(tile_shape));
},
Number<Tuple<Ls...>::Size()>{});
} }
/** /**
...@@ -69,6 +155,20 @@ CalculateOffsetMultiIdxs(const ThreadIdxs& thread_idxs, ...@@ -69,6 +155,20 @@ CalculateOffsetMultiIdxs(const ThreadIdxs& thread_idxs,
return thread_idxs * partition_lengths_seq + old_offset_idxs; return thread_idxs * partition_lengths_seq + old_offset_idxs;
} }
/**
* \brief Calculate default projection.
*
* \param tile_shape Tile shape.
* \return Default projection (filled with Number<1>{}).
*/
template <typename TileShape>
__host__ __device__ constexpr auto
GenerateDefaultProjection([[maybe_unused]] const TileShape tile_shape)
{
return generate_tuple([&](auto) { return Number<1>{}; }, Number<TileShape::Size()>{});
}
} // namespace detail
} // namespace } // namespace
/** /**
...@@ -78,35 +178,45 @@ CalculateOffsetMultiIdxs(const ThreadIdxs& thread_idxs, ...@@ -78,35 +178,45 @@ CalculateOffsetMultiIdxs(const ThreadIdxs& thread_idxs,
* \param tensor Tensor for partition. * \param tensor Tensor for partition.
* \param thread_lengths Layout of threads (could not be nested). * \param thread_lengths Layout of threads (could not be nested).
* \param thread_id Thread index represented as integer. * \param thread_id Thread index represented as integer.
* \param projection Projection is used to remove selected dim from
* partitioning. Use `slice(X)` to remove dimension, where X is dim
* size. Use `Number<1>{}` to keep it.
* \return Partition tensor. * \return Partition tensor.
*/ */
template <typename TensorType, typename ThreadLengthsTuple> template <typename TensorType, typename ThreadLengthsTuple, typename ProjectionTuple>
__host__ __device__ constexpr auto __host__ __device__ constexpr auto
make_local_partition(TensorType& tensor, make_local_partition(TensorType& tensor,
[[maybe_unused]] const ThreadLengthsTuple& thread_lengths, [[maybe_unused]] const ThreadLengthsTuple& thread_lengths,
const index_t thread_id) const index_t thread_id,
const ProjectionTuple& projection)
{ {
static_assert(!IsNestedTuple(ThreadLengthsTuple{})); static_assert(!IsNestedTuple(ThreadLengthsTuple{}));
// Calculate new partition shape // Calculate new partition shape
const auto& tensor_shape = shape(tensor); const auto& tensor_shape = shape(tensor);
// Calculate projected thread lengths
constexpr auto projected_thread_lengths =
detail::ApplyProjection(ThreadLengthsTuple{}, ProjectionTuple{});
constexpr auto partition_shape = constexpr auto partition_shape =
CalculateLocalPartitionShape(decltype(tensor_shape){}, ThreadLengthsTuple{}); detail::CalculateLocalPartitionShape(decltype(tensor_shape){}, projected_thread_lengths);
// Create Thread Cluster Descriptor // Create Thread Cluster Descriptor
constexpr auto partition_lengths_seq = generate_sequence_v2( constexpr auto partition_shape_seq =
[&](auto I) { return size<I>(partition_shape); }, Number<ThreadLengthsTuple::Size()>{}); generate_sequence_v2([&](auto I) { return size<I>(partition_shape); },
Number<decltype(partition_shape)::Size()>{});
constexpr auto thread_lengths_seq = constexpr auto thread_lengths_seq =
generate_sequence_v2([&](auto I) { return size<I>(ThreadLengthsTuple{}); }, generate_sequence_v2([&](auto I) { return size<I>(ThreadLengthsTuple{}); },
Number<ThreadLengthsTuple::Size()>{}); Number<ThreadLengthsTuple::Size()>{});
constexpr auto thread_cluster_desc_ = make_cluster_descriptor(thread_lengths_seq); constexpr auto thread_cluster_desc_ = make_cluster_descriptor(thread_lengths_seq);
// Calculate thread idxs and offsets // Calculate thread idxs and offsets
const auto thread_idxs = thread_cluster_desc_.CalculateBottomIndex(make_multi_index(thread_id)); const auto thread_idxs = thread_cluster_desc_.CalculateBottomIndex(make_multi_index(thread_id));
const auto offset_multi_idxs = // Apply projection on thread idxs to remove not needed idxs
CalculateOffsetMultiIdxs(thread_idxs, partition_lengths_seq, tensor.GetMultiIdxOffsets()); const auto projected_thread_idxs = detail::ApplyProjection(thread_idxs, projection);
const auto offset_multi_idxs = detail::CalculateOffsetMultiIdxs(
projected_thread_idxs, partition_shape_seq, tensor.GetMultiIdxOffsets());
// Create new layout and tensor // Create new layout and tensor
auto& flatten_desc = layout(tensor).GetUnrolledDescriptor(); auto& unrolled_desc = layout(tensor).GetUnrolledDescriptor();
const auto partition_layout = const auto partition_layout =
Layout<remove_reference_t<decltype(partition_shape)>, decltype(flatten_desc)>( Layout<remove_reference_t<decltype(partition_shape)>, decltype(unrolled_desc)>(
partition_shape, flatten_desc); partition_shape, unrolled_desc);
auto partition_tensor = auto partition_tensor =
make_tensor<TensorType::TensorBufferAddressSpace>(tensor.GetPointer(), partition_layout); make_tensor<TensorType::TensorBufferAddressSpace>(tensor.GetPointer(), partition_layout);
// Apply offsets // Apply offsets
...@@ -114,6 +224,24 @@ make_local_partition(TensorType& tensor, ...@@ -114,6 +224,24 @@ make_local_partition(TensorType& tensor,
return partition_tensor; return partition_tensor;
} }
/**
* \brief Create local partition for thread (At now only packed partition
* is supported).
*
* \param tensor Tensor for partition.
* \param thread_lengths Layout of threads (could not be nested).
* \param thread_id Thread index represented as integer.
* \return Partition tensor.
*/
template <typename TensorType, typename ThreadLengthsTuple>
__host__ __device__ constexpr auto make_local_partition(TensorType& tensor,
const ThreadLengthsTuple& thread_lengths,
const index_t thread_id)
{
const auto projection = detail::GenerateDefaultProjection(ThreadLengthsTuple{});
return make_local_partition(tensor, thread_lengths, thread_id, projection);
}
/** /**
* \brief Create local tile for thread block. (At now only packed tile * \brief Create local tile for thread block. (At now only packed tile
* is supported). * is supported).
...@@ -125,22 +253,29 @@ make_local_partition(TensorType& tensor, ...@@ -125,22 +253,29 @@ make_local_partition(TensorType& tensor,
* \param tensor Tensor for partition. * \param tensor Tensor for partition.
* \param tile_shape Shapes of requested tile. * \param tile_shape Shapes of requested tile.
* \param block_id Block index represented as integer. * \param block_id Block index represented as integer.
* \param projection Projection to remove selected dim from partitioning.
* slice(X) to remove, where X is dim size, Number<1>{} to keep.
* \return Tile tensor. * \return Tile tensor.
*/ */
template <typename TensorType, typename BlockShapeTuple> template <typename TensorType, typename BlockShapeTuple, typename ProjectionTuple>
__host__ __device__ constexpr auto __host__ __device__ constexpr auto make_local_tile(const TensorType& tensor,
make_local_tile(const TensorType& tensor, const BlockShapeTuple& tile_shape, const index_t block_id) const BlockShapeTuple& tile_shape,
const index_t block_id,
const ProjectionTuple& projection)
{ {
static_assert(!IsNestedTuple(BlockShapeTuple{})); static_assert(!IsNestedTuple(BlockShapeTuple{}));
constexpr bool is_default_projection =
is_same_v<ProjectionTuple, decltype(detail::GenerateDefaultProjection(BlockShapeTuple{}))>;
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
auto& aligned_desc = layout(tensor).GetMergedNestingDescriptor(); auto& aligned_desc = layout(tensor).GetMergedNestingDescriptor();
if constexpr(BlockShapeTuple::Size() == I2) // TODO: Enable block_2_tile_map partitioning for non-default projection.
if constexpr(BlockShapeTuple::Size() == I2 && is_default_projection)
{ {
// Optimized version for 2d tile shape [MxK] // Optimized version for 2d tile shape [MxK]
const auto block_2_tile_map = const auto block_2_tile_map =
...@@ -169,20 +304,24 @@ make_local_tile(const TensorType& tensor, const BlockShapeTuple& tile_shape, con ...@@ -169,20 +304,24 @@ make_local_tile(const TensorType& tensor, const BlockShapeTuple& tile_shape, con
{ {
// Calculate offsets // Calculate offsets
// Sequence with data to process per block // Sequence with data to process per block
constexpr auto tile_shape_seq = constexpr auto projected_tile_shape =
generate_sequence_v2([](auto I) { return size(BlockShapeTuple{}.At(I)); }, detail::ApplyProjection(BlockShapeTuple{}, ProjectionTuple{});
Number<BlockShapeTuple::Size()>{}); using ProjectedTileShapeTuple = decltype(projected_tile_shape);
constexpr auto projected_tile_shape_seq =
generate_sequence_v2([](auto I) { return ProjectedTileShapeTuple{}.At(I); },
Number<ProjectedTileShapeTuple::Size()>{});
// Tuple with number of blocks // Tuple with number of blocks
const auto block_lengths = CalculateGridSize(shape(tensor), tile_shape); const auto block_lengths = detail::CalculateGridSize(shape(tensor), tile_shape, projection);
constexpr auto block_cluster_desc_ = make_cluster_descriptor(block_lengths); const auto block_cluster_desc_ = make_cluster_descriptor(block_lengths);
const auto block_idxs = const auto block_idxs =
block_cluster_desc_.CalculateBottomIndex(make_multi_index(block_id)); block_cluster_desc_.CalculateBottomIndex(make_multi_index(block_id));
const auto offset_multi_idxs = const auto projected_block_idxs = detail::ApplyProjection(block_idxs, projection);
CalculateOffsetMultiIdxs(block_idxs, tile_shape_seq, tensor.GetMultiIdxOffsets()); const auto offset_multi_idxs = detail::CalculateOffsetMultiIdxs(
projected_block_idxs, projected_tile_shape_seq, tensor.GetMultiIdxOffsets());
// Create new layout and tensor // Create new layout and tensor
const auto tile_layout = const auto tile_layout =
Layout<remove_reference_t<decltype(tile_shape)>, decltype(aligned_desc)>(tile_shape, Layout<remove_reference_t<ProjectedTileShapeTuple>, decltype(aligned_desc)>(
aligned_desc); projected_tile_shape, aligned_desc);
auto tile_tensor = auto tile_tensor =
make_tensor<TensorType::TensorBufferAddressSpace>(tensor.GetPointer(), tile_layout); make_tensor<TensorType::TensorBufferAddressSpace>(tensor.GetPointer(), tile_layout);
// Apply offsets // Apply offsets
...@@ -191,5 +330,61 @@ make_local_tile(const TensorType& tensor, const BlockShapeTuple& tile_shape, con ...@@ -191,5 +330,61 @@ make_local_tile(const TensorType& tensor, const BlockShapeTuple& tile_shape, con
} }
} }
/**
* \brief Create local tile for thread block. (At now only packed tile
* is supported).
*
* \note Currently to get the best performance please use 2d shape.
*
* \param tensor Tensor for partition.
* \param tile_shape Shapes of requested tile.
* \param block_id Block index represented as integer.
* \return Tile tensor.
*/
template <typename TensorType, typename BlockShapeTuple>
__host__ __device__ constexpr auto
make_local_tile(const TensorType& tensor, const BlockShapeTuple& tile_shape, const index_t block_id)
{
const auto projection = detail::GenerateDefaultProjection(BlockShapeTuple{});
return make_local_tile(tensor, tile_shape, block_id, projection);
}
/**
* \brief Pad tensor shapes to be adjusted to tile lengths.
*
*
* \param tensor Tensor to pad.
* \param tile_lengths Tile lengths to align tensor shape.
* \return Padded tensor.
*/
template <typename TensorType, typename TileLengths>
__host__ __device__ constexpr auto pad(const TensorType& tensor, const TileLengths& tile_lengths)
{
const auto& tensor_shape = shape(tensor);
using TensorShapeType = remove_reference_t<decltype(tensor_shape)>;
auto& unrolled_desc = layout(tensor).GetUnrolledDescriptor();
// Generate sequence with ones to mark that all dims will be padded
constexpr auto do_pads_seq =
generate_sequence_v2([](auto) { return Number<1>{}; }, Number<TensorShapeType::Size()>{});
// Create descriptor with padding
auto padded_desc =
tensor_operation::device::PadTensorDescriptor(unrolled_desc, tile_lengths, do_pads_seq);
// Generate padded shape
const auto padded_shape = generate_tuple(
[&](auto i) {
const auto& dim = size<i>(tensor_shape);
const auto& tile_length = size<i>(tile_lengths);
return ck::math::integer_divide_ceil(dim, tile_length) * tile_length;
},
Number<TileLengths::Size()>{});
// Create layout and tensor
const auto padded_layout =
Layout<decltype(padded_shape), decltype(padded_desc)>(padded_shape, padded_desc);
auto partition_tensor =
make_tensor<TensorType::TensorBufferAddressSpace>(tensor.GetPointer(), padded_layout);
partition_tensor.SetMultiIdxOffset(tensor.GetMultiIdxOffsets());
return partition_tensor;
}
} // namespace wrapper } // namespace wrapper
} // namespace ck } // namespace ck
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/number.hpp" #include "ck/utility/number.hpp"
#include "ck/utility/tuple.hpp" #include "ck/utility/tuple.hpp"
#include "ck/utility/tuple_helper.hpp" #include "ck/utility/tuple_helper.hpp"
...@@ -19,9 +20,9 @@ namespace wrapper { ...@@ -19,9 +20,9 @@ namespace wrapper {
* \brief Memory type, allowed members: * \brief Memory type, allowed members:
* - Generic, * - Generic,
* - Global, * - Global,
* - LDS, * - Lds,
* - SGPR, * - Sgpr,
* - VGPR, * - Vgpr,
*/ */
using MemoryTypeEnum = AddressSpaceEnum; using MemoryTypeEnum = AddressSpaceEnum;
...@@ -52,12 +53,8 @@ struct Slice ...@@ -52,12 +53,8 @@ struct Slice
__host__ __device__ constexpr auto range(const T& dim) const __host__ __device__ constexpr auto range(const T& dim) const
{ {
if constexpr(is_same_v<FromType, index_t> || is_same_v<ToType, index_t> || if constexpr(is_same_v<FromType, index_t> || is_same_v<ToType, index_t> ||
is_same_v<T, index_t>) is_same_v<std::remove_const_t<T>, index_t>)
{ {
if(!(dim >= to_ && from_ >= 0 && (to_ < 0 || to_ > from_)))
{
throw std::runtime_error("Invalid range");
}
if(to_ < 0) if(to_ < 0)
{ {
return dim - from_ + to_ + 1; return dim - from_ + to_ + 1;
...@@ -70,9 +67,10 @@ struct Slice ...@@ -70,9 +67,10 @@ struct Slice
} }
else else
{ {
static_assert(dim >= to_ && from_ >= Number<0>{} && (to_ < 0 || to_ > from_), static_assert(T{} >= ToType{} && FromType{} >= Number<0>{} &&
(ToType{} < 0 || ToType{} > FromType{}),
"Invalid range"); "Invalid range");
if constexpr(to_ < 0) if constexpr(ToType{} < 0)
{ {
return dim - from_ + to_ + Number<1>{}; return dim - from_ + to_ + Number<1>{};
} }
...@@ -130,6 +128,23 @@ constexpr auto make_register_tensor(const Layout<Shape, UnrolledDescriptorType>& ...@@ -130,6 +128,23 @@ constexpr auto make_register_tensor(const Layout<Shape, UnrolledDescriptorType>&
return Tensor<MemoryType, ElementType, Shape, UnrolledDescriptorType>(layout); return Tensor<MemoryType, ElementType, Shape, UnrolledDescriptorType>(layout);
} }
/**
* \brief Clear tensor. (Only for Vpgr/Sgpr)
*
* \param tensor Tensor to be cleared.
*/
template <MemoryTypeEnum BufferAddressSpace,
typename ElementType,
typename Shape,
typename UnrolledDescriptorType>
__host__ __device__ void
clear(Tensor<BufferAddressSpace, ElementType, Shape, UnrolledDescriptorType>& tensor)
{
static_assert(
!Tensor<BufferAddressSpace, ElementType, Shape, UnrolledDescriptorType>::IsDynamicBuffer);
return tensor.GetBuffer().Clear();
}
/** /**
* \brief Get Tensor Layout. * \brief Get Tensor Layout.
* *
......
...@@ -6,3 +6,9 @@ add_gtest_executable(test_copy test_copy.cpp) ...@@ -6,3 +6,9 @@ add_gtest_executable(test_copy test_copy.cpp)
target_link_libraries(test_copy PRIVATE utility) target_link_libraries(test_copy PRIVATE utility)
add_gtest_executable(test_partition test_partition.cpp) add_gtest_executable(test_partition test_partition.cpp)
target_link_libraries(test_partition PRIVATE utility) target_link_libraries(test_partition PRIVATE utility)
if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR
GPU_TARGETS MATCHES "gfx940" OR GPU_TARGETS MATCHES "gfx941" OR
GPU_TARGETS MATCHES "gfx942")
add_gtest_executable(test_gemm test_gemm.cpp)
target_link_libraries(test_gemm PRIVATE utility)
endif()
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <numeric>
#include <cstdlib>
#include <iostream>
#include <initializer_list>
#include <vector>
#include <gtest/gtest.h>
#include "ck/library/utility/host_tensor.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/wrapper/layout.hpp"
#include "ck/wrapper/tensor.hpp"
#include "ck/wrapper/operations/copy.hpp"
#include "ck/wrapper/operations/gemm.hpp"
template <typename DataType>
void CheckResult(const std::vector<DataType>& a_data,
const std::vector<DataType>& b_data,
std::vector<DataType>& c_m_n_device_result,
const ck::index_t M,
const ck::index_t N,
const ck::index_t K)
{
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<DataType, DataType, DataType, float, PassThrough, PassThrough, PassThrough>;
Tensor<DataType> a_m_k(HostTensorDescriptor({M, K}));
Tensor<DataType> b_k_n(HostTensorDescriptor({K, N}, {1, K}));
Tensor<DataType> c_m_n_host_result(HostTensorDescriptor({M, N}));
a_m_k.mData = a_data;
b_k_n.mData = b_data;
auto ref_op = ReferenceGemmInstance{};
auto ref_invoker = ref_op.MakeInvoker();
auto ref_argument = ref_op.MakeArgument(
a_m_k, b_k_n, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{});
ref_invoker.Run(ref_argument);
EXPECT_TRUE(ck::utils::check_err(c_m_n_device_result, c_m_n_host_result.mData));
}
template <typename DataType,
typename GemmTraits,
ck::index_t scalar_per_vector,
typename BlockShape,
typename ThreadLayoutShape>
__global__ void DeviceGemm(const void* p_a,
const void* p_b,
void* p_c,
const ck::index_t M,
const ck::index_t N,
const ck::index_t K,
const BlockShape tile_shape,
const ThreadLayoutShape thread_layout)
{
constexpr auto MPerBlock = ck::wrapper::size<0>(tile_shape);
constexpr auto NPerBlock = ck::wrapper::size<1>(tile_shape);
constexpr auto KPerBlock = ck::wrapper::size<2>(tile_shape);
const auto a_global_layout =
ck::wrapper::make_layout(ck::make_tuple(M, K), ck::make_tuple(K, 1));
const auto b_global_layout =
ck::wrapper::make_layout(ck::make_tuple(N, K), ck::make_tuple(K, 1));
const auto c_global_layout =
ck::wrapper::make_layout(ck::make_tuple(M, N), ck::make_tuple(N, 1));
constexpr auto a_tile_layout = ck::wrapper::make_layout(
ck::make_tuple(MPerBlock, KPerBlock), ck::make_tuple(KPerBlock, ck::Number<1>{}));
constexpr auto b_tile_layout = ck::wrapper::make_layout(
ck::make_tuple(NPerBlock, KPerBlock), ck::make_tuple(KPerBlock, ck::Number<1>{}));
constexpr auto c_tile_layout = ck::wrapper::make_layout(
ck::make_tuple(MPerBlock, NPerBlock), ck::make_tuple(NPerBlock, ck::Number<1>{}));
auto a_global_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>(
static_cast<const DataType*>(p_a), a_global_layout);
auto b_global_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>(
static_cast<const DataType*>(p_b), b_global_layout);
auto c_global_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>(
static_cast<DataType*>(p_c), c_global_layout);
auto a_padded_global_tensor = ck::wrapper::pad(a_global_tensor, shape(a_tile_layout));
auto b_padded_global_tensor = ck::wrapper::pad(b_global_tensor, shape(b_tile_layout));
auto c_padded_global_tensor = ck::wrapper::pad(c_global_tensor, shape(c_tile_layout));
__shared__ DataType lds_a[ck::wrapper::size(a_tile_layout)];
__shared__ DataType lds_b[ck::wrapper::size(b_tile_layout)];
auto a_lds_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Lds>(
static_cast<DataType*>(lds_a), a_tile_layout);
auto b_lds_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Lds>(
static_cast<DataType*>(lds_b), b_tile_layout);
const ck::index_t block_idx = static_cast<ck::index_t>(blockIdx.x);
using DimAccessOrder = ck::Tuple<ck::Number<0>, ck::Number<1>>;
constexpr ck::index_t vector_dim = 1;
auto c_global_local_tile = ck::wrapper::make_local_tile(
c_padded_global_tensor,
tile_shape,
block_idx,
make_tuple(ck::Number<1>{}, ck::Number<1>{}, ck::wrapper::slice(KPerBlock)));
auto c_global_local_partition =
ck::wrapper::make_blockwise_gemm_xdl_c_local_partition<DataType,
decltype(a_tile_layout),
decltype(b_tile_layout),
ck::wrapper::size(thread_layout),
GemmTraits>(c_global_local_tile);
auto c_vgpr_reg = ck::wrapper::make_blockwise_gemm_xdl_c_vgpr<DataType,
decltype(a_tile_layout),
decltype(b_tile_layout),
ck::wrapper::size(thread_layout),
GemmTraits>();
ck::wrapper::clear(c_vgpr_reg);
const ck::index_t num_loop = ck::math::integer_divide_ceil(K, KPerBlock);
ck::index_t i = 0;
do
{
const auto k_slice = ck::wrapper::slice(i * KPerBlock, (i + 1) * KPerBlock);
auto a_padded_global_tensor_k_slice = a_padded_global_tensor(ck::wrapper::slice(), k_slice);
auto b_padded_global_tensor_k_slice = b_padded_global_tensor(ck::wrapper::slice(), k_slice);
auto a_global_local_tile = ck::wrapper::make_local_tile(
a_padded_global_tensor_k_slice,
tile_shape,
block_idx,
make_tuple(ck::Number<1>{}, ck::wrapper::slice(N), ck::Number<1>{}));
auto b_global_local_tile = ck::wrapper::make_local_tile(
b_padded_global_tensor_k_slice,
tile_shape,
block_idx,
make_tuple(ck::wrapper::slice(M), ck::Number<1>{}, ck::Number<1>{}));
ck::wrapper::blockwise_copy<DimAccessOrder, vector_dim, scalar_per_vector>(
a_global_local_tile, a_lds_tensor, thread_layout);
ck::wrapper::blockwise_copy<DimAccessOrder, vector_dim, scalar_per_vector>(
b_global_local_tile, b_lds_tensor, thread_layout);
ck::block_sync_lds();
ck::wrapper::blockwise_gemm_xdl<DataType, ck::wrapper::size(thread_layout), GemmTraits>(
a_lds_tensor, b_lds_tensor, c_vgpr_reg);
++i;
} while(i < num_loop);
ck::wrapper::copy(c_vgpr_reg, c_global_local_partition);
}
template <typename DataType,
typename GemmTraits,
ck::index_t scalar_per_vector,
typename BlockShape,
typename ThreadLayoutShape>
void PerformGemm(const ck::index_t M,
const ck::index_t N,
const ck::index_t K,
const BlockShape& tile_shape,
const ThreadLayoutShape& thread_layout)
{
// Global memory buffers
DeviceMem a_mem(M * K * sizeof(DataType));
DeviceMem b_mem(K * N * sizeof(DataType));
DeviceMem c_mem(M * N * sizeof(DataType));
std::vector<DataType> a_data(M * K);
std::vector<DataType> b_data(K * N);
ck::utils::FillUniformDistributionIntegerValue<DataType>{-5.f, 5.f}(a_data);
ck::utils::FillUniformDistributionIntegerValue<DataType>{-5.f, 5.f}(b_data);
a_mem.ToDevice(a_data.data());
b_mem.ToDevice(b_data.data());
c_mem.SetZero();
const ck::index_t grid_size =
ck::math::integer_divide_ceil(M, ck::wrapper::size<0>(tile_shape)) *
ck::math::integer_divide_ceil(N, ck::wrapper::size<1>(tile_shape));
const auto kernel =
DeviceGemm<DataType, GemmTraits, scalar_per_vector, BlockShape, ThreadLayoutShape>;
launch_and_time_kernel(StreamConfig{nullptr},
kernel,
dim3(grid_size),
dim3(ck::wrapper::size(thread_layout)),
0,
a_mem.GetDeviceBuffer(),
b_mem.GetDeviceBuffer(),
c_mem.GetDeviceBuffer(),
M,
N,
K,
tile_shape,
thread_layout);
std::vector<DataType> c_data(M * N);
c_mem.FromDevice(c_data.data());
CheckResult<DataType>(a_data, b_data, c_data, M, N, K);
}
TEST(TestGemm, Float)
{
using DataType = float;
const auto thread_layout = ck::make_tuple(ck::Number<16>{}, ck::Number<16>{});
const auto tile_shape = ck::make_tuple(ck::Number<128>{}, ck::Number<128>{}, ck::Number<64>{});
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_4K1, 4>(
512, 512, 128, tile_shape, thread_layout);
// Irregular case
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_4K1, 1>(
129, 129, 67, tile_shape, thread_layout);
}
TEST(TestGemm, Int8)
{
using DataType = int8_t;
const auto thread_layout = ck::make_tuple(ck::Number<64>{}, ck::Number<4>{});
const auto tile_shape = ck::make_tuple(ck::Number<128>{}, ck::Number<128>{}, ck::Number<64>{});
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_4K1, 16>(
512, 512, 128, tile_shape, thread_layout);
// Irregular case
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_4K1, 1>(
129, 129, 67, tile_shape, thread_layout);
}
TEST(TestGemm, Half)
{
using DataType = ck::half_t;
const auto thread_layout = ck::make_tuple(ck::Number<32>{}, ck::Number<8>{});
const auto tile_shape = ck::make_tuple(ck::Number<128>{}, ck::Number<128>{}, ck::Number<64>{});
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_4K1, 8>(
512, 512, 128, tile_shape, thread_layout);
// Irregular case
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_4K1, 1>(
129, 129, 67, tile_shape, thread_layout);
}
TEST(TestGemm, Float_2x4_4x2_XdlPerWave)
{
using DataType = float;
const auto thread_layout_4x2_xdl_per_wave = ck::make_tuple(ck::Number<16>{}, ck::Number<8>{});
const auto thread_layout_2x4_xdl_per_wave = ck::make_tuple(ck::Number<8>{}, ck::Number<16>{});
const auto tile_shape = ck::make_tuple(ck::Number<128>{}, ck::Number<128>{}, ck::Number<64>{});
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_4K1, 4>(
512, 512, 128, tile_shape, thread_layout_4x2_xdl_per_wave);
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_4K1, 4>(
512, 512, 128, tile_shape, thread_layout_2x4_xdl_per_wave);
}
...@@ -29,17 +29,24 @@ TEST(TestPartition, LocalPartition) ...@@ -29,17 +29,24 @@ TEST(TestPartition, LocalPartition)
const auto tensor = const auto tensor =
ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Generic>(data.data(), layout); ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Generic>(data.data(), layout);
const auto thread_steps = ck::make_tuple(ck::Number<8>{}, ck::Number<1>{}); const auto thread_steps = ck::make_tuple(ck::Number<1>{}, ck::Number<8>{}, ck::Number<1>{});
const auto thread_layout = ck::make_tuple(ck::Number<8>{}, ck::Number<1>{}); const auto thread_layout = ck::make_tuple(ck::Number<4>{}, ck::Number<8>{}, ck::Number<1>{});
// 3d partition on 2d shape (calculate partition on 3d thread layout, and then skip first dim)
const auto thread_projection =
ck::make_tuple(ck::wrapper::slice(4), ck::Number<1>{}, ck::Number<1>{});
constexpr ck::index_t projection_thread_length = ck::Number<4>{};
for(ck::index_t thread_id = 0; thread_id < ck::wrapper::size(thread_layout); thread_id++) for(ck::index_t thread_id = 0;
thread_id < ck::wrapper::size(thread_layout) / projection_thread_length;
thread_id++)
{ {
const auto packed_partition = const auto packed_partition =
ck::wrapper::make_local_partition(tensor, thread_layout, thread_id); ck::wrapper::make_local_partition(tensor, thread_layout, thread_id, thread_projection);
const auto expected_partition_size = const auto expected_partition_size =
ck::wrapper::size(tensor) / ck::wrapper::size(thread_layout); ck::wrapper::size(tensor) /
const auto expected_partition_first_val = thread_id * ck::wrapper::size<0>(thread_steps); (ck::wrapper::size(thread_layout) / projection_thread_length);
const auto expected_partition_first_val = thread_id * ck::wrapper::size<1>(thread_steps);
const auto expected_partition_second_val = expected_partition_first_val + 1; const auto expected_partition_second_val = expected_partition_first_val + 1;
EXPECT_EQ(ck::wrapper::size(packed_partition), expected_partition_size); EXPECT_EQ(ck::wrapper::size(packed_partition), expected_partition_size);
EXPECT_EQ(packed_partition(0), expected_partition_first_val); EXPECT_EQ(packed_partition(0), expected_partition_first_val);
...@@ -58,8 +65,12 @@ TEST(TestPartition, LocalTile) ...@@ -58,8 +65,12 @@ TEST(TestPartition, LocalTile)
const auto tensor = const auto tensor =
ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Generic>(data.data(), layout); ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Generic>(data.data(), layout);
// 4d tile partitioning on 3d shape (calculate tile on 4d tile layout, and then skip last dim)
const auto block_shape = ck::make_tuple(ck::Number<2>{}, ck::Number<4>{}, ck::Number<2>{}); const auto block_shape =
ck::make_tuple(ck::Number<2>{}, ck::Number<4>{}, ck::Number<2>{}, ck::Number<2>{});
const auto block_projection =
ck::make_tuple(ck::Number<1>{}, ck::Number<1>{}, ck::Number<1>{}, ck::wrapper::slice(2));
constexpr ck::index_t projection_block_dim = ck::Number<2>{};
const auto num_blocks = const auto num_blocks =
ck::make_tuple(ck::wrapper::size<0>(shape) / ck::wrapper::size<0>(block_shape), ck::make_tuple(ck::wrapper::size<0>(shape) / ck::wrapper::size<0>(block_shape),
ck::wrapper::size<1>(shape) / ck::wrapper::size<1>(block_shape), ck::wrapper::size<1>(shape) / ck::wrapper::size<1>(block_shape),
...@@ -69,9 +80,10 @@ TEST(TestPartition, LocalTile) ...@@ -69,9 +80,10 @@ TEST(TestPartition, LocalTile)
for(auto block_idx : block_idxs) for(auto block_idx : block_idxs)
{ {
const auto packed_tile = ck::wrapper::make_local_tile(tensor, block_shape, block_idx); const auto packed_tile =
ck::wrapper::make_local_tile(tensor, block_shape, block_idx, block_projection);
const auto expected_tile_size = ck::wrapper::size(block_shape); const auto expected_tile_size = ck::wrapper::size(block_shape) / projection_block_dim;
auto expected_tile_first_val = (block_idx % ck::wrapper::size<2>(num_blocks)) * auto expected_tile_first_val = (block_idx % ck::wrapper::size<2>(num_blocks)) *
ck::wrapper::size<2>(block_shape) * ck::wrapper::size<2>(block_shape) *
ck::wrapper::size<2>(strides); ck::wrapper::size<2>(strides);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment