"...composable_kernel_rocm.git" did not exist on "421996707e28caec9ce3702e3f9b451bf5d4c969"
Commit 8bf23425 authored by rocking's avatar rocking
Browse files

calculate max count for tail block

parent cb17765e
...@@ -56,15 +56,15 @@ using BElementOp = PassThrough; ...@@ -56,15 +56,15 @@ using BElementOp = PassThrough;
using CDEElementOp = AddReluAdd; using CDEElementOp = AddReluAdd;
using HElementOp = PassThrough; using HElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// clang-format off // clang-format off
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDLayernorm_Xdl_CShuffle using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDLayernorm_Xdl_CShuffle
//######| ALayout| BLayout| DsLayout| HLayout| AData| BData| AccData| CShuffle| DsData| GammaData| BetaData| HData| A| B| CDE| H| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| PostShuffle| PostShuffle| //######| ALayout| BLayout| DsLayout| HLayout| AData| BData| AccData| CShuffle| DsData| GammaData| BetaData| HData| A| B| CDE| H| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| PostShuffle| PostShuffle| Layernorm| Layernorm| Layernorm| Layernorm| Layernorm| Layernorm| Layernorm|
//######| | | | | Type| Type| Type| DataType| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| //######| | | | | Type| Type| Type| DataType| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| ThreadClusterSize| ThreadSliceSize| ESrcHDst| ESrc| HDst| GammaSrc| BetaSrc|
//######| | | | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _M_N| _M_N| //######| | | | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _M_N| _M_N| _M_N| _M_N| VectorDim| VectorSize| VectorSize| VectorSize| VectorSize|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | LayernormThreadClusterSize_M_N, LayernormThreadSliceSize_M_N //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, DsLayout, HLayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, GammaDataType, BetaDataType, HDataType, AElementOp, BElementOp, CDEElementOp, HElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<64, 4>, 4, S<8, 32>, S<1, 8>, 1, 8, 8, 8, 8, 1>; < ALayout, BLayout, DsLayout, HLayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, GammaDataType, BetaDataType, HDataType, AElementOp, BElementOp, CDEElementOp, HElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<64, 4>, 4, S<8, 32>, S<1, 8>, 1, 8, 8, 8, 8>;
// clang-format on // clang-format on
auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) { auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) {
...@@ -149,11 +149,11 @@ int main() ...@@ -149,11 +149,11 @@ int main()
ck::index_t N = 1024; ck::index_t N = 1024;
ck::index_t K = 1024; ck::index_t K = 1024;
ck::index_t StrideA = 1024; ck::index_t StrideA = K;
ck::index_t StrideB = 1024; ck::index_t StrideB = K;
ck::index_t StrideD0 = 0; ck::index_t StrideD0 = 0;
ck::index_t StrideD1 = 1024; ck::index_t StrideD1 = N;
ck::index_t StrideH = 1024; ck::index_t StrideH = N;
float epsilon = 1e-5; float epsilon = 1e-5;
...@@ -253,7 +253,7 @@ int main() ...@@ -253,7 +253,7 @@ int main()
e_device_buf.FromDevice(e_m_n.mData.data()); e_device_buf.FromDevice(e_m_n.mData.data());
h_device_buf.FromDevice(h_m_n.mData.data()); h_device_buf.FromDevice(h_m_n.mData.data());
pass &= ck::utils::check_err(e_m_n, e_m_n_host); pass &= ck::utils::check_err(e_m_n, e_m_n_host, "Error: Incorrect results e_m_n");
pass &= pass &=
ck::utils::check_err(h_m_n, h_m_n_host, "Error: Incorrect results h_m_n", 1e-2, 1e-2); ck::utils::check_err(h_m_n, h_m_n_host, "Error: Incorrect results h_m_n", 1e-2, 1e-2);
} }
......
...@@ -59,7 +59,8 @@ __global__ void ...@@ -59,7 +59,8 @@ __global__ void
e_grid_desc_mblock_mperblock_nblock_nperblock, e_grid_desc_mblock_mperblock_nblock_nperblock,
const MeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock const MeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock
mean_var_count_grid_desc_mblock_mperblock_nblock, mean_var_count_grid_desc_mblock_mperblock_nblock,
const Block2ETileMap block_2_etile_map) const Block2ETileMap block_2_etile_map,
index_t NRaw)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemmWelford::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemmWelford::GetSharedMemoryNumberOfByte()];
...@@ -81,7 +82,8 @@ __global__ void ...@@ -81,7 +82,8 @@ __global__ void
ds_grid_desc_mblock_mperblock_nblock_nperblock, ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock, e_grid_desc_mblock_mperblock_nblock_nperblock,
mean_var_count_grid_desc_mblock_mperblock_nblock, mean_var_count_grid_desc_mblock_mperblock_nblock,
block_2_etile_map); block_2_etile_map,
NRaw);
#else #else
ignore = p_a_grid; ignore = p_a_grid;
ignore = p_b_grid; ignore = p_b_grid;
...@@ -99,6 +101,7 @@ __global__ void ...@@ -99,6 +101,7 @@ __global__ void
ignore = e_grid_desc_mblock_mperblock_nblock_nperblock; ignore = e_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = mean_var_count_grid_desc_mblock_mperblock_nblock; ignore = mean_var_count_grid_desc_mblock_mperblock_nblock;
ignore = block_2_etile_map; ignore = block_2_etile_map;
ignore = NRaw;
#endif #endif
} }
...@@ -225,7 +228,6 @@ template <typename ALayout, ...@@ -225,7 +228,6 @@ template <typename ALayout,
index_t LayernormHDstVectorSize, index_t LayernormHDstVectorSize,
index_t LayernormGammaSrcVectorSize, index_t LayernormGammaSrcVectorSize,
index_t LayernormBetaSrcVectorSize, index_t LayernormBetaSrcVectorSize,
index_t LayernormMeanVarSrcDstVectorSize,
LoopScheduler LoopSched = make_default_loop_scheduler()> LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
{ {
...@@ -329,7 +331,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -329,7 +331,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
}(); }();
return PadTensorDescriptor( return PadTensorDescriptor(
grid_desc_m_n, make_tuple(MPerBlock, NPerBlock), Sequence<true, true>{}); grid_desc_m_n, make_tuple(MPerBlock, NBlock), Sequence<true, false>{});
} }
template <typename LayOut> template <typename LayOut>
...@@ -487,8 +489,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -487,8 +489,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
LayernormESrcVectorSize, LayernormESrcVectorSize,
LayernormHDstVectorSize, LayernormHDstVectorSize,
LayernormGammaSrcVectorSize, LayernormGammaSrcVectorSize,
LayernormBetaSrcVectorSize, LayernormBetaSrcVectorSize>;
LayernormMeanVarSrcDstVectorSize>;
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
...@@ -732,7 +733,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -732,7 +733,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.mean_var_count_grid_desc_mblock_mperblock_nblock_, arg.mean_var_count_grid_desc_mblock_mperblock_nblock_,
arg.block_2_etile_map_); arg.block_2_etile_map_,
arg.NRaw_);
grid_size = math::integer_divide_ceil(M, LayernormBlockTileSize_M_N::At(0)); grid_size = math::integer_divide_ceil(M, LayernormBlockTileSize_M_N::At(0));
......
...@@ -240,7 +240,6 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -240,7 +240,6 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
Number<NumDTensor>{}); Number<NumDTensor>{});
} }
// TODO - MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
template <typename GridDescriptor_M_N> template <typename GridDescriptor_M_N>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(const GridDescriptor_M_N& grid_desc_m_n) MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(const GridDescriptor_M_N& grid_desc_m_n)
...@@ -381,7 +380,8 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -381,7 +380,8 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
e_grid_desc_mblock_mperblock_nblock_nperblock, e_grid_desc_mblock_mperblock_nblock_nperblock,
const MeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock& const MeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock&
mean_var_count_grid_desc_mblock_mperblock_nblock, mean_var_count_grid_desc_mblock_mperblock_nblock,
const Block2ETileMap& block_2_etile_map) const Block2ETileMap& block_2_etile_map,
index_t NRaw)
{ {
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
...@@ -879,9 +879,38 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -879,9 +879,38 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
Array<mean_var_vgpr_type, num_shuffleM> var_thread_bufs; Array<mean_var_vgpr_type, num_shuffleM> var_thread_bufs;
Array<welford_count_vgpr_type, num_shuffleM> welford_count_thread_bufs; Array<welford_count_vgpr_type, num_shuffleM> welford_count_thread_bufs;
int max_count = PostShuffleThreadSliceSize_N * num_shuffleN;
const auto nblock = mean_var_count_grid_desc_mblock_mperblock_nblock.GetLength(I2);
// tail block
if(block_work_idx[I1] % nblock == nblock - 1)
{
constexpr index_t NPerShuffleBlock =
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl;
int NPerBlockTail = NRaw - NPerBlock * (nblock - 1);
int thread_max_len =
PostShuffleThreadSliceSize_N * (post_shuffle_thread_cluster_idx[I1] + 1);
int shuffle_step = 0;
while(thread_max_len <= NPerBlockTail && shuffle_step < num_shuffleN)
{
++shuffle_step;
thread_max_len += NPerShuffleBlock;
}
int delta = 0;
if(thread_max_len - NPerBlockTail > PostShuffleThreadSliceSize_N)
delta = 0;
else if(NPerBlockTail > thread_max_len)
delta = PostShuffleThreadSliceSize_N;
else
delta = PostShuffleThreadSliceSize_N - thread_max_len + NPerBlockTail;
max_count = shuffle_step * PostShuffleThreadSliceSize_N + delta;
}
static_for<0, num_shuffleM, 1>{}([&](auto i) { static_for<0, num_shuffleM, 1>{}([&](auto i) {
// TODO - padding threadwise_welfords(i).max_count_ = max_count;
threadwise_welfords(i).max_count_ = PostShuffleThreadSliceSize_N * num_shuffleN;
mean_thread_bufs(i) = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>( mean_thread_bufs(i) = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
thread_welford_dst_desc_m.GetElementSpaceSize()); thread_welford_dst_desc_m.GetElementSpaceSize());
......
...@@ -39,8 +39,7 @@ template <typename EDataType, ...@@ -39,8 +39,7 @@ template <typename EDataType,
index_t ESrcVectorSize, index_t ESrcVectorSize,
index_t HDstVectorSize, index_t HDstVectorSize,
index_t GammaSrcVectorSize, index_t GammaSrcVectorSize,
index_t BetaSrcVectorSize, index_t BetaSrcVectorSize>
index_t MeanVarSrcDstVectorSize>
struct GridwiseWelfordSecondHalfLayernorm2d struct GridwiseWelfordSecondHalfLayernorm2d
{ {
// TODO - Support ESrcHDstVectorDim == 0 // TODO - Support ESrcHDstVectorDim == 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment