Unverified Commit 7829d729 authored by rocking5566's avatar rocking5566 Committed by GitHub
Browse files

Gemm layernorm welford (#413)



* Add device op of gemm layernorm

* [What] Rename F to H
[Why] F and G prepare for welford tensor

* Add gridwise gemm + welford

* Extract template parameter

* Rename kernel. Prepare to add second half kernel

* Extract var

* Add second kernel for gemm+layernorm

* Move to the gemm_layernorm folder

* Rename F and G to mean and var

* Do not use snakeCurved, it makes determination of padding  for welford difficult

* Rewrite the device interface and rename some var

* Add welford count

* Update interface

* Sync code, prepare to test on MI200

* Clean the code

* Implement layernorm

* Add comment to mension hipFree

* Wrtie out the e for debug.
This could be remove and use h for instead

* 1. Allocate mean, var and count into by SetWorkSpacePointer.
2. Add GetWorkSpaceSize to calculate the space size

* Add gemm layernorm host code

* use reference layernorm

* Fix bug of blockwise welford for first kernel

* Fix bug of mean var padding for layernorm

* Use sgpr for shuffleM_index

* padding for GemmMeanVarCountGridDescriptor_M_NBlock

* Add layout parameter

* Check argument for gemm

* calculate max count for tail block

* Share E and H memory in device op

* Hard code the vector dim

* Refine the MakeDescriptor

* 1. Remove E parameter, because E is inside of device op
2. Check vector size

* [What] Rename MakeMeanVarDescriptor_M_N
[Why] Prepare to add count version of make descriptor

* Use 1D global memory for count

* Prevent redundant IO

* Update parameter

* Add pipeline v1/v2 selector

* Rename the example name

* Add base class for gemm layernorm

* Refine naming to distinguish naive and welford

* Add comment to explan in detail

* We don't need to pad in N dimension in gemm for mean/var/count. Set NPerTile 1

* Rewrite the 2st kernel, use multiple block along N dimension in layernorm kernel

* Share the vector size

* Refine var name

* [What] Force LayernormThreadSliceSize_N = vector size.
[Why] Memory coalesce

* Add comment

* Extract divisor out of the loop in reference layernorm

* Pad different size for E and H in layernorm kernel according to different block tile

* Refine naming

* Refine naming

* Prevent implicit cast

* [What] use ck::math::sqrt instead of __builtin_amdgcn_sqrtf
[Why] __builtin_amdgcn_sqrtf is only support float, double will cause casting

* Cast only constant

* Change of post shuffle thread descriptor

* Add EMeanVarDataType parameter.

* Merge the mean and var threadwise copy

* Add missing index

* Fix Typo

* Sync the variable with previous if

* 1. Declare e inside the host_gemm_layernorm()
2. Prevent implicit cast in reference code
Co-authored-by: default avatarPo Yen Chen <PoYen.Chen@amd.com>
parent 919aeb1f
add_example_executable(example_gemm_bias_relu_add_layernorm_xdl_fp16 gemm_bias_relu_add_layernorm_xdl_fp16.cpp) add_example_executable(example_gemm_bias_relu_add_layernorm_xdl_welford_fp16 gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp)
add_example_executable(example_gemm_layernorm_xdl_fp16 gemm_layernorm_xdl_fp16.cpp) add_example_executable(example_gemm_bias_relu_add_layernorm_xdl_naive_fp16 gemm_bias_relu_add_layernorm_xdl_naive_fp16.cpp)
add_example_executable(example_gemm_xdl_layernorm_single_kernel_fp16 gemm_xdl_layernorm_single_kernel_fp16.cpp) add_example_executable(example_gemm_layernorm_xdl_naive_fp16 gemm_layernorm_xdl_naive_fp16.cpp)
add_example_executable(example_gemm_xdl_layernorm_naive_single_kernel_fp16 gemm_xdl_layernorm_naive_single_kernel_fp16.cpp)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp"
#include "ck/library/utility/check_err.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd;
// DataType
using ADataType = F16;
using BDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F32;
using D0DataType = F16;
using D1DataType = F16;
using DsDataType = ck::Tuple<D0DataType, D1DataType>;
using EMeanVarDataType = F16;
using GammaDataType = F16;
using BetaDataType = F16;
using HDataType = F16;
// Layout
using ALayout = Row;
using BLayout = Col;
using D0Layout = Row;
using D1Layout = Row;
using DsLayout = ck::Tuple<D0Layout, D1Layout>;
using HLayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = AddReluAdd;
using HElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// clang-format off
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDLayernorm_Xdl_CShuffle
//######| ALayout| BLayout| DsLayout| HLayout| AData| BData| AccData| CShuffle| DsData| EMeanVarData| GammaData| BetaData| HData| A| B| CDE| H| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| PostShuffle| PostShuffle| Layernorm| Layernorm|
//######| | | | | Type| Type| Type| DataType| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ThreadClusterLengths| ScalarPerVector| ThreadClusterLengths| ThreadSliceSize|
//######| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _M_N| _M_N| _M_N| _M|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, DsLayout, HLayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EMeanVarDataType, GammaDataType, BetaDataType, HDataType, AElementOp, BElementOp, CDEElementOp, HElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<32, 8>, 8, S<8, 32>, 8>;
// clang-format on
auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) {
return HostTensorDescriptor(std::vector<std::size_t>({len}),
std::vector<std::size_t>({stride}));
};
auto f_host_tensor_descriptor2d =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({stride, 1}));
}
else
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({1, stride}));
}
};
void host_gemm_layernorm(Tensor<HDataType>& h_m_n,
const Tensor<ADataType>& a_m_k,
const Tensor<BDataType>& b_k_n,
const Tensor<D0DataType>& bias_n,
const Tensor<D1DataType>& d1_m_n,
const Tensor<GammaDataType>& gamma_n,
const Tensor<BetaDataType>& beta_n,
AElementOp a_element_op,
BElementOp b_element_op,
CDEElementOp cde_element_op,
int M,
int N,
AccDataType epsilon = 1e-5)
{
using ReferenceGemm = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
AccDataType,
AccDataType,
AElementOp,
BElementOp,
PassThrough>;
using ReferenceLayernorm = ck::tensor_operation::host::ReferenceLayernorm<EMeanVarDataType,
GammaDataType,
BetaDataType,
HDataType,
AccDataType,
HElementOp,
2,
1>;
Tensor<EMeanVarDataType> e_m_n(HostTensorDescriptor{M, N});
Tensor<AccDataType> c_m_n(HostTensorDescriptor{M, N});
auto ref_gemm = ReferenceGemm{};
auto ref_gemm_invoker = ref_gemm.MakeInvoker();
auto ref_gemm_argument =
ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{});
ref_gemm_invoker.Run(ref_gemm_argument);
for(int n = 0; n < N; ++n)
{
AccDataType bias = static_cast<AccDataType>(bias_n(n));
for(int m = 0; m < M; ++m)
{
AccDataType e = static_cast<AccDataType>(e_m_n(m, n));
AccDataType d1 = static_cast<AccDataType>(d1_m_n(m, n));
cde_element_op(e, c_m_n(m, n), bias, d1);
e_m_n(m, n) = static_cast<EMeanVarDataType>(e);
}
}
ReferenceLayernorm ref_layernorm;
auto ref_layernorm_invoker = ref_layernorm.MakeInvoker();
auto ref_layernorm_argument = ref_layernorm.MakeArgument(
e_m_n, gamma_n, beta_n, h_m_n, HElementOp{}, {M, N}, {1}, epsilon);
ref_layernorm_invoker.Run(ref_layernorm_argument);
}
int main()
{
bool do_verification = true;
// GEMM shape
ck::index_t M = 1024;
ck::index_t N = 1024;
ck::index_t K = 1024;
ck::index_t StrideA = K;
ck::index_t StrideB = K;
ck::index_t StrideD0 = 0;
ck::index_t StrideD1 = N;
ck::index_t StrideH = N;
float epsilon = 1e-5;
Tensor<ADataType> a_m_k(f_host_tensor_descriptor2d(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor2d(K, N, StrideB, BLayout{}));
Tensor<D0DataType> d0_n(f_host_tensor_descriptor1d(N, 1));
Tensor<D1DataType> d1_m_n(f_host_tensor_descriptor2d(M, N, StrideD1, D1Layout{}));
Tensor<GammaDataType> gamma_n(f_host_tensor_descriptor1d(N, 1));
Tensor<BetaDataType> beta_n(f_host_tensor_descriptor1d(N, 1));
Tensor<HDataType> h_m_n(f_host_tensor_descriptor2d(M, N, StrideH, HLayout{}));
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-1, 1});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-1, 1});
d0_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{-1, 1});
d1_m_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{-1, 1});
gamma_n.GenerateTensorValue(GeneratorTensor_3<GammaDataType>{-1, 1});
beta_n.GenerateTensorValue(GeneratorTensor_3<BetaDataType>{-1, 1});
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem d0_device_buf(sizeof(D0DataType) * d0_n.mDesc.GetElementSpaceSize());
DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize());
DeviceMem gamma_device_buf(sizeof(GammaDataType) * gamma_n.mDesc.GetElementSpaceSize());
DeviceMem beta_device_buf(sizeof(BetaDataType) * beta_n.mDesc.GetElementSpaceSize());
DeviceMem h_device_buf(sizeof(HDataType) * h_m_n.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
d0_device_buf.ToDevice(d0_n.mData.data());
d1_device_buf.ToDevice(d1_m_n.mData.data());
gamma_device_buf.ToDevice(gamma_n.mData.data());
beta_device_buf.ToDevice(beta_n.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
auto h_element_op = HElementOp{};
auto device_op = DeviceOpInstance{};
auto invoker = device_op.MakeInvoker();
auto argument =
device_op.MakeArgument(a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
{d0_device_buf.GetDeviceBuffer(), d1_device_buf.GetDeviceBuffer()},
gamma_device_buf.GetDeviceBuffer(),
beta_device_buf.GetDeviceBuffer(),
h_device_buf.GetDeviceBuffer(),
M,
N,
K,
StrideA,
StrideB,
{StrideD0, StrideD1},
StrideH,
epsilon,
a_element_op,
b_element_op,
cde_element_op,
h_element_op);
if(!device_op.IsSupportedArgument(argument))
{
throw std::runtime_error("wrong! this device_op instance does not support this problem");
}
size_t workspace_sz = device_op.GetWorkSpaceSize(&argument);
DeviceMem workspace_dev(workspace_sz);
device_op.SetWorkSpacePointer(&argument, workspace_dev.GetDeviceBuffer());
invoker.Run(argument, StreamConfig{nullptr, false});
bool pass = true;
if(do_verification)
{
Tensor<HDataType> h_m_n_host(HostTensorDescriptor{M, N});
host_gemm_layernorm(h_m_n_host,
a_m_k,
b_k_n,
d0_n,
d1_m_n,
gamma_n,
beta_n,
a_element_op,
b_element_op,
cde_element_op,
M,
N,
epsilon);
h_device_buf.FromDevice(h_m_n.mData.data());
pass &=
ck::utils::check_err(h_m_n, h_m_n_host, "Error: Incorrect results h_m_n", 1e-2, 1e-2);
}
return pass ? 0 : 1;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <array>
#include "device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
// GEMM:
// input : A[M, K]
// input : B[N, K]
// input : D0[M, N], D1[M, N], ...
// output : E[M, N]
// output : H[M, N]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// H = layernorm(E)
// Assume:
// D0, D1, ... and E have the same layout
// Calculate mean & variance along N dimension in layernorm(E)
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename HLayout,
typename ADataType,
typename BDataType,
typename DsDataType,
typename GammaDataType,
typename BetaDataType,
typename HDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
typename HElementwiseOperation>
struct DeviceGemmMultipleDLayernorm : public BaseOperator
{
static constexpr index_t NumDTensor = DsDataType::Size();
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
const void* p_gamma,
const void* p_beta,
void* p_h,
index_t MRaw,
index_t NRaw,
index_t KRaw,
index_t StrideA,
index_t StrideB,
std::array<index_t, NumDTensor> StrideDs,
index_t StrideH,
double epsilon,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op,
HElementwiseOperation h_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; // namespace device
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
template <typename GridwiseGemmWelford,
typename ABDataType,
typename DsPointer,
typename EMeanVarDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename MeanVarGridDescriptor_MBlock_MPerBlock_NBlock,
typename CountGridDescriptor_MBlock_MPerBlock_NBlock,
typename Block2ETileMap,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_multiple_d_welford_first_half_xdl_cshuffle(
const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid,
EMeanVarDataType* __restrict__ p_e_grid,
EMeanVarDataType* __restrict__ p_welford_mean_grid,
EMeanVarDataType* __restrict__ p_welford_var_grid,
int32_t* __restrict__ p_welford_count_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock,
const MeanVarGridDescriptor_MBlock_MPerBlock_NBlock
mean_var_grid_desc_mblock_mperblock_nblock,
const CountGridDescriptor_MBlock_MPerBlock_NBlock
count_grid_desc_mblock_mperblock_nblock,
const Block2ETileMap block_2_etile_map,
index_t NRaw)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemmWelford::GetSharedMemoryNumberOfByte()];
GridwiseGemmWelford::template Run<HasMainKBlockLoop>(
p_a_grid,
p_b_grid,
p_ds_grid,
p_e_grid,
p_welford_mean_grid,
p_welford_var_grid,
p_welford_count_grid,
p_shared,
a_element_op,
b_element_op,
cde_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
mean_var_grid_desc_mblock_mperblock_nblock,
count_grid_desc_mblock_mperblock_nblock,
block_2_etile_map,
NRaw);
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_ds_grid;
ignore = p_e_grid;
ignore = p_welford_mean_grid;
ignore = p_welford_var_grid;
ignore = p_welford_count_grid;
ignore = a_element_op;
ignore = b_element_op;
ignore = cde_element_op;
ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1;
ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = e_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = mean_var_grid_desc_mblock_mperblock_nblock;
ignore = count_grid_desc_mblock_mperblock_nblock;
ignore = block_2_etile_map;
ignore = NRaw;
#endif
}
template <typename GridwiseWelfordLayernorm,
typename EMeanVarDataType,
typename HDataType,
typename GammaDataType,
typename BetaDataType,
typename ComputeDataType,
typename EHGridDesc_M_N,
typename LayernormMeanVarGridDesc_M_NBlock,
typename LayernormCountGridDesc_M_NBlock,
typename GammaBetaGridDesc_N,
typename HElementwiseOperation>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_welford_layernorm2d_second_half(
const EMeanVarDataType* __restrict__ p_e_grid,
const EMeanVarDataType* __restrict__ p_in_welford_mean_grid,
const EMeanVarDataType* __restrict__ p_in_welford_var_grid,
const int32_t* __restrict__ p_in_welford_count_grid,
const GammaDataType* __restrict__ p_gamma_grid,
const BetaDataType* __restrict__ p_beta_grid,
HDataType* __restrict__ p_h_grid,
const EHGridDesc_M_N e_grid_desc_m_n,
const EHGridDesc_M_N h_grid_desc_m_n,
const LayernormMeanVarGridDesc_M_NBlock mean_var_grid_desc_m_nblock,
const LayernormCountGridDesc_M_NBlock count_grid_desc_m_nblock,
const GammaBetaGridDesc_N gamma_grid_desc_n,
const GammaBetaGridDesc_N beta_grid_desc_n,
index_t numMeanVarCountBlockTileIteration_N,
index_t NBlockClusterLength,
ComputeDataType epsilon,
HElementwiseOperation h_element_op)
{
GridwiseWelfordLayernorm::Run(p_e_grid,
p_in_welford_mean_grid,
p_in_welford_var_grid,
p_in_welford_count_grid,
p_gamma_grid,
p_beta_grid,
p_h_grid,
e_grid_desc_m_n,
h_grid_desc_m_n,
mean_var_grid_desc_m_nblock,
count_grid_desc_m_nblock,
gamma_grid_desc_n,
beta_grid_desc_n,
numMeanVarCountBlockTileIteration_N,
NBlockClusterLength,
epsilon,
h_element_op);
}
} // namespace ck
namespace ck {
namespace tensor_operation {
namespace device {
// GEMM:
// input : A[M, K]
// input : B[N, K]
// input : D0[M, N], D1[M, N], ...
// output : E[M, N]
// output : H[M, N]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// H = layernorm(E)
// Assume:
// D0, D1, ... and E have the same layout
// Calculate mean & variance along N dimension in layernorm(E)
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename HLayout,
typename ADataType,
typename BDataType,
typename AccDataType,
typename CShuffleDataType,
typename DsDataType,
typename EMeanVarDataType,
typename GammaDataType,
typename BetaDataType,
typename HDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
typename HElementwiseOperation,
GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t GemmMPerBlock,
index_t GemmNPerBlock,
index_t GemmKPerBlock,
index_t AK1,
index_t BK1,
index_t MPerXDL,
index_t NPerXDL,
index_t MXdlPerWave,
index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename PostShuffleThreadClusterSize_M_N,
index_t PostShuffleScalarPerVector,
typename LayernormThreadClusterSize_M_N,
index_t LayernormThreadSliceSize_M,
LoopScheduler LoopSched = make_default_loop_scheduler(),
PipelineVersion PipelineVer = PipelineVersion::v1>
struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
: public DeviceGemmMultipleDLayernorm<ALayout,
BLayout,
DsLayout,
HLayout,
ADataType,
BDataType,
DsDataType,
GammaDataType,
BetaDataType,
HDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
HElementwiseOperation>
{
// EDataType, MeanDataType and VarDataType must be the same.
// eg. M, N, K = [1, 1, 1],
// in case of layernorm, divisor = 1 / sqrt(var + 1e-5) = 316.227783
// if (x - mean) != 0, (x - mean) * divisor * gamma might be too large
// However, (x - mean) * divisor * gamma should be 0 in this case
using DeviceOp = DeviceGemmMultipleDLayernorm_Xdl_CShuffle;
using ELayout = HLayout;
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr index_t LayernormHDstVectorSize = PostShuffleScalarPerVector;
static constexpr index_t LayernormGammaSrcVectorSize = PostShuffleScalarPerVector;
static constexpr index_t LayernormBetaSrcVectorSize = PostShuffleScalarPerVector;
static constexpr index_t LayernormESrcVectorSize = PostShuffleScalarPerVector;
static constexpr index_t LayernormThreadSliceSize_N = PostShuffleScalarPerVector;
using LayernormBlockTileSize_M_N =
Sequence<LayernormThreadClusterSize_M_N::At(0) * LayernormThreadSliceSize_M,
LayernormThreadClusterSize_M_N::At(1) * LayernormThreadSliceSize_N>;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto matrix_padder = MatrixPadder<GemmSpec, index_t, index_t, index_t>{
GemmMPerBlock, GemmNPerBlock, GemmKPerBlock};
static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA)
{
const auto a_grid_desc_mraw_kraw = [&]() {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
make_tuple(StrideA, I1));
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
make_tuple(I1, StrideA));
}
}();
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
}
static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
{
const auto b_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(I1, StrideB));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(StrideB, I1));
}
}();
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
}
template <typename DoPads, index_t MPerTile, index_t NPerTile>
static auto MakeEHGridDescriptor_M_N(index_t M, index_t N, index_t Stride)
{
// Only support row major for E and H
const auto grid_desc_m_n =
make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(Stride, I1));
return PadTensorDescriptor(grid_desc_m_n, make_tuple(MPerTile, NPerTile), DoPads{});
}
static auto MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& MRaws,
const std::array<index_t, NumDTensor>& NRaws,
const std::array<index_t, NumDTensor>& DsStride)
{
return generate_tuple(
[&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
static_assert(is_same<tensor_layout::gemm::RowMajor, DLayout>::value);
return DeviceOp::
MakeEHGridDescriptor_M_N<Sequence<true, true>, GemmMPerBlock, GemmNPerBlock>(
MRaws[i], NRaws[i], DsStride[i]);
},
Number<NumDTensor>{});
}
template <typename DoPads, index_t MPerTile, index_t NPerTile>
static auto MakeMeanVarDescriptor_M_N(index_t M, index_t N)
{
const auto grid_desc_m_n =
make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(N, I1));
return PadTensorDescriptor(grid_desc_m_n, make_tuple(MPerTile, NPerTile), DoPads{});
}
template <typename DoPads, index_t MPerTile, index_t NPerTile>
static auto MakeCountDescriptor_M_N(index_t M, index_t N)
{
// We will broadcast [N] to [M, N] in this descriptor
// Hence, 1st stride is 0
const auto grid_desc_m_n =
make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I0, I1));
return PadTensorDescriptor(grid_desc_m_n, make_tuple(MPerTile, NPerTile), DoPads{});
}
template <index_t XPerTile>
static auto MakeDescriptor_X(index_t X)
{
const auto grid_desc_x = make_naive_tensor_descriptor_packed(make_tuple(X));
return PadTensorDescriptor(grid_desc_x, make_tuple(XPerTile), Sequence<true>{});
}
using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1));
using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1));
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
// We have to separate mean var descriptor for gemm and layernorm bacause of different grid
// layout(different padding)
using GemmMeanVarGridDesc_M_NBlock = decltype(
MakeMeanVarDescriptor_M_N<Sequence<true, false>, GemmMPerBlock, GemmNPerBlock>(1, 1));
using GemmCountGridDesc_M_NBlock = decltype(
MakeCountDescriptor_M_N<Sequence<true, false>, GemmMPerBlock, GemmNPerBlock>(1, 1));
using LayernormMeanVarGridDesc_M_NBlock =
decltype(MakeMeanVarDescriptor_M_N<Sequence<true, true>,
LayernormBlockTileSize_M_N::At(0),
LayernormBlockTileSize_M_N::At(1)>(1, 1));
using LayernormCountGridDesc_M_NBlock =
decltype(MakeCountDescriptor_M_N<Sequence<true, true>,
LayernormBlockTileSize_M_N::At(0),
LayernormBlockTileSize_M_N::At(1)>(1, 1));
using GammaBetaGridDesc_N = decltype(MakeDescriptor_X<LayernormBlockTileSize_M_N::At(1)>(1));
using EHGridDesc_M_N = decltype(MakeEHGridDescriptor_M_N<Sequence<true, true>, 1, 1>(1, 1, 1));
using GridwiseGemmWelford = GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle<
ADataType, // TODO: distinguish A/B datatype
AccDataType,
CShuffleDataType,
DsDataType,
EMeanVarDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
InMemoryDataOperationEnum::Set,
AGridDesc_M_K,
BGridDesc_N_K,
DsGridDesc_M_N,
EHGridDesc_M_N,
GemmMeanVarGridDesc_M_NBlock,
GemmCountGridDesc_M_NBlock,
NumGemmKPrefetchStage,
BlockSize,
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
PostShuffleThreadClusterSize_M_N,
PostShuffleScalarPerVector,
LoopSched,
PipelineVer>;
using Block2ETileMap = typename GridwiseGemmWelford::DefaultBlock2ETileMap;
using GridwiseWelfordLayernorm =
GridwiseWelfordSecondHalfLayernorm2d<EMeanVarDataType,
HDataType,
GammaDataType,
BetaDataType,
AccDataType,
EHGridDesc_M_N,
LayernormMeanVarGridDesc_M_NBlock,
LayernormCountGridDesc_M_NBlock,
GammaBetaGridDesc_N,
HElementwiseOperation,
BlockSize,
LayernormThreadClusterSize_M_N::At(I0),
LayernormThreadClusterSize_M_N::At(I1),
LayernormThreadSliceSize_M,
LayernormThreadSliceSize_N,
LayernormESrcVectorSize,
LayernormHDstVectorSize,
LayernormGammaSrcVectorSize,
LayernormBetaSrcVectorSize>;
// Argument
struct Argument : public BaseArgument
{
Argument(const void* p_a_grid,
const void* p_b_grid,
std::array<const void*, NumDTensor> p_ds_grid,
const void* p_gamma_grid,
const void* p_beta_grid,
void* p_h_grid,
index_t MRaw,
index_t NRaw,
index_t KRaw,
index_t StrideA,
index_t StrideB,
std::array<index_t, NumDTensor> StrideDs,
index_t StrideH,
double epsilon,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op,
HElementwiseOperation h_element_op)
: p_a_grid_{static_cast<const ADataType*>(p_a_grid)},
p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
p_ds_grid_{},
p_workspace_e_grid_{nullptr},
p_workspace_mean_{nullptr},
p_workspace_var_{nullptr},
p_workspace_count_{nullptr},
p_gamma_grid_{static_cast<const GammaDataType*>(p_gamma_grid)},
p_beta_grid_{static_cast<const BetaDataType*>(p_beta_grid)},
p_h_grid_{static_cast<HDataType*>(p_h_grid)},
a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(MRaw, KRaw, StrideA)},
b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(KRaw, NRaw, StrideB)},
ds_grid_desc_m_n_{},
gemm_e_grid_desc_m_n_{
DeviceOp::MakeEHGridDescriptor_M_N<Sequence<true, true>,
GemmMPerBlock,
GemmNPerBlock>(MRaw, NRaw, StrideH)},
layernorm_e_grid_desc_m_n_{
DeviceOp::MakeEHGridDescriptor_M_N<Sequence<true, true>,
LayernormBlockTileSize_M_N::At(0),
LayernormBlockTileSize_M_N::At(1)>(
MRaw, NRaw, StrideH)},
gemm_mean_var_grid_desc_m_nblock_{},
gemm_count_grid_desc_m_nblock_{},
layernorm_mean_var_grid_desc_m_nblock_{},
layernorm_count_grid_desc_m_nblock_{},
gamma_grid_desc_n_{
DeviceOp::MakeDescriptor_X<LayernormBlockTileSize_M_N::At(1)>(NRaw)},
beta_grid_desc_n_{
DeviceOp::MakeDescriptor_X<LayernormBlockTileSize_M_N::At(1)>(NRaw)},
h_grid_desc_m_n_{
DeviceOp::MakeEHGridDescriptor_M_N<Sequence<true, true>,
LayernormBlockTileSize_M_N::At(0),
LayernormBlockTileSize_M_N::At(1)>(
MRaw, NRaw, StrideH)},
a_grid_desc_ak0_m_ak1_{
GridwiseGemmWelford::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)},
b_grid_desc_bk0_n_bk1_{
GridwiseGemmWelford::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)},
block_2_etile_map_{
GridwiseGemmWelford::MakeDefaultBlock2ETileMap(gemm_e_grid_desc_m_n_)},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op},
h_element_op_{h_element_op},
MRaw_{MRaw},
NRaw_{NRaw},
KRaw_{KRaw},
gemm_nblock_{math::integer_divide_ceil(NRaw, GemmNPerBlock)},
epsilon_{static_cast<AccDataType>(epsilon)}
{
// We don't need to pad in N dimension in gemm for mean/var/count. Set NPerTile 1.
gemm_mean_var_grid_desc_m_nblock_ =
DeviceOp::MakeMeanVarDescriptor_M_N<Sequence<true, false>, GemmMPerBlock, 1>(
MRaw, gemm_nblock_);
gemm_count_grid_desc_m_nblock_ =
DeviceOp::MakeCountDescriptor_M_N<Sequence<true, false>, GemmMPerBlock, 1>(
MRaw, gemm_nblock_);
layernorm_mean_var_grid_desc_m_nblock_ =
DeviceOp::MakeMeanVarDescriptor_M_N<Sequence<true, true>,
LayernormBlockTileSize_M_N::At(0),
LayernormBlockTileSize_M_N::At(1)>(
MRaw, gemm_nblock_);
layernorm_count_grid_desc_m_nblock_ =
DeviceOp::MakeCountDescriptor_M_N<Sequence<true, true>,
LayernormBlockTileSize_M_N::At(0),
LayernormBlockTileSize_M_N::At(1)>(MRaw,
gemm_nblock_);
// populate pointer, desc for Ds
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
// D pointer
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds_grid[i]);
// D desc
ds_grid_desc_m_n_(i) =
DeviceOp::MakeEHGridDescriptor_M_N<Sequence<true, true>,
GemmMPerBlock,
GemmNPerBlock>(MRaw, NRaw, StrideDs[i]);
});
// populate desc for Ds/E/mean/var/count
if(GridwiseGemmWelford::CheckValidity(a_grid_desc_m_k_,
b_grid_desc_n_k_,
ds_grid_desc_m_n_,
gemm_e_grid_desc_m_n_,
block_2_etile_map_))
{
ds_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemmWelford::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n_);
e_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemmWelford::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
gemm_e_grid_desc_m_n_);
gemm_mean_var_grid_desc_mblock_mperblock_nblock_ =
GridwiseGemmWelford::MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(
gemm_mean_var_grid_desc_m_nblock_);
gemm_count_grid_desc_mblock_mperblock_nblock_ =
GridwiseGemmWelford::MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(
gemm_count_grid_desc_m_nblock_);
}
}
void Print() const
{
std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl;
std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl;
static_for<0, NumDTensor, 1>{}(
[&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; });
std::cout << "E[M, N]: " << gemm_e_grid_desc_m_n_ << std::endl;
std::cout << "H[M, N]: " << h_grid_desc_m_n_ << std::endl;
}
// private:
// pointers
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
typename GridwiseGemmWelford::DsGridPointer p_ds_grid_;
void* p_workspace_e_grid_;
void* p_workspace_mean_;
void* p_workspace_var_;
void* p_workspace_count_;
const GammaDataType* p_gamma_grid_;
const BetaDataType* p_beta_grid_;
HDataType* p_h_grid_;
// tensor descriptors for problem definiton
AGridDesc_M_K a_grid_desc_m_k_;
BGridDesc_N_K b_grid_desc_n_k_;
DsGridDesc_M_N ds_grid_desc_m_n_;
EHGridDesc_M_N gemm_e_grid_desc_m_n_;
EHGridDesc_M_N layernorm_e_grid_desc_m_n_;
GemmMeanVarGridDesc_M_NBlock gemm_mean_var_grid_desc_m_nblock_;
GemmCountGridDesc_M_NBlock gemm_count_grid_desc_m_nblock_;
LayernormMeanVarGridDesc_M_NBlock layernorm_mean_var_grid_desc_m_nblock_;
LayernormCountGridDesc_M_NBlock layernorm_count_grid_desc_m_nblock_;
GammaBetaGridDesc_N gamma_grid_desc_n_;
GammaBetaGridDesc_N beta_grid_desc_n_;
EHGridDesc_M_N h_grid_desc_m_n_;
// tensor descriptors for block/thread-wise copy
typename GridwiseGemmWelford::DefaultAGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
typename GridwiseGemmWelford::DefaultBGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
typename GridwiseGemmWelford::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemmWelford::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemmWelford::MeanVarGridDescriptor_MBlock_MPerBlock_NBlock
gemm_mean_var_grid_desc_mblock_mperblock_nblock_;
typename GridwiseGemmWelford::CountGridDescriptor_MBlock_MPerBlock_NBlock
gemm_count_grid_desc_mblock_mperblock_nblock_;
// block-to-e-tile map
Block2ETileMap block_2_etile_map_;
// element-wise op
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
HElementwiseOperation h_element_op_;
index_t MRaw_;
index_t NRaw_;
index_t KRaw_;
index_t gemm_nblock_;
AccDataType epsilon_;
};
// Invoker
struct Invoker : public BaseInvoker
{
using Argument = DeviceOp::Argument;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
float avg_time = 0;
if(!GridwiseGemmWelford::CheckValidity(arg.a_grid_desc_m_k_,
arg.b_grid_desc_n_k_,
arg.ds_grid_desc_m_n_,
arg.gemm_e_grid_desc_m_n_,
arg.block_2_etile_map_))
{
throw std::runtime_error("wrong! GridwiseGemmWelford has invalid setting");
}
index_t grid_size = arg.block_2_etile_map_.CalculateGridSize(arg.gemm_e_grid_desc_m_n_);
const auto M = arg.h_grid_desc_m_n_.GetLength(I0);
const auto N = arg.h_grid_desc_m_n_.GetLength(I1);
const auto K =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
auto launch_kernel = [&](auto has_main_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value;
const auto kernel_gemm_welford =
kernel_gemm_multiple_d_welford_first_half_xdl_cshuffle<
GridwiseGemmWelford,
ADataType, // TODO: distiguish A/B datatype
typename GridwiseGemmWelford::DsGridPointer,
EMeanVarDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
typename GridwiseGemmWelford::DefaultAGridDesc_AK0_M_AK1,
typename GridwiseGemmWelford::DefaultBGridDesc_BK0_N_BK1,
typename GridwiseGemmWelford::
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemmWelford::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemmWelford::MeanVarGridDescriptor_MBlock_MPerBlock_NBlock,
typename GridwiseGemmWelford::CountGridDescriptor_MBlock_MPerBlock_NBlock,
typename GridwiseGemmWelford::DefaultBlock2ETileMap,
has_main_loop>;
const auto kernel_welford_layernorm =
kernel_welford_layernorm2d_second_half<GridwiseWelfordLayernorm,
EMeanVarDataType,
HDataType,
GammaDataType,
BetaDataType,
AccDataType,
EHGridDesc_M_N,
LayernormMeanVarGridDesc_M_NBlock,
LayernormCountGridDesc_M_NBlock,
GammaBetaGridDesc_N,
HElementwiseOperation>;
avg_time +=
launch_and_time_kernel(stream_config,
kernel_gemm_welford,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_ds_grid_,
static_cast<EMeanVarDataType*>(arg.p_workspace_e_grid_),
static_cast<EMeanVarDataType*>(arg.p_workspace_mean_),
static_cast<EMeanVarDataType*>(arg.p_workspace_var_),
static_cast<int32_t*>(arg.p_workspace_count_),
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.gemm_mean_var_grid_desc_mblock_mperblock_nblock_,
arg.gemm_count_grid_desc_mblock_mperblock_nblock_,
arg.block_2_etile_map_,
arg.NRaw_);
index_t MBlockClusterLength =
math::integer_divide_ceil(M, LayernormBlockTileSize_M_N::At(0));
index_t NBlockClusterLength =
math::integer_divide_ceil(N, LayernormBlockTileSize_M_N::At(1));
grid_size = MBlockClusterLength * NBlockClusterLength;
index_t numMeanVarCountBlockTileIteration_N = math::integer_divide_ceil(
arg.gemm_nblock_, LayernormThreadClusterSize_M_N::At(I1));
avg_time += launch_and_time_kernel(
stream_config,
kernel_welford_layernorm,
dim3(grid_size),
dim3(BlockSize),
0,
static_cast<EMeanVarDataType*>(arg.p_workspace_e_grid_),
static_cast<const EMeanVarDataType*>(arg.p_workspace_mean_),
static_cast<const EMeanVarDataType*>(arg.p_workspace_var_),
static_cast<const int32_t*>(arg.p_workspace_count_),
arg.p_gamma_grid_,
arg.p_beta_grid_,
arg.p_h_grid_,
arg.layernorm_e_grid_desc_m_n_,
arg.h_grid_desc_m_n_,
arg.layernorm_mean_var_grid_desc_m_nblock_,
arg.layernorm_count_grid_desc_m_nblock_,
arg.gamma_grid_desc_n_,
arg.beta_grid_desc_n_,
numMeanVarCountBlockTileIteration_N,
NBlockClusterLength,
arg.epsilon_,
arg.h_element_op_);
return avg_time;
};
if(GridwiseGemmWelford::CalculateHasMainKBlockLoop(K))
{
return launch_kernel(integral_constant<bool, true>{});
}
else
{
return launch_kernel(integral_constant<bool, false>{});
}
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
size_t GetWorkSpaceSize(const BaseArgument* pArg) const override
{
const Argument* pArg_ = dynamic_cast<const Argument*>(pArg);
size_t workspace_size = 0;
int gemm_welford_size = pArg_->MRaw_ * pArg_->gemm_nblock_;
// workspace for welford intermediate mean
workspace_size += gemm_welford_size * sizeof(EMeanVarDataType) + 64;
// workspace for welford intermediate mean
workspace_size += gemm_welford_size * sizeof(EMeanVarDataType) + 64;
// workspace for welford intermediate count
workspace_size += pArg_->gemm_nblock_ * sizeof(int32_t) + 64;
if constexpr(!is_same_v<EMeanVarDataType, HDataType>)
workspace_size += pArg_->MRaw_ * pArg_->NRaw_ * sizeof(EMeanVarDataType);
return (workspace_size);
};
void SetWorkSpacePointer(BaseArgument* pArg, void* p_workspace) const override
{
Argument* pArg_ = dynamic_cast<Argument*>(pArg);
pArg_->p_workspace_ = p_workspace;
int gemm_welford_size = pArg_->MRaw_ * pArg_->gemm_nblock_;
// setup buffer used for intermediate welford mean
pArg_->p_workspace_mean_ = static_cast<char*>(pArg_->p_workspace_);
index_t mean_space_sz = gemm_welford_size * sizeof(EMeanVarDataType);
mean_space_sz = math::integer_least_multiple(mean_space_sz, 64);
// setup buffer used for intermediate welford varirance
pArg_->p_workspace_var_ = reinterpret_cast<char*>(pArg_->p_workspace_mean_) + mean_space_sz;
index_t variance_space_sz = gemm_welford_size * sizeof(EMeanVarDataType);
variance_space_sz = math::integer_least_multiple(variance_space_sz, 64);
// setup buffer used for intermediate welford count
pArg_->p_workspace_count_ =
reinterpret_cast<char*>(pArg_->p_workspace_var_) + variance_space_sz;
index_t count_space_sz = gemm_welford_size * sizeof(int32_t);
count_space_sz = math::integer_least_multiple(count_space_sz, 64);
if constexpr(!is_same_v<EMeanVarDataType, HDataType>)
pArg_->p_workspace_e_grid_ =
reinterpret_cast<char*>(pArg_->p_workspace_count_) + count_space_sz;
else
pArg_->p_workspace_e_grid_ = static_cast<void*>(pArg_->p_h_grid_);
};
static bool IsSupportedArgument(const Argument& arg)
{
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a"))
{
return false;
}
// check vector load/store
{
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
// check vector load of A
if constexpr(is_same_v<ALayout, Row> && ABlockTransferSrcVectorDim == 2)
{
if(arg.KRaw_ % ABlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else if constexpr(is_same_v<ALayout, Col> && ABlockTransferSrcVectorDim == 1)
{
// FIXME: not rigorous
if(arg.MRaw_ % ABlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else
{
return false;
}
// check vector laod of B
if constexpr(is_same_v<BLayout, Col> && BBlockTransferSrcVectorDim == 2)
{
if(arg.KRaw_ % BBlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else if constexpr(is_same_v<BLayout, Row> && BBlockTransferSrcVectorDim == 1)
{
// FIXME: not rigorous
if(arg.NRaw_ % BBlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else
{
return false;
}
// check vector load of Ds
// only support RowMajor for now
bool all_valid = true;
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
if constexpr(!is_same_v<DLayout, Row>)
{
all_valid = false;
}
});
if(!all_valid)
{
return false;
}
// check vector store of E
// E and H only support RowMajor for now
if constexpr(is_same_v<ELayout, Row> && is_same_v<HLayout, Row>)
{
if(arg.NRaw_ % PostShuffleScalarPerVector != 0 ||
arg.NRaw_ % LayernormGammaSrcVectorSize != 0 ||
arg.NRaw_ % LayernormBetaSrcVectorSize != 0 ||
arg.NRaw_ % LayernormHDstVectorSize != 0)
{
return false;
}
}
else
{
return false;
}
}
return true;
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
const void* p_gamma,
const void* p_beta,
void* p_h,
index_t MRaw,
index_t NRaw,
index_t KRaw,
index_t StrideA,
index_t StrideB,
std::array<index_t, NumDTensor> StrideDs,
index_t StrideH,
double epsilon,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op,
HElementwiseOperation h_element_op)
{
return Argument{p_a,
p_b,
p_ds,
p_gamma,
p_beta,
p_h,
MRaw,
NRaw,
KRaw,
StrideA,
StrideB,
StrideDs,
StrideH,
epsilon,
a_element_op,
b_element_op,
cde_element_op,
h_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
const void* p_gamma,
const void* p_beta,
void* p_h,
index_t MRaw,
index_t NRaw,
index_t KRaw,
index_t StrideA,
index_t StrideB,
std::array<index_t, NumDTensor> StrideDs,
index_t StrideH,
double epsilon,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op,
HElementwiseOperation h_element_op) override
{
return std::make_unique<Argument>(p_a,
p_b,
p_ds,
p_gamma,
p_beta,
p_h,
MRaw,
NRaw,
KRaw,
StrideA,
StrideB,
StrideDs,
StrideH,
epsilon,
a_element_op,
b_element_op,
cde_element_op,
h_element_op);
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
std::map<LoopScheduler, std::string> LoopSchedToString{
{LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}};
std::map<PipelineVersion, std::string> PipelineVersionToString{{PipelineVersion::v1, "v1"},
{PipelineVersion::v2, "v2"}};
// clang-format off
str << "DeviceGemmMultipleDLayernorm_Xdl_CShuffle"
<< "<"
<< BlockSize << ", "
<< GemmMPerBlock << ", "
<< GemmNPerBlock << ", "
<< GemmKPerBlock << ", "
<< AK1 << ", "
<< BK1 << ", "
<< getGemmSpecializationString(GemmSpec)
<< ">"
<< " LoopScheduler: "
<< LoopSchedToString[LoopSched] << ", "
<< "PipelineVersion: "
<< PipelineVersionToString[PipelineVer];
// clang-format on
return str.str();
}
}; // namespace device
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp"
namespace ck {
// GEMM:
// input : A[M, K]
// input : B[N, K]
// input : D0[M, N], D1[M, N], ...
// output : E[M, N]
// output : F[M, N0], where N0 is number of blocks along N dimension
// output : G[M, N0], where N0 is number of blocks along N dimension
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// F, G = welford(E)
// Assume:
// D0, D1, ... and E have the same layout
// Calculate mean & variance along N dimension for E
template <typename ABDataType,
typename AccDataType,
typename CShuffleDataType,
typename DsDataType,
typename EMeanVarDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
typename AGridDesc_M_K,
typename BGridDesc_N_K,
typename DsGridDesc_M_N,
typename EGridDesc_M_N,
typename MeanVarGridDesc_M_NBlock,
typename CountGridDesc_M_NBlock,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t AK1Value,
index_t BK1Value,
index_t MPerXdl,
index_t NPerXdl,
index_t MXdlPerWave,
index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
bool AThreadTransferSrcResetCoordinateAfterRun,
index_t ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BThreadTransferSrcResetCoordinateAfterRun,
index_t BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename PostShuffleThreadClusterSize_M_N,
index_t PostShuffleScalarPerVector,
LoopScheduler LoopSched,
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
{
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
// K1 should be Number<...>
static constexpr auto AK1 = Number<AK1Value>{};
static constexpr auto BK1 = Number<BK1Value>{};
static constexpr auto AK0PerBlock = Number<KPerBlock / AK1Value>{};
static constexpr auto BK0PerBlock = Number<KPerBlock / BK1Value>{};
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{
// A matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(AK0PerBlock, Number<MPerBlock>{}, AK1),
make_tuple(Number<MPerBlock + ABlockLdsExtraM>{} * AK1, AK1, I1));
}
__host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
{
// B matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(BK0PerBlock, Number<NPerBlock>{}, BK1),
make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * BK1, BK1, I1));
}
__host__ __device__ static constexpr auto
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
{
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
make_naive_tensor_descriptor_packed(
make_tuple(I1,
Number<CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl>{},
I1,
Number<CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>{}));
return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
}
// ck::Tuple<const D0DataType*, const D1DataType*, ...>
static constexpr auto MakeDsGridPointer()
{
return generate_tuple(
[&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
return static_cast<const DDataType*>(nullptr);
},
Number<NumDTensor>{});
}
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// lds max alignment
constexpr auto max_lds_align = math::lcm(AK1, BK1);
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
// LDS allocation for C shuffle in LDS
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
constexpr auto c_block_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
sizeof(ABDataType),
c_block_size * sizeof(CShuffleDataType));
}
// A desc for source in blockwise copy
__host__ __device__ static constexpr auto
MakeDefaultAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k)
{
const auto M = a_grid_desc_m_k.GetLength(I0);
const auto K = a_grid_desc_m_k.GetLength(I1);
const auto AK0 = K / AK1;
return transform_tensor_descriptor(a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
// B desc for source in blockwise copy
__host__ __device__ static constexpr auto
MakeDefaultBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k)
{
const auto N = b_grid_desc_n_k.GetLength(I0);
const auto K = b_grid_desc_n_k.GetLength(I1);
const auto BK0 = K / BK1;
return transform_tensor_descriptor(b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
// E desc for destination in blockwise copy
template <typename EGridDescriptor_M_N>
__host__ __device__ static constexpr auto MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
const EGridDescriptor_M_N& e_grid_desc_m_n)
{
const auto M = e_grid_desc_m_n.GetLength(I0);
const auto N = e_grid_desc_m_n.GetLength(I1);
const auto MBlock = M / MPerBlock;
const auto NBlock = N / NPerBlock;
const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
e_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
return e_grid_desc_mblock_mperblock_nblock_nperblock;
}
// Ds desc for source in blockwise copy
template <typename DsGridDescriptor_M_N>
__host__ __device__ static constexpr auto
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
const DsGridDescriptor_M_N& ds_grid_desc_m_n)
{
return generate_tuple(
[&](auto i) {
return MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[i]);
},
Number<NumDTensor>{});
}
template <typename GridDescriptor_M_N>
__host__ __device__ static constexpr auto
MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(const GridDescriptor_M_N& grid_desc_m_n)
{
const auto M = grid_desc_m_n.GetLength(I0);
const auto NBlock = grid_desc_m_n.GetLength(I1);
const auto MBlock = M / MPerBlock;
const auto grid_desc_mblock_mperblock_nblock = transform_tensor_descriptor(
grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
make_pass_through_transform(NBlock)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}));
return grid_desc_mblock_mperblock_nblock;
}
// return block_id to E matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto
MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n)
{
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, EGridDesc_M_N>(
e_grid_desc_m_n);
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Block2ETileMap>
__host__ __device__ static constexpr bool CheckValidity(const AGridDesc_M_K& a_grid_desc_m_k,
const BGridDesc_N_K& b_grid_desc_n_k,
const DsGridDesc_M_N& ds_grid_desc_m_n,
const EGridDesc_M_N& e_grid_desc_m_n,
const Block2ETileMap& block_2_etile_map)
{
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
"Invalid tuning param!");
const auto M = a_grid_desc_m_k.GetLength(I0);
const auto N = b_grid_desc_n_k.GetLength(I0);
const auto K = a_grid_desc_m_k.GetLength(I1);
// check consistency of desc
if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1)))
{
return false;
}
bool valid = true;
static_for<0, NumDTensor, 1>{}([&](auto i) {
valid = valid && (M == ds_grid_desc_m_n[i].GetLength(I0) &&
N == ds_grid_desc_m_n[i].GetLength(I1));
});
if(!valid)
{
return false;
}
// check tile size
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
{
return false;
}
// check gridwise gemm pipeline
const auto num_k_loop = K / KPerBlock;
if(!GridwiseGemmPipe::IsSupported(num_k_loop))
{
return false;
}
// check block-to-E-tile
if(!block_2_etile_map.CheckValidity(e_grid_desc_m_n))
{
return false;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
// check tensor size: cannot be larger than 2GB each
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
if(!(a_grid_desc_m_k.GetElementSpaceSize() * sizeof(ABDataType) <= TwoGB &&
b_grid_desc_n_k.GetElementSpaceSize() * sizeof(ABDataType) <= TwoGB &&
e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EMeanVarDataType) <= TwoGB))
{
return false;
}
return true;
}
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
{
const index_t num_loop = K / KPerBlock;
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
}
using DefaultAGridDesc_AK0_M_AK1 =
remove_cvref_t<decltype(MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
using DefaultBGridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
using MeanVarGridDescriptor_MBlock_MPerBlock_NBlock = remove_cvref_t<decltype(
MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(MeanVarGridDesc_M_NBlock{}))>;
using CountGridDescriptor_MBlock_MPerBlock_NBlock = remove_cvref_t<decltype(
MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(CountGridDesc_M_NBlock{}))>;
using DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>;
using DefaultBlock2ETileMap =
remove_cvref_t<decltype(MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
using DsGridPointer = decltype(MakeDsGridPointer());
template <bool HasMainKBlockLoop,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename Block2ETileMap>
__device__ static void
Run(const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid,
DsGridPointer p_ds_grid,
EMeanVarDataType* __restrict__ p_e_grid,
EMeanVarDataType* __restrict__ p_welford_mean_grid,
EMeanVarDataType* __restrict__ p_welford_var_grid,
int32_t* __restrict__ p_welford_count,
void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
e_grid_desc_mblock_mperblock_nblock_nperblock,
const MeanVarGridDescriptor_MBlock_MPerBlock_NBlock&
mean_var_grid_desc_mblock_mperblock_nblock,
const CountGridDescriptor_MBlock_MPerBlock_NBlock& count_grid_desc_mblock_mperblock_nblock,
const Block2ETileMap& block_2_etile_map,
index_t NRaw)
{
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
const auto ds_grid_buf = generate_tuple(
[&](auto i) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ds_grid[i],
ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize());
},
Number<NumDTensor>{});
auto e_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
auto mean_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_welford_mean_grid, mean_var_grid_desc_mblock_mperblock_nblock.GetElementSpaceSize());
auto var_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_welford_var_grid, mean_var_grid_desc_mblock_mperblock_nblock.GetElementSpaceSize());
auto welford_count_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_welford_count, count_grid_desc_mblock_mperblock_nblock.GetElementSpaceSize());
// divide block work by [M, N]
const auto block_work_idx =
block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!block_2_etile_map.ValidCTileIndex(
block_work_idx,
make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
{
return;
}
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
// lds max alignment
constexpr auto max_lds_align = math::lcm(AK1, BK1);
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// A matrix blockwise copy
auto a_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<AK0PerBlock, MPerBlock, AK1>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABDataType,
ABDataType,
decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true,
NumGemmKPrefetchStage>(
a_grid_desc_ak0_m_ak1,
make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
// B matrix blockwise copy
auto b_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<BK0PerBlock, NPerBlock, BK1>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
ABDataType,
ABDataType,
decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true,
NumGemmKPrefetchStage>(
b_grid_desc_bk0_n_bk1,
make_multi_index(0, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[K0PerBlock, MPerBlock] is in LDS
// b_mtx[K0PerBlock, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
constexpr index_t KPack =
math::max(math::lcm(AK1, BK1),
MfmaSelector<ABDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,
ABDataType,
AccDataType,
decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1),
MPerXdl,
NPerXdl,
MXdlPerWave,
NXdlPerWave,
KPack,
LoopSched>();
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ABDataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ABDataType*>(p_shared) + a_block_space_size_aligned,
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0);
// gridwise GEMM pipeline
const auto gridwise_gemm_pipeline =
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>();
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
KPerBlock);
gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,
a_blockwise_copy,
a_grid_buf,
a_block_buf,
a_block_slice_copy_step,
b_grid_desc_bk0_n_bk1,
b_block_desc_bk0_n_bk1,
b_blockwise_copy,
b_grid_buf,
b_block_buf,
b_block_slice_copy_step,
blockwise_gemm,
c_thread_buf,
num_k_block_main_loop);
// shuffle C, Welford and write out
{
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
"wrong!");
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
// TODO: hacky, fix it!
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
// TODO: hacky, fix it!
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<CShuffleDataType*>(p_shared),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_tuple(
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
M1, // M1 = MWave
M2, // M2 * M3 * M4 = MPerXdl
M3,
M4)),
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
N1, // N1 = NWave
N2))), // N2 = NPerXdl
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(
Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_block_idx =
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_block));
const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_block_idx =
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_block));
// shuffle: threadwise copy C from VGPR to LDS
auto c_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
CShuffleDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
ck::tensor_operation::element_wise::PassThrough,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
I1,
I1,
M2,
I1,
M4,
I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
1,
InMemoryDataOperationEnum::Set,
1,
true>{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(0,
0,
m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I3],
m_thread_data_on_block_idx[I4],
n_thread_data_on_block_idx[I2]),
ck::tensor_operation::element_wise::PassThrough{}};
// space filling curve for threadwise C in VGPR
constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
1,
1,
M2,
1,
M4,
1>,
false>{};
// space filling curve for shuffled blockwise C in global mem
constexpr auto sfc_der_global =
SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
Sequence<0, 2, 1, 3>,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>,
false>{};
// LDS c_shuffle_block_desc_mperblock_nperblock
constexpr auto c_shuffle_block_desc_mperblock_nperblock = transform_tensor_descriptor(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_tuple(
make_freeze_transform(I0),
make_pass_through_transform(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I1)),
make_freeze_transform(I0),
make_pass_through_transform(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I3))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<>{}, Sequence<1>{}));
static_assert(PostShuffleThreadClusterSize_M_N::At(I0) *
PostShuffleThreadClusterSize_M_N::At(I1) ==
BlockSize,
"wrong!");
static_assert((CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl) %
PostShuffleThreadClusterSize_M_N::At(I0) ==
0 &&
(CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl) %
PostShuffleThreadClusterSize_M_N::At(I1) ==
0,
"wrong!");
constexpr index_t PostShuffleThreadSliceSize_M =
(CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl) /
PostShuffleThreadClusterSize_M_N::At(I0);
constexpr index_t PostShuffleThreadSliceSize_N =
(CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl) /
PostShuffleThreadClusterSize_M_N::At(I1);
constexpr auto PostShuffleThreadSliceSize_M_N =
Sequence<PostShuffleThreadSliceSize_M, PostShuffleThreadSliceSize_N>{};
// VGPR post_shuffle_thread_desc_m_n
constexpr auto post_shuffle_thread_desc_m_n = make_naive_tensor_descriptor_packed(
make_tuple(Number<PostShuffleThreadSliceSize_M>{},
Number<PostShuffleThreadSliceSize_N>{}));
auto e_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
post_shuffle_thread_desc_m_n.GetElementSpaceSize());
// To apply D0, D1, ... and Welford.
// threadwise copy from LDS to VGPR
constexpr auto post_shuffle_thread_cluster_desc =
make_cluster_descriptor(PostShuffleThreadClusterSize_M_N{}, Sequence<0, 1>{});
const auto post_shuffle_thread_cluster_idx =
post_shuffle_thread_cluster_desc.CalculateBottomIndex(
make_multi_index(get_thread_local_1d_id()));
const auto post_shuffle_thread_data_idx_begin =
post_shuffle_thread_cluster_idx * PostShuffleThreadSliceSize_M_N;
// To apply D0, D1, ... and Welford.
// Copy c shuffle from LDS back to VGPR
auto post_shuffle_thread_copy_lds_to_vgpr =
ThreadwiseTensorSliceTransfer_v2<CShuffleDataType,
AccDataType,
decltype(c_shuffle_block_desc_mperblock_nperblock),
decltype(post_shuffle_thread_desc_m_n),
decltype(PostShuffleThreadSliceSize_M_N),
Sequence<0, 1>,
1,
PostShuffleScalarPerVector,
1,
true>{c_shuffle_block_desc_mperblock_nperblock,
post_shuffle_thread_data_idx_begin};
// D0, D1, ..., Dn
constexpr auto post_shuffle_thread_desc_I1_mperblock_I1_nperblock =
make_naive_tensor_descriptor_packed(
make_tuple(I1,
Number<PostShuffleThreadSliceSize_M>{},
I1,
Number<PostShuffleThreadSliceSize_N>{}));
// FIXME: Decrease usage of VGPR
// Apply pointwise lambda function from multi-source (Global and LDS) into VGPR
auto ds_thread_buf = generate_tuple(
[&](auto) {
return make_static_buffer<AddressSpaceEnum::Vgpr, CShuffleDataType>(
post_shuffle_thread_desc_I1_mperblock_I1_nperblock.GetElementSpaceSize());
},
Number<NumDTensor>{});
// Copy D0, D1, ..., Dn from global to VGPR
auto ds_thread_copy_global_to_vgpr = generate_tuple(
[&](auto I) {
using DDataType = remove_cvref_t<tuple_element_t<I.value, DsDataType>>;
return ThreadwiseTensorSliceTransfer_v2<
DDataType,
AccDataType,
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock[I]),
decltype(post_shuffle_thread_desc_I1_mperblock_I1_nperblock),
Sequence<I1,
PostShuffleThreadSliceSize_M,
I1,
PostShuffleThreadSliceSize_N>,
Sequence<0, 1, 2, 3>,
3,
PostShuffleScalarPerVector,
1,
true>(
ds_grid_desc_mblock_mperblock_nblock_nperblock[I],
make_multi_index(
I0,
m_block_data_idx_on_grid + post_shuffle_thread_data_idx_begin[I0],
I0,
n_block_data_idx_on_grid + post_shuffle_thread_data_idx_begin[I1]));
},
Number<NumDTensor>{});
auto e_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
AccDataType,
EMeanVarDataType,
decltype(post_shuffle_thread_desc_I1_mperblock_I1_nperblock),
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
tensor_operation::element_wise::PassThrough,
Sequence<I1,
PostShuffleThreadSliceSize_M,
I1,
PostShuffleThreadSliceSize_N>, // SliceLengths
Sequence<0, 1, 2, 3>, // DimAccessOrder
3, // DstVectorDim
PostShuffleScalarPerVector,
InMemoryDataOperationEnum::Set,
1,
true>{
e_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(I0,
m_block_data_idx_on_grid + post_shuffle_thread_data_idx_begin[I0],
I0,
n_block_data_idx_on_grid + post_shuffle_thread_data_idx_begin[I1]),
tensor_operation::element_wise::PassThrough{}};
// Welford
constexpr auto thread_welford_src_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<PostShuffleThreadSliceSize_M>{},
Number<PostShuffleThreadSliceSize_N>{}));
constexpr auto thread_welford_dst_desc_m = make_naive_tensor_descriptor_packed(
make_tuple(Number<PostShuffleThreadSliceSize_M>{}));
using ThreadwiseWelford = ThreadwiseWelford<AccDataType,
decltype(thread_welford_src_desc_m_k),
decltype(thread_welford_dst_desc_m)>;
using BlockwiseWelford = BlockwiseWelford<AccDataType,
BlockSize,
PostShuffleThreadClusterSize_M_N,
Sequence<0, 1>,
false>;
constexpr int num_shuffleM =
MPerBlock / (CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl);
constexpr int num_shuffleN =
NPerBlock / (CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl);
using mean_var_vgpr_type =
decltype(make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
thread_welford_dst_desc_m.GetElementSpaceSize()));
using welford_count_vgpr_type =
decltype(make_static_buffer<AddressSpaceEnum::Vgpr, int32_t>(
thread_welford_dst_desc_m.GetElementSpaceSize()));
Array<ThreadwiseWelford, num_shuffleM> threadwise_welfords;
Array<mean_var_vgpr_type, num_shuffleM> mean_thread_bufs;
Array<mean_var_vgpr_type, num_shuffleM> var_thread_bufs;
Array<welford_count_vgpr_type, num_shuffleM> welford_count_thread_bufs;
int max_count = PostShuffleThreadSliceSize_N * num_shuffleN;
const auto nblock = mean_var_grid_desc_mblock_mperblock_nblock.GetLength(I2);
// tail block
if(block_work_idx[I1] % nblock == nblock - 1)
{
constexpr index_t NPerShuffleBlock =
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl;
int NPerBlockTail = NRaw - NPerBlock * (nblock - 1);
int thread_max_len =
PostShuffleThreadSliceSize_N * (post_shuffle_thread_cluster_idx[I1] + 1);
int shuffle_step = 0;
while(thread_max_len <= NPerBlockTail && shuffle_step < num_shuffleN)
{
++shuffle_step;
thread_max_len += NPerShuffleBlock;
}
int delta = 0;
if(thread_max_len - NPerBlockTail > PostShuffleThreadSliceSize_N)
delta = 0;
else if(NPerBlockTail > thread_max_len)
delta = PostShuffleThreadSliceSize_N;
else
delta = PostShuffleThreadSliceSize_N - thread_max_len + NPerBlockTail;
max_count = shuffle_step * PostShuffleThreadSliceSize_N + delta;
}
static_for<0, num_shuffleM, 1>{}([&](auto i) {
threadwise_welfords(i).max_count_ = max_count;
mean_thread_bufs(i) = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
thread_welford_dst_desc_m.GetElementSpaceSize());
var_thread_bufs(i) = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
thread_welford_dst_desc_m.GetElementSpaceSize());
welford_count_thread_bufs(i) = make_static_buffer<AddressSpaceEnum::Vgpr, int32_t>(
thread_welford_dst_desc_m.GetElementSpaceSize());
static_for<0, PostShuffleThreadSliceSize_M, 1>{}([&](auto j) {
mean_thread_bufs(i)(j) = type_convert<AccDataType>(0.0f);
var_thread_bufs(i)(j) = type_convert<AccDataType>(0.0f);
welford_count_thread_bufs(i)(j) = 0;
});
});
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
static_assert(num_access == sfc_der_global.GetNumOfAccess(), "wrong!");
int shuffleM_index = __builtin_amdgcn_readfirstlane(0);
static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to read from LDS
block_sync_lds();
// each thread shuffle data from VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
c_thread_buf,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c_shuffle_block_buf);
// make sure it's safe to write to LDS
block_sync_lds();
// Get shuffle data from LDS to VGPR
post_shuffle_thread_copy_lds_to_vgpr.Run(c_shuffle_block_desc_mperblock_nperblock,
c_shuffle_block_buf,
post_shuffle_thread_desc_m_n,
make_tuple(I0, I0),
e_thread_buf);
// Global read D0, D1, ...
static_for<0, NumDTensor, 1>{}([&](auto Id) {
auto& d_thread_copy_global_to_vgpr = ds_thread_copy_global_to_vgpr(Id);
d_thread_copy_global_to_vgpr.Run(
ds_grid_desc_mblock_mperblock_nblock_nperblock[Id],
ds_grid_buf[Id],
post_shuffle_thread_desc_I1_mperblock_I1_nperblock,
make_tuple(I0, I0, I0, I0),
ds_thread_buf(Id));
if constexpr(access_id < num_access - 1)
{
// move on D0, D1, ...
constexpr auto de_global_step = sfc_der_global.GetForwardStep(access_id);
d_thread_copy_global_to_vgpr.MoveSrcSliceWindow(
ds_grid_desc_mblock_mperblock_nblock_nperblock[Id], de_global_step);
}
});
// cde_element_op(e, c, d0, d1, ...);
static_for<0, post_shuffle_thread_desc_m_n.GetElementSize(), 1>{}([&](auto i) {
const auto c_ds_src_data_refs = concat_tuple_of_reference(
tie(e_thread_buf[i]),
generate_tie(
[&](auto Id) -> const auto& { return ds_thread_buf[Id][i]; },
Number<NumDTensor>{}));
auto e_dst_data_refs = tie(e_thread_buf(i));
unpack2(cde_element_op, e_dst_data_refs, c_ds_src_data_refs);
});
// Global write E
e_thread_copy_vgpr_to_global.Run(post_shuffle_thread_desc_I1_mperblock_I1_nperblock,
make_tuple(I0, I0, I0, I0),
e_thread_buf,
e_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_buf);
if constexpr(access_id < num_access - 1)
{
// move on E
constexpr auto de_global_step = sfc_der_global.GetForwardStep(access_id);
e_thread_copy_vgpr_to_global.MoveDstSliceWindow(
e_grid_desc_mblock_mperblock_nblock_nperblock, de_global_step);
}
// Threadwise welford
auto& threadwise_welford = threadwise_welfords(shuffleM_index);
auto& mean_thread_buf = mean_thread_bufs(shuffleM_index);
auto& var_thread_buf = var_thread_bufs(shuffleM_index);
threadwise_welford.Run(e_thread_buf, mean_thread_buf, var_thread_buf);
if constexpr(access_id < num_access - 1)
{
constexpr auto de_global_step = sfc_der_global.GetForwardStep(access_id);
constexpr int shuffleMInc =
de_global_step[I1] /
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I1);
shuffleM_index = __builtin_amdgcn_readfirstlane(shuffleM_index + shuffleMInc);
}
}); // copy c, d, e + welford
// Blockwise welford and write out
static_for<0, num_shuffleM, 1>{}([&](auto i) {
auto& mean_thread_buf = mean_thread_bufs(i);
auto& var_thread_buf = var_thread_bufs(i);
auto& count_thread_buf = welford_count_thread_bufs(i);
static_for<0, PostShuffleThreadSliceSize_M, 1>{}([&](auto j) {
block_sync_lds();
count_thread_buf(j) = threadwise_welfords(i).cur_count_;
BlockwiseWelford::Run(
mean_thread_buf(j), var_thread_buf(j), count_thread_buf(j));
});
if(post_shuffle_thread_cluster_idx[I1] == 0)
{
constexpr auto thread_welford_desc_I_m_I = make_naive_tensor_descriptor_packed(
make_tuple(I1, Number<PostShuffleThreadSliceSize_M>{}, I1));
constexpr int shuffleMPerBlock =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I1);
auto mean_var_count_thread_copy_index = make_multi_index(
block_work_idx[I0], // mblock
shuffleMPerBlock * i + post_shuffle_thread_data_idx_begin[I0], // mperblock
block_work_idx[I1]); // nblock
auto mean_var_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
AccDataType,
EMeanVarDataType,
decltype(thread_welford_desc_I_m_I),
decltype(mean_var_grid_desc_mblock_mperblock_nblock),
tensor_operation::element_wise::PassThrough,
Sequence<1, PostShuffleThreadSliceSize_M, 1>,
Sequence<0, 1, 2>,
1,
1,
InMemoryDataOperationEnum::Set,
1,
true>{mean_var_grid_desc_mblock_mperblock_nblock,
mean_var_count_thread_copy_index,
tensor_operation::element_wise::PassThrough{}};
mean_var_thread_copy_vgpr_to_global.Run(
thread_welford_desc_I_m_I,
make_tuple(I0, I0, I0),
mean_thread_buf,
mean_var_grid_desc_mblock_mperblock_nblock,
mean_grid_buf); // write mean
mean_var_thread_copy_vgpr_to_global.Run(
thread_welford_desc_I_m_I,
make_tuple(I0, I0, I0),
var_thread_buf,
mean_var_grid_desc_mblock_mperblock_nblock,
var_grid_buf); // write variance
// Stride of count is [0, 1]. Only the first row in count[0, 0:nblock] need
// to be written.
if(i == 0 && block_work_idx[I0] == 0 &&
post_shuffle_thread_cluster_idx[I0] == 0)
{
auto count_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
int32_t,
int32_t,
decltype(thread_welford_desc_I_m_I),
decltype(count_grid_desc_mblock_mperblock_nblock),
tensor_operation::element_wise::PassThrough,
Sequence<1, PostShuffleThreadSliceSize_M, 1>,
Sequence<0, 1, 2>,
1,
1,
InMemoryDataOperationEnum::Set,
1,
false>{count_grid_desc_mblock_mperblock_nblock,
mean_var_count_thread_copy_index,
tensor_operation::element_wise::PassThrough{}};
count_thread_copy_vgpr_to_global.Run(
thread_welford_desc_I_m_I,
make_tuple(I0, I0, I0),
count_thread_buf,
count_grid_desc_mblock_mperblock_nblock,
welford_count_grid_buf); // write count
}
}
});
} // shuffle C + Ds + welford + write out
} // run
};
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp"
namespace ck {
template <typename EMeanVarDataType,
typename HDataType,
typename GammaDataType,
typename BetaDataType,
typename ComputeDataType,
typename EHGridDesc_M_N,
typename MeanVarGridDesc_M_NBlock,
typename CountGridDesc_M_NBlock,
typename GammaBetaGridDesc_N,
typename HElementwiseOperation,
index_t BlockSize,
index_t MThreadClusterSize,
index_t NThreadClusterSize,
index_t MThreadSliceSize,
index_t NThreadSliceSize,
index_t ESrcVectorSize,
index_t HDstVectorSize,
index_t GammaSrcVectorSize,
index_t BetaSrcVectorSize>
struct GridwiseWelfordSecondHalfLayernorm2d
{
static_assert(NThreadSliceSize % ESrcVectorSize == 0 &&
NThreadSliceSize % GammaSrcVectorSize == 0 &&
NThreadSliceSize % BetaSrcVectorSize == 0,
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static_assert(NThreadSliceSize % HDstVectorSize == 0,
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
using ThreadClusterLengths_M_N = Sequence<MThreadClusterSize, NThreadClusterSize>;
using ThreadBufferDimAccessOrder = Sequence<0, 1>;
using ThreadClusterArrangeOrder = Sequence<0, 1>;
static constexpr auto thread_cluster_desc_m_n =
make_cluster_descriptor(ThreadClusterLengths_M_N{}, ThreadClusterArrangeOrder{});
using ThreadBufferLengths_M_N = Sequence<MThreadSliceSize, NThreadSliceSize>;
static constexpr auto thread_buffer_desc_m_n = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<NThreadSliceSize>{}));
using ThreadBufferLengths_M_1 = Sequence<MThreadSliceSize, 1>;
static constexpr auto thread_buffer_desc_m_1 =
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}, Number<1>{}));
using ThreadBufferLengths_N = Sequence<NThreadSliceSize>;
static constexpr auto thread_buffer_desc_n =
make_naive_tensor_descriptor_packed(make_tuple(Number<NThreadSliceSize>{}));
using ThreadWelfordSrcDesc_M_1 = decltype(thread_buffer_desc_m_1);
using ThreadWelfordDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
using ThreadwiseWelford =
ThreadwiseWelfordMerge<ComputeDataType, ThreadWelfordSrcDesc_M_1, ThreadWelfordDstDesc_M>;
using BlockwiseWelford = BlockwiseWelford<ComputeDataType,
BlockSize,
ThreadClusterLengths_M_N,
ThreadClusterArrangeOrder>;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t N_BlockTileSize = NThreadClusterSize * NThreadSliceSize;
__device__ static void Run(const EMeanVarDataType* __restrict__ p_e_grid,
const EMeanVarDataType* __restrict__ p_in_welford_mean_grid,
const EMeanVarDataType* __restrict__ p_in_welford_var_grid,
const int32_t* __restrict__ p_in_welford_count_grid,
const GammaDataType* __restrict__ p_gamma_grid,
const BetaDataType* __restrict__ p_beta_grid,
HDataType* __restrict__ p_h_grid,
const EHGridDesc_M_N& e_grid_desc_m_n,
const EHGridDesc_M_N& h_grid_desc_m_n,
const MeanVarGridDesc_M_NBlock& mean_var_grid_desc_m_nblock,
const CountGridDesc_M_NBlock& count_grid_desc_m_nblock,
const GammaBetaGridDesc_N& gamma_grid_desc_n,
const GammaBetaGridDesc_N& beta_grid_desc_n,
index_t numMeanVarCountBlockTileIteration_N,
index_t NBlockClusterLength,
ComputeDataType epsilon,
HElementwiseOperation h_element_op)
{
// Thread/Block id
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id();
const auto block_work_idx = make_tuple(block_global_id / NBlockClusterLength,
block_global_id % NBlockClusterLength);
const auto thread_cluster_idx =
thread_cluster_desc_m_n.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_n_cluster_id = thread_cluster_idx[I1];
// Global Memory
const auto e_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_e_grid, e_grid_desc_m_n.GetElementSpaceSize());
const auto welford_mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_welford_mean_grid, mean_var_grid_desc_m_nblock.GetElementSpaceSize());
const auto welford_var_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_welford_var_grid, mean_var_grid_desc_m_nblock.GetElementSpaceSize());
const auto welford_count_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_welford_count_grid, count_grid_desc_m_nblock.GetElementSpaceSize());
const auto gamma_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_gamma_grid, gamma_grid_desc_n.GetElementSpaceSize());
const auto beta_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_beta_grid, beta_grid_desc_n.GetElementSpaceSize());
auto h_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_h_grid, h_grid_desc_m_n.GetElementSpaceSize());
// VGPR
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
in_welford_mean_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
in_welford_var_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, int32_t, MThreadSliceSize, true>
in_welford_count_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
welford_mean_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
welford_var_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, int32_t, MThreadSliceSize, true>
welford_count_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr,
ComputeDataType,
MThreadSliceSize * NThreadSliceSize,
true>
e_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr,
ComputeDataType,
MThreadSliceSize * NThreadSliceSize,
true>
gamma_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr,
ComputeDataType,
MThreadSliceSize * NThreadSliceSize,
true>
beta_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr,
ComputeDataType,
MThreadSliceSize * NThreadSliceSize,
true>
h_thread_buf;
// IO
auto threadwise_mean_load_m_nblock =
ThreadwiseTensorSliceTransfer_v2<EMeanVarDataType,
ComputeDataType,
MeanVarGridDesc_M_NBlock,
decltype(thread_buffer_desc_m_1),
ThreadBufferLengths_M_1,
ThreadBufferDimAccessOrder,
1,
1,
1,
true>(
mean_var_grid_desc_m_nblock,
make_multi_index(block_work_idx[I0] * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_n_cluster_id));
auto threadwise_var_load_m_nblock =
ThreadwiseTensorSliceTransfer_v2<EMeanVarDataType,
ComputeDataType,
MeanVarGridDesc_M_NBlock,
decltype(thread_buffer_desc_m_1),
ThreadBufferLengths_M_1,
ThreadBufferDimAccessOrder,
1,
1,
1,
true>(
mean_var_grid_desc_m_nblock,
make_multi_index(block_work_idx[I0] * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_n_cluster_id));
auto threadwise_count_load_m_nblock =
ThreadwiseTensorSliceTransfer_v2<int32_t,
int32_t,
CountGridDesc_M_NBlock,
decltype(thread_buffer_desc_m_1),
ThreadBufferLengths_M_1,
ThreadBufferDimAccessOrder,
1,
1,
1,
true>(
count_grid_desc_m_nblock,
make_multi_index(block_work_idx[I0] * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_n_cluster_id));
auto threadwise_e_load_m_n =
ThreadwiseTensorSliceTransfer_v2<EMeanVarDataType,
ComputeDataType,
decltype(e_grid_desc_m_n),
decltype(thread_buffer_desc_m_n),
ThreadBufferLengths_M_N,
ThreadBufferDimAccessOrder,
1, // SrcVectorDim
ESrcVectorSize,
1,
true>(
e_grid_desc_m_n,
make_multi_index(
block_work_idx[I0] * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
block_work_idx[I1] * N_BlockTileSize + thread_n_cluster_id * NThreadSliceSize));
auto threadwise_gamma_load_n =
ThreadwiseTensorSliceTransfer_v2<GammaDataType,
ComputeDataType,
decltype(gamma_grid_desc_n),
decltype(thread_buffer_desc_n),
ThreadBufferLengths_N,
Sequence<0>, // DimAccessOrder,
0, // SrcVectorDim,
GammaSrcVectorSize,
1,
true>(
gamma_grid_desc_n,
make_multi_index(block_work_idx[I1] * N_BlockTileSize +
thread_n_cluster_id * NThreadSliceSize));
auto threadwise_beta_load_n =
ThreadwiseTensorSliceTransfer_v2<BetaDataType,
ComputeDataType,
decltype(beta_grid_desc_n),
decltype(thread_buffer_desc_n),
ThreadBufferLengths_N,
Sequence<0>, // DimAccessOrder,
0, // SrcVectorDim,
BetaSrcVectorSize,
1,
true>(
beta_grid_desc_n,
make_multi_index(block_work_idx[I1] * N_BlockTileSize +
thread_n_cluster_id * NThreadSliceSize));
auto threadwise_h_store_m_n =
ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType,
HDataType,
decltype(thread_buffer_desc_m_n),
decltype(h_grid_desc_m_n),
HElementwiseOperation,
ThreadBufferLengths_M_N,
ThreadBufferDimAccessOrder,
1, // DstVectorDim
HDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
true>(
h_grid_desc_m_n,
make_multi_index(
block_work_idx[I0] * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
block_work_idx[I1] * N_BlockTileSize + thread_n_cluster_id * NThreadSliceSize),
h_element_op);
// step1: Merge mean and variance
constexpr auto mean_var_count_thread_copy_step_I0_n =
make_multi_index(I0, NThreadClusterSize);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
welford_mean_thread_buf(I) = type_convert<ComputeDataType>(0.0f);
welford_var_thread_buf(I) = type_convert<ComputeDataType>(0.0f);
welford_count_thread_buf(I) = 0;
});
for(index_t n = 0; n < numMeanVarCountBlockTileIteration_N; ++n)
{
threadwise_mean_load_m_nblock.Run(mean_var_grid_desc_m_nblock,
welford_mean_global_val_buf,
thread_buffer_desc_m_1,
make_tuple(I0, I0),
in_welford_mean_thread_buf);
threadwise_var_load_m_nblock.Run(mean_var_grid_desc_m_nblock,
welford_var_global_val_buf,
thread_buffer_desc_m_1,
make_tuple(I0, I0),
in_welford_var_thread_buf);
threadwise_count_load_m_nblock.Run(count_grid_desc_m_nblock,
welford_count_global_val_buf,
thread_buffer_desc_m_1,
make_tuple(I0, I0),
in_welford_count_thread_buf);
ThreadwiseWelford::Run(in_welford_mean_thread_buf,
in_welford_var_thread_buf,
in_welford_count_thread_buf,
welford_mean_thread_buf,
welford_var_thread_buf,
welford_count_thread_buf);
threadwise_mean_load_m_nblock.MoveSrcSliceWindow(mean_var_grid_desc_m_nblock,
mean_var_count_thread_copy_step_I0_n);
threadwise_var_load_m_nblock.MoveSrcSliceWindow(mean_var_grid_desc_m_nblock,
mean_var_count_thread_copy_step_I0_n);
threadwise_count_load_m_nblock.MoveSrcSliceWindow(count_grid_desc_m_nblock,
mean_var_count_thread_copy_step_I0_n);
}
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if constexpr(I > 0)
block_sync_lds();
BlockwiseWelford::Run(
welford_mean_thread_buf(I), welford_var_thread_buf(I), welford_count_thread_buf(I));
});
// step2: normalization
// h[m, n] = [(e[m, n] - mean[m]) / sqrt(var[m] + eps)] * gamma[n] + beta[n]
threadwise_e_load_m_n.Run(e_grid_desc_m_n,
e_global_val_buf,
thread_buffer_desc_m_n,
make_tuple(I0, I0),
e_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto m) {
auto divisor = 1 / ck::math::sqrt(welford_var_thread_buf(m) + epsilon);
static_for<0, NThreadSliceSize, 1>{}([&](auto n) {
constexpr auto m_n = thread_buffer_desc_m_n.CalculateOffset(make_tuple(m, n));
h_thread_buf(Number<m_n>{}) =
(e_thread_buf(Number<m_n>{}) - welford_mean_thread_buf(m)) * divisor;
});
});
threadwise_gamma_load_n.Run(gamma_grid_desc_n,
gamma_global_val_buf,
thread_buffer_desc_n,
make_tuple(I0),
gamma_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto m) {
static_for<0, NThreadSliceSize, 1>{}([&](auto n) {
constexpr auto m_n = thread_buffer_desc_m_n.CalculateOffset(make_tuple(m, n));
h_thread_buf(Number<m_n>{}) = h_thread_buf(Number<m_n>{}) * gamma_thread_buf(n);
});
});
threadwise_beta_load_n.Run(beta_grid_desc_n,
beta_global_val_buf,
thread_buffer_desc_n,
make_tuple(I0),
beta_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto m) {
static_for<0, NThreadSliceSize, 1>{}([&](auto n) {
constexpr auto m_n = thread_buffer_desc_m_n.CalculateOffset(make_tuple(m, n));
h_thread_buf(Number<m_n>{}) = h_thread_buf(Number<m_n>{}) + beta_thread_buf(n);
});
});
threadwise_h_store_m_n.Run(thread_buffer_desc_m_n,
make_tuple(I0, I0),
h_thread_buf,
h_grid_desc_m_n,
h_global_val_buf);
} // run
};
} // namespace ck
...@@ -434,7 +434,7 @@ struct GridwiseElementwiseLayernormWelfordVariance_mk_to_mk ...@@ -434,7 +434,7 @@ struct GridwiseElementwiseLayernormWelfordVariance_mk_to_mk
}); });
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
auto divisor = 1 / __builtin_amdgcn_sqrtf(var_thread_buf(iM) + epsilon); auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon);
static_for<0, XThreadBufferNumber, 1>{}([&](auto iK0) { static_for<0, XThreadBufferNumber, 1>{}([&](auto iK0) {
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
constexpr auto offset_m_k = constexpr auto offset_m_k =
......
...@@ -319,7 +319,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk ...@@ -319,7 +319,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
}); });
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
auto divisor = 1 / __builtin_amdgcn_sqrtf(var_thread_buf(iM) + epsilon); auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon);
static_for<0, XThreadBufferNumber, 1>{}([&](auto iK0) { static_for<0, XThreadBufferNumber, 1>{}([&](auto iK0) {
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
constexpr auto offset_m_k = constexpr auto offset_m_k =
......
...@@ -90,10 +90,13 @@ struct ReferenceLayernorm : public device::BaseOperator ...@@ -90,10 +90,13 @@ struct ReferenceLayernorm : public device::BaseOperator
for(int m = 0; m < M; ++m) for(int m = 0; m < M; ++m)
{ {
AccDataType divisor =
static_cast<AccDataType>(1) / ck::math::sqrt(var(m) + arg.epsilon_);
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
auto x_val = ck::type_convert<AccDataType>(arg.x_m_n_(m, n)); auto x_val = ck::type_convert<AccDataType>(arg.x_m_n_(m, n));
auto y_val = (x_val - mean(m)) / sqrt(var(m) + arg.epsilon_); auto y_val = (x_val - mean(m)) * divisor;
y_val = (y_val * arg.gamma_n_(n)) + arg.beta_n_(n); y_val = (y_val * arg.gamma_n_(n)) + arg.beta_n_(n);
arg.acc_elementwise_op_(y_val, y_val); arg.acc_elementwise_op_(y_val, y_val);
arg.y_m_n_(m, n) = ck::type_convert<YDataType>(y_val); arg.y_m_n_(m, n) = ck::type_convert<YDataType>(y_val);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment