Commit 0ed92c0c authored by ltqin's avatar ltqin
Browse files

split save z matrix for spill register issue

parent a634c224
...@@ -31,6 +31,7 @@ Kernel outputs: ...@@ -31,6 +31,7 @@ Kernel outputs:
#include <numeric> #include <numeric>
#include <initializer_list> #include <initializer_list>
#include <cstdlib> #include <cstdlib>
#include <fstream>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
......
...@@ -17,7 +17,7 @@ struct BlockwiseDropout ...@@ -17,7 +17,7 @@ struct BlockwiseDropout
static constexpr index_t KRepeat = ThreadSliceDesc_M_K{}.GetLength(I1); static constexpr index_t KRepeat = ThreadSliceDesc_M_K{}.GetLength(I1);
template <typename CThreadBuffer, bool using_sign_bit = false> template <typename CThreadBuffer, bool using_sign_bit = false>
__host__ __device__ void ApplyDropout(CThreadBuffer& in_thread_buf, ck::philox ph) __host__ __device__ void ApplyDropout(CThreadBuffer& in_thread_buf, ck::philox& ph)
{ {
auto execute_dropout = [&](bool keep, DataType val) { auto execute_dropout = [&](bool keep, DataType val) {
...@@ -52,7 +52,7 @@ struct BlockwiseDropout ...@@ -52,7 +52,7 @@ struct BlockwiseDropout
template <typename CThreadBuffer, typename ZThreadBuffer, bool using_sign_bit = false> template <typename CThreadBuffer, typename ZThreadBuffer, bool using_sign_bit = false>
__host__ __device__ void __host__ __device__ void
ApplyDropout(CThreadBuffer& in_thread_buf, ck::philox ph, ZThreadBuffer& z_thread_buf) ApplyDropout(CThreadBuffer& in_thread_buf, ck::philox& ph, ZThreadBuffer& z_thread_buf)
{ {
auto execute_dropout = [&](bool keep, DataType val) { auto execute_dropout = [&](bool keep, DataType val) {
...@@ -86,6 +86,42 @@ struct BlockwiseDropout ...@@ -86,6 +86,42 @@ struct BlockwiseDropout
}); });
} }
template <typename CThreadBuffer,
typename ZThreadBuffer,
bool using_sign_bit,
typename N0,
typename Offset>
__host__ __device__ void
ApplyDropout(CThreadBuffer& in_thread_buf, ck::philox& ph, ZThreadBuffer& z_thread_buf)
{
auto execute_dropout = [&](bool keep, DataType val) {
if constexpr(using_sign_bit)
return keep ? val : -val;
else
return keep ? val * p_dropout_rescale : float(0);
};
constexpr int tmp_size = MRepeat * KRepeat / N0{}.value;
int philox_calls = tmp_size / 8;
ushort tmp[tmp_size];
for(int i = 0; i < philox_calls; i++)
{
ph.get_random_8x16((tmp + i * 8));
}
block_sync_lds();
constexpr auto iOffset = Number<tmp_size>{} * Offset{};
static_for<0, tmp_size, 1>{}([&](auto i) {
in_thread_buf(i + iOffset) =
execute_dropout(tmp[i.value] <= p_dropout_16bits, in_thread_buf(i + iOffset));
z_thread_buf(i) = tmp[i.value];
});
}
ushort p_dropout_16bits; ushort p_dropout_16bits;
DataType p_dropout_rescale; DataType p_dropout_rescale;
}; };
......
...@@ -1455,7 +1455,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1455,7 +1455,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId
I1, // NBlockID I1, // NBlockID
m0, // MRepeat m0, // MRepeat
n0, // NRepeat I1, // NRepeat
m1, // MWaveId m1, // MWaveId
n1, // NWaveId n1, // NWaveId
m2, // MPerXdl m2, // MPerXdl
...@@ -1491,7 +1491,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1491,7 +1491,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
Sequence<I1, // MBlockId Sequence<I1, // MBlockId
I1, // NBlockID I1, // NBlockID
m0, // MRepeat m0, // MRepeat
n0, // NRepeat I1, // NRepeat
m1, // MWaveId m1, // MWaveId
n1, // NWaveId n1, // NWaveId
m2, // MPerXdl m2, // MPerXdl
...@@ -1856,19 +1856,31 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1856,19 +1856,31 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
if(p_z_grid) if(p_z_grid)
{ {
// P_dropped // P_dropped
blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf), static_for<0, n0, 1>{}([&](auto i) {
decltype(z_tenor_buffer), blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf),
true>( decltype(z_tenor_buffer),
s_slash_p_thread_buf, ph, z_tenor_buffer); true,
decltype(n0),
z_thread_copy_vgpr_to_global.Run(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, decltype(i)>(
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), s_slash_p_thread_buf, ph, z_tenor_buffer);
z_tenor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_thread_copy_vgpr_to_global.Run(
z_grid_buf); z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_buf);
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 0, 0, 1, 0, 0, 0, 0, 0, 0));
});
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 0, 0, -n0.value, 0, 0, 0, 0, 0, 0));
} }
else else
{ {
ignore = z_grid_buf;
// P_dropped // P_dropped
blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf), true>( blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf), true>(
s_slash_p_thread_buf, ph); s_slash_p_thread_buf, ph);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment