Commit c8c016dd authored by aska-0096's avatar aska-0096
Browse files

Merge branch 'develop' of https://github.com/ROCm/composable_kernel into update_cka8w8

parents e8ca3daf 4e731776
......@@ -20,7 +20,6 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp" // stare wywalic
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
namespace ck {
......@@ -522,7 +521,7 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
ComputeTypeA,
ComputeTypeB>;
using KernelArguments = GroupedGemmTileLoopKernelArguments<NumDTensor>;
using KernelArguments = GroupedGemmKernelArgument<NumDTensor>;
using Block2ETileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>;
using OffsettedLocalBlock2ETileMap = OffsettedBlockToCTileMap2<Block2ETileMap>;
......@@ -936,12 +935,31 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
return str.str();
}
void SetDeviceKernelArgs(Argument& arg,
void* p_dev_kernel_args,
const void* p_host_kernel_args) const
{
arg.p_dev_gemm_args_ = p_dev_kernel_args;
hip_check_error(hipMemcpyAsync(p_dev_kernel_args,
p_host_kernel_args,
GetDeviceKernelArgSize(&arg),
hipMemcpyHostToDevice));
}
virtual void SetDeviceKernelArgs(BaseArgument* p_arg,
void* p_dev_kernel_args,
const void* p_host_kernel_args) const override
{
return SetDeviceKernelArgs(
*dynamic_cast<Argument*>(p_arg), p_dev_kernel_args, p_host_kernel_args);
}
void SetDeviceKernelArgs(Argument& arg, void* p_dev_kernel_args) const
{
arg.p_dev_gemm_args_ = p_dev_kernel_args;
}
void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const override
virtual void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const override
{
return SetDeviceKernelArgs(*dynamic_cast<Argument*>(p_arg), p_dev_kernel_args);
}
......
#pragma once
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -557,12 +557,12 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
}
}
hipGetErrorString(hipMemcpyWithStream(arg.p_workspace_,
arg.gemm_desc_kernel_arg_.data(),
arg.gemm_desc_kernel_arg_.size() *
sizeof(GemmBiasTransKernelArg),
hipMemcpyHostToDevice,
stream_config.stream_id_));
hipGetErrorString(
hipMemcpyAsync(arg.p_workspace_,
arg.gemm_desc_kernel_arg_.data(),
arg.gemm_desc_kernel_arg_.size() * sizeof(GemmBiasTransKernelArg),
hipMemcpyHostToDevice,
stream_config.stream_id_));
float ave_time = 0;
......@@ -717,7 +717,24 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
{
return dynamic_cast<const Argument*>(p_arg)->group_count_ * sizeof(GemmBiasTransKernelArg);
auto p_arg_ = dynamic_cast<const Argument*>(p_arg);
if(p_arg_)
{
return p_arg_->group_count_ * sizeof(GemmBiasTransKernelArg);
}
else
throw std::runtime_error("The argument pointer is not an object of "
"DeviceGroupedGemmMultipleDXdlCShuffle::Argument structure!");
}
size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override
{
return GetWorkSpaceSize(p_arg);
}
void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const override
{
return this->SetWorkSpacePointer(p_arg, p_dev_kernel_args);
}
};
......
......@@ -445,6 +445,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
using Block2ETileMap = BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops<MPerBlock, NPerBlock>;
using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMapMLoops<Block2ETileMap>;
// TODO: replace with GroupedGemmKernelArgument
struct GemmBiasTransKernelArg
{
// pointers
......@@ -900,40 +901,58 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
return str.str();
}
static void SetDeviceKernelArgs(Argument& arg, const void* kernel_args)
{
arg.grouped_gemm_kernel_args_dev = kernel_args;
}
// polymorphic
void SetDeviceKernelArgs(BaseArgument* p_arg, const void* kernel_args) const override
void SetDeviceKernelArgs(BaseArgument* p_arg, void* kernel_args) const override
{
return SetDeviceKernelArgs(*dynamic_cast<Argument*>(p_arg), kernel_args);
auto arg_ptr = dynamic_cast<Argument*>(p_arg);
if(arg_ptr)
{
arg_ptr->grouped_gemm_kernel_args_dev = kernel_args;
}
else
throw std::runtime_error("The argument pointer is not an object of "
"DeviceGroupedGemm_Xdl_Fixed_NK::Argument structure!");
}
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
{
auto arg = *dynamic_cast<const Argument*>(p_arg);
return arg.group_count_ * arg.barrier_size_grp_ * sizeof(uint32_t);
auto arg_ptr = dynamic_cast<const Argument*>(p_arg);
if(arg_ptr)
{
return arg_ptr->group_count_ * arg_ptr->barrier_size_grp_ * sizeof(uint32_t);
}
else
throw std::runtime_error("The argument pointer is not an object of "
"DeviceGroupedGemm_Xdl_Fixed_NK::Argument structure!");
}
size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override
{
auto arg = *dynamic_cast<const Argument*>(p_arg);
return arg.group_count_ * sizeof(GroupedGemmKernelArgument<NumDTensor>);
auto arg_ptr = dynamic_cast<const Argument*>(p_arg);
if(arg_ptr)
{
return arg_ptr->group_count_ * sizeof(GroupedGemmKernelArgument<NumDTensor>);
}
else
throw std::runtime_error("The argument pointer is not an object of "
"DeviceGroupedGemm_Xdl_Fixed_NK::Argument structure!");
}
void SetWorkSpacePointer(BaseArgument* p_arg,
void* p_workspace,
const StreamConfig& stream_config = StreamConfig{}) const override
{
auto p_arg_ = dynamic_cast<Argument*>(p_arg);
p_arg_->p_workspace_ = p_workspace;
auto arg_ptr = dynamic_cast<Argument*>(p_arg);
if(arg_ptr)
{
arg_ptr->p_workspace_ = p_workspace;
}
else
throw std::runtime_error("The argument pointer is not an object of "
"DeviceGroupedGemm_Xdl_Fixed_NK::Argument structure!");
hip_check_error(
hipMemsetAsync(p_workspace, 0, GetWorkSpaceSize(p_arg), stream_config.stream_id_));
hipMemsetAsync(p_workspace, 0, GetWorkSpaceSize(arg_ptr), stream_config.stream_id_));
}
static void SetKBatch(Argument& arg, index_t k_batch) { arg.UpdateKBatch(k_batch); }
......@@ -941,7 +960,26 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
// polymorphic
void SetKBatch(BaseArgument* p_arg, index_t k_batch) const override
{
return SetKBatch(*dynamic_cast<Argument*>(p_arg), k_batch);
auto arg_ptr = dynamic_cast<Argument*>(p_arg);
if(arg_ptr)
{
arg_ptr->UpdateKBatch(k_batch);
}
else
throw std::runtime_error("The argument pointer is not an object of "
"DeviceGroupedGemm_Xdl_Fixed_NK::Argument structure!");
}
void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const override
{
auto arg_ptr = dynamic_cast<Argument*>(p_arg);
if(arg_ptr)
{
arg_ptr->UpdateKBatch(kbatch);
}
else
throw std::runtime_error("The argument pointer is not an object of "
"DeviceGroupedGemm_Xdl_Fixed_NK::Argument structure!");
}
};
......
......@@ -421,11 +421,11 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
}
hip_check_error(
hipMemcpyWithStream(arg.p_workspace_,
arg.gemm_kernel_args_.data(),
arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg),
hipMemcpyHostToDevice,
stream_config.stream_id_));
hipMemcpyAsync(arg.p_workspace_,
arg.gemm_kernel_args_.data(),
arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg),
hipMemcpyHostToDevice,
stream_config.stream_id_));
float ave_time = 0;
......@@ -538,10 +538,16 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
return false;
}
if(std::is_same_v<EDataType, ck::bhalf_t> && arg.K_BATCH > 1 && !is_bf16_atomic_supported())
{
return false;
}
bool supported = true;
for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i)
{
const auto& a = arg.gemm_kernel_args_[i].karg_;
const auto& a = arg.gemm_kernel_args_[i].karg_;
bool group_arg_valid = GridwiseGemm::CheckValidity(a);
if(not group_arg_valid)
{
......@@ -631,16 +637,42 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
{
return dynamic_cast<const Argument*>(p_arg)->gemm_kernel_args_.size() *
sizeof(GemmTransKernelArg);
auto p_arg_ = dynamic_cast<const Argument*>(p_arg);
if(p_arg_)
{
return p_arg_->gemm_kernel_args_.size() * sizeof(GemmTransKernelArg);
}
else
throw std::runtime_error(
"The argument pointer is not an object of "
"DeviceGroupedGemmMultipleDSplitKXdlCShuffle::Argument structure!");
}
size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override
{
return GetWorkSpaceSize(p_arg);
}
// TODO: deperecation notice.
static void SetKBatchSize(Argument& arg, index_t kbatch) { arg.UpdateKBatch(kbatch); }
// polymorphic
void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const override
{
return SetKBatchSize(*dynamic_cast<Argument*>(p_arg), kbatch);
auto p_arg_ = dynamic_cast<Argument*>(p_arg);
if(p_arg_)
{
p_arg_->UpdateKBatch(kbatch);
}
else
throw std::runtime_error(
"The argument pointer is not an object of "
"DeviceGroupedGemmMultipleDSplitKXdlCShuffle::Argument structure!");
}
void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const override
{
return this->SetWorkSpacePointer(p_arg, p_dev_kernel_args);
}
};
......
......@@ -14,6 +14,8 @@
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1r2.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/utility/workgroup_barrier.hpp"
#include "ck/utility/reduction_functions_accumulate.hpp"
namespace ck {
......@@ -38,7 +40,7 @@ __global__ void
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, p_shared, karg);
karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, p_shared, karg, karg.p_workspace_);
#else
ignore = karg;
#endif // end of if (defined(__gfx9__))
......@@ -62,7 +64,13 @@ __global__ void
__shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, p_shared_0, p_shared_1, karg);
karg.p_a_grid,
karg.p_b_grid,
karg.p_c_grid,
p_shared_0,
p_shared_1,
karg,
karg.p_workspace_);
#else
ignore = karg;
#endif // end of if (defined(__gfx9__))
......@@ -521,7 +529,9 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, Streamk_sel_, Grid_size_},
p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_},
p_c_grid{p_c_grid_}
p_c_grid{p_c_grid_},
block_2_ctile_map_streamk(
M_, N_, AK0Number * CalculateKPadded(K_, 1), Grid_size_, Streamk_sel_)
{
}
......@@ -529,6 +539,13 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
const ADataType* p_a_grid;
const BDataType* p_b_grid;
CDataType* p_c_grid;
BlockToCTileMap_GemmStreamK_v2<MPerBlock,
NPerBlock,
KPerBlock,
StreamKReductionStrategy::Atomic,
8,
4>
block_2_ctile_map_streamk;
};
struct SplitKBatchOffset
......@@ -853,6 +870,19 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
}
__host__ __device__ static constexpr auto
GetCBlockDescriptor_MShuffle_MPerShuffle_NShuffle_NPerShuffle()
{
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
return make_naive_tensor_descriptor_packed(
make_tuple(Number<MXdlPerWave / CShuffleMXdlPerWavePerShuffle>{},
Number<CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl>{},
Number<NXdlPerWave / CShuffleNXdlPerWavePerShuffle>{},
Number<CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>{}));
}
using BlockwiseGemmPipe =
remove_cvref_t<decltype(BlockGemmPipeline_Selector<
BlkGemmPipelineVer,
......@@ -1118,6 +1148,34 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
return c_grid_desc_mblock_mperblock_nblock_nperblock;
}
__host__ __device__ static constexpr auto GetClusterLengthReduction()
{
// TODO: assume C is row major
// TODO: we always first loop over N, then M
constexpr auto NPerBlockPow2 = math::next_power_of_two<NPerBlock>();
constexpr auto NPerBlockReduction =
NPerBlockPow2 / CShuffleBlockTransferScalarPerVector_NPerBlock;
constexpr auto MPerBlockReduction =
(BlockSize + NPerBlockReduction - 1) / NPerBlockReduction;
return Sequence<MPerBlockReduction, NPerBlockReduction>{};
}
__host__ __device__ static constexpr auto GetPartialAccBlockDescriptor()
{
const auto c_partial_acc_block_m_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(MPerBlock, NPerBlock),
make_tuple(NPerBlock, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(MPerBlock, NPerBlock),
make_tuple(I1, MPerBlock));
}
}();
return c_partial_acc_block_m_n;
}
using Block2CTileMap_streamk = BlockToCTileMap_GemmStreamK_v2<MPerBlock,
NPerBlock,
KPerBlock,
......@@ -1132,22 +1190,42 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
const BDataType* p_b_grid,
CDataType* p_c_grid,
void* p_shared,
Problem& problem)
Problem& problem,
void* p_workspace)
{
const AElementwiseOperation a_element_op{};
const BElementwiseOperation b_element_op{};
const CElementwiseOperation c_element_op{};
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
Block2CTileMap_streamk block_2_ctile_map_streamk(problem.M,
problem.N,
AK0Number * problem.KPadded,
problem.Grid_size,
problem.Streamk_sel);
uint32_t iter_start, iter_end;
bool is_sk_block, is_dp_block;
bool is_sk_block, is_dp_block, is_reduction_block;
index_t num_k_block_main_loop;
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n, problem.MBlock, problem.NBlock);
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
uint32_t* p_semaphore = reinterpret_cast<uint32_t*>(
reinterpret_cast<char*>(p_workspace) +
block_2_ctile_map_streamk.get_workspace_size_for_acc(sizeof(AccDataType)));
for(auto block_idx = get_block_1d_id();
block_idx < block_2_ctile_map_streamk.get_grid_dims();
block_idx += gridDim.x)
......@@ -1163,6 +1241,214 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
block_2_ctile_map_streamk.get_block_itr(block_idx, iter_start, iter_end);
num_k_block_main_loop = iter_end - iter_start;
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
{
is_reduction_block = static_cast<uint32_t>(block_idx) >=
block_2_ctile_map_streamk.reduction_start_block_idx;
if(is_reduction_block)
{
// descriptors
constexpr auto cluster_length_reduce = GetClusterLengthReduction();
constexpr auto reduce_desc = make_cluster_descriptor(cluster_length_reduce);
const auto reduce_thread_cluster_idx =
reduce_desc.CalculateBottomIndex(make_multi_index(block_idx));
const auto thread_m_cluster_id = reduce_thread_cluster_idx[I0];
const auto thread_n_cluster_id = reduce_thread_cluster_idx[I1];
constexpr auto MReduceIters = math::integer_divide_ceil(
Number<MPerBlock>{}, cluster_length_reduce.At(I0));
constexpr auto NReduceIters = math::integer_divide_ceil(
Number<NPerBlock>{},
cluster_length_reduce.At(I1) *
Number<CShuffleBlockTransferScalarPerVector_NPerBlock>{});
constexpr auto acc_thread_buf_load_desc = make_naive_tensor_descriptor_packed(
make_tuple(I1, Number<CShuffleBlockTransferScalarPerVector_NPerBlock>{}));
constexpr auto acc_thread_buf_store_desc =
make_naive_tensor_descriptor_packed(make_tuple(
I1, I1, I1, Number<CShuffleBlockTransferScalarPerVector_NPerBlock>{}));
constexpr auto c_partial_acc_block_m_n = GetPartialAccBlockDescriptor();
constexpr auto partial_acc_load_step_n =
make_multi_index(0,
cluster_length_reduce.At(I1) *
CShuffleBlockTransferScalarPerVector_NPerBlock);
constexpr auto partial_acc_load_step_n_reverse = make_multi_index(
0,
-1 * cluster_length_reduce.At(I1).value * (NReduceIters - 1) *
CShuffleBlockTransferScalarPerVector_NPerBlock);
constexpr auto partial_acc_load_step_m =
make_multi_index(cluster_length_reduce.At(I0), 0);
constexpr auto partial_acc_store_step_n =
make_multi_index(0,
0,
0,
cluster_length_reduce.At(I1) *
CShuffleBlockTransferScalarPerVector_NPerBlock);
constexpr auto partial_acc_store_step_n_reverse = make_multi_index(
0,
0,
0,
-1 * cluster_length_reduce.At(I1).value * (NReduceIters - 1) *
CShuffleBlockTransferScalarPerVector_NPerBlock);
constexpr auto partial_acc_store_step_m =
make_multi_index(0, cluster_length_reduce.At(I0), 0, 0);
StaticBuffer<AddressSpaceEnum::Vgpr,
AccDataType,
CShuffleBlockTransferScalarPerVector_NPerBlock,
true>
parcial_acc_buf;
StaticBuffer<AddressSpaceEnum::Vgpr,
AccDataType,
CShuffleBlockTransferScalarPerVector_NPerBlock,
true>
acc_buf;
// start to compute
auto reduction_idx =
block_idx - block_2_ctile_map_streamk.reduction_start_block_idx;
auto spatial_idx = block_2_ctile_map_streamk.tile_to_spatial(
reduction_idx, problem.M, problem.N);
workgroup_barrier wg_barrier(p_semaphore);
uint32_t tile_acc_offset_start =
block_2_ctile_map_streamk.get_acc_buffer_offset_from_tile(reduction_idx);
uint32_t tile_acc_offset_end =
block_2_ctile_map_streamk.get_acc_buffer_offset_from_tile(reduction_idx +
1);
__syncthreads();
auto acc_load = ThreadwiseTensorSliceTransfer_v2<
AccDataType, // SrcData,
AccDataType, // DstData,
decltype(c_partial_acc_block_m_n), // SrcDesc,
decltype(acc_thread_buf_load_desc), // DstDesc,
Sequence<1,
CShuffleBlockTransferScalarPerVector_NPerBlock>, // SliceLengths,
Sequence<0, 1>, // DimAccessOrder,
1, // SrcVectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock, // SrcScalarPerVector,
1, // SrcScalarStrideInVector,
false // SrcResetCoordinateAfterRun,
>{c_partial_acc_block_m_n,
make_multi_index(thread_m_cluster_id,
thread_n_cluster_id *
CShuffleBlockTransferScalarPerVector_NPerBlock)};
auto acc_store = ThreadwiseTensorSliceTransfer_v1r3<
AccDataType, // SrcData,
CDataType, // DstData,
decltype(acc_thread_buf_store_desc), // SrcDesc,
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), // DstDesc,
CElementwiseOperation, // ElementwiseOperation,
Sequence<1,
1,
1,
CShuffleBlockTransferScalarPerVector_NPerBlock>, // SliceLengths,
Sequence<0, 1, 2, 3>, // DimAccessOrder,
3, // DstVectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock, // DstScalarPerVector,
InMemoryDataOperationEnum::Set, // InMemoryDataOperationEnum DstInMemOp,
1, // DstScalarStrideInVector,
false // DstResetCoordinateAfterRun,
>{c_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(__builtin_amdgcn_readfirstlane(spatial_idx[I0]),
thread_m_cluster_id,
__builtin_amdgcn_readfirstlane(spatial_idx[I1]),
thread_n_cluster_id *
CShuffleBlockTransferScalarPerVector_NPerBlock),
CElementwiseOperation{}};
wg_barrier.wait_eq(reduction_idx, tile_acc_offset_end - tile_acc_offset_start);
if(threadIdx.x == 0)
{
p_semaphore[reduction_idx] = 0;
}
using Accumulation = ck::detail::
AccumulateWithNanCheck<false /*PropagateNan*/, reduce::Add, AccDataType>;
for(int i_m = 0; i_m < MReduceIters; i_m++)
{
static_for<0, NReduceIters, 1>{}([&](auto i_n_reduce) {
acc_buf.Clear();
for(auto i = tile_acc_offset_start; i < tile_acc_offset_end; i++)
{
auto c_partial_acc_buf =
make_dynamic_buffer<AddressSpaceEnum::Global,
AmdBufferCoherenceEnum::GLC>(
reinterpret_cast<AccDataType*>(p_workspace) +
i * c_partial_acc_block_m_n.GetElementSpaceSize(),
c_partial_acc_block_m_n.GetElementSpaceSize());
acc_load.Run(c_partial_acc_block_m_n,
c_partial_acc_buf,
acc_thread_buf_load_desc,
make_tuple(I0, I0),
parcial_acc_buf);
static_for<0, CShuffleBlockTransferScalarPerVector_NPerBlock, 1>{}(
[&](auto i_vec) {
constexpr auto offset =
acc_thread_buf_load_desc.CalculateOffset(
make_tuple(0, i_vec));
Accumulation::Calculate(acc_buf(Number<offset>{}),
parcial_acc_buf[Number<offset>{}]);
});
}
if(thread_n_cluster_id *
CShuffleBlockTransferScalarPerVector_NPerBlock <
NPerBlock)
{
acc_store.Run(acc_thread_buf_store_desc,
make_tuple(I0, I0, I0, I0),
acc_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
}
if constexpr(NReduceIters != 1)
{
if constexpr(i_n_reduce != (NReduceIters - 1))
{
acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
partial_acc_load_step_n);
acc_store.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock,
partial_acc_store_step_n);
}
else
{
acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
partial_acc_load_step_n_reverse);
acc_store.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock,
partial_acc_store_step_n_reverse);
}
}
});
{
acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
partial_acc_load_step_m);
acc_store.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock,
partial_acc_store_step_m);
}
}
continue;
}
}
// offset for last acc buffer of this block
uint32_t block_acc_offset =
(block_2_ctile_map_streamk.get_acc_buffer_offset_from_block(block_idx + 1) - 1) *
MPerBlock * NPerBlock;
while(true)
{
uint32_t current_iter_length = __builtin_amdgcn_readfirstlane(
......@@ -1173,33 +1459,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
iter_end - 1, tile_idx, iter_offset);
iter_offset = __builtin_amdgcn_readfirstlane(iter_offset - current_iter_length + 1);
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(problem.M,
problem.MPadded,
problem.K,
problem.KPadded,
problem.StrideA,
problem.AK0);
const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(problem.K,
problem.KPadded,
problem.N,
problem.NPadded,
problem.StrideB,
problem.BK0);
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n, problem.MBlock, problem.NBlock);
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
auto block_work_idx =
block_2_ctile_map_streamk.tile_to_spatial(tile_idx, problem.M, problem.N);
......@@ -1363,11 +1622,20 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
constexpr auto c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle =
GetCBlockDescriptor_MShuffle_MPerShuffle_NShuffle_NPerShuffle();
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<CShuffleDataType*>(p_shared),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.GetElementSpaceSize());
auto c_partial_acc_buf =
make_dynamic_buffer<AddressSpaceEnum::Global, AmdBufferCoherenceEnum::GLC>(
reinterpret_cast<AccDataType*>(p_workspace) + block_acc_offset,
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
.GetElementSpaceSize());
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
transform_tensor_descriptor(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
......@@ -1477,7 +1745,34 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
c_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(block_m_id, 0, block_n_id, 0),
c_element_op};
// LDS to global partial acc
auto c_block_copy_lds_to_partial_acc = ThreadGroupTensorSliceTransfer_v6r1r2<
ThisThreadBlock, // index_t BlockSize,
CElementwiseOperation, // ElementwiseOperation,
// InMemoryDataOperationEnum::Set, // DstInMemOp,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave *
NPerXdl>, // BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
CShuffleDataType, // typename SrcData,
CShuffleDataType, // typename DstData,
decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
decltype(c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
false, // bool ThreadTransferSrcResetCoordinateAfterRun, => need to be
// false, othre wise has scratch
false> // bool ThreadTransferDstResetCoordinateAfterRun, => need to be
// false, othre wise has scratch
{c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(0, 0, 0, 0),
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
make_multi_index(0, 0, 0, 0),
c_element_op};
// space filling curve for threadwise C in VGPR
constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
......@@ -1535,15 +1830,40 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
}
else if(is_sk_block)
{
// each block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global
.template Run<decltype(c_shuffle_block_buf),
decltype(c_grid_buf),
InMemoryDataOperationEnum::AtomicAdd>(
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Atomic)
{
// each block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global
.template Run<decltype(c_shuffle_block_buf),
decltype(c_grid_buf),
InMemoryDataOperationEnum::AtomicAdd>(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
c_shuffle_block_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
}
else if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
{
// constexpr offset
c_block_copy_lds_to_partial_acc.SetSrcSliceOrigin(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
c_shuffle_block_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
make_tuple(0, 0, 0, 0));
c_block_copy_lds_to_partial_acc.SetDstSliceOrigin(
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
make_tuple(MXdlPerWave, 0, NXdlPerWave, 0));
c_block_copy_lds_to_partial_acc
.template Run<decltype(c_shuffle_block_buf),
decltype(c_partial_acc_buf),
InMemoryDataOperationEnum::Set>(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
c_shuffle_block_buf,
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
c_partial_acc_buf);
}
}
if constexpr(access_id < num_access - 1)
......@@ -1555,15 +1875,33 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
}
});
}
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
{
if(is_sk_block)
{
// increase the counter for this tile
workgroup_barrier wg_barrier(p_semaphore);
wg_barrier.inc(tile_idx);
}
}
} // shuffle c and write-out end
// exit condition
iter_end -= current_iter_length;
if(iter_end <= iter_start)
break;
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
{
block_acc_offset -= MPerBlock * NPerBlock;
}
// make sure next loop LDS is ready for use
block_sync_lds();
}
}
} // while loop
} // for loop
}
template <bool HasMainKBlockLoop,
......@@ -1574,19 +1912,43 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
CDataType* p_c_grid,
void* p_shared_0,
void* p_shared_1,
Problem& problem)
Problem& problem,
void* p_workspace)
{
const AElementwiseOperation a_element_op{};
const BElementwiseOperation b_element_op{};
const CElementwiseOperation c_element_op{};
Block2CTileMap_streamk block_2_ctile_map_streamk(
problem.M, problem.N, AK0Number * problem.KPadded, problem.Grid_size);
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
uint32_t iter_start, iter_end;
bool is_sk_block, is_dp_block; //, is_padding_block; //, is_reduction_block;
bool is_sk_block, is_dp_block, is_reduction_block;
index_t num_k_block_main_loop;
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n, problem.MBlock, problem.NBlock);
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
Block2CTileMap_streamk block_2_ctile_map_streamk(problem.M,
problem.N,
AK0Number * problem.KPadded,
problem.Grid_size,
problem.Streamk_sel);
for(auto block_idx = get_block_1d_id();
block_idx < block_2_ctile_map_streamk.get_grid_dims();
block_idx += gridDim.x)
......@@ -1601,6 +1963,235 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
block_2_ctile_map_streamk.get_block_itr(block_idx, iter_start, iter_end);
num_k_block_main_loop = iter_end - iter_start;
uint32_t* p_semaphore = reinterpret_cast<uint32_t*>(
reinterpret_cast<char*>(p_workspace) +
block_2_ctile_map_streamk.get_workspace_size_for_acc(sizeof(AccDataType)));
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
{
is_reduction_block = static_cast<uint32_t>(block_idx) >=
block_2_ctile_map_streamk.reduction_start_block_idx;
if(is_reduction_block)
{
// descriptors
constexpr auto cluster_length_reduce = GetClusterLengthReduction();
constexpr auto reduce_desc = make_cluster_descriptor(cluster_length_reduce);
const auto reduce_thread_cluster_idx =
reduce_desc.CalculateBottomIndex(make_multi_index(block_idx));
const auto thread_m_cluster_id = reduce_thread_cluster_idx[I0];
const auto thread_n_cluster_id = reduce_thread_cluster_idx[I1];
constexpr auto MReduceIters = math::integer_divide_ceil(
Number<MPerBlock>{}, cluster_length_reduce.At(I0));
constexpr auto NReduceIters = math::integer_divide_ceil(
Number<NPerBlock>{},
cluster_length_reduce.At(I1) *
Number<CShuffleBlockTransferScalarPerVector_NPerBlock>{});
constexpr auto acc_thread_buf_load_desc = make_naive_tensor_descriptor_packed(
make_tuple(I1, Number<CShuffleBlockTransferScalarPerVector_NPerBlock>{}));
constexpr auto acc_thread_buf_store_desc =
make_naive_tensor_descriptor_packed(make_tuple(
I1, I1, I1, Number<CShuffleBlockTransferScalarPerVector_NPerBlock>{}));
constexpr auto c_partial_acc_block_m_n = GetPartialAccBlockDescriptor();
constexpr auto partial_acc_load_step_n =
make_multi_index(0,
cluster_length_reduce.At(I1) *
CShuffleBlockTransferScalarPerVector_NPerBlock);
constexpr auto partial_acc_load_step_n_reverse = make_multi_index(
0,
-1 * cluster_length_reduce.At(I1).value * (NReduceIters - 1) *
CShuffleBlockTransferScalarPerVector_NPerBlock);
constexpr auto partial_acc_load_step_m =
make_multi_index(cluster_length_reduce.At(I0), 0);
constexpr auto partial_acc_store_step_n =
make_multi_index(0,
0,
0,
cluster_length_reduce.At(I1) *
CShuffleBlockTransferScalarPerVector_NPerBlock);
constexpr auto partial_acc_store_step_n_reverse = make_multi_index(
0,
0,
0,
-1 * cluster_length_reduce.At(I1).value * (NReduceIters - 1) *
CShuffleBlockTransferScalarPerVector_NPerBlock);
constexpr auto partial_acc_store_step_m =
make_multi_index(0, cluster_length_reduce.At(I0), 0, 0);
StaticBuffer<AddressSpaceEnum::Vgpr,
AccDataType,
CShuffleBlockTransferScalarPerVector_NPerBlock,
true>
parcial_acc_buf;
StaticBuffer<AddressSpaceEnum::Vgpr,
AccDataType,
CShuffleBlockTransferScalarPerVector_NPerBlock,
true>
acc_buf;
// start to compute
auto reduction_idx =
block_idx - block_2_ctile_map_streamk.reduction_start_block_idx;
auto spatial_idx = block_2_ctile_map_streamk.tile_to_spatial(
reduction_idx, problem.M, problem.N);
workgroup_barrier wg_barrier(p_semaphore);
uint32_t tile_acc_offset_start =
block_2_ctile_map_streamk.get_acc_buffer_offset_from_tile(reduction_idx);
uint32_t tile_acc_offset_end =
block_2_ctile_map_streamk.get_acc_buffer_offset_from_tile(reduction_idx +
1);
uint32_t expected_count = tile_acc_offset_end - tile_acc_offset_start;
if(threadIdx.x == 0)
{
p_semaphore[reduction_idx] = 0;
}
__syncthreads();
auto acc_load = ThreadwiseTensorSliceTransfer_v2<
AccDataType, // SrcData,
AccDataType, // DstData,
decltype(c_partial_acc_block_m_n), // SrcDesc,
decltype(acc_thread_buf_load_desc), // DstDesc,
Sequence<1,
CShuffleBlockTransferScalarPerVector_NPerBlock>, // SliceLengths,
Sequence<0, 1>, // DimAccessOrder,
1, // SrcVectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock, // SrcScalarPerVector,
1, // SrcScalarStrideInVector,
false // SrcResetCoordinateAfterRun,
>{c_partial_acc_block_m_n,
make_multi_index(thread_m_cluster_id,
thread_n_cluster_id *
CShuffleBlockTransferScalarPerVector_NPerBlock)};
auto acc_store = ThreadwiseTensorSliceTransfer_v1r3<
AccDataType, // SrcData,
CDataType, // DstData,
decltype(acc_thread_buf_store_desc), // SrcDesc,
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), // DstDesc,
CElementwiseOperation, // ElementwiseOperation,
Sequence<1,
1,
1,
CShuffleBlockTransferScalarPerVector_NPerBlock>, // SliceLengths,
Sequence<0, 1, 2, 3>, // DimAccessOrder,
3, // DstVectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock, // DstScalarPerVector,
InMemoryDataOperationEnum::Set, // InMemoryDataOperationEnum DstInMemOp,
1, // DstScalarStrideInVector,
false // DstResetCoordinateAfterRun,
>{c_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(__builtin_amdgcn_readfirstlane(spatial_idx[I0]),
thread_m_cluster_id,
__builtin_amdgcn_readfirstlane(spatial_idx[I1]),
thread_n_cluster_id *
CShuffleBlockTransferScalarPerVector_NPerBlock),
CElementwiseOperation{}};
#if 0
if(threadIdx.x == 0) {
printf("bid:%d, rid:%d, os:%d,%d, spatial:%d,%d\n", static_cast<int>(blockIdx.x),
reduction_idx, __builtin_amdgcn_readfirstlane(tile_acc_offset_start), __builtin_amdgcn_readfirstlane(tile_acc_offset_end),
__builtin_amdgcn_readfirstlane(spatial_idx[I0]),
__builtin_amdgcn_readfirstlane(spatial_idx[I1]));
}
#endif
if(threadIdx.x == 0)
{
atomicAdd(&p_semaphore[reduction_idx], 1);
}
wg_barrier.wait_eq(p_semaphore[reduction_idx], expected_count);
using Accumulation = ck::detail::
AccumulateWithNanCheck<false /*PropagateNan*/, reduce::Add, AccDataType>;
for(int i_m = 0; i_m < MReduceIters; i_m++)
{
static_for<0, NReduceIters, 1>{}([&](auto i_n_reduce) {
acc_buf.Clear();
for(auto i = tile_acc_offset_start; i < tile_acc_offset_end; i++)
{
auto c_partial_acc_buf =
make_dynamic_buffer<AddressSpaceEnum::Global,
AmdBufferCoherenceEnum::GLC>(
reinterpret_cast<AccDataType*>(p_workspace) +
i * c_partial_acc_block_m_n.GetElementSpaceSize(),
c_partial_acc_block_m_n.GetElementSpaceSize());
acc_load.Run(c_partial_acc_block_m_n,
c_partial_acc_buf,
acc_thread_buf_load_desc,
make_tuple(I0, I0),
parcial_acc_buf);
static_for<0, CShuffleBlockTransferScalarPerVector_NPerBlock, 1>{}(
[&](auto i_vec) {
constexpr auto offset =
acc_thread_buf_load_desc.CalculateOffset(
make_tuple(0, i_vec));
Accumulation::Calculate(acc_buf(Number<offset>{}),
parcial_acc_buf[Number<offset>{}]);
});
}
if(thread_n_cluster_id *
CShuffleBlockTransferScalarPerVector_NPerBlock <
NPerBlock)
{
acc_store.Run(acc_thread_buf_store_desc,
make_tuple(I0, I0, I0, I0),
acc_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
}
if constexpr(NReduceIters != 1)
{
if constexpr(i_n_reduce != (NReduceIters - 1))
{
acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
partial_acc_load_step_n);
acc_store.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock,
partial_acc_store_step_n);
}
else
{
acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
partial_acc_load_step_n_reverse);
acc_store.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock,
partial_acc_store_step_n_reverse);
}
}
});
{
acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
partial_acc_load_step_m);
acc_store.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock,
partial_acc_store_step_m);
}
}
continue;
}
}
// offset for last acc buffer of this block
uint32_t block_acc_offset =
(block_2_ctile_map_streamk.get_acc_buffer_offset_from_block(block_idx + 1) - 1) *
MPerBlock * NPerBlock;
while(true)
{
uint32_t current_iter_length = __builtin_amdgcn_readfirstlane(
......@@ -1611,33 +2202,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
iter_end - 1, tile_idx, iter_offset);
iter_offset = __builtin_amdgcn_readfirstlane(iter_offset - current_iter_length + 1);
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(problem.M,
problem.MPadded,
problem.K,
problem.KPadded,
problem.StrideA,
problem.AK0);
const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(problem.K,
problem.KPadded,
problem.N,
problem.NPadded,
problem.StrideB,
problem.BK0);
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n, problem.MBlock, problem.NBlock);
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
auto block_work_idx =
block_2_ctile_map_streamk.tile_to_spatial(tile_idx, problem.M, problem.N);
......@@ -1811,11 +2375,20 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
constexpr auto c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle =
GetCBlockDescriptor_MShuffle_MPerShuffle_NShuffle_NPerShuffle();
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<CShuffleDataType*>(p_shared_0),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.GetElementSpaceSize());
auto c_partial_acc_buf =
make_dynamic_buffer<AddressSpaceEnum::Global, AmdBufferCoherenceEnum::GLC>(
reinterpret_cast<AccDataType*>(p_workspace) + block_acc_offset,
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
.GetElementSpaceSize());
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
transform_tensor_descriptor(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
......@@ -1925,6 +2498,35 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
make_multi_index(block_m_id, 0, block_n_id, 0),
c_element_op};
// LDS to global partial acc
auto c_block_copy_lds_to_partial_acc = ThreadGroupTensorSliceTransfer_v6r1r2<
ThisThreadBlock, // index_t BlockSize,
CElementwiseOperation, // ElementwiseOperation,
// InMemoryDataOperationEnum::Set, // DstInMemOp,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave *
NPerXdl>, // BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
CShuffleDataType, // typename SrcData,
CShuffleDataType, // typename DstData,
decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
decltype(c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
false, // bool ThreadTransferSrcResetCoordinateAfterRun, => need to be
// false, othre wise has scratch
false> // bool ThreadTransferDstResetCoordinateAfterRun, => need to be
// false, othre wise has scratch
{c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(0, 0, 0, 0),
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
make_multi_index(0, 0, 0, 0),
c_element_op};
// space filling curve for threadwise C in VGPR
constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
......@@ -1982,15 +2584,40 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
}
else if(is_sk_block)
{
// each block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global
.template Run<decltype(c_shuffle_block_buf),
decltype(c_grid_buf),
InMemoryDataOperationEnum::AtomicAdd>(
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Atomic)
{
// each block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global
.template Run<decltype(c_shuffle_block_buf),
decltype(c_grid_buf),
InMemoryDataOperationEnum::AtomicAdd>(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
c_shuffle_block_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
}
else if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
{
// constexpr offset
c_block_copy_lds_to_partial_acc.SetSrcSliceOrigin(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
c_shuffle_block_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
make_tuple(0, 0, 0, 0));
c_block_copy_lds_to_partial_acc.SetDstSliceOrigin(
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
make_tuple(MXdlPerWave, 0, NXdlPerWave, 0));
c_block_copy_lds_to_partial_acc
.template Run<decltype(c_shuffle_block_buf),
decltype(c_partial_acc_buf),
InMemoryDataOperationEnum::Set>(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
c_shuffle_block_buf,
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
c_partial_acc_buf);
}
}
if constexpr(access_id < num_access - 1)
{
......@@ -2002,6 +2629,27 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
}
});
}
// exit condition
iter_end -= current_iter_length;
if(iter_end <= iter_start)
break;
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
{
block_acc_offset -= MPerBlock * NPerBlock;
}
// make sure next loop LDS is ready for use
block_sync_lds();
}
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
{
if(is_sk_block)
{
// increase the counter for this tile
workgroup_barrier wg_barrier(p_semaphore);
wg_barrier.inc(0);
}
}
}
}
......
......@@ -38,8 +38,7 @@ __global__ void
// __attribute__((amdgpu_waves_per_eu(1, 1)))
kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -13,245 +13,614 @@
namespace ck {
namespace tensor_operation {
namespace {
template <
index_t NDimSpatial,
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization ConvBwdDataSpecialization,
index_t AK1,
index_t BK1,
index_t GemmMPerBlock,
index_t GemmNPerBlock,
index_t GemmKPerBlock,
bool DoPadGemmM,
bool DoPadGemmN,
typename ALayout,
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization ConvBwdDataSpecialization>
constexpr auto make_out_grid_desc(const index_t N,
const index_t Do,
const index_t Ho,
const index_t Wo,
const index_t K,
const std::array<index_t, NDimSpatial + 3>& out_g_n_k_wos_strides)
typename BLayout,
typename CLayout,
bool SplitN = false,
typename ADataType = float,
typename CDataType = float,
index_t NumGroupsToMerge = 1,
typename IndexType = index_t>
struct TransformConvBwdDataToGemm_v1
{
const auto KStride = Number<1>{};
private:
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
if constexpr(is_same_v<ALayout, tensor_layout::convolution::NHWGK>)
{
const index_t NStride = out_g_n_k_wos_strides[1];
const index_t HiStride = out_g_n_k_wos_strides[3];
const index_t WiStride = out_g_n_k_wos_strides[4];
if constexpr(ConvBwdDataSpecialization ==
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
Filter1x1Stride1Pad0)
{
static constexpr auto NonSpatialDimsNum = Number<3>{};
return make_naive_tensor_descriptor(make_tuple(N * Ho * Wo, K),
make_tuple(WiStride, KStride));
}
else
static constexpr auto DIdx = NonSpatialDimsNum;
static constexpr auto HIdx =
NDimSpatial == 2 ? NonSpatialDimsNum : Number<NonSpatialDimsNum + 1>{};
static constexpr auto WIdx =
NDimSpatial == 2 ? Number<NonSpatialDimsNum + 1>{} : Number<NonSpatialDimsNum + 2>{};
static constexpr auto ZIdx = NonSpatialDimsNum;
static constexpr auto YIdx =
NDimSpatial == 2 ? NonSpatialDimsNum : Number<NonSpatialDimsNum + 1>{};
static constexpr auto XIdx =
NDimSpatial == 2 ? Number<NonSpatialDimsNum + 1>{} : Number<NonSpatialDimsNum + 2>{};
template <typename ConvDimsType>
static long_index_t calculate_element_space_size_impl(const ConvDimsType& lengths,
const ConvDimsType& strides,
index_t i)
{
long_index_t acc = 1;
for(; i < (NDimSpatial + 3); i++)
{
return make_naive_tensor_descriptor(make_tuple(N, Ho, Wo, K),
make_tuple(NStride, HiStride, WiStride, KStride));
acc +=
static_cast<long_index_t>(lengths[i] - I1) * static_cast<long_index_t>(strides[i]);
}
return acc;
}
else if constexpr(is_same_v<ALayout, tensor_layout::convolution::NDHWGK>)
template <typename ConvDimsType>
static IndexType GetSplitedNSize(const ConvDimsType& a_g_n_k_wos_lengths,
const ConvDimsType& a_g_n_k_wos_strides,
const ConvDimsType& c_g_n_c_wis_lengths,
const ConvDimsType& c_g_n_c_wis_strides)
{
const index_t NStride = out_g_n_k_wos_strides[1];
const index_t DoStride = out_g_n_k_wos_strides[3];
const index_t HoStride = out_g_n_k_wos_strides[4];
const index_t WoStride = out_g_n_k_wos_strides[5];
if constexpr(ConvBwdDataSpecialization ==
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
Filter1x1Stride1Pad0)
const long_index_t a_element_space_size =
calculate_element_space_size_impl(a_g_n_k_wos_lengths, a_g_n_k_wos_strides, I1);
const long_index_t c_element_space_size =
calculate_element_space_size_impl(c_g_n_c_wis_lengths, c_g_n_c_wis_strides, I1);
const long_index_t element_space_size = math::max(a_element_space_size * sizeof(ADataType),
c_element_space_size * sizeof(CDataType));
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
const IndexType N = a_g_n_k_wos_lengths[I1];
if(element_space_size > TwoGB)
{
// Minimum divisor of N to not exceed 2GB
const auto divisor = math::integer_divide_ceil(element_space_size, TwoGB);
return make_naive_tensor_descriptor(make_tuple(N * Do * Ho * Wo, K),
make_tuple(WoStride, KStride));
if(divisor <= static_cast<double>(N))
{
// Find least divisor of N larger than element_space_size / TwoGB
// Iterate up to sqrt(N). There are no divisors above this value.
for(IndexType least_divisor = divisor; least_divisor * least_divisor <= N;
least_divisor++)
{
if(N % least_divisor == 0)
{
return N / least_divisor;
}
}
// Not found, process one Convolution N per block
return 1;
}
else
{
// Not possible to support even after split N.
// Too large tensor.
return N;
}
}
else
{
return make_naive_tensor_descriptor(
make_tuple(N, Do, Ho, Wo, K),
make_tuple(NStride, DoStride, HoStride, WoStride, KStride));
// Split N is not needed.
return N;
}
}
else if constexpr(is_same_v<ALayout, tensor_layout::convolution::GNHWK>)
public:
__host__ __device__ constexpr TransformConvBwdDataToGemm_v1() {}
template <typename TransformConvBwdDataToGemm_v1Base>
__host__ __device__ TransformConvBwdDataToGemm_v1(
const TransformConvBwdDataToGemm_v1Base& transform_conv_bwd_data_to_gemm_base)
: N_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.N_)},
Di_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.Di_)},
Hi_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.Hi_)},
Wi_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.Wi_)},
Do_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.Do_)},
Ho_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.Ho_)},
Wo_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.Wo_)},
Z_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.Z_)},
Y_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.Y_)},
X_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.X_)},
K_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.K_)},
C_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.C_)},
DiStride_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.DiStride_)},
HiStride_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.HiStride_)},
WiStride_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.WiStride_)},
DoStride_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.DoStride_)},
HoStride_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.HoStride_)},
WoStride_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.WoStride_)},
CStrideTensorB_{
static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.CStrideTensorB_)},
CStrideTensorC_{
static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.CStrideTensorC_)},
KStrideTensorA_{
static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.KStrideTensorA_)},
KStrideTensorB_{
static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.KStrideTensorB_)},
NStrideTensorA_{
static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.NStrideTensorA_)},
NStrideTensorC_{
static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.NStrideTensorC_)},
ConvStrideD_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.ConvStrideD_)},
ConvStrideH_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.ConvStrideH_)},
ConvStrideW_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.ConvStrideW_)},
ConvDilationD_{
static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.ConvDilationD_)},
ConvDilationH_{
static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.ConvDilationH_)},
ConvDilationW_{
static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.ConvDilationW_)},
InLeftPadD_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.InLeftPadD_)},
InLeftPadH_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.InLeftPadH_)},
InLeftPadW_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.InLeftPadW_)},
InRightPadD_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.InRightPadD_)},
InRightPadH_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.InRightPadH_)},
InRightPadW_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.InRightPadW_)},
IdxZTilde_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.IdxZTilde_)},
IdxYTilde_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.IdxYTilde_)},
IdxXTilde_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.IdxXTilde_)},
GcdStrideDilationD_{
static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.GcdStrideDilationD_)},
GcdStrideDilationH_{
static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.GcdStrideDilationH_)},
GcdStrideDilationW_{
static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.GcdStrideDilationW_)},
ZTilde_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.ZTilde_)},
YTilde_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.YTilde_)},
XTilde_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.XTilde_)},
DTilde_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.DTilde_)},
HTilde_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.HTilde_)},
WTilde_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.WTilde_)},
ZDot_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.ZDot_)},
YDot_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.YDot_)},
XDot_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.XDot_)}
{
// assume packed
if constexpr(ConvBwdDataSpecialization ==
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
Filter1x1Stride1Pad0)
}
template <typename ConvDimsType, typename ConvSpatialDimsType>
__host__ __device__
TransformConvBwdDataToGemm_v1(const ConvDimsType& a_g_n_k_wos_lengths,
const ConvDimsType& a_g_n_k_wos_strides,
const ConvDimsType& b_g_k_c_xs_lengths,
const ConvDimsType& b_g_k_c_xs_strides,
const ConvDimsType& c_g_n_c_wis_lengths,
const ConvDimsType& c_g_n_c_wis_strides,
const ConvSpatialDimsType& conv_filter_strides,
const ConvSpatialDimsType& conv_filter_dilations,
const ConvSpatialDimsType& input_left_pads,
const ConvSpatialDimsType& input_right_pads,
const ConvSpatialDimsType& tildes)
: Hi_{c_g_n_c_wis_lengths[HIdx]},
Wi_{c_g_n_c_wis_lengths[WIdx]},
Ho_{a_g_n_k_wos_lengths[HIdx]},
Wo_{a_g_n_k_wos_lengths[WIdx]},
Y_{b_g_k_c_xs_lengths[YIdx]},
X_{b_g_k_c_xs_lengths[XIdx]},
K_{a_g_n_k_wos_lengths[I2]},
C_{b_g_k_c_xs_lengths[I2]},
HiStride_{c_g_n_c_wis_strides[HIdx]},
WiStride_{c_g_n_c_wis_strides[WIdx]},
HoStride_{a_g_n_k_wos_strides[HIdx]},
WoStride_{a_g_n_k_wos_strides[WIdx]},
CStrideTensorB_{b_g_k_c_xs_strides[I2]},
CStrideTensorC_{c_g_n_c_wis_strides[I2]},
KStrideTensorA_{a_g_n_k_wos_strides[I2]},
KStrideTensorB_{b_g_k_c_xs_strides[I1]},
NStrideTensorA_{a_g_n_k_wos_strides[I1]},
NStrideTensorC_{c_g_n_c_wis_strides[I1]},
ConvStrideH_{conv_filter_strides[HIdx - NonSpatialDimsNum]},
ConvStrideW_{conv_filter_strides[WIdx - NonSpatialDimsNum]},
ConvDilationH_{conv_filter_dilations[HIdx - NonSpatialDimsNum]},
ConvDilationW_{conv_filter_dilations[WIdx - NonSpatialDimsNum]},
InLeftPadH_{input_left_pads[HIdx - NonSpatialDimsNum]},
InLeftPadW_{input_left_pads[WIdx - NonSpatialDimsNum]},
InRightPadH_{input_right_pads[HIdx - NonSpatialDimsNum]},
InRightPadW_{input_right_pads[WIdx - NonSpatialDimsNum]},
IdxYTilde_{tildes[YIdx - NonSpatialDimsNum]},
IdxXTilde_{tildes[XIdx - NonSpatialDimsNum]}
{
static_assert(is_same_v<ConvSpatialDimsType, std::array<IndexType, NDimSpatial>> ||
is_same_v<ConvSpatialDimsType, ck::Array<IndexType, NDimSpatial>>);
static_assert(is_same_v<ConvDimsType, std::array<IndexType, NDimSpatial + I3>> ||
is_same_v<ConvDimsType, ck::Array<IndexType, NDimSpatial + I3>>);
if constexpr(SplitN)
{
return make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
N_ = GetSplitedNSize(
a_g_n_k_wos_lengths, a_g_n_k_wos_strides, c_g_n_c_wis_lengths, c_g_n_c_wis_strides);
}
else
{
return make_naive_tensor_descriptor_packed(make_tuple(N, Ho, Wo, K));
N_ = c_g_n_c_wis_lengths[I1];
}
}
else if constexpr(is_same_v<ALayout, tensor_layout::convolution::GNDHWK>)
{
// assume packed
if constexpr(ConvBwdDataSpecialization ==
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
Filter1x1Stride1Pad0)
if constexpr(NDimSpatial == 3)
{
return make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo, K));
Di_ = c_g_n_c_wis_lengths[DIdx];
Do_ = a_g_n_k_wos_lengths[DIdx];
Z_ = b_g_k_c_xs_lengths[ZIdx];
DiStride_ = c_g_n_c_wis_strides[DIdx];
DoStride_ = a_g_n_k_wos_strides[DIdx];
ConvStrideD_ = conv_filter_strides[DIdx - NonSpatialDimsNum];
ConvDilationD_ = conv_filter_dilations[DIdx - NonSpatialDimsNum];
InLeftPadD_ = input_left_pads[DIdx - NonSpatialDimsNum];
InRightPadD_ = input_right_pads[DIdx - NonSpatialDimsNum];
IdxZTilde_ = tildes[ZIdx - NonSpatialDimsNum];
GcdStrideDilationD_ = math::gcd(ConvStrideD_, ConvDilationD_);
ZTilde_ = ConvStrideD_ / GcdStrideDilationD_;
DTilde_ = Do_ + math::integer_divide_ceil(ConvDilationD_ * (Z_ - I1), ConvStrideD_);
ZDot_ = math::integer_divide_ceil(Z_, ZTilde_);
}
else
{
return make_naive_tensor_descriptor_packed(make_tuple(N, Do, Ho, Wo, K));
Di_ = Do_ = Z_ = ZTilde_ = ConvStrideD_ = DTilde_ = ZDot_ = 1;
InLeftPadD_ = InRightPadD_ = DiStride_ = DoStride_ = IdxZTilde_ = 0;
}
}
else
{
throw std::runtime_error("wrong! unsupported layout: " + ALayout::name());
}
}
template <typename BLayout>
constexpr auto make_wei_grid_desc(
const index_t K, const index_t Z, const index_t Y, const index_t X, const index_t C)
{
GcdStrideDilationH_ = math::gcd(ConvStrideH_, ConvDilationH_);
GcdStrideDilationW_ = math::gcd(ConvStrideW_, ConvDilationW_);
if constexpr(is_same_v<BLayout, tensor_layout::convolution::GKYXC>)
{
return make_naive_tensor_descriptor_packed(make_tuple(K, Y, X, C));
}
else if constexpr(is_same_v<BLayout, tensor_layout::convolution::GKZYXC>)
{
return make_naive_tensor_descriptor_packed(make_tuple(K, Z, Y, X, C));
}
else
{
throw std::runtime_error("wrong! unsupported layout: " + BLayout::name());
}
}
template <index_t NDimSpatial, typename CLayout>
constexpr auto make_in_grid_desc(const index_t N,
const index_t Di,
const index_t Hi,
const index_t Wi,
const index_t C,
const std::array<index_t, NDimSpatial + 3>& in_g_n_c_wis_strides)
{
YTilde_ = ConvStrideH_ / GcdStrideDilationH_;
XTilde_ = ConvStrideW_ / GcdStrideDilationW_;
if constexpr(is_same_v<CLayout, tensor_layout::convolution::GNHWC> ||
is_same_v<CLayout, tensor_layout::convolution::NHWGC> ||
is_same_v<CLayout, tensor_layout::convolution::G_NHW_C>)
{
return make_naive_tensor_descriptor(make_tuple(N, Hi, Wi, C),
make_tuple(in_g_n_c_wis_strides[1],
in_g_n_c_wis_strides[3],
in_g_n_c_wis_strides[4],
in_g_n_c_wis_strides[2]));
HTilde_ = Ho_ + math::integer_divide_ceil(ConvDilationH_ * (Y_ - I1), ConvStrideH_);
WTilde_ = Wo_ + math::integer_divide_ceil(ConvDilationW_ * (X_ - I1), ConvStrideW_);
YDot_ = math::integer_divide_ceil(Y_, YTilde_);
XDot_ = math::integer_divide_ceil(X_, XTilde_);
}
else if constexpr(is_same_v<CLayout, tensor_layout::convolution::GNDHWC> ||
is_same_v<CLayout, tensor_layout::convolution::NDHWGC>)
#if 0 // At now not supported to split tensor
__host__ bool AreDescriptorsSmallerThan2GB() const
{
return make_naive_tensor_descriptor(make_tuple(N, Di, Hi, Wi, C),
make_tuple(in_g_n_c_wis_strides[1],
in_g_n_c_wis_strides[3],
in_g_n_c_wis_strides[4],
in_g_n_c_wis_strides[5],
in_g_n_c_wis_strides[2]));
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
const long_index_t in_desc_space_size =
I1 + (N_ - I1) * NStrideTensorC_ + (Di_ - I1) * DiStride_ + (Hi_ - I1) * HiStride_ +
(Wi_ - I1) * WiStride_ + (C_ - I1) * CStrideTensorC_;
const long_index_t out_desc_space_size =
I1 + (N_ - I1) * NStrideTensorA_ + (Do_ - I1) * DoStride_ + (Ho_ - I1) * HoStride_ +
(Wo_ - I1) * WoStride_ + (K_ - I1) * KStrideTensorA_;
bool is_a_descriptor_smaller_than_2GB = (out_desc_space_size * sizeof(ADataType)) <= TwoGB;
bool is_c_descriptor_smaller_than_2GB = (in_desc_space_size * sizeof(CDataType)) <= TwoGB;
return is_a_descriptor_smaller_than_2GB && is_c_descriptor_smaller_than_2GB;
}
else
__host__ auto SplitConvProblem(const ADataType* a_grid_ptr_base,
CDataType* c_grid_ptr_base) const
{
throw std::runtime_error("wrong! unsupported layout: " + CLayout::name());
}
}
// Create copies
auto conv_to_gemm_transformer_left = *this;
auto conv_to_gemm_transformer_right = *this;
IndexType a_right_offset = 0;
IndexType c_right_offset = 0;
// Calculate real filter size
const IndexType z_eff = (Z_ - 1) * ConvDilationD_ + 1;
const IndexType y_eff = (Y_ - 1) * ConvDilationH_ + 1;
const IndexType x_eff = (X_ - 1) * ConvDilationW_ + 1;
// Calculate start position in input for right tensor
const IndexType di_right_transformer_start_idx = (Do_ / 2) * ConvStrideD_;
const IndexType hi_right_transformer_start_idx = (Ho_ / 2) * ConvStrideH_;
const IndexType wi_right_transformer_start_idx = (Wo_ / 2) * ConvStrideW_;
// Calculate last position in input for left tensor
const IndexType di_left_transformer_end_idx = (Do_ / 2 - 1) * ConvStrideD_ + z_eff;
const IndexType hi_left_transformer_end_idx = (Ho_ / 2 - 1) * ConvStrideH_ + y_eff;
const IndexType wi_left_transformer_end_idx = (Wo_ / 2 - 1) * ConvStrideW_ + x_eff;
// Allow to split if whole left padding will be in left tensor and right padding in right
// tensor
const bool is_possible_to_split_d = Do_ != 1 &&
di_right_transformer_start_idx > InLeftPadD_ &&
di_left_transformer_end_idx <= (InLeftPadD_ + Di_);
const bool is_possible_to_split_h = Ho_ != 1 &&
hi_right_transformer_start_idx > InLeftPadH_ &&
hi_left_transformer_end_idx <= (InLeftPadH_ + Hi_);
const bool is_possible_to_split_w = Wo_ != 1 &&
wi_right_transformer_start_idx > InLeftPadW_ &&
wi_left_transformer_end_idx <= (InLeftPadW_ + Wi_);
if(is_possible_to_split_d)
{
// Apply new sizes
// Split output on half
conv_to_gemm_transformer_left.Do_ = Do_ / 2;
conv_to_gemm_transformer_right.Do_ = Do_ - Do_ / 2;
// Assign left padding to left convolution
conv_to_gemm_transformer_left.InLeftPadD_ = InLeftPadD_;
conv_to_gemm_transformer_right.InLeftPadD_ = 0;
// Assign right padding to right convolution
conv_to_gemm_transformer_left.InRightPadD_ = 0;
conv_to_gemm_transformer_right.InRightPadD_ = InRightPadD_;
// Calculate new input size
conv_to_gemm_transformer_left.Di_ = di_left_transformer_end_idx - InLeftPadD_;
conv_to_gemm_transformer_right.Di_ =
math::min(Di_ - (di_right_transformer_start_idx - InLeftPadD_),
(conv_to_gemm_transformer_right.Do_ - 1) * ConvStrideD_ + z_eff);
;
// Calcualte offsets
a_right_offset = (Do_ / 2) * DoStride_;
c_right_offset = ((Do_ / 2) * ConvStrideD_ - InLeftPadD_) * DiStride_;
}
else if(is_possible_to_split_h)
{
conv_to_gemm_transformer_left.Ho_ = Ho_ / 2;
conv_to_gemm_transformer_right.Ho_ = Ho_ - Ho_ / 2;
} // namespace
conv_to_gemm_transformer_left.InLeftPadH_ = InLeftPadH_;
conv_to_gemm_transformer_right.InLeftPadH_ = 0;
template <
index_t NDimSpatial,
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization ConvBwdDataSpecialization,
index_t AK1,
index_t BK1,
index_t GemmMPerBlock,
index_t GemmNPerBlock,
index_t GemmKPerBlock,
bool DoPadGemmM,
bool DoPadGemmN>
struct TransformConvBwdDataToGemm_v1
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
conv_to_gemm_transformer_left.InRightPadH_ = 0;
conv_to_gemm_transformer_right.InRightPadH_ = InRightPadH_;
static constexpr auto NonSpatialDimsNum = Number<3>{};
conv_to_gemm_transformer_left.Hi_ = hi_left_transformer_end_idx - InLeftPadH_;
conv_to_gemm_transformer_right.Hi_ =
math::min(Hi_ - (hi_right_transformer_start_idx - InLeftPadH_),
(conv_to_gemm_transformer_right.Ho_ - 1) * ConvStrideH_ + y_eff);
a_right_offset = (Ho_ / 2) * HoStride_;
c_right_offset = ((Ho_ / 2) * ConvStrideH_ - InLeftPadH_) * HiStride_;
}
else if(is_possible_to_split_w)
{
conv_to_gemm_transformer_left.Wo_ = Wo_ / 2;
conv_to_gemm_transformer_right.Wo_ = Wo_ - Wo_ / 2;
static constexpr auto DIdx = Number<NonSpatialDimsNum>{};
static constexpr auto HIdx =
NDimSpatial == 2 ? Number<NonSpatialDimsNum>{} : Number<NonSpatialDimsNum + 1>{};
static constexpr auto WIdx =
NDimSpatial == 2 ? Number<NonSpatialDimsNum + 1>{} : Number<NonSpatialDimsNum + 2>{};
conv_to_gemm_transformer_left.InLeftPadW_ = InLeftPadW_;
conv_to_gemm_transformer_right.InLeftPadW_ = 0;
static constexpr auto ZIdx = Number<NonSpatialDimsNum>{};
static constexpr auto YIdx =
NDimSpatial == 2 ? Number<NonSpatialDimsNum>{} : Number<NonSpatialDimsNum + 1>{};
static constexpr auto XIdx =
NDimSpatial == 2 ? Number<NonSpatialDimsNum + 1>{} : Number<NonSpatialDimsNum + 2>{};
conv_to_gemm_transformer_left.InRightPadW_ = 0;
conv_to_gemm_transformer_right.InRightPadW_ = InRightPadW_;
template <typename ALayout,
typename std::enable_if<(NDimSpatial == 2 || NDimSpatial == 3) &&
(is_same_v<ALayout, tensor_layout::convolution::GNHWK> ||
is_same_v<ALayout, tensor_layout::convolution::GNDHWK> ||
is_same_v<ALayout, tensor_layout::convolution::NHWGK> ||
is_same_v<ALayout, tensor_layout::convolution::NDHWGK>),
bool>::type = false>
static auto MakeADescriptor_AK0_M_AK1(
const std::array<index_t, NDimSpatial + 3>& out_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& out_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial + 3>& wei_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& /* wei_g_k_c_xs_strides */,
const std::array<index_t, NDimSpatial + 3>& in_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& /* in_g_n_c_wis_strides */,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& /* input_right_pads */,
const std::array<index_t, NDimSpatial>& tildes)
conv_to_gemm_transformer_left.Wi_ = wi_left_transformer_end_idx - InLeftPadW_;
conv_to_gemm_transformer_right.Wi_ =
math::min(Wi_ - (wi_right_transformer_start_idx - InLeftPadW_),
(conv_to_gemm_transformer_right.Wo_ - 1) * ConvStrideW_ + x_eff);
a_right_offset = (Wo_ / 2) * WoStride_;
c_right_offset = ((Wo_ / 2) * ConvStrideW_ - InLeftPadW_) * WiStride_;
}
// Return left transform, right transformer, right offset to Input and right offset to
// Output
return ck::make_tuple(conv_to_gemm_transformer_left,
conv_to_gemm_transformer_right,
a_grid_ptr_base + a_right_offset,
c_grid_ptr_base + c_right_offset);
}
__host__ auto SplitConvProblem(const ADataType* a_grid_ptr_base,
CDataType* c_grid_ptr_base) const
{
index_t i_ztilde = tildes[ZIdx - NonSpatialDimsNum];
index_t i_ytilde = tildes[YIdx - NonSpatialDimsNum];
index_t i_xtilde = tildes[XIdx - NonSpatialDimsNum];
// Create copies
auto conv_to_gemm_transformer_left = *this;
auto conv_to_gemm_transformer_right = *this;
IndexType a_right_offset = 0;
IndexType c_right_offset = 0;
// Calculate start position in input for right tensor
const IndexType do_right_transformer_start_idx = math::integer_divide_ceil((Di_ / 2) + InLeftPadD_ - ((Z_ - 1) * ConvDilationD_), ConvStrideD_);
const IndexType ho_right_transformer_start_idx = math::integer_divide_ceil((Hi_ / 2) + InLeftPadH_ - ((Y_ - 1) * ConvDilationH_), ConvStrideH_);
const IndexType wo_right_transformer_start_idx = math::integer_divide_ceil((Wi_ / 2) + InLeftPadW_ - ((X_ - 1) * ConvDilationW_), ConvStrideW_);
// Calculate last position in input for left tensor
const IndexType do_left_transformer_end_idx = math::integer_divide_ceil((Di_ / 2 - 1) + InLeftPadD_, ConvStrideD_);
const IndexType ho_left_transformer_end_idx = math::integer_divide_ceil((Hi_ / 2 - 1) + InLeftPadH_, ConvStrideH_);
const IndexType wo_left_transformer_end_idx = math::integer_divide_ceil((Wi_ / 2 - 1) + InLeftPadW_, ConvStrideW_);
if(Di_!=1)
{
// Apply new sizes
// Split output on half
conv_to_gemm_transformer_left.Di_ = Di_ / 2;
conv_to_gemm_transformer_right.Di_ = Di_ - Di_ / 2;
// Assign left padding to left convolution
conv_to_gemm_transformer_left.InLeftPadD_ = InLeftPadD_;
conv_to_gemm_transformer_right.InLeftPadD_ = 0;
// // Assign right padding to right convolution
conv_to_gemm_transformer_left.InRightPadD_ = 0;
conv_to_gemm_transformer_right.InRightPadD_ = InRightPadD_;
// Calculate new input size
conv_to_gemm_transformer_left.Do_ = do_left_transformer_end_idx;
conv_to_gemm_transformer_right.Do_ = Do_ - do_right_transformer_start_idx;
;
// Calcualte offsets
a_right_offset = do_right_transformer_start_idx * DoStride_;
c_right_offset = (Di_ / 2) * DiStride_;
}
else if(Hi_!=1)
{
// Apply new sizes
// Split output on half
conv_to_gemm_transformer_left.Hi_ = Hi_ / 2;
conv_to_gemm_transformer_right.Hi_ = Hi_ - Hi_ / 2;
// Assign left padding to left convolution
conv_to_gemm_transformer_left.InLeftPadH_ = InLeftPadH_;
conv_to_gemm_transformer_right.InLeftPadH_ = 0;
// // Assign right padding to right convolution
conv_to_gemm_transformer_left.InRightPadH_ = 0;
conv_to_gemm_transformer_right.InRightPadH_ = InRightPadH_;
// Calculate new input size
conv_to_gemm_transformer_left.Ho_ = ho_left_transformer_end_idx ;
conv_to_gemm_transformer_right.Ho_ = Ho_ - ho_right_transformer_start_idx ;
;
// Calcualte offsets
a_right_offset = ho_right_transformer_start_idx * HoStride_;
c_right_offset = (Hi_ / 2) * HiStride_;
}
else if(Wi_!=1)
{
// Apply new sizes
// Split output on half
conv_to_gemm_transformer_left.Wi_ = Wi_ / 2;
conv_to_gemm_transformer_right.Wi_ = Wi_ - Wi_ / 2;
// Assign left padding to left convolution
conv_to_gemm_transformer_left.InLeftPadW_ = InLeftPadW_;
conv_to_gemm_transformer_right.InLeftPadW_ = 0;
// Assign right padding to right convolution
conv_to_gemm_transformer_left.InRightPadW_ = 0;
conv_to_gemm_transformer_right.InRightPadW_ = InRightPadW_;
// Calculate new input size
conv_to_gemm_transformer_left.Wo_ = wo_left_transformer_end_idx;
conv_to_gemm_transformer_right.Wo_ = Wo_ - wo_right_transformer_start_idx;
;
// Calcualte offsets
a_right_offset = wo_right_transformer_start_idx * WoStride_;
c_right_offset = (Wi_ / 2) * WiStride_;
}
// Return left transform, right transformer, right offset to Input and right offset to
// Output
return ck::make_tuple(conv_to_gemm_transformer_left,
conv_to_gemm_transformer_right,
a_grid_ptr_base + a_right_offset,
c_grid_ptr_base + c_right_offset);
}
#endif
const index_t N = in_g_n_c_wis_lengths[1];
const index_t K = wei_g_k_c_xs_lengths[1];
__host__ __device__ auto MakeOutGridDesc() const
{
if constexpr(is_same_v<ALayout, tensor_layout::convolution::NHWGK>)
{
if constexpr(ConvBwdDataSpecialization ==
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
Filter1x1Stride1Pad0)
{
const index_t Di = NDimSpatial == 3 ? in_g_n_c_wis_lengths[DIdx] : 1;
const index_t Hi = in_g_n_c_wis_lengths[HIdx];
const index_t Wi = in_g_n_c_wis_lengths[WIdx];
return make_naive_tensor_descriptor(make_tuple(N_ * Ho_ * Wo_, K_),
make_tuple(WoStride_, KStrideTensorA_));
}
else
{
return make_naive_tensor_descriptor(
make_tuple(N_, Ho_, Wo_, K_),
make_tuple(NStrideTensorA_, HoStride_, WoStride_, KStrideTensorA_));
}
}
else if constexpr(is_same_v<ALayout, tensor_layout::convolution::NDHWGK>)
{
if constexpr(ConvBwdDataSpecialization ==
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
Filter1x1Stride1Pad0)
{
const index_t Do = NDimSpatial == 3 ? out_g_n_k_wos_lengths[DIdx] : 1;
const index_t Ho = out_g_n_k_wos_lengths[HIdx];
const index_t Wo = out_g_n_k_wos_lengths[WIdx];
return make_naive_tensor_descriptor(make_tuple(N_ * Do_ * Ho_ * Wo_, K_),
make_tuple(WoStride_, KStrideTensorA_));
}
else
{
return make_naive_tensor_descriptor(
make_tuple(N_, Do_, Ho_, Wo_, K_),
make_tuple(NStrideTensorA_, DoStride_, HoStride_, WoStride_, KStrideTensorA_));
}
}
else if constexpr(is_same_v<ALayout, tensor_layout::convolution::GNHWK>)
{
// assume packed
if constexpr(ConvBwdDataSpecialization ==
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
Filter1x1Stride1Pad0)
{
return make_naive_tensor_descriptor_packed(make_tuple(N_ * Ho_ * Wo_, K_));
}
else
{
return make_naive_tensor_descriptor_packed(make_tuple(N_, Ho_, Wo_, K_));
}
}
else if constexpr(is_same_v<ALayout, tensor_layout::convolution::GNDHWK>)
{
// assume packed
if constexpr(ConvBwdDataSpecialization ==
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
Filter1x1Stride1Pad0)
{
return make_naive_tensor_descriptor_packed(make_tuple(N_ * Do_ * Ho_ * Wo_, K_));
}
else
{
return make_naive_tensor_descriptor_packed(make_tuple(N_, Do_, Ho_, Wo_, K_));
}
}
else
{
throw std::runtime_error("wrong! unsupported layout: " + ALayout::name());
}
}
const index_t Z = NDimSpatial == 3 ? wei_g_k_c_xs_lengths[ZIdx] : 1;
const index_t Y = wei_g_k_c_xs_lengths[YIdx];
const index_t X = wei_g_k_c_xs_lengths[XIdx];
__host__ __device__ auto MakeWeiGridDesc() const
{
const index_t InLeftPadD = input_left_pads[DIdx - NonSpatialDimsNum];
const index_t InLeftPadH = input_left_pads[HIdx - NonSpatialDimsNum];
const index_t InLeftPadW = input_left_pads[WIdx - NonSpatialDimsNum];
if constexpr(is_same_v<BLayout, tensor_layout::convolution::GKYXC>)
{
return make_naive_tensor_descriptor_packed(make_tuple(K_, Y_, X_, C_));
}
else if constexpr(is_same_v<BLayout, tensor_layout::convolution::GKZYXC>)
{
return make_naive_tensor_descriptor_packed(make_tuple(K_, Z_, Y_, X_, C_));
}
else
{
throw std::runtime_error("wrong! unsupported layout: " + BLayout::name());
}
}
const index_t ConvStrideD = conv_filter_strides[DIdx - NonSpatialDimsNum];
const index_t ConvStrideH = conv_filter_strides[HIdx - NonSpatialDimsNum];
const index_t ConvStrideW = conv_filter_strides[WIdx - NonSpatialDimsNum];
__host__ __device__ auto MakeInGridDesc() const
{
const index_t ConvDilationD = conv_filter_dilations[DIdx - NonSpatialDimsNum];
const index_t ConvDilationH = conv_filter_dilations[HIdx - NonSpatialDimsNum];
const index_t ConvDilationW = conv_filter_dilations[WIdx - NonSpatialDimsNum];
if constexpr(is_same_v<CLayout, tensor_layout::convolution::GNHWC> ||
is_same_v<CLayout, tensor_layout::convolution::NHWGC> ||
is_same_v<CLayout, tensor_layout::convolution::G_NHW_C>)
{
return make_naive_tensor_descriptor(
make_tuple(N_, Hi_, Wi_, C_),
make_tuple(NStrideTensorC_, HiStride_, WiStride_, CStrideTensorC_));
}
else if constexpr(is_same_v<CLayout, tensor_layout::convolution::GNDHWC> ||
is_same_v<CLayout, tensor_layout::convolution::NDHWGC>)
{
return make_naive_tensor_descriptor(
make_tuple(N_, Di_, Hi_, Wi_, C_),
make_tuple(NStrideTensorC_, DiStride_, HiStride_, WiStride_, CStrideTensorC_));
}
else
{
throw std::runtime_error("wrong! unsupported layout: " + CLayout::name());
}
}
template <
typename ALayout_ = ALayout,
typename std::enable_if<(NDimSpatial == 2 || NDimSpatial == 3) &&
(is_same_v<ALayout_, tensor_layout::convolution::GNHWK> ||
is_same_v<ALayout_, tensor_layout::convolution::GNDHWK> ||
is_same_v<ALayout_, tensor_layout::convolution::NHWGK> ||
is_same_v<ALayout_, tensor_layout::convolution::NDHWGK>),
bool>::type = false>
__host__ __device__ auto MakeADescriptor_AK0_M_AK1() const
{
// n_do_ho_wo_k for 3d or n_ho_wo_k for 2d
const auto out_grid_desc =
make_out_grid_desc<NDimSpatial, ALayout, ConvBwdDataSpecialization>(
N, Do, Ho, Wo, K, out_g_n_k_wos_strides);
const auto out_grid_desc = MakeOutGridDesc();
if constexpr(ConvBwdDataSpecialization ==
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
Filter1x1Stride1Pad0)
{
const index_t AK0 = math::integer_divide_ceil(K, AK1);
const index_t AK0 = math::integer_divide_ceil(K_, AK1);
// A: output tensor
const auto out_gemmak0_gemmmraw_gemmak1_grid_desc = transform_tensor_descriptor(
out_grid_desc,
make_tuple(make_pass_through_transform(N * Do * Ho * Wo),
make_tuple(make_pass_through_transform(N_ * Do_ * Ho_ * Wo_),
make_unmerge_transform(make_tuple(AK0, AK1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}));
......@@ -266,82 +635,63 @@ struct TransformConvBwdDataToGemm_v1
}
else
{
const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD);
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
const auto ZTilde = ConvStrideD / GcdStrideDilationD;
const auto YTilde = ConvStrideH / GcdStrideDilationH;
const auto XTilde = ConvStrideW / GcdStrideDilationW;
const auto ZDot = math::integer_divide_ceil(Z, ZTilde);
const auto YDot = math::integer_divide_ceil(Y, YTilde);
const auto XDot = math::integer_divide_ceil(X, XTilde);
const auto DTilde =
Do + math::integer_divide_ceil(ConvDilationD * (Z - I1), ConvStrideD);
const auto HTilde =
Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH);
const auto WTilde =
Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
// only work on HTilde and WTilde that contribute to non-padding area of input tensor
const auto IDTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadD - ConvDilationD * (ZTilde - I1)), ConvStrideD);
math::max(I0, InLeftPadD_ - ConvDilationD_ * (ZTilde_ - I1)), ConvStrideD_);
const auto IHTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH);
math::max(I0, InLeftPadH_ - ConvDilationH_ * (YTilde_ - I1)), ConvStrideH_);
const auto IWTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW);
math::max(I0, InLeftPadW_ - ConvDilationW_ * (XTilde_ - I1)), ConvStrideW_);
const auto IDTildeSliceEnd = math::min(
DTilde, math::integer_divide_ceil(InLeftPadD + Di - I1, ConvStrideD) + I1);
DTilde_, math::integer_divide_ceil(InLeftPadD_ + Di_ - I1, ConvStrideD_) + I1);
const auto IHTildeSliceEnd = math::min(
HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1);
HTilde_, math::integer_divide_ceil(InLeftPadH_ + Hi_ - I1, ConvStrideH_) + I1);
const auto IWTildeSliceEnd = math::min(
WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
WTilde_, math::integer_divide_ceil(InLeftPadW_ + Wi_ - I1, ConvStrideW_) + I1);
const auto DTildeSlice = IDTildeSliceEnd - IDTildeSliceBegin;
const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
// GemmK is different for each GEMM
const auto ZDotSlice = math::integer_divide_ceil(Z - i_ztilde, ZTilde);
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
const auto ZDotSlice = math::integer_divide_ceil(Z_ - IdxZTilde_, ZTilde_);
const auto YDotSlice = math::integer_divide_ceil(Y_ - IdxYTilde_, YTilde_);
const auto XDotSlice = math::integer_divide_ceil(X_ - IdxXTilde_, XTilde_);
if constexpr(NDimSpatial == 2)
{
// A: output tensor
const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor(
out_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Ho, I0, I0),
make_pad_transform(Wo, I0, I0),
make_pass_through_transform(K)),
make_tuple(make_pass_through_transform(N_),
make_pad_transform(Ho_, I0, I0),
make_pad_transform(Wo_, I0, I0),
make_pass_through_transform(K_)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor(
out_n_hop_wop_k_grid_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(YDot, HTilde),
make_tuple(-ConvDilationH / GcdStrideDilationH, I1)),
make_embed_transform(make_tuple(XDot, WTilde),
make_tuple(-ConvDilationW / GcdStrideDilationW, I1)),
make_pass_through_transform(K)),
make_pass_through_transform(N_),
make_embed_transform(make_tuple(YDot_, HTilde_),
make_tuple(-ConvDilationH_ / GcdStrideDilationH_, I1)),
make_embed_transform(make_tuple(XDot_, WTilde_),
make_tuple(-ConvDilationW_ / GcdStrideDilationW_, I1)),
make_pass_through_transform(K_)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc =
transform_tensor_descriptor(
out_n_ydot_htilde_xdot_wtilde_k_grid_desc,
make_tuple(make_pass_through_transform(N),
make_slice_transform(YDot, I0, YDotSlice),
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
make_slice_transform(XDot, I0, XDotSlice),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
make_pass_through_transform(K)),
make_tuple(make_pass_through_transform(N_),
make_slice_transform(YDot_, I0, YDotSlice),
make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice),
make_slice_transform(XDot_, I0, XDotSlice),
make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice),
make_pass_through_transform(K_)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
......@@ -357,8 +707,8 @@ struct TransformConvBwdDataToGemm_v1
const auto out_gemmk_gemmmraw_grid_desc = transform_tensor_descriptor(
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc,
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K)),
make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice))),
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K_)),
make_merge_transform(make_tuple(N_, HTildeSlice, WTildeSlice))),
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
......@@ -385,11 +735,11 @@ struct TransformConvBwdDataToGemm_v1
// A: output tensor
const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor(
out_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Do, I0, I0),
make_pad_transform(Ho, I0, I0),
make_pad_transform(Wo, I0, I0),
make_pass_through_transform(K)),
make_tuple(make_pass_through_transform(N_),
make_pad_transform(Do_, I0, I0),
make_pad_transform(Ho_, I0, I0),
make_pad_transform(Wo_, I0, I0),
make_pass_through_transform(K_)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
......@@ -398,17 +748,17 @@ struct TransformConvBwdDataToGemm_v1
const auto out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc =
transform_tensor_descriptor(
out_n_hop_wop_k_grid_desc,
make_tuple(make_pass_through_transform(N),
make_tuple(make_pass_through_transform(N_),
make_embed_transform(
make_tuple(ZDot, DTilde),
make_tuple(-ConvDilationD / GcdStrideDilationD, I1)),
make_tuple(ZDot_, DTilde_),
make_tuple(-ConvDilationD_ / GcdStrideDilationD_, I1)),
make_embed_transform(
make_tuple(YDot, HTilde),
make_tuple(-ConvDilationH / GcdStrideDilationH, I1)),
make_tuple(YDot_, HTilde_),
make_tuple(-ConvDilationH_ / GcdStrideDilationH_, I1)),
make_embed_transform(
make_tuple(XDot, WTilde),
make_tuple(-ConvDilationW / GcdStrideDilationW, I1)),
make_pass_through_transform(K)),
make_tuple(XDot_, WTilde_),
make_tuple(-ConvDilationW_ / GcdStrideDilationW_, I1)),
make_pass_through_transform(K_)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
......@@ -424,14 +774,15 @@ struct TransformConvBwdDataToGemm_v1
out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc =
transform_tensor_descriptor(
out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc,
make_tuple(make_pass_through_transform(N),
make_slice_transform(ZDot, I0, ZDotSlice),
make_slice_transform(DTilde, IDTildeSliceBegin, DTildeSlice),
make_slice_transform(YDot, I0, YDotSlice),
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
make_slice_transform(XDot, I0, XDotSlice),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
make_pass_through_transform(K)),
make_tuple(
make_pass_through_transform(N_),
make_slice_transform(ZDot_, I0, ZDotSlice),
make_slice_transform(DTilde_, IDTildeSliceBegin, DTildeSlice),
make_slice_transform(YDot_, I0, YDotSlice),
make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice),
make_slice_transform(XDot_, I0, XDotSlice),
make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice),
make_pass_through_transform(K_)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
......@@ -452,8 +803,9 @@ struct TransformConvBwdDataToGemm_v1
const auto out_gemmk_gemmmraw_grid_desc = transform_tensor_descriptor(
out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc,
make_tuple(
make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K)),
make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice))),
make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K_)),
make_merge_transform(
make_tuple(N_, DTildeSlice, HTildeSlice, WTildeSlice))),
make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
......@@ -482,66 +834,31 @@ struct TransformConvBwdDataToGemm_v1
}
}
template <typename BLayout,
template <typename BLayout_ = BLayout,
typename std::enable_if<(NDimSpatial == 2 || NDimSpatial == 3) &&
(is_same_v<BLayout, tensor_layout::convolution::GKYXC> ||
is_same_v<BLayout, tensor_layout::convolution::GKZYXC>),
(is_same_v<BLayout_, tensor_layout::convolution::GKYXC> ||
is_same_v<BLayout_, tensor_layout::convolution::GKZYXC>),
bool>::type = false>
static auto MakeBDescriptor_BK0_N_BK1(
const std::array<index_t, NDimSpatial + 3>& out_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& /* out_g_n_k_wos_strides */,
const std::array<index_t, NDimSpatial + 3>& wei_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& /* wei_g_k_c_xs_strides */,
const std::array<index_t, NDimSpatial + 3>& in_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& /* in_g_n_c_wis_strides */,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& /* input_left_pads */,
const std::array<index_t, NDimSpatial>& /* input_right_pads */,
const std::array<index_t, NDimSpatial>& tildes)
__host__ __device__ auto MakeBDescriptor_BK0_N_BK1() const
{
index_t i_ztilde = tildes[ZIdx - NonSpatialDimsNum];
index_t i_ytilde = tildes[YIdx - NonSpatialDimsNum];
index_t i_xtilde = tildes[XIdx - NonSpatialDimsNum];
const index_t N = in_g_n_c_wis_lengths[1];
const index_t K = wei_g_k_c_xs_lengths[1];
const index_t C = wei_g_k_c_xs_lengths[2];
const index_t Do = NDimSpatial == 3 ? out_g_n_k_wos_lengths[DIdx] : 1;
const index_t Ho = out_g_n_k_wos_lengths[HIdx];
const index_t Wo = out_g_n_k_wos_lengths[WIdx];
const index_t Z = NDimSpatial == 3 ? wei_g_k_c_xs_lengths[ZIdx] : 1;
const index_t Y = wei_g_k_c_xs_lengths[YIdx];
const index_t X = wei_g_k_c_xs_lengths[XIdx];
const index_t ConvStrideD = conv_filter_strides[DIdx - NonSpatialDimsNum];
const index_t ConvStrideH = conv_filter_strides[HIdx - NonSpatialDimsNum];
const index_t ConvStrideW = conv_filter_strides[WIdx - NonSpatialDimsNum];
const index_t ConvDilationD = conv_filter_dilations[DIdx - NonSpatialDimsNum];
const index_t ConvDilationH = conv_filter_dilations[HIdx - NonSpatialDimsNum];
const index_t ConvDilationW = conv_filter_dilations[WIdx - NonSpatialDimsNum];
// assume packed
// k_y_x_c for 2d or k_z_y_x_c for 3d
const auto wei_grid_desc = make_wei_grid_desc<BLayout>(K, Z, Y, X, C);
const auto wei_grid_desc = MakeWeiGridDesc();
if constexpr(ConvBwdDataSpecialization ==
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
Filter1x1Stride1Pad0)
{
const index_t BK0 = math::integer_divide_ceil(K, BK1);
const index_t BK0 = math::integer_divide_ceil(K_, BK1);
// B: weight tensor
const auto wei_gemmbk0_gemmnraw_gemmbk1_grid_desc =
transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C)),
transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K_, C_)),
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(C)),
make_pass_through_transform(C_)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
make_naive_tensor_descriptor(make_tuple(N * Do * Ho * Wo, C), make_tuple(I0, I1));
make_naive_tensor_descriptor(make_tuple(N_ * Do_ * Ho_ * Wo_, C_), make_tuple(I0, I1));
const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(
......@@ -553,22 +870,10 @@ struct TransformConvBwdDataToGemm_v1
}
else
{
const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD);
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
const auto ZTilde = ConvStrideD / GcdStrideDilationD;
const auto YTilde = ConvStrideH / GcdStrideDilationH;
const auto XTilde = ConvStrideW / GcdStrideDilationW;
const auto ZDot = math::integer_divide_ceil(Z, ZTilde);
const auto YDot = math::integer_divide_ceil(Y, YTilde);
const auto XDot = math::integer_divide_ceil(X, XTilde);
// GemmK is different for each GEMM
const auto ZDotSlice = math::integer_divide_ceil(Z - i_ztilde, ZTilde);
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
const auto ZDotSlice = math::integer_divide_ceil(Z_ - IdxZTilde_, ZTilde_);
const auto YDotSlice = math::integer_divide_ceil(Y_ - IdxYTilde_, YTilde_);
const auto XDotSlice = math::integer_divide_ceil(X_ - IdxXTilde_, XTilde_);
// B weight tensor
if constexpr(NDimSpatial == 2)
......@@ -576,23 +881,23 @@ struct TransformConvBwdDataToGemm_v1
const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor(
wei_grid_desc,
make_tuple(
make_pass_through_transform(K),
make_embed_transform(make_tuple(YDot, YTilde),
make_tuple(ConvStrideH / GcdStrideDilationH, I1)),
make_embed_transform(make_tuple(XDot, XTilde),
make_tuple(ConvStrideW / GcdStrideDilationW, I1)),
make_pass_through_transform(C)),
make_pass_through_transform(K_),
make_embed_transform(make_tuple(YDot_, YTilde_),
make_tuple(ConvStrideH_ / GcdStrideDilationH_, I1)),
make_embed_transform(make_tuple(XDot_, XTilde_),
make_tuple(ConvStrideW_ / GcdStrideDilationW_, I1)),
make_pass_through_transform(C_)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto wei_k_ydotslice_xdotslice_c_grid_desc = transform_tensor_descriptor(
wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc,
make_tuple(make_pass_through_transform(K),
make_slice_transform(YDot, I0, YDotSlice),
make_slice_transform(XDot, I0, XDotSlice),
make_freeze_transform(i_ytilde),
make_freeze_transform(i_xtilde),
make_pass_through_transform(C)),
make_tuple(make_pass_through_transform(K_),
make_slice_transform(YDot_, I0, YDotSlice),
make_slice_transform(XDot_, I0, XDotSlice),
make_freeze_transform(IdxYTilde_),
make_freeze_transform(IdxXTilde_),
make_pass_through_transform(C_)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<3>{},
......@@ -608,8 +913,8 @@ struct TransformConvBwdDataToGemm_v1
const auto wei_gemmk_gemmnraw_grid_desc = transform_tensor_descriptor(
wei_k_ydotslice_xdotslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K)),
make_pass_through_transform(C)),
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K_)),
make_pass_through_transform(C_)),
make_tuple(Sequence<1, 2, 0>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
......@@ -636,15 +941,17 @@ struct TransformConvBwdDataToGemm_v1
const auto wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc =
transform_tensor_descriptor(
wei_grid_desc,
make_tuple(
make_pass_through_transform(K),
make_embed_transform(make_tuple(ZDot, ZTilde),
make_tuple(ConvStrideD / GcdStrideDilationD, I1)),
make_embed_transform(make_tuple(YDot, YTilde),
make_tuple(ConvStrideH / GcdStrideDilationH, I1)),
make_embed_transform(make_tuple(XDot, XTilde),
make_tuple(ConvStrideW / GcdStrideDilationW, I1)),
make_pass_through_transform(C)),
make_tuple(make_pass_through_transform(K_),
make_embed_transform(
make_tuple(ZDot_, ZTilde_),
make_tuple(ConvStrideD_ / GcdStrideDilationD_, I1)),
make_embed_transform(
make_tuple(YDot_, YTilde_),
make_tuple(ConvStrideH_ / GcdStrideDilationH_, I1)),
make_embed_transform(
make_tuple(XDot_, XTilde_),
make_tuple(ConvStrideW_ / GcdStrideDilationW_, I1)),
make_pass_through_transform(C_)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
......@@ -659,14 +966,14 @@ struct TransformConvBwdDataToGemm_v1
const auto wei_gemmk_zdotslice_ydotslice_xdotslice_c_grid_desc =
transform_tensor_descriptor(
wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc,
make_tuple(make_pass_through_transform(K),
make_slice_transform(ZDot, I0, ZDotSlice),
make_slice_transform(YDot, I0, YDotSlice),
make_slice_transform(XDot, I0, XDotSlice),
make_freeze_transform(i_ztilde),
make_freeze_transform(i_ytilde),
make_freeze_transform(i_xtilde),
make_pass_through_transform(C)),
make_tuple(make_pass_through_transform(K_),
make_slice_transform(ZDot_, I0, ZDotSlice),
make_slice_transform(YDot_, I0, YDotSlice),
make_slice_transform(XDot_, I0, XDotSlice),
make_freeze_transform(IdxZTilde_),
make_freeze_transform(IdxYTilde_),
make_freeze_transform(IdxXTilde_),
make_pass_through_transform(C_)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<3>{},
......@@ -686,8 +993,9 @@ struct TransformConvBwdDataToGemm_v1
const auto wei_gemmk_gemmnraw_grid_desc = transform_tensor_descriptor(
wei_gemmk_zdotslice_ydotslice_xdotslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K)),
make_pass_through_transform(C)),
make_tuple(
make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K_)),
make_pass_through_transform(C_)),
make_tuple(Sequence<1, 2, 3, 0>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
......@@ -716,66 +1024,20 @@ struct TransformConvBwdDataToGemm_v1
}
}
template <typename CLayout,
typename std::enable_if<(NDimSpatial == 2 || NDimSpatial == 3) &&
(is_same_v<CLayout, tensor_layout::convolution::GNHWC> ||
is_same_v<CLayout, tensor_layout::convolution::GNDHWC> ||
is_same_v<CLayout, tensor_layout::convolution::NHWGC> ||
is_same_v<CLayout, tensor_layout::convolution::NDHWGC> ||
is_same_v<CLayout, tensor_layout::convolution::G_NHW_C>),
bool>::type = false>
static auto
MakeCDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& out_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& /* out_g_n_k_wos_strides */,
const std::array<index_t, NDimSpatial + 3>& wei_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& /* wei_g_k_c_xs_strides */,
const std::array<index_t, NDimSpatial + 3>& in_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& in_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads,
const std::array<index_t, NDimSpatial>& tildes)
template <
typename CLayout_ = CLayout,
typename std::enable_if<(NDimSpatial == 2 || NDimSpatial == 3) &&
(is_same_v<CLayout_, tensor_layout::convolution::GNHWC> ||
is_same_v<CLayout_, tensor_layout::convolution::GNDHWC> ||
is_same_v<CLayout_, tensor_layout::convolution::NHWGC> ||
is_same_v<CLayout_, tensor_layout::convolution::NDHWGC> ||
is_same_v<CLayout_, tensor_layout::convolution::G_NHW_C>),
bool>::type = false>
__host__ __device__ auto MakeCDescriptor_M_N() const
{
index_t i_ztilde = tildes[ZIdx - NonSpatialDimsNum];
index_t i_ytilde = tildes[YIdx - NonSpatialDimsNum];
index_t i_xtilde = tildes[XIdx - NonSpatialDimsNum];
const index_t N = in_g_n_c_wis_lengths[1];
const index_t C = wei_g_k_c_xs_lengths[2];
const index_t Di = NDimSpatial == 3 ? in_g_n_c_wis_lengths[DIdx] : 1;
const index_t Hi = in_g_n_c_wis_lengths[HIdx];
const index_t Wi = in_g_n_c_wis_lengths[WIdx];
const index_t Do = NDimSpatial == 3 ? out_g_n_k_wos_lengths[DIdx] : 1;
const index_t Ho = out_g_n_k_wos_lengths[HIdx];
const index_t Wo = out_g_n_k_wos_lengths[WIdx];
const index_t Z = NDimSpatial == 3 ? wei_g_k_c_xs_lengths[ZIdx] : 1;
const index_t Y = wei_g_k_c_xs_lengths[YIdx];
const index_t X = wei_g_k_c_xs_lengths[XIdx];
const index_t InLeftPadD = input_left_pads[DIdx - NonSpatialDimsNum];
const index_t InLeftPadH = input_left_pads[HIdx - NonSpatialDimsNum];
const index_t InLeftPadW = input_left_pads[WIdx - NonSpatialDimsNum];
const index_t InRightPadD = input_right_pads[DIdx - NonSpatialDimsNum];
const index_t InRightPadH = input_right_pads[HIdx - NonSpatialDimsNum];
const index_t InRightPadW = input_right_pads[WIdx - NonSpatialDimsNum];
const index_t ConvStrideD = conv_filter_strides[DIdx - NonSpatialDimsNum];
const index_t ConvStrideH = conv_filter_strides[HIdx - NonSpatialDimsNum];
const index_t ConvStrideW = conv_filter_strides[WIdx - NonSpatialDimsNum];
const index_t ConvDilationD = conv_filter_dilations[DIdx - NonSpatialDimsNum];
const index_t ConvDilationH = conv_filter_dilations[HIdx - NonSpatialDimsNum];
const index_t ConvDilationW = conv_filter_dilations[WIdx - NonSpatialDimsNum];
// assume strided
// n_hi_wi_c for 2d n_di_hi_wi_c for 3d
const auto in_grid_desc =
make_in_grid_desc<NDimSpatial, CLayout>(N, Di, Hi, Wi, C, in_g_n_c_wis_strides);
const auto in_grid_desc = MakeInGridDesc();
if constexpr(ConvBwdDataSpecialization ==
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
......@@ -787,10 +1049,10 @@ struct TransformConvBwdDataToGemm_v1
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_grid_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)),
make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)),
make_pass_through_transform(C)),
make_pass_through_transform(N_),
make_embed_transform(make_tuple(I1, Ho_), make_tuple(I1, ConvStrideH_)),
make_embed_transform(make_tuple(I1, Wo_), make_tuple(I1, ConvStrideW_)),
make_pass_through_transform(C_)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
......@@ -798,8 +1060,8 @@ struct TransformConvBwdDataToGemm_v1
in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_freeze_transform(I0),
make_freeze_transform(I0),
make_merge_transform(make_tuple(N, Ho, Wo)),
make_pass_through_transform(C)),
make_merge_transform(make_tuple(N_, Ho_, Wo_)),
make_pass_through_transform(C_)),
make_tuple(Sequence<1>{}, Sequence<3>{}, Sequence<0, 2, 4>{}, Sequence<5>{}),
make_tuple(Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{}));
......@@ -818,11 +1080,11 @@ struct TransformConvBwdDataToGemm_v1
const auto in_n_x_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_grid_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(I1, Do), make_tuple(I1, ConvStrideD)),
make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)),
make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)),
make_pass_through_transform(C)),
make_pass_through_transform(N_),
make_embed_transform(make_tuple(I1, Do_), make_tuple(I1, ConvStrideD_)),
make_embed_transform(make_tuple(I1, Ho_), make_tuple(I1, ConvStrideH_)),
make_embed_transform(make_tuple(I1, Wo_), make_tuple(I1, ConvStrideW_)),
make_pass_through_transform(C_)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{},
......@@ -836,8 +1098,8 @@ struct TransformConvBwdDataToGemm_v1
make_tuple(make_freeze_transform(I0),
make_freeze_transform(I0),
make_freeze_transform(I0),
make_merge_transform(make_tuple(N, Do, Ho, Wo)),
make_pass_through_transform(C)),
make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_)),
make_pass_through_transform(C_)),
make_tuple(Sequence<1>{},
Sequence<3>{},
Sequence<5>{},
......@@ -861,36 +1123,21 @@ struct TransformConvBwdDataToGemm_v1
}
else
{
const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD);
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
const auto ZTilde = ConvStrideD / GcdStrideDilationD;
const auto YTilde = ConvStrideH / GcdStrideDilationH;
const auto XTilde = ConvStrideW / GcdStrideDilationW;
const auto DTilde =
Do + math::integer_divide_ceil(ConvDilationD * (Z - I1), ConvStrideD);
const auto HTilde =
Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH);
const auto WTilde =
Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
// only work on DTilde, HTilde and WTilde that contribute to
// non-padding area of input tensor
const auto IDTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadD - ConvDilationD * (ZTilde - I1)), ConvStrideD);
math::max(I0, InLeftPadD_ - ConvDilationD_ * (ZTilde_ - I1)), ConvStrideD_);
const auto IHTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH);
math::max(I0, InLeftPadH_ - ConvDilationH_ * (YTilde_ - I1)), ConvStrideH_);
const auto IWTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW);
math::max(I0, InLeftPadW_ - ConvDilationW_ * (XTilde_ - I1)), ConvStrideW_);
const auto IDTildeSliceEnd = math::min(
DTilde, math::integer_divide_ceil(InLeftPadD + Di - I1, ConvStrideD) + I1);
DTilde_, math::integer_divide_ceil(InLeftPadD_ + Di_ - I1, ConvStrideD_) + I1);
const auto IHTildeSliceEnd = math::min(
HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1);
HTilde_, math::integer_divide_ceil(InLeftPadH_ + Hi_ - I1, ConvStrideH_) + I1);
const auto IWTildeSliceEnd = math::min(
WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
WTilde_, math::integer_divide_ceil(InLeftPadW_ + Wi_ - I1, ConvStrideW_) + I1);
const auto DTildeSlice = IDTildeSliceEnd - IDTildeSliceBegin;
const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
......@@ -901,34 +1148,34 @@ struct TransformConvBwdDataToGemm_v1
{
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(make_pass_through_transform(N_),
make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
make_pass_through_transform(C_)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc =
transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(YTilde, HTilde),
make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(XTilde, WTilde),
make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(make_pass_through_transform(N_),
make_embed_transform(make_tuple(YTilde_, HTilde_),
make_tuple(ConvDilationH_, ConvStrideH_)),
make_embed_transform(make_tuple(XTilde_, WTilde_),
make_tuple(ConvDilationW_, ConvStrideW_)),
make_pass_through_transform(C_)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(
Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor(
in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_freeze_transform(i_ytilde),
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
make_freeze_transform(i_xtilde),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
make_pass_through_transform(C)),
make_tuple(make_pass_through_transform(N_),
make_freeze_transform(IdxYTilde_),
make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice),
make_freeze_transform(IdxXTilde_),
make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice),
make_pass_through_transform(C_)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
......@@ -944,8 +1191,8 @@ struct TransformConvBwdDataToGemm_v1
const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor(
in_n_htildeslice_wtildeslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)),
make_pass_through_transform(C)),
make_tuple(make_merge_transform(make_tuple(N_, HTildeSlice, WTildeSlice)),
make_pass_through_transform(C_)),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
......@@ -961,11 +1208,11 @@ struct TransformConvBwdDataToGemm_v1
{
const auto in_n_dip_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Di, InLeftPadD, InRightPadD),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(make_pass_through_transform(N_),
make_pad_transform(Di_, InLeftPadD_, InRightPadD_),
make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
make_pass_through_transform(C_)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
......@@ -974,14 +1221,14 @@ struct TransformConvBwdDataToGemm_v1
const auto in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc =
transform_tensor_descriptor(
in_n_dip_hip_wip_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(ZTilde, DTilde),
make_tuple(ConvDilationD, ConvStrideD)),
make_embed_transform(make_tuple(YTilde, HTilde),
make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(XTilde, WTilde),
make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(make_pass_through_transform(N_),
make_embed_transform(make_tuple(ZTilde_, DTilde_),
make_tuple(ConvDilationD_, ConvStrideD_)),
make_embed_transform(make_tuple(YTilde_, HTilde_),
make_tuple(ConvDilationH_, ConvStrideH_)),
make_embed_transform(make_tuple(XTilde_, WTilde_),
make_tuple(ConvDilationW_, ConvStrideW_)),
make_pass_through_transform(C_)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
......@@ -996,14 +1243,14 @@ struct TransformConvBwdDataToGemm_v1
const auto in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc =
transform_tensor_descriptor(
in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_freeze_transform(i_ztilde),
make_slice_transform(DTilde, IDTildeSliceBegin, DTildeSlice),
make_freeze_transform(i_ytilde),
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
make_freeze_transform(i_xtilde),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
make_pass_through_transform(C)),
make_tuple(make_pass_through_transform(N_),
make_freeze_transform(IdxZTilde_),
make_slice_transform(DTilde_, IDTildeSliceBegin, DTildeSlice),
make_freeze_transform(IdxYTilde_),
make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice),
make_freeze_transform(IdxXTilde_),
make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice),
make_pass_through_transform(C_)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
......@@ -1024,8 +1271,8 @@ struct TransformConvBwdDataToGemm_v1
const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor(
in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc,
make_tuple(
make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice)),
make_pass_through_transform(C)),
make_merge_transform(make_tuple(N_, DTildeSlice, HTildeSlice, WTildeSlice)),
make_pass_through_transform(C_)),
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
......@@ -1044,84 +1291,41 @@ struct TransformConvBwdDataToGemm_v1
}
// for input bias
template <typename CLayout,
template <typename CLayout_ = CLayout,
typename std::enable_if<NDimSpatial == 2 &&
(is_same_v<CLayout, tensor_layout::convolution::GC> ||
is_same_v<CLayout, tensor_layout::convolution::G_C>),
(is_same_v<CLayout_, tensor_layout::convolution::GC> ||
is_same_v<CLayout_, tensor_layout::convolution::G_C>),
bool>::type = false>
static auto
MakeCDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& out_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& /* out_g_n_k_wos_strides */,
const std::array<index_t, NDimSpatial + 3>& wei_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& /* wei_g_k_c_xs_strides */,
const std::array<index_t, NDimSpatial + 3>& in_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& /* in_g_n_c_wis_strides */,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& /* input_right_pads */,
const std::array<index_t, NDimSpatial>& /* tildes */)
__host__ __device__ auto MakeCDescriptor_M_N() const
{
const index_t N = in_g_n_c_wis_lengths[1];
const index_t C = wei_g_k_c_xs_lengths[2];
const index_t Hi = in_g_n_c_wis_lengths[3];
const index_t Wi = in_g_n_c_wis_lengths[4];
const index_t Ho = out_g_n_k_wos_lengths[3];
const index_t Wo = out_g_n_k_wos_lengths[4];
const index_t Y = wei_g_k_c_xs_lengths[3];
const index_t X = wei_g_k_c_xs_lengths[4];
const index_t InLeftPadH = input_left_pads[0];
const index_t InLeftPadW = input_left_pads[1];
const index_t ConvStrideH = conv_filter_strides[0];
const index_t ConvStrideW = conv_filter_strides[1];
const index_t ConvDilationH = conv_filter_dilations[0];
const index_t ConvDilationW = conv_filter_dilations[1];
if constexpr(ConvBwdDataSpecialization ==
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
Filter1x1Stride1Pad0)
{
const auto in_gemmm_gemmn_grid_desc =
make_naive_tensor_descriptor(make_tuple(N * Ho * Wo, C), make_tuple(I0, I1));
make_naive_tensor_descriptor(make_tuple(N_ * Ho_ * Wo_, C_), make_tuple(I0, I1));
return in_gemmm_gemmn_grid_desc;
}
else
{
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
const auto YTilde = ConvStrideH / GcdStrideDilationH;
const auto XTilde = ConvStrideW / GcdStrideDilationW;
const auto HTilde =
Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH);
const auto WTilde =
Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
// only work on HTilde and WTilde that contribute to non-padding area of input tensor
const auto IHTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH);
math::max(I0, InLeftPadH_ - ConvDilationH_ * (YTilde_ - I1)), ConvStrideH_);
const auto IWTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW);
math::max(I0, InLeftPadW_ - ConvDilationW_ * (XTilde_ - I1)), ConvStrideW_);
const auto IHTildeSliceEnd = math::min(
HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1);
HTilde_, math::integer_divide_ceil(InLeftPadH_ + Hi_ - I1, ConvStrideH_) + I1);
const auto IWTildeSliceEnd = math::min(
WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
WTilde_, math::integer_divide_ceil(InLeftPadW_ + Wi_ - I1, ConvStrideW_) + I1);
const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
// bias tensor
const auto in_gemmmraw_gemmnraw_grid_desc = make_naive_tensor_descriptor(
make_tuple(N * HTildeSlice * WTildeSlice, C), make_tuple(I0, I1));
make_tuple(N_ * HTildeSlice * WTildeSlice, C_), make_tuple(I0, I1));
const auto in_gemmm_gemmn_grid_desc = ck::tensor_operation::device::PadTensorDescriptor(
in_gemmmraw_gemmnraw_grid_desc,
......@@ -1131,6 +1335,25 @@ struct TransformConvBwdDataToGemm_v1
return in_gemmm_gemmn_grid_desc;
}
}
IndexType N_;
IndexType Di_, Hi_, Wi_;
IndexType Do_, Ho_, Wo_;
IndexType Z_, Y_, X_;
IndexType K_, C_;
IndexType DiStride_, HiStride_, WiStride_;
IndexType DoStride_, HoStride_, WoStride_;
IndexType CStrideTensorB_, CStrideTensorC_, KStrideTensorA_, KStrideTensorB_;
IndexType NStrideTensorA_, NStrideTensorC_;
IndexType ConvStrideD_, ConvStrideH_, ConvStrideW_;
IndexType ConvDilationD_, ConvDilationH_, ConvDilationW_;
IndexType InLeftPadD_, InLeftPadH_, InLeftPadW_;
IndexType InRightPadD_, InRightPadH_, InRightPadW_;
IndexType IdxZTilde_, IdxYTilde_, IdxXTilde_;
IndexType GcdStrideDilationD_, GcdStrideDilationH_, GcdStrideDilationW_;
IndexType ZTilde_, YTilde_, XTilde_;
IndexType DTilde_, HTilde_, WTilde_;
IndexType ZDot_, YDot_, XDot_;
};
} // namespace tensor_operation
......
......@@ -549,8 +549,10 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
(is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, bhalf_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, f8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, bf8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, f8_fnuz_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, bf8_fnuz_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, fp8_storage_t>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
"wrong! not implemented");
......@@ -843,8 +845,8 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
#else
vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size, coherence>(
src_wave_buffer_resource, src_thread_addr_offset, 0);
vector_t tmp{amd_buffer_load_impl<scalar_t, vector_size, coherence>(
src_wave_buffer_resource, src_thread_addr_offset, 0)};
return src_thread_element_valid ? tmp : vector_t(0);
#endif
}
......@@ -873,8 +875,8 @@ amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave,
constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size, coherence>(
src_wave_buffer_resource, src_thread_addr_offset, 0);
vector_t tmp{amd_buffer_load_impl<scalar_t, vector_size, coherence>(
src_wave_buffer_resource, src_thread_addr_offset, 0)};
return src_thread_element_valid ? tmp : vector_t(customized_value);
}
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/random_gen.hpp"
#include "ck/utility/type.hpp"
#ifdef CK_USE_FNUZ_FP8
#define CK_USE_FNUZ_FP8 1
#else
#define CK_USE_FNUZ_FP8 0
#endif
#ifdef CK_USE_OCP_FP8
#define CK_USE_OCP_FP8 1
#else
#define CK_USE_OCP_FP8 0
#endif
namespace ck {
using f8_fnuz_t = _BitInt(8);
using bf8_fnuz_t = unsigned _BitInt(8);
#if(defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) || defined(__gfx1200__) || \
defined(__gfx1201__)) && \
__HIP_DEVICE_COMPILE__
#define CK_FP8_CVT_FAST_PATH 1
#else
#define CK_FP8_CVT_FAST_PATH 0
#endif
#if(defined(__gfx1200__) || defined(__gfx1201__)) && __HIP_DEVICE_COMPILE__
#define CK_OCP_FP8_CVT_FAST_PATH 1
#else
#define CK_OCP_FP8_CVT_FAST_PATH 0
#endif
typedef unsigned char fp8_storage_t;
/**
* \brief Describes FP8 interpretation
*/
enum class ck_fp8_interpretation_t
{
CK_E4M3_OCP = 0, // OCP E4M3
CK_E5M2_OCP = 1, // OCP E5M2
CK_E4M3_FNUZ = 2, // FP8
CK_E5M2_FNUZ = 3, // BF8
};
/**
* \brief Describes saturation behavior
*/
enum class ck_saturation_t
{
CK_NOSAT = 0, // No saturation - replace with NaN or Inf
CK_SATFINITE = 1, // Saturate to finite
};
namespace fp8_impl {
typedef fp8_storage_t fp8x2_storage_t __attribute__((ext_vector_type(2)));
typedef float float2_t __attribute__((ext_vector_type(2)));
__host__ __device__ static inline constexpr bool fnuz_f8_is_nan(f8_fnuz_t a)
{
return static_cast<unsigned char>(a) == 0x80;
}
__host__ __device__ static inline constexpr bool fnuz_bf8_is_nan(bf8_fnuz_t a)
{
return static_cast<unsigned char>(a) == 0x80;
}
__host__ __device__ static inline constexpr bool ocp_f8_is_nan(fp8_storage_t a)
{
return (a & 0x7f) == 0x7f;
}
__host__ __device__ static inline constexpr bool ocp_bf8_is_nan(fp8_storage_t a)
{
return (a & 0x7f) > 0x7c;
}
// The conversion function is from rocblas
// https://github.com/ROCm/rocBLAS/blob/9b7f692abe3c54b88d1e77e045a7db7f1f188b69/library/include/internal/rocblas_hip_f8_impl.h#L220
// This has been modified to handle double types as well
template <typename T, int wm, int we, bool is_fnuz, bool clip = false>
__host__ __device__ static inline T cast_from_f8(fp8_storage_t x)
{
constexpr bool is_half = __hip_internal::is_same<T, _Float16>::value;
constexpr bool is_float = __hip_internal::is_same<T, float>::value;
constexpr bool is_double = __hip_internal::is_same<T, double>::value;
static_assert(is_half || is_float || is_double, "only half, float and double are supported");
constexpr int weo = is_half ? 5 : (is_float ? 8 : 11);
constexpr int wmo = is_half ? 10 : (is_float ? 23 : 52);
T fInf, fNegInf, fNaN, fNeg0, fmax, fmin;
if constexpr(is_half)
{
const unsigned short int ihInf = 0x7C00;
const unsigned short int ihNegInf = 0xFC00;
const unsigned short int ihNaN = 0x7C01;
const unsigned short int ihNeg0 = 0x8000;
/* Max number in e5m2 57344*/
const unsigned short int ifmax = 0x7B00;
const unsigned short int ifmin = 0xFB00;
fInf = bit_cast<_Float16>(ihInf);
fNegInf = bit_cast<_Float16>(ihNegInf);
fNaN = bit_cast<_Float16>(ihNaN);
fNeg0 = bit_cast<_Float16>(ihNeg0);
fmax = bit_cast<_Float16>(ifmax);
fmin = bit_cast<_Float16>(ifmin);
}
else if constexpr(is_float)
{
const unsigned int ifInf = 0x7F800000;
const unsigned int ifNegInf = 0xFF800000;
const unsigned int ifNaN = 0x7F800001;
const unsigned int ifNeg0 = 0x80000000;
/* Max number in e5m2 57344*/
const unsigned int ifmax = 0x47600000;
const unsigned int ifmin = 0xC7600000;
fInf = bit_cast<float>(ifInf);
fNegInf = bit_cast<float>(ifNegInf);
fNaN = bit_cast<float>(ifNaN);
fNeg0 = bit_cast<float>(ifNeg0);
fmax = bit_cast<float>(ifmax);
fmin = bit_cast<float>(ifmin);
}
else if constexpr(is_double)
{
const unsigned long long ifInf = 0x7FF0000000000000ull;
const unsigned long long ifNegInf = 0xFFF0000000000000ull;
const unsigned long long ifNaN = 0x7FF0000000000001ull;
const unsigned long long ifNeg0 = 0x8000000000000000ull;
/* Max number in e5m2 57344*/
const unsigned long long ifmax = 0x40EC000000000000ull;
const unsigned long long ifmin = 0xC0EC000000000000ull;
fInf = bit_cast<double>(ifInf);
fNegInf = bit_cast<double>(ifNegInf);
fNaN = bit_cast<double>(ifNaN);
fNeg0 = bit_cast<double>(ifNeg0);
fmax = bit_cast<double>(ifmax);
fmin = bit_cast<double>(ifmin);
}
if(x == 0)
{
return 0;
}
unsigned long long sign = x >> 7;
unsigned long long mantissa = x & ((1 << wm) - 1);
int exponent = (x & 0x7F) >> wm;
if constexpr(is_fnuz)
{
if(x == 0x80)
{
return fNaN;
}
}
else
{
if(x == 0x80)
{
return fNeg0;
}
if constexpr(we == 4)
{ // e4m3
if((x & 0x7F) == 0x7F)
{
return fNaN;
}
}
else if((x & 0x7C) == 0x7C)
{ // e5m2
if((x & 0x3) == 0)
{
if constexpr(clip)
{
return sign ? fmin : fmax;
}
return sign ? fNegInf : fInf;
}
return fNaN;
}
}
typename __hip_internal::conditional<
sizeof(T) == 2,
unsigned short int,
typename __hip_internal::conditional<sizeof(T) == 4, unsigned int, unsigned long long>::
type>::type retval;
if constexpr(we == 5 && is_half && !is_fnuz)
{
retval = x << 8;
return bit_cast<T>(retval);
}
const int exp_low_cutoff = (1 << (weo - 1)) - (1 << (we - 1)) + 1 - (is_fnuz ? 1 : 0);
// subnormal input
if(exponent == 0)
{
#if defined(__HIP_DEVICE_COMPILE__) && __HIP_DEVICE_COMPILE__
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
int sh = 1 + __clz(mantissa) - (32 - wm);
#else
int sh = 1 + __builtin_clz(mantissa) - (32 - wm);
#endif
mantissa <<= sh;
exponent += 1 - sh;
mantissa &= ((1ull << wm) - 1);
}
exponent += exp_low_cutoff - 1;
mantissa <<= wmo - wm;
// subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
if(exponent <= 0)
{
mantissa |= 1 << wmo;
mantissa >>= 1 - exponent;
exponent = 0;
}
if constexpr(sizeof(T) == 2)
retval = (sign << 15) | (exponent << 10) | mantissa;
else if constexpr(sizeof(T) == 4)
retval = (sign << 31) | (exponent << 23) | mantissa;
else
retval = (sign << 63) | (static_cast<unsigned long long>(exponent) << 52) | mantissa;
return bit_cast<T>(retval);
}
#if CK_FP8_CVT_FAST_PATH
template <ck_fp8_interpretation_t interpret>
static __device__ float cast_to_f32_from_f8(fp8_storage_t v)
{
union
{
unsigned int i32val;
unsigned char i8val[4];
} val;
val.i8val[0] = v;
static_assert(interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ ||
interpret == ck_fp8_interpretation_t::CK_E4M3_OCP ||
interpret == ck_fp8_interpretation_t::CK_E5M2_FNUZ ||
interpret == ck_fp8_interpretation_t::CK_E5M2_OCP,
"Only FNUZ and OCP interpretations are supported");
if constexpr((interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ) ||
(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP))
{
return __builtin_amdgcn_cvt_f32_fp8(val.i32val, 0);
}
else
{
return __builtin_amdgcn_cvt_f32_bf8(val.i32val, 0);
}
}
template <ck_fp8_interpretation_t interpret>
static __device__ float2_t cast_to_f32x2_from_f8x2(fp8x2_storage_t v)
{
const auto i16val = bit_cast<uint16_t>(v);
static_assert(interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ ||
interpret == ck_fp8_interpretation_t::CK_E4M3_OCP ||
interpret == ck_fp8_interpretation_t::CK_E5M2_FNUZ ||
interpret == ck_fp8_interpretation_t::CK_E5M2_OCP,
"Only FNUZ and OCP interpretations are supported");
if constexpr((interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ) ||
(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP))
{
return __builtin_amdgcn_cvt_pk_f32_fp8(i16val, false);
}
else
{
return __builtin_amdgcn_cvt_pk_f32_bf8(i16val, false);
}
}
#endif
} // namespace fp8_impl
struct f8_ocp_t
{
using data_type = fp8_storage_t;
data_type data;
static constexpr ck_saturation_t default_saturation = ck_saturation_t::CK_SATFINITE;
static constexpr ck_fp8_interpretation_t default_interpret =
ck_fp8_interpretation_t::CK_E4M3_OCP;
static constexpr unsigned int we = 4; // exponent width
static constexpr unsigned int wm = 3; // mantissa width
__host__ __device__ constexpr bool operator==(const f8_ocp_t& other) const
{
return (data == other.data) && (fp8_impl::ocp_f8_is_nan(data) == false); // NaN != NaN
}
#if CK_USE_OCP_FP8
__host__ __device__ explicit operator float() const
#else
__host__ explicit operator float() const
#endif
{
#if CK_OCP_FP8_CVT_FAST_PATH
return fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data);
#else
return fp8_impl::cast_from_f8<float, wm, we, false>(
this->data); // XXX: clip==false must be consistent with operator _Float16
#endif
}
#if CK_USE_OCP_FP8
__host__ __device__ explicit operator _Float16() const
#else
__host__ explicit operator _Float16() const
#endif
{
#if CK_OCP_FP8_CVT_FAST_PATH
return static_cast<_Float16>(fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data));
#else
return fp8_impl::cast_from_f8<_Float16, wm, we, false>(
this->data); // XXX: clip==false must be consistent with operator float
#endif
}
};
struct bf8_ocp_t
{
using data_type = fp8_storage_t;
data_type data;
static constexpr ck_saturation_t default_saturation = ck_saturation_t::CK_SATFINITE;
static constexpr ck_fp8_interpretation_t default_interpret =
ck_fp8_interpretation_t::CK_E5M2_OCP;
static constexpr unsigned int we = 5; // exponent width
static constexpr unsigned int wm = 2; // mantissa width
__host__ __device__ constexpr bool operator==(const bf8_ocp_t& other) const
{
return (data == other.data) && (fp8_impl::ocp_bf8_is_nan(data) == false); // NaN != NaN
}
#if CK_USE_OCP_FP8
__host__ __device__ explicit operator float() const
#else
__host__ explicit operator float() const
#endif
{
#if defined(__gfx1200__) || defined(__gfx1201__)
return fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data);
#else
return fp8_impl::cast_from_f8<float, wm, we, false>(
this->data); // XXX: clip==false must be consistent with operator _Float16
#endif
}
#if CK_USE_OCP_FP8
__host__ __device__ explicit operator _Float16() const
#else
__host__ explicit operator _Float16() const
#endif
{
#if defined(__gfx1200__) || defined(__gfx1201__)
return static_cast<_Float16>(fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data));
#else
return fp8_impl::cast_from_f8<_Float16, wm, we, false>(
this->data); // XXX: clip==false must be consistent with operator float
#endif
}
};
template <typename T>
__host__ __device__ static inline constexpr bool fp8_is_nan(T);
template <>
__host__ __device__ inline constexpr bool fp8_is_nan(f8_ocp_t a)
{
return fp8_impl::ocp_f8_is_nan(a.data);
}
template <>
__host__ __device__ inline constexpr bool fp8_is_nan(bf8_ocp_t a)
{
return fp8_impl::ocp_bf8_is_nan(a.data);
}
template <>
__host__ __device__ inline constexpr bool fp8_is_nan(f8_fnuz_t a)
{
return fp8_impl::fnuz_f8_is_nan(a);
}
template <>
__host__ __device__ inline constexpr bool fp8_is_nan(bf8_fnuz_t a)
{
return fp8_impl::fnuz_bf8_is_nan(a);
}
template <typename T,
std::enable_if_t<std::is_same_v<T, bf8_ocp_t> || std::is_same_v<T, f8_ocp_t> ||
std::is_same_v<T, bf8_fnuz_t> || std::is_same_v<T, f8_fnuz_t>,
bool> = true>
__host__ __device__ static inline constexpr bool fp8_is_inf(T)
{
return false;
}
template <>
__host__ __device__ inline constexpr bool fp8_is_inf(bf8_ocp_t a)
{
return (a.data & 0x7f) == 0x7c;
}
namespace fp8_impl {
// Assertions to check for supported conversion types
#define __assert_ocp_support(interp) \
{ \
if(interp != ck_fp8_interpretation_t::CK_E4M3_OCP && \
interp != ck_fp8_interpretation_t::CK_E5M2_OCP) \
{ \
__hip_assert(false && "type is unsupported by current target device"); \
} \
}
#define __assert_fnuz_support(interp) \
{ \
if(interp != ck_fp8_interpretation_t::CK_E4M3_FNUZ && \
interp != ck_fp8_interpretation_t::CK_E5M2_FNUZ) \
{ \
__hip_assert(false && "type is unsupported by current target device"); \
} \
}
__host__ __device__ static inline void
__is_interpret_supported([[maybe_unused]] ck_fp8_interpretation_t interp)
{
#if defined(__HIP_DEVICE_COMPILE__) && __HIP_DEVICE_COMPILE__
#if CK_USE_OCP_FP8
__assert_ocp_support(interp);
#endif
#if CK_USE_FNUZ_FP8
__assert_fnuz_support(interp);
#endif
#endif
}
#if CK_FP8_CVT_FAST_PATH
// The conversion function is from rocblas
// https://github.com/ROCm/rocBLAS/blob/9b7f692abe3c54b88d1e77e045a7db7f1f188b69/library/include/internal/rocblas_float8.h#L79
template <ck_fp8_interpretation_t interpret, bool saturate, bool stochastic_rounding = false>
static __device__ fp8_storage_t cast_to_f8_from_f32(float v, unsigned int rng = 0)
{
fp8_storage_t i8data;
union
{
float fval;
unsigned int i32val;
unsigned char i8val[4]; // NOTE: not endian independent
} val;
unsigned int ival = 0;
val.fval = v;
if constexpr(saturate)
{
if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ)
{
if((val.i32val & 0x7F800000) != 0x7F800000)
{ /// propagate NAN/INF, no clipping
val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0);
}
}
else if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)
{ // OCP type
if((val.i32val & 0x7F800000) != 0x7F800000)
{ /// propagate NAN/INF, no clipping
val.fval = __builtin_amdgcn_fmed3f(val.fval, 448.0, -448.0);
}
}
else
{
if((val.i32val & 0x7F800000) != 0x7F800000)
{ /// propagate NAN/INF, no clipping
val.fval = __builtin_amdgcn_fmed3f(val.fval, 57344.0, -57344.0);
}
}
}
if constexpr(stochastic_rounding)
{
ival = (interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ) ||
(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)
? __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0)
: __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0); // 0 pos
val.i32val = ival;
i8data = val.i8val[0]; // little endian
}
else
{ // RNE CVT
ival = (interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ) ||
(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)
? __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, false)
: __builtin_amdgcn_cvt_pk_bf8_f32(val.fval,
val.fval,
ival,
false); // false -> WORD0
val.i32val = ival;
i8data = val.i8val[0];
}
return i8data;
}
#endif // CK_FP8_CVT_FAST_PATH
// The conversion function is from rocblas
// https://github.com/ROCm/rocBLAS/blob/9b7f692abe3c54b88d1e77e045a7db7f1f188b69/library/include/internal/rocblas_hip_f8_impl.h#L39
// This has been modified to add double types conversion as well
template <typename T, int wm, int we, bool is_fnuz, bool clip = false, bool stoch = false>
__host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rng = 0)
{
constexpr bool is_half = __hip_internal::is_same<T, _Float16>::value;
constexpr bool is_float = __hip_internal::is_same<T, float>::value;
constexpr bool is_double = __hip_internal::is_same<T, double>::value;
static_assert(is_half || is_float || is_double,
"Only half, float and double can be cast to f8");
constexpr int mfmt = (sizeof(T) == 8) ? 52 : ((sizeof(T) == 4) ? 23 : 10);
using T_bitwise = typename __hip_internal::conditional<
sizeof(T) == 2,
unsigned short int,
typename __hip_internal::conditional<sizeof(T) == 4, unsigned int, unsigned long long>::
type>::type;
T_bitwise x_bitwise = bit_cast<T_bitwise>(_x);
unsigned long long x{x_bitwise};
unsigned long long head, mantissa;
int exponent, bias;
unsigned int sign;
unsigned long long fInf, mask;
if constexpr(sizeof(T) == 8)
{
head = x & 0xFFF0000000000000ull;
mantissa = x & 0xFFFFFFFFFFFFFull;
exponent = (head >> 52) & 0x7FF;
sign = head >> 63;
bias = 1023;
fInf = 0x7FF0000000000000ull;
mask = 0x7FFFFFFFFFFFFFFFull;
}
else if constexpr(sizeof(T) == 4)
{
head = x & 0xFF800000;
mantissa = x & 0x7FFFFF;
exponent = (head >> 23) & 0xFF;
sign = head >> 31;
bias = 127;
fInf = 0x7F800000;
mask = 0x7FFFFFFF;
}
else
{
head = x & 0xFC00;
mantissa = x & 0x3FF;
exponent = (head >> 10) & 0x1F;
sign = head >> 15;
bias = 15;
fInf = 0x7C00;
mask = 0x7FFF;
}
unsigned int signed_inf = 0;
unsigned int nan = 0;
if constexpr(is_fnuz)
{
signed_inf = clip ? ((sign << 7) + 0x7f) : 0x80;
nan = 0x80;
}
else
{
if constexpr(we == 4)
{ // e4m3
signed_inf = (sign << 7) + (clip ? 0x7e : 0x7f);
}
else
{ // e5m2
signed_inf = (sign << 7) + (clip ? 0x7b : 0x7c);
}
nan = (sign << 7) + 0x7f;
}
// Max values
unsigned long long ifmax = 0;
if constexpr(sizeof(T) == 8)
{
if constexpr(we == 5)
{ // 57344
ifmax = 0x40EC000000000000ull;
}
else
{
if constexpr(is_fnuz)
{ // 240
ifmax = 0x406E000000000000ull;
}
else
{ // 448
ifmax = 0x407C000000000000ull;
}
}
}
else if(sizeof(T) == 4)
{
if constexpr(we == 5)
{
ifmax = 0x47600000;
}
else
{
if constexpr(is_fnuz)
{
ifmax = 0x43700000;
}
else
{
ifmax = 0x43E00000;
}
}
}
else
{
if constexpr(we == 5)
{
ifmax = 0x7B00;
}
else
{
if constexpr(is_fnuz)
{
ifmax = 0x5B80;
}
else
{
ifmax = 0x5F00;
}
}
}
// Deal with inf and NaNs
if((x & fInf) == fInf)
{
if constexpr(is_fnuz)
return signed_inf;
return mantissa != 0 ? nan : signed_inf;
}
if((x & mask) > ifmax)
{
return signed_inf;
}
if(x == 0)
{
return 0;
}
// First need to check if it is normal or denorm as there is a difference of
// implicit 1 Then need to adjust the exponent to align with the F8 exponent,
// in the meanwhile, shift The mantissa. Then for stochastic rounding, add rng
// to mantissa and truncate. And for RNE, no need to add rng. Then probably
// need to check whether there is carry and adjust exponent and mantissa again
// For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent
// bits
const int f8_bias = (1 << (we - 1)) - 1 + (is_fnuz ? 1 : 0);
const int f8_denormal_act_exponent = 1 - f8_bias; // actual exponent of f8 denormal
// act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
// f8_exponent is the converted f8 exponent with bias encoding
// exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
// the difference needs to be adjusted and mantissa shifted
int act_exponent, f8_exponent, exponent_diff;
if(exponent == 0)
{ // fp32/fp16 is in denormal.
/* fp32 denormal is below 2^-127 so it is usually not a concern here, we
mostly concern fp16 here. In this case, f8 is usually in denormal. But there
could be exceptions. fp16 denormal has exponent bias 15 while bf8 with NANOO has
exponent bias 16. It means that there are some numbers in fp16 denormal but they
are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers
where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8
(NANOO) normal. In this case, the fp16 mantissa should be shift left by 1 */
act_exponent = exponent - bias + 1;
exponent_diff = f8_denormal_act_exponent -
act_exponent; // actual exponent is exponent-bias+1 as it is denormal
}
else
{ // fp32/fp16 is normal with implicit 1
act_exponent = exponent - bias;
if(act_exponent <= f8_denormal_act_exponent)
{
/* This is the case where fp32/fp16 is normal but it is in f8 denormal
range. For example fp8 nanoo mode, denormal exponent is -7, but if the fp32/fp16
actual exponent is -7, it is actually larger due to the implicit 1,
Therefore it needs to be adjust to -6 and mantissa shift right by 1.
So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
exponent_diff = f8_denormal_act_exponent - act_exponent;
}
else
{ // both fp32/fp16 and f8 are in normal range
exponent_diff = 0; // exponent_diff=0 does not mean there is no difference
// for this case, act_exponent could be larger. Just
// that it does not need shift mantissa
}
mantissa += (1ull << mfmt); // Add the implicit 1 into mantissa
}
bool midpoint = (mantissa & ((1ull << (mfmt - wm + exponent_diff)) - 1)) ==
(1ull << (mfmt - wm + exponent_diff - 1));
/* This part is a bit tricky. The judgment of whether it is a tie needs to be
done before we shift right as shift right could rip off some residual part and
make something not midpoint look like midpoint. For example, the fp16 number
0x1002 (0 00100 0000000010), it is larger than midpoint, but after shift right
by 4 bits, it would look like midpoint.
*/
if(exponent_diff > 0)
mantissa >>= exponent_diff;
else if(exponent_diff == -1)
mantissa <<= -exponent_diff;
bool implicit_one = mantissa & (1ull << mfmt);
// if there is no implicit 1, it means the f8 is denormal and need to adjust
// to denorm exponent
f8_exponent =
(act_exponent + exponent_diff) /*actual f8 exponent*/ + f8_bias - (implicit_one ? 0 : 1);
// Now we have the exponent and mantissa adjusted
unsigned long long drop_mask = (1ull << (mfmt - wm)) - 1;
bool odd =
mantissa & (1ull << (mfmt - wm)); // if the least significant bit that is not truncated is 1
mantissa +=
(stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1ull) : mantissa)) & drop_mask;
// Now we deal with overflow
if(f8_exponent == 0)
{
if((1ull << mfmt) & mantissa)
{
f8_exponent = 1; // denormal overflow to become normal, promote exponent
}
}
else
{
if((1ull << (mfmt + 1)) & mantissa)
{
mantissa >>= 1;
f8_exponent++;
}
}
mantissa >>= (mfmt - wm);
// above range: quantize to maximum possible float of the same sign
const int max_exp = (1 << we) - 1;
if(f8_exponent > max_exp)
{
if constexpr(clip)
{
mantissa = (1 << wm) - 1;
f8_exponent = max_exp;
}
else
{
return signed_inf;
}
}
if(f8_exponent == 0 && mantissa == 0)
return is_fnuz ? 0 : (sign << 7);
mantissa &= (1 << wm) - 1;
return (sign << 7) | (f8_exponent << wm) | mantissa;
}
/**
* \brief convert float to @p fp8_storage_t
*
* \tparam interp interpretation of fp8
* \tparam sat saturation of fp8
* \param f float number
* \return fp8_storage_t
*/
template <ck_fp8_interpretation_t interp,
ck_saturation_t sat = ck_saturation_t::CK_SATFINITE,
bool stochastic_rounding = false>
#if CK_FP8_CVT_FAST_PATH
__host__ __device__ static inline fp8_storage_t cvt_float_to_fp8(const float f)
{
__is_interpret_supported(interp);
uint32_t rng = 0;
if constexpr(stochastic_rounding)
{
constexpr int seed = 1254739;
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f);
}
return cast_to_f8_from_f32<interp, sat == ck_saturation_t::CK_SATFINITE, stochastic_rounding>(
f, rng);
#else
#if CK_USE_OCP_FP8
__host__ __device__ static inline fp8_storage_t cvt_float_to_fp8(const float f)
{
#else
__host__ static inline fp8_storage_t cvt_float_to_fp8(const float f)
{
#endif
uint32_t rng = 0;
if constexpr(stochastic_rounding)
{
constexpr int seed = 1254739;
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f);
}
if constexpr(interp == ck_fp8_interpretation_t::CK_E4M3_FNUZ)
{
return cast_to_f8<float,
3,
4,
true,
sat == ck_saturation_t::CK_SATFINITE,
stochastic_rounding>(f, rng);
}
else if constexpr(interp == ck_fp8_interpretation_t::CK_E5M2_FNUZ)
{
return cast_to_f8<float,
2,
5,
true,
sat == ck_saturation_t::CK_SATFINITE,
stochastic_rounding>(f, rng);
}
else if constexpr(interp == ck_fp8_interpretation_t::CK_E4M3_OCP)
{
return cast_to_f8<float,
3,
4,
false,
sat == ck_saturation_t::CK_SATFINITE,
stochastic_rounding>(f, rng);
}
else if constexpr(interp == ck_fp8_interpretation_t::CK_E5M2_OCP)
{
return cast_to_f8<float,
2,
5,
false,
sat == ck_saturation_t::CK_SATFINITE,
stochastic_rounding>(f, rng);
}
else
{
__hip_assert(false && "FP8 type is not supported by current target device");
return 0;
}
#endif // CK_FP8_CVT_FAST_PATH
}
/**
* \brief convert _Float16 to @p fp8_storage_t
*
* \tparam sat saturation of fp8
* \tparam interp interpretation of fp8
* \tparam stochastic_rounding switch between RNE and SR
* \param x _Float16 value
* \return fp8_storage_t
*/
template <ck_fp8_interpretation_t interp,
ck_saturation_t sat = ck_saturation_t::CK_SATFINITE,
bool stochastic_rounding = false>
#if CK_FP8_CVT_FAST_PATH || CK_USE_OCP_FP8
__host__ __device__ static inline fp8_storage_t cvt_half_t_to_fp8(const _Float16 x)
#else
__host__ static inline fp8_storage_t cvt_half_t_to_fp8(const _Float16 x)
#endif
{
return cvt_float_to_fp8<interp, sat, stochastic_rounding>(static_cast<float>(x));
}
} // namespace fp8_impl
// Declare a template function for fp8 conversion using RNE
template <typename Y, typename X>
__host__ __device__ constexpr Y f8_convert_rne(X x);
// convert fp32 to fp8 with rounding to nearest even
template <>
inline __host__ __device__ f8_ocp_t f8_convert_rne<f8_ocp_t, float>(float x)
{
return f8_ocp_t{
fp8_impl::cvt_float_to_fp8<f8_ocp_t::default_interpret, f8_ocp_t::default_saturation>(x)};
}
// convert fp32 to bf8 with rounding to nearest even
template <>
inline __host__ __device__ bf8_ocp_t f8_convert_rne<bf8_ocp_t, float>(float x)
{
return bf8_ocp_t{
fp8_impl::cvt_float_to_fp8<bf8_ocp_t::default_interpret, bf8_ocp_t::default_saturation>(x)};
}
// convert _Float16 to fp8 with rounding to nearest even
template <>
inline __host__ __device__ f8_ocp_t f8_convert_rne<f8_ocp_t, _Float16>(_Float16 x)
{
return f8_ocp_t{
fp8_impl::cvt_half_t_to_fp8<f8_ocp_t::default_interpret, f8_ocp_t::default_saturation>(x)};
}
template <>
inline __host__ __device__ bf8_ocp_t f8_convert_rne<bf8_ocp_t, _Float16>(_Float16 x)
{
return bf8_ocp_t{
fp8_impl::cvt_half_t_to_fp8<bf8_ocp_t::default_interpret, bf8_ocp_t::default_saturation>(
x)};
}
// Declare a template function for fp8 conversion using RNE
template <typename Y, typename X>
__host__ __device__ constexpr Y f8_convert_sr(X x);
// convert fp32 to fp8 with stochastic rounding
template <>
inline __host__ __device__ f8_ocp_t f8_convert_sr<f8_ocp_t, float>(float x)
{
return f8_ocp_t{
fp8_impl::cvt_float_to_fp8<f8_ocp_t::default_interpret, f8_ocp_t::default_saturation, true>(
x)};
}
// convert fp32 to bf8 with stochastic rounding
template <>
inline __host__ __device__ bf8_ocp_t f8_convert_sr<bf8_ocp_t, float>(float x)
{
return bf8_ocp_t{fp8_impl::cvt_float_to_fp8<bf8_ocp_t::default_interpret,
bf8_ocp_t::default_saturation,
true>(x)};
}
// convert _Float16 to fp8 with stochastic rounding
template <>
inline __host__ __device__ f8_ocp_t f8_convert_sr<f8_ocp_t, _Float16>(_Float16 x)
{
return f8_ocp_t{fp8_impl::cvt_half_t_to_fp8<f8_ocp_t::default_interpret,
f8_ocp_t::default_saturation,
true>(x)};
}
// convert _Float16 to bf8 with stochastic rounding
template <>
inline __host__ __device__ bf8_ocp_t f8_convert_sr<bf8_ocp_t, _Float16>(_Float16 x)
{
return bf8_ocp_t{fp8_impl::cvt_half_t_to_fp8<bf8_ocp_t::default_interpret,
bf8_ocp_t::default_saturation,
true>(x)};
}
#if CK_USE_OCP_FP8
using f8_t = f8_ocp_t;
using bf8_t = bf8_ocp_t;
#define CK_FP8_TYPE_FNUZ 0
#define CK_FP8_TYPE_OCP 1
#else
using f8_t = f8_fnuz_t;
using bf8_t = bf8_fnuz_t;
#define CK_FP8_TYPE_FNUZ 1
#define CK_FP8_TYPE_OCP 0
#endif
} // namespace ck
......@@ -4,7 +4,7 @@
#pragma once
namespace ck {
// Define the common macro for gfx94x models
// Define the common macro for MI300 models
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#define __gfx94__
#endif
......
......@@ -3,6 +3,7 @@
#pragma once
#include "ck/utility/amd_ck_fp8.hpp"
#include "ck/utility/statically_indexed_array.hpp"
namespace ck {
......@@ -10,8 +11,6 @@ namespace ck {
using bhalf_t = ushort;
using half_t = _Float16;
using int4_t = _BitInt(4);
using f8_t = _BitInt(8);
using bf8_t = unsigned _BitInt(8);
inline constexpr auto next_pow2(uint32_t x)
{
......@@ -19,14 +18,15 @@ inline constexpr auto next_pow2(uint32_t x)
return x > 1u ? (1u << (32u - __builtin_clz(x - 1u))) : x;
}
// native types: double, float, _Float16, ushort, int32_t, int8_t, uint8_t, f8_t, bf8_t, bool
// native types: double, float, _Float16, ushort, int32_t, int8_t, uint8_t, f8_fnuz_t, bf8_fnuz_t,
// native types: bool
template <typename T>
inline constexpr bool is_native_type()
{
return is_same<T, double>::value || is_same<T, float>::value || is_same<T, half_t>::value ||
is_same<T, bhalf_t>::value || is_same<T, int32_t>::value || is_same<T, int8_t>::value ||
is_same<T, uint8_t>::value || is_same<T, f8_t>::value || is_same<T, bf8_t>::value ||
is_same<T, bool>::value;
is_same<T, uint8_t>::value || is_same<T, f8_fnuz_t>::value ||
is_same<T, bf8_fnuz_t>::value || is_same<T, bool>::value;
}
// vector_type
......@@ -166,16 +166,30 @@ struct scalar_type<int4_t>
#endif
template <>
struct scalar_type<f8_t>
struct scalar_type<f8_fnuz_t>
{
using type = f8_t;
using type = f8_fnuz_t;
static constexpr index_t vector_size = 1;
};
template <>
struct scalar_type<bf8_t>
struct scalar_type<bf8_fnuz_t>
{
using type = bf8_t;
using type = bf8_fnuz_t;
static constexpr index_t vector_size = 1;
};
template <>
struct scalar_type<f8_ocp_t>
{
using type = f8_ocp_t::data_type;
static constexpr index_t vector_size = 1;
};
template <>
struct scalar_type<bf8_ocp_t>
{
using type = bf8_ocp_t::data_type;
static constexpr index_t vector_size = 1;
};
......@@ -1010,60 +1024,203 @@ struct vector_type<T, 256, typename std::enable_if_t<is_native_type<T>()>>
}
};
template <typename T, index_t N, typename Enable = void>
struct non_native_vector_base;
template <typename T>
struct nnvb_data_t_selector
{
using type = unsigned _BitInt(8 * sizeof(T));
};
template <>
struct nnvb_data_t_selector<f8_ocp_t>
{
using type = f8_ocp_t::data_type;
};
template <>
struct nnvb_data_t_selector<bf8_ocp_t>
{
using type = bf8_ocp_t::data_type;
};
template <typename T, index_t N>
struct non_native_vector_base<
T,
N,
std::enable_if_t<sizeof(T) == 1 || sizeof(T) == 2 || sizeof(T) == 4 || sizeof(T) == 8>>
{
using data_t = typename nnvb_data_t_selector<T>::type; // select data_t based on the size of T
static_assert(sizeof(T) == sizeof(data_t), "non_native_vector_base storage size mismatch");
using data_v = data_t __attribute__((ext_vector_type(N)));
using type = non_native_vector_base<T, N>;
union alignas(next_pow2(N * sizeof(T)))
{
data_v dN; // storage vector;
StaticallyIndexedArray<data_t, N> dxN;
StaticallyIndexedArray<T, N> dTxN;
StaticallyIndexedArray<data_v, 1> dNx1;
} data_;
__host__ __device__ constexpr non_native_vector_base(data_t a) : data_{data_v(a)} {}
__host__ __device__ constexpr non_native_vector_base(T f)
: non_native_vector_base(bit_cast<data_t>(f))
{
}
__host__ __device__ constexpr non_native_vector_base() : non_native_vector_base(T{}){};
__host__ __device__ constexpr non_native_vector_base(data_v v) : data_{v} {}
__host__ __device__ constexpr operator data_v() const { return data_.dN; }
__host__ __device__ constexpr operator data_t() const
{
if constexpr(N == 1)
{
return data_.dxN[Number<0>{}];
}
else
{
return data_.dxN; // XXX this should cause an error
}
}
__host__ __device__ constexpr operator T() const
{
if constexpr(N == 1)
{
return data_.dTxN[Number<0>{}];
}
else
{
return data_.dTxN; // XXX this should cause an error
}
}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same_v<X, data_t> || is_same_v<X, T> || is_same_v<X, data_v>,
"Something went wrong, please check src and dst types.");
if constexpr(is_same_v<X, data_t>)
{
return data_.dxN;
}
else if constexpr(is_same_v<X, T>)
{
return data_.dTxN;
}
else if constexpr(is_same_v<X, data_v>)
{
return data_.dNx1;
}
else
{
return err;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same_v<X, data_t> || is_same_v<X, T> || is_same_v<X, data_v>,
"Something went wrong, please check src and dst types.");
if constexpr(is_same_v<X, data_t>)
{
return data_.dxN;
}
else if constexpr(is_same_v<X, T>)
{
return data_.dTxN;
}
else if constexpr(is_same_v<X, data_v>)
{
return data_.dNx1;
}
else
{
return err;
}
}
};
template <typename T, index_t N>
struct non_native_vector_base
struct scalar_type<non_native_vector_base<T, N>>;
template <index_t N>
struct scalar_type<non_native_vector_base<f8_ocp_t, N>>
{
using type = non_native_vector_base<T, N>;
using type = typename non_native_vector_base<f8_ocp_t, N>::data_t;
static constexpr index_t vector_size = N;
};
__host__ __device__ non_native_vector_base() = default;
__host__ __device__ non_native_vector_base(const type&) = default;
__host__ __device__ non_native_vector_base(type&&) = default;
__host__ __device__ ~non_native_vector_base() = default;
template <index_t N>
struct scalar_type<non_native_vector_base<bf8_ocp_t, N>>
{
using type = typename non_native_vector_base<bf8_ocp_t, N>::data_t;
T d[N];
static constexpr index_t vector_size = N;
};
// non-native vector_type implementation
template <typename T>
struct vector_type<T, 1, typename std::enable_if_t<!is_native_type<T>()>>
{
using d1_t = T;
using type = d1_t;
using d1_t = T;
using d1_nnv_t = non_native_vector_base<T, 1>;
using type = d1_nnv_t;
union alignas(next_pow2(1 * sizeof(T)))
{
d1_t d1_;
StaticallyIndexedArray<d1_t, 1> d1x1_;
d1_nnv_t d1_nnv_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{}} {}
__host__ __device__ constexpr vector_type() : data_{d1_t{}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value,
static_assert(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value,
"Something went wrong, please check src and dst types.");
return data_.d1x1_;
if constexpr(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value)
{
return data_.d1x1_;
}
else
{
return err;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value,
static_assert(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value,
"Something went wrong, please check src and dst types.");
return data_.d1x1_;
if constexpr(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value)
{
return data_.d1x1_;
}
else
{
return err;
}
}
};
template <typename T>
struct vector_type<T, 2, typename std::enable_if_t<!is_native_type<T>()>>
{
using d1_t = T;
using d2_t = non_native_vector_base<T, 2>;
using d1_t = T;
using d1_nnv_t = non_native_vector_base<T, 1>;
using d2_t = non_native_vector_base<T, 2>;
using type = d2_t;
......@@ -1081,10 +1238,11 @@ struct vector_type<T, 2, typename std::enable_if_t<!is_native_type<T>()>>
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value,
static_assert(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value ||
is_same<X, d2_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
if constexpr(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value)
{
return data_.d1x2_;
}
......@@ -1101,10 +1259,11 @@ struct vector_type<T, 2, typename std::enable_if_t<!is_native_type<T>()>>
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value,
static_assert(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value ||
is_same<X, d2_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
if constexpr(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value)
{
return data_.d1x2_;
}
......@@ -1122,9 +1281,10 @@ struct vector_type<T, 2, typename std::enable_if_t<!is_native_type<T>()>>
template <typename T>
struct vector_type<T, 4, typename std::enable_if_t<!is_native_type<T>()>>
{
using d1_t = T;
using d2_t = non_native_vector_base<T, 2>;
using d4_t = non_native_vector_base<T, 4>;
using d1_t = T;
using d1_nnv_t = non_native_vector_base<T, 1>;
using d2_t = non_native_vector_base<T, 2>;
using d4_t = non_native_vector_base<T, 4>;
using type = d4_t;
......@@ -1143,10 +1303,11 @@ struct vector_type<T, 4, typename std::enable_if_t<!is_native_type<T>()>>
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d4_t>::value,
static_assert(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value ||
is_same<X, d2_t>::value || is_same<X, d4_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
if constexpr(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value)
{
return data_.d1x4_;
}
......@@ -1167,10 +1328,11 @@ struct vector_type<T, 4, typename std::enable_if_t<!is_native_type<T>()>>
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d4_t>::value,
static_assert(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value ||
is_same<X, d2_t>::value || is_same<X, d4_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
if constexpr(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value)
{
return data_.d1x4_;
}
......@@ -1192,10 +1354,11 @@ struct vector_type<T, 4, typename std::enable_if_t<!is_native_type<T>()>>
template <typename T>
struct vector_type<T, 8, typename std::enable_if_t<!is_native_type<T>()>>
{
using d1_t = T;
using d2_t = non_native_vector_base<T, 2>;
using d4_t = non_native_vector_base<T, 4>;
using d8_t = non_native_vector_base<T, 8>;
using d1_t = T;
using d1_nnv_t = non_native_vector_base<T, 1>;
using d2_t = non_native_vector_base<T, 2>;
using d4_t = non_native_vector_base<T, 4>;
using d8_t = non_native_vector_base<T, 8>;
using type = d8_t;
......@@ -1215,11 +1378,12 @@ struct vector_type<T, 8, typename std::enable_if_t<!is_native_type<T>()>>
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value,
static_assert(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value ||
is_same<X, d2_t>::value || is_same<X, d4_t>::value ||
is_same<X, d8_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
if constexpr(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value)
{
return data_.d1x8_;
}
......@@ -1244,11 +1408,12 @@ struct vector_type<T, 8, typename std::enable_if_t<!is_native_type<T>()>>
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value,
static_assert(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value ||
is_same<X, d2_t>::value || is_same<X, d4_t>::value ||
is_same<X, d8_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
if constexpr(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value)
{
return data_.d1x8_;
}
......@@ -1274,11 +1439,12 @@ struct vector_type<T, 8, typename std::enable_if_t<!is_native_type<T>()>>
template <typename T>
struct vector_type<T, 16, typename std::enable_if_t<!is_native_type<T>()>>
{
using d1_t = T;
using d2_t = non_native_vector_base<T, 2>;
using d4_t = non_native_vector_base<T, 4>;
using d8_t = non_native_vector_base<T, 8>;
using d16_t = non_native_vector_base<T, 16>;
using d1_t = T;
using d1_nnv_t = non_native_vector_base<T, 1>;
using d2_t = non_native_vector_base<T, 2>;
using d4_t = non_native_vector_base<T, 4>;
using d8_t = non_native_vector_base<T, 8>;
using d16_t = non_native_vector_base<T, 16>;
using type = d16_t;
......@@ -1299,12 +1465,12 @@ struct vector_type<T, 16, typename std::enable_if_t<!is_native_type<T>()>>
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value,
static_assert(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value ||
is_same<X, d2_t>::value || is_same<X, d4_t>::value ||
is_same<X, d8_t>::value || is_same<X, d16_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
if constexpr(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value)
{
return data_.d1x16_;
}
......@@ -1333,12 +1499,12 @@ struct vector_type<T, 16, typename std::enable_if_t<!is_native_type<T>()>>
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value,
static_assert(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value ||
is_same<X, d2_t>::value || is_same<X, d4_t>::value ||
is_same<X, d8_t>::value || is_same<X, d16_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
if constexpr(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value)
{
return data_.d1x16_;
}
......@@ -1632,20 +1798,70 @@ using int8x32_t = typename vector_type<int8_t, 32>::type;
using int8x64_t = typename vector_type<int8_t, 64>::type;
// f8
using f8x2_t = typename vector_type<f8_t, 2>::type;
using f8x4_t = typename vector_type<f8_t, 4>::type;
using f8x8_t = typename vector_type<f8_t, 8>::type;
using f8x16_t = typename vector_type<f8_t, 16>::type;
using f8x32_t = typename vector_type<f8_t, 32>::type;
using f8x64_t = typename vector_type<f8_t, 64>::type;
using f8x2_fnuz_t = typename vector_type<f8_fnuz_t, 2>::type;
using f8x4_fnuz_t = typename vector_type<f8_fnuz_t, 4>::type;
using f8x8_fnuz_t = typename vector_type<f8_fnuz_t, 8>::type;
using f8x16_fnuz_t = typename vector_type<f8_fnuz_t, 16>::type;
using f8x32_fnuz_t = typename vector_type<f8_fnuz_t, 32>::type;
using f8x64_fnuz_t = typename vector_type<f8_fnuz_t, 64>::type;
// bf8
using bf8x2_t = typename vector_type<bf8_t, 2>::type;
using bf8x4_t = typename vector_type<bf8_t, 4>::type;
using bf8x8_t = typename vector_type<bf8_t, 8>::type;
using bf8x16_t = typename vector_type<bf8_t, 16>::type;
using bf8x32_t = typename vector_type<bf8_t, 32>::type;
using bf8x64_t = typename vector_type<bf8_t, 64>::type;
using bf8x2_fnuz_t = typename vector_type<bf8_fnuz_t, 2>::type;
using bf8x4_fnuz_t = typename vector_type<bf8_fnuz_t, 4>::type;
using bf8x8_fnuz_t = typename vector_type<bf8_fnuz_t, 8>::type;
using bf8x16_fnuz_t = typename vector_type<bf8_fnuz_t, 16>::type;
using bf8x32_fnuz_t = typename vector_type<bf8_fnuz_t, 32>::type;
using bf8x64_fnuz_t = typename vector_type<bf8_fnuz_t, 64>::type;
// f8
using f8x2_ocp_t = typename vector_type<f8_ocp_t, 2>::type;
using f8x4_ocp_t = typename vector_type<f8_ocp_t, 4>::type;
using f8x8_ocp_t = typename vector_type<f8_ocp_t, 8>::type;
using f8x16_ocp_t = typename vector_type<f8_ocp_t, 16>::type;
using f8x32_ocp_t = typename vector_type<f8_ocp_t, 32>::type;
using f8x64_ocp_t = typename vector_type<f8_ocp_t, 64>::type;
// bf8
using bf8x2_ocp_t = typename vector_type<bf8_ocp_t, 2>::type;
using bf8x4_ocp_t = typename vector_type<bf8_ocp_t, 4>::type;
using bf8x8_ocp_t = typename vector_type<bf8_ocp_t, 8>::type;
using bf8x16_ocp_t = typename vector_type<bf8_ocp_t, 16>::type;
using bf8x32_ocp_t = typename vector_type<bf8_ocp_t, 32>::type;
using bf8x64_ocp_t = typename vector_type<bf8_ocp_t, 64>::type;
#if CK_FP8_TYPE_OCP
// f8
using f8x2_t = f8x2_ocp_t;
using f8x4_t = f8x4_ocp_t;
using f8x8_t = f8x8_ocp_t;
using f8x16_t = f8x16_ocp_t;
using f8x32_t = f8x32_ocp_t;
using f8x64_t = f8x64_ocp_t;
// bf8
using bf8x2_t = bf8x2_ocp_t;
using bf8x4_t = bf8x4_ocp_t;
using bf8x8_t = bf8x8_ocp_t;
using bf8x16_t = bf8x16_ocp_t;
using bf8x32_t = bf8x32_ocp_t;
using bf8x64_t = bf8x64_ocp_t;
#elif CK_FP8_TYPE_FNUZ
// f8
using f8x2_t = f8x2_fnuz_t;
using f8x4_t = f8x4_fnuz_t;
using f8x8_t = f8x8_fnuz_t;
using f8x16_t = f8x16_fnuz_t;
using f8x32_t = f8x32_fnuz_t;
using f8x64_t = f8x64_fnuz_t;
// bf8
using bf8x2_t = bf8x2_fnuz_t;
using bf8x4_t = bf8x4_fnuz_t;
using bf8x8_t = bf8x8_fnuz_t;
using bf8x16_t = bf8x16_fnuz_t;
using bf8x32_t = bf8x32_fnuz_t;
using bf8x64_t = bf8x64_fnuz_t;
#endif
// u8
using uint8x2_t = typename vector_type<uint8_t, 2>::type;
......@@ -1702,7 +1918,7 @@ struct NumericLimits<int4_t>
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template <>
struct NumericLimits<f8_t>
struct NumericLimits<f8_fnuz_t>
{
// negative zero nan mode with exp bias = 8
static constexpr uint8_t binary_min = 0x08; // 0b00001000
......@@ -1715,17 +1931,17 @@ struct NumericLimits<f8_t>
// static constexpr uint8_t binary_lowest = 0xF7; // 0b11110111
// static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=0
__host__ __device__ static constexpr f8_t Min() { return f8_t(binary_min); }
__host__ __device__ static constexpr f8_fnuz_t Min() { return f8_fnuz_t(binary_min); }
__host__ __device__ static constexpr f8_t Max() { return f8_t(binary_max); }
__host__ __device__ static constexpr f8_fnuz_t Max() { return f8_fnuz_t(binary_max); }
__host__ __device__ static constexpr f8_t Lowest() { return f8_t(binary_lowest); }
__host__ __device__ static constexpr f8_fnuz_t Lowest() { return f8_fnuz_t(binary_lowest); }
__host__ __device__ static constexpr f8_t QuietNaN() { return f8_t(binary_qnan); }
__host__ __device__ static constexpr f8_fnuz_t QuietNaN() { return f8_fnuz_t(binary_qnan); }
};
template <>
struct NumericLimits<bf8_t>
struct NumericLimits<bf8_fnuz_t>
{
// negative zero nan mode with exp bias = 16
static constexpr uint8_t binary_min = 0x04; // 0b00000100
......@@ -1738,13 +1954,59 @@ struct NumericLimits<bf8_t>
// static constexpr uint8_t binary_lowest = 0xFB; // 0b11111011
// static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=
__host__ __device__ static constexpr bf8_t Min() { return bf8_t(binary_min); }
__host__ __device__ static constexpr bf8_fnuz_t Min() { return bf8_fnuz_t(binary_min); }
__host__ __device__ static constexpr bf8_t Max() { return bf8_t(binary_max); }
__host__ __device__ static constexpr bf8_fnuz_t Max() { return bf8_fnuz_t(binary_max); }
__host__ __device__ static constexpr bf8_t Lowest() { return bf8_t(binary_lowest); }
__host__ __device__ static constexpr bf8_fnuz_t Lowest() { return bf8_fnuz_t(binary_lowest); }
__host__ __device__ static constexpr bf8_t QuietNaN() { return bf8_t(binary_qnan); }
__host__ __device__ static constexpr bf8_fnuz_t QuietNaN() { return bf8_fnuz_t(binary_qnan); }
};
template <>
struct NumericLimits<f8_ocp_t>
{
static constexpr uint8_t binary_min = 0x08; // 0b00001000 = 2^-6
static constexpr uint8_t binary_max = 0x7E; // 0b01111110 = 448
static constexpr uint8_t binary_lowest = 0xFE; // 0b11111110 = -448
static constexpr uint8_t binary_qnan = 0x7F; // 0b01111111
__host__ __device__ static constexpr f8_ocp_t Min() { return bit_cast<f8_ocp_t>(binary_min); }
__host__ __device__ static constexpr f8_ocp_t Max() { return bit_cast<f8_ocp_t>(binary_max); }
__host__ __device__ static constexpr f8_ocp_t Lowest()
{
return bit_cast<f8_ocp_t>(binary_lowest);
}
__host__ __device__ static constexpr f8_ocp_t QuietNaN()
{
return bit_cast<f8_ocp_t>(binary_qnan);
}
};
template <>
struct NumericLimits<bf8_ocp_t>
{
static constexpr uint8_t binary_min = 0x04; // 0b00000100 = 2^-14
static constexpr uint8_t binary_max = 0x7B; // 0b01111011 = 57344
static constexpr uint8_t binary_lowest = 0xFB; // 0b11111011 = -57344
static constexpr uint8_t binary_qnan = 0x7D; // 0b01111101
__host__ __device__ static constexpr bf8_ocp_t Min() { return bit_cast<bf8_ocp_t>(binary_min); }
__host__ __device__ static constexpr bf8_ocp_t Max() { return bit_cast<bf8_ocp_t>(binary_max); }
__host__ __device__ static constexpr bf8_ocp_t Lowest()
{
return bit_cast<bf8_ocp_t>(binary_lowest);
}
__host__ __device__ static constexpr bf8_ocp_t QuietNaN()
{
return bit_cast<bf8_ocp_t>(binary_qnan);
}
};
template <typename T>
......@@ -1787,7 +2049,7 @@ struct NumericUtils<half_t>
};
template <>
struct NumericUtils<f8_t>
struct NumericUtils<f8_fnuz_t>
{
static constexpr int exp = 4;
static constexpr int mant = 3;
......@@ -1796,13 +2058,28 @@ struct NumericUtils<f8_t>
};
template <>
struct NumericUtils<bf8_t>
struct NumericUtils<bf8_fnuz_t>
{
static constexpr int exp = 5;
static constexpr int mant = 2;
static constexpr int bias = 16; // negative zero nan mode
// static constexpr int bias = 15; // ieee mode
};
template <>
struct NumericUtils<f8_ocp_t>
{
static constexpr int exp = 4;
static constexpr int mant = 3;
static constexpr int bias = 7;
};
template <>
struct NumericUtils<bf8_ocp_t>
{
static constexpr int exp = 5;
static constexpr int mant = 2;
static constexpr int bias = 15;
};
template <>
struct NumericUtils<bhalf_t>
......
......@@ -5,7 +5,6 @@
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
namespace ck {
......
......@@ -80,7 +80,7 @@ static inline __host__ bool isnan(half_t x)
return (xx & 0x7FFF) > 0x7C00;
};
static inline __host__ bool isnan(f8_t x) { return (x & 0x80); };
static inline __host__ bool isnan(f8_t x) { return ck::fp8_is_nan(x); };
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
static inline __host__ bool isnan(int4_t x)
......@@ -531,7 +531,7 @@ static inline __device__ bool isnan(half_t x)
return (xx & 0x7FFF) > 0x7C00;
};
static inline __device__ bool isnan(f8_t x) { return (x & 0x80); };
static inline __device__ bool isnan(f8_t x) { return ck::fp8_is_nan(x); };
static inline __device__ half_t sqrt(half_t x)
{
......@@ -611,7 +611,7 @@ inline __device__ int8_t neg<int8_t>(int8_t x)
template <>
inline __device__ half_t neg<half_t>(half_t x)
{
return __hneg(x);
return __hneg(static_cast<__half>(x));
};
template <typename T>
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
namespace ck {
// Pseudo random number generator
......@@ -23,7 +25,7 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed =
}
// version for fp16
template <typename T, uint32_t seed_t, std::enable_if_t<std::is_same<half_t, T>{}, bool> = false>
template <typename T, uint32_t seed_t, std::enable_if_t<std::is_same<_Float16, T>{}, bool> = false>
__host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = seed_t)
{
uint16_t x = *(reinterpret_cast<uint16_t*>(&val));
......@@ -38,9 +40,10 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed =
}
// return 0 if data is not fp16 or fp32
template <typename T,
uint32_t seed_t,
std::enable_if_t<!(std::is_same<float, T>{} || std::is_same<half_t, T>{}), bool> = false>
template <
typename T,
uint32_t seed_t,
std::enable_if_t<!(std::is_same<float, T>{} || std::is_same<_Float16, T>{}), bool> = false>
__host__ __device__ uint32_t prand_generator(int id, T val, uint32_t seed = seed_t)
{
std::ignore = id;
......
......@@ -9,7 +9,7 @@
#include "ck/utility/array.hpp"
namespace ck {
// Define the common macro for gfx94x models
// Define the common macro for MI300 models
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#define __gfx94__
#endif
......@@ -100,6 +100,18 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_
return type_convert<bhalf_t>(x_fp32);
}
template <>
inline __host__ __device__ constexpr f8_ocp_t type_convert<f8_ocp_t, int>(int x)
{
return f8_ocp_t{type_convert<f8_ocp_t::data_type>(x)};
}
template <>
inline __host__ __device__ constexpr bf8_ocp_t type_convert<bf8_ocp_t, int>(int x)
{
return bf8_ocp_t{type_convert<bf8_ocp_t::data_type>(x)};
}
// Convert X to Y
template <typename Y, typename X>
__host__ __device__ constexpr Y type_convert_sp(X x)
......@@ -163,7 +175,7 @@ __host__ __device__ constexpr Y f8_convert_sr(X x);
// convert fp32 to fp8 with stochastic rounding
template <>
inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
inline __host__ __device__ f8_fnuz_t f8_convert_sr<f8_fnuz_t, float>(float x)
{
constexpr int seed = 1254739;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
......@@ -189,33 +201,35 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
return utils::
cast_to_f8<float, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x,
rng);
cast_to_f8<float, f8_fnuz_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
#endif
}
// convert fp16 to fp8 with stochastic rounding
template <>
inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
inline __host__ __device__ f8_fnuz_t f8_convert_sr<f8_fnuz_t, half_t>(half_t x)
{
#if defined(__gfx94__)
// convert to float and use native converion
return f8_convert_sr<f8_t>(type_convert<float>(x));
return f8_convert_sr<f8_fnuz_t>(type_convert<float>(x));
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
constexpr int seed = 1254739;
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x);
return utils::
cast_to_f8<half_t, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
return utils::cast_to_f8<half_t,
f8_fnuz_t,
negative_zero_nan,
clip,
(rm == f8_rounding_mode::stochastic)>(x, rng);
#endif
}
// convert fp32 to bf8 with stochastic rounding
template <>
inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, float>(float x)
inline __host__ __device__ bf8_fnuz_t f8_convert_sr<bf8_fnuz_t, float>(float x)
{
constexpr int seed = 1254739;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
......@@ -240,28 +254,32 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, float>(float x)
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
return utils::
cast_to_f8<float, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
return utils::cast_to_f8<float,
bf8_fnuz_t,
negative_zero_nan,
clip,
(rm == f8_rounding_mode::stochastic)>(x, rng);
#endif
}
// convert fp16 to bf8 with stochastic rounding
template <>
inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
inline __host__ __device__ bf8_fnuz_t f8_convert_sr<bf8_fnuz_t, half_t>(half_t x)
{
#if defined(__gfx94__)
// convert to float and use native converion
return f8_convert_sr<bf8_t>(type_convert<float>(x));
return f8_convert_sr<bf8_fnuz_t>(type_convert<float>(x));
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
constexpr int seed = 1254739;
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x);
return utils::
cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
return utils::cast_to_f8<half_t,
bf8_fnuz_t,
negative_zero_nan,
clip,
(rm == f8_rounding_mode::stochastic)>(x, rng);
#endif
}
......@@ -271,7 +289,7 @@ __host__ __device__ constexpr Y f8_convert_rne(X x);
// convert fp32 to fp8 with rounding to nearest even
template <>
inline __host__ __device__ f8_t f8_convert_rne<f8_t, float>(float x)
inline __host__ __device__ f8_fnuz_t f8_convert_rne<f8_fnuz_t, float>(float x)
{
#if defined(__gfx94__)
union
......@@ -296,32 +314,34 @@ inline __host__ __device__ f8_t f8_convert_rne<f8_t, float>(float x)
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0;
return utils::
cast_to_f8<float, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x,
rng);
cast_to_f8<float, f8_fnuz_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
#endif
}
// convert fp16 to fp8 with rounding to nearest even
template <>
inline __host__ __device__ f8_t f8_convert_rne<f8_t, half_t>(half_t x)
inline __host__ __device__ f8_fnuz_t f8_convert_rne<f8_fnuz_t, half_t>(half_t x)
{
#if defined(__gfx94__)
// convert to float and use native converion
return f8_convert_rne<f8_t>(type_convert<float>(x));
return f8_convert_rne<f8_fnuz_t>(type_convert<float>(x));
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0;
return utils::
cast_to_f8<half_t, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
return utils::cast_to_f8<half_t,
f8_fnuz_t,
negative_zero_nan,
clip,
(rm == f8_rounding_mode::stochastic)>(x, rng);
#endif
}
// convert fp32 to bf8 with rounding to nearest even
template <>
inline __host__ __device__ bf8_t f8_convert_rne<bf8_t, float>(float x)
inline __host__ __device__ bf8_fnuz_t f8_convert_rne<bf8_fnuz_t, float>(float x)
{
#if defined(__gfx94__)
union
......@@ -345,44 +365,59 @@ inline __host__ __device__ bf8_t f8_convert_rne<bf8_t, float>(float x)
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0;
return utils::
cast_to_f8<float, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
return utils::cast_to_f8<float,
bf8_fnuz_t,
negative_zero_nan,
clip,
(rm == f8_rounding_mode::stochastic)>(x, rng);
#endif
}
// convert fp16 to bf8 with rounding to nearest even
template <>
inline __host__ __device__ bf8_t f8_convert_rne<bf8_t, half_t>(half_t x)
inline __host__ __device__ bf8_fnuz_t f8_convert_rne<bf8_fnuz_t, half_t>(half_t x)
{
#if defined(__gfx94__)
// convert to float and use native converion
return f8_convert_rne<bf8_t>(type_convert<float>(x));
return f8_convert_rne<bf8_fnuz_t>(type_convert<float>(x));
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0;
return utils::
cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
return utils::cast_to_f8<half_t,
bf8_fnuz_t,
negative_zero_nan,
clip,
(rm == f8_rounding_mode::stochastic)>(x, rng);
#endif
}
// convert fp32 to fp8
template <>
inline __host__ __device__ f8_fnuz_t type_convert<f8_fnuz_t, float>(float x)
{
#if CK_USE_SR_F8_CONVERSION
return f8_convert_sr<f8_fnuz_t>(x);
#else
return f8_convert_rne<f8_fnuz_t>(x);
#endif
}
// convert fp32 to fp8
template <>
inline __host__ __device__ f8_t type_convert<f8_t, float>(float x)
inline __host__ __device__ f8_ocp_t type_convert<f8_ocp_t, float>(float x)
{
#if CK_USE_SR_F8_CONVERSION
return f8_convert_sr<f8_t>(x);
return f8_convert_sr<f8_ocp_t>(x);
#else
return f8_convert_rne<f8_t>(x);
return f8_convert_rne<f8_ocp_t>(x);
#endif
}
// convert fp8 to fp32
template <>
inline __host__ __device__ float type_convert<float, f8_t>(f8_t x)
inline __host__ __device__ float type_convert<float, f8_fnuz_t>(f8_fnuz_t x)
{
#if defined(__gfx94__)
float fval;
......@@ -392,30 +427,44 @@ inline __host__ __device__ float type_convert<float, f8_t>(f8_t x)
return fval;
#else
constexpr bool negative_zero_nan = true;
return utils::cast_from_f8<f8_t, float, negative_zero_nan>(x);
return utils::cast_from_f8<f8_fnuz_t, float, negative_zero_nan>(x);
#endif
}
template <>
inline __host__ __device__ float2_t type_convert<float2_t, f8x2_t>(f8x2_t x)
inline __host__ __device__ float2_t type_convert<float2_t, f8x2_fnuz_t>(f8x2_fnuz_t x)
{
#if defined(__gfx94__)
const auto i16val = bit_cast<uint16_t>(x);
return __builtin_amdgcn_cvt_pk_f32_fp8(i16val, 0);
#else
constexpr bool negative_zero_nan = true;
const auto f8x2_v = vector_type<f8_t, 2>(x);
const auto f8x2_v = vector_type<f8_fnuz_t, 2>(x);
vector_type<float, 2> f32x2_v;
f32x2_v.template AsType<float>()(Number<0>{}) =
utils::cast_from_f8<f8_t, float, negative_zero_nan>(
f8x2_v.template AsType<f8_t>()[Number<0>{}]);
utils::cast_from_f8<f8_fnuz_t, float, negative_zero_nan>(
f8x2_v.template AsType<f8_fnuz_t>()[Number<0>{}]);
f32x2_v.template AsType<float>()(Number<1>{}) =
utils::cast_from_f8<f8_t, float, negative_zero_nan>(
f8x2_v.template AsType<f8_t>()[Number<1>{}]);
utils::cast_from_f8<f8_fnuz_t, float, negative_zero_nan>(
f8x2_v.template AsType<f8_fnuz_t>()[Number<1>{}]);
return f32x2_v.template AsType<float2_t>()[Number<0>{}];
#endif
}
template <>
inline __host__ __device__ float2_t type_convert<float2_t, f8x2_ocp_t>(f8x2_ocp_t x)
{
#if CK_OCP_FP8_CVT_FAST_PATH
return fp8_impl::cast_to_f32x2_from_f8x2<f8_ocp_t::default_interpret>(
x.AsType<fp8_impl::fp8x2_storage_t>()[Number<0>{}]);
#else
return float2_t{fp8_impl::cast_from_f8<float, f8_ocp_t::wm, f8_ocp_t::we, false>(
x.AsType<fp8_storage_t>()[Number<0>{}]),
fp8_impl::cast_from_f8<float, f8_ocp_t::wm, f8_ocp_t::we, false>(
x.AsType<fp8_storage_t>()[Number<1>{}])};
#endif
}
template <>
inline __host__ __device__ half2_t type_convert<half2_t, float2_t>(float2_t x)
{
......@@ -428,42 +477,64 @@ inline __host__ __device__ half2_t type_convert<half2_t, float2_t>(float2_t x)
// convert fp16 to fp8
template <>
inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
inline __host__ __device__ f8_fnuz_t type_convert<f8_fnuz_t, half_t>(half_t x)
{
#if CK_USE_SR_F8_CONVERSION
return f8_convert_sr<f8_t>(x);
return f8_convert_sr<f8_fnuz_t>(x);
#else
return f8_convert_rne<f8_t>(x);
return f8_convert_rne<f8_fnuz_t>(x);
#endif
}
// convert fp16 to fp8
template <>
inline __host__ __device__ f8_ocp_t type_convert<f8_ocp_t, half_t>(half_t x)
{
#if CK_USE_SR_F8_CONVERSION
return f8_convert_sr<f8_ocp_t>(x);
#else
return f8_convert_rne<f8_ocp_t>(x);
#endif
}
// convert fp8 to fp16
template <>
inline __host__ __device__ half_t type_convert<half_t, f8_t>(f8_t x)
inline __host__ __device__ half_t type_convert<half_t, f8_fnuz_t>(f8_fnuz_t x)
{
#if defined(__gfx94__)
// use native conversion to float and convert to fp16
return type_convert<half_t>(type_convert<float>(x));
#else
constexpr bool negative_zero_nan = true;
return utils::cast_from_f8<f8_t, half_t, negative_zero_nan>(x);
return utils::cast_from_f8<f8_fnuz_t, half_t, negative_zero_nan>(x);
#endif
}
// convert fp32 to bf8
template <>
inline __host__ __device__ bf8_fnuz_t type_convert<bf8_fnuz_t, float>(float x)
{
#if CK_USE_SR_F8_CONVERSION
return f8_convert_sr<bf8_fnuz_t>(x);
#else
return f8_convert_rne<bf8_fnuz_t>(x);
#endif
}
// convert fp32 to bf8
template <>
inline __host__ __device__ bf8_t type_convert<bf8_t, float>(float x)
inline __host__ __device__ bf8_ocp_t type_convert<bf8_ocp_t, float>(float x)
{
#if CK_USE_SR_F8_CONVERSION
return f8_convert_sr<bf8_t>(x);
return f8_convert_sr<bf8_ocp_t>(x);
#else
return f8_convert_rne<bf8_t>(x);
return f8_convert_rne<bf8_ocp_t>(x);
#endif
}
// convert bf8 to fp32
template <>
inline __host__ __device__ float type_convert<float, bf8_t>(bf8_t x)
inline __host__ __device__ float type_convert<float, bf8_fnuz_t>(bf8_fnuz_t x)
{
#if defined(__gfx94__)
float fval;
......@@ -473,31 +544,42 @@ inline __host__ __device__ float type_convert<float, bf8_t>(bf8_t x)
return fval;
#else
constexpr bool negative_zero_nan = true;
return utils::cast_from_f8<bf8_t, float, negative_zero_nan>(x);
return utils::cast_from_f8<bf8_fnuz_t, float, negative_zero_nan>(x);
#endif
}
// convert fp16 to bf8
template <>
inline __host__ __device__ bf8_fnuz_t type_convert<bf8_fnuz_t, half_t>(half_t x)
{
#if CK_USE_SR_F8_CONVERSION
return f8_convert_sr<bf8_fnuz_t>(x);
#else
return f8_convert_rne<bf8_fnuz_t>(x);
#endif
}
// convert fp16 to bf8
template <>
inline __host__ __device__ bf8_t type_convert<bf8_t, half_t>(half_t x)
inline __host__ __device__ bf8_ocp_t type_convert<bf8_ocp_t, half_t>(half_t x)
{
#if CK_USE_SR_F8_CONVERSION
return f8_convert_sr<bf8_t>(x);
return f8_convert_sr<bf8_ocp_t>(x);
#else
return f8_convert_rne<bf8_t>(x);
return f8_convert_rne<bf8_ocp_t>(x);
#endif
}
// convert bf8 to fp16
template <>
inline __host__ __device__ half_t type_convert<half_t, bf8_t>(bf8_t x)
inline __host__ __device__ half_t type_convert<half_t, bf8_fnuz_t>(bf8_fnuz_t x)
{
#if defined(__gfx94__)
// use native conversion to float and convert to fp16
return type_convert<half_t>(type_convert<float>(x));
#else
constexpr bool negative_zero_nan = true;
return utils::cast_from_f8<bf8_t, half_t, negative_zero_nan>(x);
return utils::cast_from_f8<bf8_fnuz_t, half_t, negative_zero_nan>(x);
#endif
}
......
# ck_tile
[Back to the main page](../../README.md)
# Composable Kernel Tile
## concept
`ck_tile` provides a programming model with templated abstractions to enable users to implement performance-critical kernels for machine learning workloads. introduces following basic concepts to help users building your own operator
- tensor coordinate transformation, this is the core concept of layout/index transform abstraction in both compiler time and run time.
......@@ -44,5 +45,8 @@ our implementation of different device operators.
**[ops/epilogue]**
epilogue part of our kernel. We may extend this epilogue part to let users to build their own cutomized epilogues.
**[ref]**
reference implementation of cpu or gpu. This folder is supposed to include a specific header on demand.
## examples
currently we put all ck_tile related example under [/example/ck_tile](/example/ck_tile/) folder. Please check each example's subfolder.
......@@ -52,7 +52,9 @@
#include "ck_tile/core/tensor/tile_elementwise.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/tensor/tile_window_linear.hpp"
#include "ck_tile/core/tensor/tile_window_utils.hpp"
#include "ck_tile/core/tensor/update_tile.hpp"
#include "ck_tile/core/utility/amd_address_space.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/functional_with_tuple.hpp"
......@@ -62,6 +64,7 @@
#include "ck_tile/core/utility/philox_rand.hpp"
#include "ck_tile/core/utility/random.hpp"
#include "ck_tile/core/utility/reduce_operator.hpp"
#include "ck_tile/core/utility/static_counter.hpp"
#include "ck_tile/core/utility/to_sequence.hpp"
#include "ck_tile/core/utility/transpose_vectors.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
......
......@@ -621,6 +621,65 @@ CK_TILE_DEVICE void buffer_load_fence(index_t cnt = 0)
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
}
CK_TILE_DEVICE void lds_load_fence(index_t cnt = 0)
{
asm volatile("s_waitcnt lgkmcnt(%0)" : : "n"(cnt) : "memory");
}
template <typename scalar_type, index_t N, bool pre_nop = false>
struct buffer_atomic_add_if;
template <bool pre_nop>
struct buffer_atomic_add_if<bf16_t, 2, pre_nop>
{
template <typename T>
CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/,
index_t flag = 1)
{
static_assert(sizeof(T) == 4);
auto save_exec = __builtin_amdgcn_read_exec();
using mbuf_t = float;
asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
"global_atomic_pk_add_bf16 %0, %1, %2 offset:%3\n"
"s_mov_b64 exec %5"
:
: "v"(v_offset),
"v"(bit_cast<mbuf_t>(value)),
"s"(res.xy),
"n"(i_offset),
"v"(flag),
"s"(save_exec)
: "memory");
}
};
template <typename scalar_type, index_t N, bool pre_nop = false>
struct buffer_atomic_add;
template <bool pre_nop>
struct buffer_atomic_add<bf16_t, 2, pre_nop>
{
template <typename T>
CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/,
index_t /*flag = 1*/)
{
static_assert(sizeof(T) == 4);
using mbuf_t = float;
asm volatile("global_atomic_pk_add_bf16 %0, %1, %2 offset:%3"
:
: "v"(v_offset), "v"(bit_cast<mbuf_t>(value)), "s"(res.xy), "n"(i_offset)
: "memory");
}
};
namespace impl {
// below type indicate the data type used for buffer load inline asm
// clang-format off
......@@ -810,6 +869,11 @@ CK_TILE_DEVICE void buffer_store_fence(index_t cnt = 0)
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
}
CK_TILE_DEVICE auto async_load_fence_raw(index_t cnt = 0)
{
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
}
// buffer load i8
CK_TILE_DEVICE_EXTERN int8_t
llvm_amdgcn_raw_buffer_load_i8(int32x4_t srsrc,
......@@ -2378,6 +2442,45 @@ CK_TILE_DEVICE void amd_buffer_atomic_add(const thread_buffer<T, N>& src_thread_
#endif
}
template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE void amd_buffer_atomic_add_raw(const thread_buffer<T, N>& src_thread_data,
T* p_dst_wave,
const index_t dst_thread_element_offset,
const index_t dst_linear_element_offset,
const bool dst_thread_element_valid,
const index_t dst_element_space_size,
bool_constant<pre_nop> = {})
{
const int32x4_t dst_wave_buffer_resource =
make_wave_buffer_resource(p_dst_wave, dst_element_space_size * sizeof(T));
index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T);
index_t dst_linear_addr_offset = dst_linear_element_offset * sizeof(T);
if constexpr(oob_conditional_check)
{
buffer_atomic_add_if<T, N, pre_nop>{}(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
0,
dst_linear_addr_offset,
dst_thread_element_valid);
}
else
{
buffer_atomic_add<T, N, pre_nop>{}(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
0,
dst_linear_addr_offset,
1);
}
}
// buffer_atomic_max requires:
// 1) p_dst_wave must point to global memory
// 2) p_dst_wave must be a wavewise pointer.
......
......@@ -73,6 +73,24 @@ CK_TILE_DEVICE void block_sync_lds()
#endif
}
CK_TILE_DEVICE void block_sync_load_raw(index_t cnt = 0)
{
#ifdef __gfx12__
asm volatile("s_wait_loadcnt %0 \n"
"s_barrier_signal -1 \n"
"s_barrier_wait -1"
:
: "n"(cnt)
: "memory");
#else
asm volatile("s_waitcnt vmcnt(%0) \n"
"s_barrier"
:
: "n"(cnt)
: "memory");
#endif
}
CK_TILE_DEVICE void block_sync_lds_direct_load()
{
asm volatile("\
......
......@@ -102,4 +102,28 @@ CK_TILE_DEVICE T warp_shuffle(const T& v_local, uint32_t src_lane)
#endif
}
template <typename T>
CK_TILE_DEVICE auto flag_to_exec(const T& v_flag)
{
static_assert(sizeof(T) == 4);
// per-thread v_flag store into 2x sgpr
uint32x2_t exec_flag;
asm volatile("v_cmp_ge_u32 %[s_exec_flag], %[v_flag], 1"
: [s_exec_flag] "=s"(exec_flag)
: [v_flag] "v"(v_flag));
return exec_flag;
}
template <typename X, typename Y>
CK_TILE_DEVICE auto cmp_lt_to_exec(const X& x, const Y& y)
{
static_assert(sizeof(X) == 4 && sizeof(Y) == 4);
// per-thread cmp store into 2x sgpr
uint32x2_t exec_flag;
asm volatile("v_cmp_lt_u32 %[s_exec_flag], %[v_x], %[v_y]"
: [s_exec_flag] "=s"(exec_flag)
: [v_x] "v"(x), [v_y] "v"(y));
return exec_flag;
}
} // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment