Commit ef5e60f6 authored by illsilin's avatar illsilin
Browse files

Merge branch 'develop' into gfx950

parents 2cc0fa26 5e93fa9e
......@@ -421,11 +421,11 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
}
hip_check_error(
hipMemcpyWithStream(arg.p_workspace_,
arg.gemm_kernel_args_.data(),
arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg),
hipMemcpyHostToDevice,
stream_config.stream_id_));
hipMemcpyAsync(arg.p_workspace_,
arg.gemm_kernel_args_.data(),
arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg),
hipMemcpyHostToDevice,
stream_config.stream_id_));
float ave_time = 0;
......@@ -538,10 +538,16 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
return false;
}
if(std::is_same_v<EDataType, ck::bhalf_t> && arg.K_BATCH > 1 && !is_bf16_atomic_supported())
{
return false;
}
bool supported = true;
for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i)
{
const auto& a = arg.gemm_kernel_args_[i].karg_;
const auto& a = arg.gemm_kernel_args_[i].karg_;
bool group_arg_valid = GridwiseGemm::CheckValidity(a);
if(not group_arg_valid)
{
......@@ -631,16 +637,42 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
{
return dynamic_cast<const Argument*>(p_arg)->gemm_kernel_args_.size() *
sizeof(GemmTransKernelArg);
auto p_arg_ = dynamic_cast<const Argument*>(p_arg);
if(p_arg_)
{
return p_arg_->gemm_kernel_args_.size() * sizeof(GemmTransKernelArg);
}
else
throw std::runtime_error(
"The argument pointer is not an object of "
"DeviceGroupedGemmMultipleDSplitKXdlCShuffle::Argument structure!");
}
size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override
{
return GetWorkSpaceSize(p_arg);
}
// TODO: deperecation notice.
static void SetKBatchSize(Argument& arg, index_t kbatch) { arg.UpdateKBatch(kbatch); }
// polymorphic
void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const override
{
return SetKBatchSize(*dynamic_cast<Argument*>(p_arg), kbatch);
auto p_arg_ = dynamic_cast<Argument*>(p_arg);
if(p_arg_)
{
p_arg_->UpdateKBatch(kbatch);
}
else
throw std::runtime_error(
"The argument pointer is not an object of "
"DeviceGroupedGemmMultipleDSplitKXdlCShuffle::Argument structure!");
}
void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const override
{
return this->SetWorkSpacePointer(p_arg, p_dev_kernel_args);
}
};
......
......@@ -14,6 +14,8 @@
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1r2.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/utility/workgroup_barrier.hpp"
#include "ck/utility/reduction_functions_accumulate.hpp"
namespace ck {
......@@ -38,7 +40,7 @@ __global__ void
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, p_shared, karg);
karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, p_shared, karg, karg.p_workspace_);
#else
ignore = karg;
#endif // end of if (defined(__gfx9__))
......@@ -62,7 +64,13 @@ __global__ void
__shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, p_shared_0, p_shared_1, karg);
karg.p_a_grid,
karg.p_b_grid,
karg.p_c_grid,
p_shared_0,
p_shared_1,
karg,
karg.p_workspace_);
#else
ignore = karg;
#endif // end of if (defined(__gfx9__))
......@@ -521,7 +529,9 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, Streamk_sel_, Grid_size_},
p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_},
p_c_grid{p_c_grid_}
p_c_grid{p_c_grid_},
block_2_ctile_map_streamk(
M_, N_, AK0Number * CalculateKPadded(K_, 1), Grid_size_, Streamk_sel_)
{
}
......@@ -529,6 +539,13 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
const ADataType* p_a_grid;
const BDataType* p_b_grid;
CDataType* p_c_grid;
BlockToCTileMap_GemmStreamK_v2<MPerBlock,
NPerBlock,
KPerBlock,
StreamKReductionStrategy::Atomic,
8,
4>
block_2_ctile_map_streamk;
};
struct SplitKBatchOffset
......@@ -853,6 +870,19 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
}
__host__ __device__ static constexpr auto
GetCBlockDescriptor_MShuffle_MPerShuffle_NShuffle_NPerShuffle()
{
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
return make_naive_tensor_descriptor_packed(
make_tuple(Number<MXdlPerWave / CShuffleMXdlPerWavePerShuffle>{},
Number<CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl>{},
Number<NXdlPerWave / CShuffleNXdlPerWavePerShuffle>{},
Number<CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>{}));
}
using BlockwiseGemmPipe =
remove_cvref_t<decltype(BlockGemmPipeline_Selector<
BlkGemmPipelineVer,
......@@ -1118,6 +1148,34 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
return c_grid_desc_mblock_mperblock_nblock_nperblock;
}
__host__ __device__ static constexpr auto GetClusterLengthReduction()
{
// TODO: assume C is row major
// TODO: we always first loop over N, then M
constexpr auto NPerBlockPow2 = math::next_power_of_two<NPerBlock>();
constexpr auto NPerBlockReduction =
NPerBlockPow2 / CShuffleBlockTransferScalarPerVector_NPerBlock;
constexpr auto MPerBlockReduction =
(BlockSize + NPerBlockReduction - 1) / NPerBlockReduction;
return Sequence<MPerBlockReduction, NPerBlockReduction>{};
}
__host__ __device__ static constexpr auto GetPartialAccBlockDescriptor()
{
const auto c_partial_acc_block_m_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(MPerBlock, NPerBlock),
make_tuple(NPerBlock, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(MPerBlock, NPerBlock),
make_tuple(I1, MPerBlock));
}
}();
return c_partial_acc_block_m_n;
}
using Block2CTileMap_streamk = BlockToCTileMap_GemmStreamK_v2<MPerBlock,
NPerBlock,
KPerBlock,
......@@ -1132,22 +1190,42 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
const BDataType* p_b_grid,
CDataType* p_c_grid,
void* p_shared,
Problem& problem)
Problem& problem,
void* p_workspace)
{
const AElementwiseOperation a_element_op{};
const BElementwiseOperation b_element_op{};
const CElementwiseOperation c_element_op{};
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
Block2CTileMap_streamk block_2_ctile_map_streamk(problem.M,
problem.N,
AK0Number * problem.KPadded,
problem.Grid_size,
problem.Streamk_sel);
uint32_t iter_start, iter_end;
bool is_sk_block, is_dp_block;
bool is_sk_block, is_dp_block, is_reduction_block;
index_t num_k_block_main_loop;
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n, problem.MBlock, problem.NBlock);
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
uint32_t* p_semaphore = reinterpret_cast<uint32_t*>(
reinterpret_cast<char*>(p_workspace) +
block_2_ctile_map_streamk.get_workspace_size_for_acc(sizeof(AccDataType)));
for(auto block_idx = get_block_1d_id();
block_idx < block_2_ctile_map_streamk.get_grid_dims();
block_idx += gridDim.x)
......@@ -1163,6 +1241,214 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
block_2_ctile_map_streamk.get_block_itr(block_idx, iter_start, iter_end);
num_k_block_main_loop = iter_end - iter_start;
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
{
is_reduction_block = static_cast<uint32_t>(block_idx) >=
block_2_ctile_map_streamk.reduction_start_block_idx;
if(is_reduction_block)
{
// descriptors
constexpr auto cluster_length_reduce = GetClusterLengthReduction();
constexpr auto reduce_desc = make_cluster_descriptor(cluster_length_reduce);
const auto reduce_thread_cluster_idx =
reduce_desc.CalculateBottomIndex(make_multi_index(block_idx));
const auto thread_m_cluster_id = reduce_thread_cluster_idx[I0];
const auto thread_n_cluster_id = reduce_thread_cluster_idx[I1];
constexpr auto MReduceIters = math::integer_divide_ceil(
Number<MPerBlock>{}, cluster_length_reduce.At(I0));
constexpr auto NReduceIters = math::integer_divide_ceil(
Number<NPerBlock>{},
cluster_length_reduce.At(I1) *
Number<CShuffleBlockTransferScalarPerVector_NPerBlock>{});
constexpr auto acc_thread_buf_load_desc = make_naive_tensor_descriptor_packed(
make_tuple(I1, Number<CShuffleBlockTransferScalarPerVector_NPerBlock>{}));
constexpr auto acc_thread_buf_store_desc =
make_naive_tensor_descriptor_packed(make_tuple(
I1, I1, I1, Number<CShuffleBlockTransferScalarPerVector_NPerBlock>{}));
constexpr auto c_partial_acc_block_m_n = GetPartialAccBlockDescriptor();
constexpr auto partial_acc_load_step_n =
make_multi_index(0,
cluster_length_reduce.At(I1) *
CShuffleBlockTransferScalarPerVector_NPerBlock);
constexpr auto partial_acc_load_step_n_reverse = make_multi_index(
0,
-1 * cluster_length_reduce.At(I1).value * (NReduceIters - 1) *
CShuffleBlockTransferScalarPerVector_NPerBlock);
constexpr auto partial_acc_load_step_m =
make_multi_index(cluster_length_reduce.At(I0), 0);
constexpr auto partial_acc_store_step_n =
make_multi_index(0,
0,
0,
cluster_length_reduce.At(I1) *
CShuffleBlockTransferScalarPerVector_NPerBlock);
constexpr auto partial_acc_store_step_n_reverse = make_multi_index(
0,
0,
0,
-1 * cluster_length_reduce.At(I1).value * (NReduceIters - 1) *
CShuffleBlockTransferScalarPerVector_NPerBlock);
constexpr auto partial_acc_store_step_m =
make_multi_index(0, cluster_length_reduce.At(I0), 0, 0);
StaticBuffer<AddressSpaceEnum::Vgpr,
AccDataType,
CShuffleBlockTransferScalarPerVector_NPerBlock,
true>
parcial_acc_buf;
StaticBuffer<AddressSpaceEnum::Vgpr,
AccDataType,
CShuffleBlockTransferScalarPerVector_NPerBlock,
true>
acc_buf;
// start to compute
auto reduction_idx =
block_idx - block_2_ctile_map_streamk.reduction_start_block_idx;
auto spatial_idx = block_2_ctile_map_streamk.tile_to_spatial(
reduction_idx, problem.M, problem.N);
workgroup_barrier wg_barrier(p_semaphore);
uint32_t tile_acc_offset_start =
block_2_ctile_map_streamk.get_acc_buffer_offset_from_tile(reduction_idx);
uint32_t tile_acc_offset_end =
block_2_ctile_map_streamk.get_acc_buffer_offset_from_tile(reduction_idx +
1);
__syncthreads();
auto acc_load = ThreadwiseTensorSliceTransfer_v2<
AccDataType, // SrcData,
AccDataType, // DstData,
decltype(c_partial_acc_block_m_n), // SrcDesc,
decltype(acc_thread_buf_load_desc), // DstDesc,
Sequence<1,
CShuffleBlockTransferScalarPerVector_NPerBlock>, // SliceLengths,
Sequence<0, 1>, // DimAccessOrder,
1, // SrcVectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock, // SrcScalarPerVector,
1, // SrcScalarStrideInVector,
false // SrcResetCoordinateAfterRun,
>{c_partial_acc_block_m_n,
make_multi_index(thread_m_cluster_id,
thread_n_cluster_id *
CShuffleBlockTransferScalarPerVector_NPerBlock)};
auto acc_store = ThreadwiseTensorSliceTransfer_v1r3<
AccDataType, // SrcData,
CDataType, // DstData,
decltype(acc_thread_buf_store_desc), // SrcDesc,
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), // DstDesc,
CElementwiseOperation, // ElementwiseOperation,
Sequence<1,
1,
1,
CShuffleBlockTransferScalarPerVector_NPerBlock>, // SliceLengths,
Sequence<0, 1, 2, 3>, // DimAccessOrder,
3, // DstVectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock, // DstScalarPerVector,
InMemoryDataOperationEnum::Set, // InMemoryDataOperationEnum DstInMemOp,
1, // DstScalarStrideInVector,
false // DstResetCoordinateAfterRun,
>{c_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(__builtin_amdgcn_readfirstlane(spatial_idx[I0]),
thread_m_cluster_id,
__builtin_amdgcn_readfirstlane(spatial_idx[I1]),
thread_n_cluster_id *
CShuffleBlockTransferScalarPerVector_NPerBlock),
CElementwiseOperation{}};
wg_barrier.wait_eq(reduction_idx, tile_acc_offset_end - tile_acc_offset_start);
if(threadIdx.x == 0)
{
p_semaphore[reduction_idx] = 0;
}
using Accumulation = ck::detail::
AccumulateWithNanCheck<false /*PropagateNan*/, reduce::Add, AccDataType>;
for(int i_m = 0; i_m < MReduceIters; i_m++)
{
static_for<0, NReduceIters, 1>{}([&](auto i_n_reduce) {
acc_buf.Clear();
for(auto i = tile_acc_offset_start; i < tile_acc_offset_end; i++)
{
auto c_partial_acc_buf =
make_dynamic_buffer<AddressSpaceEnum::Global,
AmdBufferCoherenceEnum::GLC>(
reinterpret_cast<AccDataType*>(p_workspace) +
i * c_partial_acc_block_m_n.GetElementSpaceSize(),
c_partial_acc_block_m_n.GetElementSpaceSize());
acc_load.Run(c_partial_acc_block_m_n,
c_partial_acc_buf,
acc_thread_buf_load_desc,
make_tuple(I0, I0),
parcial_acc_buf);
static_for<0, CShuffleBlockTransferScalarPerVector_NPerBlock, 1>{}(
[&](auto i_vec) {
constexpr auto offset =
acc_thread_buf_load_desc.CalculateOffset(
make_tuple(0, i_vec));
Accumulation::Calculate(acc_buf(Number<offset>{}),
parcial_acc_buf[Number<offset>{}]);
});
}
if(thread_n_cluster_id *
CShuffleBlockTransferScalarPerVector_NPerBlock <
NPerBlock)
{
acc_store.Run(acc_thread_buf_store_desc,
make_tuple(I0, I0, I0, I0),
acc_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
}
if constexpr(NReduceIters != 1)
{
if constexpr(i_n_reduce != (NReduceIters - 1))
{
acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
partial_acc_load_step_n);
acc_store.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock,
partial_acc_store_step_n);
}
else
{
acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
partial_acc_load_step_n_reverse);
acc_store.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock,
partial_acc_store_step_n_reverse);
}
}
});
{
acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
partial_acc_load_step_m);
acc_store.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock,
partial_acc_store_step_m);
}
}
continue;
}
}
// offset for last acc buffer of this block
uint32_t block_acc_offset =
(block_2_ctile_map_streamk.get_acc_buffer_offset_from_block(block_idx + 1) - 1) *
MPerBlock * NPerBlock;
while(true)
{
uint32_t current_iter_length = __builtin_amdgcn_readfirstlane(
......@@ -1173,33 +1459,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
iter_end - 1, tile_idx, iter_offset);
iter_offset = __builtin_amdgcn_readfirstlane(iter_offset - current_iter_length + 1);
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(problem.M,
problem.MPadded,
problem.K,
problem.KPadded,
problem.StrideA,
problem.AK0);
const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(problem.K,
problem.KPadded,
problem.N,
problem.NPadded,
problem.StrideB,
problem.BK0);
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n, problem.MBlock, problem.NBlock);
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
auto block_work_idx =
block_2_ctile_map_streamk.tile_to_spatial(tile_idx, problem.M, problem.N);
......@@ -1363,11 +1622,20 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
constexpr auto c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle =
GetCBlockDescriptor_MShuffle_MPerShuffle_NShuffle_NPerShuffle();
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<CShuffleDataType*>(p_shared),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.GetElementSpaceSize());
auto c_partial_acc_buf =
make_dynamic_buffer<AddressSpaceEnum::Global, AmdBufferCoherenceEnum::GLC>(
reinterpret_cast<AccDataType*>(p_workspace) + block_acc_offset,
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
.GetElementSpaceSize());
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
transform_tensor_descriptor(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
......@@ -1477,7 +1745,34 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
c_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(block_m_id, 0, block_n_id, 0),
c_element_op};
// LDS to global partial acc
auto c_block_copy_lds_to_partial_acc = ThreadGroupTensorSliceTransfer_v6r1r2<
ThisThreadBlock, // index_t BlockSize,
CElementwiseOperation, // ElementwiseOperation,
// InMemoryDataOperationEnum::Set, // DstInMemOp,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave *
NPerXdl>, // BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
CShuffleDataType, // typename SrcData,
CShuffleDataType, // typename DstData,
decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
decltype(c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
false, // bool ThreadTransferSrcResetCoordinateAfterRun, => need to be
// false, othre wise has scratch
false> // bool ThreadTransferDstResetCoordinateAfterRun, => need to be
// false, othre wise has scratch
{c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(0, 0, 0, 0),
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
make_multi_index(0, 0, 0, 0),
c_element_op};
// space filling curve for threadwise C in VGPR
constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
......@@ -1535,15 +1830,40 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
}
else if(is_sk_block)
{
// each block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global
.template Run<decltype(c_shuffle_block_buf),
decltype(c_grid_buf),
InMemoryDataOperationEnum::AtomicAdd>(
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Atomic)
{
// each block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global
.template Run<decltype(c_shuffle_block_buf),
decltype(c_grid_buf),
InMemoryDataOperationEnum::AtomicAdd>(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
c_shuffle_block_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
}
else if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
{
// constexpr offset
c_block_copy_lds_to_partial_acc.SetSrcSliceOrigin(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
c_shuffle_block_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
make_tuple(0, 0, 0, 0));
c_block_copy_lds_to_partial_acc.SetDstSliceOrigin(
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
make_tuple(MXdlPerWave, 0, NXdlPerWave, 0));
c_block_copy_lds_to_partial_acc
.template Run<decltype(c_shuffle_block_buf),
decltype(c_partial_acc_buf),
InMemoryDataOperationEnum::Set>(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
c_shuffle_block_buf,
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
c_partial_acc_buf);
}
}
if constexpr(access_id < num_access - 1)
......@@ -1555,15 +1875,33 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
}
});
}
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
{
if(is_sk_block)
{
// increase the counter for this tile
workgroup_barrier wg_barrier(p_semaphore);
wg_barrier.inc(tile_idx);
}
}
} // shuffle c and write-out end
// exit condition
iter_end -= current_iter_length;
if(iter_end <= iter_start)
break;
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
{
block_acc_offset -= MPerBlock * NPerBlock;
}
// make sure next loop LDS is ready for use
block_sync_lds();
}
}
} // while loop
} // for loop
}
template <bool HasMainKBlockLoop,
......@@ -1574,19 +1912,43 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
CDataType* p_c_grid,
void* p_shared_0,
void* p_shared_1,
Problem& problem)
Problem& problem,
void* p_workspace)
{
const AElementwiseOperation a_element_op{};
const BElementwiseOperation b_element_op{};
const CElementwiseOperation c_element_op{};
Block2CTileMap_streamk block_2_ctile_map_streamk(
problem.M, problem.N, AK0Number * problem.KPadded, problem.Grid_size);
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
uint32_t iter_start, iter_end;
bool is_sk_block, is_dp_block; //, is_padding_block; //, is_reduction_block;
bool is_sk_block, is_dp_block, is_reduction_block;
index_t num_k_block_main_loop;
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n, problem.MBlock, problem.NBlock);
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
Block2CTileMap_streamk block_2_ctile_map_streamk(problem.M,
problem.N,
AK0Number * problem.KPadded,
problem.Grid_size,
problem.Streamk_sel);
for(auto block_idx = get_block_1d_id();
block_idx < block_2_ctile_map_streamk.get_grid_dims();
block_idx += gridDim.x)
......@@ -1601,6 +1963,235 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
block_2_ctile_map_streamk.get_block_itr(block_idx, iter_start, iter_end);
num_k_block_main_loop = iter_end - iter_start;
uint32_t* p_semaphore = reinterpret_cast<uint32_t*>(
reinterpret_cast<char*>(p_workspace) +
block_2_ctile_map_streamk.get_workspace_size_for_acc(sizeof(AccDataType)));
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
{
is_reduction_block = static_cast<uint32_t>(block_idx) >=
block_2_ctile_map_streamk.reduction_start_block_idx;
if(is_reduction_block)
{
// descriptors
constexpr auto cluster_length_reduce = GetClusterLengthReduction();
constexpr auto reduce_desc = make_cluster_descriptor(cluster_length_reduce);
const auto reduce_thread_cluster_idx =
reduce_desc.CalculateBottomIndex(make_multi_index(block_idx));
const auto thread_m_cluster_id = reduce_thread_cluster_idx[I0];
const auto thread_n_cluster_id = reduce_thread_cluster_idx[I1];
constexpr auto MReduceIters = math::integer_divide_ceil(
Number<MPerBlock>{}, cluster_length_reduce.At(I0));
constexpr auto NReduceIters = math::integer_divide_ceil(
Number<NPerBlock>{},
cluster_length_reduce.At(I1) *
Number<CShuffleBlockTransferScalarPerVector_NPerBlock>{});
constexpr auto acc_thread_buf_load_desc = make_naive_tensor_descriptor_packed(
make_tuple(I1, Number<CShuffleBlockTransferScalarPerVector_NPerBlock>{}));
constexpr auto acc_thread_buf_store_desc =
make_naive_tensor_descriptor_packed(make_tuple(
I1, I1, I1, Number<CShuffleBlockTransferScalarPerVector_NPerBlock>{}));
constexpr auto c_partial_acc_block_m_n = GetPartialAccBlockDescriptor();
constexpr auto partial_acc_load_step_n =
make_multi_index(0,
cluster_length_reduce.At(I1) *
CShuffleBlockTransferScalarPerVector_NPerBlock);
constexpr auto partial_acc_load_step_n_reverse = make_multi_index(
0,
-1 * cluster_length_reduce.At(I1).value * (NReduceIters - 1) *
CShuffleBlockTransferScalarPerVector_NPerBlock);
constexpr auto partial_acc_load_step_m =
make_multi_index(cluster_length_reduce.At(I0), 0);
constexpr auto partial_acc_store_step_n =
make_multi_index(0,
0,
0,
cluster_length_reduce.At(I1) *
CShuffleBlockTransferScalarPerVector_NPerBlock);
constexpr auto partial_acc_store_step_n_reverse = make_multi_index(
0,
0,
0,
-1 * cluster_length_reduce.At(I1).value * (NReduceIters - 1) *
CShuffleBlockTransferScalarPerVector_NPerBlock);
constexpr auto partial_acc_store_step_m =
make_multi_index(0, cluster_length_reduce.At(I0), 0, 0);
StaticBuffer<AddressSpaceEnum::Vgpr,
AccDataType,
CShuffleBlockTransferScalarPerVector_NPerBlock,
true>
parcial_acc_buf;
StaticBuffer<AddressSpaceEnum::Vgpr,
AccDataType,
CShuffleBlockTransferScalarPerVector_NPerBlock,
true>
acc_buf;
// start to compute
auto reduction_idx =
block_idx - block_2_ctile_map_streamk.reduction_start_block_idx;
auto spatial_idx = block_2_ctile_map_streamk.tile_to_spatial(
reduction_idx, problem.M, problem.N);
workgroup_barrier wg_barrier(p_semaphore);
uint32_t tile_acc_offset_start =
block_2_ctile_map_streamk.get_acc_buffer_offset_from_tile(reduction_idx);
uint32_t tile_acc_offset_end =
block_2_ctile_map_streamk.get_acc_buffer_offset_from_tile(reduction_idx +
1);
uint32_t expected_count = tile_acc_offset_end - tile_acc_offset_start;
if(threadIdx.x == 0)
{
p_semaphore[reduction_idx] = 0;
}
__syncthreads();
auto acc_load = ThreadwiseTensorSliceTransfer_v2<
AccDataType, // SrcData,
AccDataType, // DstData,
decltype(c_partial_acc_block_m_n), // SrcDesc,
decltype(acc_thread_buf_load_desc), // DstDesc,
Sequence<1,
CShuffleBlockTransferScalarPerVector_NPerBlock>, // SliceLengths,
Sequence<0, 1>, // DimAccessOrder,
1, // SrcVectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock, // SrcScalarPerVector,
1, // SrcScalarStrideInVector,
false // SrcResetCoordinateAfterRun,
>{c_partial_acc_block_m_n,
make_multi_index(thread_m_cluster_id,
thread_n_cluster_id *
CShuffleBlockTransferScalarPerVector_NPerBlock)};
auto acc_store = ThreadwiseTensorSliceTransfer_v1r3<
AccDataType, // SrcData,
CDataType, // DstData,
decltype(acc_thread_buf_store_desc), // SrcDesc,
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), // DstDesc,
CElementwiseOperation, // ElementwiseOperation,
Sequence<1,
1,
1,
CShuffleBlockTransferScalarPerVector_NPerBlock>, // SliceLengths,
Sequence<0, 1, 2, 3>, // DimAccessOrder,
3, // DstVectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock, // DstScalarPerVector,
InMemoryDataOperationEnum::Set, // InMemoryDataOperationEnum DstInMemOp,
1, // DstScalarStrideInVector,
false // DstResetCoordinateAfterRun,
>{c_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(__builtin_amdgcn_readfirstlane(spatial_idx[I0]),
thread_m_cluster_id,
__builtin_amdgcn_readfirstlane(spatial_idx[I1]),
thread_n_cluster_id *
CShuffleBlockTransferScalarPerVector_NPerBlock),
CElementwiseOperation{}};
#if 0
if(threadIdx.x == 0) {
printf("bid:%d, rid:%d, os:%d,%d, spatial:%d,%d\n", static_cast<int>(blockIdx.x),
reduction_idx, __builtin_amdgcn_readfirstlane(tile_acc_offset_start), __builtin_amdgcn_readfirstlane(tile_acc_offset_end),
__builtin_amdgcn_readfirstlane(spatial_idx[I0]),
__builtin_amdgcn_readfirstlane(spatial_idx[I1]));
}
#endif
if(threadIdx.x == 0)
{
atomicAdd(&p_semaphore[reduction_idx], 1);
}
wg_barrier.wait_eq(p_semaphore[reduction_idx], expected_count);
using Accumulation = ck::detail::
AccumulateWithNanCheck<false /*PropagateNan*/, reduce::Add, AccDataType>;
for(int i_m = 0; i_m < MReduceIters; i_m++)
{
static_for<0, NReduceIters, 1>{}([&](auto i_n_reduce) {
acc_buf.Clear();
for(auto i = tile_acc_offset_start; i < tile_acc_offset_end; i++)
{
auto c_partial_acc_buf =
make_dynamic_buffer<AddressSpaceEnum::Global,
AmdBufferCoherenceEnum::GLC>(
reinterpret_cast<AccDataType*>(p_workspace) +
i * c_partial_acc_block_m_n.GetElementSpaceSize(),
c_partial_acc_block_m_n.GetElementSpaceSize());
acc_load.Run(c_partial_acc_block_m_n,
c_partial_acc_buf,
acc_thread_buf_load_desc,
make_tuple(I0, I0),
parcial_acc_buf);
static_for<0, CShuffleBlockTransferScalarPerVector_NPerBlock, 1>{}(
[&](auto i_vec) {
constexpr auto offset =
acc_thread_buf_load_desc.CalculateOffset(
make_tuple(0, i_vec));
Accumulation::Calculate(acc_buf(Number<offset>{}),
parcial_acc_buf[Number<offset>{}]);
});
}
if(thread_n_cluster_id *
CShuffleBlockTransferScalarPerVector_NPerBlock <
NPerBlock)
{
acc_store.Run(acc_thread_buf_store_desc,
make_tuple(I0, I0, I0, I0),
acc_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
}
if constexpr(NReduceIters != 1)
{
if constexpr(i_n_reduce != (NReduceIters - 1))
{
acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
partial_acc_load_step_n);
acc_store.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock,
partial_acc_store_step_n);
}
else
{
acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
partial_acc_load_step_n_reverse);
acc_store.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock,
partial_acc_store_step_n_reverse);
}
}
});
{
acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
partial_acc_load_step_m);
acc_store.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock,
partial_acc_store_step_m);
}
}
continue;
}
}
// offset for last acc buffer of this block
uint32_t block_acc_offset =
(block_2_ctile_map_streamk.get_acc_buffer_offset_from_block(block_idx + 1) - 1) *
MPerBlock * NPerBlock;
while(true)
{
uint32_t current_iter_length = __builtin_amdgcn_readfirstlane(
......@@ -1611,33 +2202,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
iter_end - 1, tile_idx, iter_offset);
iter_offset = __builtin_amdgcn_readfirstlane(iter_offset - current_iter_length + 1);
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(problem.M,
problem.MPadded,
problem.K,
problem.KPadded,
problem.StrideA,
problem.AK0);
const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(problem.K,
problem.KPadded,
problem.N,
problem.NPadded,
problem.StrideB,
problem.BK0);
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n, problem.MBlock, problem.NBlock);
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
auto block_work_idx =
block_2_ctile_map_streamk.tile_to_spatial(tile_idx, problem.M, problem.N);
......@@ -1811,11 +2375,20 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
constexpr auto c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle =
GetCBlockDescriptor_MShuffle_MPerShuffle_NShuffle_NPerShuffle();
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<CShuffleDataType*>(p_shared_0),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.GetElementSpaceSize());
auto c_partial_acc_buf =
make_dynamic_buffer<AddressSpaceEnum::Global, AmdBufferCoherenceEnum::GLC>(
reinterpret_cast<AccDataType*>(p_workspace) + block_acc_offset,
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
.GetElementSpaceSize());
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
transform_tensor_descriptor(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
......@@ -1925,6 +2498,35 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
make_multi_index(block_m_id, 0, block_n_id, 0),
c_element_op};
// LDS to global partial acc
auto c_block_copy_lds_to_partial_acc = ThreadGroupTensorSliceTransfer_v6r1r2<
ThisThreadBlock, // index_t BlockSize,
CElementwiseOperation, // ElementwiseOperation,
// InMemoryDataOperationEnum::Set, // DstInMemOp,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave *
NPerXdl>, // BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
CShuffleDataType, // typename SrcData,
CShuffleDataType, // typename DstData,
decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
decltype(c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
false, // bool ThreadTransferSrcResetCoordinateAfterRun, => need to be
// false, othre wise has scratch
false> // bool ThreadTransferDstResetCoordinateAfterRun, => need to be
// false, othre wise has scratch
{c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(0, 0, 0, 0),
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
make_multi_index(0, 0, 0, 0),
c_element_op};
// space filling curve for threadwise C in VGPR
constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
......@@ -1982,15 +2584,40 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
}
else if(is_sk_block)
{
// each block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global
.template Run<decltype(c_shuffle_block_buf),
decltype(c_grid_buf),
InMemoryDataOperationEnum::AtomicAdd>(
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Atomic)
{
// each block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global
.template Run<decltype(c_shuffle_block_buf),
decltype(c_grid_buf),
InMemoryDataOperationEnum::AtomicAdd>(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
c_shuffle_block_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
}
else if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
{
// constexpr offset
c_block_copy_lds_to_partial_acc.SetSrcSliceOrigin(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
c_shuffle_block_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
make_tuple(0, 0, 0, 0));
c_block_copy_lds_to_partial_acc.SetDstSliceOrigin(
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
make_tuple(MXdlPerWave, 0, NXdlPerWave, 0));
c_block_copy_lds_to_partial_acc
.template Run<decltype(c_shuffle_block_buf),
decltype(c_partial_acc_buf),
InMemoryDataOperationEnum::Set>(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
c_shuffle_block_buf,
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
c_partial_acc_buf);
}
}
if constexpr(access_id < num_access - 1)
{
......@@ -2002,6 +2629,27 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
}
});
}
// exit condition
iter_end -= current_iter_length;
if(iter_end <= iter_start)
break;
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
{
block_acc_offset -= MPerBlock * NPerBlock;
}
// make sure next loop LDS is ready for use
block_sync_lds();
}
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
{
if(is_sk_block)
{
// increase the counter for this tile
workgroup_barrier wg_barrier(p_semaphore);
wg_barrier.inc(0);
}
}
}
}
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -13,245 +13,614 @@
namespace ck {
namespace tensor_operation {
namespace {
template <
index_t NDimSpatial,
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization ConvBwdDataSpecialization,
index_t AK1,
index_t BK1,
index_t GemmMPerBlock,
index_t GemmNPerBlock,
index_t GemmKPerBlock,
bool DoPadGemmM,
bool DoPadGemmN,
typename ALayout,
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization ConvBwdDataSpecialization>
constexpr auto make_out_grid_desc(const index_t N,
const index_t Do,
const index_t Ho,
const index_t Wo,
const index_t K,
const std::array<index_t, NDimSpatial + 3>& out_g_n_k_wos_strides)
typename BLayout,
typename CLayout,
bool SplitN = false,
typename ADataType = float,
typename CDataType = float,
index_t NumGroupsToMerge = 1,
typename IndexType = index_t>
struct TransformConvBwdDataToGemm_v1
{
const auto KStride = Number<1>{};
private:
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
if constexpr(is_same_v<ALayout, tensor_layout::convolution::NHWGK>)
{
const index_t NStride = out_g_n_k_wos_strides[1];
const index_t HiStride = out_g_n_k_wos_strides[3];
const index_t WiStride = out_g_n_k_wos_strides[4];
if constexpr(ConvBwdDataSpecialization ==
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
Filter1x1Stride1Pad0)
{
static constexpr auto NonSpatialDimsNum = Number<3>{};
return make_naive_tensor_descriptor(make_tuple(N * Ho * Wo, K),
make_tuple(WiStride, KStride));
}
else
static constexpr auto DIdx = NonSpatialDimsNum;
static constexpr auto HIdx =
NDimSpatial == 2 ? NonSpatialDimsNum : Number<NonSpatialDimsNum + 1>{};
static constexpr auto WIdx =
NDimSpatial == 2 ? Number<NonSpatialDimsNum + 1>{} : Number<NonSpatialDimsNum + 2>{};
static constexpr auto ZIdx = NonSpatialDimsNum;
static constexpr auto YIdx =
NDimSpatial == 2 ? NonSpatialDimsNum : Number<NonSpatialDimsNum + 1>{};
static constexpr auto XIdx =
NDimSpatial == 2 ? Number<NonSpatialDimsNum + 1>{} : Number<NonSpatialDimsNum + 2>{};
template <typename ConvDimsType>
static long_index_t calculate_element_space_size_impl(const ConvDimsType& lengths,
const ConvDimsType& strides,
index_t i)
{
long_index_t acc = 1;
for(; i < (NDimSpatial + 3); i++)
{
return make_naive_tensor_descriptor(make_tuple(N, Ho, Wo, K),
make_tuple(NStride, HiStride, WiStride, KStride));
acc +=
static_cast<long_index_t>(lengths[i] - I1) * static_cast<long_index_t>(strides[i]);
}
return acc;
}
else if constexpr(is_same_v<ALayout, tensor_layout::convolution::NDHWGK>)
template <typename ConvDimsType>
static IndexType GetSplitedNSize(const ConvDimsType& a_g_n_k_wos_lengths,
const ConvDimsType& a_g_n_k_wos_strides,
const ConvDimsType& c_g_n_c_wis_lengths,
const ConvDimsType& c_g_n_c_wis_strides)
{
const index_t NStride = out_g_n_k_wos_strides[1];
const index_t DoStride = out_g_n_k_wos_strides[3];
const index_t HoStride = out_g_n_k_wos_strides[4];
const index_t WoStride = out_g_n_k_wos_strides[5];
if constexpr(ConvBwdDataSpecialization ==
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
Filter1x1Stride1Pad0)
const long_index_t a_element_space_size =
calculate_element_space_size_impl(a_g_n_k_wos_lengths, a_g_n_k_wos_strides, I1);
const long_index_t c_element_space_size =
calculate_element_space_size_impl(c_g_n_c_wis_lengths, c_g_n_c_wis_strides, I1);
const long_index_t element_space_size = math::max(a_element_space_size * sizeof(ADataType),
c_element_space_size * sizeof(CDataType));
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
const IndexType N = a_g_n_k_wos_lengths[I1];
if(element_space_size > TwoGB)
{
// Minimum divisor of N to not exceed 2GB
const auto divisor = math::integer_divide_ceil(element_space_size, TwoGB);
return make_naive_tensor_descriptor(make_tuple(N * Do * Ho * Wo, K),
make_tuple(WoStride, KStride));
if(divisor <= static_cast<double>(N))
{
// Find least divisor of N larger than element_space_size / TwoGB
// Iterate up to sqrt(N). There are no divisors above this value.
for(IndexType least_divisor = divisor; least_divisor * least_divisor <= N;
least_divisor++)
{
if(N % least_divisor == 0)
{
return N / least_divisor;
}
}
// Not found, process one Convolution N per block
return 1;
}
else
{
// Not possible to support even after split N.
// Too large tensor.
return N;
}
}
else
{
return make_naive_tensor_descriptor(
make_tuple(N, Do, Ho, Wo, K),
make_tuple(NStride, DoStride, HoStride, WoStride, KStride));
// Split N is not needed.
return N;
}
}
else if constexpr(is_same_v<ALayout, tensor_layout::convolution::GNHWK>)
public:
__host__ __device__ constexpr TransformConvBwdDataToGemm_v1() {}
template <typename TransformConvBwdDataToGemm_v1Base>
__host__ __device__ TransformConvBwdDataToGemm_v1(
const TransformConvBwdDataToGemm_v1Base& transform_conv_bwd_data_to_gemm_base)
: N_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.N_)},
Di_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.Di_)},
Hi_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.Hi_)},
Wi_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.Wi_)},
Do_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.Do_)},
Ho_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.Ho_)},
Wo_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.Wo_)},
Z_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.Z_)},
Y_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.Y_)},
X_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.X_)},
K_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.K_)},
C_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.C_)},
DiStride_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.DiStride_)},
HiStride_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.HiStride_)},
WiStride_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.WiStride_)},
DoStride_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.DoStride_)},
HoStride_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.HoStride_)},
WoStride_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.WoStride_)},
CStrideTensorB_{
static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.CStrideTensorB_)},
CStrideTensorC_{
static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.CStrideTensorC_)},
KStrideTensorA_{
static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.KStrideTensorA_)},
KStrideTensorB_{
static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.KStrideTensorB_)},
NStrideTensorA_{
static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.NStrideTensorA_)},
NStrideTensorC_{
static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.NStrideTensorC_)},
ConvStrideD_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.ConvStrideD_)},
ConvStrideH_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.ConvStrideH_)},
ConvStrideW_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.ConvStrideW_)},
ConvDilationD_{
static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.ConvDilationD_)},
ConvDilationH_{
static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.ConvDilationH_)},
ConvDilationW_{
static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.ConvDilationW_)},
InLeftPadD_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.InLeftPadD_)},
InLeftPadH_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.InLeftPadH_)},
InLeftPadW_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.InLeftPadW_)},
InRightPadD_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.InRightPadD_)},
InRightPadH_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.InRightPadH_)},
InRightPadW_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.InRightPadW_)},
IdxZTilde_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.IdxZTilde_)},
IdxYTilde_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.IdxYTilde_)},
IdxXTilde_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.IdxXTilde_)},
GcdStrideDilationD_{
static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.GcdStrideDilationD_)},
GcdStrideDilationH_{
static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.GcdStrideDilationH_)},
GcdStrideDilationW_{
static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.GcdStrideDilationW_)},
ZTilde_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.ZTilde_)},
YTilde_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.YTilde_)},
XTilde_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.XTilde_)},
DTilde_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.DTilde_)},
HTilde_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.HTilde_)},
WTilde_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.WTilde_)},
ZDot_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.ZDot_)},
YDot_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.YDot_)},
XDot_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.XDot_)}
{
// assume packed
if constexpr(ConvBwdDataSpecialization ==
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
Filter1x1Stride1Pad0)
}
template <typename ConvDimsType, typename ConvSpatialDimsType>
__host__ __device__
TransformConvBwdDataToGemm_v1(const ConvDimsType& a_g_n_k_wos_lengths,
const ConvDimsType& a_g_n_k_wos_strides,
const ConvDimsType& b_g_k_c_xs_lengths,
const ConvDimsType& b_g_k_c_xs_strides,
const ConvDimsType& c_g_n_c_wis_lengths,
const ConvDimsType& c_g_n_c_wis_strides,
const ConvSpatialDimsType& conv_filter_strides,
const ConvSpatialDimsType& conv_filter_dilations,
const ConvSpatialDimsType& input_left_pads,
const ConvSpatialDimsType& input_right_pads,
const ConvSpatialDimsType& tildes)
: Hi_{c_g_n_c_wis_lengths[HIdx]},
Wi_{c_g_n_c_wis_lengths[WIdx]},
Ho_{a_g_n_k_wos_lengths[HIdx]},
Wo_{a_g_n_k_wos_lengths[WIdx]},
Y_{b_g_k_c_xs_lengths[YIdx]},
X_{b_g_k_c_xs_lengths[XIdx]},
K_{a_g_n_k_wos_lengths[I2]},
C_{b_g_k_c_xs_lengths[I2]},
HiStride_{c_g_n_c_wis_strides[HIdx]},
WiStride_{c_g_n_c_wis_strides[WIdx]},
HoStride_{a_g_n_k_wos_strides[HIdx]},
WoStride_{a_g_n_k_wos_strides[WIdx]},
CStrideTensorB_{b_g_k_c_xs_strides[I2]},
CStrideTensorC_{c_g_n_c_wis_strides[I2]},
KStrideTensorA_{a_g_n_k_wos_strides[I2]},
KStrideTensorB_{b_g_k_c_xs_strides[I1]},
NStrideTensorA_{a_g_n_k_wos_strides[I1]},
NStrideTensorC_{c_g_n_c_wis_strides[I1]},
ConvStrideH_{conv_filter_strides[HIdx - NonSpatialDimsNum]},
ConvStrideW_{conv_filter_strides[WIdx - NonSpatialDimsNum]},
ConvDilationH_{conv_filter_dilations[HIdx - NonSpatialDimsNum]},
ConvDilationW_{conv_filter_dilations[WIdx - NonSpatialDimsNum]},
InLeftPadH_{input_left_pads[HIdx - NonSpatialDimsNum]},
InLeftPadW_{input_left_pads[WIdx - NonSpatialDimsNum]},
InRightPadH_{input_right_pads[HIdx - NonSpatialDimsNum]},
InRightPadW_{input_right_pads[WIdx - NonSpatialDimsNum]},
IdxYTilde_{tildes[YIdx - NonSpatialDimsNum]},
IdxXTilde_{tildes[XIdx - NonSpatialDimsNum]}
{
static_assert(is_same_v<ConvSpatialDimsType, std::array<IndexType, NDimSpatial>> ||
is_same_v<ConvSpatialDimsType, ck::Array<IndexType, NDimSpatial>>);
static_assert(is_same_v<ConvDimsType, std::array<IndexType, NDimSpatial + I3>> ||
is_same_v<ConvDimsType, ck::Array<IndexType, NDimSpatial + I3>>);
if constexpr(SplitN)
{
return make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
N_ = GetSplitedNSize(
a_g_n_k_wos_lengths, a_g_n_k_wos_strides, c_g_n_c_wis_lengths, c_g_n_c_wis_strides);
}
else
{
return make_naive_tensor_descriptor_packed(make_tuple(N, Ho, Wo, K));
N_ = c_g_n_c_wis_lengths[I1];
}
}
else if constexpr(is_same_v<ALayout, tensor_layout::convolution::GNDHWK>)
{
// assume packed
if constexpr(ConvBwdDataSpecialization ==
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
Filter1x1Stride1Pad0)
if constexpr(NDimSpatial == 3)
{
return make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo, K));
Di_ = c_g_n_c_wis_lengths[DIdx];
Do_ = a_g_n_k_wos_lengths[DIdx];
Z_ = b_g_k_c_xs_lengths[ZIdx];
DiStride_ = c_g_n_c_wis_strides[DIdx];
DoStride_ = a_g_n_k_wos_strides[DIdx];
ConvStrideD_ = conv_filter_strides[DIdx - NonSpatialDimsNum];
ConvDilationD_ = conv_filter_dilations[DIdx - NonSpatialDimsNum];
InLeftPadD_ = input_left_pads[DIdx - NonSpatialDimsNum];
InRightPadD_ = input_right_pads[DIdx - NonSpatialDimsNum];
IdxZTilde_ = tildes[ZIdx - NonSpatialDimsNum];
GcdStrideDilationD_ = math::gcd(ConvStrideD_, ConvDilationD_);
ZTilde_ = ConvStrideD_ / GcdStrideDilationD_;
DTilde_ = Do_ + math::integer_divide_ceil(ConvDilationD_ * (Z_ - I1), ConvStrideD_);
ZDot_ = math::integer_divide_ceil(Z_, ZTilde_);
}
else
{
return make_naive_tensor_descriptor_packed(make_tuple(N, Do, Ho, Wo, K));
Di_ = Do_ = Z_ = ZTilde_ = ConvStrideD_ = DTilde_ = ZDot_ = 1;
InLeftPadD_ = InRightPadD_ = DiStride_ = DoStride_ = IdxZTilde_ = 0;
}
}
else
{
throw std::runtime_error("wrong! unsupported layout: " + ALayout::name());
}
}
template <typename BLayout>
constexpr auto make_wei_grid_desc(
const index_t K, const index_t Z, const index_t Y, const index_t X, const index_t C)
{
GcdStrideDilationH_ = math::gcd(ConvStrideH_, ConvDilationH_);
GcdStrideDilationW_ = math::gcd(ConvStrideW_, ConvDilationW_);
if constexpr(is_same_v<BLayout, tensor_layout::convolution::GKYXC>)
{
return make_naive_tensor_descriptor_packed(make_tuple(K, Y, X, C));
}
else if constexpr(is_same_v<BLayout, tensor_layout::convolution::GKZYXC>)
{
return make_naive_tensor_descriptor_packed(make_tuple(K, Z, Y, X, C));
}
else
{
throw std::runtime_error("wrong! unsupported layout: " + BLayout::name());
}
}
template <index_t NDimSpatial, typename CLayout>
constexpr auto make_in_grid_desc(const index_t N,
const index_t Di,
const index_t Hi,
const index_t Wi,
const index_t C,
const std::array<index_t, NDimSpatial + 3>& in_g_n_c_wis_strides)
{
YTilde_ = ConvStrideH_ / GcdStrideDilationH_;
XTilde_ = ConvStrideW_ / GcdStrideDilationW_;
if constexpr(is_same_v<CLayout, tensor_layout::convolution::GNHWC> ||
is_same_v<CLayout, tensor_layout::convolution::NHWGC> ||
is_same_v<CLayout, tensor_layout::convolution::G_NHW_C>)
{
return make_naive_tensor_descriptor(make_tuple(N, Hi, Wi, C),
make_tuple(in_g_n_c_wis_strides[1],
in_g_n_c_wis_strides[3],
in_g_n_c_wis_strides[4],
in_g_n_c_wis_strides[2]));
HTilde_ = Ho_ + math::integer_divide_ceil(ConvDilationH_ * (Y_ - I1), ConvStrideH_);
WTilde_ = Wo_ + math::integer_divide_ceil(ConvDilationW_ * (X_ - I1), ConvStrideW_);
YDot_ = math::integer_divide_ceil(Y_, YTilde_);
XDot_ = math::integer_divide_ceil(X_, XTilde_);
}
else if constexpr(is_same_v<CLayout, tensor_layout::convolution::GNDHWC> ||
is_same_v<CLayout, tensor_layout::convolution::NDHWGC>)
#if 0 // At now not supported to split tensor
__host__ bool AreDescriptorsSmallerThan2GB() const
{
return make_naive_tensor_descriptor(make_tuple(N, Di, Hi, Wi, C),
make_tuple(in_g_n_c_wis_strides[1],
in_g_n_c_wis_strides[3],
in_g_n_c_wis_strides[4],
in_g_n_c_wis_strides[5],
in_g_n_c_wis_strides[2]));
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
const long_index_t in_desc_space_size =
I1 + (N_ - I1) * NStrideTensorC_ + (Di_ - I1) * DiStride_ + (Hi_ - I1) * HiStride_ +
(Wi_ - I1) * WiStride_ + (C_ - I1) * CStrideTensorC_;
const long_index_t out_desc_space_size =
I1 + (N_ - I1) * NStrideTensorA_ + (Do_ - I1) * DoStride_ + (Ho_ - I1) * HoStride_ +
(Wo_ - I1) * WoStride_ + (K_ - I1) * KStrideTensorA_;
bool is_a_descriptor_smaller_than_2GB = (out_desc_space_size * sizeof(ADataType)) <= TwoGB;
bool is_c_descriptor_smaller_than_2GB = (in_desc_space_size * sizeof(CDataType)) <= TwoGB;
return is_a_descriptor_smaller_than_2GB && is_c_descriptor_smaller_than_2GB;
}
else
__host__ auto SplitConvProblem(const ADataType* a_grid_ptr_base,
CDataType* c_grid_ptr_base) const
{
throw std::runtime_error("wrong! unsupported layout: " + CLayout::name());
}
}
// Create copies
auto conv_to_gemm_transformer_left = *this;
auto conv_to_gemm_transformer_right = *this;
IndexType a_right_offset = 0;
IndexType c_right_offset = 0;
// Calculate real filter size
const IndexType z_eff = (Z_ - 1) * ConvDilationD_ + 1;
const IndexType y_eff = (Y_ - 1) * ConvDilationH_ + 1;
const IndexType x_eff = (X_ - 1) * ConvDilationW_ + 1;
// Calculate start position in input for right tensor
const IndexType di_right_transformer_start_idx = (Do_ / 2) * ConvStrideD_;
const IndexType hi_right_transformer_start_idx = (Ho_ / 2) * ConvStrideH_;
const IndexType wi_right_transformer_start_idx = (Wo_ / 2) * ConvStrideW_;
// Calculate last position in input for left tensor
const IndexType di_left_transformer_end_idx = (Do_ / 2 - 1) * ConvStrideD_ + z_eff;
const IndexType hi_left_transformer_end_idx = (Ho_ / 2 - 1) * ConvStrideH_ + y_eff;
const IndexType wi_left_transformer_end_idx = (Wo_ / 2 - 1) * ConvStrideW_ + x_eff;
// Allow to split if whole left padding will be in left tensor and right padding in right
// tensor
const bool is_possible_to_split_d = Do_ != 1 &&
di_right_transformer_start_idx > InLeftPadD_ &&
di_left_transformer_end_idx <= (InLeftPadD_ + Di_);
const bool is_possible_to_split_h = Ho_ != 1 &&
hi_right_transformer_start_idx > InLeftPadH_ &&
hi_left_transformer_end_idx <= (InLeftPadH_ + Hi_);
const bool is_possible_to_split_w = Wo_ != 1 &&
wi_right_transformer_start_idx > InLeftPadW_ &&
wi_left_transformer_end_idx <= (InLeftPadW_ + Wi_);
if(is_possible_to_split_d)
{
// Apply new sizes
// Split output on half
conv_to_gemm_transformer_left.Do_ = Do_ / 2;
conv_to_gemm_transformer_right.Do_ = Do_ - Do_ / 2;
// Assign left padding to left convolution
conv_to_gemm_transformer_left.InLeftPadD_ = InLeftPadD_;
conv_to_gemm_transformer_right.InLeftPadD_ = 0;
// Assign right padding to right convolution
conv_to_gemm_transformer_left.InRightPadD_ = 0;
conv_to_gemm_transformer_right.InRightPadD_ = InRightPadD_;
// Calculate new input size
conv_to_gemm_transformer_left.Di_ = di_left_transformer_end_idx - InLeftPadD_;
conv_to_gemm_transformer_right.Di_ =
math::min(Di_ - (di_right_transformer_start_idx - InLeftPadD_),
(conv_to_gemm_transformer_right.Do_ - 1) * ConvStrideD_ + z_eff);
;
// Calcualte offsets
a_right_offset = (Do_ / 2) * DoStride_;
c_right_offset = ((Do_ / 2) * ConvStrideD_ - InLeftPadD_) * DiStride_;
}
else if(is_possible_to_split_h)
{
conv_to_gemm_transformer_left.Ho_ = Ho_ / 2;
conv_to_gemm_transformer_right.Ho_ = Ho_ - Ho_ / 2;
} // namespace
conv_to_gemm_transformer_left.InLeftPadH_ = InLeftPadH_;
conv_to_gemm_transformer_right.InLeftPadH_ = 0;
template <
index_t NDimSpatial,
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization ConvBwdDataSpecialization,
index_t AK1,
index_t BK1,
index_t GemmMPerBlock,
index_t GemmNPerBlock,
index_t GemmKPerBlock,
bool DoPadGemmM,
bool DoPadGemmN>
struct TransformConvBwdDataToGemm_v1
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
conv_to_gemm_transformer_left.InRightPadH_ = 0;
conv_to_gemm_transformer_right.InRightPadH_ = InRightPadH_;
static constexpr auto NonSpatialDimsNum = Number<3>{};
conv_to_gemm_transformer_left.Hi_ = hi_left_transformer_end_idx - InLeftPadH_;
conv_to_gemm_transformer_right.Hi_ =
math::min(Hi_ - (hi_right_transformer_start_idx - InLeftPadH_),
(conv_to_gemm_transformer_right.Ho_ - 1) * ConvStrideH_ + y_eff);
a_right_offset = (Ho_ / 2) * HoStride_;
c_right_offset = ((Ho_ / 2) * ConvStrideH_ - InLeftPadH_) * HiStride_;
}
else if(is_possible_to_split_w)
{
conv_to_gemm_transformer_left.Wo_ = Wo_ / 2;
conv_to_gemm_transformer_right.Wo_ = Wo_ - Wo_ / 2;
static constexpr auto DIdx = Number<NonSpatialDimsNum>{};
static constexpr auto HIdx =
NDimSpatial == 2 ? Number<NonSpatialDimsNum>{} : Number<NonSpatialDimsNum + 1>{};
static constexpr auto WIdx =
NDimSpatial == 2 ? Number<NonSpatialDimsNum + 1>{} : Number<NonSpatialDimsNum + 2>{};
conv_to_gemm_transformer_left.InLeftPadW_ = InLeftPadW_;
conv_to_gemm_transformer_right.InLeftPadW_ = 0;
static constexpr auto ZIdx = Number<NonSpatialDimsNum>{};
static constexpr auto YIdx =
NDimSpatial == 2 ? Number<NonSpatialDimsNum>{} : Number<NonSpatialDimsNum + 1>{};
static constexpr auto XIdx =
NDimSpatial == 2 ? Number<NonSpatialDimsNum + 1>{} : Number<NonSpatialDimsNum + 2>{};
conv_to_gemm_transformer_left.InRightPadW_ = 0;
conv_to_gemm_transformer_right.InRightPadW_ = InRightPadW_;
template <typename ALayout,
typename std::enable_if<(NDimSpatial == 2 || NDimSpatial == 3) &&
(is_same_v<ALayout, tensor_layout::convolution::GNHWK> ||
is_same_v<ALayout, tensor_layout::convolution::GNDHWK> ||
is_same_v<ALayout, tensor_layout::convolution::NHWGK> ||
is_same_v<ALayout, tensor_layout::convolution::NDHWGK>),
bool>::type = false>
static auto MakeADescriptor_AK0_M_AK1(
const std::array<index_t, NDimSpatial + 3>& out_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& out_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial + 3>& wei_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& /* wei_g_k_c_xs_strides */,
const std::array<index_t, NDimSpatial + 3>& in_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& /* in_g_n_c_wis_strides */,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& /* input_right_pads */,
const std::array<index_t, NDimSpatial>& tildes)
conv_to_gemm_transformer_left.Wi_ = wi_left_transformer_end_idx - InLeftPadW_;
conv_to_gemm_transformer_right.Wi_ =
math::min(Wi_ - (wi_right_transformer_start_idx - InLeftPadW_),
(conv_to_gemm_transformer_right.Wo_ - 1) * ConvStrideW_ + x_eff);
a_right_offset = (Wo_ / 2) * WoStride_;
c_right_offset = ((Wo_ / 2) * ConvStrideW_ - InLeftPadW_) * WiStride_;
}
// Return left transform, right transformer, right offset to Input and right offset to
// Output
return ck::make_tuple(conv_to_gemm_transformer_left,
conv_to_gemm_transformer_right,
a_grid_ptr_base + a_right_offset,
c_grid_ptr_base + c_right_offset);
}
__host__ auto SplitConvProblem(const ADataType* a_grid_ptr_base,
CDataType* c_grid_ptr_base) const
{
index_t i_ztilde = tildes[ZIdx - NonSpatialDimsNum];
index_t i_ytilde = tildes[YIdx - NonSpatialDimsNum];
index_t i_xtilde = tildes[XIdx - NonSpatialDimsNum];
// Create copies
auto conv_to_gemm_transformer_left = *this;
auto conv_to_gemm_transformer_right = *this;
IndexType a_right_offset = 0;
IndexType c_right_offset = 0;
// Calculate start position in input for right tensor
const IndexType do_right_transformer_start_idx = math::integer_divide_ceil((Di_ / 2) + InLeftPadD_ - ((Z_ - 1) * ConvDilationD_), ConvStrideD_);
const IndexType ho_right_transformer_start_idx = math::integer_divide_ceil((Hi_ / 2) + InLeftPadH_ - ((Y_ - 1) * ConvDilationH_), ConvStrideH_);
const IndexType wo_right_transformer_start_idx = math::integer_divide_ceil((Wi_ / 2) + InLeftPadW_ - ((X_ - 1) * ConvDilationW_), ConvStrideW_);
// Calculate last position in input for left tensor
const IndexType do_left_transformer_end_idx = math::integer_divide_ceil((Di_ / 2 - 1) + InLeftPadD_, ConvStrideD_);
const IndexType ho_left_transformer_end_idx = math::integer_divide_ceil((Hi_ / 2 - 1) + InLeftPadH_, ConvStrideH_);
const IndexType wo_left_transformer_end_idx = math::integer_divide_ceil((Wi_ / 2 - 1) + InLeftPadW_, ConvStrideW_);
if(Di_!=1)
{
// Apply new sizes
// Split output on half
conv_to_gemm_transformer_left.Di_ = Di_ / 2;
conv_to_gemm_transformer_right.Di_ = Di_ - Di_ / 2;
// Assign left padding to left convolution
conv_to_gemm_transformer_left.InLeftPadD_ = InLeftPadD_;
conv_to_gemm_transformer_right.InLeftPadD_ = 0;
// // Assign right padding to right convolution
conv_to_gemm_transformer_left.InRightPadD_ = 0;
conv_to_gemm_transformer_right.InRightPadD_ = InRightPadD_;
// Calculate new input size
conv_to_gemm_transformer_left.Do_ = do_left_transformer_end_idx;
conv_to_gemm_transformer_right.Do_ = Do_ - do_right_transformer_start_idx;
;
// Calcualte offsets
a_right_offset = do_right_transformer_start_idx * DoStride_;
c_right_offset = (Di_ / 2) * DiStride_;
}
else if(Hi_!=1)
{
// Apply new sizes
// Split output on half
conv_to_gemm_transformer_left.Hi_ = Hi_ / 2;
conv_to_gemm_transformer_right.Hi_ = Hi_ - Hi_ / 2;
// Assign left padding to left convolution
conv_to_gemm_transformer_left.InLeftPadH_ = InLeftPadH_;
conv_to_gemm_transformer_right.InLeftPadH_ = 0;
// // Assign right padding to right convolution
conv_to_gemm_transformer_left.InRightPadH_ = 0;
conv_to_gemm_transformer_right.InRightPadH_ = InRightPadH_;
// Calculate new input size
conv_to_gemm_transformer_left.Ho_ = ho_left_transformer_end_idx ;
conv_to_gemm_transformer_right.Ho_ = Ho_ - ho_right_transformer_start_idx ;
;
// Calcualte offsets
a_right_offset = ho_right_transformer_start_idx * HoStride_;
c_right_offset = (Hi_ / 2) * HiStride_;
}
else if(Wi_!=1)
{
// Apply new sizes
// Split output on half
conv_to_gemm_transformer_left.Wi_ = Wi_ / 2;
conv_to_gemm_transformer_right.Wi_ = Wi_ - Wi_ / 2;
// Assign left padding to left convolution
conv_to_gemm_transformer_left.InLeftPadW_ = InLeftPadW_;
conv_to_gemm_transformer_right.InLeftPadW_ = 0;
// Assign right padding to right convolution
conv_to_gemm_transformer_left.InRightPadW_ = 0;
conv_to_gemm_transformer_right.InRightPadW_ = InRightPadW_;
// Calculate new input size
conv_to_gemm_transformer_left.Wo_ = wo_left_transformer_end_idx;
conv_to_gemm_transformer_right.Wo_ = Wo_ - wo_right_transformer_start_idx;
;
// Calcualte offsets
a_right_offset = wo_right_transformer_start_idx * WoStride_;
c_right_offset = (Wi_ / 2) * WiStride_;
}
// Return left transform, right transformer, right offset to Input and right offset to
// Output
return ck::make_tuple(conv_to_gemm_transformer_left,
conv_to_gemm_transformer_right,
a_grid_ptr_base + a_right_offset,
c_grid_ptr_base + c_right_offset);
}
#endif
const index_t N = in_g_n_c_wis_lengths[1];
const index_t K = wei_g_k_c_xs_lengths[1];
__host__ __device__ auto MakeOutGridDesc() const
{
if constexpr(is_same_v<ALayout, tensor_layout::convolution::NHWGK>)
{
if constexpr(ConvBwdDataSpecialization ==
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
Filter1x1Stride1Pad0)
{
const index_t Di = NDimSpatial == 3 ? in_g_n_c_wis_lengths[DIdx] : 1;
const index_t Hi = in_g_n_c_wis_lengths[HIdx];
const index_t Wi = in_g_n_c_wis_lengths[WIdx];
return make_naive_tensor_descriptor(make_tuple(N_ * Ho_ * Wo_, K_),
make_tuple(WoStride_, KStrideTensorA_));
}
else
{
return make_naive_tensor_descriptor(
make_tuple(N_, Ho_, Wo_, K_),
make_tuple(NStrideTensorA_, HoStride_, WoStride_, KStrideTensorA_));
}
}
else if constexpr(is_same_v<ALayout, tensor_layout::convolution::NDHWGK>)
{
if constexpr(ConvBwdDataSpecialization ==
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
Filter1x1Stride1Pad0)
{
const index_t Do = NDimSpatial == 3 ? out_g_n_k_wos_lengths[DIdx] : 1;
const index_t Ho = out_g_n_k_wos_lengths[HIdx];
const index_t Wo = out_g_n_k_wos_lengths[WIdx];
return make_naive_tensor_descriptor(make_tuple(N_ * Do_ * Ho_ * Wo_, K_),
make_tuple(WoStride_, KStrideTensorA_));
}
else
{
return make_naive_tensor_descriptor(
make_tuple(N_, Do_, Ho_, Wo_, K_),
make_tuple(NStrideTensorA_, DoStride_, HoStride_, WoStride_, KStrideTensorA_));
}
}
else if constexpr(is_same_v<ALayout, tensor_layout::convolution::GNHWK>)
{
// assume packed
if constexpr(ConvBwdDataSpecialization ==
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
Filter1x1Stride1Pad0)
{
return make_naive_tensor_descriptor_packed(make_tuple(N_ * Ho_ * Wo_, K_));
}
else
{
return make_naive_tensor_descriptor_packed(make_tuple(N_, Ho_, Wo_, K_));
}
}
else if constexpr(is_same_v<ALayout, tensor_layout::convolution::GNDHWK>)
{
// assume packed
if constexpr(ConvBwdDataSpecialization ==
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
Filter1x1Stride1Pad0)
{
return make_naive_tensor_descriptor_packed(make_tuple(N_ * Do_ * Ho_ * Wo_, K_));
}
else
{
return make_naive_tensor_descriptor_packed(make_tuple(N_, Do_, Ho_, Wo_, K_));
}
}
else
{
throw std::runtime_error("wrong! unsupported layout: " + ALayout::name());
}
}
const index_t Z = NDimSpatial == 3 ? wei_g_k_c_xs_lengths[ZIdx] : 1;
const index_t Y = wei_g_k_c_xs_lengths[YIdx];
const index_t X = wei_g_k_c_xs_lengths[XIdx];
__host__ __device__ auto MakeWeiGridDesc() const
{
const index_t InLeftPadD = input_left_pads[DIdx - NonSpatialDimsNum];
const index_t InLeftPadH = input_left_pads[HIdx - NonSpatialDimsNum];
const index_t InLeftPadW = input_left_pads[WIdx - NonSpatialDimsNum];
if constexpr(is_same_v<BLayout, tensor_layout::convolution::GKYXC>)
{
return make_naive_tensor_descriptor_packed(make_tuple(K_, Y_, X_, C_));
}
else if constexpr(is_same_v<BLayout, tensor_layout::convolution::GKZYXC>)
{
return make_naive_tensor_descriptor_packed(make_tuple(K_, Z_, Y_, X_, C_));
}
else
{
throw std::runtime_error("wrong! unsupported layout: " + BLayout::name());
}
}
const index_t ConvStrideD = conv_filter_strides[DIdx - NonSpatialDimsNum];
const index_t ConvStrideH = conv_filter_strides[HIdx - NonSpatialDimsNum];
const index_t ConvStrideW = conv_filter_strides[WIdx - NonSpatialDimsNum];
__host__ __device__ auto MakeInGridDesc() const
{
const index_t ConvDilationD = conv_filter_dilations[DIdx - NonSpatialDimsNum];
const index_t ConvDilationH = conv_filter_dilations[HIdx - NonSpatialDimsNum];
const index_t ConvDilationW = conv_filter_dilations[WIdx - NonSpatialDimsNum];
if constexpr(is_same_v<CLayout, tensor_layout::convolution::GNHWC> ||
is_same_v<CLayout, tensor_layout::convolution::NHWGC> ||
is_same_v<CLayout, tensor_layout::convolution::G_NHW_C>)
{
return make_naive_tensor_descriptor(
make_tuple(N_, Hi_, Wi_, C_),
make_tuple(NStrideTensorC_, HiStride_, WiStride_, CStrideTensorC_));
}
else if constexpr(is_same_v<CLayout, tensor_layout::convolution::GNDHWC> ||
is_same_v<CLayout, tensor_layout::convolution::NDHWGC>)
{
return make_naive_tensor_descriptor(
make_tuple(N_, Di_, Hi_, Wi_, C_),
make_tuple(NStrideTensorC_, DiStride_, HiStride_, WiStride_, CStrideTensorC_));
}
else
{
throw std::runtime_error("wrong! unsupported layout: " + CLayout::name());
}
}
template <
typename ALayout_ = ALayout,
typename std::enable_if<(NDimSpatial == 2 || NDimSpatial == 3) &&
(is_same_v<ALayout_, tensor_layout::convolution::GNHWK> ||
is_same_v<ALayout_, tensor_layout::convolution::GNDHWK> ||
is_same_v<ALayout_, tensor_layout::convolution::NHWGK> ||
is_same_v<ALayout_, tensor_layout::convolution::NDHWGK>),
bool>::type = false>
__host__ __device__ auto MakeADescriptor_AK0_M_AK1() const
{
// n_do_ho_wo_k for 3d or n_ho_wo_k for 2d
const auto out_grid_desc =
make_out_grid_desc<NDimSpatial, ALayout, ConvBwdDataSpecialization>(
N, Do, Ho, Wo, K, out_g_n_k_wos_strides);
const auto out_grid_desc = MakeOutGridDesc();
if constexpr(ConvBwdDataSpecialization ==
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
Filter1x1Stride1Pad0)
{
const index_t AK0 = math::integer_divide_ceil(K, AK1);
const index_t AK0 = math::integer_divide_ceil(K_, AK1);
// A: output tensor
const auto out_gemmak0_gemmmraw_gemmak1_grid_desc = transform_tensor_descriptor(
out_grid_desc,
make_tuple(make_pass_through_transform(N * Do * Ho * Wo),
make_tuple(make_pass_through_transform(N_ * Do_ * Ho_ * Wo_),
make_unmerge_transform(make_tuple(AK0, AK1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}));
......@@ -266,82 +635,63 @@ struct TransformConvBwdDataToGemm_v1
}
else
{
const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD);
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
const auto ZTilde = ConvStrideD / GcdStrideDilationD;
const auto YTilde = ConvStrideH / GcdStrideDilationH;
const auto XTilde = ConvStrideW / GcdStrideDilationW;
const auto ZDot = math::integer_divide_ceil(Z, ZTilde);
const auto YDot = math::integer_divide_ceil(Y, YTilde);
const auto XDot = math::integer_divide_ceil(X, XTilde);
const auto DTilde =
Do + math::integer_divide_ceil(ConvDilationD * (Z - I1), ConvStrideD);
const auto HTilde =
Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH);
const auto WTilde =
Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
// only work on HTilde and WTilde that contribute to non-padding area of input tensor
const auto IDTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadD - ConvDilationD * (ZTilde - I1)), ConvStrideD);
math::max(I0, InLeftPadD_ - ConvDilationD_ * (ZTilde_ - I1)), ConvStrideD_);
const auto IHTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH);
math::max(I0, InLeftPadH_ - ConvDilationH_ * (YTilde_ - I1)), ConvStrideH_);
const auto IWTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW);
math::max(I0, InLeftPadW_ - ConvDilationW_ * (XTilde_ - I1)), ConvStrideW_);
const auto IDTildeSliceEnd = math::min(
DTilde, math::integer_divide_ceil(InLeftPadD + Di - I1, ConvStrideD) + I1);
DTilde_, math::integer_divide_ceil(InLeftPadD_ + Di_ - I1, ConvStrideD_) + I1);
const auto IHTildeSliceEnd = math::min(
HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1);
HTilde_, math::integer_divide_ceil(InLeftPadH_ + Hi_ - I1, ConvStrideH_) + I1);
const auto IWTildeSliceEnd = math::min(
WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
WTilde_, math::integer_divide_ceil(InLeftPadW_ + Wi_ - I1, ConvStrideW_) + I1);
const auto DTildeSlice = IDTildeSliceEnd - IDTildeSliceBegin;
const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
// GemmK is different for each GEMM
const auto ZDotSlice = math::integer_divide_ceil(Z - i_ztilde, ZTilde);
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
const auto ZDotSlice = math::integer_divide_ceil(Z_ - IdxZTilde_, ZTilde_);
const auto YDotSlice = math::integer_divide_ceil(Y_ - IdxYTilde_, YTilde_);
const auto XDotSlice = math::integer_divide_ceil(X_ - IdxXTilde_, XTilde_);
if constexpr(NDimSpatial == 2)
{
// A: output tensor
const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor(
out_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Ho, I0, I0),
make_pad_transform(Wo, I0, I0),
make_pass_through_transform(K)),
make_tuple(make_pass_through_transform(N_),
make_pad_transform(Ho_, I0, I0),
make_pad_transform(Wo_, I0, I0),
make_pass_through_transform(K_)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor(
out_n_hop_wop_k_grid_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(YDot, HTilde),
make_tuple(-ConvDilationH / GcdStrideDilationH, I1)),
make_embed_transform(make_tuple(XDot, WTilde),
make_tuple(-ConvDilationW / GcdStrideDilationW, I1)),
make_pass_through_transform(K)),
make_pass_through_transform(N_),
make_embed_transform(make_tuple(YDot_, HTilde_),
make_tuple(-ConvDilationH_ / GcdStrideDilationH_, I1)),
make_embed_transform(make_tuple(XDot_, WTilde_),
make_tuple(-ConvDilationW_ / GcdStrideDilationW_, I1)),
make_pass_through_transform(K_)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc =
transform_tensor_descriptor(
out_n_ydot_htilde_xdot_wtilde_k_grid_desc,
make_tuple(make_pass_through_transform(N),
make_slice_transform(YDot, I0, YDotSlice),
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
make_slice_transform(XDot, I0, XDotSlice),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
make_pass_through_transform(K)),
make_tuple(make_pass_through_transform(N_),
make_slice_transform(YDot_, I0, YDotSlice),
make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice),
make_slice_transform(XDot_, I0, XDotSlice),
make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice),
make_pass_through_transform(K_)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
......@@ -357,8 +707,8 @@ struct TransformConvBwdDataToGemm_v1
const auto out_gemmk_gemmmraw_grid_desc = transform_tensor_descriptor(
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc,
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K)),
make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice))),
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K_)),
make_merge_transform(make_tuple(N_, HTildeSlice, WTildeSlice))),
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
......@@ -385,11 +735,11 @@ struct TransformConvBwdDataToGemm_v1
// A: output tensor
const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor(
out_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Do, I0, I0),
make_pad_transform(Ho, I0, I0),
make_pad_transform(Wo, I0, I0),
make_pass_through_transform(K)),
make_tuple(make_pass_through_transform(N_),
make_pad_transform(Do_, I0, I0),
make_pad_transform(Ho_, I0, I0),
make_pad_transform(Wo_, I0, I0),
make_pass_through_transform(K_)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
......@@ -398,17 +748,17 @@ struct TransformConvBwdDataToGemm_v1
const auto out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc =
transform_tensor_descriptor(
out_n_hop_wop_k_grid_desc,
make_tuple(make_pass_through_transform(N),
make_tuple(make_pass_through_transform(N_),
make_embed_transform(
make_tuple(ZDot, DTilde),
make_tuple(-ConvDilationD / GcdStrideDilationD, I1)),
make_tuple(ZDot_, DTilde_),
make_tuple(-ConvDilationD_ / GcdStrideDilationD_, I1)),
make_embed_transform(
make_tuple(YDot, HTilde),
make_tuple(-ConvDilationH / GcdStrideDilationH, I1)),
make_tuple(YDot_, HTilde_),
make_tuple(-ConvDilationH_ / GcdStrideDilationH_, I1)),
make_embed_transform(
make_tuple(XDot, WTilde),
make_tuple(-ConvDilationW / GcdStrideDilationW, I1)),
make_pass_through_transform(K)),
make_tuple(XDot_, WTilde_),
make_tuple(-ConvDilationW_ / GcdStrideDilationW_, I1)),
make_pass_through_transform(K_)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
......@@ -424,14 +774,15 @@ struct TransformConvBwdDataToGemm_v1
out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc =
transform_tensor_descriptor(
out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc,
make_tuple(make_pass_through_transform(N),
make_slice_transform(ZDot, I0, ZDotSlice),
make_slice_transform(DTilde, IDTildeSliceBegin, DTildeSlice),
make_slice_transform(YDot, I0, YDotSlice),
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
make_slice_transform(XDot, I0, XDotSlice),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
make_pass_through_transform(K)),
make_tuple(
make_pass_through_transform(N_),
make_slice_transform(ZDot_, I0, ZDotSlice),
make_slice_transform(DTilde_, IDTildeSliceBegin, DTildeSlice),
make_slice_transform(YDot_, I0, YDotSlice),
make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice),
make_slice_transform(XDot_, I0, XDotSlice),
make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice),
make_pass_through_transform(K_)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
......@@ -452,8 +803,9 @@ struct TransformConvBwdDataToGemm_v1
const auto out_gemmk_gemmmraw_grid_desc = transform_tensor_descriptor(
out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc,
make_tuple(
make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K)),
make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice))),
make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K_)),
make_merge_transform(
make_tuple(N_, DTildeSlice, HTildeSlice, WTildeSlice))),
make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
......@@ -482,66 +834,31 @@ struct TransformConvBwdDataToGemm_v1
}
}
template <typename BLayout,
template <typename BLayout_ = BLayout,
typename std::enable_if<(NDimSpatial == 2 || NDimSpatial == 3) &&
(is_same_v<BLayout, tensor_layout::convolution::GKYXC> ||
is_same_v<BLayout, tensor_layout::convolution::GKZYXC>),
(is_same_v<BLayout_, tensor_layout::convolution::GKYXC> ||
is_same_v<BLayout_, tensor_layout::convolution::GKZYXC>),
bool>::type = false>
static auto MakeBDescriptor_BK0_N_BK1(
const std::array<index_t, NDimSpatial + 3>& out_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& /* out_g_n_k_wos_strides */,
const std::array<index_t, NDimSpatial + 3>& wei_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& /* wei_g_k_c_xs_strides */,
const std::array<index_t, NDimSpatial + 3>& in_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& /* in_g_n_c_wis_strides */,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& /* input_left_pads */,
const std::array<index_t, NDimSpatial>& /* input_right_pads */,
const std::array<index_t, NDimSpatial>& tildes)
__host__ __device__ auto MakeBDescriptor_BK0_N_BK1() const
{
index_t i_ztilde = tildes[ZIdx - NonSpatialDimsNum];
index_t i_ytilde = tildes[YIdx - NonSpatialDimsNum];
index_t i_xtilde = tildes[XIdx - NonSpatialDimsNum];
const index_t N = in_g_n_c_wis_lengths[1];
const index_t K = wei_g_k_c_xs_lengths[1];
const index_t C = wei_g_k_c_xs_lengths[2];
const index_t Do = NDimSpatial == 3 ? out_g_n_k_wos_lengths[DIdx] : 1;
const index_t Ho = out_g_n_k_wos_lengths[HIdx];
const index_t Wo = out_g_n_k_wos_lengths[WIdx];
const index_t Z = NDimSpatial == 3 ? wei_g_k_c_xs_lengths[ZIdx] : 1;
const index_t Y = wei_g_k_c_xs_lengths[YIdx];
const index_t X = wei_g_k_c_xs_lengths[XIdx];
const index_t ConvStrideD = conv_filter_strides[DIdx - NonSpatialDimsNum];
const index_t ConvStrideH = conv_filter_strides[HIdx - NonSpatialDimsNum];
const index_t ConvStrideW = conv_filter_strides[WIdx - NonSpatialDimsNum];
const index_t ConvDilationD = conv_filter_dilations[DIdx - NonSpatialDimsNum];
const index_t ConvDilationH = conv_filter_dilations[HIdx - NonSpatialDimsNum];
const index_t ConvDilationW = conv_filter_dilations[WIdx - NonSpatialDimsNum];
// assume packed
// k_y_x_c for 2d or k_z_y_x_c for 3d
const auto wei_grid_desc = make_wei_grid_desc<BLayout>(K, Z, Y, X, C);
const auto wei_grid_desc = MakeWeiGridDesc();
if constexpr(ConvBwdDataSpecialization ==
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
Filter1x1Stride1Pad0)
{
const index_t BK0 = math::integer_divide_ceil(K, BK1);
const index_t BK0 = math::integer_divide_ceil(K_, BK1);
// B: weight tensor
const auto wei_gemmbk0_gemmnraw_gemmbk1_grid_desc =
transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C)),
transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K_, C_)),
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(C)),
make_pass_through_transform(C_)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
make_naive_tensor_descriptor(make_tuple(N * Do * Ho * Wo, C), make_tuple(I0, I1));
make_naive_tensor_descriptor(make_tuple(N_ * Do_ * Ho_ * Wo_, C_), make_tuple(I0, I1));
const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(
......@@ -553,22 +870,10 @@ struct TransformConvBwdDataToGemm_v1
}
else
{
const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD);
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
const auto ZTilde = ConvStrideD / GcdStrideDilationD;
const auto YTilde = ConvStrideH / GcdStrideDilationH;
const auto XTilde = ConvStrideW / GcdStrideDilationW;
const auto ZDot = math::integer_divide_ceil(Z, ZTilde);
const auto YDot = math::integer_divide_ceil(Y, YTilde);
const auto XDot = math::integer_divide_ceil(X, XTilde);
// GemmK is different for each GEMM
const auto ZDotSlice = math::integer_divide_ceil(Z - i_ztilde, ZTilde);
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
const auto ZDotSlice = math::integer_divide_ceil(Z_ - IdxZTilde_, ZTilde_);
const auto YDotSlice = math::integer_divide_ceil(Y_ - IdxYTilde_, YTilde_);
const auto XDotSlice = math::integer_divide_ceil(X_ - IdxXTilde_, XTilde_);
// B weight tensor
if constexpr(NDimSpatial == 2)
......@@ -576,23 +881,23 @@ struct TransformConvBwdDataToGemm_v1
const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor(
wei_grid_desc,
make_tuple(
make_pass_through_transform(K),
make_embed_transform(make_tuple(YDot, YTilde),
make_tuple(ConvStrideH / GcdStrideDilationH, I1)),
make_embed_transform(make_tuple(XDot, XTilde),
make_tuple(ConvStrideW / GcdStrideDilationW, I1)),
make_pass_through_transform(C)),
make_pass_through_transform(K_),
make_embed_transform(make_tuple(YDot_, YTilde_),
make_tuple(ConvStrideH_ / GcdStrideDilationH_, I1)),
make_embed_transform(make_tuple(XDot_, XTilde_),
make_tuple(ConvStrideW_ / GcdStrideDilationW_, I1)),
make_pass_through_transform(C_)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto wei_k_ydotslice_xdotslice_c_grid_desc = transform_tensor_descriptor(
wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc,
make_tuple(make_pass_through_transform(K),
make_slice_transform(YDot, I0, YDotSlice),
make_slice_transform(XDot, I0, XDotSlice),
make_freeze_transform(i_ytilde),
make_freeze_transform(i_xtilde),
make_pass_through_transform(C)),
make_tuple(make_pass_through_transform(K_),
make_slice_transform(YDot_, I0, YDotSlice),
make_slice_transform(XDot_, I0, XDotSlice),
make_freeze_transform(IdxYTilde_),
make_freeze_transform(IdxXTilde_),
make_pass_through_transform(C_)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<3>{},
......@@ -608,8 +913,8 @@ struct TransformConvBwdDataToGemm_v1
const auto wei_gemmk_gemmnraw_grid_desc = transform_tensor_descriptor(
wei_k_ydotslice_xdotslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K)),
make_pass_through_transform(C)),
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K_)),
make_pass_through_transform(C_)),
make_tuple(Sequence<1, 2, 0>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
......@@ -636,15 +941,17 @@ struct TransformConvBwdDataToGemm_v1
const auto wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc =
transform_tensor_descriptor(
wei_grid_desc,
make_tuple(
make_pass_through_transform(K),
make_embed_transform(make_tuple(ZDot, ZTilde),
make_tuple(ConvStrideD / GcdStrideDilationD, I1)),
make_embed_transform(make_tuple(YDot, YTilde),
make_tuple(ConvStrideH / GcdStrideDilationH, I1)),
make_embed_transform(make_tuple(XDot, XTilde),
make_tuple(ConvStrideW / GcdStrideDilationW, I1)),
make_pass_through_transform(C)),
make_tuple(make_pass_through_transform(K_),
make_embed_transform(
make_tuple(ZDot_, ZTilde_),
make_tuple(ConvStrideD_ / GcdStrideDilationD_, I1)),
make_embed_transform(
make_tuple(YDot_, YTilde_),
make_tuple(ConvStrideH_ / GcdStrideDilationH_, I1)),
make_embed_transform(
make_tuple(XDot_, XTilde_),
make_tuple(ConvStrideW_ / GcdStrideDilationW_, I1)),
make_pass_through_transform(C_)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
......@@ -659,14 +966,14 @@ struct TransformConvBwdDataToGemm_v1
const auto wei_gemmk_zdotslice_ydotslice_xdotslice_c_grid_desc =
transform_tensor_descriptor(
wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc,
make_tuple(make_pass_through_transform(K),
make_slice_transform(ZDot, I0, ZDotSlice),
make_slice_transform(YDot, I0, YDotSlice),
make_slice_transform(XDot, I0, XDotSlice),
make_freeze_transform(i_ztilde),
make_freeze_transform(i_ytilde),
make_freeze_transform(i_xtilde),
make_pass_through_transform(C)),
make_tuple(make_pass_through_transform(K_),
make_slice_transform(ZDot_, I0, ZDotSlice),
make_slice_transform(YDot_, I0, YDotSlice),
make_slice_transform(XDot_, I0, XDotSlice),
make_freeze_transform(IdxZTilde_),
make_freeze_transform(IdxYTilde_),
make_freeze_transform(IdxXTilde_),
make_pass_through_transform(C_)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<3>{},
......@@ -686,8 +993,9 @@ struct TransformConvBwdDataToGemm_v1
const auto wei_gemmk_gemmnraw_grid_desc = transform_tensor_descriptor(
wei_gemmk_zdotslice_ydotslice_xdotslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K)),
make_pass_through_transform(C)),
make_tuple(
make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K_)),
make_pass_through_transform(C_)),
make_tuple(Sequence<1, 2, 3, 0>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
......@@ -716,66 +1024,20 @@ struct TransformConvBwdDataToGemm_v1
}
}
template <typename CLayout,
typename std::enable_if<(NDimSpatial == 2 || NDimSpatial == 3) &&
(is_same_v<CLayout, tensor_layout::convolution::GNHWC> ||
is_same_v<CLayout, tensor_layout::convolution::GNDHWC> ||
is_same_v<CLayout, tensor_layout::convolution::NHWGC> ||
is_same_v<CLayout, tensor_layout::convolution::NDHWGC> ||
is_same_v<CLayout, tensor_layout::convolution::G_NHW_C>),
bool>::type = false>
static auto
MakeCDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& out_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& /* out_g_n_k_wos_strides */,
const std::array<index_t, NDimSpatial + 3>& wei_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& /* wei_g_k_c_xs_strides */,
const std::array<index_t, NDimSpatial + 3>& in_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& in_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads,
const std::array<index_t, NDimSpatial>& tildes)
template <
typename CLayout_ = CLayout,
typename std::enable_if<(NDimSpatial == 2 || NDimSpatial == 3) &&
(is_same_v<CLayout_, tensor_layout::convolution::GNHWC> ||
is_same_v<CLayout_, tensor_layout::convolution::GNDHWC> ||
is_same_v<CLayout_, tensor_layout::convolution::NHWGC> ||
is_same_v<CLayout_, tensor_layout::convolution::NDHWGC> ||
is_same_v<CLayout_, tensor_layout::convolution::G_NHW_C>),
bool>::type = false>
__host__ __device__ auto MakeCDescriptor_M_N() const
{
index_t i_ztilde = tildes[ZIdx - NonSpatialDimsNum];
index_t i_ytilde = tildes[YIdx - NonSpatialDimsNum];
index_t i_xtilde = tildes[XIdx - NonSpatialDimsNum];
const index_t N = in_g_n_c_wis_lengths[1];
const index_t C = wei_g_k_c_xs_lengths[2];
const index_t Di = NDimSpatial == 3 ? in_g_n_c_wis_lengths[DIdx] : 1;
const index_t Hi = in_g_n_c_wis_lengths[HIdx];
const index_t Wi = in_g_n_c_wis_lengths[WIdx];
const index_t Do = NDimSpatial == 3 ? out_g_n_k_wos_lengths[DIdx] : 1;
const index_t Ho = out_g_n_k_wos_lengths[HIdx];
const index_t Wo = out_g_n_k_wos_lengths[WIdx];
const index_t Z = NDimSpatial == 3 ? wei_g_k_c_xs_lengths[ZIdx] : 1;
const index_t Y = wei_g_k_c_xs_lengths[YIdx];
const index_t X = wei_g_k_c_xs_lengths[XIdx];
const index_t InLeftPadD = input_left_pads[DIdx - NonSpatialDimsNum];
const index_t InLeftPadH = input_left_pads[HIdx - NonSpatialDimsNum];
const index_t InLeftPadW = input_left_pads[WIdx - NonSpatialDimsNum];
const index_t InRightPadD = input_right_pads[DIdx - NonSpatialDimsNum];
const index_t InRightPadH = input_right_pads[HIdx - NonSpatialDimsNum];
const index_t InRightPadW = input_right_pads[WIdx - NonSpatialDimsNum];
const index_t ConvStrideD = conv_filter_strides[DIdx - NonSpatialDimsNum];
const index_t ConvStrideH = conv_filter_strides[HIdx - NonSpatialDimsNum];
const index_t ConvStrideW = conv_filter_strides[WIdx - NonSpatialDimsNum];
const index_t ConvDilationD = conv_filter_dilations[DIdx - NonSpatialDimsNum];
const index_t ConvDilationH = conv_filter_dilations[HIdx - NonSpatialDimsNum];
const index_t ConvDilationW = conv_filter_dilations[WIdx - NonSpatialDimsNum];
// assume strided
// n_hi_wi_c for 2d n_di_hi_wi_c for 3d
const auto in_grid_desc =
make_in_grid_desc<NDimSpatial, CLayout>(N, Di, Hi, Wi, C, in_g_n_c_wis_strides);
const auto in_grid_desc = MakeInGridDesc();
if constexpr(ConvBwdDataSpecialization ==
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
......@@ -787,10 +1049,10 @@ struct TransformConvBwdDataToGemm_v1
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_grid_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)),
make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)),
make_pass_through_transform(C)),
make_pass_through_transform(N_),
make_embed_transform(make_tuple(I1, Ho_), make_tuple(I1, ConvStrideH_)),
make_embed_transform(make_tuple(I1, Wo_), make_tuple(I1, ConvStrideW_)),
make_pass_through_transform(C_)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
......@@ -798,8 +1060,8 @@ struct TransformConvBwdDataToGemm_v1
in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_freeze_transform(I0),
make_freeze_transform(I0),
make_merge_transform(make_tuple(N, Ho, Wo)),
make_pass_through_transform(C)),
make_merge_transform(make_tuple(N_, Ho_, Wo_)),
make_pass_through_transform(C_)),
make_tuple(Sequence<1>{}, Sequence<3>{}, Sequence<0, 2, 4>{}, Sequence<5>{}),
make_tuple(Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{}));
......@@ -818,11 +1080,11 @@ struct TransformConvBwdDataToGemm_v1
const auto in_n_x_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_grid_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(I1, Do), make_tuple(I1, ConvStrideD)),
make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)),
make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)),
make_pass_through_transform(C)),
make_pass_through_transform(N_),
make_embed_transform(make_tuple(I1, Do_), make_tuple(I1, ConvStrideD_)),
make_embed_transform(make_tuple(I1, Ho_), make_tuple(I1, ConvStrideH_)),
make_embed_transform(make_tuple(I1, Wo_), make_tuple(I1, ConvStrideW_)),
make_pass_through_transform(C_)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{},
......@@ -836,8 +1098,8 @@ struct TransformConvBwdDataToGemm_v1
make_tuple(make_freeze_transform(I0),
make_freeze_transform(I0),
make_freeze_transform(I0),
make_merge_transform(make_tuple(N, Do, Ho, Wo)),
make_pass_through_transform(C)),
make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_)),
make_pass_through_transform(C_)),
make_tuple(Sequence<1>{},
Sequence<3>{},
Sequence<5>{},
......@@ -861,36 +1123,21 @@ struct TransformConvBwdDataToGemm_v1
}
else
{
const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD);
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
const auto ZTilde = ConvStrideD / GcdStrideDilationD;
const auto YTilde = ConvStrideH / GcdStrideDilationH;
const auto XTilde = ConvStrideW / GcdStrideDilationW;
const auto DTilde =
Do + math::integer_divide_ceil(ConvDilationD * (Z - I1), ConvStrideD);
const auto HTilde =
Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH);
const auto WTilde =
Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
// only work on DTilde, HTilde and WTilde that contribute to
// non-padding area of input tensor
const auto IDTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadD - ConvDilationD * (ZTilde - I1)), ConvStrideD);
math::max(I0, InLeftPadD_ - ConvDilationD_ * (ZTilde_ - I1)), ConvStrideD_);
const auto IHTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH);
math::max(I0, InLeftPadH_ - ConvDilationH_ * (YTilde_ - I1)), ConvStrideH_);
const auto IWTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW);
math::max(I0, InLeftPadW_ - ConvDilationW_ * (XTilde_ - I1)), ConvStrideW_);
const auto IDTildeSliceEnd = math::min(
DTilde, math::integer_divide_ceil(InLeftPadD + Di - I1, ConvStrideD) + I1);
DTilde_, math::integer_divide_ceil(InLeftPadD_ + Di_ - I1, ConvStrideD_) + I1);
const auto IHTildeSliceEnd = math::min(
HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1);
HTilde_, math::integer_divide_ceil(InLeftPadH_ + Hi_ - I1, ConvStrideH_) + I1);
const auto IWTildeSliceEnd = math::min(
WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
WTilde_, math::integer_divide_ceil(InLeftPadW_ + Wi_ - I1, ConvStrideW_) + I1);
const auto DTildeSlice = IDTildeSliceEnd - IDTildeSliceBegin;
const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
......@@ -901,34 +1148,34 @@ struct TransformConvBwdDataToGemm_v1
{
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(make_pass_through_transform(N_),
make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
make_pass_through_transform(C_)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc =
transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(YTilde, HTilde),
make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(XTilde, WTilde),
make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(make_pass_through_transform(N_),
make_embed_transform(make_tuple(YTilde_, HTilde_),
make_tuple(ConvDilationH_, ConvStrideH_)),
make_embed_transform(make_tuple(XTilde_, WTilde_),
make_tuple(ConvDilationW_, ConvStrideW_)),
make_pass_through_transform(C_)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(
Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor(
in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_freeze_transform(i_ytilde),
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
make_freeze_transform(i_xtilde),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
make_pass_through_transform(C)),
make_tuple(make_pass_through_transform(N_),
make_freeze_transform(IdxYTilde_),
make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice),
make_freeze_transform(IdxXTilde_),
make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice),
make_pass_through_transform(C_)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
......@@ -944,8 +1191,8 @@ struct TransformConvBwdDataToGemm_v1
const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor(
in_n_htildeslice_wtildeslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)),
make_pass_through_transform(C)),
make_tuple(make_merge_transform(make_tuple(N_, HTildeSlice, WTildeSlice)),
make_pass_through_transform(C_)),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
......@@ -961,11 +1208,11 @@ struct TransformConvBwdDataToGemm_v1
{
const auto in_n_dip_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Di, InLeftPadD, InRightPadD),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(make_pass_through_transform(N_),
make_pad_transform(Di_, InLeftPadD_, InRightPadD_),
make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
make_pass_through_transform(C_)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
......@@ -974,14 +1221,14 @@ struct TransformConvBwdDataToGemm_v1
const auto in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc =
transform_tensor_descriptor(
in_n_dip_hip_wip_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(ZTilde, DTilde),
make_tuple(ConvDilationD, ConvStrideD)),
make_embed_transform(make_tuple(YTilde, HTilde),
make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(XTilde, WTilde),
make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(make_pass_through_transform(N_),
make_embed_transform(make_tuple(ZTilde_, DTilde_),
make_tuple(ConvDilationD_, ConvStrideD_)),
make_embed_transform(make_tuple(YTilde_, HTilde_),
make_tuple(ConvDilationH_, ConvStrideH_)),
make_embed_transform(make_tuple(XTilde_, WTilde_),
make_tuple(ConvDilationW_, ConvStrideW_)),
make_pass_through_transform(C_)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
......@@ -996,14 +1243,14 @@ struct TransformConvBwdDataToGemm_v1
const auto in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc =
transform_tensor_descriptor(
in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_freeze_transform(i_ztilde),
make_slice_transform(DTilde, IDTildeSliceBegin, DTildeSlice),
make_freeze_transform(i_ytilde),
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
make_freeze_transform(i_xtilde),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
make_pass_through_transform(C)),
make_tuple(make_pass_through_transform(N_),
make_freeze_transform(IdxZTilde_),
make_slice_transform(DTilde_, IDTildeSliceBegin, DTildeSlice),
make_freeze_transform(IdxYTilde_),
make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice),
make_freeze_transform(IdxXTilde_),
make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice),
make_pass_through_transform(C_)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
......@@ -1024,8 +1271,8 @@ struct TransformConvBwdDataToGemm_v1
const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor(
in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc,
make_tuple(
make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice)),
make_pass_through_transform(C)),
make_merge_transform(make_tuple(N_, DTildeSlice, HTildeSlice, WTildeSlice)),
make_pass_through_transform(C_)),
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
......@@ -1044,84 +1291,41 @@ struct TransformConvBwdDataToGemm_v1
}
// for input bias
template <typename CLayout,
template <typename CLayout_ = CLayout,
typename std::enable_if<NDimSpatial == 2 &&
(is_same_v<CLayout, tensor_layout::convolution::GC> ||
is_same_v<CLayout, tensor_layout::convolution::G_C>),
(is_same_v<CLayout_, tensor_layout::convolution::GC> ||
is_same_v<CLayout_, tensor_layout::convolution::G_C>),
bool>::type = false>
static auto
MakeCDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& out_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& /* out_g_n_k_wos_strides */,
const std::array<index_t, NDimSpatial + 3>& wei_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& /* wei_g_k_c_xs_strides */,
const std::array<index_t, NDimSpatial + 3>& in_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& /* in_g_n_c_wis_strides */,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& /* input_right_pads */,
const std::array<index_t, NDimSpatial>& /* tildes */)
__host__ __device__ auto MakeCDescriptor_M_N() const
{
const index_t N = in_g_n_c_wis_lengths[1];
const index_t C = wei_g_k_c_xs_lengths[2];
const index_t Hi = in_g_n_c_wis_lengths[3];
const index_t Wi = in_g_n_c_wis_lengths[4];
const index_t Ho = out_g_n_k_wos_lengths[3];
const index_t Wo = out_g_n_k_wos_lengths[4];
const index_t Y = wei_g_k_c_xs_lengths[3];
const index_t X = wei_g_k_c_xs_lengths[4];
const index_t InLeftPadH = input_left_pads[0];
const index_t InLeftPadW = input_left_pads[1];
const index_t ConvStrideH = conv_filter_strides[0];
const index_t ConvStrideW = conv_filter_strides[1];
const index_t ConvDilationH = conv_filter_dilations[0];
const index_t ConvDilationW = conv_filter_dilations[1];
if constexpr(ConvBwdDataSpecialization ==
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
Filter1x1Stride1Pad0)
{
const auto in_gemmm_gemmn_grid_desc =
make_naive_tensor_descriptor(make_tuple(N * Ho * Wo, C), make_tuple(I0, I1));
make_naive_tensor_descriptor(make_tuple(N_ * Ho_ * Wo_, C_), make_tuple(I0, I1));
return in_gemmm_gemmn_grid_desc;
}
else
{
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
const auto YTilde = ConvStrideH / GcdStrideDilationH;
const auto XTilde = ConvStrideW / GcdStrideDilationW;
const auto HTilde =
Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH);
const auto WTilde =
Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
// only work on HTilde and WTilde that contribute to non-padding area of input tensor
const auto IHTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH);
math::max(I0, InLeftPadH_ - ConvDilationH_ * (YTilde_ - I1)), ConvStrideH_);
const auto IWTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW);
math::max(I0, InLeftPadW_ - ConvDilationW_ * (XTilde_ - I1)), ConvStrideW_);
const auto IHTildeSliceEnd = math::min(
HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1);
HTilde_, math::integer_divide_ceil(InLeftPadH_ + Hi_ - I1, ConvStrideH_) + I1);
const auto IWTildeSliceEnd = math::min(
WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
WTilde_, math::integer_divide_ceil(InLeftPadW_ + Wi_ - I1, ConvStrideW_) + I1);
const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
// bias tensor
const auto in_gemmmraw_gemmnraw_grid_desc = make_naive_tensor_descriptor(
make_tuple(N * HTildeSlice * WTildeSlice, C), make_tuple(I0, I1));
make_tuple(N_ * HTildeSlice * WTildeSlice, C_), make_tuple(I0, I1));
const auto in_gemmm_gemmn_grid_desc = ck::tensor_operation::device::PadTensorDescriptor(
in_gemmmraw_gemmnraw_grid_desc,
......@@ -1131,6 +1335,25 @@ struct TransformConvBwdDataToGemm_v1
return in_gemmm_gemmn_grid_desc;
}
}
IndexType N_;
IndexType Di_, Hi_, Wi_;
IndexType Do_, Ho_, Wo_;
IndexType Z_, Y_, X_;
IndexType K_, C_;
IndexType DiStride_, HiStride_, WiStride_;
IndexType DoStride_, HoStride_, WoStride_;
IndexType CStrideTensorB_, CStrideTensorC_, KStrideTensorA_, KStrideTensorB_;
IndexType NStrideTensorA_, NStrideTensorC_;
IndexType ConvStrideD_, ConvStrideH_, ConvStrideW_;
IndexType ConvDilationD_, ConvDilationH_, ConvDilationW_;
IndexType InLeftPadD_, InLeftPadH_, InLeftPadW_;
IndexType InRightPadD_, InRightPadH_, InRightPadW_;
IndexType IdxZTilde_, IdxYTilde_, IdxXTilde_;
IndexType GcdStrideDilationD_, GcdStrideDilationH_, GcdStrideDilationW_;
IndexType ZTilde_, YTilde_, XTilde_;
IndexType DTilde_, HTilde_, WTilde_;
IndexType ZDot_, YDot_, XDot_;
};
} // namespace tensor_operation
......
......@@ -580,7 +580,7 @@ __device__ void amd_global_atomic_add_impl(const typename vector_type<T, N>::typ
tmp.template AsType<half2_t>()[i]);
});
}
#if defined(__gfx942__) || defined(__gfx950__)
#if defined(__gfx942__)
else if constexpr(is_same<T, bhalf_t>::value)
{
vector_type<bhalf_t, N> tmp{src_thread_data};
......
......@@ -24,14 +24,14 @@ using f8_fnuz_t = _BitInt(8);
using bf8_fnuz_t = unsigned _BitInt(8);
#if(defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) || defined(__gfx1200__) || \
defined(__gfx1201__) || defined(__gfx950__)) && \
defined(__gfx1201__)) && \
__HIP_DEVICE_COMPILE__
#define CK_FP8_CVT_FAST_PATH 1
#else
#define CK_FP8_CVT_FAST_PATH 0
#endif
#if(defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx950__)) && __HIP_DEVICE_COMPILE__
#if(defined(__gfx1200__) || defined(__gfx1201__)) && __HIP_DEVICE_COMPILE__
#define CK_OCP_FP8_CVT_FAST_PATH 1
#else
#define CK_OCP_FP8_CVT_FAST_PATH 0
......@@ -363,7 +363,7 @@ struct bf8_ocp_t
__host__ explicit operator float() const
#endif
{
#if defined(__gfx950__) || defined(__gfx1200__) || defined(__gfx1201__)
#if defined(__gfx1200__) || defined(__gfx1201__)
return fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data);
#else
return fp8_impl::cast_from_f8<float, wm, we, false>(
......@@ -377,7 +377,7 @@ struct bf8_ocp_t
__host__ explicit operator _Float16() const
#endif
{
#if defined(__gfx950__) || defined(__gfx1200__) || defined(__gfx1201__)
#if defined(__gfx1200__) || defined(__gfx1201__)
return static_cast<_Float16>(fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data));
#else
return fp8_impl::cast_from_f8<_Float16, wm, we, false>(
......
......@@ -13,6 +13,11 @@ namespace ck {
defined(__gfx1103__) || defined(__gfx11_generic__)
#define __gfx11__
#endif
#if defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx12_generic__)
#define __gfx12__
#endif
/********************************WAVE32 MODE***********************************************/
// src: fp16, dst: fp32
......@@ -99,7 +104,7 @@ struct intrin_wmma_bf16_16x16x16_bf16_w32<16, 16, Opsel>
// opsel usage
// false: D0.[0:15] = result
// true : D0.[16:31]= result
#if defined(__gfx11__)
#if defined(__gfx11__) || defined(__gfx12__)
reg_c.template AsType<bhalf16_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32(
reg_a, reg_b, reg_c.template AsType<bhalf16_t>()[Number<0>{}], Opsel);
......@@ -261,10 +266,6 @@ struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp>
// gfx12
/********************************WAVE32 MODE***********************************************/
#if defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx12_generic__)
#define __gfx12__
#endif
// src: fp16, dst: fp32
template <index_t MPerWave, index_t NPerWave>
struct intrin_wmma_f32_16x16x16_f16_w32_gfx12;
......
......@@ -5,7 +5,7 @@
namespace ck {
// Define the common macro for MI300 models
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) || defined(__gfx950__)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#define __gfx94__
#endif
......
......@@ -5,7 +5,6 @@
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
namespace ck {
......
......@@ -12,7 +12,7 @@
namespace ck {
// Define the common macro for MI300 models
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) || defined(__gfx950__)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#define __gfx94__
#endif
......
# ck_tile
[Back to the main page](../../README.md)
# Composable Kernel Tile
## concept
`ck_tile` provides a programming model with templated abstractions to enable users to implement performance-critical kernels for machine learning workloads. introduces following basic concepts to help users building your own operator
- tensor coordinate transformation, this is the core concept of layout/index transform abstraction in both compiler time and run time.
......
......@@ -52,6 +52,7 @@
#include "ck_tile/core/tensor/tile_elementwise.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/tensor/tile_window_linear.hpp"
#include "ck_tile/core/tensor/tile_window_utils.hpp"
#include "ck_tile/core/tensor/update_tile.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/functional.hpp"
......@@ -62,6 +63,7 @@
#include "ck_tile/core/utility/philox_rand.hpp"
#include "ck_tile/core/utility/random.hpp"
#include "ck_tile/core/utility/reduce_operator.hpp"
#include "ck_tile/core/utility/static_counter.hpp"
#include "ck_tile/core/utility/to_sequence.hpp"
#include "ck_tile/core/utility/transpose_vectors.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
......
......@@ -621,6 +621,65 @@ CK_TILE_DEVICE void buffer_load_fence(index_t cnt = 0)
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
}
CK_TILE_DEVICE void lds_load_fence(index_t cnt = 0)
{
asm volatile("s_waitcnt lgkmcnt(%0)" : : "n"(cnt) : "memory");
}
template <typename scalar_type, index_t N, bool pre_nop = false>
struct buffer_atomic_add_if;
template <bool pre_nop>
struct buffer_atomic_add_if<bf16_t, 2, pre_nop>
{
template <typename T>
CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/,
index_t flag = 1)
{
static_assert(sizeof(T) == 4);
auto save_exec = __builtin_amdgcn_read_exec();
using mbuf_t = float;
asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
"global_atomic_pk_add_bf16 %0, %1, %2 offset:%3\n"
"s_mov_b64 exec %5"
:
: "v"(v_offset),
"v"(bit_cast<mbuf_t>(value)),
"s"(res.xy),
"n"(i_offset),
"v"(flag),
"s"(save_exec)
: "memory");
}
};
template <typename scalar_type, index_t N, bool pre_nop = false>
struct buffer_atomic_add;
template <bool pre_nop>
struct buffer_atomic_add<bf16_t, 2, pre_nop>
{
template <typename T>
CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/,
index_t /*flag = 1*/)
{
static_assert(sizeof(T) == 4);
using mbuf_t = float;
asm volatile("global_atomic_pk_add_bf16 %0, %1, %2 offset:%3"
:
: "v"(v_offset), "v"(bit_cast<mbuf_t>(value)), "s"(res.xy), "n"(i_offset)
: "memory");
}
};
namespace impl {
// below type indicate the data type used for buffer load inline asm
// clang-format off
......@@ -810,6 +869,11 @@ CK_TILE_DEVICE void buffer_store_fence(index_t cnt = 0)
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
}
CK_TILE_DEVICE auto async_load_fence_raw(index_t cnt = 0)
{
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
}
// buffer load i8
CK_TILE_DEVICE_EXTERN int8_t
llvm_amdgcn_raw_buffer_load_i8(int32x4_t srsrc,
......@@ -2378,6 +2442,45 @@ CK_TILE_DEVICE void amd_buffer_atomic_add(const thread_buffer<T, N>& src_thread_
#endif
}
template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE void amd_buffer_atomic_add_raw(const thread_buffer<T, N>& src_thread_data,
T* p_dst_wave,
const index_t dst_thread_element_offset,
const index_t dst_linear_element_offset,
const bool dst_thread_element_valid,
const index_t dst_element_space_size,
bool_constant<pre_nop> = {})
{
const int32x4_t dst_wave_buffer_resource =
make_wave_buffer_resource(p_dst_wave, dst_element_space_size * sizeof(T));
index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T);
index_t dst_linear_addr_offset = dst_linear_element_offset * sizeof(T);
if constexpr(oob_conditional_check)
{
buffer_atomic_add_if<T, N, pre_nop>{}(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
0,
dst_linear_addr_offset,
dst_thread_element_valid);
}
else
{
buffer_atomic_add<T, N, pre_nop>{}(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
0,
dst_linear_addr_offset,
1);
}
}
// buffer_atomic_max requires:
// 1) p_dst_wave must point to global memory
// 2) p_dst_wave must be a wavewise pointer.
......
......@@ -73,6 +73,24 @@ CK_TILE_DEVICE void block_sync_lds()
#endif
}
CK_TILE_DEVICE void block_sync_load_raw(index_t cnt = 0)
{
#ifdef __gfx12__
asm volatile("s_wait_loadcnt %0 \n"
"s_barrier_signal -1 \n"
"s_barrier_wait -1"
:
: "n"(cnt)
: "memory");
#else
asm volatile("s_waitcnt vmcnt(%0) \n"
"s_barrier"
:
: "n"(cnt)
: "memory");
#endif
}
CK_TILE_DEVICE void block_sync_lds_direct_load()
{
asm volatile("\
......
......@@ -102,4 +102,28 @@ CK_TILE_DEVICE T warp_shuffle(const T& v_local, uint32_t src_lane)
#endif
}
template <typename T>
CK_TILE_DEVICE auto flag_to_exec(const T& v_flag)
{
static_assert(sizeof(T) == 4);
// per-thread v_flag store into 2x sgpr
uint32x2_t exec_flag;
asm volatile("v_cmp_ge_u32 %[s_exec_flag], %[v_flag], 1"
: [s_exec_flag] "=s"(exec_flag)
: [v_flag] "v"(v_flag));
return exec_flag;
}
template <typename X, typename Y>
CK_TILE_DEVICE auto cmp_lt_to_exec(const X& x, const Y& y)
{
static_assert(sizeof(X) == 4 && sizeof(Y) == 4);
// per-thread cmp store into 2x sgpr
uint32x2_t exec_flag;
asm volatile("v_cmp_lt_u32 %[s_exec_flag], %[v_x], %[v_y]"
: [s_exec_flag] "=s"(exec_flag)
: [v_x] "v"(x), [v_y] "v"(y));
return exec_flag;
}
} // namespace ck_tile
......@@ -4,10 +4,10 @@
#pragma once
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
defined(__gfx942__) || defined(__gfx950__)
defined(__gfx942__)
#define __gfx9__
#endif
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) || defined(__gfx950__)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#define __gfx94__
#endif
#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || \
......@@ -64,6 +64,7 @@
#define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE_WITH_NAN 1
#define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE 2
#define CK_TILE_FLOAT_TO_BFLOAT16_STANDARD_ASM 3
#define CK_TILE_FLOAT_TO_BFLOAT16_RTA_ASM 4
#ifndef CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT
#define CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE
......@@ -225,3 +226,7 @@
#ifndef CK_TILE_WORKAROUND_SWDEV_383542
#define CK_TILE_WORKAROUND_SWDEV_383542 1
#endif
#ifndef CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
#define CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID 1
#endif
......@@ -18,6 +18,7 @@ enum class bf16_rounding_mode
truncate_with_nan,
truncate,
standard_asm,
rta_asm, // round to nearest away
};
template <bf16_rounding_mode rounding =
......@@ -180,6 +181,39 @@ uint16_t float_to_bf16_rtn_asm(float f)
return uint16_t(u.int32);
}
// TODO: do we need this on host?
CK_TILE_HOST
uint16_t float_to_bf16_rta_asm(float f) { return float_to_bf16_rtn_raw(f); }
CK_TILE_DEVICE
uint16_t float_to_bf16_rta_asm(float f)
{
union
{
float fp32;
struct
{
uint16_t lo;
uint16_t hi;
};
} u = {f};
const uint32_t low_nan = 0x7fff;
const uint32_t hi_nan = 0x7fff0000;
using uint32x2_t = uint32_t __attribute__((ext_vector_type(2)));
uint32x2_t check_nan;
asm volatile("v_cmp_u_f32 %[s_cnan], %[v_x], %[v_x] \n"
"v_add3_u32 %[v_x], %[v_x], %[v_blo], 1 \n"
"v_cndmask_b32 %[v_x], %[v_x], %[v_bhi], %[s_cnan]"
: [s_cnan] "+s"(check_nan), [v_x] "+v"(u.fp32)
: [v_blo] "v"(low_nan), [v_bhi] "v"(hi_nan));
// Note: in above code snipet, we use hi 16 bit
return u.hi;
}
// Truncate instead of rounding, preserving SNaN
CK_TILE_HOST_DEVICE
constexpr uint16_t float_to_bf16_truc_nan_raw(float f)
......@@ -213,6 +247,8 @@ CK_TILE_HOST_DEVICE constexpr uint16_t float_to_bf16_raw(float f, constant<round
return float_to_bf16_rtn_asm(f);
else if constexpr(rounding == bf16_rounding_mode::truncate_with_nan)
return float_to_bf16_truc_nan_raw(f);
else if constexpr(rounding == bf16_rounding_mode::rta_asm)
return float_to_bf16_rta_asm(f);
else
return float_to_bf16_truc_raw(f);
}
......
......@@ -437,34 +437,74 @@ struct buffer_view<address_space_enum::global,
// i is offset of T, not X. i should be aligned to X
template <memory_operation_enum Op,
typename X,
bool oob_conditional_check = true,
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE void update(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
CK_TILE_DEVICE void update(index_t i,
index_t linear_offset,
bool is_valid_element,
const X& x,
bool_constant<oob_conditional_check> = {})
{
if constexpr(Op == memory_operation_enum::set)
{
this->template set<X>(i, linear_offset, is_valid_element, x);
this->template set<X, oob_conditional_check>(i, linear_offset, is_valid_element, x);
}
else if constexpr(Op == memory_operation_enum::atomic_add)
{
this->template atomic_add<X>(i, linear_offset, is_valid_element, x);
this->template atomic_add<X, oob_conditional_check>(
i, linear_offset, is_valid_element, x);
}
else if constexpr(Op == memory_operation_enum::atomic_max)
{
this->template atomic_max<X>(i, linear_offset, is_valid_element, x);
this->template atomic_max<X, oob_conditional_check>(
i, linear_offset, is_valid_element, x);
}
// FIXME: remove memory_operation_enum::add
else if constexpr(Op == memory_operation_enum::add)
{
auto tmp = this->template get<X>(i, linear_offset, is_valid_element);
this->template set<X>(i, linear_offset, is_valid_element, x + tmp);
auto tmp =
this->template get<X, oob_conditional_check>(i, linear_offset, is_valid_element);
this->template set<X, oob_conditional_check>(
i, linear_offset, is_valid_element, x + tmp);
// tmp += x;
// this->template set<X>(i, is_valid_element, tmp);
}
}
// i is offset of T, not X. i should be aligned to X
template <memory_operation_enum Op,
typename X,
bool oob_conditional_check = true,
bool pre_nop = false,
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE void update_raw(index_t i,
index_t linear_offset,
bool is_valid_element,
const X& x,
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{
if constexpr(Op == memory_operation_enum::set)
{
this->template set_raw<X, oob_conditional_check>(i, linear_offset, is_valid_element, x);
}
else if constexpr(Op == memory_operation_enum::atomic_add)
{
this->template atomic_add_raw<X, oob_conditional_check, pre_nop>(
i, linear_offset, is_valid_element, x);
}
else if constexpr(Op == memory_operation_enum::atomic_max)
{
// this->template atomic_max_raw<X>(i, linear_offset, is_valid_element, x);
}
}
// i is offset of T, not X. i should be aligned to X
template <typename X,
bool oob_conditional_check = true,
......@@ -533,6 +573,7 @@ struct buffer_view<address_space_enum::global,
}
template <typename X,
bool oob_conditional_check = true,
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
......@@ -585,6 +626,39 @@ struct buffer_view<address_space_enum::global,
}
template <typename X,
bool oob_conditional_check = true,
bool pre_nop = true,
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE void
atomic_add_raw(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
{
// using scalar_t = typename vector_traits<remove_cvref_t<T>>::scalar_type;
// X contains multiple T
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X should contain multiple T");
static_assert(get_address_space() == address_space_enum::global, "only support global mem");
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_atomic_add_raw<remove_cvref_t<T>,
t_per_x,
Coherence,
oob_conditional_check,
pre_nop>(
x, p_data_, i, linear_offset, is_valid_element, buffer_size_);
}
template <typename X,
bool oob_conditional_check = true,
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
......
......@@ -22,28 +22,32 @@ template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord,
index_t i_access = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_tile(const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{
return tile_window.load(number<-1>{}, bool_constant<oob_conditional_check>{});
return tile_window.load(number<i_access>{}, bool_constant<oob_conditional_check>{});
}
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename LinearBottomDims_,
index_t i_access = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_tile(const tile_window_linear<BottomTensorView_,
WindowLengths_,
TileDistribution_,
LinearBottomDims_>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{
return tile_window.load(number<-1>{}, bool_constant<oob_conditional_check>{});
return tile_window.load(number<i_access>{}, bool_constant<oob_conditional_check>{});
}
template <typename DistributedTensor_,
......@@ -51,15 +55,35 @@ template <typename DistributedTensor_,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord,
index_t i_access = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_tile(DistributedTensor_& dst_tile,
const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{
return tile_window.load(dst_tile, bool_constant<oob_conditional_check>{});
return tile_window.load(dst_tile, number<i_access>{}, bool_constant<oob_conditional_check>{});
}
template <typename DistributedTensor_,
typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename LinearBottomDims_,
index_t i_access = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_tile(DistributedTensor_& dst_tile,
const tile_window_linear<BottomTensorView_,
WindowLengths_,
TileDistribution_,
LinearBottomDims_>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{
return tile_window.load(dst_tile, number<i_access>{}, bool_constant<oob_conditional_check>{});
}
/**
......@@ -76,6 +100,7 @@ template <typename T,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord,
index_t i_access = -1,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto load_tile_raw(T& tile,
......@@ -83,11 +108,12 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile,
WindowLengths_,
TileDistribution_,
NumCoord>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{
tile_window.load_raw(
tile, number<-1>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
tile, number<i_access>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
}
template <typename T,
......@@ -95,6 +121,7 @@ template <typename T,
typename WindowLengths_,
typename TileDistribution_,
typename LinearBottomDims_,
index_t i_access = -1,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto load_tile_raw(T& tile,
......@@ -102,11 +129,12 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile,
WindowLengths_,
TileDistribution_,
LinearBottomDims_>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{
tile_window.load_raw(
tile, number<-1>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
tile, number<i_access>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
}
template <typename LdsTileWindow_,
......@@ -114,6 +142,7 @@ template <typename LdsTileWindow_,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord,
index_t i_access = -1,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto
......@@ -122,11 +151,14 @@ async_load_tile_raw(LdsTileWindow_&& lds_tile,
WindowLengths_,
TileDistribution_,
NumCoord>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{
return tile_window.async_load_raw(
lds_tile, number<-1>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
return tile_window.async_load_raw(lds_tile,
number<i_access>{},
bool_constant<oob_conditional_check>{},
bool_constant<pre_nop>{});
}
template <typename LdsTileWindow_,
......@@ -134,6 +166,7 @@ template <typename LdsTileWindow_,
typename WindowLengths_,
typename TileDistribution_,
typename LinearBottomDims_,
index_t i_access = -1,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_&& lds_tile,
......@@ -141,11 +174,14 @@ CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_&& lds_tile,
WindowLengths_,
TileDistribution_,
LinearBottomDims_>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{
return tile_window.async_load_raw(
lds_tile, number<-1>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
return tile_window.async_load_raw(lds_tile,
number<i_access>{},
bool_constant<oob_conditional_check>{},
bool_constant<pre_nop>{});
}
CK_TILE_DEVICE auto async_load_fence(index_t cnt = 0)
......
......@@ -201,4 +201,30 @@ CK_TILE_HOST_DEVICE constexpr auto get_y_unpacks_from_x_unpacks(YLengths, number
return unpacks;
}
namespace detail {
// check if 2 static_distributed_tensor has same data type and size of element
// but only difference in distribution
template <typename X, typename Y>
struct is_similiar_distributed_tensor
{
static constexpr bool value = false;
};
template <typename TypeX, typename DistX, typename TypeY, typename DistY>
struct is_similiar_distributed_tensor<static_distributed_tensor<TypeX, DistX>,
static_distributed_tensor<TypeY, DistY>>
{
using Tx = static_distributed_tensor<TypeX, DistX>;
using Ty = static_distributed_tensor<TypeY, DistY>;
static constexpr bool value = std::is_same_v<typename Tx::DataType, typename Ty::DataType> &&
Tx::get_thread_buffer_size() == Ty::get_thread_buffer_size();
};
template <typename X, typename Y>
inline constexpr bool is_similiar_distributed_tensor_v =
is_similiar_distributed_tensor<X, Y>::value;
} // namespace detail
} // namespace ck_tile
......@@ -333,6 +333,48 @@ struct tensor_view
coord.get_offset(), linear_offset, is_valid_element, x);
}
// X is vector of DataType.
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
template <typename X,
bool oob_conditional_check = true,
bool pre_nop = false,
typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr void
update_vectorized_elements_raw(const TensorCoord& coord,
index_t linear_offset,
const X& x,
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{
buf_.template update_raw<DstInMemOp, X, oob_conditional_check, pre_nop>(
coord.get_offset(),
linear_offset,
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
x);
}
template <typename X,
bool oob_conditional_check = true,
bool pre_nop = false,
typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr void
update_vectorized_elements_raw(const TensorCoord& coord,
index_t linear_offset,
bool is_valid_element,
const X& x,
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{
buf_.template update_raw<DstInMemOp, X, oob_conditional_check, pre_nop>(
coord.get_offset(), linear_offset, is_valid_element, x);
}
CK_TILE_HOST_DEVICE void print() const
{
printf("tensor_view{");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment