Commit 2298a1a4 authored by illsilin's avatar illsilin
Browse files

sync from public

parents 965b7ba4 2f088b87
...@@ -445,6 +445,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -445,6 +445,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
using Block2ETileMap = BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops<MPerBlock, NPerBlock>; using Block2ETileMap = BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops<MPerBlock, NPerBlock>;
using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMapMLoops<Block2ETileMap>; using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMapMLoops<Block2ETileMap>;
// TODO: replace with GroupedGemmKernelArgument
struct GemmBiasTransKernelArg struct GemmBiasTransKernelArg
{ {
// pointers // pointers
...@@ -900,40 +901,58 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -900,40 +901,58 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
return str.str(); return str.str();
} }
static void SetDeviceKernelArgs(Argument& arg, const void* kernel_args)
{
arg.grouped_gemm_kernel_args_dev = kernel_args;
}
// polymorphic // polymorphic
void SetDeviceKernelArgs(BaseArgument* p_arg, const void* kernel_args) const override void SetDeviceKernelArgs(BaseArgument* p_arg, void* kernel_args) const override
{ {
return SetDeviceKernelArgs(*dynamic_cast<Argument*>(p_arg), kernel_args); auto arg_ptr = dynamic_cast<Argument*>(p_arg);
if(arg_ptr)
{
arg_ptr->grouped_gemm_kernel_args_dev = kernel_args;
}
else
throw std::runtime_error("The argument pointer is not an object of "
"DeviceGroupedGemm_Xdl_Fixed_NK::Argument structure!");
} }
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
{ {
auto arg = *dynamic_cast<const Argument*>(p_arg); auto arg_ptr = dynamic_cast<const Argument*>(p_arg);
if(arg_ptr)
return arg.group_count_ * arg.barrier_size_grp_ * sizeof(uint32_t); {
return arg_ptr->group_count_ * arg_ptr->barrier_size_grp_ * sizeof(uint32_t);
}
else
throw std::runtime_error("The argument pointer is not an object of "
"DeviceGroupedGemm_Xdl_Fixed_NK::Argument structure!");
} }
size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override
{ {
auto arg = *dynamic_cast<const Argument*>(p_arg); auto arg_ptr = dynamic_cast<const Argument*>(p_arg);
if(arg_ptr)
return arg.group_count_ * sizeof(GroupedGemmKernelArgument<NumDTensor>); {
return arg_ptr->group_count_ * sizeof(GroupedGemmKernelArgument<NumDTensor>);
}
else
throw std::runtime_error("The argument pointer is not an object of "
"DeviceGroupedGemm_Xdl_Fixed_NK::Argument structure!");
} }
void SetWorkSpacePointer(BaseArgument* p_arg, void SetWorkSpacePointer(BaseArgument* p_arg,
void* p_workspace, void* p_workspace,
const StreamConfig& stream_config = StreamConfig{}) const override const StreamConfig& stream_config = StreamConfig{}) const override
{ {
auto p_arg_ = dynamic_cast<Argument*>(p_arg); auto arg_ptr = dynamic_cast<Argument*>(p_arg);
p_arg_->p_workspace_ = p_workspace; if(arg_ptr)
{
arg_ptr->p_workspace_ = p_workspace;
}
else
throw std::runtime_error("The argument pointer is not an object of "
"DeviceGroupedGemm_Xdl_Fixed_NK::Argument structure!");
hip_check_error( hip_check_error(
hipMemsetAsync(p_workspace, 0, GetWorkSpaceSize(p_arg), stream_config.stream_id_)); hipMemsetAsync(p_workspace, 0, GetWorkSpaceSize(arg_ptr), stream_config.stream_id_));
} }
static void SetKBatch(Argument& arg, index_t k_batch) { arg.UpdateKBatch(k_batch); } static void SetKBatch(Argument& arg, index_t k_batch) { arg.UpdateKBatch(k_batch); }
...@@ -941,7 +960,26 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -941,7 +960,26 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
// polymorphic // polymorphic
void SetKBatch(BaseArgument* p_arg, index_t k_batch) const override void SetKBatch(BaseArgument* p_arg, index_t k_batch) const override
{ {
return SetKBatch(*dynamic_cast<Argument*>(p_arg), k_batch); auto arg_ptr = dynamic_cast<Argument*>(p_arg);
if(arg_ptr)
{
arg_ptr->UpdateKBatch(k_batch);
}
else
throw std::runtime_error("The argument pointer is not an object of "
"DeviceGroupedGemm_Xdl_Fixed_NK::Argument structure!");
}
void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const override
{
auto arg_ptr = dynamic_cast<Argument*>(p_arg);
if(arg_ptr)
{
arg_ptr->UpdateKBatch(kbatch);
}
else
throw std::runtime_error("The argument pointer is not an object of "
"DeviceGroupedGemm_Xdl_Fixed_NK::Argument structure!");
} }
}; };
......
...@@ -421,11 +421,11 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -421,11 +421,11 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
} }
hip_check_error( hip_check_error(
hipMemcpyWithStream(arg.p_workspace_, hipMemcpyAsync(arg.p_workspace_,
arg.gemm_kernel_args_.data(), arg.gemm_kernel_args_.data(),
arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg), arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg),
hipMemcpyHostToDevice, hipMemcpyHostToDevice,
stream_config.stream_id_)); stream_config.stream_id_));
float ave_time = 0; float ave_time = 0;
...@@ -538,10 +538,16 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -538,10 +538,16 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
return false; return false;
} }
if(std::is_same_v<EDataType, ck::bhalf_t> && arg.K_BATCH > 1 && !is_bf16_atomic_supported())
{
return false;
}
bool supported = true; bool supported = true;
for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i) for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i)
{ {
const auto& a = arg.gemm_kernel_args_[i].karg_; const auto& a = arg.gemm_kernel_args_[i].karg_;
bool group_arg_valid = GridwiseGemm::CheckValidity(a); bool group_arg_valid = GridwiseGemm::CheckValidity(a);
if(not group_arg_valid) if(not group_arg_valid)
{ {
...@@ -631,16 +637,42 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -631,16 +637,42 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
{ {
return dynamic_cast<const Argument*>(p_arg)->gemm_kernel_args_.size() * auto p_arg_ = dynamic_cast<const Argument*>(p_arg);
sizeof(GemmTransKernelArg); if(p_arg_)
{
return p_arg_->gemm_kernel_args_.size() * sizeof(GemmTransKernelArg);
}
else
throw std::runtime_error(
"The argument pointer is not an object of "
"DeviceGroupedGemmMultipleDSplitKXdlCShuffle::Argument structure!");
} }
size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override
{
return GetWorkSpaceSize(p_arg);
}
// TODO: deperecation notice.
static void SetKBatchSize(Argument& arg, index_t kbatch) { arg.UpdateKBatch(kbatch); } static void SetKBatchSize(Argument& arg, index_t kbatch) { arg.UpdateKBatch(kbatch); }
// polymorphic // polymorphic
void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const override void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const override
{ {
return SetKBatchSize(*dynamic_cast<Argument*>(p_arg), kbatch); auto p_arg_ = dynamic_cast<Argument*>(p_arg);
if(p_arg_)
{
p_arg_->UpdateKBatch(kbatch);
}
else
throw std::runtime_error(
"The argument pointer is not an object of "
"DeviceGroupedGemmMultipleDSplitKXdlCShuffle::Argument structure!");
}
void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const override
{
return this->SetWorkSpacePointer(p_arg, p_dev_kernel_args);
} }
}; };
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1r2.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1r2.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/utility/workgroup_barrier.hpp"
#include "ck/utility/reduction_functions_accumulate.hpp"
namespace ck { namespace ck {
...@@ -38,7 +40,7 @@ __global__ void ...@@ -38,7 +40,7 @@ __global__ void
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>( GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, p_shared, karg); karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, p_shared, karg, karg.p_workspace_);
#else #else
ignore = karg; ignore = karg;
#endif // end of if (defined(__gfx9__)) #endif // end of if (defined(__gfx9__))
...@@ -62,7 +64,13 @@ __global__ void ...@@ -62,7 +64,13 @@ __global__ void
__shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>( GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, p_shared_0, p_shared_1, karg); karg.p_a_grid,
karg.p_b_grid,
karg.p_c_grid,
p_shared_0,
p_shared_1,
karg,
karg.p_workspace_);
#else #else
ignore = karg; ignore = karg;
#endif // end of if (defined(__gfx9__)) #endif // end of if (defined(__gfx9__))
...@@ -521,7 +529,9 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 ...@@ -521,7 +529,9 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, Streamk_sel_, Grid_size_}, : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, Streamk_sel_, Grid_size_},
p_a_grid{p_a_grid_}, p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_}, p_b_grid{p_b_grid_},
p_c_grid{p_c_grid_} p_c_grid{p_c_grid_},
block_2_ctile_map_streamk(
M_, N_, AK0Number * CalculateKPadded(K_, 1), Grid_size_, Streamk_sel_)
{ {
} }
...@@ -529,6 +539,13 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 ...@@ -529,6 +539,13 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
const ADataType* p_a_grid; const ADataType* p_a_grid;
const BDataType* p_b_grid; const BDataType* p_b_grid;
CDataType* p_c_grid; CDataType* p_c_grid;
BlockToCTileMap_GemmStreamK_v2<MPerBlock,
NPerBlock,
KPerBlock,
StreamKReductionStrategy::Atomic,
8,
4>
block_2_ctile_map_streamk;
}; };
struct SplitKBatchOffset struct SplitKBatchOffset
...@@ -853,6 +870,19 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 ...@@ -853,6 +870,19 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
} }
__host__ __device__ static constexpr auto
GetCBlockDescriptor_MShuffle_MPerShuffle_NShuffle_NPerShuffle()
{
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
return make_naive_tensor_descriptor_packed(
make_tuple(Number<MXdlPerWave / CShuffleMXdlPerWavePerShuffle>{},
Number<CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl>{},
Number<NXdlPerWave / CShuffleNXdlPerWavePerShuffle>{},
Number<CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>{}));
}
using BlockwiseGemmPipe = using BlockwiseGemmPipe =
remove_cvref_t<decltype(BlockGemmPipeline_Selector< remove_cvref_t<decltype(BlockGemmPipeline_Selector<
BlkGemmPipelineVer, BlkGemmPipelineVer,
...@@ -1118,6 +1148,34 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 ...@@ -1118,6 +1148,34 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
return c_grid_desc_mblock_mperblock_nblock_nperblock; return c_grid_desc_mblock_mperblock_nblock_nperblock;
} }
__host__ __device__ static constexpr auto GetClusterLengthReduction()
{
// TODO: assume C is row major
// TODO: we always first loop over N, then M
constexpr auto NPerBlockPow2 = math::next_power_of_two<NPerBlock>();
constexpr auto NPerBlockReduction =
NPerBlockPow2 / CShuffleBlockTransferScalarPerVector_NPerBlock;
constexpr auto MPerBlockReduction =
(BlockSize + NPerBlockReduction - 1) / NPerBlockReduction;
return Sequence<MPerBlockReduction, NPerBlockReduction>{};
}
__host__ __device__ static constexpr auto GetPartialAccBlockDescriptor()
{
const auto c_partial_acc_block_m_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(MPerBlock, NPerBlock),
make_tuple(NPerBlock, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(MPerBlock, NPerBlock),
make_tuple(I1, MPerBlock));
}
}();
return c_partial_acc_block_m_n;
}
using Block2CTileMap_streamk = BlockToCTileMap_GemmStreamK_v2<MPerBlock, using Block2CTileMap_streamk = BlockToCTileMap_GemmStreamK_v2<MPerBlock,
NPerBlock, NPerBlock,
KPerBlock, KPerBlock,
...@@ -1132,22 +1190,42 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 ...@@ -1132,22 +1190,42 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
const BDataType* p_b_grid, const BDataType* p_b_grid,
CDataType* p_c_grid, CDataType* p_c_grid,
void* p_shared, void* p_shared,
Problem& problem) Problem& problem,
void* p_workspace)
{ {
const AElementwiseOperation a_element_op{}; const AElementwiseOperation a_element_op{};
const BElementwiseOperation b_element_op{}; const BElementwiseOperation b_element_op{};
const CElementwiseOperation c_element_op{}; const CElementwiseOperation c_element_op{};
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
Block2CTileMap_streamk block_2_ctile_map_streamk(problem.M, Block2CTileMap_streamk block_2_ctile_map_streamk(problem.M,
problem.N, problem.N,
AK0Number * problem.KPadded, AK0Number * problem.KPadded,
problem.Grid_size, problem.Grid_size,
problem.Streamk_sel); problem.Streamk_sel);
uint32_t iter_start, iter_end; uint32_t iter_start, iter_end;
bool is_sk_block, is_dp_block; bool is_sk_block, is_dp_block, is_reduction_block;
index_t num_k_block_main_loop; index_t num_k_block_main_loop;
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n, problem.MBlock, problem.NBlock);
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
uint32_t* p_semaphore = reinterpret_cast<uint32_t*>(
reinterpret_cast<char*>(p_workspace) +
block_2_ctile_map_streamk.get_workspace_size_for_acc(sizeof(AccDataType)));
for(auto block_idx = get_block_1d_id(); for(auto block_idx = get_block_1d_id();
block_idx < block_2_ctile_map_streamk.get_grid_dims(); block_idx < block_2_ctile_map_streamk.get_grid_dims();
block_idx += gridDim.x) block_idx += gridDim.x)
...@@ -1163,6 +1241,214 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 ...@@ -1163,6 +1241,214 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
block_2_ctile_map_streamk.get_block_itr(block_idx, iter_start, iter_end); block_2_ctile_map_streamk.get_block_itr(block_idx, iter_start, iter_end);
num_k_block_main_loop = iter_end - iter_start; num_k_block_main_loop = iter_end - iter_start;
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
{
is_reduction_block = static_cast<uint32_t>(block_idx) >=
block_2_ctile_map_streamk.reduction_start_block_idx;
if(is_reduction_block)
{
// descriptors
constexpr auto cluster_length_reduce = GetClusterLengthReduction();
constexpr auto reduce_desc = make_cluster_descriptor(cluster_length_reduce);
const auto reduce_thread_cluster_idx =
reduce_desc.CalculateBottomIndex(make_multi_index(block_idx));
const auto thread_m_cluster_id = reduce_thread_cluster_idx[I0];
const auto thread_n_cluster_id = reduce_thread_cluster_idx[I1];
constexpr auto MReduceIters = math::integer_divide_ceil(
Number<MPerBlock>{}, cluster_length_reduce.At(I0));
constexpr auto NReduceIters = math::integer_divide_ceil(
Number<NPerBlock>{},
cluster_length_reduce.At(I1) *
Number<CShuffleBlockTransferScalarPerVector_NPerBlock>{});
constexpr auto acc_thread_buf_load_desc = make_naive_tensor_descriptor_packed(
make_tuple(I1, Number<CShuffleBlockTransferScalarPerVector_NPerBlock>{}));
constexpr auto acc_thread_buf_store_desc =
make_naive_tensor_descriptor_packed(make_tuple(
I1, I1, I1, Number<CShuffleBlockTransferScalarPerVector_NPerBlock>{}));
constexpr auto c_partial_acc_block_m_n = GetPartialAccBlockDescriptor();
constexpr auto partial_acc_load_step_n =
make_multi_index(0,
cluster_length_reduce.At(I1) *
CShuffleBlockTransferScalarPerVector_NPerBlock);
constexpr auto partial_acc_load_step_n_reverse = make_multi_index(
0,
-1 * cluster_length_reduce.At(I1).value * (NReduceIters - 1) *
CShuffleBlockTransferScalarPerVector_NPerBlock);
constexpr auto partial_acc_load_step_m =
make_multi_index(cluster_length_reduce.At(I0), 0);
constexpr auto partial_acc_store_step_n =
make_multi_index(0,
0,
0,
cluster_length_reduce.At(I1) *
CShuffleBlockTransferScalarPerVector_NPerBlock);
constexpr auto partial_acc_store_step_n_reverse = make_multi_index(
0,
0,
0,
-1 * cluster_length_reduce.At(I1).value * (NReduceIters - 1) *
CShuffleBlockTransferScalarPerVector_NPerBlock);
constexpr auto partial_acc_store_step_m =
make_multi_index(0, cluster_length_reduce.At(I0), 0, 0);
StaticBuffer<AddressSpaceEnum::Vgpr,
AccDataType,
CShuffleBlockTransferScalarPerVector_NPerBlock,
true>
parcial_acc_buf;
StaticBuffer<AddressSpaceEnum::Vgpr,
AccDataType,
CShuffleBlockTransferScalarPerVector_NPerBlock,
true>
acc_buf;
// start to compute
auto reduction_idx =
block_idx - block_2_ctile_map_streamk.reduction_start_block_idx;
auto spatial_idx = block_2_ctile_map_streamk.tile_to_spatial(
reduction_idx, problem.M, problem.N);
workgroup_barrier wg_barrier(p_semaphore);
uint32_t tile_acc_offset_start =
block_2_ctile_map_streamk.get_acc_buffer_offset_from_tile(reduction_idx);
uint32_t tile_acc_offset_end =
block_2_ctile_map_streamk.get_acc_buffer_offset_from_tile(reduction_idx +
1);
__syncthreads();
auto acc_load = ThreadwiseTensorSliceTransfer_v2<
AccDataType, // SrcData,
AccDataType, // DstData,
decltype(c_partial_acc_block_m_n), // SrcDesc,
decltype(acc_thread_buf_load_desc), // DstDesc,
Sequence<1,
CShuffleBlockTransferScalarPerVector_NPerBlock>, // SliceLengths,
Sequence<0, 1>, // DimAccessOrder,
1, // SrcVectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock, // SrcScalarPerVector,
1, // SrcScalarStrideInVector,
false // SrcResetCoordinateAfterRun,
>{c_partial_acc_block_m_n,
make_multi_index(thread_m_cluster_id,
thread_n_cluster_id *
CShuffleBlockTransferScalarPerVector_NPerBlock)};
auto acc_store = ThreadwiseTensorSliceTransfer_v1r3<
AccDataType, // SrcData,
CDataType, // DstData,
decltype(acc_thread_buf_store_desc), // SrcDesc,
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), // DstDesc,
CElementwiseOperation, // ElementwiseOperation,
Sequence<1,
1,
1,
CShuffleBlockTransferScalarPerVector_NPerBlock>, // SliceLengths,
Sequence<0, 1, 2, 3>, // DimAccessOrder,
3, // DstVectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock, // DstScalarPerVector,
InMemoryDataOperationEnum::Set, // InMemoryDataOperationEnum DstInMemOp,
1, // DstScalarStrideInVector,
false // DstResetCoordinateAfterRun,
>{c_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(__builtin_amdgcn_readfirstlane(spatial_idx[I0]),
thread_m_cluster_id,
__builtin_amdgcn_readfirstlane(spatial_idx[I1]),
thread_n_cluster_id *
CShuffleBlockTransferScalarPerVector_NPerBlock),
CElementwiseOperation{}};
wg_barrier.wait_eq(reduction_idx, tile_acc_offset_end - tile_acc_offset_start);
if(threadIdx.x == 0)
{
p_semaphore[reduction_idx] = 0;
}
using Accumulation = ck::detail::
AccumulateWithNanCheck<false /*PropagateNan*/, reduce::Add, AccDataType>;
for(int i_m = 0; i_m < MReduceIters; i_m++)
{
static_for<0, NReduceIters, 1>{}([&](auto i_n_reduce) {
acc_buf.Clear();
for(auto i = tile_acc_offset_start; i < tile_acc_offset_end; i++)
{
auto c_partial_acc_buf =
make_dynamic_buffer<AddressSpaceEnum::Global,
AmdBufferCoherenceEnum::GLC>(
reinterpret_cast<AccDataType*>(p_workspace) +
i * c_partial_acc_block_m_n.GetElementSpaceSize(),
c_partial_acc_block_m_n.GetElementSpaceSize());
acc_load.Run(c_partial_acc_block_m_n,
c_partial_acc_buf,
acc_thread_buf_load_desc,
make_tuple(I0, I0),
parcial_acc_buf);
static_for<0, CShuffleBlockTransferScalarPerVector_NPerBlock, 1>{}(
[&](auto i_vec) {
constexpr auto offset =
acc_thread_buf_load_desc.CalculateOffset(
make_tuple(0, i_vec));
Accumulation::Calculate(acc_buf(Number<offset>{}),
parcial_acc_buf[Number<offset>{}]);
});
}
if(thread_n_cluster_id *
CShuffleBlockTransferScalarPerVector_NPerBlock <
NPerBlock)
{
acc_store.Run(acc_thread_buf_store_desc,
make_tuple(I0, I0, I0, I0),
acc_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
}
if constexpr(NReduceIters != 1)
{
if constexpr(i_n_reduce != (NReduceIters - 1))
{
acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
partial_acc_load_step_n);
acc_store.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock,
partial_acc_store_step_n);
}
else
{
acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
partial_acc_load_step_n_reverse);
acc_store.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock,
partial_acc_store_step_n_reverse);
}
}
});
{
acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
partial_acc_load_step_m);
acc_store.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock,
partial_acc_store_step_m);
}
}
continue;
}
}
// offset for last acc buffer of this block
uint32_t block_acc_offset =
(block_2_ctile_map_streamk.get_acc_buffer_offset_from_block(block_idx + 1) - 1) *
MPerBlock * NPerBlock;
while(true) while(true)
{ {
uint32_t current_iter_length = __builtin_amdgcn_readfirstlane( uint32_t current_iter_length = __builtin_amdgcn_readfirstlane(
...@@ -1173,33 +1459,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 ...@@ -1173,33 +1459,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
iter_end - 1, tile_idx, iter_offset); iter_end - 1, tile_idx, iter_offset);
iter_offset = __builtin_amdgcn_readfirstlane(iter_offset - current_iter_length + 1); iter_offset = __builtin_amdgcn_readfirstlane(iter_offset - current_iter_length + 1);
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(problem.M,
problem.MPadded,
problem.K,
problem.KPadded,
problem.StrideA,
problem.AK0);
const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(problem.K,
problem.KPadded,
problem.N,
problem.NPadded,
problem.StrideB,
problem.BK0);
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n, problem.MBlock, problem.NBlock);
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
auto block_work_idx = auto block_work_idx =
block_2_ctile_map_streamk.tile_to_spatial(tile_idx, problem.M, problem.N); block_2_ctile_map_streamk.tile_to_spatial(tile_idx, problem.M, problem.N);
...@@ -1363,11 +1622,20 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 ...@@ -1363,11 +1622,20 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
constexpr auto c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle =
GetCBlockDescriptor_MShuffle_MPerShuffle_NShuffle_NPerShuffle();
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<CShuffleDataType*>(p_shared), static_cast<CShuffleDataType*>(p_shared),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.GetElementSpaceSize()); .GetElementSpaceSize());
auto c_partial_acc_buf =
make_dynamic_buffer<AddressSpaceEnum::Global, AmdBufferCoherenceEnum::GLC>(
reinterpret_cast<AccDataType*>(p_workspace) + block_acc_offset,
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
.GetElementSpaceSize());
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
transform_tensor_descriptor( transform_tensor_descriptor(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
...@@ -1477,7 +1745,34 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 ...@@ -1477,7 +1745,34 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(block_m_id, 0, block_n_id, 0), make_multi_index(block_m_id, 0, block_n_id, 0),
c_element_op}; c_element_op};
// LDS to global partial acc
auto c_block_copy_lds_to_partial_acc = ThreadGroupTensorSliceTransfer_v6r1r2<
ThisThreadBlock, // index_t BlockSize,
CElementwiseOperation, // ElementwiseOperation,
// InMemoryDataOperationEnum::Set, // DstInMemOp,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave *
NPerXdl>, // BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
CShuffleDataType, // typename SrcData,
CShuffleDataType, // typename DstData,
decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
decltype(c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
false, // bool ThreadTransferSrcResetCoordinateAfterRun, => need to be
// false, othre wise has scratch
false> // bool ThreadTransferDstResetCoordinateAfterRun, => need to be
// false, othre wise has scratch
{c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(0, 0, 0, 0),
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
make_multi_index(0, 0, 0, 0),
c_element_op};
// space filling curve for threadwise C in VGPR // space filling curve for threadwise C in VGPR
constexpr auto sfc_c_vgpr = constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>, SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
...@@ -1535,15 +1830,40 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 ...@@ -1535,15 +1830,40 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
} }
else if(is_sk_block) else if(is_sk_block)
{ {
// each block copy its data from LDS to global if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
c_shuffle_block_copy_lds_to_global StreamKReductionStrategy::Atomic)
.template Run<decltype(c_shuffle_block_buf), {
decltype(c_grid_buf), // each block copy its data from LDS to global
InMemoryDataOperationEnum::AtomicAdd>( c_shuffle_block_copy_lds_to_global
.template Run<decltype(c_shuffle_block_buf),
decltype(c_grid_buf),
InMemoryDataOperationEnum::AtomicAdd>(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
c_shuffle_block_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
}
else if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
{
// constexpr offset
c_block_copy_lds_to_partial_acc.SetSrcSliceOrigin(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
c_shuffle_block_buf, make_tuple(0, 0, 0, 0));
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf); c_block_copy_lds_to_partial_acc.SetDstSliceOrigin(
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
make_tuple(MXdlPerWave, 0, NXdlPerWave, 0));
c_block_copy_lds_to_partial_acc
.template Run<decltype(c_shuffle_block_buf),
decltype(c_partial_acc_buf),
InMemoryDataOperationEnum::Set>(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
c_shuffle_block_buf,
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
c_partial_acc_buf);
}
} }
if constexpr(access_id < num_access - 1) if constexpr(access_id < num_access - 1)
...@@ -1555,15 +1875,33 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 ...@@ -1555,15 +1875,33 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
} }
}); });
}
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
{
if(is_sk_block)
{
// increase the counter for this tile
workgroup_barrier wg_barrier(p_semaphore);
wg_barrier.inc(tile_idx);
}
}
} // shuffle c and write-out end
// exit condition // exit condition
iter_end -= current_iter_length; iter_end -= current_iter_length;
if(iter_end <= iter_start) if(iter_end <= iter_start)
break; break;
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
{
block_acc_offset -= MPerBlock * NPerBlock;
}
// make sure next loop LDS is ready for use // make sure next loop LDS is ready for use
block_sync_lds(); block_sync_lds();
} } // while loop
}
} // for loop
} }
template <bool HasMainKBlockLoop, template <bool HasMainKBlockLoop,
...@@ -1574,19 +1912,43 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 ...@@ -1574,19 +1912,43 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
CDataType* p_c_grid, CDataType* p_c_grid,
void* p_shared_0, void* p_shared_0,
void* p_shared_1, void* p_shared_1,
Problem& problem) Problem& problem,
void* p_workspace)
{ {
const AElementwiseOperation a_element_op{}; const AElementwiseOperation a_element_op{};
const BElementwiseOperation b_element_op{}; const BElementwiseOperation b_element_op{};
const CElementwiseOperation c_element_op{}; const CElementwiseOperation c_element_op{};
Block2CTileMap_streamk block_2_ctile_map_streamk( const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
problem.M, problem.N, AK0Number * problem.KPadded, problem.Grid_size); problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
uint32_t iter_start, iter_end; uint32_t iter_start, iter_end;
bool is_sk_block, is_dp_block; //, is_padding_block; //, is_reduction_block; bool is_sk_block, is_dp_block, is_reduction_block;
index_t num_k_block_main_loop; index_t num_k_block_main_loop;
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n, problem.MBlock, problem.NBlock);
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
Block2CTileMap_streamk block_2_ctile_map_streamk(problem.M,
problem.N,
AK0Number * problem.KPadded,
problem.Grid_size,
problem.Streamk_sel);
for(auto block_idx = get_block_1d_id(); for(auto block_idx = get_block_1d_id();
block_idx < block_2_ctile_map_streamk.get_grid_dims(); block_idx < block_2_ctile_map_streamk.get_grid_dims();
block_idx += gridDim.x) block_idx += gridDim.x)
...@@ -1601,6 +1963,235 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 ...@@ -1601,6 +1963,235 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
block_2_ctile_map_streamk.get_block_itr(block_idx, iter_start, iter_end); block_2_ctile_map_streamk.get_block_itr(block_idx, iter_start, iter_end);
num_k_block_main_loop = iter_end - iter_start; num_k_block_main_loop = iter_end - iter_start;
uint32_t* p_semaphore = reinterpret_cast<uint32_t*>(
reinterpret_cast<char*>(p_workspace) +
block_2_ctile_map_streamk.get_workspace_size_for_acc(sizeof(AccDataType)));
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
{
is_reduction_block = static_cast<uint32_t>(block_idx) >=
block_2_ctile_map_streamk.reduction_start_block_idx;
if(is_reduction_block)
{
// descriptors
constexpr auto cluster_length_reduce = GetClusterLengthReduction();
constexpr auto reduce_desc = make_cluster_descriptor(cluster_length_reduce);
const auto reduce_thread_cluster_idx =
reduce_desc.CalculateBottomIndex(make_multi_index(block_idx));
const auto thread_m_cluster_id = reduce_thread_cluster_idx[I0];
const auto thread_n_cluster_id = reduce_thread_cluster_idx[I1];
constexpr auto MReduceIters = math::integer_divide_ceil(
Number<MPerBlock>{}, cluster_length_reduce.At(I0));
constexpr auto NReduceIters = math::integer_divide_ceil(
Number<NPerBlock>{},
cluster_length_reduce.At(I1) *
Number<CShuffleBlockTransferScalarPerVector_NPerBlock>{});
constexpr auto acc_thread_buf_load_desc = make_naive_tensor_descriptor_packed(
make_tuple(I1, Number<CShuffleBlockTransferScalarPerVector_NPerBlock>{}));
constexpr auto acc_thread_buf_store_desc =
make_naive_tensor_descriptor_packed(make_tuple(
I1, I1, I1, Number<CShuffleBlockTransferScalarPerVector_NPerBlock>{}));
constexpr auto c_partial_acc_block_m_n = GetPartialAccBlockDescriptor();
constexpr auto partial_acc_load_step_n =
make_multi_index(0,
cluster_length_reduce.At(I1) *
CShuffleBlockTransferScalarPerVector_NPerBlock);
constexpr auto partial_acc_load_step_n_reverse = make_multi_index(
0,
-1 * cluster_length_reduce.At(I1).value * (NReduceIters - 1) *
CShuffleBlockTransferScalarPerVector_NPerBlock);
constexpr auto partial_acc_load_step_m =
make_multi_index(cluster_length_reduce.At(I0), 0);
constexpr auto partial_acc_store_step_n =
make_multi_index(0,
0,
0,
cluster_length_reduce.At(I1) *
CShuffleBlockTransferScalarPerVector_NPerBlock);
constexpr auto partial_acc_store_step_n_reverse = make_multi_index(
0,
0,
0,
-1 * cluster_length_reduce.At(I1).value * (NReduceIters - 1) *
CShuffleBlockTransferScalarPerVector_NPerBlock);
constexpr auto partial_acc_store_step_m =
make_multi_index(0, cluster_length_reduce.At(I0), 0, 0);
StaticBuffer<AddressSpaceEnum::Vgpr,
AccDataType,
CShuffleBlockTransferScalarPerVector_NPerBlock,
true>
parcial_acc_buf;
StaticBuffer<AddressSpaceEnum::Vgpr,
AccDataType,
CShuffleBlockTransferScalarPerVector_NPerBlock,
true>
acc_buf;
// start to compute
auto reduction_idx =
block_idx - block_2_ctile_map_streamk.reduction_start_block_idx;
auto spatial_idx = block_2_ctile_map_streamk.tile_to_spatial(
reduction_idx, problem.M, problem.N);
workgroup_barrier wg_barrier(p_semaphore);
uint32_t tile_acc_offset_start =
block_2_ctile_map_streamk.get_acc_buffer_offset_from_tile(reduction_idx);
uint32_t tile_acc_offset_end =
block_2_ctile_map_streamk.get_acc_buffer_offset_from_tile(reduction_idx +
1);
uint32_t expected_count = tile_acc_offset_end - tile_acc_offset_start;
if(threadIdx.x == 0)
{
p_semaphore[reduction_idx] = 0;
}
__syncthreads();
auto acc_load = ThreadwiseTensorSliceTransfer_v2<
AccDataType, // SrcData,
AccDataType, // DstData,
decltype(c_partial_acc_block_m_n), // SrcDesc,
decltype(acc_thread_buf_load_desc), // DstDesc,
Sequence<1,
CShuffleBlockTransferScalarPerVector_NPerBlock>, // SliceLengths,
Sequence<0, 1>, // DimAccessOrder,
1, // SrcVectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock, // SrcScalarPerVector,
1, // SrcScalarStrideInVector,
false // SrcResetCoordinateAfterRun,
>{c_partial_acc_block_m_n,
make_multi_index(thread_m_cluster_id,
thread_n_cluster_id *
CShuffleBlockTransferScalarPerVector_NPerBlock)};
auto acc_store = ThreadwiseTensorSliceTransfer_v1r3<
AccDataType, // SrcData,
CDataType, // DstData,
decltype(acc_thread_buf_store_desc), // SrcDesc,
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), // DstDesc,
CElementwiseOperation, // ElementwiseOperation,
Sequence<1,
1,
1,
CShuffleBlockTransferScalarPerVector_NPerBlock>, // SliceLengths,
Sequence<0, 1, 2, 3>, // DimAccessOrder,
3, // DstVectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock, // DstScalarPerVector,
InMemoryDataOperationEnum::Set, // InMemoryDataOperationEnum DstInMemOp,
1, // DstScalarStrideInVector,
false // DstResetCoordinateAfterRun,
>{c_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(__builtin_amdgcn_readfirstlane(spatial_idx[I0]),
thread_m_cluster_id,
__builtin_amdgcn_readfirstlane(spatial_idx[I1]),
thread_n_cluster_id *
CShuffleBlockTransferScalarPerVector_NPerBlock),
CElementwiseOperation{}};
#if 0
if(threadIdx.x == 0) {
printf("bid:%d, rid:%d, os:%d,%d, spatial:%d,%d\n", static_cast<int>(blockIdx.x),
reduction_idx, __builtin_amdgcn_readfirstlane(tile_acc_offset_start), __builtin_amdgcn_readfirstlane(tile_acc_offset_end),
__builtin_amdgcn_readfirstlane(spatial_idx[I0]),
__builtin_amdgcn_readfirstlane(spatial_idx[I1]));
}
#endif
if(threadIdx.x == 0)
{
atomicAdd(&p_semaphore[reduction_idx], 1);
}
wg_barrier.wait_eq(p_semaphore[reduction_idx], expected_count);
using Accumulation = ck::detail::
AccumulateWithNanCheck<false /*PropagateNan*/, reduce::Add, AccDataType>;
for(int i_m = 0; i_m < MReduceIters; i_m++)
{
static_for<0, NReduceIters, 1>{}([&](auto i_n_reduce) {
acc_buf.Clear();
for(auto i = tile_acc_offset_start; i < tile_acc_offset_end; i++)
{
auto c_partial_acc_buf =
make_dynamic_buffer<AddressSpaceEnum::Global,
AmdBufferCoherenceEnum::GLC>(
reinterpret_cast<AccDataType*>(p_workspace) +
i * c_partial_acc_block_m_n.GetElementSpaceSize(),
c_partial_acc_block_m_n.GetElementSpaceSize());
acc_load.Run(c_partial_acc_block_m_n,
c_partial_acc_buf,
acc_thread_buf_load_desc,
make_tuple(I0, I0),
parcial_acc_buf);
static_for<0, CShuffleBlockTransferScalarPerVector_NPerBlock, 1>{}(
[&](auto i_vec) {
constexpr auto offset =
acc_thread_buf_load_desc.CalculateOffset(
make_tuple(0, i_vec));
Accumulation::Calculate(acc_buf(Number<offset>{}),
parcial_acc_buf[Number<offset>{}]);
});
}
if(thread_n_cluster_id *
CShuffleBlockTransferScalarPerVector_NPerBlock <
NPerBlock)
{
acc_store.Run(acc_thread_buf_store_desc,
make_tuple(I0, I0, I0, I0),
acc_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
}
if constexpr(NReduceIters != 1)
{
if constexpr(i_n_reduce != (NReduceIters - 1))
{
acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
partial_acc_load_step_n);
acc_store.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock,
partial_acc_store_step_n);
}
else
{
acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
partial_acc_load_step_n_reverse);
acc_store.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock,
partial_acc_store_step_n_reverse);
}
}
});
{
acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
partial_acc_load_step_m);
acc_store.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock,
partial_acc_store_step_m);
}
}
continue;
}
}
// offset for last acc buffer of this block
uint32_t block_acc_offset =
(block_2_ctile_map_streamk.get_acc_buffer_offset_from_block(block_idx + 1) - 1) *
MPerBlock * NPerBlock;
while(true)
{ {
uint32_t current_iter_length = __builtin_amdgcn_readfirstlane( uint32_t current_iter_length = __builtin_amdgcn_readfirstlane(
...@@ -1611,33 +2202,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 ...@@ -1611,33 +2202,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
iter_end - 1, tile_idx, iter_offset); iter_end - 1, tile_idx, iter_offset);
iter_offset = __builtin_amdgcn_readfirstlane(iter_offset - current_iter_length + 1); iter_offset = __builtin_amdgcn_readfirstlane(iter_offset - current_iter_length + 1);
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(problem.M,
problem.MPadded,
problem.K,
problem.KPadded,
problem.StrideA,
problem.AK0);
const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(problem.K,
problem.KPadded,
problem.N,
problem.NPadded,
problem.StrideB,
problem.BK0);
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n, problem.MBlock, problem.NBlock);
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
auto block_work_idx = auto block_work_idx =
block_2_ctile_map_streamk.tile_to_spatial(tile_idx, problem.M, problem.N); block_2_ctile_map_streamk.tile_to_spatial(tile_idx, problem.M, problem.N);
...@@ -1811,11 +2375,20 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 ...@@ -1811,11 +2375,20 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
constexpr auto c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle =
GetCBlockDescriptor_MShuffle_MPerShuffle_NShuffle_NPerShuffle();
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<CShuffleDataType*>(p_shared_0), static_cast<CShuffleDataType*>(p_shared_0),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.GetElementSpaceSize()); .GetElementSpaceSize());
auto c_partial_acc_buf =
make_dynamic_buffer<AddressSpaceEnum::Global, AmdBufferCoherenceEnum::GLC>(
reinterpret_cast<AccDataType*>(p_workspace) + block_acc_offset,
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
.GetElementSpaceSize());
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
transform_tensor_descriptor( transform_tensor_descriptor(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
...@@ -1925,6 +2498,35 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 ...@@ -1925,6 +2498,35 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
make_multi_index(block_m_id, 0, block_n_id, 0), make_multi_index(block_m_id, 0, block_n_id, 0),
c_element_op}; c_element_op};
// LDS to global partial acc
auto c_block_copy_lds_to_partial_acc = ThreadGroupTensorSliceTransfer_v6r1r2<
ThisThreadBlock, // index_t BlockSize,
CElementwiseOperation, // ElementwiseOperation,
// InMemoryDataOperationEnum::Set, // DstInMemOp,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave *
NPerXdl>, // BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
CShuffleDataType, // typename SrcData,
CShuffleDataType, // typename DstData,
decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
decltype(c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
false, // bool ThreadTransferSrcResetCoordinateAfterRun, => need to be
// false, othre wise has scratch
false> // bool ThreadTransferDstResetCoordinateAfterRun, => need to be
// false, othre wise has scratch
{c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(0, 0, 0, 0),
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
make_multi_index(0, 0, 0, 0),
c_element_op};
// space filling curve for threadwise C in VGPR // space filling curve for threadwise C in VGPR
constexpr auto sfc_c_vgpr = constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>, SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
...@@ -1982,15 +2584,40 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 ...@@ -1982,15 +2584,40 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
} }
else if(is_sk_block) else if(is_sk_block)
{ {
// each block copy its data from LDS to global if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
c_shuffle_block_copy_lds_to_global StreamKReductionStrategy::Atomic)
.template Run<decltype(c_shuffle_block_buf), {
decltype(c_grid_buf), // each block copy its data from LDS to global
InMemoryDataOperationEnum::AtomicAdd>( c_shuffle_block_copy_lds_to_global
.template Run<decltype(c_shuffle_block_buf),
decltype(c_grid_buf),
InMemoryDataOperationEnum::AtomicAdd>(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
c_shuffle_block_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
}
else if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
{
// constexpr offset
c_block_copy_lds_to_partial_acc.SetSrcSliceOrigin(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
c_shuffle_block_buf, make_tuple(0, 0, 0, 0));
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf); c_block_copy_lds_to_partial_acc.SetDstSliceOrigin(
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
make_tuple(MXdlPerWave, 0, NXdlPerWave, 0));
c_block_copy_lds_to_partial_acc
.template Run<decltype(c_shuffle_block_buf),
decltype(c_partial_acc_buf),
InMemoryDataOperationEnum::Set>(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
c_shuffle_block_buf,
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
c_partial_acc_buf);
}
} }
if constexpr(access_id < num_access - 1) if constexpr(access_id < num_access - 1)
{ {
...@@ -2002,6 +2629,27 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 ...@@ -2002,6 +2629,27 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
} }
}); });
} }
// exit condition
iter_end -= current_iter_length;
if(iter_end <= iter_start)
break;
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
{
block_acc_offset -= MPerBlock * NPerBlock;
}
// make sure next loop LDS is ready for use
block_sync_lds();
}
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
{
if(is_sk_block)
{
// increase the counter for this tile
workgroup_barrier wg_barrier(p_semaphore);
wg_barrier.inc(0);
}
} }
} }
} }
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -13,245 +13,614 @@ ...@@ -13,245 +13,614 @@
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace {
template < template <
index_t NDimSpatial, index_t NDimSpatial,
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization ConvBwdDataSpecialization,
index_t AK1,
index_t BK1,
index_t GemmMPerBlock,
index_t GemmNPerBlock,
index_t GemmKPerBlock,
bool DoPadGemmM,
bool DoPadGemmN,
typename ALayout, typename ALayout,
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization ConvBwdDataSpecialization> typename BLayout,
constexpr auto make_out_grid_desc(const index_t N, typename CLayout,
const index_t Do, bool SplitN = false,
const index_t Ho, typename ADataType = float,
const index_t Wo, typename CDataType = float,
const index_t K, index_t NumGroupsToMerge = 1,
const std::array<index_t, NDimSpatial + 3>& out_g_n_k_wos_strides) typename IndexType = index_t>
struct TransformConvBwdDataToGemm_v1
{ {
const auto KStride = Number<1>{}; private:
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
if constexpr(is_same_v<ALayout, tensor_layout::convolution::NHWGK>) static constexpr auto NonSpatialDimsNum = Number<3>{};
{
const index_t NStride = out_g_n_k_wos_strides[1];
const index_t HiStride = out_g_n_k_wos_strides[3];
const index_t WiStride = out_g_n_k_wos_strides[4];
if constexpr(ConvBwdDataSpecialization ==
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
Filter1x1Stride1Pad0)
{
return make_naive_tensor_descriptor(make_tuple(N * Ho * Wo, K), static constexpr auto DIdx = NonSpatialDimsNum;
make_tuple(WiStride, KStride)); static constexpr auto HIdx =
} NDimSpatial == 2 ? NonSpatialDimsNum : Number<NonSpatialDimsNum + 1>{};
else static constexpr auto WIdx =
NDimSpatial == 2 ? Number<NonSpatialDimsNum + 1>{} : Number<NonSpatialDimsNum + 2>{};
static constexpr auto ZIdx = NonSpatialDimsNum;
static constexpr auto YIdx =
NDimSpatial == 2 ? NonSpatialDimsNum : Number<NonSpatialDimsNum + 1>{};
static constexpr auto XIdx =
NDimSpatial == 2 ? Number<NonSpatialDimsNum + 1>{} : Number<NonSpatialDimsNum + 2>{};
template <typename ConvDimsType>
static long_index_t calculate_element_space_size_impl(const ConvDimsType& lengths,
const ConvDimsType& strides,
index_t i)
{
long_index_t acc = 1;
for(; i < (NDimSpatial + 3); i++)
{ {
return make_naive_tensor_descriptor(make_tuple(N, Ho, Wo, K), acc +=
make_tuple(NStride, HiStride, WiStride, KStride)); static_cast<long_index_t>(lengths[i] - I1) * static_cast<long_index_t>(strides[i]);
} }
return acc;
} }
else if constexpr(is_same_v<ALayout, tensor_layout::convolution::NDHWGK>)
template <typename ConvDimsType>
static IndexType GetSplitedNSize(const ConvDimsType& a_g_n_k_wos_lengths,
const ConvDimsType& a_g_n_k_wos_strides,
const ConvDimsType& c_g_n_c_wis_lengths,
const ConvDimsType& c_g_n_c_wis_strides)
{ {
const index_t NStride = out_g_n_k_wos_strides[1]; const long_index_t a_element_space_size =
const index_t DoStride = out_g_n_k_wos_strides[3]; calculate_element_space_size_impl(a_g_n_k_wos_lengths, a_g_n_k_wos_strides, I1);
const index_t HoStride = out_g_n_k_wos_strides[4]; const long_index_t c_element_space_size =
const index_t WoStride = out_g_n_k_wos_strides[5]; calculate_element_space_size_impl(c_g_n_c_wis_lengths, c_g_n_c_wis_strides, I1);
if constexpr(ConvBwdDataSpecialization == const long_index_t element_space_size = math::max(a_element_space_size * sizeof(ADataType),
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: c_element_space_size * sizeof(CDataType));
Filter1x1Stride1Pad0) constexpr long_index_t TwoGB = (long_index_t{1} << 31);
const IndexType N = a_g_n_k_wos_lengths[I1];
if(element_space_size > TwoGB)
{ {
// Minimum divisor of N to not exceed 2GB
const auto divisor = math::integer_divide_ceil(element_space_size, TwoGB);
return make_naive_tensor_descriptor(make_tuple(N * Do * Ho * Wo, K), if(divisor <= static_cast<double>(N))
make_tuple(WoStride, KStride)); {
// Find least divisor of N larger than element_space_size / TwoGB
// Iterate up to sqrt(N). There are no divisors above this value.
for(IndexType least_divisor = divisor; least_divisor * least_divisor <= N;
least_divisor++)
{
if(N % least_divisor == 0)
{
return N / least_divisor;
}
}
// Not found, process one Convolution N per block
return 1;
}
else
{
// Not possible to support even after split N.
// Too large tensor.
return N;
}
} }
else else
{ {
return make_naive_tensor_descriptor( // Split N is not needed.
make_tuple(N, Do, Ho, Wo, K), return N;
make_tuple(NStride, DoStride, HoStride, WoStride, KStride));
} }
} }
else if constexpr(is_same_v<ALayout, tensor_layout::convolution::GNHWK>)
public:
__host__ __device__ constexpr TransformConvBwdDataToGemm_v1() {}
template <typename TransformConvBwdDataToGemm_v1Base>
__host__ __device__ TransformConvBwdDataToGemm_v1(
const TransformConvBwdDataToGemm_v1Base& transform_conv_bwd_data_to_gemm_base)
: N_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.N_)},
Di_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.Di_)},
Hi_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.Hi_)},
Wi_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.Wi_)},
Do_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.Do_)},
Ho_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.Ho_)},
Wo_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.Wo_)},
Z_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.Z_)},
Y_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.Y_)},
X_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.X_)},
K_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.K_)},
C_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.C_)},
DiStride_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.DiStride_)},
HiStride_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.HiStride_)},
WiStride_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.WiStride_)},
DoStride_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.DoStride_)},
HoStride_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.HoStride_)},
WoStride_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.WoStride_)},
CStrideTensorB_{
static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.CStrideTensorB_)},
CStrideTensorC_{
static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.CStrideTensorC_)},
KStrideTensorA_{
static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.KStrideTensorA_)},
KStrideTensorB_{
static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.KStrideTensorB_)},
NStrideTensorA_{
static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.NStrideTensorA_)},
NStrideTensorC_{
static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.NStrideTensorC_)},
ConvStrideD_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.ConvStrideD_)},
ConvStrideH_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.ConvStrideH_)},
ConvStrideW_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.ConvStrideW_)},
ConvDilationD_{
static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.ConvDilationD_)},
ConvDilationH_{
static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.ConvDilationH_)},
ConvDilationW_{
static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.ConvDilationW_)},
InLeftPadD_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.InLeftPadD_)},
InLeftPadH_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.InLeftPadH_)},
InLeftPadW_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.InLeftPadW_)},
InRightPadD_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.InRightPadD_)},
InRightPadH_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.InRightPadH_)},
InRightPadW_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.InRightPadW_)},
IdxZTilde_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.IdxZTilde_)},
IdxYTilde_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.IdxYTilde_)},
IdxXTilde_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.IdxXTilde_)},
GcdStrideDilationD_{
static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.GcdStrideDilationD_)},
GcdStrideDilationH_{
static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.GcdStrideDilationH_)},
GcdStrideDilationW_{
static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.GcdStrideDilationW_)},
ZTilde_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.ZTilde_)},
YTilde_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.YTilde_)},
XTilde_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.XTilde_)},
DTilde_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.DTilde_)},
HTilde_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.HTilde_)},
WTilde_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.WTilde_)},
ZDot_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.ZDot_)},
YDot_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.YDot_)},
XDot_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.XDot_)}
{ {
// assume packed }
if constexpr(ConvBwdDataSpecialization ==
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: template <typename ConvDimsType, typename ConvSpatialDimsType>
Filter1x1Stride1Pad0) __host__ __device__
TransformConvBwdDataToGemm_v1(const ConvDimsType& a_g_n_k_wos_lengths,
const ConvDimsType& a_g_n_k_wos_strides,
const ConvDimsType& b_g_k_c_xs_lengths,
const ConvDimsType& b_g_k_c_xs_strides,
const ConvDimsType& c_g_n_c_wis_lengths,
const ConvDimsType& c_g_n_c_wis_strides,
const ConvSpatialDimsType& conv_filter_strides,
const ConvSpatialDimsType& conv_filter_dilations,
const ConvSpatialDimsType& input_left_pads,
const ConvSpatialDimsType& input_right_pads,
const ConvSpatialDimsType& tildes)
: Hi_{c_g_n_c_wis_lengths[HIdx]},
Wi_{c_g_n_c_wis_lengths[WIdx]},
Ho_{a_g_n_k_wos_lengths[HIdx]},
Wo_{a_g_n_k_wos_lengths[WIdx]},
Y_{b_g_k_c_xs_lengths[YIdx]},
X_{b_g_k_c_xs_lengths[XIdx]},
K_{a_g_n_k_wos_lengths[I2]},
C_{b_g_k_c_xs_lengths[I2]},
HiStride_{c_g_n_c_wis_strides[HIdx]},
WiStride_{c_g_n_c_wis_strides[WIdx]},
HoStride_{a_g_n_k_wos_strides[HIdx]},
WoStride_{a_g_n_k_wos_strides[WIdx]},
CStrideTensorB_{b_g_k_c_xs_strides[I2]},
CStrideTensorC_{c_g_n_c_wis_strides[I2]},
KStrideTensorA_{a_g_n_k_wos_strides[I2]},
KStrideTensorB_{b_g_k_c_xs_strides[I1]},
NStrideTensorA_{a_g_n_k_wos_strides[I1]},
NStrideTensorC_{c_g_n_c_wis_strides[I1]},
ConvStrideH_{conv_filter_strides[HIdx - NonSpatialDimsNum]},
ConvStrideW_{conv_filter_strides[WIdx - NonSpatialDimsNum]},
ConvDilationH_{conv_filter_dilations[HIdx - NonSpatialDimsNum]},
ConvDilationW_{conv_filter_dilations[WIdx - NonSpatialDimsNum]},
InLeftPadH_{input_left_pads[HIdx - NonSpatialDimsNum]},
InLeftPadW_{input_left_pads[WIdx - NonSpatialDimsNum]},
InRightPadH_{input_right_pads[HIdx - NonSpatialDimsNum]},
InRightPadW_{input_right_pads[WIdx - NonSpatialDimsNum]},
IdxYTilde_{tildes[YIdx - NonSpatialDimsNum]},
IdxXTilde_{tildes[XIdx - NonSpatialDimsNum]}
{
static_assert(is_same_v<ConvSpatialDimsType, std::array<IndexType, NDimSpatial>> ||
is_same_v<ConvSpatialDimsType, ck::Array<IndexType, NDimSpatial>>);
static_assert(is_same_v<ConvDimsType, std::array<IndexType, NDimSpatial + I3>> ||
is_same_v<ConvDimsType, ck::Array<IndexType, NDimSpatial + I3>>);
if constexpr(SplitN)
{ {
return make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); N_ = GetSplitedNSize(
a_g_n_k_wos_lengths, a_g_n_k_wos_strides, c_g_n_c_wis_lengths, c_g_n_c_wis_strides);
} }
else else
{ {
return make_naive_tensor_descriptor_packed(make_tuple(N, Ho, Wo, K)); N_ = c_g_n_c_wis_lengths[I1];
} }
} if constexpr(NDimSpatial == 3)
else if constexpr(is_same_v<ALayout, tensor_layout::convolution::GNDHWK>)
{
// assume packed
if constexpr(ConvBwdDataSpecialization ==
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
Filter1x1Stride1Pad0)
{ {
return make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo, K)); Di_ = c_g_n_c_wis_lengths[DIdx];
Do_ = a_g_n_k_wos_lengths[DIdx];
Z_ = b_g_k_c_xs_lengths[ZIdx];
DiStride_ = c_g_n_c_wis_strides[DIdx];
DoStride_ = a_g_n_k_wos_strides[DIdx];
ConvStrideD_ = conv_filter_strides[DIdx - NonSpatialDimsNum];
ConvDilationD_ = conv_filter_dilations[DIdx - NonSpatialDimsNum];
InLeftPadD_ = input_left_pads[DIdx - NonSpatialDimsNum];
InRightPadD_ = input_right_pads[DIdx - NonSpatialDimsNum];
IdxZTilde_ = tildes[ZIdx - NonSpatialDimsNum];
GcdStrideDilationD_ = math::gcd(ConvStrideD_, ConvDilationD_);
ZTilde_ = ConvStrideD_ / GcdStrideDilationD_;
DTilde_ = Do_ + math::integer_divide_ceil(ConvDilationD_ * (Z_ - I1), ConvStrideD_);
ZDot_ = math::integer_divide_ceil(Z_, ZTilde_);
} }
else else
{ {
return make_naive_tensor_descriptor_packed(make_tuple(N, Do, Ho, Wo, K)); Di_ = Do_ = Z_ = ZTilde_ = ConvStrideD_ = DTilde_ = ZDot_ = 1;
InLeftPadD_ = InRightPadD_ = DiStride_ = DoStride_ = IdxZTilde_ = 0;
} }
}
else
{
throw std::runtime_error("wrong! unsupported layout: " + ALayout::name());
}
}
template <typename BLayout> GcdStrideDilationH_ = math::gcd(ConvStrideH_, ConvDilationH_);
constexpr auto make_wei_grid_desc( GcdStrideDilationW_ = math::gcd(ConvStrideW_, ConvDilationW_);
const index_t K, const index_t Z, const index_t Y, const index_t X, const index_t C)
{
if constexpr(is_same_v<BLayout, tensor_layout::convolution::GKYXC>) YTilde_ = ConvStrideH_ / GcdStrideDilationH_;
{ XTilde_ = ConvStrideW_ / GcdStrideDilationW_;
return make_naive_tensor_descriptor_packed(make_tuple(K, Y, X, C));
}
else if constexpr(is_same_v<BLayout, tensor_layout::convolution::GKZYXC>)
{
return make_naive_tensor_descriptor_packed(make_tuple(K, Z, Y, X, C));
}
else
{
throw std::runtime_error("wrong! unsupported layout: " + BLayout::name());
}
}
template <index_t NDimSpatial, typename CLayout>
constexpr auto make_in_grid_desc(const index_t N,
const index_t Di,
const index_t Hi,
const index_t Wi,
const index_t C,
const std::array<index_t, NDimSpatial + 3>& in_g_n_c_wis_strides)
{
if constexpr(is_same_v<CLayout, tensor_layout::convolution::GNHWC> || HTilde_ = Ho_ + math::integer_divide_ceil(ConvDilationH_ * (Y_ - I1), ConvStrideH_);
is_same_v<CLayout, tensor_layout::convolution::NHWGC> || WTilde_ = Wo_ + math::integer_divide_ceil(ConvDilationW_ * (X_ - I1), ConvStrideW_);
is_same_v<CLayout, tensor_layout::convolution::G_NHW_C>)
{ YDot_ = math::integer_divide_ceil(Y_, YTilde_);
return make_naive_tensor_descriptor(make_tuple(N, Hi, Wi, C), XDot_ = math::integer_divide_ceil(X_, XTilde_);
make_tuple(in_g_n_c_wis_strides[1],
in_g_n_c_wis_strides[3],
in_g_n_c_wis_strides[4],
in_g_n_c_wis_strides[2]));
} }
else if constexpr(is_same_v<CLayout, tensor_layout::convolution::GNDHWC> ||
is_same_v<CLayout, tensor_layout::convolution::NDHWGC>) #if 0 // At now not supported to split tensor
__host__ bool AreDescriptorsSmallerThan2GB() const
{ {
return make_naive_tensor_descriptor(make_tuple(N, Di, Hi, Wi, C), constexpr long_index_t TwoGB = (long_index_t{1} << 31);
make_tuple(in_g_n_c_wis_strides[1],
in_g_n_c_wis_strides[3], const long_index_t in_desc_space_size =
in_g_n_c_wis_strides[4], I1 + (N_ - I1) * NStrideTensorC_ + (Di_ - I1) * DiStride_ + (Hi_ - I1) * HiStride_ +
in_g_n_c_wis_strides[5], (Wi_ - I1) * WiStride_ + (C_ - I1) * CStrideTensorC_;
in_g_n_c_wis_strides[2])); const long_index_t out_desc_space_size =
I1 + (N_ - I1) * NStrideTensorA_ + (Do_ - I1) * DoStride_ + (Ho_ - I1) * HoStride_ +
(Wo_ - I1) * WoStride_ + (K_ - I1) * KStrideTensorA_;
bool is_a_descriptor_smaller_than_2GB = (out_desc_space_size * sizeof(ADataType)) <= TwoGB;
bool is_c_descriptor_smaller_than_2GB = (in_desc_space_size * sizeof(CDataType)) <= TwoGB;
return is_a_descriptor_smaller_than_2GB && is_c_descriptor_smaller_than_2GB;
} }
else
__host__ auto SplitConvProblem(const ADataType* a_grid_ptr_base,
CDataType* c_grid_ptr_base) const
{ {
throw std::runtime_error("wrong! unsupported layout: " + CLayout::name()); // Create copies
} auto conv_to_gemm_transformer_left = *this;
} auto conv_to_gemm_transformer_right = *this;
IndexType a_right_offset = 0;
IndexType c_right_offset = 0;
// Calculate real filter size
const IndexType z_eff = (Z_ - 1) * ConvDilationD_ + 1;
const IndexType y_eff = (Y_ - 1) * ConvDilationH_ + 1;
const IndexType x_eff = (X_ - 1) * ConvDilationW_ + 1;
// Calculate start position in input for right tensor
const IndexType di_right_transformer_start_idx = (Do_ / 2) * ConvStrideD_;
const IndexType hi_right_transformer_start_idx = (Ho_ / 2) * ConvStrideH_;
const IndexType wi_right_transformer_start_idx = (Wo_ / 2) * ConvStrideW_;
// Calculate last position in input for left tensor
const IndexType di_left_transformer_end_idx = (Do_ / 2 - 1) * ConvStrideD_ + z_eff;
const IndexType hi_left_transformer_end_idx = (Ho_ / 2 - 1) * ConvStrideH_ + y_eff;
const IndexType wi_left_transformer_end_idx = (Wo_ / 2 - 1) * ConvStrideW_ + x_eff;
// Allow to split if whole left padding will be in left tensor and right padding in right
// tensor
const bool is_possible_to_split_d = Do_ != 1 &&
di_right_transformer_start_idx > InLeftPadD_ &&
di_left_transformer_end_idx <= (InLeftPadD_ + Di_);
const bool is_possible_to_split_h = Ho_ != 1 &&
hi_right_transformer_start_idx > InLeftPadH_ &&
hi_left_transformer_end_idx <= (InLeftPadH_ + Hi_);
const bool is_possible_to_split_w = Wo_ != 1 &&
wi_right_transformer_start_idx > InLeftPadW_ &&
wi_left_transformer_end_idx <= (InLeftPadW_ + Wi_);
if(is_possible_to_split_d)
{
// Apply new sizes
// Split output on half
conv_to_gemm_transformer_left.Do_ = Do_ / 2;
conv_to_gemm_transformer_right.Do_ = Do_ - Do_ / 2;
// Assign left padding to left convolution
conv_to_gemm_transformer_left.InLeftPadD_ = InLeftPadD_;
conv_to_gemm_transformer_right.InLeftPadD_ = 0;
// Assign right padding to right convolution
conv_to_gemm_transformer_left.InRightPadD_ = 0;
conv_to_gemm_transformer_right.InRightPadD_ = InRightPadD_;
// Calculate new input size
conv_to_gemm_transformer_left.Di_ = di_left_transformer_end_idx - InLeftPadD_;
conv_to_gemm_transformer_right.Di_ =
math::min(Di_ - (di_right_transformer_start_idx - InLeftPadD_),
(conv_to_gemm_transformer_right.Do_ - 1) * ConvStrideD_ + z_eff);
;
// Calcualte offsets
a_right_offset = (Do_ / 2) * DoStride_;
c_right_offset = ((Do_ / 2) * ConvStrideD_ - InLeftPadD_) * DiStride_;
}
else if(is_possible_to_split_h)
{
conv_to_gemm_transformer_left.Ho_ = Ho_ / 2;
conv_to_gemm_transformer_right.Ho_ = Ho_ - Ho_ / 2;
} // namespace conv_to_gemm_transformer_left.InLeftPadH_ = InLeftPadH_;
conv_to_gemm_transformer_right.InLeftPadH_ = 0;
template < conv_to_gemm_transformer_left.InRightPadH_ = 0;
index_t NDimSpatial, conv_to_gemm_transformer_right.InRightPadH_ = InRightPadH_;
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization ConvBwdDataSpecialization,
index_t AK1,
index_t BK1,
index_t GemmMPerBlock,
index_t GemmNPerBlock,
index_t GemmKPerBlock,
bool DoPadGemmM,
bool DoPadGemmN>
struct TransformConvBwdDataToGemm_v1
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto NonSpatialDimsNum = Number<3>{}; conv_to_gemm_transformer_left.Hi_ = hi_left_transformer_end_idx - InLeftPadH_;
conv_to_gemm_transformer_right.Hi_ =
math::min(Hi_ - (hi_right_transformer_start_idx - InLeftPadH_),
(conv_to_gemm_transformer_right.Ho_ - 1) * ConvStrideH_ + y_eff);
a_right_offset = (Ho_ / 2) * HoStride_;
c_right_offset = ((Ho_ / 2) * ConvStrideH_ - InLeftPadH_) * HiStride_;
}
else if(is_possible_to_split_w)
{
conv_to_gemm_transformer_left.Wo_ = Wo_ / 2;
conv_to_gemm_transformer_right.Wo_ = Wo_ - Wo_ / 2;
static constexpr auto DIdx = Number<NonSpatialDimsNum>{}; conv_to_gemm_transformer_left.InLeftPadW_ = InLeftPadW_;
static constexpr auto HIdx = conv_to_gemm_transformer_right.InLeftPadW_ = 0;
NDimSpatial == 2 ? Number<NonSpatialDimsNum>{} : Number<NonSpatialDimsNum + 1>{};
static constexpr auto WIdx =
NDimSpatial == 2 ? Number<NonSpatialDimsNum + 1>{} : Number<NonSpatialDimsNum + 2>{};
static constexpr auto ZIdx = Number<NonSpatialDimsNum>{}; conv_to_gemm_transformer_left.InRightPadW_ = 0;
static constexpr auto YIdx = conv_to_gemm_transformer_right.InRightPadW_ = InRightPadW_;
NDimSpatial == 2 ? Number<NonSpatialDimsNum>{} : Number<NonSpatialDimsNum + 1>{};
static constexpr auto XIdx =
NDimSpatial == 2 ? Number<NonSpatialDimsNum + 1>{} : Number<NonSpatialDimsNum + 2>{};
template <typename ALayout, conv_to_gemm_transformer_left.Wi_ = wi_left_transformer_end_idx - InLeftPadW_;
typename std::enable_if<(NDimSpatial == 2 || NDimSpatial == 3) && conv_to_gemm_transformer_right.Wi_ =
(is_same_v<ALayout, tensor_layout::convolution::GNHWK> || math::min(Wi_ - (wi_right_transformer_start_idx - InLeftPadW_),
is_same_v<ALayout, tensor_layout::convolution::GNDHWK> || (conv_to_gemm_transformer_right.Wo_ - 1) * ConvStrideW_ + x_eff);
is_same_v<ALayout, tensor_layout::convolution::NHWGK> ||
is_same_v<ALayout, tensor_layout::convolution::NDHWGK>), a_right_offset = (Wo_ / 2) * WoStride_;
bool>::type = false> c_right_offset = ((Wo_ / 2) * ConvStrideW_ - InLeftPadW_) * WiStride_;
static auto MakeADescriptor_AK0_M_AK1( }
const std::array<index_t, NDimSpatial + 3>& out_g_n_k_wos_lengths, // Return left transform, right transformer, right offset to Input and right offset to
const std::array<index_t, NDimSpatial + 3>& out_g_n_k_wos_strides, // Output
const std::array<index_t, NDimSpatial + 3>& wei_g_k_c_xs_lengths, return ck::make_tuple(conv_to_gemm_transformer_left,
const std::array<index_t, NDimSpatial + 3>& /* wei_g_k_c_xs_strides */, conv_to_gemm_transformer_right,
const std::array<index_t, NDimSpatial + 3>& in_g_n_c_wis_lengths, a_grid_ptr_base + a_right_offset,
const std::array<index_t, NDimSpatial + 3>& /* in_g_n_c_wis_strides */, c_grid_ptr_base + c_right_offset);
const std::array<index_t, NDimSpatial>& conv_filter_strides, }
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads, __host__ auto SplitConvProblem(const ADataType* a_grid_ptr_base,
const std::array<index_t, NDimSpatial>& /* input_right_pads */, CDataType* c_grid_ptr_base) const
const std::array<index_t, NDimSpatial>& tildes)
{ {
index_t i_ztilde = tildes[ZIdx - NonSpatialDimsNum]; // Create copies
index_t i_ytilde = tildes[YIdx - NonSpatialDimsNum]; auto conv_to_gemm_transformer_left = *this;
index_t i_xtilde = tildes[XIdx - NonSpatialDimsNum]; auto conv_to_gemm_transformer_right = *this;
IndexType a_right_offset = 0;
IndexType c_right_offset = 0;
// Calculate start position in input for right tensor
const IndexType do_right_transformer_start_idx = math::integer_divide_ceil((Di_ / 2) + InLeftPadD_ - ((Z_ - 1) * ConvDilationD_), ConvStrideD_);
const IndexType ho_right_transformer_start_idx = math::integer_divide_ceil((Hi_ / 2) + InLeftPadH_ - ((Y_ - 1) * ConvDilationH_), ConvStrideH_);
const IndexType wo_right_transformer_start_idx = math::integer_divide_ceil((Wi_ / 2) + InLeftPadW_ - ((X_ - 1) * ConvDilationW_), ConvStrideW_);
// Calculate last position in input for left tensor
const IndexType do_left_transformer_end_idx = math::integer_divide_ceil((Di_ / 2 - 1) + InLeftPadD_, ConvStrideD_);
const IndexType ho_left_transformer_end_idx = math::integer_divide_ceil((Hi_ / 2 - 1) + InLeftPadH_, ConvStrideH_);
const IndexType wo_left_transformer_end_idx = math::integer_divide_ceil((Wi_ / 2 - 1) + InLeftPadW_, ConvStrideW_);
if(Di_!=1)
{
// Apply new sizes
// Split output on half
conv_to_gemm_transformer_left.Di_ = Di_ / 2;
conv_to_gemm_transformer_right.Di_ = Di_ - Di_ / 2;
// Assign left padding to left convolution
conv_to_gemm_transformer_left.InLeftPadD_ = InLeftPadD_;
conv_to_gemm_transformer_right.InLeftPadD_ = 0;
// // Assign right padding to right convolution
conv_to_gemm_transformer_left.InRightPadD_ = 0;
conv_to_gemm_transformer_right.InRightPadD_ = InRightPadD_;
// Calculate new input size
conv_to_gemm_transformer_left.Do_ = do_left_transformer_end_idx;
conv_to_gemm_transformer_right.Do_ = Do_ - do_right_transformer_start_idx;
;
// Calcualte offsets
a_right_offset = do_right_transformer_start_idx * DoStride_;
c_right_offset = (Di_ / 2) * DiStride_;
}
else if(Hi_!=1)
{
// Apply new sizes
// Split output on half
conv_to_gemm_transformer_left.Hi_ = Hi_ / 2;
conv_to_gemm_transformer_right.Hi_ = Hi_ - Hi_ / 2;
// Assign left padding to left convolution
conv_to_gemm_transformer_left.InLeftPadH_ = InLeftPadH_;
conv_to_gemm_transformer_right.InLeftPadH_ = 0;
// // Assign right padding to right convolution
conv_to_gemm_transformer_left.InRightPadH_ = 0;
conv_to_gemm_transformer_right.InRightPadH_ = InRightPadH_;
// Calculate new input size
conv_to_gemm_transformer_left.Ho_ = ho_left_transformer_end_idx ;
conv_to_gemm_transformer_right.Ho_ = Ho_ - ho_right_transformer_start_idx ;
;
// Calcualte offsets
a_right_offset = ho_right_transformer_start_idx * HoStride_;
c_right_offset = (Hi_ / 2) * HiStride_;
}
else if(Wi_!=1)
{
// Apply new sizes
// Split output on half
conv_to_gemm_transformer_left.Wi_ = Wi_ / 2;
conv_to_gemm_transformer_right.Wi_ = Wi_ - Wi_ / 2;
// Assign left padding to left convolution
conv_to_gemm_transformer_left.InLeftPadW_ = InLeftPadW_;
conv_to_gemm_transformer_right.InLeftPadW_ = 0;
// Assign right padding to right convolution
conv_to_gemm_transformer_left.InRightPadW_ = 0;
conv_to_gemm_transformer_right.InRightPadW_ = InRightPadW_;
// Calculate new input size
conv_to_gemm_transformer_left.Wo_ = wo_left_transformer_end_idx;
conv_to_gemm_transformer_right.Wo_ = Wo_ - wo_right_transformer_start_idx;
;
// Calcualte offsets
a_right_offset = wo_right_transformer_start_idx * WoStride_;
c_right_offset = (Wi_ / 2) * WiStride_;
}
// Return left transform, right transformer, right offset to Input and right offset to
// Output
return ck::make_tuple(conv_to_gemm_transformer_left,
conv_to_gemm_transformer_right,
a_grid_ptr_base + a_right_offset,
c_grid_ptr_base + c_right_offset);
}
#endif
const index_t N = in_g_n_c_wis_lengths[1]; __host__ __device__ auto MakeOutGridDesc() const
const index_t K = wei_g_k_c_xs_lengths[1]; {
if constexpr(is_same_v<ALayout, tensor_layout::convolution::NHWGK>)
{
if constexpr(ConvBwdDataSpecialization ==
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
Filter1x1Stride1Pad0)
{
const index_t Di = NDimSpatial == 3 ? in_g_n_c_wis_lengths[DIdx] : 1; return make_naive_tensor_descriptor(make_tuple(N_ * Ho_ * Wo_, K_),
const index_t Hi = in_g_n_c_wis_lengths[HIdx]; make_tuple(WoStride_, KStrideTensorA_));
const index_t Wi = in_g_n_c_wis_lengths[WIdx]; }
else
{
return make_naive_tensor_descriptor(
make_tuple(N_, Ho_, Wo_, K_),
make_tuple(NStrideTensorA_, HoStride_, WoStride_, KStrideTensorA_));
}
}
else if constexpr(is_same_v<ALayout, tensor_layout::convolution::NDHWGK>)
{
if constexpr(ConvBwdDataSpecialization ==
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
Filter1x1Stride1Pad0)
{
const index_t Do = NDimSpatial == 3 ? out_g_n_k_wos_lengths[DIdx] : 1; return make_naive_tensor_descriptor(make_tuple(N_ * Do_ * Ho_ * Wo_, K_),
const index_t Ho = out_g_n_k_wos_lengths[HIdx]; make_tuple(WoStride_, KStrideTensorA_));
const index_t Wo = out_g_n_k_wos_lengths[WIdx]; }
else
{
return make_naive_tensor_descriptor(
make_tuple(N_, Do_, Ho_, Wo_, K_),
make_tuple(NStrideTensorA_, DoStride_, HoStride_, WoStride_, KStrideTensorA_));
}
}
else if constexpr(is_same_v<ALayout, tensor_layout::convolution::GNHWK>)
{
// assume packed
if constexpr(ConvBwdDataSpecialization ==
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
Filter1x1Stride1Pad0)
{
return make_naive_tensor_descriptor_packed(make_tuple(N_ * Ho_ * Wo_, K_));
}
else
{
return make_naive_tensor_descriptor_packed(make_tuple(N_, Ho_, Wo_, K_));
}
}
else if constexpr(is_same_v<ALayout, tensor_layout::convolution::GNDHWK>)
{
// assume packed
if constexpr(ConvBwdDataSpecialization ==
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
Filter1x1Stride1Pad0)
{
return make_naive_tensor_descriptor_packed(make_tuple(N_ * Do_ * Ho_ * Wo_, K_));
}
else
{
return make_naive_tensor_descriptor_packed(make_tuple(N_, Do_, Ho_, Wo_, K_));
}
}
else
{
throw std::runtime_error("wrong! unsupported layout: " + ALayout::name());
}
}
const index_t Z = NDimSpatial == 3 ? wei_g_k_c_xs_lengths[ZIdx] : 1; __host__ __device__ auto MakeWeiGridDesc() const
const index_t Y = wei_g_k_c_xs_lengths[YIdx]; {
const index_t X = wei_g_k_c_xs_lengths[XIdx];
const index_t InLeftPadD = input_left_pads[DIdx - NonSpatialDimsNum]; if constexpr(is_same_v<BLayout, tensor_layout::convolution::GKYXC>)
const index_t InLeftPadH = input_left_pads[HIdx - NonSpatialDimsNum]; {
const index_t InLeftPadW = input_left_pads[WIdx - NonSpatialDimsNum]; return make_naive_tensor_descriptor_packed(make_tuple(K_, Y_, X_, C_));
}
else if constexpr(is_same_v<BLayout, tensor_layout::convolution::GKZYXC>)
{
return make_naive_tensor_descriptor_packed(make_tuple(K_, Z_, Y_, X_, C_));
}
else
{
throw std::runtime_error("wrong! unsupported layout: " + BLayout::name());
}
}
const index_t ConvStrideD = conv_filter_strides[DIdx - NonSpatialDimsNum]; __host__ __device__ auto MakeInGridDesc() const
const index_t ConvStrideH = conv_filter_strides[HIdx - NonSpatialDimsNum]; {
const index_t ConvStrideW = conv_filter_strides[WIdx - NonSpatialDimsNum];
const index_t ConvDilationD = conv_filter_dilations[DIdx - NonSpatialDimsNum]; if constexpr(is_same_v<CLayout, tensor_layout::convolution::GNHWC> ||
const index_t ConvDilationH = conv_filter_dilations[HIdx - NonSpatialDimsNum]; is_same_v<CLayout, tensor_layout::convolution::NHWGC> ||
const index_t ConvDilationW = conv_filter_dilations[WIdx - NonSpatialDimsNum]; is_same_v<CLayout, tensor_layout::convolution::G_NHW_C>)
{
return make_naive_tensor_descriptor(
make_tuple(N_, Hi_, Wi_, C_),
make_tuple(NStrideTensorC_, HiStride_, WiStride_, CStrideTensorC_));
}
else if constexpr(is_same_v<CLayout, tensor_layout::convolution::GNDHWC> ||
is_same_v<CLayout, tensor_layout::convolution::NDHWGC>)
{
return make_naive_tensor_descriptor(
make_tuple(N_, Di_, Hi_, Wi_, C_),
make_tuple(NStrideTensorC_, DiStride_, HiStride_, WiStride_, CStrideTensorC_));
}
else
{
throw std::runtime_error("wrong! unsupported layout: " + CLayout::name());
}
}
template <
typename ALayout_ = ALayout,
typename std::enable_if<(NDimSpatial == 2 || NDimSpatial == 3) &&
(is_same_v<ALayout_, tensor_layout::convolution::GNHWK> ||
is_same_v<ALayout_, tensor_layout::convolution::GNDHWK> ||
is_same_v<ALayout_, tensor_layout::convolution::NHWGK> ||
is_same_v<ALayout_, tensor_layout::convolution::NDHWGK>),
bool>::type = false>
__host__ __device__ auto MakeADescriptor_AK0_M_AK1() const
{
// n_do_ho_wo_k for 3d or n_ho_wo_k for 2d // n_do_ho_wo_k for 3d or n_ho_wo_k for 2d
const auto out_grid_desc = const auto out_grid_desc = MakeOutGridDesc();
make_out_grid_desc<NDimSpatial, ALayout, ConvBwdDataSpecialization>(
N, Do, Ho, Wo, K, out_g_n_k_wos_strides);
if constexpr(ConvBwdDataSpecialization == if constexpr(ConvBwdDataSpecialization ==
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
Filter1x1Stride1Pad0) Filter1x1Stride1Pad0)
{ {
const index_t AK0 = math::integer_divide_ceil(K, AK1); const index_t AK0 = math::integer_divide_ceil(K_, AK1);
// A: output tensor // A: output tensor
const auto out_gemmak0_gemmmraw_gemmak1_grid_desc = transform_tensor_descriptor( const auto out_gemmak0_gemmmraw_gemmak1_grid_desc = transform_tensor_descriptor(
out_grid_desc, out_grid_desc,
make_tuple(make_pass_through_transform(N * Do * Ho * Wo), make_tuple(make_pass_through_transform(N_ * Do_ * Ho_ * Wo_),
make_unmerge_transform(make_tuple(AK0, AK1))), make_unmerge_transform(make_tuple(AK0, AK1))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0, 2>{})); make_tuple(Sequence<1>{}, Sequence<0, 2>{}));
...@@ -266,82 +635,63 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -266,82 +635,63 @@ struct TransformConvBwdDataToGemm_v1
} }
else else
{ {
const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD);
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
const auto ZTilde = ConvStrideD / GcdStrideDilationD;
const auto YTilde = ConvStrideH / GcdStrideDilationH;
const auto XTilde = ConvStrideW / GcdStrideDilationW;
const auto ZDot = math::integer_divide_ceil(Z, ZTilde);
const auto YDot = math::integer_divide_ceil(Y, YTilde);
const auto XDot = math::integer_divide_ceil(X, XTilde);
const auto DTilde =
Do + math::integer_divide_ceil(ConvDilationD * (Z - I1), ConvStrideD);
const auto HTilde =
Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH);
const auto WTilde =
Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
// only work on HTilde and WTilde that contribute to non-padding area of input tensor // only work on HTilde and WTilde that contribute to non-padding area of input tensor
const auto IDTildeSliceBegin = math::integer_divide_floor( const auto IDTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadD - ConvDilationD * (ZTilde - I1)), ConvStrideD); math::max(I0, InLeftPadD_ - ConvDilationD_ * (ZTilde_ - I1)), ConvStrideD_);
const auto IHTildeSliceBegin = math::integer_divide_floor( const auto IHTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH); math::max(I0, InLeftPadH_ - ConvDilationH_ * (YTilde_ - I1)), ConvStrideH_);
const auto IWTildeSliceBegin = math::integer_divide_floor( const auto IWTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW); math::max(I0, InLeftPadW_ - ConvDilationW_ * (XTilde_ - I1)), ConvStrideW_);
const auto IDTildeSliceEnd = math::min( const auto IDTildeSliceEnd = math::min(
DTilde, math::integer_divide_ceil(InLeftPadD + Di - I1, ConvStrideD) + I1); DTilde_, math::integer_divide_ceil(InLeftPadD_ + Di_ - I1, ConvStrideD_) + I1);
const auto IHTildeSliceEnd = math::min( const auto IHTildeSliceEnd = math::min(
HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1); HTilde_, math::integer_divide_ceil(InLeftPadH_ + Hi_ - I1, ConvStrideH_) + I1);
const auto IWTildeSliceEnd = math::min( const auto IWTildeSliceEnd = math::min(
WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1); WTilde_, math::integer_divide_ceil(InLeftPadW_ + Wi_ - I1, ConvStrideW_) + I1);
const auto DTildeSlice = IDTildeSliceEnd - IDTildeSliceBegin; const auto DTildeSlice = IDTildeSliceEnd - IDTildeSliceBegin;
const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin; const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin; const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
// GemmK is different for each GEMM // GemmK is different for each GEMM
const auto ZDotSlice = math::integer_divide_ceil(Z - i_ztilde, ZTilde); const auto ZDotSlice = math::integer_divide_ceil(Z_ - IdxZTilde_, ZTilde_);
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); const auto YDotSlice = math::integer_divide_ceil(Y_ - IdxYTilde_, YTilde_);
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); const auto XDotSlice = math::integer_divide_ceil(X_ - IdxXTilde_, XTilde_);
if constexpr(NDimSpatial == 2) if constexpr(NDimSpatial == 2)
{ {
// A: output tensor // A: output tensor
const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor( const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor(
out_grid_desc, out_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N_),
make_pad_transform(Ho, I0, I0), make_pad_transform(Ho_, I0, I0),
make_pad_transform(Wo, I0, I0), make_pad_transform(Wo_, I0, I0),
make_pass_through_transform(K)), make_pass_through_transform(K_)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor( const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor(
out_n_hop_wop_k_grid_desc, out_n_hop_wop_k_grid_desc,
make_tuple( make_tuple(
make_pass_through_transform(N), make_pass_through_transform(N_),
make_embed_transform(make_tuple(YDot, HTilde), make_embed_transform(make_tuple(YDot_, HTilde_),
make_tuple(-ConvDilationH / GcdStrideDilationH, I1)), make_tuple(-ConvDilationH_ / GcdStrideDilationH_, I1)),
make_embed_transform(make_tuple(XDot, WTilde), make_embed_transform(make_tuple(XDot_, WTilde_),
make_tuple(-ConvDilationW / GcdStrideDilationW, I1)), make_tuple(-ConvDilationW_ / GcdStrideDilationW_, I1)),
make_pass_through_transform(K)), make_pass_through_transform(K_)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc = const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc =
transform_tensor_descriptor( transform_tensor_descriptor(
out_n_ydot_htilde_xdot_wtilde_k_grid_desc, out_n_ydot_htilde_xdot_wtilde_k_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N_),
make_slice_transform(YDot, I0, YDotSlice), make_slice_transform(YDot_, I0, YDotSlice),
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice),
make_slice_transform(XDot, I0, XDotSlice), make_slice_transform(XDot_, I0, XDotSlice),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice),
make_pass_through_transform(K)), make_pass_through_transform(K_)),
make_tuple(Sequence<0>{}, make_tuple(Sequence<0>{},
Sequence<1>{}, Sequence<1>{},
Sequence<2>{}, Sequence<2>{},
...@@ -357,8 +707,8 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -357,8 +707,8 @@ struct TransformConvBwdDataToGemm_v1
const auto out_gemmk_gemmmraw_grid_desc = transform_tensor_descriptor( const auto out_gemmk_gemmmraw_grid_desc = transform_tensor_descriptor(
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc, out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc,
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K)), make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K_)),
make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice))), make_merge_transform(make_tuple(N_, HTildeSlice, WTildeSlice))),
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
...@@ -385,11 +735,11 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -385,11 +735,11 @@ struct TransformConvBwdDataToGemm_v1
// A: output tensor // A: output tensor
const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor( const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor(
out_grid_desc, out_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N_),
make_pad_transform(Do, I0, I0), make_pad_transform(Do_, I0, I0),
make_pad_transform(Ho, I0, I0), make_pad_transform(Ho_, I0, I0),
make_pad_transform(Wo, I0, I0), make_pad_transform(Wo_, I0, I0),
make_pass_through_transform(K)), make_pass_through_transform(K_)),
make_tuple( make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple( make_tuple(
...@@ -398,17 +748,17 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -398,17 +748,17 @@ struct TransformConvBwdDataToGemm_v1
const auto out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc = const auto out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc =
transform_tensor_descriptor( transform_tensor_descriptor(
out_n_hop_wop_k_grid_desc, out_n_hop_wop_k_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N_),
make_embed_transform( make_embed_transform(
make_tuple(ZDot, DTilde), make_tuple(ZDot_, DTilde_),
make_tuple(-ConvDilationD / GcdStrideDilationD, I1)), make_tuple(-ConvDilationD_ / GcdStrideDilationD_, I1)),
make_embed_transform( make_embed_transform(
make_tuple(YDot, HTilde), make_tuple(YDot_, HTilde_),
make_tuple(-ConvDilationH / GcdStrideDilationH, I1)), make_tuple(-ConvDilationH_ / GcdStrideDilationH_, I1)),
make_embed_transform( make_embed_transform(
make_tuple(XDot, WTilde), make_tuple(XDot_, WTilde_),
make_tuple(-ConvDilationW / GcdStrideDilationW, I1)), make_tuple(-ConvDilationW_ / GcdStrideDilationW_, I1)),
make_pass_through_transform(K)), make_pass_through_transform(K_)),
make_tuple(Sequence<0>{}, make_tuple(Sequence<0>{},
Sequence<1>{}, Sequence<1>{},
Sequence<2>{}, Sequence<2>{},
...@@ -424,14 +774,15 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -424,14 +774,15 @@ struct TransformConvBwdDataToGemm_v1
out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc = out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc =
transform_tensor_descriptor( transform_tensor_descriptor(
out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc, out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(
make_slice_transform(ZDot, I0, ZDotSlice), make_pass_through_transform(N_),
make_slice_transform(DTilde, IDTildeSliceBegin, DTildeSlice), make_slice_transform(ZDot_, I0, ZDotSlice),
make_slice_transform(YDot, I0, YDotSlice), make_slice_transform(DTilde_, IDTildeSliceBegin, DTildeSlice),
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), make_slice_transform(YDot_, I0, YDotSlice),
make_slice_transform(XDot, I0, XDotSlice), make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), make_slice_transform(XDot_, I0, XDotSlice),
make_pass_through_transform(K)), make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice),
make_pass_through_transform(K_)),
make_tuple(Sequence<0>{}, make_tuple(Sequence<0>{},
Sequence<1>{}, Sequence<1>{},
Sequence<2>{}, Sequence<2>{},
...@@ -452,8 +803,9 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -452,8 +803,9 @@ struct TransformConvBwdDataToGemm_v1
const auto out_gemmk_gemmmraw_grid_desc = transform_tensor_descriptor( const auto out_gemmk_gemmmraw_grid_desc = transform_tensor_descriptor(
out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc, out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc,
make_tuple( make_tuple(
make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K)), make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K_)),
make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice))), make_merge_transform(
make_tuple(N_, DTildeSlice, HTildeSlice, WTildeSlice))),
make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}), make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
...@@ -482,66 +834,31 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -482,66 +834,31 @@ struct TransformConvBwdDataToGemm_v1
} }
} }
template <typename BLayout, template <typename BLayout_ = BLayout,
typename std::enable_if<(NDimSpatial == 2 || NDimSpatial == 3) && typename std::enable_if<(NDimSpatial == 2 || NDimSpatial == 3) &&
(is_same_v<BLayout, tensor_layout::convolution::GKYXC> || (is_same_v<BLayout_, tensor_layout::convolution::GKYXC> ||
is_same_v<BLayout, tensor_layout::convolution::GKZYXC>), is_same_v<BLayout_, tensor_layout::convolution::GKZYXC>),
bool>::type = false> bool>::type = false>
static auto MakeBDescriptor_BK0_N_BK1( __host__ __device__ auto MakeBDescriptor_BK0_N_BK1() const
const std::array<index_t, NDimSpatial + 3>& out_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& /* out_g_n_k_wos_strides */,
const std::array<index_t, NDimSpatial + 3>& wei_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& /* wei_g_k_c_xs_strides */,
const std::array<index_t, NDimSpatial + 3>& in_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& /* in_g_n_c_wis_strides */,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& /* input_left_pads */,
const std::array<index_t, NDimSpatial>& /* input_right_pads */,
const std::array<index_t, NDimSpatial>& tildes)
{ {
index_t i_ztilde = tildes[ZIdx - NonSpatialDimsNum];
index_t i_ytilde = tildes[YIdx - NonSpatialDimsNum];
index_t i_xtilde = tildes[XIdx - NonSpatialDimsNum];
const index_t N = in_g_n_c_wis_lengths[1];
const index_t K = wei_g_k_c_xs_lengths[1];
const index_t C = wei_g_k_c_xs_lengths[2];
const index_t Do = NDimSpatial == 3 ? out_g_n_k_wos_lengths[DIdx] : 1;
const index_t Ho = out_g_n_k_wos_lengths[HIdx];
const index_t Wo = out_g_n_k_wos_lengths[WIdx];
const index_t Z = NDimSpatial == 3 ? wei_g_k_c_xs_lengths[ZIdx] : 1;
const index_t Y = wei_g_k_c_xs_lengths[YIdx];
const index_t X = wei_g_k_c_xs_lengths[XIdx];
const index_t ConvStrideD = conv_filter_strides[DIdx - NonSpatialDimsNum];
const index_t ConvStrideH = conv_filter_strides[HIdx - NonSpatialDimsNum];
const index_t ConvStrideW = conv_filter_strides[WIdx - NonSpatialDimsNum];
const index_t ConvDilationD = conv_filter_dilations[DIdx - NonSpatialDimsNum];
const index_t ConvDilationH = conv_filter_dilations[HIdx - NonSpatialDimsNum];
const index_t ConvDilationW = conv_filter_dilations[WIdx - NonSpatialDimsNum];
// assume packed // assume packed
// k_y_x_c for 2d or k_z_y_x_c for 3d // k_y_x_c for 2d or k_z_y_x_c for 3d
const auto wei_grid_desc = make_wei_grid_desc<BLayout>(K, Z, Y, X, C); const auto wei_grid_desc = MakeWeiGridDesc();
if constexpr(ConvBwdDataSpecialization == if constexpr(ConvBwdDataSpecialization ==
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
Filter1x1Stride1Pad0) Filter1x1Stride1Pad0)
{ {
const index_t BK0 = math::integer_divide_ceil(K, BK1); const index_t BK0 = math::integer_divide_ceil(K_, BK1);
// B: weight tensor // B: weight tensor
const auto wei_gemmbk0_gemmnraw_gemmbk1_grid_desc = const auto wei_gemmbk0_gemmnraw_gemmbk1_grid_desc =
transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C)), transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K_, C_)),
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(C)), make_pass_through_transform(C_)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
make_naive_tensor_descriptor(make_tuple(N * Do * Ho * Wo, C), make_tuple(I0, I1)); make_naive_tensor_descriptor(make_tuple(N_ * Do_ * Ho_ * Wo_, C_), make_tuple(I0, I1));
const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc = const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor( ck::tensor_operation::device::PadTensorDescriptor(
...@@ -553,22 +870,10 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -553,22 +870,10 @@ struct TransformConvBwdDataToGemm_v1
} }
else else
{ {
const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD);
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
const auto ZTilde = ConvStrideD / GcdStrideDilationD;
const auto YTilde = ConvStrideH / GcdStrideDilationH;
const auto XTilde = ConvStrideW / GcdStrideDilationW;
const auto ZDot = math::integer_divide_ceil(Z, ZTilde);
const auto YDot = math::integer_divide_ceil(Y, YTilde);
const auto XDot = math::integer_divide_ceil(X, XTilde);
// GemmK is different for each GEMM // GemmK is different for each GEMM
const auto ZDotSlice = math::integer_divide_ceil(Z - i_ztilde, ZTilde); const auto ZDotSlice = math::integer_divide_ceil(Z_ - IdxZTilde_, ZTilde_);
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); const auto YDotSlice = math::integer_divide_ceil(Y_ - IdxYTilde_, YTilde_);
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); const auto XDotSlice = math::integer_divide_ceil(X_ - IdxXTilde_, XTilde_);
// B weight tensor // B weight tensor
if constexpr(NDimSpatial == 2) if constexpr(NDimSpatial == 2)
...@@ -576,23 +881,23 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -576,23 +881,23 @@ struct TransformConvBwdDataToGemm_v1
const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor( const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor(
wei_grid_desc, wei_grid_desc,
make_tuple( make_tuple(
make_pass_through_transform(K), make_pass_through_transform(K_),
make_embed_transform(make_tuple(YDot, YTilde), make_embed_transform(make_tuple(YDot_, YTilde_),
make_tuple(ConvStrideH / GcdStrideDilationH, I1)), make_tuple(ConvStrideH_ / GcdStrideDilationH_, I1)),
make_embed_transform(make_tuple(XDot, XTilde), make_embed_transform(make_tuple(XDot_, XTilde_),
make_tuple(ConvStrideW / GcdStrideDilationW, I1)), make_tuple(ConvStrideW_ / GcdStrideDilationW_, I1)),
make_pass_through_transform(C)), make_pass_through_transform(C_)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto wei_k_ydotslice_xdotslice_c_grid_desc = transform_tensor_descriptor( const auto wei_k_ydotslice_xdotslice_c_grid_desc = transform_tensor_descriptor(
wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc, wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc,
make_tuple(make_pass_through_transform(K), make_tuple(make_pass_through_transform(K_),
make_slice_transform(YDot, I0, YDotSlice), make_slice_transform(YDot_, I0, YDotSlice),
make_slice_transform(XDot, I0, XDotSlice), make_slice_transform(XDot_, I0, XDotSlice),
make_freeze_transform(i_ytilde), make_freeze_transform(IdxYTilde_),
make_freeze_transform(i_xtilde), make_freeze_transform(IdxXTilde_),
make_pass_through_transform(C)), make_pass_through_transform(C_)),
make_tuple(Sequence<0>{}, make_tuple(Sequence<0>{},
Sequence<1>{}, Sequence<1>{},
Sequence<3>{}, Sequence<3>{},
...@@ -608,8 +913,8 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -608,8 +913,8 @@ struct TransformConvBwdDataToGemm_v1
const auto wei_gemmk_gemmnraw_grid_desc = transform_tensor_descriptor( const auto wei_gemmk_gemmnraw_grid_desc = transform_tensor_descriptor(
wei_k_ydotslice_xdotslice_c_grid_desc, wei_k_ydotslice_xdotslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K)), make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K_)),
make_pass_through_transform(C)), make_pass_through_transform(C_)),
make_tuple(Sequence<1, 2, 0>{}, Sequence<3>{}), make_tuple(Sequence<1, 2, 0>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
...@@ -636,15 +941,17 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -636,15 +941,17 @@ struct TransformConvBwdDataToGemm_v1
const auto wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc = const auto wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc =
transform_tensor_descriptor( transform_tensor_descriptor(
wei_grid_desc, wei_grid_desc,
make_tuple( make_tuple(make_pass_through_transform(K_),
make_pass_through_transform(K), make_embed_transform(
make_embed_transform(make_tuple(ZDot, ZTilde), make_tuple(ZDot_, ZTilde_),
make_tuple(ConvStrideD / GcdStrideDilationD, I1)), make_tuple(ConvStrideD_ / GcdStrideDilationD_, I1)),
make_embed_transform(make_tuple(YDot, YTilde), make_embed_transform(
make_tuple(ConvStrideH / GcdStrideDilationH, I1)), make_tuple(YDot_, YTilde_),
make_embed_transform(make_tuple(XDot, XTilde), make_tuple(ConvStrideH_ / GcdStrideDilationH_, I1)),
make_tuple(ConvStrideW / GcdStrideDilationW, I1)), make_embed_transform(
make_pass_through_transform(C)), make_tuple(XDot_, XTilde_),
make_tuple(ConvStrideW_ / GcdStrideDilationW_, I1)),
make_pass_through_transform(C_)),
make_tuple(Sequence<0>{}, make_tuple(Sequence<0>{},
Sequence<1>{}, Sequence<1>{},
Sequence<2>{}, Sequence<2>{},
...@@ -659,14 +966,14 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -659,14 +966,14 @@ struct TransformConvBwdDataToGemm_v1
const auto wei_gemmk_zdotslice_ydotslice_xdotslice_c_grid_desc = const auto wei_gemmk_zdotslice_ydotslice_xdotslice_c_grid_desc =
transform_tensor_descriptor( transform_tensor_descriptor(
wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc, wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc,
make_tuple(make_pass_through_transform(K), make_tuple(make_pass_through_transform(K_),
make_slice_transform(ZDot, I0, ZDotSlice), make_slice_transform(ZDot_, I0, ZDotSlice),
make_slice_transform(YDot, I0, YDotSlice), make_slice_transform(YDot_, I0, YDotSlice),
make_slice_transform(XDot, I0, XDotSlice), make_slice_transform(XDot_, I0, XDotSlice),
make_freeze_transform(i_ztilde), make_freeze_transform(IdxZTilde_),
make_freeze_transform(i_ytilde), make_freeze_transform(IdxYTilde_),
make_freeze_transform(i_xtilde), make_freeze_transform(IdxXTilde_),
make_pass_through_transform(C)), make_pass_through_transform(C_)),
make_tuple(Sequence<0>{}, make_tuple(Sequence<0>{},
Sequence<1>{}, Sequence<1>{},
Sequence<3>{}, Sequence<3>{},
...@@ -686,8 +993,9 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -686,8 +993,9 @@ struct TransformConvBwdDataToGemm_v1
const auto wei_gemmk_gemmnraw_grid_desc = transform_tensor_descriptor( const auto wei_gemmk_gemmnraw_grid_desc = transform_tensor_descriptor(
wei_gemmk_zdotslice_ydotslice_xdotslice_c_grid_desc, wei_gemmk_zdotslice_ydotslice_xdotslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K)), make_tuple(
make_pass_through_transform(C)), make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K_)),
make_pass_through_transform(C_)),
make_tuple(Sequence<1, 2, 3, 0>{}, Sequence<4>{}), make_tuple(Sequence<1, 2, 3, 0>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
...@@ -716,66 +1024,20 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -716,66 +1024,20 @@ struct TransformConvBwdDataToGemm_v1
} }
} }
template <typename CLayout, template <
typename std::enable_if<(NDimSpatial == 2 || NDimSpatial == 3) && typename CLayout_ = CLayout,
(is_same_v<CLayout, tensor_layout::convolution::GNHWC> || typename std::enable_if<(NDimSpatial == 2 || NDimSpatial == 3) &&
is_same_v<CLayout, tensor_layout::convolution::GNDHWC> || (is_same_v<CLayout_, tensor_layout::convolution::GNHWC> ||
is_same_v<CLayout, tensor_layout::convolution::NHWGC> || is_same_v<CLayout_, tensor_layout::convolution::GNDHWC> ||
is_same_v<CLayout, tensor_layout::convolution::NDHWGC> || is_same_v<CLayout_, tensor_layout::convolution::NHWGC> ||
is_same_v<CLayout, tensor_layout::convolution::G_NHW_C>), is_same_v<CLayout_, tensor_layout::convolution::NDHWGC> ||
bool>::type = false> is_same_v<CLayout_, tensor_layout::convolution::G_NHW_C>),
static auto bool>::type = false>
MakeCDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& out_g_n_k_wos_lengths, __host__ __device__ auto MakeCDescriptor_M_N() const
const std::array<index_t, NDimSpatial + 3>& /* out_g_n_k_wos_strides */,
const std::array<index_t, NDimSpatial + 3>& wei_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& /* wei_g_k_c_xs_strides */,
const std::array<index_t, NDimSpatial + 3>& in_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& in_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads,
const std::array<index_t, NDimSpatial>& tildes)
{ {
index_t i_ztilde = tildes[ZIdx - NonSpatialDimsNum];
index_t i_ytilde = tildes[YIdx - NonSpatialDimsNum];
index_t i_xtilde = tildes[XIdx - NonSpatialDimsNum];
const index_t N = in_g_n_c_wis_lengths[1];
const index_t C = wei_g_k_c_xs_lengths[2];
const index_t Di = NDimSpatial == 3 ? in_g_n_c_wis_lengths[DIdx] : 1;
const index_t Hi = in_g_n_c_wis_lengths[HIdx];
const index_t Wi = in_g_n_c_wis_lengths[WIdx];
const index_t Do = NDimSpatial == 3 ? out_g_n_k_wos_lengths[DIdx] : 1;
const index_t Ho = out_g_n_k_wos_lengths[HIdx];
const index_t Wo = out_g_n_k_wos_lengths[WIdx];
const index_t Z = NDimSpatial == 3 ? wei_g_k_c_xs_lengths[ZIdx] : 1;
const index_t Y = wei_g_k_c_xs_lengths[YIdx];
const index_t X = wei_g_k_c_xs_lengths[XIdx];
const index_t InLeftPadD = input_left_pads[DIdx - NonSpatialDimsNum];
const index_t InLeftPadH = input_left_pads[HIdx - NonSpatialDimsNum];
const index_t InLeftPadW = input_left_pads[WIdx - NonSpatialDimsNum];
const index_t InRightPadD = input_right_pads[DIdx - NonSpatialDimsNum];
const index_t InRightPadH = input_right_pads[HIdx - NonSpatialDimsNum];
const index_t InRightPadW = input_right_pads[WIdx - NonSpatialDimsNum];
const index_t ConvStrideD = conv_filter_strides[DIdx - NonSpatialDimsNum];
const index_t ConvStrideH = conv_filter_strides[HIdx - NonSpatialDimsNum];
const index_t ConvStrideW = conv_filter_strides[WIdx - NonSpatialDimsNum];
const index_t ConvDilationD = conv_filter_dilations[DIdx - NonSpatialDimsNum];
const index_t ConvDilationH = conv_filter_dilations[HIdx - NonSpatialDimsNum];
const index_t ConvDilationW = conv_filter_dilations[WIdx - NonSpatialDimsNum];
// assume strided // assume strided
// n_hi_wi_c for 2d n_di_hi_wi_c for 3d // n_hi_wi_c for 2d n_di_hi_wi_c for 3d
const auto in_grid_desc = const auto in_grid_desc = MakeInGridDesc();
make_in_grid_desc<NDimSpatial, CLayout>(N, Di, Hi, Wi, C, in_g_n_c_wis_strides);
if constexpr(ConvBwdDataSpecialization == if constexpr(ConvBwdDataSpecialization ==
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
...@@ -787,10 +1049,10 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -787,10 +1049,10 @@ struct TransformConvBwdDataToGemm_v1
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_grid_desc, in_grid_desc,
make_tuple( make_tuple(
make_pass_through_transform(N), make_pass_through_transform(N_),
make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)), make_embed_transform(make_tuple(I1, Ho_), make_tuple(I1, ConvStrideH_)),
make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)), make_embed_transform(make_tuple(I1, Wo_), make_tuple(I1, ConvStrideW_)),
make_pass_through_transform(C)), make_pass_through_transform(C_)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
...@@ -798,8 +1060,8 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -798,8 +1060,8 @@ struct TransformConvBwdDataToGemm_v1
in_n_y_ho_x_wo_c_grid_desc, in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_freeze_transform(I0), make_tuple(make_freeze_transform(I0),
make_freeze_transform(I0), make_freeze_transform(I0),
make_merge_transform(make_tuple(N, Ho, Wo)), make_merge_transform(make_tuple(N_, Ho_, Wo_)),
make_pass_through_transform(C)), make_pass_through_transform(C_)),
make_tuple(Sequence<1>{}, Sequence<3>{}, Sequence<0, 2, 4>{}, Sequence<5>{}), make_tuple(Sequence<1>{}, Sequence<3>{}, Sequence<0, 2, 4>{}, Sequence<5>{}),
make_tuple(Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{}));
...@@ -818,11 +1080,11 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -818,11 +1080,11 @@ struct TransformConvBwdDataToGemm_v1
const auto in_n_x_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( const auto in_n_x_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_grid_desc, in_grid_desc,
make_tuple( make_tuple(
make_pass_through_transform(N), make_pass_through_transform(N_),
make_embed_transform(make_tuple(I1, Do), make_tuple(I1, ConvStrideD)), make_embed_transform(make_tuple(I1, Do_), make_tuple(I1, ConvStrideD_)),
make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)), make_embed_transform(make_tuple(I1, Ho_), make_tuple(I1, ConvStrideH_)),
make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)), make_embed_transform(make_tuple(I1, Wo_), make_tuple(I1, ConvStrideW_)),
make_pass_through_transform(C)), make_pass_through_transform(C_)),
make_tuple( make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, make_tuple(Sequence<0>{},
...@@ -836,8 +1098,8 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -836,8 +1098,8 @@ struct TransformConvBwdDataToGemm_v1
make_tuple(make_freeze_transform(I0), make_tuple(make_freeze_transform(I0),
make_freeze_transform(I0), make_freeze_transform(I0),
make_freeze_transform(I0), make_freeze_transform(I0),
make_merge_transform(make_tuple(N, Do, Ho, Wo)), make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_)),
make_pass_through_transform(C)), make_pass_through_transform(C_)),
make_tuple(Sequence<1>{}, make_tuple(Sequence<1>{},
Sequence<3>{}, Sequence<3>{},
Sequence<5>{}, Sequence<5>{},
...@@ -861,36 +1123,21 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -861,36 +1123,21 @@ struct TransformConvBwdDataToGemm_v1
} }
else else
{ {
const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD);
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
const auto ZTilde = ConvStrideD / GcdStrideDilationD;
const auto YTilde = ConvStrideH / GcdStrideDilationH;
const auto XTilde = ConvStrideW / GcdStrideDilationW;
const auto DTilde =
Do + math::integer_divide_ceil(ConvDilationD * (Z - I1), ConvStrideD);
const auto HTilde =
Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH);
const auto WTilde =
Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
// only work on DTilde, HTilde and WTilde that contribute to // only work on DTilde, HTilde and WTilde that contribute to
// non-padding area of input tensor // non-padding area of input tensor
const auto IDTildeSliceBegin = math::integer_divide_floor( const auto IDTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadD - ConvDilationD * (ZTilde - I1)), ConvStrideD); math::max(I0, InLeftPadD_ - ConvDilationD_ * (ZTilde_ - I1)), ConvStrideD_);
const auto IHTildeSliceBegin = math::integer_divide_floor( const auto IHTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH); math::max(I0, InLeftPadH_ - ConvDilationH_ * (YTilde_ - I1)), ConvStrideH_);
const auto IWTildeSliceBegin = math::integer_divide_floor( const auto IWTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW); math::max(I0, InLeftPadW_ - ConvDilationW_ * (XTilde_ - I1)), ConvStrideW_);
const auto IDTildeSliceEnd = math::min( const auto IDTildeSliceEnd = math::min(
DTilde, math::integer_divide_ceil(InLeftPadD + Di - I1, ConvStrideD) + I1); DTilde_, math::integer_divide_ceil(InLeftPadD_ + Di_ - I1, ConvStrideD_) + I1);
const auto IHTildeSliceEnd = math::min( const auto IHTildeSliceEnd = math::min(
HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1); HTilde_, math::integer_divide_ceil(InLeftPadH_ + Hi_ - I1, ConvStrideH_) + I1);
const auto IWTildeSliceEnd = math::min( const auto IWTildeSliceEnd = math::min(
WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1); WTilde_, math::integer_divide_ceil(InLeftPadW_ + Wi_ - I1, ConvStrideW_) + I1);
const auto DTildeSlice = IDTildeSliceEnd - IDTildeSliceBegin; const auto DTildeSlice = IDTildeSliceEnd - IDTildeSliceBegin;
const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin; const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
...@@ -901,34 +1148,34 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -901,34 +1148,34 @@ struct TransformConvBwdDataToGemm_v1
{ {
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_grid_desc, in_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N_),
make_pad_transform(Hi, InLeftPadH, InRightPadH), make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
make_pad_transform(Wi, InLeftPadW, InRightPadW), make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
make_pass_through_transform(C)), make_pass_through_transform(C_)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc =
transform_tensor_descriptor( transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc, in_n_hip_wip_c_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N_),
make_embed_transform(make_tuple(YTilde, HTilde), make_embed_transform(make_tuple(YTilde_, HTilde_),
make_tuple(ConvDilationH, ConvStrideH)), make_tuple(ConvDilationH_, ConvStrideH_)),
make_embed_transform(make_tuple(XTilde, WTilde), make_embed_transform(make_tuple(XTilde_, WTilde_),
make_tuple(ConvDilationW, ConvStrideW)), make_tuple(ConvDilationW_, ConvStrideW_)),
make_pass_through_transform(C)), make_pass_through_transform(C_)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple( make_tuple(
Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor( const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor(
in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc, in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N_),
make_freeze_transform(i_ytilde), make_freeze_transform(IdxYTilde_),
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice),
make_freeze_transform(i_xtilde), make_freeze_transform(IdxXTilde_),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice),
make_pass_through_transform(C)), make_pass_through_transform(C_)),
make_tuple(Sequence<0>{}, make_tuple(Sequence<0>{},
Sequence<1>{}, Sequence<1>{},
Sequence<2>{}, Sequence<2>{},
...@@ -944,8 +1191,8 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -944,8 +1191,8 @@ struct TransformConvBwdDataToGemm_v1
const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor( const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor(
in_n_htildeslice_wtildeslice_c_grid_desc, in_n_htildeslice_wtildeslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)), make_tuple(make_merge_transform(make_tuple(N_, HTildeSlice, WTildeSlice)),
make_pass_through_transform(C)), make_pass_through_transform(C_)),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
...@@ -961,11 +1208,11 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -961,11 +1208,11 @@ struct TransformConvBwdDataToGemm_v1
{ {
const auto in_n_dip_hip_wip_c_grid_desc = transform_tensor_descriptor( const auto in_n_dip_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_grid_desc, in_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N_),
make_pad_transform(Di, InLeftPadD, InRightPadD), make_pad_transform(Di_, InLeftPadD_, InRightPadD_),
make_pad_transform(Hi, InLeftPadH, InRightPadH), make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
make_pad_transform(Wi, InLeftPadW, InRightPadW), make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
make_pass_through_transform(C)), make_pass_through_transform(C_)),
make_tuple( make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple( make_tuple(
...@@ -974,14 +1221,14 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -974,14 +1221,14 @@ struct TransformConvBwdDataToGemm_v1
const auto in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc = const auto in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc =
transform_tensor_descriptor( transform_tensor_descriptor(
in_n_dip_hip_wip_c_grid_desc, in_n_dip_hip_wip_c_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N_),
make_embed_transform(make_tuple(ZTilde, DTilde), make_embed_transform(make_tuple(ZTilde_, DTilde_),
make_tuple(ConvDilationD, ConvStrideD)), make_tuple(ConvDilationD_, ConvStrideD_)),
make_embed_transform(make_tuple(YTilde, HTilde), make_embed_transform(make_tuple(YTilde_, HTilde_),
make_tuple(ConvDilationH, ConvStrideH)), make_tuple(ConvDilationH_, ConvStrideH_)),
make_embed_transform(make_tuple(XTilde, WTilde), make_embed_transform(make_tuple(XTilde_, WTilde_),
make_tuple(ConvDilationW, ConvStrideW)), make_tuple(ConvDilationW_, ConvStrideW_)),
make_pass_through_transform(C)), make_pass_through_transform(C_)),
make_tuple(Sequence<0>{}, make_tuple(Sequence<0>{},
Sequence<1>{}, Sequence<1>{},
Sequence<2>{}, Sequence<2>{},
...@@ -996,14 +1243,14 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -996,14 +1243,14 @@ struct TransformConvBwdDataToGemm_v1
const auto in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc = const auto in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc =
transform_tensor_descriptor( transform_tensor_descriptor(
in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc, in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N_),
make_freeze_transform(i_ztilde), make_freeze_transform(IdxZTilde_),
make_slice_transform(DTilde, IDTildeSliceBegin, DTildeSlice), make_slice_transform(DTilde_, IDTildeSliceBegin, DTildeSlice),
make_freeze_transform(i_ytilde), make_freeze_transform(IdxYTilde_),
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice),
make_freeze_transform(i_xtilde), make_freeze_transform(IdxXTilde_),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice),
make_pass_through_transform(C)), make_pass_through_transform(C_)),
make_tuple(Sequence<0>{}, make_tuple(Sequence<0>{},
Sequence<1>{}, Sequence<1>{},
Sequence<2>{}, Sequence<2>{},
...@@ -1024,8 +1271,8 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -1024,8 +1271,8 @@ struct TransformConvBwdDataToGemm_v1
const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor( const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor(
in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc, in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc,
make_tuple( make_tuple(
make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice)), make_merge_transform(make_tuple(N_, DTildeSlice, HTildeSlice, WTildeSlice)),
make_pass_through_transform(C)), make_pass_through_transform(C_)),
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}), make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
...@@ -1044,84 +1291,41 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -1044,84 +1291,41 @@ struct TransformConvBwdDataToGemm_v1
} }
// for input bias // for input bias
template <typename CLayout, template <typename CLayout_ = CLayout,
typename std::enable_if<NDimSpatial == 2 && typename std::enable_if<NDimSpatial == 2 &&
(is_same_v<CLayout, tensor_layout::convolution::GC> || (is_same_v<CLayout_, tensor_layout::convolution::GC> ||
is_same_v<CLayout, tensor_layout::convolution::G_C>), is_same_v<CLayout_, tensor_layout::convolution::G_C>),
bool>::type = false> bool>::type = false>
static auto __host__ __device__ auto MakeCDescriptor_M_N() const
MakeCDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& out_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& /* out_g_n_k_wos_strides */,
const std::array<index_t, NDimSpatial + 3>& wei_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& /* wei_g_k_c_xs_strides */,
const std::array<index_t, NDimSpatial + 3>& in_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& /* in_g_n_c_wis_strides */,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& /* input_right_pads */,
const std::array<index_t, NDimSpatial>& /* tildes */)
{ {
const index_t N = in_g_n_c_wis_lengths[1];
const index_t C = wei_g_k_c_xs_lengths[2];
const index_t Hi = in_g_n_c_wis_lengths[3];
const index_t Wi = in_g_n_c_wis_lengths[4];
const index_t Ho = out_g_n_k_wos_lengths[3];
const index_t Wo = out_g_n_k_wos_lengths[4];
const index_t Y = wei_g_k_c_xs_lengths[3];
const index_t X = wei_g_k_c_xs_lengths[4];
const index_t InLeftPadH = input_left_pads[0];
const index_t InLeftPadW = input_left_pads[1];
const index_t ConvStrideH = conv_filter_strides[0];
const index_t ConvStrideW = conv_filter_strides[1];
const index_t ConvDilationH = conv_filter_dilations[0];
const index_t ConvDilationW = conv_filter_dilations[1];
if constexpr(ConvBwdDataSpecialization == if constexpr(ConvBwdDataSpecialization ==
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
Filter1x1Stride1Pad0) Filter1x1Stride1Pad0)
{ {
const auto in_gemmm_gemmn_grid_desc = const auto in_gemmm_gemmn_grid_desc =
make_naive_tensor_descriptor(make_tuple(N * Ho * Wo, C), make_tuple(I0, I1)); make_naive_tensor_descriptor(make_tuple(N_ * Ho_ * Wo_, C_), make_tuple(I0, I1));
return in_gemmm_gemmn_grid_desc; return in_gemmm_gemmn_grid_desc;
} }
else else
{ {
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
const auto YTilde = ConvStrideH / GcdStrideDilationH;
const auto XTilde = ConvStrideW / GcdStrideDilationW;
const auto HTilde =
Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH);
const auto WTilde =
Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
// only work on HTilde and WTilde that contribute to non-padding area of input tensor // only work on HTilde and WTilde that contribute to non-padding area of input tensor
const auto IHTildeSliceBegin = math::integer_divide_floor( const auto IHTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH); math::max(I0, InLeftPadH_ - ConvDilationH_ * (YTilde_ - I1)), ConvStrideH_);
const auto IWTildeSliceBegin = math::integer_divide_floor( const auto IWTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW); math::max(I0, InLeftPadW_ - ConvDilationW_ * (XTilde_ - I1)), ConvStrideW_);
const auto IHTildeSliceEnd = math::min( const auto IHTildeSliceEnd = math::min(
HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1); HTilde_, math::integer_divide_ceil(InLeftPadH_ + Hi_ - I1, ConvStrideH_) + I1);
const auto IWTildeSliceEnd = math::min( const auto IWTildeSliceEnd = math::min(
WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1); WTilde_, math::integer_divide_ceil(InLeftPadW_ + Wi_ - I1, ConvStrideW_) + I1);
const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin; const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin; const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
// bias tensor // bias tensor
const auto in_gemmmraw_gemmnraw_grid_desc = make_naive_tensor_descriptor( const auto in_gemmmraw_gemmnraw_grid_desc = make_naive_tensor_descriptor(
make_tuple(N * HTildeSlice * WTildeSlice, C), make_tuple(I0, I1)); make_tuple(N_ * HTildeSlice * WTildeSlice, C_), make_tuple(I0, I1));
const auto in_gemmm_gemmn_grid_desc = ck::tensor_operation::device::PadTensorDescriptor( const auto in_gemmm_gemmn_grid_desc = ck::tensor_operation::device::PadTensorDescriptor(
in_gemmmraw_gemmnraw_grid_desc, in_gemmmraw_gemmnraw_grid_desc,
...@@ -1131,6 +1335,25 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -1131,6 +1335,25 @@ struct TransformConvBwdDataToGemm_v1
return in_gemmm_gemmn_grid_desc; return in_gemmm_gemmn_grid_desc;
} }
} }
IndexType N_;
IndexType Di_, Hi_, Wi_;
IndexType Do_, Ho_, Wo_;
IndexType Z_, Y_, X_;
IndexType K_, C_;
IndexType DiStride_, HiStride_, WiStride_;
IndexType DoStride_, HoStride_, WoStride_;
IndexType CStrideTensorB_, CStrideTensorC_, KStrideTensorA_, KStrideTensorB_;
IndexType NStrideTensorA_, NStrideTensorC_;
IndexType ConvStrideD_, ConvStrideH_, ConvStrideW_;
IndexType ConvDilationD_, ConvDilationH_, ConvDilationW_;
IndexType InLeftPadD_, InLeftPadH_, InLeftPadW_;
IndexType InRightPadD_, InRightPadH_, InRightPadW_;
IndexType IdxZTilde_, IdxYTilde_, IdxXTilde_;
IndexType GcdStrideDilationD_, GcdStrideDilationH_, GcdStrideDilationW_;
IndexType ZTilde_, YTilde_, XTilde_;
IndexType DTilde_, HTilde_, WTilde_;
IndexType ZDot_, YDot_, XDot_;
}; };
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -13,6 +13,11 @@ namespace ck { ...@@ -13,6 +13,11 @@ namespace ck {
defined(__gfx1103__) || defined(__gfx11_generic__) defined(__gfx1103__) || defined(__gfx11_generic__)
#define __gfx11__ #define __gfx11__
#endif #endif
#if defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx12_generic__)
#define __gfx12__
#endif
/********************************WAVE32 MODE***********************************************/ /********************************WAVE32 MODE***********************************************/
// src: fp16, dst: fp32 // src: fp16, dst: fp32
...@@ -99,7 +104,7 @@ struct intrin_wmma_bf16_16x16x16_bf16_w32<16, 16, Opsel> ...@@ -99,7 +104,7 @@ struct intrin_wmma_bf16_16x16x16_bf16_w32<16, 16, Opsel>
// opsel usage // opsel usage
// false: D0.[0:15] = result // false: D0.[0:15] = result
// true : D0.[16:31]= result // true : D0.[16:31]= result
#if defined(__gfx11__) #if defined(__gfx11__) || defined(__gfx12__)
reg_c.template AsType<bhalf16_t>()(Number<0>{}) = reg_c.template AsType<bhalf16_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32( __builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32(
reg_a, reg_b, reg_c.template AsType<bhalf16_t>()[Number<0>{}], Opsel); reg_a, reg_b, reg_c.template AsType<bhalf16_t>()[Number<0>{}], Opsel);
...@@ -261,10 +266,6 @@ struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp> ...@@ -261,10 +266,6 @@ struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp>
// gfx12 // gfx12
/********************************WAVE32 MODE***********************************************/ /********************************WAVE32 MODE***********************************************/
#if defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx12_generic__)
#define __gfx12__
#endif
// src: fp16, dst: fp32 // src: fp16, dst: fp32
template <index_t MPerWave, index_t NPerWave> template <index_t MPerWave, index_t NPerWave>
struct intrin_wmma_f32_16x16x16_f16_w32_gfx12; struct intrin_wmma_f32_16x16x16_f16_w32_gfx12;
......
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
#pragma once #pragma once
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
namespace ck { namespace ck {
......
# ck_tile [Back to the main page](../../README.md)
# Composable Kernel Tile
## concept ## concept
`ck_tile` provides a programming model with templated abstractions to enable users to implement performance-critical kernels for machine learning workloads. introduces following basic concepts to help users building your own operator `ck_tile` provides a programming model with templated abstractions to enable users to implement performance-critical kernels for machine learning workloads. introduces following basic concepts to help users building your own operator
- tensor coordinate transformation, this is the core concept of layout/index transform abstraction in both compiler time and run time. - tensor coordinate transformation, this is the core concept of layout/index transform abstraction in both compiler time and run time.
......
...@@ -52,6 +52,7 @@ ...@@ -52,6 +52,7 @@
#include "ck_tile/core/tensor/tile_elementwise.hpp" #include "ck_tile/core/tensor/tile_elementwise.hpp"
#include "ck_tile/core/tensor/tile_window.hpp" #include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/tensor/tile_window_linear.hpp" #include "ck_tile/core/tensor/tile_window_linear.hpp"
#include "ck_tile/core/tensor/tile_window_utils.hpp"
#include "ck_tile/core/tensor/update_tile.hpp" #include "ck_tile/core/tensor/update_tile.hpp"
#include "ck_tile/core/utility/bit_cast.hpp" #include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/functional.hpp" #include "ck_tile/core/utility/functional.hpp"
...@@ -62,6 +63,7 @@ ...@@ -62,6 +63,7 @@
#include "ck_tile/core/utility/philox_rand.hpp" #include "ck_tile/core/utility/philox_rand.hpp"
#include "ck_tile/core/utility/random.hpp" #include "ck_tile/core/utility/random.hpp"
#include "ck_tile/core/utility/reduce_operator.hpp" #include "ck_tile/core/utility/reduce_operator.hpp"
#include "ck_tile/core/utility/static_counter.hpp"
#include "ck_tile/core/utility/to_sequence.hpp" #include "ck_tile/core/utility/to_sequence.hpp"
#include "ck_tile/core/utility/transpose_vectors.hpp" #include "ck_tile/core/utility/transpose_vectors.hpp"
#include "ck_tile/core/utility/type_traits.hpp" #include "ck_tile/core/utility/type_traits.hpp"
......
...@@ -621,6 +621,65 @@ CK_TILE_DEVICE void buffer_load_fence(index_t cnt = 0) ...@@ -621,6 +621,65 @@ CK_TILE_DEVICE void buffer_load_fence(index_t cnt = 0)
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory"); asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
} }
CK_TILE_DEVICE void lds_load_fence(index_t cnt = 0)
{
asm volatile("s_waitcnt lgkmcnt(%0)" : : "n"(cnt) : "memory");
}
template <typename scalar_type, index_t N, bool pre_nop = false>
struct buffer_atomic_add_if;
template <bool pre_nop>
struct buffer_atomic_add_if<bf16_t, 2, pre_nop>
{
template <typename T>
CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/,
index_t flag = 1)
{
static_assert(sizeof(T) == 4);
auto save_exec = __builtin_amdgcn_read_exec();
using mbuf_t = float;
asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
"global_atomic_pk_add_bf16 %0, %1, %2 offset:%3\n"
"s_mov_b64 exec %5"
:
: "v"(v_offset),
"v"(bit_cast<mbuf_t>(value)),
"s"(res.xy),
"n"(i_offset),
"v"(flag),
"s"(save_exec)
: "memory");
}
};
template <typename scalar_type, index_t N, bool pre_nop = false>
struct buffer_atomic_add;
template <bool pre_nop>
struct buffer_atomic_add<bf16_t, 2, pre_nop>
{
template <typename T>
CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/,
index_t /*flag = 1*/)
{
static_assert(sizeof(T) == 4);
using mbuf_t = float;
asm volatile("global_atomic_pk_add_bf16 %0, %1, %2 offset:%3"
:
: "v"(v_offset), "v"(bit_cast<mbuf_t>(value)), "s"(res.xy), "n"(i_offset)
: "memory");
}
};
namespace impl { namespace impl {
// below type indicate the data type used for buffer load inline asm // below type indicate the data type used for buffer load inline asm
// clang-format off // clang-format off
...@@ -810,6 +869,11 @@ CK_TILE_DEVICE void buffer_store_fence(index_t cnt = 0) ...@@ -810,6 +869,11 @@ CK_TILE_DEVICE void buffer_store_fence(index_t cnt = 0)
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory"); asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
} }
CK_TILE_DEVICE auto async_load_fence_raw(index_t cnt = 0)
{
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
}
// buffer load i8 // buffer load i8
CK_TILE_DEVICE_EXTERN int8_t CK_TILE_DEVICE_EXTERN int8_t
llvm_amdgcn_raw_buffer_load_i8(int32x4_t srsrc, llvm_amdgcn_raw_buffer_load_i8(int32x4_t srsrc,
...@@ -2378,6 +2442,45 @@ CK_TILE_DEVICE void amd_buffer_atomic_add(const thread_buffer<T, N>& src_thread_ ...@@ -2378,6 +2442,45 @@ CK_TILE_DEVICE void amd_buffer_atomic_add(const thread_buffer<T, N>& src_thread_
#endif #endif
} }
template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE void amd_buffer_atomic_add_raw(const thread_buffer<T, N>& src_thread_data,
T* p_dst_wave,
const index_t dst_thread_element_offset,
const index_t dst_linear_element_offset,
const bool dst_thread_element_valid,
const index_t dst_element_space_size,
bool_constant<pre_nop> = {})
{
const int32x4_t dst_wave_buffer_resource =
make_wave_buffer_resource(p_dst_wave, dst_element_space_size * sizeof(T));
index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T);
index_t dst_linear_addr_offset = dst_linear_element_offset * sizeof(T);
if constexpr(oob_conditional_check)
{
buffer_atomic_add_if<T, N, pre_nop>{}(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
0,
dst_linear_addr_offset,
dst_thread_element_valid);
}
else
{
buffer_atomic_add<T, N, pre_nop>{}(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
0,
dst_linear_addr_offset,
1);
}
}
// buffer_atomic_max requires: // buffer_atomic_max requires:
// 1) p_dst_wave must point to global memory // 1) p_dst_wave must point to global memory
// 2) p_dst_wave must be a wavewise pointer. // 2) p_dst_wave must be a wavewise pointer.
......
...@@ -73,6 +73,24 @@ CK_TILE_DEVICE void block_sync_lds() ...@@ -73,6 +73,24 @@ CK_TILE_DEVICE void block_sync_lds()
#endif #endif
} }
CK_TILE_DEVICE void block_sync_load_raw(index_t cnt = 0)
{
#ifdef __gfx12__
asm volatile("s_wait_loadcnt %0 \n"
"s_barrier_signal -1 \n"
"s_barrier_wait -1"
:
: "n"(cnt)
: "memory");
#else
asm volatile("s_waitcnt vmcnt(%0) \n"
"s_barrier"
:
: "n"(cnt)
: "memory");
#endif
}
CK_TILE_DEVICE void block_sync_lds_direct_load() CK_TILE_DEVICE void block_sync_lds_direct_load()
{ {
asm volatile("\ asm volatile("\
......
...@@ -102,4 +102,28 @@ CK_TILE_DEVICE T warp_shuffle(const T& v_local, uint32_t src_lane) ...@@ -102,4 +102,28 @@ CK_TILE_DEVICE T warp_shuffle(const T& v_local, uint32_t src_lane)
#endif #endif
} }
template <typename T>
CK_TILE_DEVICE auto flag_to_exec(const T& v_flag)
{
static_assert(sizeof(T) == 4);
// per-thread v_flag store into 2x sgpr
uint32x2_t exec_flag;
asm volatile("v_cmp_ge_u32 %[s_exec_flag], %[v_flag], 1"
: [s_exec_flag] "=s"(exec_flag)
: [v_flag] "v"(v_flag));
return exec_flag;
}
template <typename X, typename Y>
CK_TILE_DEVICE auto cmp_lt_to_exec(const X& x, const Y& y)
{
static_assert(sizeof(X) == 4 && sizeof(Y) == 4);
// per-thread cmp store into 2x sgpr
uint32x2_t exec_flag;
asm volatile("v_cmp_lt_u32 %[s_exec_flag], %[v_x], %[v_y]"
: [s_exec_flag] "=s"(exec_flag)
: [v_x] "v"(x), [v_y] "v"(y));
return exec_flag;
}
} // namespace ck_tile } // namespace ck_tile
...@@ -64,6 +64,7 @@ ...@@ -64,6 +64,7 @@
#define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE_WITH_NAN 1 #define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE_WITH_NAN 1
#define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE 2 #define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE 2
#define CK_TILE_FLOAT_TO_BFLOAT16_STANDARD_ASM 3 #define CK_TILE_FLOAT_TO_BFLOAT16_STANDARD_ASM 3
#define CK_TILE_FLOAT_TO_BFLOAT16_RTA_ASM 4
#ifndef CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT #ifndef CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT
#define CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE #define CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE
...@@ -225,3 +226,7 @@ ...@@ -225,3 +226,7 @@
#ifndef CK_TILE_WORKAROUND_SWDEV_383542 #ifndef CK_TILE_WORKAROUND_SWDEV_383542
#define CK_TILE_WORKAROUND_SWDEV_383542 1 #define CK_TILE_WORKAROUND_SWDEV_383542 1
#endif #endif
#ifndef CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
#define CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID 1
#endif
...@@ -18,6 +18,7 @@ enum class bf16_rounding_mode ...@@ -18,6 +18,7 @@ enum class bf16_rounding_mode
truncate_with_nan, truncate_with_nan,
truncate, truncate,
standard_asm, standard_asm,
rta_asm, // round to nearest away
}; };
template <bf16_rounding_mode rounding = template <bf16_rounding_mode rounding =
...@@ -180,6 +181,39 @@ uint16_t float_to_bf16_rtn_asm(float f) ...@@ -180,6 +181,39 @@ uint16_t float_to_bf16_rtn_asm(float f)
return uint16_t(u.int32); return uint16_t(u.int32);
} }
// TODO: do we need this on host?
CK_TILE_HOST
uint16_t float_to_bf16_rta_asm(float f) { return float_to_bf16_rtn_raw(f); }
CK_TILE_DEVICE
uint16_t float_to_bf16_rta_asm(float f)
{
union
{
float fp32;
struct
{
uint16_t lo;
uint16_t hi;
};
} u = {f};
const uint32_t low_nan = 0x7fff;
const uint32_t hi_nan = 0x7fff0000;
using uint32x2_t = uint32_t __attribute__((ext_vector_type(2)));
uint32x2_t check_nan;
asm volatile("v_cmp_u_f32 %[s_cnan], %[v_x], %[v_x] \n"
"v_add3_u32 %[v_x], %[v_x], %[v_blo], 1 \n"
"v_cndmask_b32 %[v_x], %[v_x], %[v_bhi], %[s_cnan]"
: [s_cnan] "+s"(check_nan), [v_x] "+v"(u.fp32)
: [v_blo] "v"(low_nan), [v_bhi] "v"(hi_nan));
// Note: in above code snipet, we use hi 16 bit
return u.hi;
}
// Truncate instead of rounding, preserving SNaN // Truncate instead of rounding, preserving SNaN
CK_TILE_HOST_DEVICE CK_TILE_HOST_DEVICE
constexpr uint16_t float_to_bf16_truc_nan_raw(float f) constexpr uint16_t float_to_bf16_truc_nan_raw(float f)
...@@ -213,6 +247,8 @@ CK_TILE_HOST_DEVICE constexpr uint16_t float_to_bf16_raw(float f, constant<round ...@@ -213,6 +247,8 @@ CK_TILE_HOST_DEVICE constexpr uint16_t float_to_bf16_raw(float f, constant<round
return float_to_bf16_rtn_asm(f); return float_to_bf16_rtn_asm(f);
else if constexpr(rounding == bf16_rounding_mode::truncate_with_nan) else if constexpr(rounding == bf16_rounding_mode::truncate_with_nan)
return float_to_bf16_truc_nan_raw(f); return float_to_bf16_truc_nan_raw(f);
else if constexpr(rounding == bf16_rounding_mode::rta_asm)
return float_to_bf16_rta_asm(f);
else else
return float_to_bf16_truc_raw(f); return float_to_bf16_truc_raw(f);
} }
......
...@@ -437,34 +437,74 @@ struct buffer_view<address_space_enum::global, ...@@ -437,34 +437,74 @@ struct buffer_view<address_space_enum::global,
// i is offset of T, not X. i should be aligned to X // i is offset of T, not X. i should be aligned to X
template <memory_operation_enum Op, template <memory_operation_enum Op,
typename X, typename X,
bool oob_conditional_check = true,
typename std::enable_if< typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type, std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value, typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false> bool>::type = false>
CK_TILE_DEVICE void update(index_t i, index_t linear_offset, bool is_valid_element, const X& x) CK_TILE_DEVICE void update(index_t i,
index_t linear_offset,
bool is_valid_element,
const X& x,
bool_constant<oob_conditional_check> = {})
{ {
if constexpr(Op == memory_operation_enum::set) if constexpr(Op == memory_operation_enum::set)
{ {
this->template set<X>(i, linear_offset, is_valid_element, x); this->template set<X, oob_conditional_check>(i, linear_offset, is_valid_element, x);
} }
else if constexpr(Op == memory_operation_enum::atomic_add) else if constexpr(Op == memory_operation_enum::atomic_add)
{ {
this->template atomic_add<X>(i, linear_offset, is_valid_element, x); this->template atomic_add<X, oob_conditional_check>(
i, linear_offset, is_valid_element, x);
} }
else if constexpr(Op == memory_operation_enum::atomic_max) else if constexpr(Op == memory_operation_enum::atomic_max)
{ {
this->template atomic_max<X>(i, linear_offset, is_valid_element, x); this->template atomic_max<X, oob_conditional_check>(
i, linear_offset, is_valid_element, x);
} }
// FIXME: remove memory_operation_enum::add // FIXME: remove memory_operation_enum::add
else if constexpr(Op == memory_operation_enum::add) else if constexpr(Op == memory_operation_enum::add)
{ {
auto tmp = this->template get<X>(i, linear_offset, is_valid_element); auto tmp =
this->template set<X>(i, linear_offset, is_valid_element, x + tmp); this->template get<X, oob_conditional_check>(i, linear_offset, is_valid_element);
this->template set<X, oob_conditional_check>(
i, linear_offset, is_valid_element, x + tmp);
// tmp += x; // tmp += x;
// this->template set<X>(i, is_valid_element, tmp); // this->template set<X>(i, is_valid_element, tmp);
} }
} }
// i is offset of T, not X. i should be aligned to X
template <memory_operation_enum Op,
typename X,
bool oob_conditional_check = true,
bool pre_nop = false,
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE void update_raw(index_t i,
index_t linear_offset,
bool is_valid_element,
const X& x,
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{
if constexpr(Op == memory_operation_enum::set)
{
this->template set_raw<X, oob_conditional_check>(i, linear_offset, is_valid_element, x);
}
else if constexpr(Op == memory_operation_enum::atomic_add)
{
this->template atomic_add_raw<X, oob_conditional_check, pre_nop>(
i, linear_offset, is_valid_element, x);
}
else if constexpr(Op == memory_operation_enum::atomic_max)
{
// this->template atomic_max_raw<X>(i, linear_offset, is_valid_element, x);
}
}
// i is offset of T, not X. i should be aligned to X // i is offset of T, not X. i should be aligned to X
template <typename X, template <typename X,
bool oob_conditional_check = true, bool oob_conditional_check = true,
...@@ -533,6 +573,7 @@ struct buffer_view<address_space_enum::global, ...@@ -533,6 +573,7 @@ struct buffer_view<address_space_enum::global,
} }
template <typename X, template <typename X,
bool oob_conditional_check = true,
typename std::enable_if< typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type, std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value, typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
...@@ -585,6 +626,39 @@ struct buffer_view<address_space_enum::global, ...@@ -585,6 +626,39 @@ struct buffer_view<address_space_enum::global,
} }
template <typename X, template <typename X,
bool oob_conditional_check = true,
bool pre_nop = true,
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE void
atomic_add_raw(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
{
// using scalar_t = typename vector_traits<remove_cvref_t<T>>::scalar_type;
// X contains multiple T
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X should contain multiple T");
static_assert(get_address_space() == address_space_enum::global, "only support global mem");
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_atomic_add_raw<remove_cvref_t<T>,
t_per_x,
Coherence,
oob_conditional_check,
pre_nop>(
x, p_data_, i, linear_offset, is_valid_element, buffer_size_);
}
template <typename X,
bool oob_conditional_check = true,
typename std::enable_if< typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type, std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value, typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
......
...@@ -22,28 +22,32 @@ template <typename BottomTensorView_, ...@@ -22,28 +22,32 @@ template <typename BottomTensorView_,
typename WindowLengths_, typename WindowLengths_,
typename TileDistribution_, typename TileDistribution_,
index_t NumCoord, index_t NumCoord,
index_t i_access = -1,
bool oob_conditional_check = true> bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_tile(const tile_window_with_static_distribution<BottomTensorView_, CK_TILE_DEVICE auto load_tile(const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_, WindowLengths_,
TileDistribution_, TileDistribution_,
NumCoord>& tile_window, NumCoord>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {}) bool_constant<oob_conditional_check> = {})
{ {
return tile_window.load(number<-1>{}, bool_constant<oob_conditional_check>{}); return tile_window.load(number<i_access>{}, bool_constant<oob_conditional_check>{});
} }
template <typename BottomTensorView_, template <typename BottomTensorView_,
typename WindowLengths_, typename WindowLengths_,
typename TileDistribution_, typename TileDistribution_,
typename LinearBottomDims_, typename LinearBottomDims_,
index_t i_access = -1,
bool oob_conditional_check = true> bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_tile(const tile_window_linear<BottomTensorView_, CK_TILE_DEVICE auto load_tile(const tile_window_linear<BottomTensorView_,
WindowLengths_, WindowLengths_,
TileDistribution_, TileDistribution_,
LinearBottomDims_>& tile_window, LinearBottomDims_>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {}) bool_constant<oob_conditional_check> = {})
{ {
return tile_window.load(number<-1>{}, bool_constant<oob_conditional_check>{}); return tile_window.load(number<i_access>{}, bool_constant<oob_conditional_check>{});
} }
template <typename DistributedTensor_, template <typename DistributedTensor_,
...@@ -51,15 +55,35 @@ template <typename DistributedTensor_, ...@@ -51,15 +55,35 @@ template <typename DistributedTensor_,
typename WindowLengths_, typename WindowLengths_,
typename TileDistribution_, typename TileDistribution_,
index_t NumCoord, index_t NumCoord,
index_t i_access = -1,
bool oob_conditional_check = true> bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_tile(DistributedTensor_& dst_tile, CK_TILE_DEVICE auto load_tile(DistributedTensor_& dst_tile,
const tile_window_with_static_distribution<BottomTensorView_, const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_, WindowLengths_,
TileDistribution_, TileDistribution_,
NumCoord>& tile_window, NumCoord>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {}) bool_constant<oob_conditional_check> = {})
{ {
return tile_window.load(dst_tile, bool_constant<oob_conditional_check>{}); return tile_window.load(dst_tile, number<i_access>{}, bool_constant<oob_conditional_check>{});
}
template <typename DistributedTensor_,
typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename LinearBottomDims_,
index_t i_access = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_tile(DistributedTensor_& dst_tile,
const tile_window_linear<BottomTensorView_,
WindowLengths_,
TileDistribution_,
LinearBottomDims_>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{
return tile_window.load(dst_tile, number<i_access>{}, bool_constant<oob_conditional_check>{});
} }
/** /**
...@@ -76,6 +100,7 @@ template <typename T, ...@@ -76,6 +100,7 @@ template <typename T,
typename WindowLengths_, typename WindowLengths_,
typename TileDistribution_, typename TileDistribution_,
index_t NumCoord, index_t NumCoord,
index_t i_access = -1,
bool oob_conditional_check = true, bool oob_conditional_check = true,
bool pre_nop = false> bool pre_nop = false>
CK_TILE_DEVICE auto load_tile_raw(T& tile, CK_TILE_DEVICE auto load_tile_raw(T& tile,
...@@ -83,11 +108,12 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile, ...@@ -83,11 +108,12 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile,
WindowLengths_, WindowLengths_,
TileDistribution_, TileDistribution_,
NumCoord>& tile_window, NumCoord>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {}, bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) bool_constant<pre_nop> = {})
{ {
tile_window.load_raw( tile_window.load_raw(
tile, number<-1>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{}); tile, number<i_access>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
} }
template <typename T, template <typename T,
...@@ -95,6 +121,7 @@ template <typename T, ...@@ -95,6 +121,7 @@ template <typename T,
typename WindowLengths_, typename WindowLengths_,
typename TileDistribution_, typename TileDistribution_,
typename LinearBottomDims_, typename LinearBottomDims_,
index_t i_access = -1,
bool oob_conditional_check = true, bool oob_conditional_check = true,
bool pre_nop = false> bool pre_nop = false>
CK_TILE_DEVICE auto load_tile_raw(T& tile, CK_TILE_DEVICE auto load_tile_raw(T& tile,
...@@ -102,11 +129,12 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile, ...@@ -102,11 +129,12 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile,
WindowLengths_, WindowLengths_,
TileDistribution_, TileDistribution_,
LinearBottomDims_>& tile_window, LinearBottomDims_>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {}, bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) bool_constant<pre_nop> = {})
{ {
tile_window.load_raw( tile_window.load_raw(
tile, number<-1>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{}); tile, number<i_access>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
} }
template <typename LdsTileWindow_, template <typename LdsTileWindow_,
...@@ -114,6 +142,7 @@ template <typename LdsTileWindow_, ...@@ -114,6 +142,7 @@ template <typename LdsTileWindow_,
typename WindowLengths_, typename WindowLengths_,
typename TileDistribution_, typename TileDistribution_,
index_t NumCoord, index_t NumCoord,
index_t i_access = -1,
bool oob_conditional_check = true, bool oob_conditional_check = true,
bool pre_nop = false> bool pre_nop = false>
CK_TILE_DEVICE auto CK_TILE_DEVICE auto
...@@ -122,11 +151,14 @@ async_load_tile_raw(LdsTileWindow_&& lds_tile, ...@@ -122,11 +151,14 @@ async_load_tile_raw(LdsTileWindow_&& lds_tile,
WindowLengths_, WindowLengths_,
TileDistribution_, TileDistribution_,
NumCoord>& tile_window, NumCoord>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {}, bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) bool_constant<pre_nop> = {})
{ {
return tile_window.async_load_raw( return tile_window.async_load_raw(lds_tile,
lds_tile, number<-1>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{}); number<i_access>{},
bool_constant<oob_conditional_check>{},
bool_constant<pre_nop>{});
} }
template <typename LdsTileWindow_, template <typename LdsTileWindow_,
...@@ -134,6 +166,7 @@ template <typename LdsTileWindow_, ...@@ -134,6 +166,7 @@ template <typename LdsTileWindow_,
typename WindowLengths_, typename WindowLengths_,
typename TileDistribution_, typename TileDistribution_,
typename LinearBottomDims_, typename LinearBottomDims_,
index_t i_access = -1,
bool oob_conditional_check = true, bool oob_conditional_check = true,
bool pre_nop = false> bool pre_nop = false>
CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_&& lds_tile, CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_&& lds_tile,
...@@ -141,11 +174,14 @@ CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_&& lds_tile, ...@@ -141,11 +174,14 @@ CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_&& lds_tile,
WindowLengths_, WindowLengths_,
TileDistribution_, TileDistribution_,
LinearBottomDims_>& tile_window, LinearBottomDims_>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {}, bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) bool_constant<pre_nop> = {})
{ {
return tile_window.async_load_raw( return tile_window.async_load_raw(lds_tile,
lds_tile, number<-1>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{}); number<i_access>{},
bool_constant<oob_conditional_check>{},
bool_constant<pre_nop>{});
} }
CK_TILE_DEVICE auto async_load_fence(index_t cnt = 0) CK_TILE_DEVICE auto async_load_fence(index_t cnt = 0)
......
...@@ -201,4 +201,30 @@ CK_TILE_HOST_DEVICE constexpr auto get_y_unpacks_from_x_unpacks(YLengths, number ...@@ -201,4 +201,30 @@ CK_TILE_HOST_DEVICE constexpr auto get_y_unpacks_from_x_unpacks(YLengths, number
return unpacks; return unpacks;
} }
namespace detail {
// check if 2 static_distributed_tensor has same data type and size of element
// but only difference in distribution
template <typename X, typename Y>
struct is_similiar_distributed_tensor
{
static constexpr bool value = false;
};
template <typename TypeX, typename DistX, typename TypeY, typename DistY>
struct is_similiar_distributed_tensor<static_distributed_tensor<TypeX, DistX>,
static_distributed_tensor<TypeY, DistY>>
{
using Tx = static_distributed_tensor<TypeX, DistX>;
using Ty = static_distributed_tensor<TypeY, DistY>;
static constexpr bool value = std::is_same_v<typename Tx::DataType, typename Ty::DataType> &&
Tx::get_thread_buffer_size() == Ty::get_thread_buffer_size();
};
template <typename X, typename Y>
inline constexpr bool is_similiar_distributed_tensor_v =
is_similiar_distributed_tensor<X, Y>::value;
} // namespace detail
} // namespace ck_tile } // namespace ck_tile
...@@ -333,6 +333,48 @@ struct tensor_view ...@@ -333,6 +333,48 @@ struct tensor_view
coord.get_offset(), linear_offset, is_valid_element, x); coord.get_offset(), linear_offset, is_valid_element, x);
} }
// X is vector of DataType.
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
template <typename X,
bool oob_conditional_check = true,
bool pre_nop = false,
typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr void
update_vectorized_elements_raw(const TensorCoord& coord,
index_t linear_offset,
const X& x,
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{
buf_.template update_raw<DstInMemOp, X, oob_conditional_check, pre_nop>(
coord.get_offset(),
linear_offset,
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
x);
}
template <typename X,
bool oob_conditional_check = true,
bool pre_nop = false,
typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr void
update_vectorized_elements_raw(const TensorCoord& coord,
index_t linear_offset,
bool is_valid_element,
const X& x,
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{
buf_.template update_raw<DstInMemOp, X, oob_conditional_check, pre_nop>(
coord.get_offset(), linear_offset, is_valid_element, x);
}
CK_TILE_HOST_DEVICE void print() const CK_TILE_HOST_DEVICE void print() const
{ {
printf("tensor_view{"); printf("tensor_view{");
......
...@@ -292,12 +292,15 @@ struct tile_window_with_static_distribution ...@@ -292,12 +292,15 @@ struct tile_window_with_static_distribution
{ {
constexpr auto tile_dstr = TileDstr{}; constexpr auto tile_dstr = TileDstr{};
auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr); auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
load(dst_tensor, bool_constant<oob_conditional_check>{}); load(dst_tensor, number<i_access_unsupport_>{}, bool_constant<oob_conditional_check>{});
return dst_tensor; return dst_tensor;
} }
template <typename DistributedTensor, bool oob_conditional_check = true> template <typename DistributedTensor,
index_t i_access_unsupport_ = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load(DistributedTensor& dst_tensor, CK_TILE_DEVICE auto load(DistributedTensor& dst_tensor,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const bool_constant<oob_conditional_check> = {}) const
{ {
using Traits = load_store_traits; using Traits = load_store_traits;
...@@ -785,6 +788,73 @@ struct tile_window_with_static_distribution ...@@ -785,6 +788,73 @@ struct tile_window_with_static_distribution
}); });
} }
template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true, bool pre_nop>
CK_TILE_DEVICE void update_raw(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) const
{
using Traits = load_store_traits;
using vector_t = typename Traits::vector_t;
using SFC_Ys = typename Traits::SFC_Ys;
constexpr auto tile_dstr = TileDstr{};
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
// read from distributed tensor
vector_t vec_value;
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_tuple(
[&](auto jj) {
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
: idx_ys_start[jj];
},
number<NDimY>{});
constexpr index_t d =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
vec_value.template get_as<DataType>()(j) =
dstr_tensor.get_thread_buffer().template at<d>();
});
// write into bottom tensor
get_bottom_tensor_view().template update_vectorized_elements_raw<vector_t>(
bottom_tensor_thread_coord,
0,
vec_value,
bool_constant<oob_conditional_check>{},
bool_constant<pre_nop>{});
// move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
{
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
}
});
});
}
// move thread's botom tensor coordiante // move thread's botom tensor coordiante
// [x0', x1', ... ] ==> [offset] // [x0', x1', ... ] ==> [offset]
// also move window-origin // also move window-origin
......
...@@ -432,23 +432,38 @@ struct tile_window_linear ...@@ -432,23 +432,38 @@ struct tile_window_linear
CK_TILE_DEVICE static constexpr index_t get_bottom_linear_offset(number<i_access>) CK_TILE_DEVICE static constexpr index_t get_bottom_linear_offset(number<i_access>)
{ {
constexpr auto linear_coord = get_bottom_linear_coordinate(number<i_access>{}); constexpr auto linear_coord = get_bottom_linear_coordinate(number<i_access>{});
// since this is linear offset, we assum bottom X tensor is always linear constexpr auto is_pure_linear_tensor =
constexpr index_t linear_offset = [&]() { reduce_on_sequence(LinearBottomDims{}, multiplies{}, number<1>{});
constexpr auto x_idx_ = linear_coord; if constexpr(is_pure_linear_tensor)
constexpr auto x_len_ = TileDstr{}.get_lengths(); {
static_assert(x_idx_.size() == x_len_.size()); // this case usually is a LDS window, everything is known at compile tile.
constexpr index_t x_dims_ = x_idx_.size(); // we directly use BottomTensorView transform to compute the offset, in case padding
index_t cu_stride_ = 1; auto bottom_tensor_coord =
index_t cu_offset_ = 0; make_tensor_coordinate(BottomTensorView{}.get_tensor_descriptor(), linear_coord);
static_for<0, x_dims_, 1>{}([&](auto i_) { return bottom_tensor_coord.get_offset();
auto r_i_ = number<x_dims_ - i_ - 1>{}; }
cu_offset_ += x_idx_[r_i_] * cu_stride_; else
cu_stride_ *= x_len_[r_i_]; {
}); // this case usually is a global window, where last dim can be linear
return cu_offset_; // we hack here, that use the original TileDstr to compute the linear offset
}(); // ... hoping that there is no extra padding between other dims, which make sense
// since that would introduce runtime length (so can't use linear offset)
return linear_offset; constexpr index_t linear_offset = [&]() {
constexpr auto x_idx_ = linear_coord;
constexpr auto x_len_ = TileDstr{}.get_lengths();
static_assert(x_idx_.size() == x_len_.size());
constexpr index_t x_dims_ = x_idx_.size();
index_t cu_stride_ = 1;
index_t cu_offset_ = 0;
static_for<0, x_dims_, 1>{}([&](auto i_) {
auto r_i_ = number<x_dims_ - i_ - 1>{};
cu_offset_ += x_idx_[r_i_] * cu_stride_;
cu_stride_ *= x_len_[r_i_];
});
return cu_offset_;
}();
return linear_offset;
}
} }
CK_TILE_DEVICE constexpr auto get_num_of_access() const { return traits::NumAccess; } CK_TILE_DEVICE constexpr auto get_num_of_access() const { return traits::NumAccess; }
...@@ -509,6 +524,64 @@ struct tile_window_linear ...@@ -509,6 +524,64 @@ struct tile_window_linear
return dst_tensor; return dst_tensor;
} }
template <typename DstTile, index_t i_access = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE auto load(DstTile& dst_tensor,
number<i_access> = {},
bool_constant<oob_conditional_check> = {}) const
{
using vector_t = typename traits::vector_t;
using SFC_Ys = typename traits::SFC_Ys;
constexpr auto tile_dstr = TileDstr{};
// auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
auto issue = [&](auto i_access_) {
constexpr auto IAccess = number<i_access_>{};
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
auto bottom_tensor_flag = cached_flags_[IAccess];
constexpr auto linear_offset = get_bottom_linear_offset(IAccess);
// read from bottom tensor
const vector_t vec_value =
get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
bottom_tensor_thread_coord,
linear_offset,
bottom_tensor_flag,
bool_constant<oob_conditional_check>{});
#if 1
// data index [y0, y1, ...]
constexpr auto idx_diff_ys = SFC_Ys::get_index(IAccess);
// write into distributed tensor
static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_tuple(
[&](auto jj) {
return jj == traits::VectorDimY ? (idx_diff_ys[jj] + j) : idx_diff_ys[jj];
},
number<NDimY>{});
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
dst_tensor.get_thread_buffer().template at<d>() =
vec_value.template get_as<DataType>()[j];
});
#else
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start);
static_assert(d % traits::ScalarPerVector == 0);
dst_tensor.get_thread_buffer().template get_as<vector_t>()(
number<d / traits::ScalarPerVector>{}) = bit_cast<vector_t>(vec_value);
#endif
};
WINDOW_DISPATCH_ISSUE();
return dst_tensor;
}
template <typename DstTile, template <typename DstTile,
index_t i_access = -1, index_t i_access = -1,
bool oob_conditional_check = true, bool oob_conditional_check = true,
...@@ -849,6 +922,58 @@ struct tile_window_linear ...@@ -849,6 +922,58 @@ struct tile_window_linear
WINDOW_DISPATCH_ISSUE(); WINDOW_DISPATCH_ISSUE();
} }
template <index_t i_access = -1, bool oob_conditional_check = true, bool pre_nop = false>
CK_TILE_DEVICE void update_raw(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
number<i_access> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) const
{
using vector_t = typename traits::vector_t;
using SFC_Ys = typename traits::SFC_Ys;
constexpr auto tile_dstr = TileDstr{};
// loop over thread tensor space [y0, y1, ...]
auto issue = [&](auto i_access_) {
constexpr auto IAccess = number<i_access_>{};
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
constexpr auto linear_offset = get_bottom_linear_offset(IAccess);
auto bottom_tensor_flag = cached_flags_[IAccess];
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(IAccess);
// read from distributed tensor
vector_t vec_value;
static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_tuple(
[&](auto jj) {
return jj == traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj];
},
number<NDimY>{});
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
vec_value.template get_as<DataType>()(j) =
dstr_tensor.get_thread_buffer().template at<d>();
});
// write into bottom tensor
get_bottom_tensor_view().template update_vectorized_elements_raw<vector_t>(
bottom_tensor_thread_coord,
linear_offset,
bottom_tensor_flag,
vec_value,
bool_constant<oob_conditional_check>{},
bool_constant<pre_nop>{});
};
WINDOW_DISPATCH_ISSUE();
}
// move thread's botom tensor coordiante // move thread's botom tensor coordiante
// [x0', x1', ... ] ==> [offset] // [x0', x1', ... ] ==> [offset]
// also move window-origin // also move window-origin
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/utility.hpp"
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/tensor/static_distributed_tensor.hpp"
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#pragma once
namespace ck_tile {
// input a lds store tile, extract some information from it
// used to set m0 value for gfx9 serious
template <typename LdsTileWindow_>
CK_TILE_DEVICE auto get_async_store_smem_info(LdsTileWindow_&& lds_tile)
{
using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
using LdsDataType = typename LdsTileWindow::DataType;
// issues * warps * lanes
static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded
const index_t size_per_buf =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<0>{}, number<0>{}, number<0>{})) *
sizeof(LdsDataType);
const index_t size_per_wave =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<0>{}, number<1>{}, number<0>{})) *
sizeof(LdsDataType) -
size_per_buf;
const index_t size_per_issue =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<1>{}, number<0>{}, number<0>{})) *
sizeof(LdsDataType) -
size_per_buf;
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
return make_tuple(m0_init_value, size_per_issue);
}
} // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment