Unverified Commit 9f8ab221 authored by zjing14's avatar zjing14 Committed by GitHub
Browse files

Merge branch 'develop' into add_int8_wmma_example_instance

parents 755ace59 b4fc4d0b
......@@ -495,6 +495,7 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,
FloatAB,
FloatAB,
FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1),
......
......@@ -494,6 +494,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<
TileMathThreadGroupSize,
ABDataType,
ABDataType,
FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1),
......
......@@ -139,7 +139,8 @@ __host__ __device__ constexpr auto make_merge_transform_v4_no_carry(const LowLen
}
template <typename GridwiseGemm,
typename FloatAB,
typename FloatA,
typename FloatB,
typename FloatC,
typename AGridDesc_B_K0_M_K1,
typename BGridDesc_B_K0_N_K1,
......@@ -153,8 +154,8 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_xdlops_bwd_weight(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
kernel_gemm_xdlops_bwd_weight(const FloatA* __restrict__ p_a_grid,
const FloatB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc,
const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc,
......@@ -181,21 +182,22 @@ __global__ void
c_element_op,
c_block_cluster_adaptor);
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_c_grid;
ignore = a_b_k0_m_k1_grid_desc;
ignore = b_b_k0_n_k1_grid_desc;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = a_element_op;
ignore = b_element_op;
ignore = c_element_op;
ignore = c_block_cluster_adaptor;
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_c_grid;
ignore = a_b_k0_m_k1_grid_desc;
ignore = b_b_k0_n_k1_grid_desc;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = a_element_op;
ignore = b_element_op;
ignore = c_element_op;
ignore = c_block_cluster_adaptor;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
template <index_t BlockSize,
typename FloatAB,
typename FloatA,
typename FloatB,
typename FloatAcc,
typename FloatC,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
......@@ -242,7 +244,9 @@ template <index_t BlockSize,
bool ABlockLdsExtraM1Wrw = false,
bool BBlockLdsExtraN1Wrw = false,
index_t NumGemmKPrefetchStage = 1,
PipelineVersion PipelineVer = PipelineVersion::v1>
PipelineVersion PipelineVer = PipelineVersion::v1,
typename ComputeTypeA = FloatA,
typename ComputeTypeB = ComputeTypeA>
struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
{
static constexpr auto I0 = Number<0>{};
......@@ -265,11 +269,16 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
// denorm test fix, required to work around fp16 mfma issue
// we convert fp16->fp32->bf16 and execute bf16 mfma instruction
// when mfma if fixed, remove this section and update
// FloatABAdjusted -> FloatAB throughout this file
// FloatAAdjusted -> ComputeTypeA, FloatBAdjusted -> ComputeTypeB,
// throughout this file
#if CK_WORKAROUND_DENORM_FIX
using FloatABAdjusted = conditional_t<is_same_v<FloatAB, ck::half_t>, ck::bhalf_t, FloatAB>;
using FloatAAdjusted =
conditional_t<is_same_v<ComputeTypeA, ck::half_t>, ck::bhalf_t, ComputeTypeA>;
using FloatBAdjusted =
conditional_t<is_same_v<ComputeTypeB, ck::half_t>, ck::bhalf_t, ComputeTypeB>;
#else
using FloatABAdjusted = FloatAB;
using FloatAAdjusted = ComputeTypeA;
using FloatBAdjusted = ComputeTypeB;
#endif
// M0/M1/M1Padding
......@@ -506,7 +515,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
constexpr auto c_block_size =
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock().GetElementSpaceSize();
return math::max((a_block_space_size + b_block_space_size) * sizeof(FloatAB),
return math::max((a_block_space_size * sizeof(FloatAAdjusted) +
b_block_space_size * sizeof(FloatBAdjusted)),
c_block_size * sizeof(FloatC));
}
......@@ -610,8 +620,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1, 1));
template <bool HasMainKBlockLoop>
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
__device__ static void Run(const FloatA* __restrict__ p_a_grid,
const FloatB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
void* __restrict__ p_shared,
const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc,
......@@ -673,8 +683,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
Sequence<1, K0PerBlock, MPerBlock, K1>,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatABAdjusted,
FloatA,
FloatAAdjusted,
decltype(a_b_k0_m_k1_grid_desc),
decltype(a_b_k0_m_k1_block_desc),
ABlockTransferSrcAccessOrder,
......@@ -703,8 +713,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
Sequence<1, K0PerBlock, NPerBlock, K1>,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatABAdjusted,
FloatB,
FloatBAdjusted,
decltype(b_b_k0_n_k1_grid_desc),
decltype(b_b_k0_n_k1_block_desc),
BBlockTransferSrcAccessOrder,
......@@ -733,11 +743,14 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
// sanity check
constexpr index_t KPack =
math::max(K1, MfmaSelector<FloatABAdjusted, MPerXDL, NPerXDL>::selected_mfma.k_per_blk);
math::max(K1,
MfmaSelector<FloatAAdjusted, MPerXDL, NPerXDL, FloatBAdjusted>::selected_mfma
.k_per_blk);
auto blockwise_gemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatABAdjusted,
FloatAAdjusted,
FloatBAdjusted,
FloatAcc,
decltype(a_k0_m_k1_block_desc),
decltype(b_k0_n_k1_block_desc),
......@@ -757,10 +770,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatABAdjusted*>(p_shared), a_k0_m_k1_block_desc.GetElementSpaceSize());
static_cast<FloatAAdjusted*>(p_shared), a_k0_m_k1_block_desc.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatABAdjusted*>(p_shared) + a_block_space_size,
static_cast<FloatBAdjusted*>(p_shared) + a_block_space_size,
b_k0_n_k1_block_desc.GetElementSpaceSize());
// gridwise GEMM pipeline
......
......@@ -490,6 +490,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
auto blockwise_gemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB,
FloatAB,
FloatAcc,
decltype(a_block_desc_k0_m_k1),
......
......@@ -175,7 +175,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
return math::integer_divide_ceil(N, NPerBlock) * NPerBlock;
}
__host__ static auto CalculateK0(index_t K) { return math::integer_divide_floor(K, K1Value); }
__host__ static auto CalculateK0(index_t K) { return math::integer_divide_ceil(K, K1Value); }
// Argument
struct Problem
......@@ -369,9 +369,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
"Invalid tuning param!");
// check gridwise gemm pipeline
const index_t K0 = problem.K / K1Value;
const auto num_k_loop = K0 / K0PerBlock;
const auto num_k_loop = math::integer_divide_ceil(problem.K0, K0PerBlock);
if(!GridwiseGemmPipe::IsSupported(num_k_loop))
{
return false;
......@@ -426,6 +424,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
using BlockwiseGemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatABAdjusted,
FloatABAdjusted,
FloatAcc,
decltype(a_block_desc_k0_m_k1),
......@@ -571,6 +570,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,
FloatABAdjusted,
FloatABAdjusted,
FloatAcc,
decltype(a_block_desc_k0_m_k1),
decltype(b_block_desc_k0_n_k1),
......@@ -945,7 +945,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext
}
}();
if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding)
if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)
{
return transform_tensor_descriptor(c_grid_desc_m_n,
make_tuple(make_right_pad_transform(M, MPad - M),
......@@ -1026,8 +1027,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext
}
// check gridwise gemm pipeline
const index_t K0 = problem.K / K1;
const auto num_k_loop = K0 / K0PerBlock;
const auto num_k_loop = math::integer_divide_ceil(problem.K0, K0PerBlock);
if(!GridwiseGemmPipe::IsSupported(num_k_loop))
{
......
......@@ -22,13 +22,19 @@ namespace ck {
template <typename GridwiseGemm,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename Block2CTileMap>
typename Block2CTileMap,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_xdlops_v2r4r2_simplified(typename GridwiseGemm::Argument karg,
const Block2CTileMap& b2c_map)
const Block2CTileMap& b2c_map,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
......@@ -37,10 +43,13 @@ __global__ void
__shared__ uint8_t p_shared[shared_size];
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation>(
karg, static_cast<void*>(p_shared), b2c_map);
karg, static_cast<void*>(p_shared), b2c_map, a_element_op, b_element_op, c_element_op);
#else
ignore = karg;
ignore = b2c_map;
ignore = a_element_op;
ignore = b_element_op;
ignore = c_element_op;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
......@@ -577,7 +586,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
typename Block2CTileMap>
__device__ static void Run(const Argument& karg,
void* __restrict__ p_shared_block,
const Block2CTileMap& block_2_ctile_map)
const Block2CTileMap& block_2_ctile_map,
const AElementwiseOperation a_element_op = AElementwiseOperation{},
const BElementwiseOperation b_element_op = BElementwiseOperation{},
const CElementwiseOperation c_element_op = CElementwiseOperation{})
{
const FloatA* p_a_grid = karg.p_a_grid;
const FloatB* p_b_grid = karg.p_b_grid;
......@@ -590,9 +602,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n);
const AElementwiseOperation a_element_op = AElementwiseOperation{};
const BElementwiseOperation b_element_op = BElementwiseOperation{};
const CElementwiseOperation c_element_op = CElementwiseOperation{};
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_b_k0_m_k1_grid_desc.GetElementSpaceSize());
......@@ -761,7 +770,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,
ComputeType,
ComputeType, // ComputeType A
ComputeType, // ComputeType B
FloatAcc,
decltype(a_k0_m_k1_block_desc),
decltype(b_k0_n_k1_block_desc),
......
......@@ -451,6 +451,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
auto blockwise_gemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB,
FloatAB,
FloatAcc,
decltype(a_block_desc_ak0_m_ak1),
......
......@@ -471,6 +471,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
auto blockwise_gemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB,
FloatAB,
FloatAcc,
decltype(a_block_desc_k0_m_k1),
......
......@@ -489,6 +489,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
auto blockwise_gemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB,
FloatAB,
FloatAcc,
decltype(a_block_desc_k0_m_k1),
......
......@@ -16,6 +16,36 @@
namespace ck {
template <typename InputGridDesc,
typename InputDataType,
typename OutputGridDesc,
typename OutputDataType,
typename Block2ETileMap,
typename GridwiseTensorRearrangeKernel>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_tensor_rearrange(const InputGridDesc in_grid_desc,
const InputDataType* __restrict__ p_in_global,
const OutputGridDesc out_grid_desc,
OutputDataType* __restrict__ p_out_global,
const Block2ETileMap block_2_tile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx1030__) || defined(__gfx1100__) || \
defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx941__) || defined(__gfx942__))
GridwiseTensorRearrangeKernel::Run(
in_grid_desc, p_in_global, out_grid_desc, p_out_global, block_2_tile_map);
#else
ignore = in_grid_desc;
ignore = p_in_global;
ignore = out_grid_desc;
ignore = p_out_global;
ignore = block_2_tile_map;
#endif
}
template <typename InputGridDesc,
typename InputDataType,
typename OutputGridDesc,
......@@ -25,8 +55,9 @@ template <typename InputGridDesc,
index_t KPerBlock,
typename ThreadClusterLengths,
index_t ScalarPerVector,
InMemoryDataOperationEnum DstInMemOp,
typename Block2ETileMap>
struct GridwiseImageToColumn
struct GridwiseTensorRearrange
{
static constexpr auto I0 = Number<0>{};
......@@ -55,27 +86,27 @@ struct GridwiseImageToColumn
auto out_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global, out_grid_desc.GetElementSpaceSize());
auto copy_global_to_global = ThreadGroupTensorSliceTransfer_v7<
ThisThreadBlock,
Tuple<InputDataType>,
Tuple<OutputDataType>,
decltype(tie(in_grid_desc)),
decltype(tie(out_grid_desc)),
tensor_operation::element_wise::PassThrough,
Sequence<static_cast<index_t>(InMemoryDataOperationEnum::Set)>,
Sequence<MPerBlock, KPerBlock>,
ThreadClusterLengths,
Sequence<0, 1>,
Sequence<0, 1>,
I1,
ScalarPerVector,
Sequence<true>,
Sequence<true>>{
in_grid_desc,
make_tuple(make_multi_index(m_block_data_idx_on_grid, k_block_data_idx_on_grid)),
out_grid_desc,
make_tuple(make_multi_index(m_block_data_idx_on_grid, k_block_data_idx_on_grid)),
tensor_operation::element_wise::PassThrough{}};
auto copy_global_to_global =
ThreadGroupTensorSliceTransfer_v7<ThisThreadBlock,
Tuple<InputDataType>,
Tuple<OutputDataType>,
decltype(tie(in_grid_desc)),
decltype(tie(out_grid_desc)),
tensor_operation::element_wise::PassThrough,
Sequence<static_cast<index_t>(DstInMemOp)>,
Sequence<MPerBlock, KPerBlock>,
ThreadClusterLengths,
Sequence<0, 1>,
Sequence<0, 1>,
I1,
ScalarPerVector,
Sequence<true>,
Sequence<true>>{
in_grid_desc,
make_tuple(make_multi_index(m_block_data_idx_on_grid, k_block_data_idx_on_grid)),
out_grid_desc,
make_tuple(make_multi_index(m_block_data_idx_on_grid, k_block_data_idx_on_grid)),
tensor_operation::element_wise::PassThrough{}};
copy_global_to_global.Run(
tie(in_grid_desc), tie(in_global_buf), tie(out_grid_desc), tie(out_global_buf));
......
......@@ -18,9 +18,11 @@ template <typename XDataType,
typename GammaDataType,
typename BetaDataType,
typename YDataType,
typename SaveMeanInvStdDataType,
typename ComputeDataType,
typename YElementwiseOperation,
typename GridDesc_M_K,
typename GridDesc_M,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
......@@ -34,6 +36,7 @@ template <typename XDataType,
index_t BetaSrcVectorSize,
index_t YDstVectorDim,
index_t YDstVectorSize,
index_t SaveMeanInvStdDstVectorSize,
bool SweepOnce>
struct GridwiseNormalizationNaiveVariance_mk_to_mk
{
......@@ -45,6 +48,10 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
(YDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static_assert(MThreadSliceSize % SaveMeanInvStdDstVectorSize == 0,
"Invalid thread slice sizes and/or save mean and inverse std vector sizes "
"configuration, please check!");
static_assert(XSrcVectorSize == YDstVectorSize);
static_assert(XSrcVectorSize == GammaSrcVectorSize);
static_assert(XSrcVectorSize == BetaSrcVectorSize);
......@@ -66,6 +73,10 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
static constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<XSrcVectorSize>{}));
using ThreadBufferLengths_M = Sequence<MThreadSliceSize>;
static constexpr auto thread_buffer_desc_m =
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<XSrcVectorSize>{})));
using ThreadReduceDstDesc_M =
......@@ -84,6 +95,8 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
reduce::Add,
true>;
using PassThroughOp = tensor_operation::element_wise::PassThrough;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
......@@ -98,12 +111,16 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
const GridDesc_M_K& gamma_grid_desc_m_k,
const GridDesc_M_K& beta_grid_desc_m_k,
const GridDesc_M_K& y_grid_desc_m_k,
const GridDesc_M& save_mean_grid_desc_m,
const GridDesc_M& save_inv_std_grid_desc_m,
index_t num_k_block_tile_iteration,
ComputeDataType epsilon,
const XDataType* const __restrict__ p_x_global,
const GammaDataType* const __restrict__ p_gamma_global,
const BetaDataType* const __restrict__ p_beta_global,
YDataType* const __restrict__ p_y_global,
SaveMeanInvStdDataType* const __restrict__ p_save_mean_global,
SaveMeanInvStdDataType* const __restrict__ p_save_inv_std_global,
const YElementwiseOperation y_elementwise_op)
{
// LDS
......@@ -115,6 +132,12 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
auto y_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_y_global, y_grid_desc_m_k.GetElementSpaceSize());
auto save_mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_save_mean_global, save_mean_grid_desc_m.GetElementSpaceSize());
auto save_inv_std_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_save_inv_std_global, save_inv_std_grid_desc_m.GetElementSpaceSize());
auto x_thread_buf = generate_tuple(
[&](auto) {
return StaticBuffer<AddressSpaceEnum::Vgpr,
......@@ -152,6 +175,8 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
mean_square_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>&
var_thread_buf = mean_square_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>&
inv_std_thread_buf = mean_square_thread_buf;
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id();
......@@ -228,6 +253,42 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
thread_k_cluster_id * YDstVectorSize),
y_elementwise_op);
auto threadwise_mean_store =
ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType,
SaveMeanInvStdDataType,
decltype(thread_buffer_desc_m),
GridDesc_M,
PassThroughOp,
ThreadBufferLengths_M,
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
SaveMeanInvStdDstVectorSize, // ScalarPerVector
InMemoryDataOperationEnum::Set,
1,
true>(
save_mean_grid_desc_m,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
auto threadwise_inv_std_store =
ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType,
SaveMeanInvStdDataType,
decltype(thread_buffer_desc_m),
GridDesc_M,
PassThroughOp,
ThreadBufferLengths_M,
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
SaveMeanInvStdDstVectorSize, // ScalarPerVector
InMemoryDataOperationEnum::Set,
1,
true>(
save_inv_std_grid_desc_m,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileStepSize);
constexpr auto thread_copy_bwd_step_m_k =
make_multi_index(0, SweepOnce ? 0 : -K_BlockTileSize);
......@@ -243,7 +304,8 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
// E(x), E[x^2], var(x)
// FIXME: Should not hack the transform from deviceOP
int reduce_length = x_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0];
ComputeDataType reduce_length = type_convert<ComputeDataType>(
x_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0]);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
mean_thread_buf(I) = reduce::Add::template GetIdentityValue<ComputeDataType>();
......@@ -302,10 +364,34 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
// var(x) = E[x^2] - E[x]^2
var_thread_buf(I) =
mean_square_thread_buf(I) - (mean_thread_buf(I) * mean_thread_buf(I));
inv_std_thread_buf(I) = type_convert<ComputeDataType>(1.0f) /
ck::math::sqrt(var_thread_buf(I) + epsilon);
});
// save mean and inverse std for backward (optional)
if(thread_k_cluster_id == 0)
{
if(p_save_mean_global != nullptr)
{
threadwise_mean_store.Run(thread_buffer_desc_m,
make_tuple(I0),
mean_thread_buf,
save_mean_grid_desc_m,
save_mean_global_val_buf);
}
if(p_save_inv_std_global != nullptr)
{
threadwise_inv_std_store.Run(thread_buffer_desc_m,
make_tuple(I0),
inv_std_thread_buf,
save_inv_std_grid_desc_m,
save_inv_std_global_val_buf);
}
}
// normalization
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon);
static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) {
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
constexpr auto offset_m_k =
......@@ -314,7 +400,7 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
// normalize
y_thread_buf(iK0)(Number<offset_m_k>{}) =
(x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) *
divisor;
inv_std_thread_buf(iM);
// gamma & beta
y_thread_buf(iK0)(Number<offset_m_k>{}) =
......@@ -404,8 +490,30 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
// var(x) = E[x^2] - E[x]^2
var_thread_buf(I) =
mean_square_thread_buf(I) - (mean_thread_buf(I) * mean_thread_buf(I));
inv_std_thread_buf(I) = 1 / ck::math::sqrt(var_thread_buf(I) + epsilon);
});
if(thread_k_cluster_id == 0)
{
if(p_save_mean_global != nullptr)
{
threadwise_mean_store.Run(thread_buffer_desc_m,
make_tuple(I0),
mean_thread_buf,
save_mean_grid_desc_m,
save_mean_global_val_buf);
}
if(p_save_inv_std_global != nullptr)
{
threadwise_inv_std_store.Run(thread_buffer_desc_m,
make_tuple(I0),
inv_std_thread_buf,
save_inv_std_grid_desc_m,
save_inv_std_global_val_buf);
}
}
auto thread_copy_tail_m_k =
(num_k_block_tile_iteration - 1) * ThreadBufferNumber * thread_copy_fwd_step_m_k;
......@@ -437,7 +545,6 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
});
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon);
static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) {
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
constexpr auto offset_m_k =
......@@ -446,7 +553,7 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
// normalize
y_thread_buf(iK0)(Number<offset_m_k>{}) =
(x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) *
divisor;
inv_std_thread_buf(iM);
// gamma
y_thread_buf(iK0)(Number<offset_m_k>{}) =
......
......@@ -12,31 +12,42 @@ template <typename GridwiseReduction,
typename GammaDataType,
typename BetaDataType,
typename YDataType,
typename SaveMeanInvStdDataType,
typename ComputeDataType,
typename YElementwiseOperation,
typename GridDesc_M_K>
__global__ void kernel_normalization(const GridDesc_M_K x_grid_desc_m_k,
const GridDesc_M_K gamma_grid_desc_m_k,
const GridDesc_M_K beta_grid_desc_m_k,
const GridDesc_M_K y_grid_desc_m_k,
index_t num_k_block_tile_iteration,
ComputeDataType epsilon,
const XDataType* const __restrict__ p_x_global,
const GammaDataType* const __restrict__ p_gamma_global,
const BetaDataType* const __restrict__ p_beta_global,
YDataType* const __restrict__ p_y_global,
const YElementwiseOperation y_elementwise_op)
typename GridDesc_M_K,
typename GridDesc_M>
__global__ void
kernel_normalization(const GridDesc_M_K x_grid_desc_m_k,
const GridDesc_M_K gamma_grid_desc_m_k,
const GridDesc_M_K beta_grid_desc_m_k,
const GridDesc_M_K y_grid_desc_m_k,
const GridDesc_M save_mean_grid_desc_m,
const GridDesc_M save_inv_std_grid_desc_m,
index_t num_k_block_tile_iteration,
ComputeDataType epsilon,
const XDataType* const __restrict__ p_x_global,
const GammaDataType* const __restrict__ p_gamma_global,
const BetaDataType* const __restrict__ p_beta_global,
YDataType* const __restrict__ p_y_global,
SaveMeanInvStdDataType* const __restrict__ p_save_mean_global,
SaveMeanInvStdDataType* const __restrict__ p_save_inv_std_global,
const YElementwiseOperation y_elementwise_op)
{
GridwiseReduction::Run(x_grid_desc_m_k,
gamma_grid_desc_m_k,
beta_grid_desc_m_k,
y_grid_desc_m_k,
save_mean_grid_desc_m,
save_inv_std_grid_desc_m,
num_k_block_tile_iteration,
epsilon,
p_x_global,
p_gamma_global,
p_beta_global,
p_y_global,
p_save_mean_global,
p_save_inv_std_global,
y_elementwise_op);
};
......@@ -44,9 +55,11 @@ template <typename XDataType,
typename GammaDataType,
typename BetaDataType,
typename YDataType,
typename SaveMeanInvStdDataType,
typename ComputeDataType,
typename YElementwiseOperation,
typename GridDesc_M_K,
typename GridDesc_M,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
......@@ -60,6 +73,7 @@ template <typename XDataType,
index_t BetaSrcVectorSize,
index_t YDstVectorDim,
index_t YDstVectorSize,
index_t SaveMeanInvStdDstVectorSize,
bool UseWelford>
auto NormalizationKernelSelector(bool isSweepOnce)
{
......@@ -68,9 +82,11 @@ auto NormalizationKernelSelector(bool isSweepOnce)
GammaDataType,
BetaDataType,
YDataType,
SaveMeanInvStdDataType,
ComputeDataType,
YElementwiseOperation,
GridDesc_M_K,
GridDesc_M,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
......@@ -84,15 +100,18 @@ auto NormalizationKernelSelector(bool isSweepOnce)
BetaSrcVectorSize,
YDstVectorDim,
YDstVectorSize,
SaveMeanInvStdDstVectorSize,
false>;
using GridwiseNormalizationSweepOnceNaive =
GridwiseNormalizationNaiveVariance_mk_to_mk<XDataType,
GammaDataType,
BetaDataType,
YDataType,
SaveMeanInvStdDataType,
ComputeDataType,
YElementwiseOperation,
GridDesc_M_K,
GridDesc_M,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
......@@ -106,15 +125,18 @@ auto NormalizationKernelSelector(bool isSweepOnce)
BetaSrcVectorSize,
YDstVectorDim,
YDstVectorSize,
SaveMeanInvStdDstVectorSize,
true>;
using GridwiseNormalizationGenericWelford =
GridwiseNormalizationWelfordVariance_mk_to_mk<XDataType,
GammaDataType,
BetaDataType,
YDataType,
SaveMeanInvStdDataType,
ComputeDataType,
YElementwiseOperation,
GridDesc_M_K,
GridDesc_M,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
......@@ -128,15 +150,18 @@ auto NormalizationKernelSelector(bool isSweepOnce)
BetaSrcVectorSize,
YDstVectorDim,
YDstVectorSize,
SaveMeanInvStdDstVectorSize,
false>;
using GridwiseNormalizationSweepOnceWelford =
GridwiseNormalizationWelfordVariance_mk_to_mk<XDataType,
GammaDataType,
BetaDataType,
YDataType,
SaveMeanInvStdDataType,
ComputeDataType,
YElementwiseOperation,
GridDesc_M_K,
GridDesc_M,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
......@@ -150,6 +175,7 @@ auto NormalizationKernelSelector(bool isSweepOnce)
BetaSrcVectorSize,
YDstVectorDim,
YDstVectorSize,
SaveMeanInvStdDstVectorSize,
true>;
if constexpr(UseWelford)
......@@ -159,17 +185,21 @@ auto NormalizationKernelSelector(bool isSweepOnce)
GammaDataType,
BetaDataType,
YDataType,
SaveMeanInvStdDataType,
ComputeDataType,
YElementwiseOperation,
GridDesc_M_K>
GridDesc_M_K,
GridDesc_M>
: kernel_normalization<GridwiseNormalizationGenericWelford,
XDataType,
GammaDataType,
BetaDataType,
YDataType,
SaveMeanInvStdDataType,
ComputeDataType,
YElementwiseOperation,
GridDesc_M_K>;
GridDesc_M_K,
GridDesc_M>;
}
else
{
......@@ -178,17 +208,21 @@ auto NormalizationKernelSelector(bool isSweepOnce)
GammaDataType,
BetaDataType,
YDataType,
SaveMeanInvStdDataType,
ComputeDataType,
YElementwiseOperation,
GridDesc_M_K>
GridDesc_M_K,
GridDesc_M>
: kernel_normalization<GridwiseNormalizationGenericNaive,
XDataType,
GammaDataType,
BetaDataType,
YDataType,
SaveMeanInvStdDataType,
ComputeDataType,
YElementwiseOperation,
GridDesc_M_K>;
GridDesc_M_K,
GridDesc_M>;
}
}
......
......@@ -17,11 +17,13 @@ template <typename MeanVarDataType,
typename GammaDataType,
typename BetaDataType,
typename YDataType,
typename SaveMeanInvStdDataType,
typename ComputeDataType,
typename YElementwiseOperation,
typename MeanVarGridDesc_M_KBlock,
typename CountGridDesc_M_KBlock,
typename XYGammaBetaGridDesc_M_K,
typename SaveMeanInvStdGridDesc_M,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
......@@ -34,7 +36,8 @@ template <typename MeanVarDataType,
index_t BetaSrcVectorDim,
index_t BetaSrcVectorSize,
index_t YDstVectorDim,
index_t YDstVectorSize>
index_t YDstVectorSize,
index_t SaveMeanInvStdDstVectorSize>
struct GridwiseNormalizationSplitK2nd
{
static_assert((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) ||
......@@ -45,6 +48,10 @@ struct GridwiseNormalizationSplitK2nd
(YDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static_assert(MThreadSliceSize % SaveMeanInvStdDstVectorSize == 0,
"Invalid thread slice sizes and/or save mean and inverse std vector sizes "
"configuration, please check!");
static_assert(XSrcVectorSize == YDstVectorSize);
static_assert(XSrcVectorSize == GammaSrcVectorSize);
static_assert(XSrcVectorSize == BetaSrcVectorSize);
......@@ -69,6 +76,10 @@ struct GridwiseNormalizationSplitK2nd
static constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<XSrcVectorSize>{}));
using ThreadBufferLengths_M = Sequence<MThreadSliceSize>;
static constexpr auto thread_buffer_desc_m =
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
using ThreadBufferLengths_M_1 = Sequence<MThreadSliceSize, 1>;
static constexpr auto thread_buffer_desc_m_1 =
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}, I1));
......@@ -99,6 +110,8 @@ struct GridwiseNormalizationSplitK2nd
const XYGammaBetaGridDesc_M_K& gamma_grid_desc_m_k,
const XYGammaBetaGridDesc_M_K& beta_grid_desc_m_k,
const XYGammaBetaGridDesc_M_K& y_grid_desc_m_k,
const SaveMeanInvStdGridDesc_M& save_mean_grid_desc_m,
const SaveMeanInvStdGridDesc_M& save_inv_std_grid_desc_m,
index_t num_k_mean_var_count_iteration,
index_t num_k_block_tile_iteration,
index_t k_grid_size,
......@@ -110,6 +123,8 @@ struct GridwiseNormalizationSplitK2nd
const GammaDataType* const __restrict__ p_gamma_global,
const BetaDataType* const __restrict__ p_beta_global,
YDataType* const __restrict__ p_y_global,
SaveMeanInvStdDataType* const __restrict__ p_save_mean_global,
SaveMeanInvStdDataType* const __restrict__ p_save_inv_std_global,
const YElementwiseOperation y_elementwise_op)
{
// Thread/Block id
......@@ -145,6 +160,12 @@ struct GridwiseNormalizationSplitK2nd
auto y_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_y_global, y_grid_desc_m_k.GetElementSpaceSize());
auto save_mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_save_mean_global, save_mean_grid_desc_m.GetElementSpaceSize());
auto save_inv_std_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_save_inv_std_global, save_inv_std_grid_desc_m.GetElementSpaceSize());
// VGPR
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
in_mean_thread_buf;
......@@ -158,6 +179,7 @@ struct GridwiseNormalizationSplitK2nd
var_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, int32_t, MThreadSliceSize, true>
welford_count_thread_buf;
auto& inv_std_thread_buf = var_thread_buf;
auto x_thread_buf = generate_tuple(
[&](auto) {
......@@ -283,6 +305,42 @@ struct GridwiseNormalizationSplitK2nd
thread_k_cluster_id * YDstVectorSize),
y_elementwise_op);
auto threadwise_mean_store =
ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType,
SaveMeanInvStdDataType,
decltype(thread_buffer_desc_m),
SaveMeanInvStdGridDesc_M,
PassThroughOp,
ThreadBufferLengths_M,
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
SaveMeanInvStdDstVectorSize, // ScalarPerVector
InMemoryDataOperationEnum::Set,
1,
true>(
save_mean_grid_desc_m,
make_multi_index(block_m_cluster_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
auto threadwise_inv_std_store =
ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType,
SaveMeanInvStdDataType,
decltype(thread_buffer_desc_m),
SaveMeanInvStdGridDesc_M,
PassThroughOp,
ThreadBufferLengths_M,
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
SaveMeanInvStdDstVectorSize, // ScalarPerVector
InMemoryDataOperationEnum::Set,
1,
true>(
save_inv_std_grid_desc_m,
make_multi_index(block_m_cluster_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
// step1: Merge mean and variance
constexpr auto mean_var_count_thread_copy_step_I0_k =
make_multi_index(I0, KThreadClusterSize);
......@@ -332,9 +390,33 @@ struct GridwiseNormalizationSplitK2nd
BlockwiseWelford::Run(
mean_thread_buf(I), var_thread_buf(I), welford_count_thread_buf(I));
inv_std_thread_buf(I) =
type_convert<ComputeDataType>(1.0f) / ck::math::sqrt(var_thread_buf(I) + epsilon);
});
// step2: normalization
// step2: save mean and inverse std for backward (optional)
if(block_k_cluster_id == 0 && thread_k_cluster_id == 0)
{
if(p_save_mean_global != nullptr)
{
threadwise_mean_store.Run(thread_buffer_desc_m,
make_tuple(I0),
mean_thread_buf,
save_mean_grid_desc_m,
save_mean_global_val_buf);
}
if(p_save_inv_std_global != nullptr)
{
threadwise_inv_std_store.Run(thread_buffer_desc_m,
make_tuple(I0),
inv_std_thread_buf,
save_inv_std_grid_desc_m,
save_inv_std_global_val_buf);
}
}
// step3: normalization
constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileStepSize);
for(index_t k = 0; k < num_k_block_tile_iteration; ++k)
......@@ -360,7 +442,6 @@ struct GridwiseNormalizationSplitK2nd
});
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon);
static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) {
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
constexpr auto offset_m_k =
......@@ -369,7 +450,7 @@ struct GridwiseNormalizationSplitK2nd
// normalize
y_thread_buf(iK0)(Number<offset_m_k>{}) =
(x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) *
divisor;
inv_std_thread_buf(iM);
// gamma
y_thread_buf(iK0)(Number<offset_m_k>{}) =
......
......@@ -16,9 +16,11 @@ template <typename XDataType,
typename GammaDataType,
typename BetaDataType,
typename YDataType,
typename SaveMeanInvStdDataType,
typename ComputeDataType,
typename YElementwiseOperation,
typename GridDesc_M_K,
typename GridDesc_M,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
......@@ -32,6 +34,7 @@ template <typename XDataType,
index_t BetaSrcVectorSize,
index_t YDstVectorDim,
index_t YDstVectorSize,
index_t SaveMeanInvStdDstVectorSize,
bool SweepOnce>
struct GridwiseNormalizationWelfordVariance_mk_to_mk
{
......@@ -43,6 +46,10 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
(YDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static_assert(MThreadSliceSize % SaveMeanInvStdDstVectorSize == 0,
"Invalid thread slice sizes and/or save mean and inverse std vector sizes "
"configuration, please check!");
static_assert(XSrcVectorSize == YDstVectorSize);
static_assert(XSrcVectorSize == GammaSrcVectorSize);
static_assert(XSrcVectorSize == BetaSrcVectorSize);
......@@ -64,6 +71,10 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
static constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<XSrcVectorSize>{}));
using ThreadBufferLengths_M = Sequence<MThreadSliceSize>;
static constexpr auto thread_buffer_desc_m =
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<XSrcVectorSize>{})));
using ThreadReduceDstDesc_M =
......@@ -77,6 +88,8 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder>;
using PassThroughOp = tensor_operation::element_wise::PassThrough;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
......@@ -114,17 +127,18 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
const GridDesc_M_K& gamma_grid_desc_m_k,
const GridDesc_M_K& beta_grid_desc_m_k,
const GridDesc_M_K& y_grid_desc_m_k,
const GridDesc_M& save_mean_grid_desc_m,
const GridDesc_M& save_inv_std_grid_desc_m,
index_t num_k_block_tile_iteration,
ComputeDataType epsilon,
const XDataType* const __restrict__ p_x_global,
const GammaDataType* const __restrict__ p_gamma_global,
const BetaDataType* const __restrict__ p_beta_global,
YDataType* const __restrict__ p_y_global,
SaveMeanInvStdDataType* const __restrict__ p_save_mean_global,
SaveMeanInvStdDataType* const __restrict__ p_save_inv_std_global,
const YElementwiseOperation y_elementwise_op)
{
auto y_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_y_global, y_grid_desc_m_k.GetElementSpaceSize());
auto x_thread_buf = generate_tuple(
[&](auto) {
return StaticBuffer<AddressSpaceEnum::Vgpr,
......@@ -150,6 +164,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
mean_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
var_thread_buf;
auto& inv_std_thread_buf = var_thread_buf;
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id();
......@@ -226,6 +241,42 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
thread_k_cluster_id * YDstVectorSize),
y_elementwise_op);
auto threadwise_mean_store =
ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType,
SaveMeanInvStdDataType,
decltype(thread_buffer_desc_m),
GridDesc_M,
PassThroughOp,
ThreadBufferLengths_M,
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
SaveMeanInvStdDstVectorSize, // ScalarPerVector
InMemoryDataOperationEnum::Set,
1,
true>(
save_mean_grid_desc_m,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
auto threadwise_inv_std_store =
ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType,
SaveMeanInvStdDataType,
decltype(thread_buffer_desc_m),
GridDesc_M,
PassThroughOp,
ThreadBufferLengths_M,
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
SaveMeanInvStdDstVectorSize, // ScalarPerVector
InMemoryDataOperationEnum::Set,
1,
true>(
save_inv_std_grid_desc_m,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileStepSize);
constexpr auto thread_copy_bwd_step_m_k =
make_multi_index(0, SweepOnce ? 0 : -K_BlockTileSize);
......@@ -239,6 +290,15 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
const auto beta_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_beta_global, beta_grid_desc_m_k.GetElementSpaceSize());
auto y_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_y_global, y_grid_desc_m_k.GetElementSpaceSize());
auto save_mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_save_mean_global, save_mean_grid_desc_m.GetElementSpaceSize());
auto save_inv_std_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_save_inv_std_global, save_inv_std_grid_desc_m.GetElementSpaceSize());
auto threadwise_welford = ThreadwiseWelford();
threadwise_welford.max_count_ = GetKPerThread(x_grid_desc_m_k, thread_k_cluster_id);
......@@ -279,10 +339,33 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
int count = threadwise_welford.cur_count_;
BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count);
inv_std_thread_buf(I) = type_convert<ComputeDataType>(1.0f) /
ck::math::sqrt(var_thread_buf(I) + epsilon);
});
// save mean and inverse std for backward (optional)
if(thread_k_cluster_id == 0)
{
if(p_save_mean_global != nullptr)
{
threadwise_mean_store.Run(thread_buffer_desc_m,
make_tuple(I0),
mean_thread_buf,
save_mean_grid_desc_m,
save_mean_global_val_buf);
}
if(p_save_inv_std_global != nullptr)
{
threadwise_inv_std_store.Run(thread_buffer_desc_m,
make_tuple(I0),
inv_std_thread_buf,
save_inv_std_grid_desc_m,
save_inv_std_global_val_buf);
}
}
// normalization
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon);
static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) {
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
constexpr auto offset_m_k =
......@@ -291,7 +374,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
// normalize
y_thread_buf(iK0)(Number<offset_m_k>{}) =
(x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) *
divisor;
inv_std_thread_buf(iM);
// gamma & beta
y_thread_buf(iK0)(Number<offset_m_k>{}) =
......@@ -360,8 +443,29 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
int count = threadwise_welford.cur_count_;
BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count);
inv_std_thread_buf(I) = 1 / ck::math::sqrt(var_thread_buf(I) + epsilon);
});
if(thread_k_cluster_id == 0)
{
if(p_save_mean_global != nullptr)
{
threadwise_mean_store.Run(thread_buffer_desc_m,
make_tuple(I0),
mean_thread_buf,
save_mean_grid_desc_m,
save_mean_global_val_buf);
}
if(p_save_inv_std_global != nullptr)
{
threadwise_inv_std_store.Run(thread_buffer_desc_m,
make_tuple(I0),
inv_std_thread_buf,
save_inv_std_grid_desc_m,
save_inv_std_global_val_buf);
}
}
auto thread_copy_tail_m_k =
(num_k_block_tile_iteration - 1) * ThreadBufferNumber * thread_copy_fwd_step_m_k;
......@@ -393,7 +497,6 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
});
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon);
static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) {
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
constexpr auto offset_m_k =
......@@ -402,7 +505,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
// normalize
y_thread_buf(iK0)(Number<offset_m_k>{}) =
(x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) *
divisor;
inv_std_thread_buf(iM);
// gamma
y_thread_buf(iK0)(Number<offset_m_k>{}) =
......
......@@ -9,6 +9,7 @@
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor/static_tensor.hpp"
#include "ck/utility/is_detected.hpp"
namespace ck {
......@@ -211,10 +212,44 @@ struct ThreadwiseTensorSliceTransfer_v3r1
auto src_vector_container = src_vector_type{
src_buf.template Get<src_vector_t>(src_coord_.GetOffset(), is_src_valid)};
using dst_vector_type = vector_type_maker_t<DstData, SrcScalarPerVector>;
using dst_vector_t = typename dst_vector_type::type;
dst_vector_type op_r_v;
constexpr auto get_elem_op_vec_len = []() {
if constexpr(is_detected<is_pack8_invocable_t, decltype(src_element_op_)>::value)
{
if constexpr(decltype(src_element_op_)::is_pack8_invocable)
return math::min(8, SrcScalarPerVector);
}
if constexpr(is_detected<is_pack4_invocable_t, decltype(src_element_op_)>::value)
{
if constexpr(decltype(src_element_op_)::is_pack4_invocable)
return math::min(4, SrcScalarPerVector);
}
if constexpr(is_detected<is_pack2_invocable_t, decltype(src_element_op_)>::value)
{
if constexpr(decltype(src_element_op_)::is_pack2_invocable)
return math::min(2, SrcScalarPerVector);
}
return 1;
};
constexpr index_t elem_op_vec_len = get_elem_op_vec_len();
using src_elem_op_vec_t = typename vector_type<SrcData, elem_op_vec_len>::type;
using dst_elem_op_vec_t = typename vector_type<DstData, elem_op_vec_len>::type;
static_for<0, SrcScalarPerVector / elem_op_vec_len, 1>{}([&](auto idx) {
// apply the src elementwise op and convert to DstData under the hood if needed
src_element_op_(op_r_v.template AsType<dst_elem_op_vec_t>()(idx),
src_vector_container.template AsType<src_elem_op_vec_t>()[idx]);
});
// copy data from src_vector_container into src_thread_scratch_
src_thread_scratch_tuple_(thread_scratch_id)
.template SetAsType<src_vector_t>(
src_data_idx_seq, src_vector_container.template AsType<src_vector_t>()[I0]);
.template SetAsType<dst_vector_t>(src_data_idx_seq,
op_r_v.template AsType<dst_vector_t>()[I0]);
constexpr auto move_on_dim = [&]() constexpr
{
......@@ -267,19 +302,15 @@ struct ThreadwiseTensorSliceTransfer_v3r1
{
#if !CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE
static_ford<SliceLengths>{}([&](auto idx) {
// convert from SrcData to DstData here
dst_thread_scratch_(idx) =
type_convert<DstData>(src_thread_scratch_tuple_[thread_scratch_id][idx]);
dst_thread_scratch_(idx) = src_thread_scratch_tuple_[thread_scratch_id][idx];
});
#else
// sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_
// TODO make this logic more generic for more sub-dword datatype
if constexpr(SrcVectorDim != DstVectorDim &&
((is_same<half_t, remove_cvref_t<SrcData>>::value &&
is_same<half_t, remove_cvref_t<DstData>>::value &&
((is_same<half_t, remove_cvref_t<DstData>>::value &&
SrcScalarPerVector % 2 == 0 && DstScalarPerVector % 2 == 0) ||
(is_same<int8_t, remove_cvref_t<SrcData>>::value &&
is_same<int8_t, remove_cvref_t<DstData>>::value &&
(is_same<int8_t, remove_cvref_t<DstData>>::value &&
SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0)))
{
// each transpose does
......@@ -313,7 +344,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
constexpr auto data_idx_seq = generate_sequence_v2(
[&](auto i) { return Number<data_idx[i]>{}; }, Number<nDim>{});
using src_vector_t = vector_type_maker_t<SrcData, SrcScalarPerVector>;
using src_vector_t = vector_type_maker_t<DstData, SrcScalarPerVector>;
using dst_vector_t = vector_type_maker_t<DstData, DstScalarPerVector>;
// get DstScalarPerVector # of read-only references to src vectors from
......@@ -336,17 +367,16 @@ struct ThreadwiseTensorSliceTransfer_v3r1
Number<num_dst_vector>{});
// do data transpose
transpose_vectors<SrcData, DstScalarPerVector, SrcScalarPerVector>{}(
transpose_vectors<DstData, DstScalarPerVector, SrcScalarPerVector>{}(
src_vector_refs, dst_vector_refs);
});
}
static_ford<SliceLengths>{}([&](auto idx) {
// apply the src elementwise op and convert to DstData under the hood if needed
DstData dst_v;
src_element_op_(dst_v, src_thread_scratch_tuple_[thread_scratch_id][idx]);
dst_thread_scratch_(idx) = dst_v;
});
else
{
static_ford<SliceLengths>{}([&](auto idx) {
dst_thread_scratch_(idx) = src_thread_scratch_tuple_[thread_scratch_id][idx];
});
}
#endif
}
......@@ -761,11 +791,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1
static constexpr auto src_thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){};
static constexpr auto dst_thread_scratch_desc_ = decltype(GetDstThreadScratchDescriptor()){};
using SrcThreadScratch = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
SrcData,
SrcScalarPerVector,
decltype(src_thread_scratch_desc_),
true>;
using SrcThreadScratch =
StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
DstData, // apply data_convert with SrcThreadScratch
SrcScalarPerVector,
decltype(src_thread_scratch_desc_),
true>;
using DstThreadScratch = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
DstData,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_space_filling_curve.hpp"
#include "ck/utility/is_detected.hpp"
namespace ck {
// Thread-level multi-source, multi-destination tensor slice data movement
// Assume:
// 1. All sources and destinations are DynamicBuffer
// 2. Same VectorDim and ScalerPerVector for all sources and destinations
// 3. DstInMemOps are per destination tensor
// 4. ThreadTransferSrcResetCoordinateAfterRunFlags are per source tensor
// 5. ThreadTransferDstResetCoordinateAfterRunFlags are per destination tensor
// 6. Does not need to know src_descs and dst_descs at compile-time
// 7. Does not need to know src_slice_origins and dst_slice_origins at compile-time,
//
// Does following things to avoid scratch memory issue
// 1. Use StaticallyIndexedArray or vector_type instead of C array for thread buffer
// 2. Pass tensor descritpors by reference (or tuple of references)
// 3. Does not keep reference to tensor descriptor
// 4. Does not construct new tensor coordinate when call Run()
template <typename SrcDatas,
typename DstDatas,
typename SrcDescs,
typename DstDescs,
typename ElementwiseOperation,
typename DstInMemOps, // Sequence<InMemoryDataOperationEnum ...>
typename SliceLengths,
typename SrcDimAccessOrder,
typename DstDimAccessOrder,
index_t SrcVectorDim,
index_t DstVectorDim,
index_t SrcScalarPerVector,
index_t DstScalarPerVector,
typename SrcResetCoordinateAfterRunFlags, // Sequence<bool ...>
typename DstResetCoordinateAfterRunFlags> // Sequence<bool ...>
struct ThreadwiseTensorSliceTransfer_v7r2
{
static constexpr auto I0 = Number<0>{};
static constexpr index_t nDim = SliceLengths::Size();
static constexpr index_t nSrc = SrcDescs::Size();
static constexpr index_t nDst = DstDescs::Size();
using Index = MultiIndex<nDim>;
// return a tuple of coordiantes for a tuple of tensor
template <typename Descs,
typename Indices,
enable_if_t<Descs::Size() == Indices::Size(), bool> = false>
static constexpr auto MakeCoordinates(const Descs& descs, const Indices& indices)
{
return generate_tuple([&](auto i) { return make_tensor_coordinate(descs[i], indices[i]); },
Number<Descs::Size()>{});
}
using SrcCoords = decltype(MakeCoordinates(SrcDescs{}, StaticallyIndexedArray<Index, nSrc>{}));
using DstCoords = decltype(MakeCoordinates(DstDescs{}, StaticallyIndexedArray<Index, nDst>{}));
// scalar per access on each dim
// FIXME: don't use lambda_scalar_per_access
static constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
using SrcSpaceFillingCurve = SpaceFillingCurve<SliceLengths,
SrcDimAccessOrder,
remove_cv_t<decltype(src_scalar_per_access)>>;
static constexpr auto dst_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
using DstSpaceFillingCurve = SpaceFillingCurve<SliceLengths,
DstDimAccessOrder,
remove_cv_t<decltype(dst_scalar_per_access)>>;
__device__ constexpr ThreadwiseTensorSliceTransfer_v7r2(
const SrcDescs& src_descs,
const StaticallyIndexedArray<Index, nSrc>& src_slice_origins,
const DstDescs& dst_descs,
const StaticallyIndexedArray<Index, nDst>& dst_slice_origins,
const ElementwiseOperation& element_op)
: src_coords_(MakeCoordinates(src_descs, src_slice_origins)),
dst_coords_(MakeCoordinates(dst_descs, dst_slice_origins)),
element_op_(element_op)
{
static_assert(SliceLengths::At(Number<SrcVectorDim>{}) % SrcScalarPerVector == 0,
"wrong! cannot evenly divide");
static_assert(SliceLengths::At(Number<DstVectorDim>{}) % DstScalarPerVector == 0,
"wrong! cannot evenly divide");
}
template <typename Indices, enable_if_t<SrcDescs::Size() == Indices::Size(), bool> = false>
__device__ void SetSrcSliceOrigins(const SrcDescs& src_descs,
const Indices& src_slice_origin_idxs)
{
static_for<0, nSrc, 1>{}([&](auto i) {
src_coords_(i) = make_tensor_coordinate(src_descs[i], src_slice_origin_idxs[i]);
});
}
template <typename Indices, enable_if_t<DstDescs::Size() == Indices::Size(), bool> = false>
__device__ void SetDstSliceOrigins(const DstDescs& dst_descs,
const Indices& dst_slice_origin_idxs)
{
static_for<0, nDst, 1>{}([&](auto i) {
dst_coords_(i) = make_tensor_coordinate(dst_descs[i], dst_slice_origin_idxs[i]);
});
}
template <typename DataTypes, index_t ScalarPerVector>
__device__ static auto generate_vectors()
{
auto data_types = DataTypes{};
constexpr index_t num = data_types.Size();
return generate_tuple(
[&](auto i) {
using DataType = remove_cvref_t<decltype(data_types[i])>;
return vector_type_maker_t<DataType, ScalarPerVector>{};
},
Number<num>{});
}
// SrcDescs: Tuple<const SrcDesc0&, const SrcDesc1&, ...>
// SrcBuffers: Tuple<const SrcBuffer0&, const SrcBuffer1&, ...>
template <typename SrcBuffers,
enable_if_t<SrcDescs::Size() == SrcBuffers::Size(), bool> = false>
__device__ void RunRead(const SrcDescs& src_descs, const SrcBuffers& src_bufs)
{
// loop over space-filling curve
static_for<0, num_access, 1>{}([&](auto iAccess) {
auto src_vectors = generate_vectors<SrcDatas, SrcScalarPerVector>();
auto dst_vectors = generate_vectors<DstDatas, DstScalarPerVector>();
// copy data from src_bufs into src_vectors
static_for<0, nSrc, 1>{}([&](auto i) {
using src_vector_t = typename remove_cvref_t<decltype(src_vectors[i])>::type;
const bool is_src_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_descs[i],
src_coords_[i]);
src_vectors(i).template AsType<src_vector_t>()(I0) =
src_bufs[i].template Get<src_vector_t>(src_coords_[i].GetOffset(),
is_src_valid);
});
constexpr auto get_elem_op_vec_len = []() {
if constexpr(is_detected<is_pack8_invocable_t, decltype(element_op_)>::value)
{
if constexpr(decltype(element_op_)::is_pack8_invocable)
return math::min(8, SrcScalarPerVector);
}
if constexpr(is_detected<is_pack4_invocable_t, decltype(element_op_)>::value)
{
if constexpr(decltype(element_op_)::is_pack4_invocable)
return math::min(4, SrcScalarPerVector);
}
if constexpr(is_detected<is_pack2_invocable_t, decltype(element_op_)>::value)
{
if constexpr(decltype(element_op_)::is_pack2_invocable)
return math::min(2, SrcScalarPerVector);
}
return 1;
};
constexpr index_t elem_op_vec_len = get_elem_op_vec_len();
// apply pointwise function
static_for<0, SrcScalarPerVector / elem_op_vec_len, 1>{}([&](auto i) {
// get reference to src data
const auto src_data_refs = generate_tie(
// return type should be lvalue
[&](auto iSrc) -> const auto& {
using SrcData = remove_cvref_t<tuple_element_t<iSrc.value, SrcDatas>>;
using elem_op_vec_t = typename vector_type<SrcData, elem_op_vec_len>::type;
return src_vectors[iSrc].template AsType<elem_op_vec_t>()[i];
},
Number<nSrc>{});
// get reference to dst data
auto dst_data_refs = generate_tie(
// return type should be lvalue
[&](auto iDst) -> auto& {
using DstData = remove_cvref_t<tuple_element_t<iDst.value, DstDatas>>;
using elem_op_vec_t = typename vector_type<DstData, elem_op_vec_len>::type;
return dst_vectors(iDst).template AsType<elem_op_vec_t>()(i);
},
Number<nDst>{});
// apply pointwise function
// pointwise function signature:
// element_op_(dst_data_refs[I0],
// dst_data_refs[I1],
// ...,
// src_data_refs[I0],
// src_data_refs[I1],
// ...)
unpack2(element_op_, dst_data_refs, src_data_refs);
});
dst_vectors_tuple_(iAccess) = dst_vectors;
// move coordinate
if constexpr(iAccess.value != num_access - 1)
{
constexpr auto forward_step = SrcSpaceFillingCurve::GetForwardStep(iAccess);
static_for<0, nSrc, 1>{}([&](auto i) {
move_tensor_coordinate(src_descs[i],
src_coords_(i),
make_tensor_coordinate_step(src_descs[i], forward_step));
});
}
});
// move coordinate back to slice origin (or not)
static_for<0, nSrc, 1>{}([&](auto i) {
if constexpr(SrcResetCoordinateAfterRunFlags::At(i))
{
const auto src_reset_step =
make_tensor_coordinate_step(src_descs[i], GetSrcCoordinateResetStep());
move_tensor_coordinate(src_descs[i], src_coords_(i), src_reset_step);
}
});
}
// DstDescs: Tuple<const DstDesc0&, const DstDesc1&, ...>
// DstBuffers: Tuple<const DstBuffer0&, const DstBuffer1&, ...>
template <typename DstBuffers,
enable_if_t<DstDescs::Size() == DstBuffers::Size(), bool> = false>
__device__ void RunWrite(const DstDescs& dst_descs, DstBuffers dst_bufs)
{
// loop over space-filling curve
static_for<0, num_access, 1>{}([&](auto iAccess) {
auto dst_vectors = dst_vectors_tuple_[iAccess];
// copy data from buf_vectors into dst_bufs
static_for<0, nDst, 1>{}([&](auto i) {
using dst_vector_t = typename remove_cvref_t<decltype(dst_vectors[i])>::type;
const bool is_dst_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_descs[i],
dst_coords_[i]);
constexpr InMemoryDataOperationEnum DstInMemOp =
static_cast<InMemoryDataOperationEnum>(DstInMemOps::At(i.value));
dst_bufs(i).template Update<DstInMemOp, dst_vector_t>(
dst_coords_[i].GetOffset(),
is_dst_valid,
dst_vectors[i].template AsType<dst_vector_t>()[I0]);
});
// move coordinate
if constexpr(iAccess.value != num_access - 1)
{
constexpr auto forward_step = DstSpaceFillingCurve::GetForwardStep(iAccess);
static_for<0, nDst, 1>{}([&](auto i) {
move_tensor_coordinate(dst_descs[i],
dst_coords_(i),
make_tensor_coordinate_step(dst_descs[i], forward_step));
});
}
});
static_for<0, nDst, 1>{}([&](auto i) {
if constexpr(DstResetCoordinateAfterRunFlags::At(i))
{
const auto dst_reset_step =
make_tensor_coordinate_step(dst_descs[i], GetDstCoordinateResetStep());
move_tensor_coordinate(dst_descs[i], dst_coords_(i), dst_reset_step);
}
});
}
// SrcDescs: Tuple<const SrcDesc0&, const SrcDesc1&, ...>
// SrcBuffers: Tuple<const SrcBuffer0&, const SrcBuffer1&, ...>
// DstDescs: Tuple<const DstDesc0&, const DstDesc1&, ...>
// DstBuffers: Tuple<const DstBuffer0&, const DstBuffer1&, ...>
template <typename SrcBuffers,
typename DstBuffers,
enable_if_t<SrcDescs::Size() == SrcBuffers::Size() &&
DstDescs::Size() == DstBuffers::Size(),
bool> = false>
__device__ void Run(const SrcDescs& src_descs,
const SrcBuffers& src_bufs,
const DstDescs& dst_descs,
DstBuffers dst_bufs)
{
RunRead(src_descs, src_bufs);
RunWrite(dst_descs, dst_bufs);
}
__device__ static constexpr auto GetSrcCoordinateResetStep()
{
if constexpr(num_access == 0)
{
return typename SrcSpaceFillingCurve::Index{};
}
else
{
return SrcSpaceFillingCurve::GetStepBetween(Number<num_access - 1>{}, Number<0>{});
}
}
__device__ static constexpr auto GetDstCoordinateResetStep()
{
if constexpr(num_access == 0)
{
return typename DstSpaceFillingCurve::Index{};
}
else
{
return DstSpaceFillingCurve::GetStepBetween(Number<num_access - 1>{}, Number<0>{});
}
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
template <index_t ISrc>
__device__ void MoveSrcSliceWindow(const SrcDescs& src_descs,
Number<ISrc> iSrc,
const Index& src_slice_origin_step_idx)
{
// if src coord was not reset by RunRead(), then need to adjust the step here
const auto adjusted_step_idx =
SrcResetCoordinateAfterRunFlags::At(iSrc)
? src_slice_origin_step_idx
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(src_descs[iSrc], adjusted_step_idx);
move_tensor_coordinate(src_descs[iSrc], src_coords_(iSrc), adjusted_step);
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
template <index_t IDst>
__device__ void MoveDstSliceWindow(const DstDescs& dst_descs,
Number<IDst> iDst,
const Index& dst_slice_origin_step_idx)
{
// if dst coord was not reset by Run(), then need to adjust the step here
const auto adjusted_step_idx =
DstResetCoordinateAfterRunFlags::At(iDst)
? dst_slice_origin_step_idx
: dst_slice_origin_step_idx + GetDstCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(dst_descs[iDst], adjusted_step_idx);
move_tensor_coordinate(dst_descs[iDst], dst_coords_(iDst), adjusted_step);
}
private:
using SrcVectorsType = decltype(generate_vectors<SrcDatas, SrcScalarPerVector>());
using DstVectorsType = decltype(generate_vectors<DstDatas, DstScalarPerVector>());
static constexpr auto num_access = SrcSpaceFillingCurve::GetNumOfAccess();
StaticallyIndexedArray<DstVectorsType, num_access> dst_vectors_tuple_;
SrcCoords src_coords_;
DstCoords dst_coords_;
const ElementwiseOperation element_op_;
};
} // namespace ck
......@@ -31,7 +31,13 @@ enum struct MfmaInstr
mfma_i32_16x16x32i8,
mfma_f64_16x16x4f64,
mfma_f32_32x32x16f8f8,
mfma_f32_16x16x32f8f8
mfma_f32_16x16x32f8f8,
mfma_f32_32x32x16bf8bf8,
mfma_f32_16x16x32bf8bf8,
mfma_f32_32x32x16f8bf8,
mfma_f32_16x16x32f8bf8,
mfma_f32_32x32x16bf8f8,
mfma_f32_16x16x32bf8f8
};
template <MfmaInstr instr>
......@@ -456,7 +462,6 @@ struct mfma_type<MfmaInstr::mfma_f64_16x16x4f64>
}
};
#if defined CK_ENABLE_FP8
template <>
struct mfma_type<MfmaInstr::mfma_f32_32x32x16f8f8>
{
......@@ -500,12 +505,149 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32f8f8>
intrin_mfma_f32_16x16x32f8f8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};
#endif
template <typename base_type, index_t MPerXdlops, index_t NPerXdlops>
template <>
struct mfma_type<MfmaInstr::mfma_f32_32x32x16bf8bf8>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 4;
static constexpr index_t num_regs_per_blk = 16;
static constexpr index_t num_threads_per_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 2;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 32;
static constexpr index_t n_per_blk = 32;
static constexpr index_t k_per_blk = 8;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_f32_32x32x16bf8bf8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};
template <>
struct mfma_type<MfmaInstr::mfma_f32_16x16x32bf8bf8>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_per_blk = 4;
static constexpr index_t num_threads_per_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 4;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 16;
static constexpr index_t n_per_blk = 16;
static constexpr index_t k_per_blk = 8;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_f32_16x16x32bf8bf8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};
template <>
struct mfma_type<MfmaInstr::mfma_f32_32x32x16f8bf8>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 4;
static constexpr index_t num_regs_per_blk = 16;
static constexpr index_t num_threads_per_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 2;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 32;
static constexpr index_t n_per_blk = 32;
static constexpr index_t k_per_blk = 8;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_f32_32x32x16f8bf8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};
template <>
struct mfma_type<MfmaInstr::mfma_f32_16x16x32f8bf8>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_per_blk = 4;
static constexpr index_t num_threads_per_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 4;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 16;
static constexpr index_t n_per_blk = 16;
static constexpr index_t k_per_blk = 8;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_f32_16x16x32f8bf8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};
template <>
struct mfma_type<MfmaInstr::mfma_f32_32x32x16bf8f8>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 4;
static constexpr index_t num_regs_per_blk = 16;
static constexpr index_t num_threads_per_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 2;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 32;
static constexpr index_t n_per_blk = 32;
static constexpr index_t k_per_blk = 8;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_f32_32x32x16bf8f8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};
template <>
struct mfma_type<MfmaInstr::mfma_f32_16x16x32bf8f8>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_per_blk = 4;
static constexpr index_t num_threads_per_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 4;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 16;
static constexpr index_t n_per_blk = 16;
static constexpr index_t k_per_blk = 8;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_f32_16x16x32bf8f8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};
template <typename base_type,
index_t MPerXdlops,
index_t NPerXdlops,
typename additional_type = base_type>
struct MfmaSelector
{
template <typename base_type_, index_t MPerXdlops_, index_t NPerXdlops_>
template <typename base_type_,
index_t MPerXdlops_,
index_t NPerXdlops_,
typename additional_type_ = base_type_>
static constexpr auto GetMfma();
template <>
......@@ -642,7 +784,6 @@ struct MfmaSelector
}
#endif
#if defined CK_ENABLE_FP8
template <>
static constexpr auto GetMfma<f8_t, 32, 32>()
{
......@@ -654,9 +795,45 @@ struct MfmaSelector
{
return MfmaInstr::mfma_f32_16x16x32f8f8;
}
#endif
static constexpr auto selected_mfma = mfma_type<GetMfma<base_type, MPerXdlops, NPerXdlops>()>{};
template <>
static constexpr auto GetMfma<bf8_t, 32, 32>()
{
return MfmaInstr::mfma_f32_32x32x16bf8bf8;
}
template <>
static constexpr auto GetMfma<bf8_t, 16, 16>()
{
return MfmaInstr::mfma_f32_16x16x32bf8bf8;
}
template <>
static constexpr auto GetMfma<f8_t, 32, 32, bf8_t>()
{
return MfmaInstr::mfma_f32_32x32x16f8bf8;
}
template <>
static constexpr auto GetMfma<f8_t, 16, 16, bf8_t>()
{
return MfmaInstr::mfma_f32_16x16x32f8bf8;
}
template <>
static constexpr auto GetMfma<bf8_t, 32, 32, f8_t>()
{
return MfmaInstr::mfma_f32_32x32x16bf8f8;
}
template <>
static constexpr auto GetMfma<bf8_t, 16, 16, f8_t>()
{
return MfmaInstr::mfma_f32_16x16x32bf8f8;
}
static constexpr auto selected_mfma =
mfma_type<GetMfma<base_type, MPerXdlops, NPerXdlops, additional_type>()>{};
__host__ __device__ constexpr MfmaSelector()
{
......@@ -703,7 +880,8 @@ template <typename base_type,
index_t MPerXdlops,
index_t NPerXdlops,
index_t KPack,
bool TransposeC = false>
typename additional_type = base_type,
bool TransposeC = false>
struct XdlopsGemm
{
static constexpr auto I0 = Number<0>{};
......@@ -854,14 +1032,14 @@ struct XdlopsGemm
template <class FloatA, class FloatB, class FloatC>
__device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const
{
static_assert(is_same<base_type, double>::value || is_same<base_type, float>::value ||
is_same<base_type, half_t>::value || is_same<base_type, bhalf_t>::value ||
is_same<base_type, int8_t>::value
#if defined CK_ENABLE_FP8
|| is_same<base_type, f8_t>::value
#endif
,
"base base_type must be double, float, half, bfloat16, and int8_t!");
static_assert(
is_same<base_type, double>::value || is_same<base_type, float>::value ||
is_same<base_type, half_t>::value || is_same<base_type, bhalf_t>::value ||
is_same<base_type, int8_t>::value || is_same<base_type, f8_t>::value ||
is_same<base_type, bf8_t>::value ||
(is_same<base_type, f8_t>::value && is_same<additional_type, bf8_t>::value) ||
(is_same<base_type, bf8_t>::value && is_same<additional_type, f8_t>::value),
"base base_type must be double, float, half, bfloat16, int8_t, f8_t or bf8_t!");
static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) {
if constexpr(!TransposeC)
......@@ -957,7 +1135,7 @@ struct XdlopsGemm
return TransposeC ? CIndex4D{blk_td, I0, blk_id, I0} : CIndex4D{I0, blk_id, I0, blk_td};
}
static constexpr auto mfma = MfmaSelector<base_type, MPerXdlops, NPerXdlops>{};
static constexpr auto mfma = MfmaSelector<base_type, MPerXdlops, NPerXdlops, additional_type>{};
static constexpr auto mfma_instr = mfma.selected_mfma;
......
......@@ -20,348 +20,13 @@ struct TransformConvFwdToGemm
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
template <typename ALayout,
typename std::enable_if<NDimSpatial == 1 &&
is_same_v<ALayout, tensor_layout::convolution::GNWC>,
bool>::type = false>
static auto
MakeADescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& /* a_g_n_c_wis_strides */,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& /* b_g_k_c_xs_strides */,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& /* c_g_n_k_wos_strides */,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads)
{
const index_t N = a_g_n_c_wis_lengths[1];
const index_t C = a_g_n_c_wis_lengths[2];
const index_t Wi = a_g_n_c_wis_lengths[3];
const index_t Wo = c_g_n_k_wos_lengths[3];
const index_t ConvStrideW = conv_filter_strides[0];
if constexpr(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{
const index_t NWo =
N * ck::accumulate_n<index_t>(
c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>());
const auto in_gemmm_gemmk_desc =
make_naive_tensor_descriptor_packed(make_tuple(NWo, C));
return in_gemmm_gemmk_desc;
}
else if constexpr(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter1x1Pad0)
{
const auto in_n_wi_c_desc = make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C));
const auto in_n_wo_c_desc = transform_tensor_descriptor(
in_n_wi_c_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto in_gemmm_gemmk_desc = transform_tensor_descriptor(
in_n_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N, Wo)), make_pass_through_transform(C)),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemmm_gemmk_desc;
}
else
{
const index_t X = b_g_k_c_xs_lengths[3];
const index_t ConvDilationW = conv_filter_dilations[0];
const index_t InLeftPadW = input_left_pads[0];
const index_t InRightPadW = input_right_pads[0];
const auto in_n_wi_c_desc = make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C));
const auto in_n_wip_c_desc = transform_tensor_descriptor(
in_n_wi_c_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto in_n_x_wo_c_desc = transform_tensor_descriptor(
in_n_wip_c_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
const auto in_gemmm_gemmk_desc =
transform_tensor_descriptor(in_n_x_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N, Wo)),
make_merge_transform(make_tuple(X, C))),
make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemmm_gemmk_desc;
}
}
template <typename ALayout,
typename std::enable_if<NDimSpatial == 2 &&
is_same_v<ALayout, tensor_layout::convolution::GNHWC>,
bool>::type = false>
static auto
MakeADescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& /* a_g_n_c_wis_strides */,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& /* b_g_k_c_xs_strides */,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& /* c_g_n_k_wos_strides */,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads)
{
const index_t N = a_g_n_c_wis_lengths[1];
const index_t C = a_g_n_c_wis_lengths[2];
const index_t Hi = a_g_n_c_wis_lengths[3];
const index_t Wi = a_g_n_c_wis_lengths[4];
const index_t Ho = c_g_n_k_wos_lengths[3];
const index_t Wo = c_g_n_k_wos_lengths[4];
const index_t ConvStrideH = conv_filter_strides[0];
const index_t ConvStrideW = conv_filter_strides[1];
if constexpr(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{
const index_t NHoWo =
N * ck::accumulate_n<index_t>(
c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>());
const auto in_gemmm_gemmk_desc =
make_naive_tensor_descriptor_packed(make_tuple(NHoWo, C));
return in_gemmm_gemmk_desc;
}
else if constexpr(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter1x1Pad0)
{
const auto in_n_hi_wi_c_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
const auto in_n_ho_wo_c_desc = transform_tensor_descriptor(
in_n_hi_wi_c_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_gemmm_gemmk_desc =
transform_tensor_descriptor(in_n_ho_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)),
make_pass_through_transform(C)),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemmm_gemmk_desc;
}
else
{
const index_t Y = b_g_k_c_xs_lengths[3];
const index_t X = b_g_k_c_xs_lengths[4];
const index_t ConvDilationH = conv_filter_dilations[0];
const index_t ConvDilationW = conv_filter_dilations[1];
const index_t InLeftPadH = input_left_pads[0];
const index_t InLeftPadW = input_left_pads[1];
const index_t InRightPadH = input_right_pads[0];
const index_t InRightPadW = input_right_pads[1];
const auto in_n_hi_wi_c_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
const auto in_n_hip_wip_c_desc = transform_tensor_descriptor(
in_n_hi_wi_c_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_y_ho_x_wo_c_desc = transform_tensor_descriptor(
in_n_hip_wip_c_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_gemmm_gemmk_desc =
transform_tensor_descriptor(in_n_y_ho_x_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)),
make_merge_transform(make_tuple(Y, X, C))),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemmm_gemmk_desc;
}
}
template <typename ALayout,
typename std::enable_if<NDimSpatial == 3 &&
is_same_v<ALayout, tensor_layout::convolution::GNDHWC>,
bool>::type = false>
static auto
MakeADescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& /* a_g_n_c_wis_strides */,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& /* b_g_k_c_xs_strides */,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& /* c_g_n_k_wos_strides */,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads)
{
const index_t N = a_g_n_c_wis_lengths[1];
const index_t C = a_g_n_c_wis_lengths[2];
const index_t Di = a_g_n_c_wis_lengths[3];
const index_t Hi = a_g_n_c_wis_lengths[4];
const index_t Wi = a_g_n_c_wis_lengths[5];
const index_t Do = c_g_n_k_wos_lengths[3];
const index_t Ho = c_g_n_k_wos_lengths[4];
const index_t Wo = c_g_n_k_wos_lengths[5];
const index_t ConvStrideD = conv_filter_strides[0];
const index_t ConvStrideH = conv_filter_strides[1];
const index_t ConvStrideW = conv_filter_strides[2];
if constexpr(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{
const index_t NDoHoWo =
N * ck::accumulate_n<index_t>(
c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>());
const auto in_gemmm_gemmk_desc =
make_naive_tensor_descriptor_packed(make_tuple(NDoHoWo, C));
return in_gemmm_gemmk_desc;
}
else if constexpr(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter1x1Pad0)
{
const auto in_n_di_hi_wi_c_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C));
const auto in_n_do_ho_wo_c_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Do), make_tuple(ConvStrideD)),
make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto in_gemmm_gemmk_desc = transform_tensor_descriptor(
in_n_do_ho_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)),
make_pass_through_transform(C)),
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemmm_gemmk_desc;
}
else
{
const index_t Z = b_g_k_c_xs_lengths[3];
const index_t Y = b_g_k_c_xs_lengths[4];
const index_t X = b_g_k_c_xs_lengths[5];
const index_t ConvDilationD = conv_filter_dilations[0];
const index_t ConvDilationH = conv_filter_dilations[1];
const index_t ConvDilationW = conv_filter_dilations[2];
const index_t InLeftPadD = input_left_pads[0];
const index_t InLeftPadH = input_left_pads[1];
const index_t InLeftPadW = input_left_pads[2];
const index_t InRightPadD = input_right_pads[0];
const index_t InRightPadH = input_right_pads[1];
const index_t InRightPadW = input_right_pads[2];
const auto in_n_di_hi_wi_c_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C));
const auto in_n_hip_wip_c_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Di, InLeftPadD, InRightPadD),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor(
in_n_hip_wip_c_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{},
Sequence<1, 2>{},
Sequence<3, 4>{},
Sequence<5, 6>{},
Sequence<7>{}));
const auto in_gemmm_gemmk_desc = transform_tensor_descriptor(
in_n_z_do_y_ho_x_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)),
make_merge_transform(make_tuple(Z, Y, X, C))),
make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemmm_gemmk_desc;
}
}
// TODO: implement ck::tensor_layout::convolution that describe packed/strided dimemsion as
// properties
template <typename ALayout,
typename std::enable_if<NDimSpatial == 1 &&
(is_same_v<ALayout, tensor_layout::convolution::G_NW_C> ||
is_same_v<ALayout, tensor_layout::convolution::NWGC>),
is_same_v<ALayout, tensor_layout::convolution::NWGC> ||
is_same_v<ALayout, tensor_layout::convolution::GNWC>),
bool>::type = false>
static auto
MakeADescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
......@@ -473,7 +138,8 @@ struct TransformConvFwdToGemm
template <typename ALayout,
typename std::enable_if<
NDimSpatial == 2 && (is_same_v<ALayout, tensor_layout::convolution::G_NHW_C> ||
is_same_v<ALayout, tensor_layout::convolution::NHWGC>),
is_same_v<ALayout, tensor_layout::convolution::NHWGC> ||
is_same_v<ALayout, tensor_layout::convolution::GNHWC>),
bool>::type = false>
static auto
MakeADescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
......@@ -601,7 +267,8 @@ struct TransformConvFwdToGemm
template <typename ALayout,
typename std::enable_if<
NDimSpatial == 3 && (is_same_v<ALayout, tensor_layout::convolution::G_NDHW_C> ||
is_same_v<ALayout, tensor_layout::convolution::NDHWGC>),
is_same_v<ALayout, tensor_layout::convolution::NDHWGC> ||
is_same_v<ALayout, tensor_layout::convolution::GNDHWC>),
bool>::type = false>
static auto
MakeADescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
......
......@@ -299,584 +299,255 @@ enum struct AmdBufferCoherenceEnum
GLC_SLC = 3,
};
template <typename T,
index_t N,
AmdBufferCoherenceEnum coherence = AmdBufferCoherenceEnum::DefaultCoherence>
__device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_wave_buffer_resource,
index_t src_thread_addr_offset,
index_t src_wave_addr_offset)
template <index_t N, AmdBufferCoherenceEnum coherence = AmdBufferCoherenceEnum::DefaultCoherence>
__device__ typename vector_type<int8_t, N>::type
amd_buffer_load_impl_raw(int32x4_t src_wave_buffer_resource,
index_t src_thread_addr_offset,
index_t src_wave_addr_offset)
{
static_assert(
(is_same<T, double>::value && (N == 1 || N == 2 || N == 4)) ||
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, bhalf_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
"wrong! not implemented");
static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32 || N == 64,
"wrong! not implemented");
if constexpr(is_same<T, double>::value)
if constexpr(N == 1)
{
// use fp32 load to mimic fp64 load
if constexpr(N == 1)
{
const float2_t tmp =
llvm_amdgcn_raw_buffer_load_fp32x2(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
return bit_cast<double>(tmp);
}
else if constexpr(N == 2)
{
const float4_t tmp =
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
return bit_cast<double2_t>(tmp);
}
else if constexpr(N == 4)
{
const float4_t f32_0 =
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
const float4_t f32_1 =
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 4 * sizeof(float),
static_cast<index_t>(coherence));
vector_type<double, 4> tmp;
tmp.AsType<double2_t>()(Number<0>{}) = bit_cast<double2_t>(f32_0);
tmp.AsType<double2_t>()(Number<1>{}) = bit_cast<double2_t>(f32_1);
return tmp.AsType<double4_t>()(Number<0>{});
}
return llvm_amdgcn_raw_buffer_load_i8(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(is_same<T, float>::value)
else if constexpr(N == 2)
{
if constexpr(N == 1)
{
return llvm_amdgcn_raw_buffer_load_fp32(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 2)
{
return llvm_amdgcn_raw_buffer_load_fp32x2(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 4)
{
return llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
int16_t tmp = llvm_amdgcn_raw_buffer_load_i16(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 8)
{
vector_type<float, 8> tmp;
tmp.AsType<float4_t>()(Number<0>{}) =
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
tmp.AsType<float4_t>()(Number<1>{}) =
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 4 * sizeof(float),
static_cast<index_t>(coherence));
return tmp.AsType<float8_t>()(Number<0>{});
}
return bit_cast<int8x2_t>(tmp);
}
else if constexpr(is_same<T, half_t>::value)
else if constexpr(N == 4)
{
if constexpr(N == 1)
{
return llvm_amdgcn_raw_buffer_load_fp16(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 2)
{
return llvm_amdgcn_raw_buffer_load_fp16x2(src_wave_buffer_resource,
int32_t tmp = llvm_amdgcn_raw_buffer_load_i32(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 4)
{
return llvm_amdgcn_raw_buffer_load_fp16x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 8)
{
// use fp32 load to mimic fp16 load
float4_t tmp = llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
return bit_cast<half8_t>(tmp);
}
}
else if constexpr(is_same<T, bhalf_t>::value)
{
if constexpr(N == 1)
{
return llvm_amdgcn_raw_buffer_load_i16(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 2)
{
return llvm_amdgcn_raw_buffer_load_i16x2(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 4)
{
return llvm_amdgcn_raw_buffer_load_i16x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 8)
{
int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
return bit_cast<bhalf8_t>(tmp);
}
return bit_cast<int8x4_t>(tmp);
}
else if constexpr(is_same<T, int32_t>::value)
else if constexpr(N == 8)
{
if constexpr(N == 1)
{
return llvm_amdgcn_raw_buffer_load_i32(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 2)
{
return llvm_amdgcn_raw_buffer_load_i32x2(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 4)
{
return llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 8)
{
vector_type<int32_t, 8> tmp;
tmp.AsType<int32x4_t>()(Number<0>{}) =
llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
tmp.AsType<int32x4_t>()(Number<1>{}) =
llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 4 * sizeof(int32_t),
static_cast<index_t>(coherence));
return tmp.AsType<int32x8_t>()(Number<0>{});
}
}
else if constexpr(is_same<T, int8_t>::value)
{
if constexpr(N == 1)
{
return llvm_amdgcn_raw_buffer_load_i8(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 2)
{
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
return llvm_amdgcn_raw_buffer_load_i8x2(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
#else
int16_t tmp = llvm_amdgcn_raw_buffer_load_i16(src_wave_buffer_resource,
int32x2_t tmp = llvm_amdgcn_raw_buffer_load_i32x2(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
return bit_cast<int8x2_t>(tmp);
#endif
}
else if constexpr(N == 4)
{
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
return llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
#else
int32_t tmp = llvm_amdgcn_raw_buffer_load_i32(src_wave_buffer_resource,
return bit_cast<int8x8_t>(tmp);
}
else if constexpr(N == 16)
{
int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
return bit_cast<int8x16_t>(tmp);
}
else if constexpr(N == 32)
{
int32x4_t tmp0 = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
int32x4_t tmp1 =
llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 4 * sizeof(int32_t),
static_cast<index_t>(coherence));
vector_type<int32_t, 8> tmp;
return bit_cast<int8x4_t>(tmp);
#endif
}
else if constexpr(N == 8)
{
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
vector_type<int8_t, 8> tmp;
tmp.AsType<int8x4_t>()(Number<0>{}) =
llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
tmp.AsType<int8x4_t>()(Number<1>{}) =
llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 4 * sizeof(int8_t),
static_cast<index_t>(coherence));
return tmp.AsType<int8x8_t>()(Number<0>{});
#else
int32x2_t tmp = llvm_amdgcn_raw_buffer_load_i32x2(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
tmp.AsType<int32x4_t>()(Number<0>{}) = tmp0;
tmp.AsType<int32x4_t>()(Number<1>{}) = tmp1;
return bit_cast<int8x8_t>(tmp);
#endif
}
else if constexpr(N == 16)
{
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
vector_type<int8_t, 16> tmp;
tmp.AsType<int8x4_t>()(Number<0>{}) =
llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
tmp.AsType<int8x4_t>()(Number<1>{}) =
llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 4 * sizeof(int8_t),
static_cast<index_t>(coherence));
tmp.AsType<int8x4_t>()(Number<2>{}) =
llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 8 * sizeof(int8_t),
static_cast<index_t>(coherence));
tmp.AsType<int8x4_t>()(Number<3>{}) =
llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 12 * sizeof(int8_t),
static_cast<index_t>(coherence));
return tmp.AsType<int8x16_t>()(Number<0>{});
#else
int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
return bit_cast<int8x32_t>(tmp);
}
else if constexpr(N == 64)
{
int32x4_t tmp0 = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
int32x4_t tmp1 =
llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 4 * sizeof(int32_t),
static_cast<index_t>(coherence));
int32x4_t tmp2 =
llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 8 * sizeof(int32_t),
static_cast<index_t>(coherence));
int32x4_t tmp3 =
llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 12 * sizeof(int32_t),
static_cast<index_t>(coherence));
return bit_cast<int8x16_t>(tmp);
#endif
}
vector_type<int32_t, 16> tmp;
tmp.AsType<int32x4_t>()(Number<0>{}) = tmp0;
tmp.AsType<int32x4_t>()(Number<1>{}) = tmp1;
tmp.AsType<int32x4_t>()(Number<2>{}) = tmp2;
tmp.AsType<int32x4_t>()(Number<3>{}) = tmp3;
return bit_cast<int8x64_t>(tmp);
}
}
template <typename T,
index_t N,
AmdBufferCoherenceEnum coherence = AmdBufferCoherenceEnum::DefaultCoherence>
__device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src_thread_data,
int32x4_t dst_wave_buffer_resource,
index_t dst_thread_addr_offset,
index_t dst_wave_addr_offset)
__device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_wave_buffer_resource,
index_t src_thread_addr_offset,
index_t src_wave_addr_offset)
{
static_assert(
(is_same<T, double>::value && (N == 1 || N == 2)) ||
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, bhalf_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4)) ||
(is_same<T, double>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, bhalf_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, f8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, bf8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
"wrong! not implemented");
if constexpr(is_same<T, double>::value)
using r_t = typename vector_type<T, N>::type;
auto raw_data = amd_buffer_load_impl_raw<sizeof(T) * N, coherence>(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset);
return bit_cast<r_t>(raw_data);
}
template <index_t N, AmdBufferCoherenceEnum coherence = AmdBufferCoherenceEnum::DefaultCoherence>
__device__ void
amd_buffer_store_impl_raw(const typename vector_type<int8_t, N>::type src_thread_data,
int32x4_t dst_wave_buffer_resource,
index_t dst_thread_addr_offset,
index_t dst_wave_addr_offset)
{
static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32 || N == 64,
"wrong! not implemented");
if constexpr(N == 1)
{
// use fp32 store to mimic fp64 store
if constexpr(N == 1)
{
llvm_amdgcn_raw_buffer_store_fp32x2(bit_cast<float2_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 2)
{
llvm_amdgcn_raw_buffer_store_fp32x4(bit_cast<float4_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
llvm_amdgcn_raw_buffer_store_i8(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(is_same<T, float>::value)
else if constexpr(N == 2)
{
if constexpr(N == 1)
{
llvm_amdgcn_raw_buffer_store_fp32(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 2)
{
llvm_amdgcn_raw_buffer_store_fp32x2(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 4)
{
llvm_amdgcn_raw_buffer_store_fp32x4(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 8)
{
vector_type<float, 8> tmp{src_thread_data};
llvm_amdgcn_raw_buffer_store_fp32x4(tmp.AsType<float4_t>()[Number<0>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
llvm_amdgcn_raw_buffer_store_fp32x4(tmp.AsType<float4_t>()[Number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + 4 * sizeof(float),
static_cast<index_t>(coherence));
}
llvm_amdgcn_raw_buffer_store_i16(bit_cast<int16_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(is_same<T, half_t>::value)
else if constexpr(N == 4)
{
if constexpr(N == 1)
{
llvm_amdgcn_raw_buffer_store_fp16(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 2)
{
llvm_amdgcn_raw_buffer_store_fp16x2(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 4)
{
llvm_amdgcn_raw_buffer_store_fp16x4(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 8)
{
#if 0
vector_type<half_t, 8> tmp{src_thread_data};
llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType<half4_t>()[Number<0>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType<half4_t>()[Number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + 4 * sizeof(half_t),
static_cast<index_t>(coherence));
#else
llvm_amdgcn_raw_buffer_store_fp32x4(bit_cast<float4_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
#endif
}
llvm_amdgcn_raw_buffer_store_i32(bit_cast<int32_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(is_same<T, bhalf_t>::value)
else if constexpr(N == 8)
{
if constexpr(N == 1)
{
llvm_amdgcn_raw_buffer_store_i16(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 2)
{
llvm_amdgcn_raw_buffer_store_i16x2(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 4)
{
llvm_amdgcn_raw_buffer_store_i16x4(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 8)
{
vector_type<bhalf_t, 8> tmp{src_thread_data};
llvm_amdgcn_raw_buffer_store_i16x4(tmp.AsType<bhalf4_t>()[Number<0>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
llvm_amdgcn_raw_buffer_store_i16x4(tmp.AsType<bhalf4_t>()[Number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + 4 * sizeof(bhalf_t),
static_cast<index_t>(coherence));
}
llvm_amdgcn_raw_buffer_store_i32x2(bit_cast<int32x2_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(is_same<T, int32_t>::value)
else if constexpr(N == 16)
{
if constexpr(N == 1)
{
llvm_amdgcn_raw_buffer_store_i32(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 2)
{
llvm_amdgcn_raw_buffer_store_i32x2(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 4)
{
llvm_amdgcn_raw_buffer_store_i32x4(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
llvm_amdgcn_raw_buffer_store_i32x4(bit_cast<int32x4_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(is_same<T, int8_t>::value)
else if constexpr(N == 32)
{
if constexpr(N == 1)
{
llvm_amdgcn_raw_buffer_store_i8(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 2)
{
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
llvm_amdgcn_raw_buffer_store_i8x2(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
#else
llvm_amdgcn_raw_buffer_store_i16(bit_cast<int16_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
#endif
}
else if constexpr(N == 4)
{
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
llvm_amdgcn_raw_buffer_store_i8x4(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
#else
llvm_amdgcn_raw_buffer_store_i32(bit_cast<int32_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
#endif
}
else if constexpr(N == 8)
{
llvm_amdgcn_raw_buffer_store_i32x2(bit_cast<int32x2_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 16)
{
llvm_amdgcn_raw_buffer_store_i32x4(bit_cast<int32x4_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
vector_type<int32_t, 8> tmp{bit_cast<int32x8_t>(src_thread_data)};
llvm_amdgcn_raw_buffer_store_i32x4(tmp.template AsType<int32x4_t>()[Number<0>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
llvm_amdgcn_raw_buffer_store_i32x4(tmp.template AsType<int32x4_t>()[Number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + sizeof(int32_t) * 4,
static_cast<index_t>(coherence));
}
else if constexpr(N == 64)
{
vector_type<int32_t, 16> tmp{bit_cast<int32x16_t>(src_thread_data)};
llvm_amdgcn_raw_buffer_store_i32x4(tmp.template AsType<int32x4_t>()[Number<0>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
llvm_amdgcn_raw_buffer_store_i32x4(tmp.template AsType<int32x4_t>()[Number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + sizeof(int32_t) * 4,
static_cast<index_t>(coherence));
llvm_amdgcn_raw_buffer_store_i32x4(tmp.template AsType<int32x4_t>()[Number<2>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + sizeof(int32_t) * 8,
static_cast<index_t>(coherence));
llvm_amdgcn_raw_buffer_store_i32x4(tmp.template AsType<int32x4_t>()[Number<3>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + sizeof(int32_t) * 12,
static_cast<index_t>(coherence));
}
}
template <typename T,
index_t N,
AmdBufferCoherenceEnum coherence = AmdBufferCoherenceEnum::DefaultCoherence>
__device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src_thread_data,
int32x4_t dst_wave_buffer_resource,
index_t dst_thread_addr_offset,
index_t dst_wave_addr_offset)
{
static_assert(
(is_same<T, double>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, bhalf_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, f8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, bf8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
"wrong! not implemented");
using r_t = typename vector_type<int8_t, sizeof(T) * N>::type;
amd_buffer_store_impl_raw<sizeof(T) * N, coherence>(bit_cast<r_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset);
}
template <typename T, index_t N>
......@@ -1127,38 +798,14 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x80000000;
#if defined CK_ENABLE_FP8
if constexpr(is_same<scalar_t, f8_t>::value)
{
auto tmp = amd_buffer_load_impl<int8_t, vector_size, coherence>(
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0);
return bit_cast<vector_t>(tmp);
}
else
{
#endif
return amd_buffer_load_impl<scalar_t, vector_size, coherence>(
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0);
#if defined CK_ENABLE_FP8
}
#endif
return amd_buffer_load_impl<scalar_t, vector_size, coherence>(
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0);
#else
#if defined CK_ENABLE_FP8
if constexpr(is_same<scalar_t, f8_t>::value)
{
auto tmp = amd_buffer_load_impl<int8_t, vector_size, coherence>(
src_wave_buffer_resource, src_thread_addr_offset, 0);
return src_thread_element_valid ? bit_cast<vector_t>(tmp) : vector_t(0);
}
else
{
#endif
vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size, coherence>(
src_wave_buffer_resource, src_thread_addr_offset, 0);
return src_thread_element_valid ? tmp : vector_t(0);
#if defined CK_ENABLE_FP8
}
#endif
vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size, coherence>(
src_wave_buffer_resource, src_thread_addr_offset, 0);
return src_thread_element_valid ? tmp : vector_t(0);
#endif
}
......@@ -1216,41 +863,13 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000;
#if defined CK_ENABLE_FP8
if constexpr(is_same<scalar_t, f8_t>::value)
{
auto tmp =
bit_cast<typename vector_type_maker<int8_t, vector_size>::type::type>(src_thread_data);
amd_buffer_store_impl<int8_t, vector_size, coherence>(
tmp, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
}
else
{
#endif
amd_buffer_store_impl<scalar_t, vector_size, coherence>(
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
#if defined CK_ENABLE_FP8
}
#endif
amd_buffer_store_impl<scalar_t, vector_size, coherence>(
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
#else
if(dst_thread_element_valid)
{
#if defined CK_ENABLE_FP8
if constexpr(is_same<scalar_t, f8_t>::value)
{
auto tmp = bit_cast<typename vector_type_maker<int8_t, vector_size>::type::type>(
src_thread_data);
amd_buffer_store_impl<int8_t, vector_size, coherence>(
tmp, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
}
else
{
#endif
amd_buffer_store_impl<scalar_t, vector_size, coherence>(
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
#if defined CK_ENABLE_FP8
}
#endif
amd_buffer_store_impl<scalar_t, vector_size, coherence>(
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
}
#endif
}
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_AMD_XDLOPS_HPP
#define CK_AMD_XDLOPS_HPP
#include "data_type.hpp"
#pragma once
namespace ck {
......@@ -355,7 +352,6 @@ struct intrin_mfma_f64_16x16x4f64<16, 16>
}
};
#if defined CK_ENABLE_FP8
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x16f8f8;
......@@ -418,6 +414,194 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16>
#endif
}
};
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x16bf8bf8;
template <>
struct intrin_mfma_f32_32x32x16bf8bf8<32, 32>
{
template <class FloatC>
__device__ static void Run(const bf8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
reg_c.template AsType<float16_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8(
bit_cast<long>(reg_a),
bit_cast<long>(reg_b),
reg_c.template AsType<float16_t>()[Number<0>{}],
0,
0,
0);
#else
vector_type<bf8_t, 8> reg_a_v(reg_a);
vector_type<bf8_t, 8> reg_b_v(reg_b);
static_for<0, 8, 1>{}([&](auto k) {
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<bf8_t>()[Number<k>{}]);
float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<bf8_t>()[Number<k>{}]);
intrin_mfma_f32_32x32x2f32<32, 32>::Run(reg_a_f32, reg_b_f32, reg_c);
});
#endif
} // namespace ck
}
};
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_16x16x32bf8bf8;
template <>
struct intrin_mfma_f32_16x16x32bf8bf8<16, 16>
{
template <class FloatC>
__device__ static void Run(const bf8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8(
bit_cast<long>(reg_a),
bit_cast<long>(reg_b),
reg_c.template AsType<float4_t>()[Number<0>{}],
0,
0,
0);
#else
vector_type<bf8_t, 8> reg_a_v(reg_a);
vector_type<bf8_t, 8> reg_b_v(reg_b);
static_for<0, 8, 1>{}([&](auto k) {
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<bf8_t>()[Number<k>{}]);
float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<bf8_t>()[Number<k>{}]);
intrin_mfma_f32_16x16x4f32<16, 16>::Run(reg_a_f32, reg_b_f32, reg_c);
});
#endif
}
};
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x16f8bf8;
template <>
struct intrin_mfma_f32_32x32x16f8bf8<32, 32>
{
template <class FloatC>
__device__ static void Run(const f8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
reg_c.template AsType<float16_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(
bit_cast<long>(reg_a),
bit_cast<long>(reg_b),
reg_c.template AsType<float16_t>()[Number<0>{}],
0,
0,
0);
#else
vector_type<f8_t, 8> reg_a_v(reg_a);
vector_type<bf8_t, 8> reg_b_v(reg_b);
static_for<0, 8, 1>{}([&](auto k) {
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<f8_t>()[Number<k>{}]);
float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<bf8_t>()[Number<k>{}]);
intrin_mfma_f32_32x32x2f32<32, 32>::Run(reg_a_f32, reg_b_f32, reg_c);
});
#endif
}
};
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_16x16x32f8bf8;
template <>
struct intrin_mfma_f32_16x16x32f8bf8<16, 16>
{
template <class FloatC>
__device__ static void Run(const f8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_fp8_bf8(
bit_cast<long>(reg_a),
bit_cast<long>(reg_b),
reg_c.template AsType<float4_t>()[Number<0>{}],
0,
0,
0);
#else
vector_type<f8_t, 8> reg_a_v(reg_a);
vector_type<bf8_t, 8> reg_b_v(reg_b);
static_for<0, 8, 1>{}([&](auto k) {
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<f8_t>()[Number<k>{}]);
float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<bf8_t>()[Number<k>{}]);
intrin_mfma_f32_16x16x4f32<16, 16>::Run(reg_a_f32, reg_b_f32, reg_c);
});
#endif
}
};
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x16bf8f8;
template <>
struct intrin_mfma_f32_32x32x16bf8f8<32, 32>
{
template <class FloatC>
__device__ static void Run(const bf8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
reg_c.template AsType<float16_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8(
bit_cast<long>(reg_a),
bit_cast<long>(reg_b),
reg_c.template AsType<float16_t>()[Number<0>{}],
0,
0,
0);
#else
vector_type<bf8_t, 8> reg_a_v(reg_a);
vector_type<f8_t, 8> reg_b_v(reg_b);
static_for<0, 8, 1>{}([&](auto k) {
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<bf8_t>()[Number<k>{}]);
float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<f8_t>()[Number<k>{}]);
intrin_mfma_f32_32x32x2f32<32, 32>::Run(reg_a_f32, reg_b_f32, reg_c);
});
#endif
}
};
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_16x16x32bf8f8;
template <>
struct intrin_mfma_f32_16x16x32bf8f8<16, 16>
{
template <class FloatC>
__device__ static void Run(const bf8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8(
bit_cast<long>(reg_a),
bit_cast<long>(reg_b),
reg_c.template AsType<float4_t>()[Number<0>{}],
0,
0,
0);
#else
vector_type<bf8_t, 8> reg_a_v(reg_a);
vector_type<f8_t, 8> reg_b_v(reg_b);
static_for<0, 8, 1>{}([&](auto k) {
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<bf8_t>()[Number<k>{}]);
float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<f8_t>()[Number<k>{}]);
intrin_mfma_f32_16x16x4f32<16, 16>::Run(reg_a_f32, reg_b_f32, reg_c);
});
#endif
}
};
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment