Commit 2723b268 authored by guangzlu's avatar guangzlu
Browse files

fixed bugs and standardize the code

parent e38d2a5d
...@@ -819,6 +819,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -819,6 +819,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_ = c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(z_grid_desc_m_n_); GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(z_grid_desc_m_n_);
// Print(); // Print();
m_raw_padded_ = GridwiseGemm::GetPaddedSize(raw_lengths_mz_nz_kz_gemm1nz_[0]);
n_raw_padded_ = GridwiseGemm::GetPaddedSize(raw_lengths_mz_nz_kz_gemm1nz_[1]);
} }
void Print() const void Print() const
...@@ -906,6 +909,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -906,6 +909,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
float p_drop_; float p_drop_;
unsigned long long seed_; unsigned long long seed_;
unsigned long long offset_; unsigned long long offset_;
index_t m_raw_padded_;
index_t n_raw_padded_;
}; };
// Invoker // Invoker
...@@ -988,8 +994,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -988,8 +994,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
arg.p_drop_, arg.p_drop_,
arg.seed_, arg.seed_,
arg.offset_, arg.offset_,
arg.raw_lengths_mz_nz_kz_gemm1nz_[0], arg.m_raw_padded_,
arg.raw_lengths_mz_nz_kz_gemm1nz_[1]); arg.n_raw_padded_);
}; };
ave_time = launch_kernel(integral_constant<bool, false>{}); ave_time = launch_kernel(integral_constant<bool, false>{});
......
...@@ -832,6 +832,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -832,6 +832,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_ = c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(z_grid_desc_m_n_); GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(z_grid_desc_m_n_);
// Print(); // Print();
m_raw_padded_ = GridwiseGemm::GetPaddedSize(raw_lengths_mz_nz_kz_gemm1nz_[0]);
n_raw_padded_ = GridwiseGemm::GetPaddedSize(raw_lengths_mz_nz_kz_gemm1nz_[1]);
} }
void Print() const void Print() const
...@@ -919,6 +922,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -919,6 +922,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
float p_drop_; float p_drop_;
unsigned long long seed_; unsigned long long seed_;
unsigned long long offset_; unsigned long long offset_;
index_t m_raw_padded_;
index_t n_raw_padded_;
}; };
// Invoker // Invoker
...@@ -1005,8 +1011,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1005,8 +1011,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
arg.p_drop_, arg.p_drop_,
arg.seed_, arg.seed_,
arg.offset_, arg.offset_,
arg.raw_lengths_mz_nz_kz_gemm1nz_[0], arg.m_raw_padded_,
arg.raw_lengths_mz_nz_kz_gemm1nz_[1]); arg.n_raw_padded_);
}; };
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need // Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
......
...@@ -648,6 +648,9 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -648,6 +648,9 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_ = z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(z_grid_desc_m_n_); GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(z_grid_desc_m_n_);
m_raw_padded_ = GridwiseGemm::GetPaddedSize(raw_lengths_mz_nz_kz_gemm1nz_[0]);
n_raw_padded_ = GridwiseGemm::GetPaddedSize(raw_lengths_mz_nz_kz_gemm1nz_[1]);
if(p_lse_grid == nullptr) if(p_lse_grid == nullptr)
{ {
is_lse_storing_ = false; is_lse_storing_ = false;
...@@ -728,6 +731,9 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -728,6 +731,9 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
bool is_dropout_; bool is_dropout_;
bool is_lse_storing_ = true; bool is_lse_storing_ = true;
index_t m_raw_padded_;
index_t n_raw_padded_;
}; };
// Invoker // Invoker
...@@ -813,8 +819,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -813,8 +819,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
arg.p_dropout_rescale_, arg.p_dropout_rescale_,
arg.seed_, arg.seed_,
arg.offset_, arg.offset_,
arg.raw_lengths_mz_nz_kz_gemm1nz_[0], arg.m_raw_padded_,
arg.raw_lengths_mz_nz_kz_gemm1nz_[1]); arg.n_raw_padded_);
}; };
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need // Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
......
...@@ -133,6 +133,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -133,6 +133,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
make_tuple(Sequence<0, 2, 4, 6, 7, 8>{}, Sequence<1, 3, 5, 9>{})); make_tuple(Sequence<0, 2, 4, 6, 7, 8>{}, Sequence<1, 3, 5, 9>{}));
} }
__host__ __device__ static constexpr auto GetPaddedSize(const index_t size)
{
constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma;
constexpr auto M5 = mfma.group_size;
return index_t(ceil(float(size) / M5) * M5);
}
__device__ static auto GetGemm0WaveIdx() __device__ static auto GetGemm0WaveIdx()
{ {
const index_t thread_id = get_thread_local_1d_id(); const index_t thread_id = get_thread_local_1d_id();
...@@ -1956,12 +1963,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -1956,12 +1963,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
MRaw * NRaw * g_idx + m_global * NRaw + n_global; // unique element global 1d id MRaw * NRaw * g_idx + m_global * NRaw + n_global; // unique element global 1d id
auto global_elem_id = auto global_elem_id =
(global_elem_id_raw % 4) * MRaw + int(global_elem_id_raw / 4) * 4; (global_elem_id_raw % M4) * NRaw + int(global_elem_id_raw / M4) * M4;
blockwise_dropout.template ApplyDropoutAttnBwdSaveZ<decltype(s_slash_p_thread_buf), blockwise_dropout.template ApplyDropoutAttnBwdSaveZ<decltype(s_slash_p_thread_buf),
decltype(z_tenor_buffer), decltype(z_tenor_buffer),
true>( true>(
s_slash_p_thread_buf, ph, global_elem_id, z_tenor_buffer, MRaw); s_slash_p_thread_buf, ph, global_elem_id, z_tenor_buffer, NRaw);
z_thread_copy_vgpr_to_global.Run(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, z_thread_copy_vgpr_to_global.Run(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
...@@ -1983,12 +1990,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -1983,12 +1990,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
MRaw * NRaw * g_idx + m_global * NRaw + n_global; // unique element global 1d id MRaw * NRaw * g_idx + m_global * NRaw + n_global; // unique element global 1d id
auto global_elem_id = auto global_elem_id =
(global_elem_id_raw % 4) * MRaw + int(global_elem_id_raw / 4) * 4; (global_elem_id_raw % M4) * NRaw + int(global_elem_id_raw / M4) * M4;
// P_dropped // P_dropped
blockwise_dropout blockwise_dropout
.template ApplyDropoutAttnBwd<decltype(s_slash_p_thread_buf), true>( .template ApplyDropoutAttnBwd<decltype(s_slash_p_thread_buf), true>(
s_slash_p_thread_buf, ph, global_elem_id, MRaw); s_slash_p_thread_buf, ph, global_elem_id, NRaw);
} }
block_sync_lds(); // wait for gemm1 LDS read block_sync_lds(); // wait for gemm1 LDS read
......
...@@ -147,6 +147,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -147,6 +147,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
make_tuple(Sequence<0, 2, 4, 6, 7, 8>{}, Sequence<1, 3, 5, 9>{})); make_tuple(Sequence<0, 2, 4, 6, 7, 8>{}, Sequence<1, 3, 5, 9>{}));
} }
__host__ __device__ static constexpr auto GetPaddedSize(const index_t size)
{
constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma;
constexpr auto M5 = mfma.group_size;
return index_t(ceil(float(size) / M5) * M5);
}
__device__ static auto GetGemm0WaveIdx() __device__ static auto GetGemm0WaveIdx()
{ {
const index_t thread_id = get_thread_local_1d_id(); const index_t thread_id = get_thread_local_1d_id();
...@@ -1872,12 +1879,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1872,12 +1879,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
MRaw * NRaw * g_idx + m_global * NRaw + n_global; // unique element global 1d id MRaw * NRaw * g_idx + m_global * NRaw + n_global; // unique element global 1d id
auto global_elem_id = auto global_elem_id =
(global_elem_id_raw % 4) * MRaw + int(global_elem_id_raw / 4) * 4; (global_elem_id_raw % M4) * NRaw + int(global_elem_id_raw / M4) * M4;
blockwise_dropout.template ApplyDropoutAttnBwdSaveZ<decltype(s_slash_p_thread_buf), blockwise_dropout.template ApplyDropoutAttnBwdSaveZ<decltype(s_slash_p_thread_buf),
decltype(z_tenor_buffer), decltype(z_tenor_buffer),
true>( true>(
s_slash_p_thread_buf, ph, global_elem_id, z_tenor_buffer, MRaw); s_slash_p_thread_buf, ph, global_elem_id, z_tenor_buffer, NRaw);
z_thread_copy_vgpr_to_global.Run(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, z_thread_copy_vgpr_to_global.Run(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
...@@ -1899,11 +1906,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1899,11 +1906,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
MRaw * NRaw * g_idx + m_global * NRaw + n_global; // unique element global 1d id MRaw * NRaw * g_idx + m_global * NRaw + n_global; // unique element global 1d id
auto global_elem_id = auto global_elem_id =
(global_elem_id_raw % 4) * MRaw + int(global_elem_id_raw / 4) * 4; (global_elem_id_raw % M4) * NRaw + int(global_elem_id_raw / M4) * M4;
// P_dropped // P_dropped
blockwise_dropout blockwise_dropout
.template ApplyDropoutAttnBwd<decltype(s_slash_p_thread_buf), true>( .template ApplyDropoutAttnBwd<decltype(s_slash_p_thread_buf), true>(
s_slash_p_thread_buf, ph, global_elem_id, MRaw); s_slash_p_thread_buf, ph, global_elem_id, NRaw);
} }
block_sync_lds(); // wait for gemm1 LDS read block_sync_lds(); // wait for gemm1 LDS read
......
...@@ -143,6 +143,13 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -143,6 +143,13 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7, 8, 9>{})); make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7, 8, 9>{}));
} }
__host__ __device__ static constexpr auto GetPaddedSize(const index_t size)
{
constexpr auto mfma = MfmaSelector<FloatGemm, MPerXdl, NPerXdl>::selected_mfma;
constexpr auto N5 = mfma.group_size;
return index_t(ceil(float(size) / N5) * N5);
}
__device__ static auto GetGemm0WaveIdx() __device__ static auto GetGemm0WaveIdx()
{ {
const index_t thread_id = get_thread_local_1d_id(); const index_t thread_id = get_thread_local_1d_id();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment