Commit b097be17 authored by root's avatar root
Browse files

merge changes for upstream/latest update

parents 8a891bbd a49115b9
#pragma once
#include "common_header.hpp"
#include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "thread_group_tensor_slice_transfer_v4r1.hpp"
#include "thread_group_tensor_slice_transfer_v7.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "gridwise_gemm_pipeline_v1.hpp"
namespace ck {
// input : A[AK0, M, AK1]
// input : B[AK0, N, AK1]
// input : D0[M, N], D1[M, N], ...
// output : E[M, N]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
template <typename FloatAB,
typename FloatGemmAcc,
typename FloatCShuffle,
typename DsDataType,
typename FloatE,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename EGridDesc_M_N,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t AK1Value,
index_t BK1Value,
index_t MPerXdl,
index_t NPerXdl,
index_t MXdlPerWave,
index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
bool AThreadTransferSrcResetCoordinateAfterRun,
index_t ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BThreadTransferSrcResetCoordinateAfterRun,
index_t BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched>
struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
{
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
// K1 should be Number<...>
static constexpr auto AK0 = Number<KPerBlock / AK1Value>{};
static constexpr auto BK0 = Number<KPerBlock / BK1Value>{};
static constexpr auto AK1 = Number<AK1Value>{};
static constexpr auto BK1 = Number<BK1Value>{};
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = GridwiseGemmPipeline_v1<NumGemmKPrefetchStage>;
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{
// A matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(AK0, Number<MPerBlock>{}, AK1),
make_tuple(Number<MPerBlock + ABlockLdsExtraM>{} * AK1, AK1, I1));
}
__host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
{
// B matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(BK0, Number<NPerBlock>{}, BK1),
make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * BK1, BK1, I1));
}
__host__ __device__ static constexpr auto
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
{
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
make_naive_tensor_descriptor_packed(
make_tuple(I1,
Number<CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl>{},
I1,
Number<CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>{}));
return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
}
// ck::Tuple<const D0DataType*, const D1DataType*, ...>
static constexpr auto MakeDsGridPointer()
{
return generate_tuple(
[&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
return static_cast<const DDataType*>(nullptr);
},
Number<NumDTensor>{});
}
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// lds max alignment
constexpr auto max_lds_align = math::lcm(AK1, BK1);
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
// LDS allocation for C shuffle in LDS
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
constexpr auto c_block_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
sizeof(FloatAB),
c_block_size * sizeof(FloatCShuffle));
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Block2ETileMap>
__host__ __device__ static constexpr bool
CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const EGridDesc_M_N& e_grid_desc_m_n,
const Block2ETileMap& block_2_etile_map)
{
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
"Invalid tuning param!");
const auto M = a_grid_desc_ak0_m_ak1.GetLength(I1);
const auto N = b_grid_desc_bk0_n_bk1.GetLength(I1);
const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2);
if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1)))
return false;
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
return false;
// check gridwise gemm pipeline
const auto num_k_loop = K / KPerBlock;
if(!GridwiseGemmPipe::IsSupported(num_k_loop))
{
return false;
}
if(!block_2_etile_map.CheckValidity(e_grid_desc_m_n))
{
return false;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return true;
}
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
{
const index_t num_loop = K / KPerBlock;
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
}
__host__ __device__ static constexpr auto
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N& e_grid_desc_m_n)
{
const auto M = e_grid_desc_m_n.GetLength(I0);
const auto N = e_grid_desc_m_n.GetLength(I1);
const auto MBlock = M / MPerBlock;
const auto NBlock = N / NPerBlock;
const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
e_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
return e_grid_desc_mblock_mperblock_nblock_nperblock;
}
// return block_id to E matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto
MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n)
{
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, EGridDesc_M_N>(
e_grid_desc_m_n);
}
using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
using DefaultBlock2ETileMap =
remove_cvref_t<decltype(MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
using DsGridPointer = decltype(MakeDsGridPointer());
template <bool HasMainKBlockLoop, typename Block2ETileMap>
__device__ static void
Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
DsGridPointer p_ds_grid,
FloatE* __restrict__ p_e_grid,
void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const StaticallyIndexedArray<EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
NumDTensor>&
ds_grid_desc_mblock_mperblock_nblock_nperblock, // FIXME: Ds desc may be of different
// type from E
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
e_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2ETileMap& block_2_etile_map)
{
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
const auto ds_grid_buf = generate_tuple(
[&](auto i) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ds_grid[i],
ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize());
},
Number<NumDTensor>{});
auto e_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// divide block work by [M, N]
const auto block_work_idx =
block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!block_2_etile_map.ValidCTileIndex(
block_work_idx,
make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
{
return;
}
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
// lds max alignment
constexpr auto max_lds_align = math::lcm(AK1, BK1);
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// A matrix blockwise copy
auto a_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<AK0, MPerBlock, AK1>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true,
NumGemmKPrefetchStage>(
a_grid_desc_ak0_m_ak1,
make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
// B matrix blockwise copy
auto b_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<BK0, NPerBlock, BK1>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true,
NumGemmKPrefetchStage>(
b_grid_desc_bk0_n_bk1,
make_multi_index(0, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[K0PerBlock, MPerBlock] is in LDS
// b_mtx[K0PerBlock, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
constexpr index_t KPack = math::max(
math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,
FloatAB,
FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1),
MPerXdl,
NPerXdl,
MXdlPerWave,
NXdlPerWave,
KPack,
LoopSched>();
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned,
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0);
// gridwise GEMM pipeline
const auto gridwise_gemm_pipeline =
GridwiseGemmPipeline_v1_Selector<NumGemmKPrefetchStage, LoopSched>();
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
KPerBlock);
gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,
a_blockwise_copy,
a_grid_buf,
a_block_buf,
a_block_slice_copy_step,
b_grid_desc_bk0_n_bk1,
b_block_desc_bk0_n_bk1,
b_blockwise_copy,
b_grid_buf,
b_block_buf,
b_block_slice_copy_step,
blockwise_gemm,
c_thread_buf,
num_k_block_main_loop);
// shuffle C and write out
{
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
"wrong!");
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
// TODO: hacky, fix it!
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
// TODO: hacky, fix it!
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatCShuffle*>(p_shared),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_tuple(
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
M1, // M1 = MWave
M2, // M2 * M3 * M4 = MPerXdl
M3,
M4)),
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
N1, // N1 = NWave
N2))), // N2 = NPerXdl
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(
Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_block_idx =
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_block));
const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_block_idx =
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_block));
// shuffle: threadwise copy C from VGPR to LDS
auto c_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<FloatGemmAcc,
FloatCShuffle,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
ck::tensor_operation::element_wise::PassThrough,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
I1,
I1,
M2,
I1,
M4,
I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
1,
InMemoryDataOperationEnum::Set,
1,
true>{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(0,
0,
m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I3],
m_thread_data_on_block_idx[I4],
n_thread_data_on_block_idx[I2]),
ck::tensor_operation::element_wise::PassThrough{}};
// tuple of reference to C/Ds tensor descriptors
const auto c_ds_desc_refs = concat_tuple_of_reference(
tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
generate_tie(
[&](auto i) -> const auto& // return type should be reference
{ return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
Number<NumDTensor>{}));
// tuple of reference to C/Ds tensor descriptors
const auto c_ds_buf_refs = concat_tuple_of_reference(
tie(c_shuffle_block_buf),
generate_tie(
[&](auto i) -> const auto& // return type should be reference
{ return ds_grid_buf[i]; },
Number<NumDTensor>{}));
// tuple of starting index of C/Ds blockwise copy
const auto idx_c_ds_block_begin = container_concat(
make_tuple(make_multi_index(0, 0, 0, 0)),
generate_tuple(
[&](auto) {
return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0);
},
Number<NumDTensor>{}));
// blockwise copy C/D/E between LDS and global
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7<
ThisThreadBlock,
decltype(container_concat(make_tuple(FloatCShuffle{}), DsDataType{})),
Tuple<FloatE>,
decltype(c_ds_desc_refs),
decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
CDEElementwiseOperation,
Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make Sequence
// support arbitray type
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CDEShuffleBlockTransferScalarPerVector_NPerBlock,
sequence_merge_t<
Sequence<true>,
uniform_sequence_gen_t<NumDTensor,
false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
Sequence<false>> // ThreadTransferDstResetCoordinateAfterRunFlags
{c_ds_desc_refs,
idx_c_ds_block_begin,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
make_tuple(make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0)),
cde_element_op};
// space filling curve for threadwise C in VGPR before shuffle
constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
1,
1,
M2,
1,
M4,
1>>{};
// space filling curve for shuffled blockwise C/D/E
constexpr auto sfc_cde_block =
SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
Sequence<0, 2, 1, 3>,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS
block_sync_lds();
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
c_thread_buf,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c_shuffle_block_buf);
// make sure it's safe to read from LDS
block_sync_lds();
// each block copy its data from LDS to global
cde_block_copy_lds_and_global.Run(
c_ds_desc_refs,
c_ds_buf_refs,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
tie(e_grid_buf));
if constexpr(access_id < num_access - 1)
{
constexpr auto cde_lds_and_global_step =
sfc_cde_block.GetForwardStep(access_id);
// move on Ds
static_for<0, NumDTensor, 1>{}([&](auto i) {
cde_block_copy_lds_and_global.MoveSrcSliceWindow(
c_ds_desc_refs, i + I1, cde_lds_and_global_step);
});
// move on E
cde_block_copy_lds_and_global.MoveDstSliceWindow(
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
I0,
cde_lds_and_global_step);
}
});
}
}
};
} // namespace ck
...@@ -21,7 +21,7 @@ template <typename GridwiseGemm, ...@@ -21,7 +21,7 @@ template <typename GridwiseGemm,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
typename DxsInElementwiseOperation, typename DxsInElementwiseOperation,
typename DxsAccElementwiseOperation, typename DxsReduceAccElementwiseOperation,
typename AGridDesc_AK0_M_AK1, typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1, typename BGridDesc_BK0_N_BK1,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
...@@ -41,7 +41,7 @@ __global__ void ...@@ -41,7 +41,7 @@ __global__ void
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op, const CElementwiseOperation c_element_op,
const DxsInElementwiseOperation dxs_in_element_op, const DxsInElementwiseOperation dxs_in_element_op,
const DxsAccElementwiseOperation dxs_out_element_op, const DxsReduceAccElementwiseOperation dxs_out_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
...@@ -96,7 +96,7 @@ template <typename FloatAB, ...@@ -96,7 +96,7 @@ template <typename FloatAB,
typename CElementwiseOperation, typename CElementwiseOperation,
typename DxsReduceOperation, typename DxsReduceOperation,
typename DxsInElementwiseOperation, typename DxsInElementwiseOperation,
typename DxsAccElementwiseOperation, typename DxsReduceAccElementwiseOperation,
InMemoryDataOperationEnum CGlobalMemoryDataOperation, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename DGlobalMemoryDataOperation, typename DGlobalMemoryDataOperation,
typename AGridDesc_AK0_M_AK1, typename AGridDesc_AK0_M_AK1,
...@@ -329,7 +329,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -329,7 +329,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
const BElementwiseOperation& b_element_op, const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op, const CElementwiseOperation& c_element_op,
const DxsInElementwiseOperation& dxs_in_element_op, const DxsInElementwiseOperation& dxs_in_element_op,
const DxsAccElementwiseOperation& dxs_out_element_op, const DxsReduceAccElementwiseOperation& dxs_out_element_op,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
...@@ -816,7 +816,8 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -816,7 +816,8 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
false>; false>;
// Global write Gemm shuffle + reduction // Global write Gemm shuffle + reduction
const auto d_identityVal = DReduceOperation::GetIdentityValue(); const auto d_identityVal =
DReduceOperation::template GetIdentityValue<FloatReduceAcc>();
static_for<0, mreduce_per_thread, 1>{}( static_for<0, mreduce_per_thread, 1>{}(
[&](auto I) { d_thread_buf(I) = d_identityVal; }); [&](auto I) { d_thread_buf(I) = d_identityVal; });
......
...@@ -791,8 +791,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight ...@@ -791,8 +791,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
constexpr auto c_block_desc_mblock_mperblock_nblock_nperblock = constexpr auto c_block_desc_mblock_mperblock_nblock_nperblock =
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
void* p_shared = static_cast<void*>(p_shared_block);
auto c_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto c_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatC*>(p_shared_block), static_cast<FloatC*>(p_shared),
c_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); c_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
static_assert(M1 == MWave, ""); static_assert(M1 == MWave, "");
......
...@@ -249,7 +249,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 ...@@ -249,7 +249,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
}(); }();
using BlockwiseGemm = using BlockwiseGemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<ThisThreadBlock, BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize
FloatAB, FloatAB,
FloatAcc, FloatAcc,
decltype(a_k0_m_k1_block_desc), decltype(a_k0_m_k1_block_desc),
...@@ -453,7 +453,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 ...@@ -453,7 +453,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
// sanity check // sanity check
auto blockwise_gemm = auto blockwise_gemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<ThisThreadBlock, BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB, FloatAB,
FloatAcc, FloatAcc,
decltype(a_k0_m_k1_block_desc), decltype(a_k0_m_k1_block_desc),
......
...@@ -37,7 +37,7 @@ __global__ void kernel_buffer_set_value(const Grid1dBufferDescType grid_1d_buffe ...@@ -37,7 +37,7 @@ __global__ void kernel_buffer_set_value(const Grid1dBufferDescType grid_1d_buffe
{ {
using PassThroughOp = tensor_operation::element_wise::UnaryIdentic<DataType, DataType>; using PassThroughOp = tensor_operation::element_wise::PassThrough;
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
......
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2022 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef GRIDWISE_SOFTMAX_HPP
#define GRIDWISE_SOFTMAX_HPP
#include "reduction_common.hpp"
#include "reduction_operator.hpp"
#include "reduction_functions_accumulate.hpp"
#include "reduction_functions_blockwise.hpp"
#include "reduction_functions_threadwise.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "element_wise_operation.hpp"
namespace ck {
template <typename GridwiseReduction,
typename InDataType,
typename OutDataType,
typename AccDataType,
typename GridDesc_M_K>
__global__ void kernel_softmax(const GridDesc_M_K in_grid_desc_m_k,
const GridDesc_M_K out_grid_desc_m_k,
index_t block_group_size,
index_t num_k_block_tile_iteration,
AccDataType alpha,
const InDataType* const __restrict__ p_in_value_global,
AccDataType beta,
OutDataType* const __restrict__ p_out_value_global)
{
GridwiseReduction::Run(in_grid_desc_m_k,
out_grid_desc_m_k,
block_group_size,
num_k_block_tile_iteration,
alpha,
p_in_value_global,
beta,
p_out_value_global);
};
template <typename InDataType,
typename OutDataType,
typename AccDataType,
typename GridDesc_M_K,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t InSrcVectorDim,
index_t InSrcVectorSize,
index_t OutDstVectorSize>
struct GridwiseSoftmax_mk_to_mk
{
static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
(KThreadSliceSize % OutDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0);
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
using ThreadBufferDimAccessOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
using ThreadClusterArrangeOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
using BlockwiseMaxReduce = PartitionedBlockwiseReduction<AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
reduce::Max,
false>; // PropagateNan
using ThreadwiseMaxReduce = ThreadwiseReduction<AccDataType,
ThreadReduceSrcDesc_M_K,
ThreadReduceDstDesc_M,
reduce::Max,
false>; // PropagateNan
using PassThroughOp = tensor_operation::element_wise::PassThrough;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
__device__ static void Run(const GridDesc_M_K& in_grid_desc_m_k,
const GridDesc_M_K& out_grid_desc_m_k,
index_t block_group_size,
index_t num_k_block_tile_iteration,
AccDataType alpha,
const InDataType* const __restrict__ p_in_value_global,
AccDataType beta,
OutDataType* const __restrict__ p_out_value_global)
{
// LDS
__shared__ AccDataType p_reduce_work_buffer[BlockSize];
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_value_global, out_grid_desc_m_k.GetElementSpaceSize());
auto reduce_work_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
in_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
out_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> max_value_buf;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
max_value_buf(I) = reduce::Max::template GetIdentityValue<AccDataType>();
});
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) = reduce::Add::template GetIdentityValue<AccDataType>();
});
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id();
const index_t blkgroup_id = block_global_id / block_group_size;
const index_t block_local_id = block_global_id % block_group_size;
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1];
const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration;
using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
AccDataType,
GridDesc_M_K,
decltype(thread_buffer_desc),
ThreadBufferLengths,
ThreadBufferDimAccessOrder,
InSrcVectorDim,
InSrcVectorSize,
1,
false>(
in_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
block_local_id * reduceSizePerBlock +
thread_k_cluster_id * KThreadSliceSize));
auto threadwise_dst_load = ThreadwiseTensorSliceTransfer_v2<OutDataType,
AccDataType,
GridDesc_M_K,
decltype(thread_buffer_desc),
ThreadBufferLengths,
ThreadBufferDimAccessOrder,
InSrcVectorDim,
InSrcVectorSize,
1,
false>(
out_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
block_local_id * reduceSizePerBlock +
thread_k_cluster_id * KThreadSliceSize));
auto threadwise_dst_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
OutDataType,
decltype(thread_buffer_desc),
GridDesc_M_K,
PassThroughOp,
ThreadBufferLengths,
ThreadBufferDimAccessOrder,
InSrcVectorDim,
OutDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
true>(
out_grid_desc_m_k,
make_multi_index(
blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
block_local_id * reduceSizePerBlock + thread_k_cluster_id * KThreadSliceSize),
PassThroughOp{});
constexpr auto in_thread_copy_fwd_step = make_multi_index(0, K_BlockTileSize);
constexpr auto in_thread_copy_bwd_step = make_multi_index(0, -K_BlockTileSize);
///
/// max(x)
///
const auto in_global_val_buf_oob_non_zero = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(),
reduce::Max::template GetIdentityValue<InDataType>());
index_t reducedTiles = 0;
do
{
threadwise_src_load.Run(in_grid_desc_m_k,
in_global_val_buf_oob_non_zero,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_buf);
ThreadwiseMaxReduce::Reduce(in_thread_buf, max_value_buf);
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_fwd_step);
reducedTiles++;
} while(reducedTiles < num_k_block_tile_iteration);
static_for<0, MThreadSliceSize, 1>{}(
[&](auto I) { BlockwiseMaxReduce::Reduce(reduce_work_buf, max_value_buf(I)); });
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_bwd_step);
///
/// sum(exp(x - max(x)))
///
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) = reduce::Add::template GetIdentityValue<AccDataType>();
});
// Normally, 0 as invalid element value is adequate since 0 makes no contribution to
// accumulated result. However, in stable softmax, all values 0s or not are subtracted by
// another value_max. As numbers become non-zero, effectively it allows invalid values to
// slip through and contribute to the accumulated result.
//
// The trick here is leveraging the fact that many math functions (add, sub, exp, ...)
// propagate NaNs when operands have NaNs involved. By initialiing invalid element value
// with NaN, an invalid value doing math manipulations is still NaN, which in turn can still
// be identified as an invalid value. We can then discard the invalid values which
// originally failed the bound check during accumulation. This allows to ignore values that
// failed bound check even after multiple math manipulations.
const auto in_global_val_buf_oob_nan =
make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(),
NumericLimits<InDataType>::QuietNaN());
using BlockwiseSumReduce = PartitionedBlockwiseReduction<
AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
reduce::Add,
false, // ignored
detail::AccumulateWithNanIgnore<reduce::Add, AccDataType>>;
using ThreadwiseSumReduce =
ThreadwiseReduction<AccDataType,
ThreadReduceSrcDesc_M_K,
ThreadReduceDstDesc_M,
reduce::Add,
false, // ignored
detail::AccumulateWithNanIgnore<reduce::Add, AccDataType>>;
reducedTiles = 0;
do
{
threadwise_src_load.Run(in_grid_desc_m_k,
in_global_val_buf_oob_nan,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_buf);
// do element-wise pre-reduction operation
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
in_thread_buf(Number<offset>{}) =
math::exp(in_thread_buf(Number<offset>{}) - max_value_buf(iM));
});
});
ThreadwiseSumReduce::Reduce(in_thread_buf, accu_value_buf);
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_bwd_step);
reducedTiles++;
} while(reducedTiles < num_k_block_tile_iteration);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
BlockwiseSumReduce::Reduce(reduce_work_buf, accu_value_buf(I));
// block_sync_lds();
});
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_fwd_step);
///
/// softmax
///
reducedTiles = 0;
if(float_equal_zero{}(beta))
{
do
{
threadwise_src_load.Run(in_grid_desc_m_k,
in_global_val_buf_oob_nan,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
// out = alpha * exp(x - max(x)) / sum(exp(x - max(x)))
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset =
thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
out_thread_buf(Number<offset>{}) =
alpha * math::exp(in_thread_buf(Number<offset>{}) - max_value_buf(iM)) /
accu_value_buf(iM);
});
});
threadwise_dst_store.Run(thread_buffer_desc,
make_tuple(I0, I0),
out_thread_buf,
out_grid_desc_m_k,
out_global_val_buf);
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_fwd_step);
threadwise_dst_store.MoveDstSliceWindow(out_grid_desc_m_k, in_thread_copy_fwd_step);
reducedTiles++;
} while(reducedTiles < num_k_block_tile_iteration);
}
else
{
do
{
threadwise_src_load.Run(in_grid_desc_m_k,
in_global_val_buf_oob_nan,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_buf);
threadwise_dst_load.Run(out_grid_desc_m_k,
out_global_val_buf,
thread_buffer_desc,
make_tuple(I0, I0),
out_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
// out = alpha * exp(x - max(x)) / sum(exp(x - max(x))) + beta * prior_out
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset =
thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
out_thread_buf(Number<offset>{}) =
alpha * math::exp(in_thread_buf(Number<offset>{}) - max_value_buf(iM)) /
accu_value_buf(iM) +
beta * out_thread_buf(Number<offset>{});
});
});
threadwise_dst_store.Run(thread_buffer_desc,
make_tuple(I0, I0),
out_thread_buf,
out_grid_desc_m_k,
out_global_val_buf);
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_fwd_step);
threadwise_dst_store.MoveDstSliceWindow(out_grid_desc_m_k, in_thread_copy_fwd_step);
threadwise_dst_load.MoveSrcSliceWindow(out_grid_desc_m_k, in_thread_copy_fwd_step);
reducedTiles++;
} while(reducedTiles < num_k_block_tile_iteration);
}
}
};
} // namespace ck
#endif // GRIDWISE_SOFTMAX_HPP
#pragma once
#include "cluster_descriptor.hpp"
#include "data_type.hpp"
#include "element_wise_operation.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
namespace ck {
template <typename GridwiseUEltwise,
typename ADataType,
typename BDataType,
typename GridDesc_M0,
typename ElementwiseFunctor>
__global__ void kernel_unary_elementwise_1d(const ADataType* __restrict__ p_a_global,
BDataType* __restrict__ p_b_global,
const GridDesc_M0 a_grid_desc_m0,
const GridDesc_M0 b_grid_desc_m0,
const ElementwiseFunctor functor)
{
GridwiseUEltwise::Run(p_a_global, p_b_global, a_grid_desc_m0, b_grid_desc_m0, functor);
}
template <typename ADataType,
typename BDataType,
typename GridDesc_M0,
typename ElementwiseFunctor,
index_t ScalarPerVector>
struct GridwiseUnaryElementwise_1D
{
static constexpr auto I0 = Number<0>{};
static constexpr auto thread_desc_m0 =
make_naive_tensor_descriptor_packed(make_tuple(Number<ScalarPerVector>{}));
using PassThrough = tensor_operation::element_wise::PassThrough;
static __device__ auto CalculateElementwiseIndex()
{
const index_t global_thread_id = get_thread_global_1d_id();
return make_multi_index(global_thread_id * ScalarPerVector);
}
__host__ __device__ static constexpr bool CheckValidity(const GridDesc_M0 a_grid_desc_m0,
const GridDesc_M0 b_grid_desc_m0)
{
return a_grid_desc_m0.GetLength(I0) == b_grid_desc_m0.GetLength(I0);
}
__host__ __device__ static constexpr index_t CalculateGridSize(const index_t tensor_size)
{
const index_t grid_size = math::integer_divide_ceil(tensor_size, 256 * ScalarPerVector);
return grid_size;
}
__device__ static void Run(const ADataType* __restrict__ p_a_global,
BDataType* __restrict__ p_b_global,
const GridDesc_M0 a_grid_desc_m0,
const GridDesc_M0 b_grid_desc_m0,
const ElementwiseFunctor functor)
{
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_global, a_grid_desc_m0.GetElementSpaceSize());
auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_global, b_grid_desc_m0.GetElementSpaceSize());
StaticBuffer<AddressSpaceEnum::Vgpr, ADataType, ScalarPerVector, true> a_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, BDataType, ScalarPerVector, true> b_thread_buf;
const auto thread_store_global_offset = CalculateElementwiseIndex();
auto a_global_load =
ThreadwiseTensorSliceTransfer_v2<ADataType,
ADataType,
GridDesc_M0,
decltype(thread_desc_m0),
Sequence<ScalarPerVector>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
ScalarPerVector,
1, // SrcScalarStrideInVector
false>{a_grid_desc_m0, thread_store_global_offset};
auto b_global_write =
ThreadwiseTensorSliceTransfer_v1r3<BDataType,
BDataType,
decltype(thread_desc_m0),
GridDesc_M0,
PassThrough,
Sequence<ScalarPerVector>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // DstVectorDim
ScalarPerVector,
InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector
false>{
b_grid_desc_m0, thread_store_global_offset, PassThrough{}};
const index_t blockSize = get_block_size();
const index_t blockPerGrid = get_grid_size();
const auto m0 = b_grid_desc_m0.GetLength(I0);
const index_t loop_step = blockPerGrid * blockSize * ScalarPerVector;
const auto loop_step_index = make_multi_index(loop_step);
index_t num_iter = m0 / (loop_step);
do
{
// read and process ScalarPerVector elements
a_global_load.Run(
a_grid_desc_m0, a_global_buf, thread_desc_m0, make_tuple(I0), a_thread_buf);
static_for<0, ScalarPerVector, 1>{}([&](auto m) {
constexpr auto offset = thread_desc_m0.CalculateOffset(make_tuple(m));
functor(b_thread_buf(Number<offset>{}), a_thread_buf(Number<offset>{}));
});
b_global_write.Run(thread_desc_m0,
make_tuple(I0), // SrcSliceOriginIdx
b_thread_buf,
b_grid_desc_m0,
b_global_buf);
a_global_load.MoveSrcSliceWindow(a_grid_desc_m0, loop_step_index);
b_global_write.MoveDstSliceWindow(b_grid_desc_m0, loop_step_index);
} while(--num_iter);
}
};
} // namespace ck
...@@ -39,7 +39,9 @@ template <typename AccDataType, ...@@ -39,7 +39,9 @@ template <typename AccDataType,
typename SrcThreadDesc_M_K, typename SrcThreadDesc_M_K,
typename DstThreadDesc_M, typename DstThreadDesc_M,
typename OpReduce, typename OpReduce,
bool PropagateNan> bool PropagateNan,
typename Accumulation =
detail::AccumulateWithNanCheck<PropagateNan, OpReduce, AccDataType>>
struct ThreadwiseReduction struct ThreadwiseReduction
{ {
static constexpr auto src_thread_desc_m_k = SrcThreadDesc_M_K{}; static constexpr auto src_thread_desc_m_k = SrcThreadDesc_M_K{};
...@@ -51,8 +53,6 @@ struct ThreadwiseReduction ...@@ -51,8 +53,6 @@ struct ThreadwiseReduction
static_assert(src_length_m == dst_length_m, "lengths of source and dst buffer must match!"); static_assert(src_length_m == dst_length_m, "lengths of source and dst buffer must match!");
using Accumulation = detail::AccumulateWithNanCheck<PropagateNan, OpReduce, AccDataType>;
template <typename SrcBufferType, typename DstBufferType> template <typename SrcBufferType, typename DstBufferType>
__device__ static void Reduce(const SrcBufferType& src_buf, DstBufferType& dst_buf) __device__ static void Reduce(const SrcBufferType& src_buf, DstBufferType& dst_buf)
{ {
...@@ -73,12 +73,15 @@ struct ThreadwiseReduction ...@@ -73,12 +73,15 @@ struct ThreadwiseReduction
// 2) DstDesc is known at compile-time // 2) DstDesc is known at compile-time
// 3) SrcBuffer is static buffer // 3) SrcBuffer is static buffer
// 4) DstBuffer is static buffer // 4) DstBuffer is static buffer
template <typename AccDataType, template <
typename IndexDataType, typename AccDataType,
typename SrcThreadDesc_M_K, typename IndexDataType,
typename DstThreadDesc_M, typename SrcThreadDesc_M_K,
typename OpReduce, typename DstThreadDesc_M,
bool PropagateNan> typename OpReduce,
bool PropagateNan,
typename Accumulation =
detail::AccumulateWithIndexAndNanCheck<PropagateNan, OpReduce, AccDataType, IndexDataType>>
struct ThreadwiseReductionWithIndex struct ThreadwiseReductionWithIndex
{ {
static constexpr auto src_thread_desc_m_k = SrcThreadDesc_M_K{}; static constexpr auto src_thread_desc_m_k = SrcThreadDesc_M_K{};
...@@ -90,9 +93,6 @@ struct ThreadwiseReductionWithIndex ...@@ -90,9 +93,6 @@ struct ThreadwiseReductionWithIndex
static_assert(src_length_m == dst_length_m, "lengths of source and dst buffer must match!"); static_assert(src_length_m == dst_length_m, "lengths of source and dst buffer must match!");
using Accumulation =
detail::AccumulateWithIndexAndNanCheck<PropagateNan, OpReduce, AccDataType, IndexDataType>;
template <typename SrcValueBufferType, template <typename SrcValueBufferType,
typename SrcIndexBufferType, typename SrcIndexBufferType,
typename DstValueBufferType, typename DstValueBufferType,
......
#pragma once
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_space_filling_curve.hpp"
namespace ck {
// Thread-level multi-source, multi-destination tensor slice data movement
// Assume:
// 1. All sources and destinations are DynamicBuffer
// 2. Same VectorDim and ScalerPerVector for all sources and destinations
// 3. DstInMemOps are per destination tensor
// 4. ThreadTransferSrcResetCoordinateAfterRunFlags are per source tensor
// 5. ThreadTransferDstResetCoordinateAfterRunFlags are per destination tensor
// 6. Does not need to know src_descs and dst_descs at compile-time
// 7. Does not need to know src_slice_origins and dst_slice_origins at compile-time,
//
// Does following things to avoid scratch memory issue
// 1. Use StaticallyIndexedArray or vector_type instead of C array for thread buffer
// 2. Pass tensor descritpors by reference (or tuple of references)
// 3. Does not keep reference to tensor descriptor
// 4. Does not construct new tensor coordinate when call Run()
template <typename SrcDatas,
typename DstDatas,
typename SrcDescs,
typename DstDescs,
typename ElementwiseOperation,
typename DstInMemOps, // Sequence<InMemoryDataOperationEnum ...>
typename SliceLengths,
typename DimAccessOrder,
index_t VectorDim,
index_t ScalarPerVector,
typename SrcResetCoordinateAfterRunFlags, // Sequence<bool ...>
typename DstResetCoordinateAfterRunFlags> // Sequence<bool ...>
struct ThreadwiseTensorSliceTransfer_v7
{
static constexpr auto I0 = Number<0>{};
static constexpr index_t nDim = SliceLengths::Size();
static constexpr index_t nSrc = SrcDescs::Size();
static constexpr index_t nDst = DstDescs::Size();
using Index = MultiIndex<nDim>;
// return a tuple of coordiantes for a tuple of tensor
template <typename Descs,
typename Indices,
enable_if_t<Descs::Size() == Indices::Size(), bool> = false>
static constexpr auto MakeCoordinates(const Descs& descs, const Indices& indices)
{
return generate_tuple([&](auto i) { return make_tensor_coordinate(descs[i], indices[i]); },
Number<Descs::Size()>{});
}
using SrcCoords = decltype(MakeCoordinates(SrcDescs{}, StaticallyIndexedArray<Index, nSrc>{}));
using DstCoords = decltype(MakeCoordinates(DstDescs{}, StaticallyIndexedArray<Index, nDst>{}));
// scalar per access on each dim
// FIXME: don't use lambda_scalar_per_access
static constexpr auto scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<VectorDim, ScalarPerVector>{}, Number<nDim>{});
using SpaceFillingCurve =
SpaceFillingCurve<SliceLengths, DimAccessOrder, remove_cv_t<decltype(scalar_per_access)>>;
__device__ constexpr ThreadwiseTensorSliceTransfer_v7(
const SrcDescs& src_descs,
const StaticallyIndexedArray<Index, nSrc>& src_slice_origins,
const DstDescs& dst_descs,
const StaticallyIndexedArray<Index, nDst>& dst_slice_origins,
const ElementwiseOperation& element_op)
: src_coords_(MakeCoordinates(src_descs, src_slice_origins)),
dst_coords_(MakeCoordinates(dst_descs, dst_slice_origins)),
element_op_(element_op)
{
static_assert(SliceLengths::At(Number<VectorDim>{}) % ScalarPerVector == 0,
"wrong! cannot evenly divide");
}
template <typename Indices, enable_if_t<SrcDescs::Size() == Indices::Size(), bool> = false>
__device__ void SetSrcSliceOrigins(const SrcDescs& src_descs,
const Indices& src_slice_origin_idxs)
{
static_for<0, nSrc, 1>{}([&](auto i) {
src_coords_(i) = make_tensor_coordinate(src_descs[i], src_slice_origin_idxs[i]);
});
}
template <typename Indices, enable_if_t<DstDescs::Size() == Indices::Size(), bool> = false>
__device__ void SetDstSliceOrigins(const DstDescs& dst_descs,
const Indices& dst_slice_origin_idxs)
{
static_for<0, nDst, 1>{}([&](auto i) {
dst_coords_(i) = make_tensor_coordinate(dst_descs[i], dst_slice_origin_idxs[i]);
});
}
// SrcDescs: Tuple<const SrcDesc0&, const SrcDesc1&, ...>
// SrcBuffers: Tuple<const SrcBuffer0&, const SrcBuffer1&, ...>
// DstDescs: Tuple<const DstDesc0&, const DstDesc1&, ...>
// DstBuffers: Tuple<const DstBuffer0&, const DstBuffer1&, ...>
template <typename SrcBuffers,
typename DstBuffers,
enable_if_t<SrcDescs::Size() == SrcBuffers::Size() &&
DstDescs::Size() == DstBuffers::Size(),
bool> = false>
__device__ void Run(const SrcDescs& src_descs,
const SrcBuffers& src_bufs,
const DstDescs& dst_descs,
DstBuffers dst_bufs)
{
auto generate_vectors = [&](auto data_types) {
constexpr index_t num = data_types.Size();
return generate_tuple(
[&](auto i) {
using DataType = remove_cvref_t<decltype(data_types[i])>;
return vector_type_maker_t<DataType, ScalarPerVector>{};
},
Number<num>{});
};
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
// loop over space-filling curve
static_for<0, num_access, 1>{}([&](auto iAccess) {
auto src_vectors = generate_vectors(SrcDatas{});
auto dst_vectors = generate_vectors(DstDatas{});
// copy data from src_bufs into src_vectors
static_for<0, nSrc, 1>{}([&](auto i) {
using src_vector_t = typename remove_cvref_t<decltype(src_vectors[i])>::type;
const bool is_src_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_descs[i],
src_coords_[i]);
src_vectors(i).template AsType<src_vector_t>()(I0) =
src_bufs[i].template Get<src_vector_t>(src_coords_[i].GetOffset(),
is_src_valid);
});
// apply pointwise function
static_for<0, ScalarPerVector, 1>{}([&](auto i) {
// get reference to src data
const auto src_data_refs = generate_tie(
// return type should be lvalue
[&](auto iSrc) -> const auto& {
using SrcData = remove_cvref_t<tuple_element_t<iSrc.value, SrcDatas>>;
return src_vectors[iSrc].template AsType<SrcData>()[i];
},
Number<nSrc>{});
// get reference to dst data
auto dst_data_refs = generate_tie(
// return type should be lvalue
[&](auto iDst) -> auto& {
using DstData = remove_cvref_t<tuple_element_t<iDst.value, DstDatas>>;
return dst_vectors(iDst).template AsType<DstData>()(i);
},
Number<nDst>{});
// apply pointwise function
// pointwise function signature:
// element_op_(dst_data_refs[I0],
// dst_data_refs[I1],
// ...,
// src_data_refs[I0],
// src_data_refs[I1],
// ...)
unpack2(element_op_, dst_data_refs, src_data_refs);
});
// copy data from buf_vectors into dst_bufs
static_for<0, nDst, 1>{}([&](auto i) {
using dst_vector_t = typename remove_cvref_t<decltype(dst_vectors[i])>::type;
const bool is_dst_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_descs[i],
dst_coords_[i]);
constexpr InMemoryDataOperationEnum DstInMemOp =
static_cast<InMemoryDataOperationEnum>(DstInMemOps::At(i.value));
dst_bufs(i).template Update<DstInMemOp, dst_vector_t>(
dst_coords_[i].GetOffset(),
is_dst_valid,
dst_vectors[i].template AsType<dst_vector_t>()[I0]);
});
// move coordinate
if constexpr(iAccess.value != num_access - 1)
{
constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(iAccess);
static_for<0, nSrc, 1>{}([&](auto i) {
move_tensor_coordinate(src_descs[i],
src_coords_(i),
make_tensor_coordinate_step(src_descs[i], forward_step));
});
static_for<0, nDst, 1>{}([&](auto i) {
move_tensor_coordinate(dst_descs[i],
dst_coords_(i),
make_tensor_coordinate_step(dst_descs[i], forward_step));
});
}
});
// move coordinate back to slice origin (or not)
static_for<0, nSrc, 1>{}([&](auto i) {
if constexpr(SrcResetCoordinateAfterRunFlags::At(i))
{
const auto src_reset_step =
make_tensor_coordinate_step(src_descs[i], GetCoordinateResetStep());
move_tensor_coordinate(src_descs[i], src_coords_(i), src_reset_step);
}
});
static_for<0, nDst, 1>{}([&](auto i) {
if constexpr(DstResetCoordinateAfterRunFlags::At(i))
{
const auto dst_reset_step =
make_tensor_coordinate_step(dst_descs[i], GetCoordinateResetStep());
move_tensor_coordinate(dst_descs[i], dst_coords_(i), dst_reset_step);
}
});
}
__device__ static constexpr auto GetCoordinateResetStep()
{
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
if constexpr(num_access == 0)
{
return typename SpaceFillingCurve::Index{};
}
else
{
constexpr auto reset_step =
SpaceFillingCurve::GetStepBetween(Number<num_access - 1>{}, Number<0>{});
return reset_step;
}
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
template <index_t ISrc>
__device__ void MoveSrcSliceWindow(const SrcDescs& src_descs,
Number<ISrc> iSrc,
const Index& src_slice_origin_step_idx)
{
// if src coord was not reset by RunRead(), then need to adjust the step here
const auto adjusted_step_idx = SrcResetCoordinateAfterRunFlags::At(iSrc)
? src_slice_origin_step_idx
: src_slice_origin_step_idx + GetCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(src_descs[iSrc], adjusted_step_idx);
move_tensor_coordinate(src_descs[iSrc], src_coords_(iSrc), adjusted_step);
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
template <index_t IDst>
__device__ void MoveDstSliceWindow(const DstDescs& dst_descs,
Number<IDst> iDst,
const Index& dst_slice_origin_step_idx)
{
// if dst coord was not reset by Run(), then need to adjust the step here
const auto adjusted_step_idx = DstResetCoordinateAfterRunFlags::At(iDst)
? dst_slice_origin_step_idx
: dst_slice_origin_step_idx + GetCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(dst_descs[iDst], adjusted_step_idx);
move_tensor_coordinate(dst_descs[iDst], dst_coords_(iDst), adjusted_step);
}
private:
SrcCoords src_coords_;
DstCoords dst_coords_;
const ElementwiseOperation element_op_;
};
} // namespace ck
...@@ -6,6 +6,8 @@ namespace ck { ...@@ -6,6 +6,8 @@ namespace ck {
template <typename T> template <typename T>
union BufferResource union BufferResource
{ {
__device__ constexpr BufferResource() : content{} {}
// 128 bit SGPRs to supply buffer resource in buffer instructions // 128 bit SGPRs to supply buffer resource in buffer instructions
// https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions // https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions
int32x4_t content; int32x4_t content;
......
#pragma once #pragma once
#include "statically_indexed_array.hpp" #include "statically_indexed_array.hpp"
namespace ck { namespace ck {
...@@ -1000,6 +1001,11 @@ struct NumericLimits ...@@ -1000,6 +1001,11 @@ struct NumericLimits
__host__ __device__ static constexpr T Max() { return std::numeric_limits<T>::max(); } __host__ __device__ static constexpr T Max() { return std::numeric_limits<T>::max(); }
__host__ __device__ static constexpr T Lowest() { return std::numeric_limits<T>::lowest(); } __host__ __device__ static constexpr T Lowest() { return std::numeric_limits<T>::lowest(); }
__host__ __device__ static constexpr T QuietNaN()
{
return std::numeric_limits<T>::quiet_NaN();
}
}; };
template <> template <>
...@@ -1008,12 +1014,15 @@ struct NumericLimits<half_t> ...@@ -1008,12 +1014,15 @@ struct NumericLimits<half_t>
static constexpr unsigned short binary_min = 0x0400; static constexpr unsigned short binary_min = 0x0400;
static constexpr unsigned short binary_max = 0x7BFF; static constexpr unsigned short binary_max = 0x7BFF;
static constexpr unsigned short binary_lowest = 0xFBFF; static constexpr unsigned short binary_lowest = 0xFBFF;
static constexpr unsigned short binary_qnan = 0x7FFF;
__host__ __device__ static constexpr half_t Min() { return bit_cast<half_t>(binary_min); } __host__ __device__ static constexpr half_t Min() { return bit_cast<half_t>(binary_min); }
__host__ __device__ static constexpr half_t Max() { return bit_cast<half_t>(binary_max); } __host__ __device__ static constexpr half_t Max() { return bit_cast<half_t>(binary_max); }
__host__ __device__ static constexpr half_t Lowest() { return bit_cast<half_t>(binary_lowest); } __host__ __device__ static constexpr half_t Lowest() { return bit_cast<half_t>(binary_lowest); }
__host__ __device__ static constexpr half_t QuietNaN() { return bit_cast<half_t>(binary_qnan); }
}; };
} // namespace ck } // namespace ck
#ifndef CK_ENABLE_IF_HPP #pragma once
#define CK_ENABLE_IF_HPP
namespace ck { namespace ck {
...@@ -10,4 +9,3 @@ template <bool B, typename T = void> ...@@ -10,4 +9,3 @@ template <bool B, typename T = void>
using enable_if_t = typename std::enable_if<B, T>::type; using enable_if_t = typename std::enable_if<B, T>::type;
} // namespace ck } // namespace ck
#endif
...@@ -142,6 +142,22 @@ __host__ __device__ constexpr auto min(X x, Ys... ys) ...@@ -142,6 +142,22 @@ __host__ __device__ constexpr auto min(X x, Ys... ys)
return min(x, min(ys...)); return min(x, min(ys...));
} }
// disallow implicit type casting
template <typename T>
__device__ T exp(T x);
template <>
__device__ float exp<float>(float x)
{
return __expf(x);
}
template <>
__device__ double exp<double>(double x)
{
return exp(x);
}
// greatest common divisor, aka highest common factor // greatest common divisor, aka highest common factor
__host__ __device__ constexpr index_t gcd(index_t x, index_t y) __host__ __device__ constexpr index_t gcd(index_t x, index_t y)
{ {
......
...@@ -35,9 +35,27 @@ ...@@ -35,9 +35,27 @@
namespace ck { namespace ck {
namespace detail { namespace detail {
// Check for NaN; guarantee NaNs are NOT propagated to result (i.e., ignore NaNs)
template <typename ReduceOperation, typename AccDataType>
struct AccumulateWithNanIgnore
{
__device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal)
{
if(!isnan(currVal))
{
ReduceOperation{}(accuVal, currVal);
}
};
};
template <bool PropagateNan, typename ReduceOperation, typename AccDataType> template <bool PropagateNan, typename ReduceOperation, typename AccDataType>
struct AccumulateWithNanCheck; struct AccumulateWithNanCheck;
// Does not check for NaN; does not guarantee NaNs be propagated to result
// e.g., given that max(a, b) = a > b ? a : b
// then max(NaN, 1) returns 1
// max(1, NaN) returns NaN
// since any comparison involving NaNs returns false
template <typename ReduceOperation, typename AccDataType> template <typename ReduceOperation, typename AccDataType>
struct AccumulateWithNanCheck<false, ReduceOperation, AccDataType> struct AccumulateWithNanCheck<false, ReduceOperation, AccDataType>
{ {
...@@ -48,6 +66,7 @@ struct AccumulateWithNanCheck<false, ReduceOperation, AccDataType> ...@@ -48,6 +66,7 @@ struct AccumulateWithNanCheck<false, ReduceOperation, AccDataType>
}; };
}; };
// Check for NaN; guarantees NaNs be propagated to result
template <typename ReduceOperation, typename AccDataType> template <typename ReduceOperation, typename AccDataType>
struct AccumulateWithNanCheck<true, ReduceOperation, AccDataType> struct AccumulateWithNanCheck<true, ReduceOperation, AccDataType>
{ {
......
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include "config.hpp" #include "config.hpp"
#include "data_type.hpp" #include "data_type.hpp"
#include "type.hpp"
namespace ck { namespace ck {
...@@ -54,64 +55,92 @@ namespace reduce { ...@@ -54,64 +55,92 @@ namespace reduce {
// accumulated index also need be // accumulated index also need be
// changed. // changed.
template <class T>
struct Add struct Add
{ {
using dataType = T; template <typename T>
__host__ __device__ static constexpr T GetIdentityValue()
__host__ __device__ static constexpr T GetIdentityValue() { return static_cast<T>(0.0f); }; {
return type_convert<T>(0.0f);
};
__device__ static constexpr bool __host__ __device__ static constexpr bool
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation) IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
{ {
return operation == InMemoryDataOperationEnum::AtomicAdd || return operation == InMemoryDataOperationEnum::AtomicAdd ||
operation == InMemoryDataOperationEnum::Set; operation == InMemoryDataOperationEnum::Set;
}; };
__host__ __device__ inline constexpr void operator()(T& a, T b) const { a = a + b; } template <typename T>
__host__ __device__ inline constexpr void operator()(T& a, T b) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, int32_t>::value,
"The data type is not supported by the Add accumulator!");
a = a + b;
}
}; };
template <class T>
struct Mul struct Mul
{ {
using dataType = T; template <typename T>
__host__ __device__ static constexpr T GetIdentityValue()
__host__ __device__ static constexpr T GetIdentityValue() { return static_cast<T>(1.0f); }; {
return type_convert<T>(1.0f);
};
__device__ static constexpr bool __host__ __device__ static constexpr bool
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation) IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
{ {
return operation == InMemoryDataOperationEnum::Set; return operation == InMemoryDataOperationEnum::Set;
}; };
__host__ __device__ inline constexpr void operator()(T& a, T b) const { a = a * b; } template <typename T>
__host__ __device__ inline constexpr void operator()(T& a, T b) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, int32_t>::value,
"The data type is not supported by the Mul accumulator!");
a = a * b;
}
}; };
template <class T>
struct Max struct Max
{ {
using dataType = T; template <typename T>
__host__ __device__ static constexpr T GetIdentityValue() __host__ __device__ static constexpr T GetIdentityValue()
{ {
return NumericLimits<T>::Lowest(); return NumericLimits<T>::Lowest();
}; };
__device__ static constexpr bool __host__ __device__ static constexpr bool
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation) IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
{ {
// ToChange: atomic_max to be added // ToChange: atomic_max to be added
return operation == InMemoryDataOperationEnum::Set; return operation == InMemoryDataOperationEnum::Set;
}; };
template <typename T>
__host__ __device__ inline constexpr void operator()(T& a, T b) const __host__ __device__ inline constexpr void operator()(T& a, T b) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"The data type is not supported by the Max accumulator!");
if(a < b) if(a < b)
a = b; a = b;
} }
template <typename T>
__host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const __host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"The data type is not supported by the Max accumulator!");
if(a < b) if(a < b)
{ {
a = b; a = b;
...@@ -120,28 +149,41 @@ struct Max ...@@ -120,28 +149,41 @@ struct Max
} }
}; };
template <class T>
struct Min struct Min
{ {
using dataType = T; template <typename T>
__host__ __device__ static constexpr T GetIdentityValue()
__host__ __device__ static constexpr T GetIdentityValue() { return NumericLimits<T>::Max(); }; {
return NumericLimits<T>::Max();
};
__device__ static constexpr bool __host__ __device__ static constexpr bool
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation) IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
{ {
// ToChange: atomic_min to be added // ToChange: atomic_min to be added
return operation == InMemoryDataOperationEnum::Set; return operation == InMemoryDataOperationEnum::Set;
}; };
template <typename T>
__host__ __device__ inline constexpr void operator()(T& a, T b) const __host__ __device__ inline constexpr void operator()(T& a, T b) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"The data type is not supported by the Min accumulator!");
if(a > b) if(a > b)
a = b; a = b;
} }
template <typename T>
__host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const __host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"The data type is not supported by the Min accumulator!");
if(a > b) if(a > b)
{ {
a = b; a = b;
...@@ -150,28 +192,41 @@ struct Min ...@@ -150,28 +192,41 @@ struct Min
} }
}; };
template <class T>
struct AMax struct AMax
{ {
using dataType = T; template <typename T>
__host__ __device__ static constexpr T GetIdentityValue()
__host__ __device__ static constexpr T GetIdentityValue() { return static_cast<T>(0.0f); }; {
return type_convert<T>(0.0f);
};
__device__ static constexpr bool __host__ __device__ static constexpr bool
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation) IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
{ {
// ToChange: atomic_max to be added // ToChange: atomic_max to be added
return operation == InMemoryDataOperationEnum::Set; return operation == InMemoryDataOperationEnum::Set;
}; };
template <typename T>
__host__ __device__ inline constexpr void operator()(T& a, T b) const __host__ __device__ inline constexpr void operator()(T& a, T b) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"The data type is not supported by the AMax accumulator!");
if(a < b) if(a < b)
a = b; a = b;
} }
template <typename T>
__host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const __host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"The data type is not supported by the AMax accumulator!");
if(a < b) if(a < b)
{ {
a = b; a = b;
...@@ -181,7 +236,7 @@ struct AMax ...@@ -181,7 +236,7 @@ struct AMax
}; };
template <typename T> template <typename T>
T GetIdentityValueueForInMemoryDataOperation(InMemoryDataOperationEnum operation) constexpr T GetIdentityValueForInMemoryDataOperation(InMemoryDataOperationEnum operation)
{ {
T result = ck::type_convert<T>(0.0f); T result = ck::type_convert<T>(0.0f);
...@@ -191,6 +246,44 @@ T GetIdentityValueueForInMemoryDataOperation(InMemoryDataOperationEnum operation ...@@ -191,6 +246,44 @@ T GetIdentityValueueForInMemoryDataOperation(InMemoryDataOperationEnum operation
return (result); return (result);
}; };
template <InMemoryDataOperationEnum Operation, typename DataType>
struct InMemoryDataOperatonSupportedOnDataType
{
static constexpr bool value = false;
};
template <typename DataType>
struct InMemoryDataOperatonSupportedOnDataType<InMemoryDataOperationEnum::AtomicAdd, DataType>
{
static constexpr bool value =
is_same<DataType, float>::value || is_same<DataType, double>::value;
};
template <typename DataType>
struct InMemoryDataOperatonSupportedOnDataType<InMemoryDataOperationEnum::AtomicMax, DataType>
{
static constexpr bool value =
is_same<DataType, float>::value || is_same<DataType, double>::value;
};
template <typename DataType>
struct InMemoryDataOperatonSupportedOnDataType<InMemoryDataOperationEnum::Set, DataType>
{
static constexpr bool value =
is_same<DataType, float>::value || is_same<DataType, double>::value ||
is_same<DataType, half_t>::value || is_same<DataType, bhalf_t>::value ||
is_same<DataType, int8_t>::value || is_same<DataType, int32_t>::value;
};
template <typename DataType>
struct InMemoryDataOperatonSupportedOnDataType<InMemoryDataOperationEnum::Add, DataType>
{
static constexpr bool value =
is_same<DataType, float>::value || is_same<DataType, double>::value ||
is_same<DataType, half_t>::value || is_same<DataType, int8_t>::value ||
is_same<DataType, int32_t>::value;
};
}; // end of namespace reduce }; // end of namespace reduce
} // end of namespace ck } // end of namespace ck
......
#ifndef CK_SEQUENCE_HPP #pragma once
#define CK_SEQUENCE_HPP
#include "integral_constant.hpp" #include "integral_constant.hpp"
#include "type.hpp" #include "type.hpp"
...@@ -241,7 +240,13 @@ struct arithmetic_sequence_gen ...@@ -241,7 +240,13 @@ struct arithmetic_sequence_gen
} }
}; };
using type = typename sequence_gen<(IEnd - IBegin) / Increment, F>::type; using type0 = typename sequence_gen<(IEnd - IBegin) / Increment, F>::type;
using type1 = Sequence<>;
static constexpr bool kHasContent =
(Increment > 0 && IBegin < IEnd) || (Increment < 0 && IBegin > IEnd);
using type = typename conditional<kHasContent, type0, type1>::type;
}; };
// uniform sequence // uniform sequence
...@@ -882,5 +887,10 @@ __host__ __device__ constexpr bool sequence_all_of(Seq, F f) ...@@ -882,5 +887,10 @@ __host__ __device__ constexpr bool sequence_all_of(Seq, F f)
return flag; return flag;
} }
template <typename Sx, typename Sy>
using sequence_merge_t = typename sequence_merge<Sx, Sy>::type;
template <index_t NSize, index_t I>
using uniform_sequence_gen_t = typename uniform_sequence_gen<NSize, I>::type;
} // namespace ck } // namespace ck
#endif
#ifndef CK_TUPLE_HPP #pragma once
#define CK_TUPLE_HPP
#include "integral_constant.hpp" #include "integral_constant.hpp"
#include "sequence.hpp" #include "sequence.hpp"
...@@ -17,14 +16,18 @@ struct TupleElementKey ...@@ -17,14 +16,18 @@ struct TupleElementKey
}; };
template <typename Key, typename Data> template <typename Key, typename Data>
struct TupleElement struct TupleElementKeyData
{ {
__host__ __device__ constexpr TupleElement() = default; #if 0 // workaround compiler complaint about implicitly-deleted default constructor
__host__ __device__ constexpr TupleElementKeyData() = default;
#else
__host__ __device__ constexpr TupleElementKeyData() : mData{} {}
#endif
template < template <typename T,
typename T, typename enable_if<!is_same<remove_cvref_t<T>, TupleElementKeyData>::value,
typename enable_if<!is_same<remove_cvref_t<T>, TupleElement>::value, bool>::type = false> bool>::type = false>
__host__ __device__ constexpr TupleElement(T&& v) : mData(std::forward<T>(v)) __host__ __device__ constexpr TupleElementKeyData(T&& v) : mData(std::forward<T>(v))
{ {
} }
...@@ -32,20 +35,21 @@ struct TupleElement ...@@ -32,20 +35,21 @@ struct TupleElement
}; };
template <typename Key, typename Data> template <typename Key, typename Data>
__host__ __device__ constexpr const Data& get_tuple_element(const TupleElement<Key, Data>& x) __host__ __device__ constexpr const Data&
get_tuple_element_data(const TupleElementKeyData<Key, Data>& x)
{ {
return static_cast<const Data&>(x.mData); return static_cast<const Data&>(x.mData);
} }
template <typename Key, typename Data> template <typename Key, typename Data>
__host__ __device__ constexpr Data& get_tuple_element(TupleElement<Key, Data>& x) __host__ __device__ constexpr Data& get_tuple_element_data(TupleElementKeyData<Key, Data>& x)
{ {
return x.mData; return x.mData;
} }
// TODO: not sure the use of reference is correct // TODO: not sure the use of reference is correct
template <typename Key, typename Data> template <typename Key, typename Data>
__host__ __device__ constexpr Data&& get_tuple_element(TupleElement<Key, Data>&& x) __host__ __device__ constexpr Data&& get_tuple_element_data(TupleElementKeyData<Key, Data>&& x)
{ {
return static_cast<Data&&>(x.mData); return static_cast<Data&&>(x.mData);
} }
...@@ -54,7 +58,7 @@ template <typename Indices, typename... Xs> ...@@ -54,7 +58,7 @@ template <typename Indices, typename... Xs>
struct TupleImpl; struct TupleImpl;
template <index_t... Is, typename... Xs> template <index_t... Is, typename... Xs>
struct TupleImpl<Sequence<Is...>, Xs...> : TupleElement<TupleElementKey<Is>, Xs>... struct TupleImpl<Sequence<Is...>, Xs...> : TupleElementKeyData<TupleElementKey<Is>, Xs>...
{ {
__host__ __device__ constexpr TupleImpl() = default; __host__ __device__ constexpr TupleImpl() = default;
...@@ -63,13 +67,13 @@ struct TupleImpl<Sequence<Is...>, Xs...> : TupleElement<TupleElementKey<Is>, Xs> ...@@ -63,13 +67,13 @@ struct TupleImpl<Sequence<Is...>, Xs...> : TupleElement<TupleElementKey<Is>, Xs>
!is_same<remove_cvref_t<Y>, TupleImpl>::value, !is_same<remove_cvref_t<Y>, TupleImpl>::value,
bool>::type = false> bool>::type = false>
__host__ __device__ constexpr TupleImpl(Y&& y) __host__ __device__ constexpr TupleImpl(Y&& y)
: TupleElement<TupleElementKey<Is>, Xs>(std::forward<Y>(y))... : TupleElementKeyData<TupleElementKey<Is>, Xs>(std::forward<Y>(y))...
{ {
} }
template <typename... Ys, typename enable_if<sizeof...(Ys) >= 2, bool>::type = false> template <typename... Ys, typename enable_if<sizeof...(Ys) >= 2, bool>::type = false>
__host__ __device__ constexpr TupleImpl(Ys&&... ys) __host__ __device__ constexpr TupleImpl(Ys&&... ys)
: TupleElement<TupleElementKey<Is>, Xs>(std::forward<Ys>(ys))... : TupleElementKeyData<TupleElementKey<Is>, Xs>(std::forward<Ys>(ys))...
{ {
static_assert(sizeof...(Is) == sizeof...(Xs) && sizeof...(Is) == sizeof...(Ys), static_assert(sizeof...(Is) == sizeof...(Xs) && sizeof...(Is) == sizeof...(Ys),
"wrong! inconsistent size"); "wrong! inconsistent size");
...@@ -78,15 +82,15 @@ struct TupleImpl<Sequence<Is...>, Xs...> : TupleElement<TupleElementKey<Is>, Xs> ...@@ -78,15 +82,15 @@ struct TupleImpl<Sequence<Is...>, Xs...> : TupleElement<TupleElementKey<Is>, Xs>
__host__ __device__ static constexpr index_t Size() { return sizeof...(Xs); } __host__ __device__ static constexpr index_t Size() { return sizeof...(Xs); }
template <index_t I> template <index_t I>
__host__ __device__ constexpr const auto& GetElementByKey(TupleElementKey<I>) const __host__ __device__ constexpr const auto& GetElementDataByKey(TupleElementKey<I>) const
{ {
return get_tuple_element<TupleElementKey<I>>(*this); return get_tuple_element_data<TupleElementKey<I>>(*this);
} }
template <index_t I> template <index_t I>
__host__ __device__ constexpr auto& GetElementByKey(TupleElementKey<I>) __host__ __device__ constexpr auto& GetElementDataByKey(TupleElementKey<I>)
{ {
return get_tuple_element<TupleElementKey<I>>(*this); return get_tuple_element_data<TupleElementKey<I>>(*this);
} }
}; };
...@@ -121,7 +125,7 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X ...@@ -121,7 +125,7 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
__host__ __device__ constexpr const auto& At(Number<I>) const __host__ __device__ constexpr const auto& At(Number<I>) const
{ {
static_assert(I < base::Size(), "wrong! out of range"); static_assert(I < base::Size(), "wrong! out of range");
return base::GetElementByKey(detail::TupleElementKey<I>{}); return base::GetElementDataByKey(detail::TupleElementKey<I>{});
} }
// write access // write access
...@@ -129,7 +133,7 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X ...@@ -129,7 +133,7 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
__host__ __device__ constexpr auto& At(Number<I>) __host__ __device__ constexpr auto& At(Number<I>)
{ {
static_assert(I < base::Size(), "wrong! out of range"); static_assert(I < base::Size(), "wrong! out of range");
return base::GetElementByKey(detail::TupleElementKey<I>{}); return base::GetElementDataByKey(detail::TupleElementKey<I>{});
} }
// read access // read access
...@@ -159,6 +163,31 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X ...@@ -159,6 +163,31 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
__host__ __device__ static constexpr bool IsStaticBuffer() { return true; } __host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
}; };
template <>
struct Tuple<>
{
__host__ __device__ constexpr Tuple() = default;
__host__ __device__ static constexpr index_t Size() { return 0; }
template <typename T>
__host__ __device__ constexpr auto operator=(const T&)
{
return *this;
}
__host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
};
template <index_t I, typename TTuple>
struct tuple_element
{
using type = decltype(TTuple{}.At(Number<I>{}));
};
template <index_t I, typename TTuple>
using tuple_element_t = typename tuple_element<I, TTuple>::type;
template <typename... Xs> template <typename... Xs>
__host__ __device__ constexpr auto make_tuple(Xs&&... xs) __host__ __device__ constexpr auto make_tuple(Xs&&... xs)
{ {
...@@ -173,4 +202,3 @@ constexpr Tuple<Args&...> tie(Args&... args) noexcept ...@@ -173,4 +202,3 @@ constexpr Tuple<Args&...> tie(Args&... args) noexcept
} }
} // namespace ck } // namespace ck
#endif
#ifndef CK_TUPLE_HELPER_HPP #pragma once
#define CK_TUPLE_HELPER_HPP
#include "functional4.hpp" #include "functional4.hpp"
#include "tuple.hpp" #include "tuple.hpp"
...@@ -20,6 +19,17 @@ __host__ __device__ constexpr auto generate_tie(F&& f, Number<N>) ...@@ -20,6 +19,17 @@ __host__ __device__ constexpr auto generate_tie(F&& f, Number<N>)
typename arithmetic_sequence_gen<0, N, 1>::type{}); typename arithmetic_sequence_gen<0, N, 1>::type{});
} }
// tx and ty are tuple of references, return type of will tuple of referennce (not rvalue)
template <typename... X, typename... Y>
__host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple<X&...>& tx,
const Tuple<Y&...>& ty)
{
return unpack2(
[&](auto&&... zs) { return Tuple<decltype(zs)...>{std::forward<decltype(zs)>(zs)...}; },
tx,
ty);
}
namespace detail { namespace detail {
template <typename F, typename X, index_t... Is> template <typename F, typename X, index_t... Is>
...@@ -66,4 +76,3 @@ __host__ __device__ constexpr auto transform_tuples(F f, const X& x, const Y& y, ...@@ -66,4 +76,3 @@ __host__ __device__ constexpr auto transform_tuples(F f, const X& x, const Y& y,
} }
} // namespace ck } // namespace ck
#endif
...@@ -174,15 +174,18 @@ struct ReductionHost ...@@ -174,15 +174,18 @@ struct ReductionHost
const InDataType* in_data, const InDataType* in_data,
float beta, float beta,
OutDataType* out_data, OutDataType* out_data,
IndexDataType* out_indices) IndexDataType* out_indices,
InElementwiseOperation in_elementwise_op,
AccElementwiseOperation acc_elementwise_op)
{ {
if constexpr(OutputIndex) if constexpr(OutputIndex)
{ {
RunImpl_with_index(alpha, in_data, beta, out_data, out_indices); RunImpl_with_index(
alpha, in_data, beta, out_data, out_indices, in_elementwise_op, acc_elementwise_op);
} }
else else
{ {
RunImpl_no_index(alpha, in_data, beta, out_data); RunImpl_no_index(alpha, in_data, beta, out_data, in_elementwise_op, acc_elementwise_op);
}; };
}; };
...@@ -190,7 +193,9 @@ struct ReductionHost ...@@ -190,7 +193,9 @@ struct ReductionHost
const InDataType* in_data, const InDataType* in_data,
float beta, float beta,
OutDataType* out_data, OutDataType* out_data,
IndexDataType* out_indices) IndexDataType* out_indices,
InElementwiseOperation in_elementwise_op,
AccElementwiseOperation acc_elementwise_op)
{ {
using ck::float_equal_one; using ck::float_equal_one;
using ck::float_equal_zero; using ck::float_equal_zero;
...@@ -200,12 +205,10 @@ struct ReductionHost ...@@ -200,12 +205,10 @@ struct ReductionHost
ReduceOperation, ReduceOperation,
AccDataType, AccDataType,
IndexDataType>; IndexDataType>;
InElementwiseOperation in_elementwise_op(divider);
AccElementwiseOperation acc_elementwise_op(divider);
if constexpr(NumInvariantDim == 0) if constexpr(NumInvariantDim == 0)
{ {
AccDataType accuVal = ReduceOperation::GetIdentityValue(); AccDataType accuVal = ReduceOperation::template GetIdentityValue<AccDataType>();
IndexDataType accuIndex = 0; IndexDataType accuIndex = 0;
for(std::size_t i = 0; i < reduce_dim_indexes.size(); i++) for(std::size_t i = 0; i < reduce_dim_indexes.size(); i++)
...@@ -236,7 +239,7 @@ struct ReductionHost ...@@ -236,7 +239,7 @@ struct ReductionHost
else else
{ {
auto thread_reduce_func = [&](auto invariant_index) { auto thread_reduce_func = [&](auto invariant_index) {
AccDataType accuVal = ReduceOperation::GetIdentityValue(); AccDataType accuVal = ReduceOperation::template GetIdentityValue<AccDataType>();
IndexDataType accuIndex = 0; IndexDataType accuIndex = 0;
auto offset_invariant = auto offset_invariant =
...@@ -297,7 +300,12 @@ struct ReductionHost ...@@ -297,7 +300,12 @@ struct ReductionHost
}; };
}; };
void RunImpl_no_index(float alpha, const InDataType* in_data, float beta, OutDataType* out_data) void RunImpl_no_index(float alpha,
const InDataType* in_data,
float beta,
OutDataType* out_data,
InElementwiseOperation in_elementwise_op,
AccElementwiseOperation acc_elementwise_op)
{ {
using ck::float_equal_one; using ck::float_equal_one;
using ck::float_equal_zero; using ck::float_equal_zero;
...@@ -306,12 +314,9 @@ struct ReductionHost ...@@ -306,12 +314,9 @@ struct ReductionHost
using Accumulation = using Accumulation =
ck::detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>; ck::detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>;
InElementwiseOperation in_elementwise_op(divider);
AccElementwiseOperation acc_elementwise_op(divider);
if constexpr(NumInvariantDim == 0) if constexpr(NumInvariantDim == 0)
{ {
AccDataType accuVal = ReduceOperation::GetIdentityValue(); AccDataType accuVal = ReduceOperation::template GetIdentityValue<AccDataType>();
for(const auto& reduce_index : reduce_dim_indexes) for(const auto& reduce_index : reduce_dim_indexes)
{ {
...@@ -338,7 +343,7 @@ struct ReductionHost ...@@ -338,7 +343,7 @@ struct ReductionHost
else else
{ {
auto thread_reduce_func = [&](auto invariant_index) { auto thread_reduce_func = [&](auto invariant_index) {
AccDataType accuVal = ReduceOperation::GetIdentityValue(); AccDataType accuVal = ReduceOperation::template GetIdentityValue<AccDataType>();
auto offset_invariant = auto offset_invariant =
get_offset_from_index<NumInvariantDim>(invariantStrides, invariant_index); get_offset_from_index<NumInvariantDim>(invariantStrides, invariant_index);
......
...@@ -107,6 +107,11 @@ struct HostTensorDescriptor ...@@ -107,6 +107,11 @@ struct HostTensorDescriptor
return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0}); return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0});
} }
std::size_t GetOffsetFromMultiIndex(std::vector<std::size_t> iss) const
{
return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0});
}
friend std::ostream& operator<<(std::ostream& os, const HostTensorDescriptor& desc); friend std::ostream& operator<<(std::ostream& os, const HostTensorDescriptor& desc);
private: private:
...@@ -212,6 +217,54 @@ struct Tensor ...@@ -212,6 +217,54 @@ struct Tensor
Tensor(const HostTensorDescriptor& desc) : mDesc(desc), mData(mDesc.GetElementSpace()) {} Tensor(const HostTensorDescriptor& desc) : mDesc(desc), mData(mDesc.GetElementSpace()) {}
Tensor(const Tensor& other) : mDesc(other.mDesc), mData(other.mData) {}
template <typename F>
void ForEach_impl(F&& f, std::vector<size_t>& idx, size_t rank)
{
if(rank == mDesc.GetNumOfDimension())
{
f(*this, idx);
return;
}
// else
for(size_t i = 0; i < mDesc.GetLengths()[rank]; i++)
{
idx[rank] = i;
ForEach_impl(std::forward<F>(f), idx, rank + 1);
}
}
template <typename F>
void ForEach(F&& f)
{
std::vector<size_t> idx(mDesc.GetNumOfDimension(), 0);
ForEach_impl(std::forward<F>(f), idx, size_t(0));
}
template <typename F>
void ForEach_impl(const F&& f, std::vector<size_t>& idx, size_t rank) const
{
if(rank == mDesc.GetNumOfDimension())
{
f(*this, idx);
return;
}
// else
for(size_t i = 0; i < mDesc.GetLengths()[rank]; i++)
{
idx[rank] = i;
ForEach_impl(std::forward<const F>(f), idx, rank + 1);
}
}
template <typename F>
void ForEach(const F&& f) const
{
std::vector<size_t> idx(mDesc.GetNumOfDimension(), 0);
ForEach_impl(std::forward<const F>(f), idx, size_t(0));
}
template <typename G> template <typename G>
void GenerateTensorValue(G g, std::size_t num_thread = 1) void GenerateTensorValue(G g, std::size_t num_thread = 1)
{ {
...@@ -272,6 +325,16 @@ struct Tensor ...@@ -272,6 +325,16 @@ struct Tensor
return mData[mDesc.GetOffsetFromMultiIndex(is...)]; return mData[mDesc.GetOffsetFromMultiIndex(is...)];
} }
T& operator()(std::vector<std::size_t> idx)
{
return mData[mDesc.GetOffsetFromMultiIndex(idx)];
}
const T& operator()(std::vector<std::size_t> idx) const
{
return mData[mDesc.GetOffsetFromMultiIndex(idx)];
}
typename std::vector<T>::iterator begin() { return mData.begin(); } typename std::vector<T>::iterator begin() { return mData.begin(); }
typename std::vector<T>::iterator end() { return mData.end(); } typename std::vector<T>::iterator end() { return mData.end(); }
...@@ -285,7 +348,8 @@ struct Tensor ...@@ -285,7 +348,8 @@ struct Tensor
}; };
template <typename X> template <typename X>
HostTensorDescriptor::HostTensorDescriptor(const std::vector<X>& lens) : mLens(lens) HostTensorDescriptor::HostTensorDescriptor(const std::vector<X>& lens)
: mLens(lens.begin(), lens.end())
{ {
this->CalculateStrides(); this->CalculateStrides();
} }
...@@ -293,7 +357,7 @@ HostTensorDescriptor::HostTensorDescriptor(const std::vector<X>& lens) : mLens(l ...@@ -293,7 +357,7 @@ HostTensorDescriptor::HostTensorDescriptor(const std::vector<X>& lens) : mLens(l
template <typename X, typename Y> template <typename X, typename Y>
HostTensorDescriptor::HostTensorDescriptor(const std::vector<X>& lens, HostTensorDescriptor::HostTensorDescriptor(const std::vector<X>& lens,
const std::vector<Y>& strides) const std::vector<Y>& strides)
: mLens(lens), mStrides(strides) : mLens(lens.begin(), lens.end()), mStrides(strides.begin(), strides.end())
{ {
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment