Commit 8bf23425 authored by rocking's avatar rocking
Browse files

calculate max count for tail block

parent cb17765e
......@@ -56,15 +56,15 @@ using BElementOp = PassThrough;
using CDEElementOp = AddReluAdd;
using HElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// clang-format off
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDLayernorm_Xdl_CShuffle
//######| ALayout| BLayout| DsLayout| HLayout| AData| BData| AccData| CShuffle| DsData| GammaData| BetaData| HData| A| B| CDE| H| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| PostShuffle| PostShuffle|
//######| | | | | Type| Type| Type| DataType| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector|
//######| | | | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _M_N| _M_N|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | LayernormThreadClusterSize_M_N, LayernormThreadSliceSize_M_N
< ALayout, BLayout, DsLayout, HLayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, GammaDataType, BetaDataType, HDataType, AElementOp, BElementOp, CDEElementOp, HElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<64, 4>, 4, S<8, 32>, S<1, 8>, 1, 8, 8, 8, 8, 1>;
//######| ALayout| BLayout| DsLayout| HLayout| AData| BData| AccData| CShuffle| DsData| GammaData| BetaData| HData| A| B| CDE| H| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| PostShuffle| PostShuffle| Layernorm| Layernorm| Layernorm| Layernorm| Layernorm| Layernorm| Layernorm|
//######| | | | | Type| Type| Type| DataType| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| ThreadClusterSize| ThreadSliceSize| ESrcHDst| ESrc| HDst| GammaSrc| BetaSrc|
//######| | | | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _M_N| _M_N| _M_N| _M_N| VectorDim| VectorSize| VectorSize| VectorSize| VectorSize|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, DsLayout, HLayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, GammaDataType, BetaDataType, HDataType, AElementOp, BElementOp, CDEElementOp, HElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<64, 4>, 4, S<8, 32>, S<1, 8>, 1, 8, 8, 8, 8>;
// clang-format on
auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) {
......@@ -149,11 +149,11 @@ int main()
ck::index_t N = 1024;
ck::index_t K = 1024;
ck::index_t StrideA = 1024;
ck::index_t StrideB = 1024;
ck::index_t StrideA = K;
ck::index_t StrideB = K;
ck::index_t StrideD0 = 0;
ck::index_t StrideD1 = 1024;
ck::index_t StrideH = 1024;
ck::index_t StrideD1 = N;
ck::index_t StrideH = N;
float epsilon = 1e-5;
......@@ -253,7 +253,7 @@ int main()
e_device_buf.FromDevice(e_m_n.mData.data());
h_device_buf.FromDevice(h_m_n.mData.data());
pass &= ck::utils::check_err(e_m_n, e_m_n_host);
pass &= ck::utils::check_err(e_m_n, e_m_n_host, "Error: Incorrect results e_m_n");
pass &=
ck::utils::check_err(h_m_n, h_m_n_host, "Error: Incorrect results h_m_n", 1e-2, 1e-2);
}
......
......@@ -59,7 +59,8 @@ __global__ void
e_grid_desc_mblock_mperblock_nblock_nperblock,
const MeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock
mean_var_count_grid_desc_mblock_mperblock_nblock,
const Block2ETileMap block_2_etile_map)
const Block2ETileMap block_2_etile_map,
index_t NRaw)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemmWelford::GetSharedMemoryNumberOfByte()];
......@@ -81,7 +82,8 @@ __global__ void
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
mean_var_count_grid_desc_mblock_mperblock_nblock,
block_2_etile_map);
block_2_etile_map,
NRaw);
#else
ignore = p_a_grid;
ignore = p_b_grid;
......@@ -99,6 +101,7 @@ __global__ void
ignore = e_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = mean_var_count_grid_desc_mblock_mperblock_nblock;
ignore = block_2_etile_map;
ignore = NRaw;
#endif
}
......@@ -225,7 +228,6 @@ template <typename ALayout,
index_t LayernormHDstVectorSize,
index_t LayernormGammaSrcVectorSize,
index_t LayernormBetaSrcVectorSize,
index_t LayernormMeanVarSrcDstVectorSize,
LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
{
......@@ -329,7 +331,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
}();
return PadTensorDescriptor(
grid_desc_m_n, make_tuple(MPerBlock, NPerBlock), Sequence<true, true>{});
grid_desc_m_n, make_tuple(MPerBlock, NBlock), Sequence<true, false>{});
}
template <typename LayOut>
......@@ -487,8 +489,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
LayernormESrcVectorSize,
LayernormHDstVectorSize,
LayernormGammaSrcVectorSize,
LayernormBetaSrcVectorSize,
LayernormMeanVarSrcDstVectorSize>;
LayernormBetaSrcVectorSize>;
// Argument
struct Argument : public BaseArgument
......@@ -732,7 +733,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.mean_var_count_grid_desc_mblock_mperblock_nblock_,
arg.block_2_etile_map_);
arg.block_2_etile_map_,
arg.NRaw_);
grid_size = math::integer_divide_ceil(M, LayernormBlockTileSize_M_N::At(0));
......
......@@ -240,7 +240,6 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
Number<NumDTensor>{});
}
// TODO - MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
template <typename GridDescriptor_M_N>
__host__ __device__ static constexpr auto
MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(const GridDescriptor_M_N& grid_desc_m_n)
......@@ -381,7 +380,8 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
e_grid_desc_mblock_mperblock_nblock_nperblock,
const MeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock&
mean_var_count_grid_desc_mblock_mperblock_nblock,
const Block2ETileMap& block_2_etile_map)
const Block2ETileMap& block_2_etile_map,
index_t NRaw)
{
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
......@@ -879,9 +879,38 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
Array<mean_var_vgpr_type, num_shuffleM> var_thread_bufs;
Array<welford_count_vgpr_type, num_shuffleM> welford_count_thread_bufs;
int max_count = PostShuffleThreadSliceSize_N * num_shuffleN;
const auto nblock = mean_var_count_grid_desc_mblock_mperblock_nblock.GetLength(I2);
// tail block
if(block_work_idx[I1] % nblock == nblock - 1)
{
constexpr index_t NPerShuffleBlock =
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl;
int NPerBlockTail = NRaw - NPerBlock * (nblock - 1);
int thread_max_len =
PostShuffleThreadSliceSize_N * (post_shuffle_thread_cluster_idx[I1] + 1);
int shuffle_step = 0;
while(thread_max_len <= NPerBlockTail && shuffle_step < num_shuffleN)
{
++shuffle_step;
thread_max_len += NPerShuffleBlock;
}
int delta = 0;
if(thread_max_len - NPerBlockTail > PostShuffleThreadSliceSize_N)
delta = 0;
else if(NPerBlockTail > thread_max_len)
delta = PostShuffleThreadSliceSize_N;
else
delta = PostShuffleThreadSliceSize_N - thread_max_len + NPerBlockTail;
max_count = shuffle_step * PostShuffleThreadSliceSize_N + delta;
}
static_for<0, num_shuffleM, 1>{}([&](auto i) {
// TODO - padding
threadwise_welfords(i).max_count_ = PostShuffleThreadSliceSize_N * num_shuffleN;
threadwise_welfords(i).max_count_ = max_count;
mean_thread_bufs(i) = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
thread_welford_dst_desc_m.GetElementSpaceSize());
......
......@@ -39,8 +39,7 @@ template <typename EDataType,
index_t ESrcVectorSize,
index_t HDstVectorSize,
index_t GammaSrcVectorSize,
index_t BetaSrcVectorSize,
index_t MeanVarSrcDstVectorSize>
index_t BetaSrcVectorSize>
struct GridwiseWelfordSecondHalfLayernorm2d
{
// TODO - Support ESrcHDstVectorDim == 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment