Commit 0ed92c0c authored by ltqin's avatar ltqin
Browse files

split save z matrix for spill register issue

parent a634c224
......@@ -31,6 +31,7 @@ Kernel outputs:
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <fstream>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
......
......@@ -17,7 +17,7 @@ struct BlockwiseDropout
static constexpr index_t KRepeat = ThreadSliceDesc_M_K{}.GetLength(I1);
template <typename CThreadBuffer, bool using_sign_bit = false>
__host__ __device__ void ApplyDropout(CThreadBuffer& in_thread_buf, ck::philox ph)
__host__ __device__ void ApplyDropout(CThreadBuffer& in_thread_buf, ck::philox& ph)
{
auto execute_dropout = [&](bool keep, DataType val) {
......@@ -52,7 +52,7 @@ struct BlockwiseDropout
template <typename CThreadBuffer, typename ZThreadBuffer, bool using_sign_bit = false>
__host__ __device__ void
ApplyDropout(CThreadBuffer& in_thread_buf, ck::philox ph, ZThreadBuffer& z_thread_buf)
ApplyDropout(CThreadBuffer& in_thread_buf, ck::philox& ph, ZThreadBuffer& z_thread_buf)
{
auto execute_dropout = [&](bool keep, DataType val) {
......@@ -86,6 +86,42 @@ struct BlockwiseDropout
});
}
template <typename CThreadBuffer,
typename ZThreadBuffer,
bool using_sign_bit,
typename N0,
typename Offset>
__host__ __device__ void
ApplyDropout(CThreadBuffer& in_thread_buf, ck::philox& ph, ZThreadBuffer& z_thread_buf)
{
auto execute_dropout = [&](bool keep, DataType val) {
if constexpr(using_sign_bit)
return keep ? val : -val;
else
return keep ? val * p_dropout_rescale : float(0);
};
constexpr int tmp_size = MRepeat * KRepeat / N0{}.value;
int philox_calls = tmp_size / 8;
ushort tmp[tmp_size];
for(int i = 0; i < philox_calls; i++)
{
ph.get_random_8x16((tmp + i * 8));
}
block_sync_lds();
constexpr auto iOffset = Number<tmp_size>{} * Offset{};
static_for<0, tmp_size, 1>{}([&](auto i) {
in_thread_buf(i + iOffset) =
execute_dropout(tmp[i.value] <= p_dropout_16bits, in_thread_buf(i + iOffset));
z_thread_buf(i) = tmp[i.value];
});
}
ushort p_dropout_16bits;
DataType p_dropout_rescale;
};
......
......@@ -1455,7 +1455,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId
I1, // NBlockID
m0, // MRepeat
n0, // NRepeat
I1, // NRepeat
m1, // MWaveId
n1, // NWaveId
m2, // MPerXdl
......@@ -1491,7 +1491,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
Sequence<I1, // MBlockId
I1, // NBlockID
m0, // MRepeat
n0, // NRepeat
I1, // NRepeat
m1, // MWaveId
n1, // NWaveId
m2, // MPerXdl
......@@ -1856,19 +1856,31 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
if(p_z_grid)
{
// P_dropped
blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf),
decltype(z_tenor_buffer),
true>(
s_slash_p_thread_buf, ph, z_tenor_buffer);
z_thread_copy_vgpr_to_global.Run(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_buf);
static_for<0, n0, 1>{}([&](auto i) {
blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf),
decltype(z_tenor_buffer),
true,
decltype(n0),
decltype(i)>(
s_slash_p_thread_buf, ph, z_tenor_buffer);
z_thread_copy_vgpr_to_global.Run(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_buf);
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 0, 0, 1, 0, 0, 0, 0, 0, 0));
});
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 0, 0, -n0.value, 0, 0, 0, 0, 0, 0));
}
else
{
ignore = z_grid_buf;
// P_dropped
blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf), true>(
s_slash_p_thread_buf, ph);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment