Unverified Commit 29dcb956 authored by Illia Silin's avatar Illia Silin Committed by GitHub
Browse files

Merge pull request #33 from ROCm/lwpck-1292

Merge from the public repo.
parents 29deceb6 cbcc844e
......@@ -54,7 +54,7 @@ __global__ void
const Block2CTileMap block_2_ctile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
defined(__gfx94__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(
......
......@@ -35,9 +35,8 @@ __global__ void
const Block2ETileMap block_2_tile_map,
const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx1030__) || defined(__gfx1100__) || \
defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx941__) || defined(__gfx942__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__))
GridwiseTensorRearrangeKernel::Run(in_grid_desc,
p_in_global,
out_grid_desc,
......@@ -50,7 +49,9 @@ __global__ void
ignore = p_in_global;
ignore = out_grid_desc;
ignore = p_out_global;
ignore = batch_count;
ignore = block_2_tile_map;
ignore = compute_ptr_offset_of_batch;
#endif
}
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp"
namespace ck {
// Tensor Shape
// dy, x = [M, K], gamma = [1, K], x_mean, inv_std = [M, 1]
// Flow:
// def normalization_backward_x(dy, x, gamma, x_mean, inv_std, reduce_axis, reduce_size):
// ds = np.sum(dy * gamma * x, axis=reduce_axis, keepdims=True)
// db = np.sum(dy * gamma, axis=reduce_axis, keepdims=True)
// b = (db * x_mean - ds) * inv_std ** (3) / reduce_size
// c = -b * x_mean - db * inv_std / reduce_size
// dx = inv_std * dy * gamma + b * x + c
// return dx
template <typename DYDataType,
typename XDataType,
typename GammaDataType,
typename MeanInvStdDataType,
typename ComputeDataType,
typename DXDataType,
typename GridDesc_M_K,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t DYSrcVectorDim,
index_t DYSrcVectorSize,
index_t XSrcVectorDim,
index_t XSrcVectorSize,
index_t GammaSrcVectorDim,
index_t GammaSrcVectorSize,
index_t MeanInvStdSrcVectorDim,
index_t MeanInvStdSrcVectorSize,
index_t DXDstVectorDim,
index_t DXDstVectorSize,
bool SweepOnce>
struct GridwiseNormalizationBwdData_mk_to_mk
{
// if we just check ThreadSliceSize % VectorSize == 0, the performance may be poor (coalesce)
static_assert(((DYSrcVectorDim == 0 && MThreadSliceSize == DYSrcVectorSize) ||
(DYSrcVectorDim == 1 && KThreadSliceSize == DYSrcVectorSize)),
"Invalid thread slice sizes and/or dy vector sizes configuration, please check!");
static_assert(((XSrcVectorDim == 0 && MThreadSliceSize == XSrcVectorSize) ||
(XSrcVectorDim == 1 && KThreadSliceSize == XSrcVectorSize)),
"Invalid thread slice sizes and/or x vector sizes configuration, please check!");
static_assert(
((GammaSrcVectorDim == 0 && MThreadSliceSize == GammaSrcVectorSize) ||
(GammaSrcVectorDim == 1 && KThreadSliceSize == GammaSrcVectorSize)),
"Invalid thread slice sizes and/or gamma vector sizes configuration, please check!");
static_assert(
((MeanInvStdSrcVectorDim == 0 && MThreadSliceSize == MeanInvStdSrcVectorSize) ||
(MeanInvStdSrcVectorDim == 1 && KThreadSliceSize == MeanInvStdSrcVectorSize)),
"Invalid thread slice sizes and/or mean/inv_std vector sizes configuration, please check!");
static_assert(((DXDstVectorDim == 0 && MThreadSliceSize == DXDstVectorSize) ||
(DXDstVectorDim == 1 && KThreadSliceSize == DXDstVectorSize)),
"Invalid thread slice sizes and/or dx vector sizes configuration, please check!");
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
using DYThreadBufferDimAccessOrder =
typename conditional<DYSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type;
using XThreadBufferDimAccessOrder =
typename conditional<XSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type;
using GammaThreadBufferDimAccessOrder =
typename conditional<GammaSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type;
using MeanInvStdThreadBufferDimAccessOrder =
typename conditional<MeanInvStdSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type;
using DXThreadBufferDimAccessOrder =
typename conditional<DXDstVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type;
using ThreadClusterArrangeOrder = DYThreadBufferDimAccessOrder;
static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, KThreadSliceSize>;
static constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
static constexpr auto thread_buffer_desc_m =
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
using PassThroughOp = tensor_operation::element_wise::PassThrough;
using BlockwiseSumReduce = PartitionedBlockwiseReduction<ComputeDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
reduce::Add,
true>;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
__device__ static void Run(const GridDesc_M_K& dy_grid_desc_m_k,
const GridDesc_M_K& x_grid_desc_m_k,
const GridDesc_M_K& gamma_grid_desc_m_k,
const GridDesc_M_K& mean_grid_desc_m_k,
const GridDesc_M_K& inv_std_grid_desc_m_k,
const GridDesc_M_K& dx_grid_desc_m_k,
index_t num_k_block_tile_iteration,
const DYDataType* const __restrict__ p_dy_global,
const XDataType* const __restrict__ p_x_global,
const GammaDataType* const __restrict__ p_gamma_global,
const MeanInvStdDataType* const __restrict__ p_mean_global,
const MeanInvStdDataType* const __restrict__ p_inv_std_global,
DXDataType* const __restrict__ p_dx_global)
{
// LDS
__shared__ ComputeDataType p_reduce_work_buffer[BlockSize];
auto reduce_work_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
// Global
const auto dy_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_dy_global, dy_grid_desc_m_k.GetElementSpaceSize());
const auto x_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_x_global, x_grid_desc_m_k.GetElementSpaceSize());
auto gamma_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_gamma_global, gamma_grid_desc_m_k.GetElementSpaceSize());
const auto mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_mean_global, mean_grid_desc_m_k.GetElementSpaceSize());
const auto inv_std_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_inv_std_global, inv_std_grid_desc_m_k.GetElementSpaceSize());
auto dx_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_dx_global, dx_grid_desc_m_k.GetElementSpaceSize());
// VGPR
auto dy_thread_buf = StaticBuffer<AddressSpaceEnum::Vgpr,
ComputeDataType,
MThreadSliceSize * KThreadSliceSize,
true>{};
auto x_thread_buf = StaticBuffer<AddressSpaceEnum::Vgpr,
ComputeDataType,
MThreadSliceSize * KThreadSliceSize,
true>{};
auto gamma_thread_buf = StaticBuffer<AddressSpaceEnum::Vgpr,
ComputeDataType,
MThreadSliceSize * KThreadSliceSize,
true>{};
auto mean_thread_buf = StaticBuffer<AddressSpaceEnum::Vgpr,
ComputeDataType,
MThreadSliceSize * KThreadSliceSize,
true>{};
auto inv_std_thread_buf = StaticBuffer<AddressSpaceEnum::Vgpr,
ComputeDataType,
MThreadSliceSize * KThreadSliceSize,
true>{};
auto dx_thread_buf = StaticBuffer<AddressSpaceEnum::Vgpr,
ComputeDataType,
MThreadSliceSize * KThreadSliceSize,
true>{};
auto ds_thread_buf =
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>{};
auto db_thread_buf =
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>{};
// thread id
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id();
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1];
// IO
auto threadwise_dy_load = ThreadwiseTensorSliceTransfer_v2<DYDataType,
ComputeDataType,
GridDesc_M_K,
decltype(thread_buffer_desc_m_k),
ThreadBufferLengths_M_K,
DYThreadBufferDimAccessOrder,
DYSrcVectorDim,
DYSrcVectorSize,
1,
false>(
dy_grid_desc_m_k,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType,
ComputeDataType,
GridDesc_M_K,
decltype(thread_buffer_desc_m_k),
ThreadBufferLengths_M_K,
XThreadBufferDimAccessOrder,
XSrcVectorDim,
XSrcVectorSize,
1,
false>(
x_grid_desc_m_k,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
auto threadwise_gamma_load =
ThreadwiseTensorSliceTransfer_v2<GammaDataType,
ComputeDataType,
GridDesc_M_K,
decltype(thread_buffer_desc_m_k),
ThreadBufferLengths_M_K,
XThreadBufferDimAccessOrder,
GammaSrcVectorDim,
GammaSrcVectorSize,
1,
false>(
gamma_grid_desc_m_k,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
auto threadwise_mean_load =
ThreadwiseTensorSliceTransfer_v2<MeanInvStdDataType,
ComputeDataType,
GridDesc_M_K,
decltype(thread_buffer_desc_m_k),
ThreadBufferLengths_M_K,
MeanInvStdThreadBufferDimAccessOrder,
MeanInvStdSrcVectorDim,
MeanInvStdSrcVectorSize,
1,
false>(
mean_grid_desc_m_k,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
auto threadwise_inv_std_load =
ThreadwiseTensorSliceTransfer_v2<MeanInvStdDataType,
ComputeDataType,
GridDesc_M_K,
decltype(thread_buffer_desc_m_k),
ThreadBufferLengths_M_K,
MeanInvStdThreadBufferDimAccessOrder,
MeanInvStdSrcVectorDim,
MeanInvStdSrcVectorSize,
1,
false>(
inv_std_grid_desc_m_k,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
auto threadwise_dx_store =
ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType,
DXDataType,
decltype(thread_buffer_desc_m_k),
GridDesc_M_K,
PassThroughOp,
ThreadBufferLengths_M_K,
DXThreadBufferDimAccessOrder,
DXDstVectorDim,
DXDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
false>(
dx_grid_desc_m_k,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize),
PassThroughOp{});
ComputeDataType reduce_size = type_convert<ComputeDataType>(
dy_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0]);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
ds_thread_buf(I) = type_convert<ComputeDataType>(0.0f);
db_thread_buf(I) = type_convert<ComputeDataType>(0.0f);
});
// Separate sweep once and sweep twice pipeline
// Sweep once: for small k, if KThreadClusterSize * KThreadSliceSize > K
// we don't need to use loop to read x, dy, gamma twice
if constexpr(SweepOnce)
{
threadwise_dy_load.Run(dy_grid_desc_m_k,
dy_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
dy_thread_buf);
threadwise_x_load.Run(x_grid_desc_m_k,
x_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
x_thread_buf);
threadwise_gamma_load.Run(gamma_grid_desc_m_k,
gamma_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
gamma_thread_buf);
threadwise_mean_load.Run(mean_grid_desc_m_k,
mean_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
mean_thread_buf);
threadwise_inv_std_load.Run(inv_std_grid_desc_m_k,
inv_std_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
inv_std_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
constexpr auto offset_m =
Number<thread_buffer_desc_m.CalculateOffset(make_tuple(iM))>{};
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset_m_k =
Number<thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK))>{};
ds_thread_buf(offset_m) += dy_thread_buf[offset_m_k] *
gamma_thread_buf[offset_m_k] *
x_thread_buf[offset_m_k];
db_thread_buf(offset_m) +=
dy_thread_buf[offset_m_k] * gamma_thread_buf[offset_m_k];
});
});
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if constexpr(I > 0)
block_sync_lds();
BlockwiseSumReduce::Reduce(reduce_work_buf, ds_thread_buf(I));
block_sync_lds();
BlockwiseSumReduce::Reduce(reduce_work_buf, db_thread_buf(I));
});
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
constexpr auto offset_m =
Number<thread_buffer_desc_m.CalculateOffset(make_tuple(iM))>{};
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset_m_k =
Number<thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK))>{};
// b = (db * x_mean - ds) * rstd ** (3) / reduce_size
// c = -b * x_mean - db * rstd / reduce_size
// dx = rstd * dy * gamma + b * x + c
ComputeDataType b = db_thread_buf[offset_m] * mean_thread_buf[offset_m_k] -
ds_thread_buf[offset_m];
b *= inv_std_thread_buf[offset_m_k] * inv_std_thread_buf[offset_m_k] *
inv_std_thread_buf[offset_m_k] / reduce_size;
ComputeDataType c = -b * mean_thread_buf(offset_m_k);
c -= db_thread_buf[offset_m] * inv_std_thread_buf[offset_m_k] / reduce_size;
dx_thread_buf(offset_m_k) = dy_thread_buf[offset_m_k] *
gamma_thread_buf[offset_m_k] *
inv_std_thread_buf[offset_m_k] +
b * x_thread_buf[offset_m_k] + c;
});
});
threadwise_dx_store.Run(thread_buffer_desc_m_k,
make_tuple(I0, I0),
dx_thread_buf,
dx_grid_desc_m_k,
dx_global_val_buf);
} // end of sweep once
else // Sweep Twice pipeline
{
constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileSize);
for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
{
threadwise_dy_load.Run(dy_grid_desc_m_k,
dy_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
dy_thread_buf);
threadwise_x_load.Run(x_grid_desc_m_k,
x_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
x_thread_buf);
threadwise_gamma_load.Run(gamma_grid_desc_m_k,
gamma_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
gamma_thread_buf);
threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, thread_copy_fwd_step_m_k);
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k,
thread_copy_fwd_step_m_k);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
constexpr auto offset_m =
Number<thread_buffer_desc_m.CalculateOffset(make_tuple(iM))>{};
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset_m_k =
Number<thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK))>{};
ds_thread_buf(offset_m) += dy_thread_buf[offset_m_k] *
gamma_thread_buf[offset_m_k] *
x_thread_buf[offset_m_k];
db_thread_buf(offset_m) +=
dy_thread_buf[offset_m_k] * gamma_thread_buf[offset_m_k];
});
});
} // end of first sweep
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if constexpr(I > 0)
block_sync_lds();
BlockwiseSumReduce::Reduce(reduce_work_buf, ds_thread_buf(I));
block_sync_lds();
BlockwiseSumReduce::Reduce(reduce_work_buf, db_thread_buf(I));
});
// reverse read for using dy, gamma and x in the cache
constexpr auto thread_copy_bwd_step_m_k = make_multi_index(0, -K_BlockTileSize);
auto thread_copy_tail_m_k = (num_k_block_tile_iteration - 1) * thread_copy_fwd_step_m_k;
// move to tail
threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, thread_copy_bwd_step_m_k);
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k);
threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, thread_copy_bwd_step_m_k);
// move from start to tail
threadwise_mean_load.MoveSrcSliceWindow(mean_grid_desc_m_k, thread_copy_tail_m_k);
threadwise_inv_std_load.MoveSrcSliceWindow(inv_std_grid_desc_m_k, thread_copy_tail_m_k);
threadwise_dx_store.MoveDstSliceWindow(dx_grid_desc_m_k, thread_copy_tail_m_k);
for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
{
threadwise_dy_load.Run(dy_grid_desc_m_k,
dy_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
dy_thread_buf);
threadwise_x_load.Run(x_grid_desc_m_k,
x_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
x_thread_buf);
threadwise_gamma_load.Run(gamma_grid_desc_m_k,
gamma_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
gamma_thread_buf);
threadwise_mean_load.Run(mean_grid_desc_m_k,
mean_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
mean_thread_buf);
threadwise_inv_std_load.Run(inv_std_grid_desc_m_k,
inv_std_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
inv_std_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
constexpr auto offset_m =
Number<thread_buffer_desc_m.CalculateOffset(make_tuple(iM))>{};
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset_m_k =
Number<thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK))>{};
// b = (db * x_mean - ds) * rstd ** (3) / reduce_size
// c = -b * x_mean - db * rstd / reduce_size
// dx = rstd * dy * gamma + b * x + c
ComputeDataType b = db_thread_buf[offset_m] * mean_thread_buf[offset_m_k] -
ds_thread_buf[offset_m];
b *= inv_std_thread_buf[offset_m_k] * inv_std_thread_buf[offset_m_k] *
inv_std_thread_buf[offset_m_k] / reduce_size;
ComputeDataType c = -b * mean_thread_buf(offset_m_k);
c -= db_thread_buf[offset_m] * inv_std_thread_buf[offset_m_k] / reduce_size;
dx_thread_buf(offset_m_k) = dy_thread_buf[offset_m_k] *
gamma_thread_buf[offset_m_k] *
inv_std_thread_buf[offset_m_k] +
b * x_thread_buf[offset_m_k] + c;
});
});
threadwise_dx_store.Run(thread_buffer_desc_m_k,
make_tuple(I0, I0),
dx_thread_buf,
dx_grid_desc_m_k,
dx_global_val_buf);
threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, thread_copy_bwd_step_m_k);
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k);
threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k,
thread_copy_bwd_step_m_k);
threadwise_mean_load.MoveSrcSliceWindow(mean_grid_desc_m_k,
thread_copy_bwd_step_m_k);
threadwise_inv_std_load.MoveSrcSliceWindow(inv_std_grid_desc_m_k,
thread_copy_bwd_step_m_k);
threadwise_dx_store.MoveDstSliceWindow(dx_grid_desc_m_k, thread_copy_bwd_step_m_k);
}
}
}
};
} // namespace ck
......@@ -35,7 +35,7 @@ template <typename DYDataType,
index_t DBetaDstVectorSize>
struct GridwiseNormalizationBwdGammaBeta_mk_to_k
{
// if we just check ThreadSliceSize & VectorSize == 0, the performance may be poor
// if we just check ThreadSliceSize % VectorSize == 0, the performance may be poor (coalesce)
static_assert(((DYSrcVectorDim == 0 && MThreadSliceSize == DYSrcVectorSize) ||
(DYSrcVectorDim == 1 && KThreadSliceSize == DYSrcVectorSize)),
"Invalid thread slice sizes and/or dy vector sizes configuration, please check!");
......@@ -44,6 +44,15 @@ struct GridwiseNormalizationBwdGammaBeta_mk_to_k
(XSrcVectorDim == 1 && KThreadSliceSize == XSrcVectorSize)),
"Invalid thread slice sizes and/or x vector sizes configuration, please check!");
// do not force SliceSize == MeanInvStdSrcVectorSize for groupnorm
static_assert(
((MeanInvStdSrcVectorDim == 0 && MThreadSliceSize % MeanInvStdSrcVectorSize == 0) ||
(MeanInvStdSrcVectorDim == 1 && KThreadSliceSize % MeanInvStdSrcVectorSize == 0)),
"Invalid thread slice sizes and/or mean/inv_std vector sizes configuration, please check!");
static_assert(MThreadSliceSize == DGammaDstVectorSize && MThreadSliceSize == DBetaDstVectorSize,
"Invalid thread slice sizes and/or dx vector sizes configuration, please check!");
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
using DYThreadBufferDimAccessOrder =
......
......@@ -328,7 +328,7 @@ struct WmmaSelector
}
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template <>
static constexpr auto GetWmma<int4_t, int, 16, 16>()
static constexpr auto GetWmma<int4_t, int4_t, int, 16, 16>()
{
return WmmaInstr::wmma_i32_16x16x16_iu4;
}
......
......@@ -522,22 +522,21 @@ struct TransformConvFwdToGemm
// for output bias
template <typename CLayout,
typename std::enable_if<is_same_v<CLayout, tensor_layout::convolution::GK> ||
is_same_v<CLayout, tensor_layout::convolution::G_K>,
typename std::enable_if<is_same_v<CLayout, tensor_layout::convolution::G_K>,
bool>::type = false>
static auto
MakeCDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& /* c_g_n_k_wos_strides */)
static auto MakeCDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides)
{
const index_t N = c_g_n_k_wos_lengths[1];
const index_t K = c_g_n_k_wos_lengths[2];
const index_t N = c_g_n_k_wos_lengths[1];
const index_t K = c_g_n_k_wos_lengths[2];
const index_t KStride = c_g_n_k_wos_strides[2];
const index_t NHoWo =
N * ck::accumulate_n<index_t>(
c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>());
const auto out_gemmm_gemmn_desc =
make_naive_tensor_descriptor(make_tuple(NHoWo, K), make_tuple(I0, I1));
make_naive_tensor_descriptor(make_tuple(NHoWo, K), make_tuple(I0, KStride));
return out_gemmm_gemmn_desc;
}
......
......@@ -972,6 +972,15 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr,
const int32x4_t src_resource = make_wave_buffer_resource(global_ptr, src_element_space_size);
const index_t global_offset_bytes = is_valid ? global_offset * sizeof(T) : 0x80000000;
#if CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM
T* lds_ptr = lds_base_ptr + lds_offset;
auto const lds_ptr_sgpr =
__builtin_amdgcn_readfirstlane((reinterpret_cast<uintptr_t>(lds_ptr)));
asm volatile("s_mov_b32 m0, %0; \n\t"
"buffer_load_dword %1, %2, 0 offen lds;\n\t" ::"s"(lds_ptr_sgpr),
"v"(global_offset_bytes),
"s"(src_resource));
#else
// LDS pointer must be attributed with the LDS address space.
__attribute__((address_space(3))) uint32_t* lds_ptr =
reinterpret_cast<__attribute__((address_space(3))) uint32_t*>(
......@@ -979,6 +988,7 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr,
llvm_amdgcn_raw_buffer_load_lds(
src_resource, lds_ptr, sizeof(uint32_t), global_offset_bytes, 0, 0, 0);
#endif
}
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/amd_address_space.hpp"
#include "ck/utility/dynamic_buffer.hpp"
#include "ck/utility/math.hpp"
namespace ck {
namespace lds_utils {
/** \brief Allocate a given number of buffers in LDS and return them as a tuple.
*
* \tparam DataType Data type of elements to be stored in LDS.
* \tparam NumBuffers Number of buffers to be allocated.
* \param lds_ptr Address of the beginning of LDS space.
* \param num_elems_per_buffer Number of elements to allocate per single buffer.
* \param start_offset_elems Number of elements to move from the start of LDS for the allocation of
* the first buffer. \param lds_alignment Alignment of every buffer allocation given as a number of
* elements. \return Tuple of dynamic buffers representing memory allocated in LDS.
*/
template <typename DataType, index_t NumBuffers>
__device__ static auto AllocateLdsBuffers(void* lds_ptr,
int32_t num_elems_per_buffer,
int32_t start_offset_elems,
int32_t lds_alignment)
{
const DataType* lds_start = static_cast<DataType*>(lds_ptr) + start_offset_elems;
const int32_t single_buffer_offset =
math::integer_least_multiple(num_elems_per_buffer, lds_alignment);
return generate_tuple(
[&](auto i) {
const int32_t local_offset = i * single_buffer_offset;
return make_dynamic_buffer<AddressSpaceEnum::Lds>(lds_start + local_offset,
num_elems_per_buffer);
},
Number<NumBuffers>{});
}
} // namespace lds_utils
} // namespace ck
......@@ -9,6 +9,9 @@
// TODO: Add arch limitation
namespace ck {
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__)
#define __gfx11__
#endif
/********************************WAVE32 MODE***********************************************/
// src: fp16, dst: fp32
......@@ -25,7 +28,7 @@ struct intrin_wmma_f32_16x16x16_f16_w32<16, 16>
// delete them.
// amd_assembly_wmma_f32_16x16x16_f16_w32(
// reg_a, reg_b, reg_c.template AsType<float8_t>()(Number<0>{}));
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
#if defined(__gfx11__)
reg_c.template AsType<float8_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(
reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]);
#else
......@@ -46,7 +49,7 @@ struct intrin_wmma_f32_16x16x16_bf16_w32<16, 16>
template <class FloatC>
__device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
#if defined(__gfx11__)
reg_c.template AsType<float8_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32(
reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]);
......@@ -71,7 +74,7 @@ struct intrin_wmma_f16_16x16x16_f16_w32<16, 16, Opsel>
// opsel usage
// false: D0.[0:15] = result
// true : D0.[16:31]= result
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
#if defined(__gfx11__)
reg_c.template AsType<half16_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(
reg_a, reg_b, reg_c.template AsType<half16_t>()[Number<0>{}], Opsel);
#else
......@@ -95,7 +98,7 @@ struct intrin_wmma_bf16_16x16x16_bf16_w32<16, 16, Opsel>
// opsel usage
// false: D0.[0:15] = result
// true : D0.[16:31]= result
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
#if defined(__gfx11__)
reg_c.template AsType<bhalf16_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32(
reg_a, reg_b, reg_c.template AsType<bhalf16_t>()[Number<0>{}], Opsel);
......@@ -117,7 +120,7 @@ struct intrin_wmma_i32_16x16x16_iu8_w32<16, 16, neg_a, neg_b, clamp>
template <class FloatC>
__device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
#if defined(__gfx11__)
reg_c.template AsType<int32x8_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(
neg_a,
......@@ -145,7 +148,7 @@ struct intrin_wmma_f32_16x16x16_f16_w64<16, 16>
template <class FloatC>
__device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
#if defined(__gfx11__)
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w64(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}]);
#else
......@@ -166,7 +169,7 @@ struct intrin_wmma_f32_16x16x16_bf16_w64<16, 16>
template <class FloatC>
__device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
#if defined(__gfx11__)
reg_c.template AsType<float4_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_f32_16x16x16_bf16_w64(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}]);
......@@ -191,7 +194,7 @@ struct intrin_wmma_f16_16x16x16_f16_w64<16, 16, Opsel>
// opsel usage
// false: D0.[0:15] = result
// true : D0.[16:31]= result
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
#if defined(__gfx11__)
reg_c.template AsType<half8_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f16_16x16x16_f16_w64(
reg_a, reg_b, reg_c.template AsType<half8_t>()[Number<0>{}], Opsel);
#else
......@@ -215,7 +218,7 @@ struct intrin_wmma_bf16_16x16x16_bf16_w64<16, 16, Opsel>
// opsel usage
// false: D0.[0:15] = result
// true : D0.[16:31]= result
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
#if defined(__gfx11__)
reg_c.template AsType<bhalf8_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64(
reg_a, reg_b, reg_c.template AsType<bhalf8_t>()[Number<0>{}], Opsel);
......@@ -237,7 +240,7 @@ struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp>
template <class FloatC>
__device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
#if defined(__gfx11__)
reg_c.template AsType<int32x4_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_i32_16x16x16_iu8_w64(
neg_a,
......
......@@ -4,6 +4,10 @@
#pragma once
namespace ck {
// Define the common macro for MI300 models
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#define __gfx94__
#endif
// fp32
template <index_t MPerWave, index_t NPerWave>
......@@ -341,7 +345,7 @@ struct intrin_mfma_f64_16x16x4f64<16, 16>
template <class FloatC>
__device__ static void Run(const double& reg_a, const double& reg_b, FloatC& reg_c)
{
#if defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#if defined(__gfx90a__) || defined(__gfx94__)
reg_c.template AsType<double4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f64_16x16x4f64(
reg_a, reg_b, reg_c.template AsType<double4_t>()[Number<0>{}], 0, 0, 0);
#else
......@@ -361,7 +365,7 @@ struct intrin_mfma_f32_32x32x16f8f8<32, 32>
template <class FloatC>
__device__ static void Run(const f8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#if defined(__gfx94__)
reg_c.template AsType<float16_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
bit_cast<long>(reg_a),
......@@ -393,7 +397,7 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16>
template <class FloatC>
__device__ static void Run(const f8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#if defined(__gfx94__)
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(
bit_cast<long>(reg_a),
bit_cast<long>(reg_b),
......@@ -424,7 +428,7 @@ struct intrin_mfma_f32_32x32x16bf8bf8<32, 32>
template <class FloatC>
__device__ static void Run(const bf8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#if defined(__gfx94__)
reg_c.template AsType<float16_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8(
bit_cast<long>(reg_a),
......@@ -456,7 +460,7 @@ struct intrin_mfma_f32_16x16x32bf8bf8<16, 16>
template <class FloatC>
__device__ static void Run(const bf8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#if defined(__gfx94__)
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8(
bit_cast<long>(reg_a),
bit_cast<long>(reg_b),
......@@ -487,7 +491,7 @@ struct intrin_mfma_f32_32x32x16f8bf8<32, 32>
template <class FloatC>
__device__ static void Run(const f8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#if defined(__gfx94__)
reg_c.template AsType<float16_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(
bit_cast<long>(reg_a),
......@@ -519,7 +523,7 @@ struct intrin_mfma_f32_16x16x32f8bf8<16, 16>
template <class FloatC>
__device__ static void Run(const f8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#if defined(__gfx94__)
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_fp8_bf8(
bit_cast<long>(reg_a),
bit_cast<long>(reg_b),
......@@ -550,7 +554,7 @@ struct intrin_mfma_f32_32x32x16bf8f8<32, 32>
template <class FloatC>
__device__ static void Run(const bf8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#if defined(__gfx94__)
reg_c.template AsType<float16_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8(
bit_cast<long>(reg_a),
......@@ -582,7 +586,7 @@ struct intrin_mfma_f32_16x16x32bf8f8<16, 16>
template <class FloatC>
__device__ static void Run(const bf8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#if defined(__gfx94__)
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8(
bit_cast<long>(reg_a),
bit_cast<long>(reg_b),
......
......@@ -189,6 +189,7 @@ struct vector_type<T, 1>
}
};
int static err = 0;
template <typename T>
struct vector_type<T, 2>
{
......@@ -221,6 +222,10 @@ struct vector_type<T, 2>
{
return data_.d2x1_;
}
else
{
return err;
}
}
template <typename X>
......@@ -236,6 +241,10 @@ struct vector_type<T, 2>
{
return data_.d2x1_;
}
else
{
return err;
}
}
};
......@@ -278,6 +287,10 @@ struct vector_type<T, 4>
{
return data_.d4x1_;
}
else
{
return err;
}
}
template <typename X>
......@@ -298,6 +311,10 @@ struct vector_type<T, 4>
{
return data_.d4x1_;
}
else
{
return err;
}
}
};
......@@ -347,6 +364,10 @@ struct vector_type<T, 8>
{
return data_.d8x1_;
}
else
{
return err;
}
}
template <typename X>
......@@ -372,6 +393,10 @@ struct vector_type<T, 8>
{
return data_.d8x1_;
}
else
{
return err;
}
}
};
......@@ -428,6 +453,10 @@ struct vector_type<T, 16>
{
return data_.d16x1_;
}
else
{
return err;
}
}
template <typename X>
......@@ -458,6 +487,10 @@ struct vector_type<T, 16>
{
return data_.d16x1_;
}
else
{
return err;
}
}
};
......@@ -520,6 +553,10 @@ struct vector_type<T, 32>
{
return data_.d32x1_;
}
else
{
return err;
}
}
template <typename X>
......@@ -554,6 +591,10 @@ struct vector_type<T, 32>
{
return data_.d32x1_;
}
else
{
return err;
}
}
};
......@@ -623,6 +664,10 @@ struct vector_type<T, 64>
{
return data_.d64x1_;
}
else
{
return err;
}
}
template <typename X>
......@@ -662,6 +707,10 @@ struct vector_type<T, 64>
{
return data_.d64x1_;
}
else
{
return err;
}
}
};
......@@ -737,6 +786,10 @@ struct vector_type<T, 128>
{
return data_.d128x1_;
}
else
{
return err;
}
}
template <typename X>
......@@ -780,6 +833,10 @@ struct vector_type<T, 128>
{
return data_.d128x1_;
}
else
{
return err;
}
}
};
......@@ -861,6 +918,10 @@ struct vector_type<T, 256>
{
return data_.d256x1_;
}
else
{
return err;
}
}
template <typename X>
......@@ -908,6 +969,10 @@ struct vector_type<T, 256>
{
return data_.d256x1_;
}
else
{
return err;
}
}
};
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -19,6 +19,12 @@ struct is_known_at_compile_time<index_t>
static constexpr bool value = false;
};
template <>
struct is_known_at_compile_time<unsigned int>
{
static constexpr bool value = false;
};
template <>
struct is_known_at_compile_time<long_index_t>
{
......
......@@ -5,6 +5,7 @@
#include "functional4.hpp"
#include "tuple.hpp"
#include "is_detected.hpp"
namespace ck {
......@@ -33,6 +34,28 @@ __host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple<X&...>&
ty);
}
template <typename... X, typename... Y>
__host__ __device__ constexpr auto concat_tuple(const Tuple<X...>& tx, const Tuple<Y...>& ty)
{
return unpack2(
[&](auto... zs) { return Tuple<decltype(zs)...>{std::forward<decltype(zs)>(zs)...}; },
tx,
ty);
}
// Support any number of tuples to concat (also 1)
template <typename... X>
__host__ __device__ constexpr auto concat_tuple(const Tuple<X...>& tx)
{
return tx;
}
template <typename... X, typename... Tuples>
__host__ __device__ constexpr auto concat_tuple(const Tuple<X...>& tx, const Tuples&... tuples)
{
return concat_tuple(tx, concat_tuple(tuples...));
}
namespace detail {
template <typename F, typename X, index_t... Is>
......@@ -78,4 +101,92 @@ __host__ __device__ constexpr auto transform_tuples(F f, const X& x, const Y& y,
f, x, y, z, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{});
}
// By default unroll to the flatten
template <index_t Depth = 0, index_t MaxDepth = -1>
__host__ __device__ constexpr auto UnrollNestedTuple(const Tuple<>& element)
{
return element;
}
template <index_t Depth = 0, index_t MaxDepth = -1, typename T>
__host__ __device__ constexpr auto UnrollNestedTuple(const T& element)
{
return make_tuple(element);
}
template <index_t Depth = 0, index_t MaxDepth = -1, typename... Ts>
__host__ __device__ constexpr auto UnrollNestedTuple(const Tuple<Ts...>& tuple)
{
if constexpr(Depth == MaxDepth)
{
return tuple;
}
else
{
return unpack(
[&](auto&&... ts) {
return concat_tuple(UnrollNestedTuple<Depth + 1, MaxDepth>(ts)...);
},
tuple);
}
}
template <typename... Ts>
__host__ __device__ constexpr auto TupleReverse(const Tuple<Ts...>& tuple)
{
return generate_tuple(
[&](auto i) {
using Idx = Number<Tuple<Ts...>::Size() - i - 1>;
return tuple.At(Idx{});
},
Number<Tuple<Ts...>::Size()>{});
}
// Reduce tuple values in specific range using Function
template <index_t Idx, index_t End, typename F, typename... Ts>
__host__ __device__ constexpr auto TupleReduce(F&& f, const Tuple<Ts...>& tuple)
{
static_assert(Idx < End, "Wrong parameters for TupleReduce");
if constexpr(Idx + 1 == End)
{
return tuple.At(Number<Idx>{});
}
else
{
return f(tuple.At(Number<Idx>{}), TupleReduce<Idx + 1, End>(f, tuple));
}
}
template <typename T>
using is_tuple = decltype(std::declval<T&>().IsTuple());
template <typename... Ts>
__host__ __device__ constexpr auto IsNestedTuple(const Tuple<Ts...>&)
{
return (is_detected<is_tuple, Ts>::value || ...);
}
template <index_t depth = 0, typename T>
__host__ __device__ constexpr auto TupleDepth(const T&)
{
return depth;
}
template <index_t depth = 0, typename... Ts>
__host__ __device__ constexpr auto TupleDepth(const Tuple<Ts...>&)
{
return math::max(TupleDepth<depth + 1>(Ts{})...);
}
template <index_t from, index_t to, typename... Ts>
__host__ __device__ constexpr auto TupleSlice(const Tuple<Ts...>& tuple)
{
return generate_tuple(
[&](auto i) {
using Idx = Number<from + i>;
return tuple.At(Idx{});
},
Number<to - from>{});
}
} // namespace ck
......@@ -8,6 +8,10 @@
#include "ck/utility/random_gen.hpp"
namespace ck {
// Define the common macro for MI300 models
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#define __gfx94__
#endif
// Convert X to Y, both X and Y are non-const data types.
template <typename Y,
......@@ -105,7 +109,7 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
{
constexpr int seed = 42;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#if defined(__gfx94__)
float max_fp8 = 240.0f;
x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x);
union
......@@ -133,7 +137,7 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
template <>
inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#if defined(__gfx94__)
// convert to float and use native converion
return f8_convert_sr<f8_t>(type_convert<float>(x));
#else
......@@ -154,7 +158,7 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, float>(float x)
{
constexpr int seed = 42;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#if defined(__gfx94__)
union
{
float fval;
......@@ -180,9 +184,9 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, float>(float x)
template <>
inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#if defined(__gfx94__)
// convert to float and use native converion
return f8_convert_sr<f8_t>(type_convert<float>(x));
return f8_convert_sr<bf8_t>(type_convert<float>(x));
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
......@@ -203,7 +207,7 @@ __host__ __device__ constexpr Y f8_convert_rne(X x);
template <>
inline __host__ __device__ f8_t f8_convert_rne<f8_t, float>(float x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#if defined(__gfx94__)
float max_fp8 = 240.0f;
x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x);
union
......@@ -232,7 +236,7 @@ inline __host__ __device__ f8_t f8_convert_rne<f8_t, float>(float x)
template <>
inline __host__ __device__ f8_t f8_convert_rne<f8_t, half_t>(half_t x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#if defined(__gfx94__)
// convert to float and use native converion
return f8_convert_rne<f8_t>(type_convert<float>(x));
#else
......@@ -250,7 +254,7 @@ inline __host__ __device__ f8_t f8_convert_rne<f8_t, half_t>(half_t x)
template <>
inline __host__ __device__ bf8_t f8_convert_rne<bf8_t, float>(float x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#if defined(__gfx94__)
union
{
float fval;
......@@ -277,7 +281,7 @@ inline __host__ __device__ bf8_t f8_convert_rne<bf8_t, float>(float x)
template <>
inline __host__ __device__ bf8_t f8_convert_rne<bf8_t, half_t>(half_t x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#if defined(__gfx94__)
// convert to float and use native converion
return f8_convert_rne<bf8_t>(type_convert<float>(x));
#else
......@@ -295,7 +299,7 @@ inline __host__ __device__ bf8_t f8_convert_rne<bf8_t, half_t>(half_t x)
template <>
inline __host__ __device__ f8_t type_convert<f8_t, float>(float x)
{
#if defined CK_USE_SR_F8_CONVERSION
#if CK_USE_SR_F8_CONVERSION
return f8_convert_sr<f8_t>(x);
#else
return f8_convert_rne<f8_t>(x);
......@@ -306,7 +310,7 @@ inline __host__ __device__ f8_t type_convert<f8_t, float>(float x)
template <>
inline __host__ __device__ float type_convert<float, f8_t>(f8_t x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#if defined(__gfx94__)
float fval;
uint32_t i32val = static_cast<uint32_t>(x);
fval = __builtin_amdgcn_cvt_f32_fp8(i32val, 0);
......@@ -321,7 +325,7 @@ inline __host__ __device__ float type_convert<float, f8_t>(f8_t x)
template <>
inline __host__ __device__ float2_t type_convert<float2_t, f8x2_t>(f8x2_t x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#if defined(__gfx94__)
const auto i16val = bit_cast<uint16_t>(x);
return __builtin_amdgcn_cvt_pk_f32_fp8(i16val, 0);
#else
......@@ -352,10 +356,10 @@ inline __host__ __device__ half2_t type_convert<half2_t, float2_t>(float2_t x)
template <>
inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
{
#if defined CK_USE_SR_F8_CONVERSION
#if CK_USE_SR_F8_CONVERSION
return f8_convert_sr<f8_t>(x);
#else
return f8_convert_nre<f8_t>(x);
return f8_convert_rne<f8_t>(x);
#endif
}
......@@ -363,7 +367,7 @@ inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
template <>
inline __host__ __device__ half_t type_convert<half_t, f8_t>(f8_t x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#if defined(__gfx94__)
// use native conversion to float and convert to fp16
return type_convert<half_t>(type_convert<float>(x));
#else
......@@ -376,7 +380,7 @@ inline __host__ __device__ half_t type_convert<half_t, f8_t>(f8_t x)
template <>
inline __host__ __device__ bf8_t type_convert<bf8_t, float>(float x)
{
#if defined CK_USE_SR_F8_CONVERSION
#if CK_USE_SR_F8_CONVERSION
return f8_convert_sr<bf8_t>(x);
#else
return f8_convert_rne<bf8_t>(x);
......@@ -387,7 +391,7 @@ inline __host__ __device__ bf8_t type_convert<bf8_t, float>(float x)
template <>
inline __host__ __device__ float type_convert<float, bf8_t>(bf8_t x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#if defined(__gfx94__)
float fval;
uint32_t i32val = static_cast<uint32_t>(x);
fval = __builtin_amdgcn_cvt_f32_bf8(i32val, 0);
......@@ -403,7 +407,7 @@ inline __host__ __device__ float type_convert<float, bf8_t>(bf8_t x)
template <>
inline __host__ __device__ bf8_t type_convert<bf8_t, half_t>(half_t x)
{
#if defined CK_USE_SR_F8_CONVERSION
#if CK_USE_SR_F8_CONVERSION
return f8_convert_sr<bf8_t>(x);
#else
return f8_convert_rne<bf8_t>(x);
......@@ -414,7 +418,7 @@ inline __host__ __device__ bf8_t type_convert<bf8_t, half_t>(half_t x)
template <>
inline __host__ __device__ half_t type_convert<half_t, bf8_t>(bf8_t x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#if defined(__gfx94__)
// use native conversion to float and convert to fp16
return type_convert<half_t>(type_convert<float>(x));
#else
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/wrapper/utils/layout_utils.hpp"
namespace ck {
namespace wrapper {
/**
* \brief Layout wrapper that performs the tensor descriptor logic.
*
* \tparam Shape Tuple of Number<> (for compile-time layout) or index_t
* (dynamic layout). It is possible to pass nested shapes
* (e.g. ((4, 2), 2)), nested dimensions are merged.
* \tparam UnrolledDescriptorType Tensor descriptor for unnested shape dims.
*/
template <typename Shape, typename UnrolledDescriptorType>
struct Layout
{
private:
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
/**
* \brief Generate default indices tuple (idx with all merged nested shapes)
*
* \param shape Shape to align.
* \return Multi idx tuple with zeros.
*/
template <typename... Ts>
__host__ __device__ constexpr static auto
GenerateDefaultIdxsTuple([[maybe_unused]] const Tuple<Ts...>& shape)
{
return generate_tuple(
[&](auto) {
if constexpr(!remove_cvref_t<UnrolledDescriptorType>::IsKnownAtCompileTime())
{
// runtime layout
return index_t(0);
}
else
{
// compiletime layout
return I0;
}
},
Number<Tuple<Ts...>::Size()>{});
}
/**
* \brief Generate lower dims in compile-time for the Merge transform using
* provided type. If element of nested Tuple<Ts...> is also a tuple, then
* merge (generate sequence for merge). If tuple is element, then pass
* through (sequence with one element).
*
* \param shape Shape to align.
* \return LowerDims for MergeTrasform.
*/
template <typename Idx, typename... Ts>
__host__ __device__ constexpr static auto
GenerateLowerDim([[maybe_unused]] const Tuple<Ts...>& shape)
{
if constexpr(Idx::value == 0)
{
if constexpr(is_detected<is_tuple, tuple_element_t<Idx::value, Tuple<Ts...>>>::value)
{
// Return Sequence for the first tuple
constexpr index_t merge_nelems = decltype(UnrollNestedTuple(
tuple_element_t<Idx::value, Tuple<Ts...>>{}))::Size();
using LowerDimsSequence =
typename arithmetic_sequence_gen<0, merge_nelems, 1>::type;
return LowerDimsSequence::Reverse();
}
else
{
// Return first element
return Sequence<0>{};
}
}
else
{
// Get previous element using recurence (in compile-time)
using PreviousSeqT = decltype(GenerateLowerDim<Number<Idx::value - 1>>(Tuple<Ts...>{}));
const auto next_seq_val = PreviousSeqT::At(I0) + 1;
if constexpr(is_detected<is_tuple, tuple_element_t<Idx::value, Tuple<Ts...>>>::value)
{
constexpr index_t merge_nelems = decltype(UnrollNestedTuple(
tuple_element_t<Idx::value, Tuple<Ts...>>{}))::Size();
using LowerDimsSequence =
typename arithmetic_sequence_gen<next_seq_val, next_seq_val + merge_nelems, 1>::
type;
return LowerDimsSequence::Reverse();
}
else
{
return Sequence<next_seq_val>{};
}
}
}
/**
* \brief Iterate over the nested tuples in the shape.
* Unroll nested tuples to align Tuple<ShapeDims...> to Tuple<IdxDims...>
* Example idx: (1, 1), 1, 1
* Example shape: (2, (2, 2)), 2, (2, 2)
* Unrolled shape: 2, (2, 2), 2, (2, 2)
*
* \param shape Layout shape.
* \param idx Idx to align.
* \return Algined shape.
*/
template <typename... ShapeDims, typename... IdxDims>
__host__ __device__ constexpr static auto AlignShapeToIdx(const Tuple<ShapeDims...>& shape,
const Tuple<IdxDims...>& idx)
{
if constexpr(!IsNestedTuple(Tuple<IdxDims...>{}))
{
// Index unrolled to flatten, return shape
return shape;
}
else
{
// Iterate over shape tuple elements:
// 1. If corresponding idx element is tuple then return (will be unrolled)
// 2. If no, pack in tuple. It will be restored during unroll.
auto aligned_shape = generate_tuple(
[&](auto i) {
if constexpr(is_detected<is_tuple,
tuple_element_t<i, Tuple<IdxDims...>>>::value)
{
return shape.At(i);
}
else
{
return make_tuple(shape.At(i));
}
},
Number<Tuple<IdxDims...>::Size()>{});
// Unroll and process next step
return AlignShapeToIdx(UnrollNestedTuple<0, 1>(aligned_shape),
UnrollNestedTuple<0, 1>(idx));
}
}
/**
* \brief Merge descriptor to 1D.
*
* \param shape Layout shape.
* \param desc Descriptor to merge.
* \return 1D descriptor.
*/
template <typename... ShapeDims, typename DescriptorToMerge>
__host__ __device__ constexpr static auto MakeMerge1d(const Tuple<ShapeDims...>& shape,
const DescriptorToMerge& desc)
{
// Reverse each element in tuple
const auto merge_elems = TupleReverse(UnrollNestedTuple(shape));
// Generate reverted indexes (column major traverse)
using MergeElemsSequence = typename arithmetic_sequence_gen<0, merge_elems.Size(), 1>::type;
const auto lower_dims = make_tuple(MergeElemsSequence::Reverse());
const auto upper_dims = make_tuple(Sequence<0>{});
// Merge to 1d
if constexpr(!remove_cvref_t<UnrolledDescriptorType>::IsKnownAtCompileTime())
{
return transform_tensor_descriptor(
desc, make_tuple(make_merge_transform(merge_elems)), lower_dims, upper_dims);
}
else
{
// If the descriptor is known at the compilation time,
// use `make_merge_transform_v1_carry_check` because it doesn't use
// memcpy.
return transform_tensor_descriptor(
desc,
make_tuple(make_merge_transform_v1_carry_check(merge_elems)),
lower_dims,
upper_dims);
}
}
/**
* \brief Merge nested shape dims when corresponding index is also merged.
* Input desc shape: 2, 2, 2, 2, 2, 2
* Example idx: 1, 1, 1, (1, 1)
* Example shape: 2, (2, 2), 2, (2, 2)
* Merged shape: 2, 4, 2, 2, 2
*
* \param shape Layout shape.
* \param idxs Indexes to align descriptor.
* \param desc Descriptor to merge.
* \return Aligned descriptor to idx.
*/
template <typename... ShapeDims, typename... IdxDims, typename DescriptorToMerge>
__host__ __device__ constexpr static auto
CreateMergedDescriptor(const Tuple<ShapeDims...>& shape,
[[maybe_unused]] const Tuple<IdxDims...>& idxs,
DescriptorToMerge& desc)
{
const auto transforms = generate_tuple(
[&](auto i) {
// Compare Idx with shape
if constexpr(is_detected<is_tuple,
tuple_element_t<i, Tuple<ShapeDims...>>>::value &&
!is_detected<is_tuple, tuple_element_t<i, Tuple<IdxDims...>>>::value)
{
// If shape element is tuple and idx element is Number, then merge
// Unroll and reverse tuple to traverse column-major
const auto merge_elems = TupleReverse(UnrollNestedTuple(shape.At(i)));
if constexpr(!remove_cvref_t<UnrolledDescriptorType>::IsKnownAtCompileTime())
{
return make_merge_transform(merge_elems);
}
else
{
// If the descriptor is known at the compilation time,
// use `make_merge_transform_v1_carry_check` because
// it doesn't use memcpy.
return make_merge_transform_v1_carry_check(merge_elems);
}
}
else
{
// If shape element is integer and idx element is tuple, passed idx is wrong
static_assert(
!(!is_detected<is_tuple, tuple_element_t<i, Tuple<ShapeDims...>>>::value &&
is_detected<is_tuple, tuple_element_t<i, Tuple<IdxDims...>>>::value),
"Wrong Idx for layout()");
// If shape element has the same type as idx element, then pass through
return make_pass_through_transform(shape.At(i));
}
},
Number<Tuple<ShapeDims...>::Size()>{});
const auto lower_dims =
generate_tuple([&](auto i) { return GenerateLowerDim<Number<i>>(shape); },
Number<Tuple<ShapeDims...>::Size()>{});
const auto upper_dims = generate_tuple([&](auto i) { return Sequence<i.value>{}; },
Number<Tuple<ShapeDims...>::Size()>{});
return transform_tensor_descriptor(desc, transforms, lower_dims, upper_dims);
}
using Descriptor1dType =
remove_cvref_t<decltype(MakeMerge1d(Shape{}, UnrolledDescriptorType{}))>;
using DefaultIdxsTupleType = remove_cvref_t<decltype(GenerateDefaultIdxsTuple(Shape{}))>;
public:
using LayoutShape = Shape;
using LayoutUnrolledDescriptorType = UnrolledDescriptorType;
/**
* \brief Transform descriptor to align to passed indexes.
*
* \param shape Layout shape.
* \param idxs Indexes to align descriptor.
* \param naive_descriptor Descriptor to merge.
* \return Aligned descriptor to idx.
*/
template <typename... ShapeDims, typename... IdxDims>
__host__ __device__ constexpr static auto
TransformDesc(const Tuple<ShapeDims...>& shape,
const Tuple<IdxDims...>& idxs,
const UnrolledDescriptorType& naive_descriptor)
{
if constexpr(Tuple<IdxDims...>::Size() == I1)
{
// 1d idx path
return MakeMerge1d(shape, naive_descriptor);
}
else
{
// Merge nested shape dims
// Example idx: (1, 1), 1, 1
// Example shape: (2, (2, 2)), 2, (2, 2)
// Merged shape: (2, 4), 2, 4
static_assert(Tuple<ShapeDims...>::Size() == Tuple<IdxDims...>::Size(),
"Idx rank and Shape rank must be the same (except 1d).");
// Unroll while IdxDims is nested
const auto aligned_shape = AlignShapeToIdx(shape, idxs);
// Transform correct form of shape
return CreateMergedDescriptor(aligned_shape, UnrollNestedTuple(idxs), naive_descriptor);
}
}
using MergedNestsDescriptorType = remove_cvref_t<decltype(TransformDesc(
Shape{}, DefaultIdxsTupleType{}, UnrolledDescriptorType{}))>;
__host__ __device__ constexpr auto GetElementSpaceSize() const
{
return unrolled_descriptor_.GetElementSpaceSize();
}
__host__ __device__ Layout() = delete;
/**
* \brief Layout constructor.
*
* \param shape Shape for layout.
* \param unnested_descriptor Descriptor
*/
__host__ __device__ constexpr Layout(const Shape& shape,
const UnrolledDescriptorType& unnested_descriptor)
: unrolled_descriptor_(unnested_descriptor), shape_(shape)
{
// Construct if runtime mode
if constexpr(!remove_cvref_t<UnrolledDescriptorType>::IsKnownAtCompileTime())
{
descriptor_1d_ = MakeMerge1d(shape_, unrolled_descriptor_);
merged_nests_descriptor_ =
TransformDesc(shape_, DefaultIdxsTupleType{}, unrolled_descriptor_);
}
}
/**
* \brief Returns real offset to element in runtime.
*
* \tparam Idxs Tuple of indexes.
* \return Calculated offset.
*/
template <typename Idxs>
__host__ __device__ constexpr index_t operator()() const
{
static_assert(remove_cvref_t<UnrolledDescriptorType>::IsKnownAtCompileTime(),
"Compiletime operator used on runtime layout.");
using TransformedDesc = decltype(TransformDesc(Shape{}, Idxs{}, UnrolledDescriptorType{}));
using UnrolledIdx = decltype(UnrollNestedTuple(Idxs{}));
return TransformedDesc{}.CalculateOffset(UnrolledIdx{});
}
/**
* \brief Returns real offset to element in compile time.
*
* \param Idx Tuple of indexes.
* \return Calculated offset.
*/
template <typename... Ts>
__host__ __device__ index_t operator()(const Tuple<Ts...>& Idx) const
{
if constexpr(!IsNestedTuple(Tuple<Ts...>{}) && Tuple<Ts...>::Size() == 1)
{
// if 1d access
return descriptor_1d_.CalculateOffset(Idx);
}
else if constexpr(!IsNestedTuple(Tuple<Ts...>{}) && Tuple<Ts...>::Size() == Shape::Size())
{
// if Shape::Size() access (merged nested shapes)
return merged_nests_descriptor_.CalculateOffset(UnrollNestedTuple(Idx));
}
else
{
// Custom index, need to transform descriptor
const auto transformed_desc = TransformDesc(shape_, Idx, unrolled_descriptor_);
return transformed_desc.CalculateOffset(UnrollNestedTuple(Idx));
}
}
/**
* \brief Length getter (product if tuple).
*
* \tparam IDim Tuple of indexes or index.
* \return Calculated size.
*/
template <index_t IDim>
__host__ __device__ constexpr auto GetLength() const
{
const auto elem = shape_.At(Number<IDim>{});
if constexpr(is_detected<is_tuple, tuple_element_t<IDim, Shape>>::value)
{
const auto unrolled_element = UnrollNestedTuple(elem);
return TupleReduce<I0.value, unrolled_element.Size()>(
[](auto x, auto y) { return x * y; }, unrolled_element);
}
else
{
return elem;
}
}
/**
* \brief Layout size getter (product of shape).
*
* \return Calculated size.
*/
__host__ __device__ constexpr auto GetLengths() const
{
const auto unrolled_shape = UnrollNestedTuple(shape_);
return TupleReduce<I0.value, unrolled_shape.Size()>([](auto x, auto y) { return x * y; },
unrolled_shape);
}
/**
* \brief Shape getter.
*
* \return Shape.
*/
__host__ __device__ constexpr const Shape& GetShape() const { return shape_; }
/**
* \brief Get default lengths (tuple filled with Shape length elements).
*
* \return Default lengths.
*/
__host__ __device__ constexpr auto GetDefaultLengthsTuple() const
{
return generate_tuple([&](auto i) { return GetLength<i>(); }, Number<Shape::Size()>{});
}
/**
* \brief Get default start idx (tuple filled with 0s of the same size as Shape).
*
* \return Default start idx.
*/
__host__ __device__ constexpr auto GetDefaultStartIdxs() const
{
return GenerateDefaultIdxsTuple(shape_);
}
/**
* \brief Get descriptor with all nested dimensions merged.
* Example, shape: ((2, 2), 2)
* Descriptor lengths: (4, 2)
*
* \note The size of merged descriptor is the same as Layout's shape.
*
* \return Merged nests descriptor.
*/
__host__ __device__ constexpr const MergedNestsDescriptorType&
GetMergedNestingDescriptor() const
{
return merged_nests_descriptor_;
}
/**
* \brief Get descriptor with all dimensions are merged (1D).
* Example, shape: ((2, 2), 2)
* Descriptor lengths: (8)
*
* \return 1D descriptor.
*/
__host__ __device__ constexpr const Descriptor1dType& Get1DDescriptor() const
{
return descriptor_1d_;
}
/**
* \brief Get unnested descriptor (with unrolled dims)
* Example, shape: ((2, 2), 2)
* Descriptor lengths: (2, 2, 2)
*
* \return Flattened descriptor.
*/
__host__ __device__ constexpr const UnrolledDescriptorType& GetUnrolledDescriptor() const
{
return unrolled_descriptor_;
}
private:
// All dimensions are unrolled
UnrolledDescriptorType unrolled_descriptor_;
// 1D descriptor
Descriptor1dType descriptor_1d_;
// All nesting are merged
MergedNestsDescriptorType merged_nests_descriptor_;
// Example, shape: ((2, 2), 2)
// UnrolledDescriptorType lengths: (2, 2, 2)
// Descriptor1dType lengths: (8)
// MergedNestsDescriptorType lengths: (4, 2)
const Shape shape_;
};
} // namespace wrapper
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/wrapper/utils/tensor_utils.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_description/tensor_space_filling_curve.hpp"
namespace ck {
namespace wrapper {
/**
* \brief Perform optimized copy between two tensors partitions (threadwise copy).
* Tensors must have the same size.
*
* \tparam DimAccessOrderTuple Tuple with dimension access order.
* \tparam VectorDim Dimension for vectorized read and write.
* \tparam ScalarPerVector Number of scalar per vectorized read and write.
* \param src_tensor Source tensor.
* \param dst_tensor Destination tensor.
*/
template <typename DimAccessOrderTuple,
index_t VectorDim,
index_t ScalarPerVector,
typename SrcTensorType,
typename DstTensorType>
__device__ void copy(const SrcTensorType& src_tensor, DstTensorType& dst_tensor)
{
static_assert(is_detected<is_tuple, DimAccessOrderTuple>::value);
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
const auto& in_grid_desc = layout(src_tensor).GetUnrolledDescriptor();
const auto& out_grid_desc = layout(dst_tensor).GetUnrolledDescriptor();
using SrcShapeType = remove_cvref_t<decltype(shape(src_tensor))>;
constexpr index_t num_dims = SrcShapeType::Size();
constexpr auto thread_slice_lengths =
generate_sequence_v2([](auto I) { return size(SrcShapeType{}.At(I)); }, Number<num_dims>{});
constexpr auto dim_access_order = generate_sequence_v2(
[](auto I) { return DimAccessOrderTuple{}.At(I); }, Number<num_dims>{});
if constexpr(SrcTensorType::IsDynamicBuffer && DstTensorType::IsDynamicBuffer)
{
// Perform a copy between DynamicBuffers
auto transfer = ThreadwiseTensorSliceTransfer_v7<
Tuple<typename SrcTensorType::TensorElementType>,
Tuple<typename DstTensorType::TensorElementType>,
decltype(tie(in_grid_desc)),
decltype(tie(out_grid_desc)),
tensor_operation::element_wise::PassThrough,
Sequence<static_cast<index_t>(InMemoryDataOperationEnum::Set)>,
decltype(thread_slice_lengths),
decltype(dim_access_order),
VectorDim,
ScalarPerVector,
Sequence<false>,
Sequence<false>>{in_grid_desc,
make_tuple(src_tensor.GetMultiIdxOffsets()),
out_grid_desc,
make_tuple(dst_tensor.GetMultiIdxOffsets()),
tensor_operation::element_wise::PassThrough{}};
transfer.Run(tie(in_grid_desc),
tie(src_tensor.GetBuffer()),
tie(out_grid_desc),
tie(dst_tensor.GetBuffer()));
}
else if constexpr(!SrcTensorType::IsDynamicBuffer && DstTensorType::IsDynamicBuffer)
{
// Perform copy from StaticBuffer to DynamicBuffer
const auto src_slice_origin_idxs =
generate_tuple([&](auto) { return I0; }, Number<num_dims>{});
auto transfer =
ThreadwiseTensorSliceTransfer_v1r3<typename SrcTensorType::TensorElementType,
typename DstTensorType::TensorElementType,
remove_cvref_t<decltype(in_grid_desc)>,
remove_cvref_t<decltype(out_grid_desc)>,
tensor_operation::element_wise::PassThrough,
decltype(thread_slice_lengths),
decltype(dim_access_order),
VectorDim,
ScalarPerVector,
InMemoryDataOperationEnum::Set,
I1,
true>{out_grid_desc,
dst_tensor.GetMultiIdxOffsets(),
tensor_operation::element_wise::PassThrough{}};
transfer.Run(in_grid_desc,
src_slice_origin_idxs,
src_tensor.GetBuffer(),
out_grid_desc,
dst_tensor.GetBuffer());
}
else if constexpr(SrcTensorType::IsDynamicBuffer && !DstTensorType::IsDynamicBuffer)
{
// Perform copy from DynamicBuffer to StaticBuffer
const auto src_dst_slice_origin =
generate_tuple([&](auto) { return I0; }, Number<num_dims>{});
constexpr auto src_vector_tensor_lengths = generate_sequence_v2(
[&](auto I) {
if constexpr(I == VectorDim)
{
return Number<ScalarPerVector>{};
}
else
{
return I1;
}
},
Number<num_dims>{});
auto transfer =
ThreadwiseTensorSliceTransfer_v4r1<typename SrcTensorType::TensorElementType,
typename DstTensorType::TensorElementType,
remove_cvref_t<decltype(in_grid_desc)>,
remove_cvref_t<decltype(out_grid_desc)>,
decltype(thread_slice_lengths),
decltype(dim_access_order),
decltype(src_vector_tensor_lengths),
decltype(dim_access_order)>{
src_tensor.GetMultiIdxOffsets()};
transfer.Run(in_grid_desc,
src_dst_slice_origin,
src_tensor.GetBuffer(),
out_grid_desc,
src_dst_slice_origin,
dst_tensor.GetBuffer());
}
else
{
// Perform copy between StaticBuffers
static_for<0, SrcShapeType::Size(), 1>{}([&](auto i) { dst_tensor(i) = src_tensor(i); });
}
}
/**
* \brief Perform generic copy between two tensors partitions (threadwise copy).
* Tensors must have the same size.
*
* \param src_tensor Source tensor.
* \param dst_tensor Destination tensor.
*/
template <typename SrcTensorType, typename DstTensorType>
__host__ __device__ void copy(const SrcTensorType& src_tensor, DstTensorType& dst_tensor)
{
// Generate default params
using SrcShapeType = remove_cvref_t<decltype(shape(src_tensor))>;
constexpr index_t num_dims = SrcShapeType::Size();
// Incrementing dims 0, 1, 2 ... num_dims - 1
constexpr auto dim_access_order_tuple =
generate_tuple([](auto i) { return Number<i>{}; }, Number<num_dims>{});
constexpr index_t vector_dim = num_dims - 1;
constexpr index_t scalar_per_vector = 1;
copy<decltype(dim_access_order_tuple), vector_dim, scalar_per_vector>(src_tensor, dst_tensor);
}
/**
* \brief Perform optimized blockwise copy between two tensors. Tensors must have the
* same size.
*
* \note At now Vgpr and Sgpr are not supported.
*
* \tparam DimAccessOrderTuple Tuple with dimension access order.
* \tparam VectorDim Dimension for vectorize read and write.
* \tparam ScalarPerVector Number of scalar per vectorize read and write.
* \param src_tensor Source tensor.
* \param dst_tensor Destination tensor.
* \param thread_layout Thread layout per each dimension for copy.
*/
template <typename DimAccessOrderTuple,
index_t VectorDim,
index_t ScalarPerVector,
typename SrcTensorType,
typename DstTensorType,
typename ThreadLayoutTuple>
__device__ void blockwise_copy(const SrcTensorType& src_tensor,
DstTensorType& dst_tensor,
[[maybe_unused]] ThreadLayoutTuple& thread_layout)
{
static_assert(SrcTensorType::IsDynamicBuffer && DstTensorType::IsDynamicBuffer);
static_assert(is_detected<is_tuple, DimAccessOrderTuple>::value);
const auto& in_grid_desc = layout(src_tensor).GetUnrolledDescriptor();
const auto& out_grid_desc = layout(dst_tensor).GetUnrolledDescriptor();
using SrcShapeType = remove_cvref_t<decltype(shape(src_tensor))>;
constexpr index_t num_dims = SrcShapeType::Size();
constexpr auto tile_lengths_seq =
generate_sequence_v2([](auto I) { return size(SrcShapeType{}.At(I)); }, Number<num_dims>{});
constexpr auto thread_layout_seq = generate_sequence_v2(
[](auto I) { return size(ThreadLayoutTuple{}.At(I)); }, Number<num_dims>{});
constexpr auto dim_access_order = generate_sequence_v2(
[](auto I) { return DimAccessOrderTuple{}.At(I); }, Number<num_dims>{});
using ThisThreadBlock = ThisThreadBlock<size(ThreadLayoutTuple{})>;
// Perform copy between DynamicBuffers
auto transfer = ThreadGroupTensorSliceTransfer_v7<
ThisThreadBlock,
Tuple<typename SrcTensorType::TensorElementType>,
Tuple<typename DstTensorType::TensorElementType>,
decltype(tie(in_grid_desc)),
decltype(tie(out_grid_desc)),
tensor_operation::element_wise::PassThrough,
Sequence<static_cast<index_t>(InMemoryDataOperationEnum::Set)>,
std::remove_const_t<decltype(tile_lengths_seq)>,
std::remove_const_t<decltype(thread_layout_seq)>,
std::remove_const_t<decltype(dim_access_order)>,
std::remove_const_t<decltype(dim_access_order)>,
VectorDim,
ScalarPerVector,
Sequence<true>,
Sequence<true>>{in_grid_desc,
make_tuple(src_tensor.GetMultiIdxOffsets()),
out_grid_desc,
make_tuple(dst_tensor.GetMultiIdxOffsets()),
tensor_operation::element_wise::PassThrough{}};
transfer.Run(tie(in_grid_desc),
tie(src_tensor.GetBuffer()),
tie(out_grid_desc),
tie(dst_tensor.GetBuffer()));
}
} // namespace wrapper
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/wrapper/utils/tensor_utils.hpp"
#include "ck/wrapper/traits/blockwise_gemm_xdl_traits.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
namespace ck {
namespace wrapper {
namespace {
namespace detail {
/**
* \brief Create block descriptor (K0, MPerBlock or NPerBlock, K1).
*
*
* \tparam K1 The number of K-dim elements that are packed together as a separate logical dimension.
* \tparam TileLayout Tensor data tile layout (M,K) or (N,K).
*
* \return Block descriptor (K0, MPerBlock or NPerBlock, K1)
*/
template <index_t K1, typename TileLayout>
__device__ constexpr auto GetBlockDescriptor()
{
using TileLayoutShape = typename TileLayout::LayoutShape;
using TileLayoutDescriptor = typename TileLayout::LayoutUnrolledDescriptorType;
constexpr auto K0PerBlock = Number<size<1>(TileLayoutShape{})>{} / Number<K1>{};
// MPerBlock or NPerBlock
constexpr auto Dim0 = Number<size<0>(TileLayoutShape{})>{};
constexpr auto a_block_desc_k0_m_k1 = transform_tensor_descriptor(
TileLayoutDescriptor{},
make_tuple(make_unmerge_transform(make_tuple(K0PerBlock, Number<K1>{})),
make_pass_through_transform(Dim0)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_block_desc_k0_m_k1;
}
} // namespace detail
} // namespace
/**
* \brief Perform blockwise gemm xdl on tensors stored in lds. Result will be
* stored in Vgpr register. A data layout must be (MPerBlock, KPerBlock) and B
* data layout must be (NPerBlock, KPerBlock).
*
* \note C output Vgpr register layout (8D):
* - MXdlPerWave - The number of MFMA instructions run by single wave in M
* dimension per tile.
* - NXdlPerWave - The number of MFMA instructions run by single wave in N
* dimension per tile.
* - MWave - Equals to 1 since this is for single wave.
* - NWave - Equals to 1 since this is for single wave.
* - NumGroupsPerBlock - Mfma instruction internal layout (depeneds on the
* instruction size).
* - NumInputsBlock - Mfma instruction internal layout (depeneds on the
* instruction size).
* - GroupSize - Mfma instruction internal layout (depeneds on the
* instruction size).
* - NumThreadsPerBlock - Mfma instruction internal layout (depeneds on the
* instruction size).
*
* \tparam DataType Input data types.
* \tparam BlockSize Tensor to pad.
* \tparam GemmTraits Traits of gemm xdl operation.
* \param a_local_tile_tensor A tensor in LDS memory for blockwise gemm
* (MPerBlock, KPerBlock) layout.
* \param b_local_tile_tensor B tensor in LDS memory for blockwise gemm
* (NPerBlock, KPerBlock) layout.
* \param c_reg_tensor C tensor VGPR memory for blockwise gemm.
*/
template <typename DataType,
index_t BlockSize,
typename GemmTraits,
typename ATensorType,
typename BTensorType,
typename CTensorType>
__device__ void blockwise_gemm_xdl(const ATensorType& a_local_tile_tensor,
const BTensorType& b_local_tile_tensor,
CTensorType& c_reg_tensor)
{
static_assert(ATensorType::TensorBufferAddressSpace == MemoryTypeEnum::Lds);
static_assert(BTensorType::TensorBufferAddressSpace == MemoryTypeEnum::Lds);
static_assert(CTensorType::TensorBufferAddressSpace == MemoryTypeEnum::Vgpr);
static_assert(is_same_v<DataType, typename ATensorType::TensorElementType>);
static_assert(is_same_v<DataType, typename BTensorType::TensorElementType>);
constexpr bool is_integer =
is_same_v<DataType, int8_t> || is_same_v<DataType, int16_t> || is_same_v<DataType, int32_t>;
using GemmAccDataType = std::conditional_t<is_integer, int32_t, float>;
using ATileLayout = remove_cvref_t<decltype(layout(a_local_tile_tensor))>;
using BTileLayout = remove_cvref_t<decltype(layout(b_local_tile_tensor))>;
using ABlockDesc_K0_M_K1_Type =
decltype(detail::GetBlockDescriptor<GemmTraits::K1, ATileLayout>());
using BBlockDesc_K0_N_K1_Type =
decltype(detail::GetBlockDescriptor<GemmTraits::K1, BTileLayout>());
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
DataType,
DataType,
GemmAccDataType,
ABlockDesc_K0_M_K1_Type,
BBlockDesc_K0_N_K1_Type,
GemmTraits::MPerXDL,
GemmTraits::NPerXDL,
GemmTraits::MXdlPerWave,
GemmTraits::NXdlPerWave,
GemmTraits::K1>
blockwise_gemm_xdl_op{};
blockwise_gemm_xdl_op.Run(
a_local_tile_tensor.GetBuffer(), b_local_tile_tensor.GetBuffer(), c_reg_tensor.GetBuffer());
}
/**
* \brief Create local partition per thread for C tensor.
*
* \note C output global memory layout (8D):
* - MXdlPerWave - The number of MFMA instructions run by single wave in M
* dimension.
* - NXdlPerWave - The number of MFMA instructions run by single wave in N
* dimension.
* - MWave - The number of waves in single tile M dimension per tile.
* - NWave - The number of waves in single tile N dimension per tile.
* - NumGroupsPerBlock - Mfma instruction internal layout (depeneds on the
* instruction size).
* - NumInputsBlock - Mfma instruction internal layout (depeneds on the
* instruction size).
* - GroupSize - Mfma instruction internal layout (depeneds on the
* instruction size).
* - NumThreadsPerBlock - Mfma instruction internal layout (depeneds on the
* instruction size).
*
* \tparam DataType Input data types.
* \tparam ATileLayout A tensor layout.
* \tparam BTileLayout B tensor layout.
* \tparam BlockSize Number of threads in block.
* \tparam GemmTraits Traits of gemm xdl operation.
* \param c_local_tile_tensor C tensor in LDS memory for blockwise gemm
* (MPerBlock, NPerBlock) layout.
*
* \return Partition c tensor for blockwise gemm.
*/
template <typename DataType,
typename ATileLayout,
typename BTileLayout,
index_t BlockSize,
typename GemmTraits,
typename CTensorType>
__host__ __device__ constexpr auto
make_blockwise_gemm_xdl_c_local_partition(CTensorType& c_local_tile_tensor)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
constexpr auto I5 = Number<5>{};
constexpr auto I6 = Number<6>{};
constexpr auto I7 = Number<7>{};
constexpr bool is_integer =
is_same_v<DataType, int8_t> || is_same_v<DataType, int16_t> || is_same_v<DataType, int32_t>;
using GemmAccDataType = std::conditional_t<is_integer, int32_t, float>;
using ABlockDesc_K0_M_K1_Type =
decltype(detail::GetBlockDescriptor<GemmTraits::K1, ATileLayout>());
using BBlockDesc_K0_N_K1_Type =
decltype(detail::GetBlockDescriptor<GemmTraits::K1, BTileLayout>());
using BlockwiseGemmXdlops =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
DataType,
DataType,
GemmAccDataType,
ABlockDesc_K0_M_K1_Type,
BBlockDesc_K0_N_K1_Type,
GemmTraits::MPerXDL,
GemmTraits::NPerXDL,
GemmTraits::MXdlPerWave,
GemmTraits::NXdlPerWave,
GemmTraits::K1>;
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
BlockwiseGemmXdlops::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0);
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1);
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I2);
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I3);
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I4);
constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I5);
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6);
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I7);
// Calculate offset on grid
const auto c_thread_mtx_on_block =
BlockwiseGemmXdlops::CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
const index_t m_thread_data_on_grid =
c_local_tile_tensor.GetMultiIdxOffsets()[I0] + c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_grid =
c_local_tile_tensor.GetMultiIdxOffsets()[I1] + c_thread_mtx_on_block[I1];
const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_grid_idx =
m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_grid));
const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor =
make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_grid_idx =
n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_grid));
// Create partition shape based on descriptor dims.
const auto partition_shape = make_tuple(M0, N0, I1, I1, M2, I1, M4, I1);
const auto partition_desc = BlockwiseGemmXdlops::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(
layout(c_local_tile_tensor).GetUnrolledDescriptor());
const auto partition_layout =
Layout<remove_reference_t<decltype(partition_shape)>, decltype(partition_desc)>(
partition_shape, partition_desc);
auto partition_tensor = make_tensor<CTensorType::TensorBufferAddressSpace>(
c_local_tile_tensor.GetPointer(), partition_layout);
partition_tensor.SetMultiIdxOffset(make_multi_index(m_thread_data_on_grid_idx[I0],
n_thread_data_on_grid_idx[I0],
m_thread_data_on_grid_idx[I1],
n_thread_data_on_grid_idx[I1],
m_thread_data_on_grid_idx[I2],
m_thread_data_on_grid_idx[I3],
m_thread_data_on_grid_idx[I4],
n_thread_data_on_grid_idx[I2]));
return partition_tensor;
}
/**
* \brief Create local partition per thread for C tensor.
*
* \note C output Vgpr register layout (8D):
* - MXdlPerWave - The number of MFMA instructions run by single wave in M
* dimension per tile.
* - NXdlPerWave - The number of MFMA instructions run by single wave in N
* dimension per tile.
* - MWave - Equals to 1 since this is for single wave.
* - NWave - Equals to 1 since this is for single wave.
* - NumGroupsPerBlock - Mfma instruction internal layout (depeneds on the
* instruction size).
* - NumInputsBlock - Mfma instruction internal layout (depeneds on the
* instruction size).
* - GroupSize - Mfma instruction internal layout (depeneds on the
* instruction size).
* - NumThreadsPerBlock - Mfma instruction internal layout (depeneds on the
* instruction size).
*
* \tparam DataType Input data types.
* \tparam ATileLayout A tensor layout.
* \tparam BTileLayout B tensor layout.
* \tparam BlockSize Number of threads in block.
* \tparam GemmTraits Traits of gemm xdl operation.
*
* \return Vgpr c tensor for blockwise gemm.
*/
template <typename DataType,
typename ATileLayout,
typename BTileLayout,
index_t BlockSize,
typename GemmTraits>
__host__ __device__ constexpr auto make_blockwise_gemm_xdl_c_vgpr()
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
constexpr auto I5 = Number<5>{};
constexpr auto I6 = Number<6>{};
constexpr auto I7 = Number<7>{};
constexpr bool is_integer =
is_same_v<DataType, int8_t> || is_same_v<DataType, int16_t> || is_same_v<DataType, int32_t>;
using GemmAccDataType = std::conditional_t<is_integer, int32_t, float>;
using ABlockDesc_K0_M_K1_Type =
decltype(detail::GetBlockDescriptor<GemmTraits::K1, ATileLayout>());
using BBlockDesc_K0_N_K1_Type =
decltype(detail::GetBlockDescriptor<GemmTraits::K1, BTileLayout>());
using BlockwiseGemmXdlops =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
DataType,
DataType,
GemmAccDataType,
ABlockDesc_K0_M_K1_Type,
BBlockDesc_K0_N_K1_Type,
GemmTraits::MPerXDL,
GemmTraits::NPerXDL,
GemmTraits::MXdlPerWave,
GemmTraits::NXdlPerWave,
GemmTraits::K1>;
// Calcualte descriptor, shape and layout
constexpr auto vgpr_desc = BlockwiseGemmXdlops::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
const auto vgpr_shape = make_tuple(vgpr_desc.GetLengths()[I0],
vgpr_desc.GetLengths()[I1],
vgpr_desc.GetLengths()[I2],
vgpr_desc.GetLengths()[I3],
vgpr_desc.GetLengths()[I4],
vgpr_desc.GetLengths()[I5],
vgpr_desc.GetLengths()[I6],
vgpr_desc.GetLengths()[I7]);
const auto vgpr_layout = Layout<remove_reference_t<decltype(vgpr_shape)>, decltype(vgpr_desc)>(
vgpr_shape, vgpr_desc);
// Get vector type for Vgpr
using BlockwiseGemmCThreadBufferType =
remove_reference_t<decltype(BlockwiseGemmXdlops{}.GetCThreadBuffer())>;
using VgprVectorType = typename BlockwiseGemmCThreadBufferType::V;
return ck::wrapper::make_register_tensor<ck::wrapper::MemoryTypeEnum::Vgpr, VgprVectorType>(
vgpr_layout);
}
} // namespace wrapper
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "utils/tensor_utils.hpp"
#include "utils/tensor_partition.hpp"
#include "utils/layout_utils.hpp"
namespace ck {
namespace wrapper {
namespace {
namespace detail {
/**
* \brief Check if Tuple contains Slice object
*
* \return True if tuple contains Slice object.
*/
template <typename T>
__host__ __device__ constexpr bool HasSlice(T&&)
{
return is_detected<is_slice, T>::value;
}
template <typename... Ts>
__host__ __device__ constexpr bool HasSlice(Tuple<Ts...>&&)
{
return (HasSlice(Ts{}) || ...);
}
/**
* \brief Calculate new shape after slice from parent shape.
*
* \param idxs Tuple of indexes defining slice ranges.
* \param shape Shape which will be sliced.
* \return New tensor shape.
*/
template <typename... Ts, typename SlicedShape>
__host__ __device__ constexpr auto GetSlicedShape(const Tuple<Ts...>& idxs,
const SlicedShape& shape)
{
// Pack each value in tuple to remove empty tuples after generation
auto new_shape = generate_tuple(
[&](auto i) {
constexpr auto num_i = Number<i>{};
if constexpr(is_detected<is_tuple, tuple_element_t<i.value, Tuple<Ts...>>>::value)
{
if constexpr(!detail::HasSlice(tuple_element_t<i.value, Tuple<Ts...>>{}))
{
// if tuple does not have any slice then we can remove dimension
return Tuple<>{};
}
else
{
// if tuple then recurrence
return make_tuple(GetSlicedShape(idxs.At(num_i), shape.At(num_i)));
}
}
else if constexpr(is_detected<is_slice, tuple_element_t<i.value, Tuple<Ts...>>>::value)
{
// calculate new dimension
const auto& dim = size(shape.At(num_i));
const auto val = idxs.At(num_i).range(dim);
return make_tuple(val);
}
else
{
// remove dimension for just value
return Tuple<>{};
}
},
Number<Tuple<Ts...>::Size()>{});
// Remove empty tuples (deleted elements) and return
return UnrollNestedTuple<0, 1>(new_shape);
}
/**
* \brief Generate Freeze for each of nested shape.
*
* \param idx Tuple of start indices for slice.
* \param shape Shape which will be freezed.
* \return Generated freeze transforms.
*/
template <typename T, typename Shape>
__host__ __device__ constexpr auto GenerateMultipleFreeze(T idx, const Shape& shape)
{
const auto unrolled_shape = UnrollNestedTuple(shape);
return generate_tuple(
[&](auto i) {
// dimension offset from idx
const auto dim = unrolled_shape.At(Number<i>{});
const auto dim_idx = idx % dim;
idx /= dim;
return make_freeze_transform(dim_idx);
},
Number<decltype(unrolled_shape)::Size()>{});
}
/**
* \brief Generate transforms for slice tensor.
*
* \param idx Tuple of start indices for slice.
* \param shape Shape which will be sliced.
* \return Generated transforms.
*/
template <typename... Ts, typename Shape>
__host__ __device__ constexpr auto GenerateSliceTransforms(const Tuple<Ts...>& idx,
const Shape& shape)
{
// Pack each value in tuple to remove empty tuples after generation
auto transforms = generate_tuple(
[&](auto i) {
constexpr auto num_i = Number<i>{};
if constexpr(is_detected<is_tuple, tuple_element_t<i.value, Tuple<Ts...>>>::value)
{
return GenerateSliceTransforms(idx.At(num_i), shape.At(num_i));
}
else if constexpr(is_detected<is_slice, tuple_element_t<i.value, Tuple<Ts...>>>::value)
{
const auto from = idx.At(num_i).from_;
const auto dim = size<num_i>(shape);
const auto range = idx.At(num_i).range(dim);
return make_slice_transform(range, from, from + range);
}
else
{
// remove dimension for just value
return GenerateMultipleFreeze(idx.At(num_i), shape.At(num_i));
}
},
Number<Tuple<Ts...>::Size()>{});
// Remove empty tuples (deleted elements) and return
return UnrollNestedTuple(transforms);
}
template <index_t i, typename LowerIndex>
__host__ __device__ constexpr auto GetSequenceVal(const ck::Freeze<LowerIndex>&)
{
// There is no output for Freeze transform
return Sequence<>{};
}
template <index_t i, typename LowLength, typename SliceBegin, typename SliceEnd>
__host__ __device__ constexpr auto GetSequenceVal(const ck::Slice<LowLength, SliceBegin, SliceEnd>&)
{
return Sequence<i>{};
}
template <index_t i>
__host__ __device__ constexpr auto GenerateUpperDims(const Tuple<>&)
{
return Tuple<>{};
}
template <index_t i, typename... Transforms>
__host__ __device__ constexpr auto GenerateUpperDims(const Tuple<Transforms...>& transforms)
{
constexpr auto num_transforms = Tuple<Transforms...>::Size();
// Deduce Sequence element for specific transform
const auto current_elem = GetSequenceVal<i>(transforms.At(Number<0>{}));
if constexpr(is_same_v<decltype(current_elem), const Sequence<>>)
{
const auto next_tuple = GenerateUpperDims<i>(TupleSlice<1, num_transforms>(transforms));
return concat_tuple(make_tuple(current_elem), next_tuple);
}
else
{
// Increase i if current_elem is Slice transform
const auto next_tuple = GenerateUpperDims<i + 1>(TupleSlice<1, num_transforms>(transforms));
return concat_tuple(make_tuple(current_elem), next_tuple);
}
}
template <typename... Ts, typename Shape, typename FlattenDescriptor>
__host__ __device__ constexpr auto GenerateSlicedDescriptor(const Tuple<Ts...>& idx,
const Shape& shape,
const FlattenDescriptor& flatten_desc)
{
constexpr auto old_shape_dims = decltype(UnrollNestedTuple(shape))::Size();
const auto transforms = GenerateSliceTransforms(idx, shape);
using TransformsTupleType = decltype(transforms);
const auto lower_dims =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<old_shape_dims>{});
const auto upper_dims = decltype(GenerateUpperDims<0>(TransformsTupleType{})){};
return transform_tensor_descriptor(flatten_desc, transforms, lower_dims, upper_dims);
}
} // namespace detail
} // namespace
/**
* \brief Tensor wrapper that performs static and dynamic buffer logic.
* The tensor is based on a descriptor stored in the Layout. Additionally,
* tensor can be sliced or shifted using multi-index offset.
*
* \tparam BufferAddressSpace Memory type (Generic, Global, LDS, VGPR, SGPR).
* \tparam ElementType Element data type.
* \tparam Shape Tensor shape (layout component).
* \tparam UnrolledDescriptorType Flatten descriptor (layout component).
*/
template <MemoryTypeEnum BufferAddressSpace,
typename ElementType,
typename Shape,
typename UnrolledDescriptorType>
struct Tensor
{
public:
using ElementSpaceSize = decltype(Layout<Shape, UnrolledDescriptorType>{
Shape{}, UnrolledDescriptorType{}}.GetElementSpaceSize()); // SpaceSize type for buffer
using TensorElementType = std::conditional_t<
is_scalar_type<ElementType>::value,
ElementType,
typename scalar_type<std::remove_const_t<ElementType>>::type>; // DataType
static constexpr MemoryTypeEnum TensorBufferAddressSpace = BufferAddressSpace;
static constexpr bool IsDynamicBuffer = !(BufferAddressSpace == MemoryTypeEnum ::Sgpr ||
BufferAddressSpace == MemoryTypeEnum ::Vgpr);
__host__ __device__ Tensor() = delete;
__host__ __device__ constexpr Tensor(ElementType* pointer,
const Layout<Shape, UnrolledDescriptorType>& layout)
: layout_(layout),
buffer_(make_dynamic_buffer<BufferAddressSpace>(pointer, layout.GetElementSpaceSize())),
multi_idx_offset_(make_zero_multi_index<Shape::Size()>()),
base_offset_(0)
{
static_assert(IsDynamicBuffer, "Wrong BufferAddressSpace for register.");
}
__host__ __device__ constexpr Tensor(const Layout<Shape, UnrolledDescriptorType>& layout)
: layout_(layout),
multi_idx_offset_(make_zero_multi_index<Shape::Size()>()),
base_offset_(0)
{
static_assert(!IsDynamicBuffer, "Wrong BufferAddressSpace for register.");
}
__host__ __device__ constexpr const Layout<Shape, UnrolledDescriptorType>& GetLayout() const
{
return layout_;
}
/**
* \brief Get the new sliced tensor.
*
* \param idx Tuple of indices: slice(from,to) or scalar.
* \return Sliced tensor.
*/
template <typename... Ts, enable_if_t<detail::HasSlice(Tuple<Ts...>{}), bool> = false>
__host__ __device__ auto operator[](const Tuple<Ts...>& idx)
{
static_assert(IsDynamicBuffer, "Register slice is not supported");
const auto& shape = layout_.GetShape();
auto new_shape = detail::GetSlicedShape(idx, shape);
const auto& flatten_desc = layout_.GetUnrolledDescriptor();
auto new_desc = detail::GenerateSlicedDescriptor(idx, shape, flatten_desc);
const auto new_layout =
Layout<decltype(new_shape), decltype(new_desc)>(new_shape, new_desc);
// Update embed offset
base_offset_ -= new_layout(make_tuple(Number<0>{}));
return make_tensor<BufferAddressSpace>(buffer_.p_data_, new_layout);
}
template <typename... Ts, enable_if_t<detail::HasSlice(Tuple<Ts...>{}), bool> = false>
__host__ __device__ auto operator()(const Tuple<Ts...>& idx)
{
return this->operator[](idx);
}
template <typename... Idxs, enable_if_t<detail::HasSlice(Tuple<Idxs...>{}), bool> = false>
__host__ __device__ auto operator()(Idxs... idxs)
{
return this->operator[](make_tuple(idxs...));
}
/**
* \brief Getter of the tensor's const value reference.
*
* \param idx Tuple of indices.
* \return Requested value.
*/
template <typename... Ts, enable_if_t<!detail::HasSlice(Tuple<Ts...>{}), bool> = false>
__host__ __device__ const TensorElementType& operator[](const Tuple<Ts...>& idx) const
{
if constexpr(IsDynamicBuffer)
{
const index_t offset = layout_(idx) + base_offset_;
return buffer_[offset];
}
else
{
constexpr index_t index_offset = Layout<Shape, UnrolledDescriptorType>{
Shape{},
UnrolledDescriptorType{}}.template operator()<Tuple<Ts...>>();
// Calculate and apply base offset in compile-time
constexpr index_t base_offset = Layout<Shape, UnrolledDescriptorType>{
Shape{},
UnrolledDescriptorType{}}.template operator()<MultiIndex<Shape::Size()>>();
return buffer_[Number<index_offset + base_offset>{}];
}
}
template <typename... Ts, enable_if_t<!detail::HasSlice(Tuple<Ts...>{}), bool> = false>
__host__ __device__ const TensorElementType& operator()(const Tuple<Ts...>& idx) const
{
return this->operator[](idx);
}
template <typename... Idxs, enable_if_t<!detail::HasSlice(Tuple<Idxs...>{}), bool> = false>
__host__ __device__ const TensorElementType& operator()(Idxs... idxs) const
{
return this->operator[](make_tuple(idxs...));
}
/**
* \brief Getter of tensor value reference.
*
* \param idx Tuple of indices.
* \return Requested value.
*/
template <typename... Ts, enable_if_t<!detail::HasSlice(Tuple<Ts...>{}), bool> = false>
__host__ __device__ TensorElementType& operator[](const Tuple<Ts...>& idx)
{
if constexpr(IsDynamicBuffer)
{
const index_t offset = layout_(idx) + base_offset_;
return buffer_(offset);
}
else
{
constexpr index_t index_offset = Layout<Shape, UnrolledDescriptorType>{
Shape{},
UnrolledDescriptorType{}}.template operator()<Tuple<Ts...>>();
// Apply embed offset (calculate in compiletime)
constexpr index_t base_offset = Layout<Shape, UnrolledDescriptorType>{
Shape{},
UnrolledDescriptorType{}}.template operator()<MultiIndex<Shape::Size()>>();
return buffer_(Number<index_offset + base_offset>{});
}
}
template <typename... Ts, enable_if_t<!detail::HasSlice(Tuple<Ts...>{}), bool> = false>
__host__ __device__ TensorElementType& operator()(const Tuple<Ts...>& idx)
{
return this->operator[](idx);
}
template <typename... Idxs, enable_if_t<!detail::HasSlice(Tuple<Idxs...>{}), bool> = false>
__host__ __device__ TensorElementType& operator()(Idxs... idxs)
{
return this->operator[](make_tuple(idxs...));
}
/**
* \brief Get descriptor with all nested dimensions merged.
*
* \return Merged nests descriptor.
*/
__host__ __device__ constexpr auto GetMergedNestingDescriptor()
{
return layout_.GetMergedNestingDescriptor();
}
/**
* \brief Get pointer to the data.
*
* \return Pointer.
*/
__host__ __device__ TensorElementType* GetPointer() const { return buffer_.p_data_; }
__host__ __device__ constexpr auto& GetBuffer() { return buffer_; }
__host__ __device__ constexpr auto& GetBuffer() const { return buffer_; }
/**
* \brief Get multi index offset to the data.
*
* \return Multi index offset.
*/
__host__ __device__ constexpr auto& GetMultiIdxOffsets() const { return multi_idx_offset_; }
/**
* \brief Apply multi index offset on the tensor.
*
* \param multi_idx_offset Multi index offset.
*/
template <typename MultiIdxOffsets>
__host__ __device__ constexpr void SetMultiIdxOffset(const MultiIdxOffsets multi_idx_offset)
{
multi_idx_offset_ = multi_idx_offset;
base_offset_ += layout_(multi_idx_offset);
}
private:
using DynamicBufferType = DynamicBuffer<BufferAddressSpace,
ElementType,
ElementSpaceSize,
true /*InvalidElementUseNumericalZeroValue*/>;
using StaticBufferType = std::conditional_t<
is_scalar_type<ElementType>::value,
StaticBuffer<BufferAddressSpace,
ElementType,
size(Shape{}),
true /*InvalidElementUseNumericalZeroValue*/>,
StaticBufferTupleOfVector<BufferAddressSpace,
TensorElementType,
size(Shape{}) /
scalar_type<std::remove_const_t<ElementType>>::vector_size,
scalar_type<std::remove_const_t<ElementType>>::vector_size,
true /*InvalidElementUseNumericalZeroValue*/>>;
// If register use static buffer, else use dynamic buffer
using Buffer = std::conditional_t<IsDynamicBuffer, DynamicBufferType, StaticBufferType>;
const Layout<Shape, UnrolledDescriptorType> layout_;
Buffer buffer_;
// We use multi_idx_offset_ to enable the creation of a descriptor in
// compile time for partitions or tiles if tile shape and thread layout
// is known at compile time (We can use the same descriptor for each
// thread). Additionally, the copy between the static and dynamic buffer
// requires a descriptor known at compile time, so we can shift data using
// such multi_idx_offset_.
MultiIndex<Shape::Size()> multi_idx_offset_;
// Base offset and multi index offset are corresponding to exactly the
// same element in tensor ( and in physical memory ). Multi index offset
// is multi dimensional index. However base offset is calculated using
// tensor descriptor (thus all it's transforms) and is linear (1D).
// We store base_offset_ to avoid multiple recalculations.
index_t base_offset_;
};
} // namespace wrapper
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
namespace ck {
namespace wrapper {
/**
* \brief Traits for blockwise gemm xdl.
*
* \tparam MPerXDLValue The MFMA instruction size in M dimension.
* \tparam NPerXDLValue The MFMA instruction size in N dimension.
* \tparam MXdlPerWaveValue The number of MFMA instructions run by single
* wave in M dimension.
* \tparam NXdlPerWaveValue The number of MFMA instructions run by single
* wave in N dimension.
* \tparam K1Value The number of K-dim elements that are packed together as
* a separate logical dimension. Usually aligns with vector load size.
*/
template <index_t MPerXDLValue,
index_t NPerXDLValue,
index_t MXdlPerWaveValue,
index_t NXdlPerWaveValue,
index_t K1Value>
struct BlockwisGemmXdlTraits
{
static constexpr index_t MPerXDL = MPerXDLValue;
static constexpr index_t NPerXDL = NPerXDLValue;
static constexpr index_t MXdlPerWave = MXdlPerWaveValue;
static constexpr index_t NXdlPerWave = NXdlPerWaveValue;
static constexpr index_t K1 = K1Value;
};
// K1 = 4
struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_4K1 : BlockwisGemmXdlTraits<32, 32, 4, 2, 4>
{
};
struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_4K1 : BlockwisGemmXdlTraits<32, 32, 2, 4, 4>
{
};
struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_4K1 : BlockwisGemmXdlTraits<32, 32, 2, 2, 4>
{
};
// K1 = 8
struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_8K1 : BlockwisGemmXdlTraits<32, 32, 4, 2, 8>
{
};
struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_8K1 : BlockwisGemmXdlTraits<32, 32, 2, 4, 8>
{
};
struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_8K1 : BlockwisGemmXdlTraits<32, 32, 2, 2, 8>
{
};
// K1 = 16
struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_16K1 : BlockwisGemmXdlTraits<32, 32, 4, 2, 16>
{
};
struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_16K1 : BlockwisGemmXdlTraits<32, 32, 2, 4, 16>
{
};
struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_16K1 : BlockwisGemmXdlTraits<32, 32, 2, 2, 16>
{
};
} // namespace wrapper
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/utility/number.hpp"
#include "ck/utility/tuple.hpp"
#include "ck/utility/tuple_helper.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/utility/sequence_helper.hpp"
#include "ck/utility/is_detected.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
namespace ck {
namespace wrapper {
// Disable from doxygen docs generation
/// @cond
// forward declaration
template <typename Shape, typename UnrolledDescriptorType>
struct Layout;
template <typename T>
using is_tuple = decltype(std::declval<T&>().IsTuple());
namespace {
/**
* \brief Generate packed (column-major) strides if not passed
*
* \param shape Tensor shape.
* \return Generated column-major strides.
*/
template <typename... Ts>
__host__ __device__ constexpr static auto
GenerateColumnMajorPackedStrides(const Tuple<Ts...>& shape)
{
const auto unrolled_shape = UnrollNestedTuple(shape);
return generate_tuple(
[&](auto i) {
if constexpr(i.value == 0)
{
return Number<1>{};
}
else
{
return TupleReduce<Number<0>{}.value, i.value>([](auto x, auto y) { return x * y; },
unrolled_shape);
}
},
Number<decltype(unrolled_shape)::Size()>{});
}
/**
* \brief Create naive tensor descriptor from nested shape.
*
* \param shape Tensor shape.
* \param strides Tensor strides.
* \return Unrolled descriptor
*/
template <typename LayoutShape, typename LayoutStrides>
__host__ __device__ constexpr auto MakeUnrolledDescriptor(const LayoutShape& shape,
const LayoutStrides& strides)
{
const auto unrolled_shape = UnrollNestedTuple(shape);
if constexpr(is_same_v<LayoutStrides, Tuple<>>)
{
// if not passed, then generate
const auto unrolled_strides = GenerateColumnMajorPackedStrides(unrolled_shape);
static_assert(unrolled_shape.Size() == unrolled_strides.Size(),
"Size of strides and shape are not consistent.");
return make_naive_tensor_descriptor(unrolled_shape, unrolled_strides);
}
else
{
const auto unrolled_strides = UnrollNestedTuple(strides);
static_assert(unrolled_shape.Size() == unrolled_strides.Size(),
"Size of strides and shape are not consistent.");
return make_naive_tensor_descriptor(unrolled_shape, unrolled_strides);
}
}
} // namespace
/// @endcond
// make_*
/**
* \brief Make layout function.
*
* \tparam Shape Shape for layout.
* \tparam Strides Strides for layout.
* \return Constructed layout.
*/
template <typename Shape, typename Strides>
__host__ __device__ constexpr auto make_layout(const Shape& shape, const Strides& strides)
{
using UnrolledDescriptorType = decltype(MakeUnrolledDescriptor(Shape{}, Strides{}));
return Layout<Shape, UnrolledDescriptorType>(shape, MakeUnrolledDescriptor(shape, strides));
}
/**
* \brief Make layout function with packed strides
* (column-major).
*
* \tparam Shape Shape for layout.
* \return Constructed layout.
*/
template <typename Shape>
__host__ __device__ constexpr auto make_layout(const Shape& shape)
{
using UnrolledDescriptorType = decltype(MakeUnrolledDescriptor(Shape{}, Tuple<>{}));
return Layout<Shape, UnrolledDescriptorType>(shape, MakeUnrolledDescriptor(shape, Tuple<>{}));
}
// Layout helpers
// get
/**
* \private
* \brief Get dim.
*
* \param dim Dimension.
* \return Returned the same dimension.
*/
template <typename T>
__host__ __device__ T constexpr get(const T& dim)
{
return dim;
}
/**
* \brief Get element from tuple (Shape/Strides/Idxs).
*
* \tparam idx Index to lookup.
* \param tuple Tuple to lookup.
* \return Requsted element.
*/
template <index_t idx, typename... Dims>
__host__ __device__ constexpr auto get(const Tuple<Dims...>& tuple)
{
return tuple.At(Number<idx>{});
}
/**
* \brief Get sub layout.
*
* \tparam idx Index to lookup.
* \param layout Layout to create sub layout.
* \return Requsted sub layout.
*/
template <index_t idx, typename Shape, typename FlattenDesc>
__host__ __device__ constexpr auto get(const Layout<Shape, FlattenDesc>& layout)
{
const auto& shape = layout.GetShape();
const auto new_shape = get<idx>(shape);
static_assert(is_detected<is_tuple, decltype(new_shape)>::value,
"Shape of sub layout must be tuple");
constexpr auto old_shape_dims = decltype(UnrollNestedTuple(shape))::Size();
constexpr auto new_shape_dims = decltype(UnrollNestedTuple(new_shape))::Size();
constexpr auto shape_offset = decltype(UnrollNestedTuple(TupleSlice<0, idx>(shape)))::Size();
const auto unrolled_shape = UnrollNestedTuple(shape);
const auto transforms = generate_tuple(
[&](auto i) {
// Compare Idx with shape
if constexpr(i < shape_offset || i >= shape_offset + new_shape_dims)
{
// Remove dimension
return make_freeze_transform(Number<0>{});
}
else
{
return make_pass_through_transform(unrolled_shape.At(i));
}
},
Number<old_shape_dims>{});
const auto lower_dims =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<old_shape_dims>{});
const auto upper_dims = generate_tuple(
[&](auto i) {
if constexpr(i < shape_offset || i >= shape_offset + new_shape_dims)
return Sequence<>{};
else
{
return Sequence<i.value - shape_offset>{};
}
},
Number<old_shape_dims>{});
const auto& flatten_desc = layout.GetUnrolledDescriptor();
auto new_desc = transform_tensor_descriptor(flatten_desc, transforms, lower_dims, upper_dims);
return Layout<decltype(new_shape), decltype(new_desc)>(new_shape, new_desc);
}
/**
* \brief Hierarchical get.
*
* \tparam Idxs Indexes to lookup.
* \param elem Element to lookup.
* \return Requsted element.
*/
template <index_t Idx, index_t... Idxs, typename T>
__host__ __device__ constexpr auto get(const T& elem)
{
return get<Idxs...>(get<Idx>(elem));
}
// size
/**
* \private
* \brief Get size.
*
* \param dim Size.
* \return Returned the same size.
*/
template <typename T>
__host__ __device__ T constexpr size(const T& dim)
{
return dim;
}
/**
* \brief Length get (product if tuple).
*
* \tparam idx Index to lookup.
* \param layout Layout to get Shape of.
* \return Requsted length.
*/
template <index_t idx, typename Shape, typename UnrolledDescriptorType>
__host__ __device__ constexpr auto size(const Layout<Shape, UnrolledDescriptorType>& layout)
{
return layout.template GetLength<idx>();
}
/**
* \brief Shape size (product of dims).
*
* \param shape Shape to lookup.
* \return Requsted size.
*/
template <typename... ShapeDims>
__host__ __device__ constexpr auto size(const Tuple<ShapeDims...>& shape)
{
const auto unrolled_shape = UnrollNestedTuple(shape);
return TupleReduce<0, unrolled_shape.Size()>([](auto x, auto y) { return x * y; },
unrolled_shape);
}
/**
* \brief Layout size (product of dims).
*
* \param layout Layout to calculate shape size.
* \return Requsted size.
*/
template <typename Shape, typename UnrolledDescriptorType>
__host__ __device__ constexpr auto size(const Layout<Shape, UnrolledDescriptorType>& layout)
{
return layout.GetLengths();
}
/**
* \brief Length get from tuple (product if tuple).
*
* \tparam idx Index to lookup.
* \param tuple Tuple to lookup.
* \return Requsted length.
*/
template <index_t idx, typename... Ts>
__host__ __device__ constexpr auto size(const Tuple<Ts...>& tuple)
{
return size(tuple.At(Number<idx>{}));
}
/**
* \brief Hierarchical size.
*
* \tparam Idx First index to lookup (to avoid empty Idxs).
* \tparam Idxs Next indexes to lookup.
* \param elem Element to lookup.
* \return Requsted element.
*/
template <index_t Idx, index_t... Idxs, typename T>
__host__ __device__ constexpr auto size(const T& elem)
{
return size(get<Idx, Idxs...>(elem));
}
// rank
/**
* \brief Get layout rank (num elements in shape).
*
* \param layout Layout to calculate rank.
* \return Requsted rank.
*/
template <typename Shape, typename UnrolledDescriptorType>
__host__ __device__ constexpr auto
rank([[maybe_unused]] const Layout<Shape, UnrolledDescriptorType>& layout)
{
return Shape::Size();
}
/**
* \brief Get tuple rank (num elements in tuple).
* Return 1 if scalar passed.
*
* \param tuple Tuple to calculate rank.
* \return Requsted rank.
*/
template <typename... Dims>
__host__ __device__ constexpr auto rank([[maybe_unused]] const Tuple<Dims...>& tuple)
{
return Tuple<Dims...>::Size();
}
/**
* \private
* \brief Rank for scalar
*
* \param dim Dimension scalar.
* \return Returned 1.
*/
template <index_t IDim>
__host__ __device__ constexpr index_t rank([[maybe_unused]] const Number<IDim>& dim)
{
return 1;
}
/**
* \private
* \brief Rank for scalar
*
* \param dim Dimension scalar.
* \return Returned 1.
*/
__host__ __device__ constexpr index_t rank([[maybe_unused]] const index_t& dim) { return 1; }
/**
* \brief Hierarchical rank.
*
* \tparam Idxs Indexes to lookup.
* \param elem Element to lookup.
* \return Requsted rank.
*/
template <index_t... Idxs, typename T>
__host__ __device__ constexpr auto rank(const T& elem)
{
return rank(get<Idxs...>(elem));
}
// depth
/**
* \brief Get depth of the layout shape (return 0 if scalar).
*
* \param layout Layout to calculate depth.
* \return Requsted depth.
*/
template <typename Shape, typename UnrolledDescriptorType>
__host__ __device__ constexpr auto depth(const Layout<Shape, UnrolledDescriptorType>& layout)
{
const auto& shape = layout.GetShape();
return TupleDepth(shape);
}
/**
* \brief Get depth of the tuple. (return 0 if scalar)
*
* \param tuple Tuple to calculate depth.
* \return Requsted depth.
*/
template <typename... Dims>
__host__ __device__ constexpr auto depth(const Tuple<Dims...>& tuple)
{
return TupleDepth(tuple);
}
/**
* \private
* \brief Depth for scalar
*
* \param dim Scalar.
* \return Returned 0.
*/
template <index_t IDim>
__host__ __device__ constexpr index_t depth([[maybe_unused]] const Number<IDim>& dim)
{
return 0;
}
/**
* \private
* \brief Depth for scalar
*
* \param dim Scalar.
* \return Returned 0.
*/
__host__ __device__ constexpr index_t depth([[maybe_unused]] const index_t& dim) { return 0; }
/**
* \brief Hierarchical depth.
*
* \tparam Idxs Indexes to lookup.
* \param elem Element to lookup.
* \return Requsted depth.
*/
template <index_t... Idxs, typename T>
__host__ __device__ constexpr auto depth(const T& elem)
{
return depth(get<Idxs...>(elem));
}
/**
* \brief Get Layout shape.
*
* \param layout Layout to get shape from.
* \return Requsted shape.
*/
template <typename LayoutType>
__host__ __device__ constexpr const auto& shape(const LayoutType& layout)
{
return layout.GetShape();
}
} // namespace wrapper
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment