"lightx2v/git@developer.sourcefind.cn:Wenxuan/LightX2V.git" did not exist on "7fde70631ce0b7f67cb2476b44934cde93a2944d"
Commit 0999070d authored by danyao12's avatar danyao12
Browse files

Merge branch 'attn-train-develop-qloop' into mha-train-develop

parents 5ba30232 68e3bb6d
...@@ -1047,7 +1047,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1 ...@@ -1047,7 +1047,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
#if 0 #if DEBUG_LOG
arg.Print(); arg.Print();
#endif #endif
......
...@@ -1048,7 +1048,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2 ...@@ -1048,7 +1048,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
#if 0 #if DEBUG_LOG
arg.Print(); arg.Print();
#endif #endif
......
...@@ -1041,7 +1041,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Phased_Xdl_CShuffle_V1 ...@@ -1041,7 +1041,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Phased_Xdl_CShuffle_V1
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
#if 0 #if DEBUG_LOG
arg.Print(); arg.Print();
#endif #endif
......
...@@ -48,6 +48,7 @@ template <typename GridwiseGemm, ...@@ -48,6 +48,7 @@ template <typename GridwiseGemm,
typename ComputeBasePtrOfStridedBatch, typename ComputeBasePtrOfStridedBatch,
typename C0MatrixMask, typename C0MatrixMask,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
bool IsDropout,
bool Deterministic> bool Deterministic>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
...@@ -120,7 +121,7 @@ __global__ void ...@@ -120,7 +121,7 @@ __global__ void
{ {
for(index_t i = 0; i < nblock; i++) for(index_t i = 0; i < nblock; i++)
{ {
GridwiseGemm::template Run<HasMainKBlockLoop>( GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(
p_a_grid + a_batch_offset, p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset, p_b_grid + b_batch_offset,
z_matrix_ptr, z_matrix_ptr,
...@@ -155,7 +156,7 @@ __global__ void ...@@ -155,7 +156,7 @@ __global__ void
} }
else else
{ {
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset, GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset, p_b_grid + b_batch_offset,
z_matrix_ptr, z_matrix_ptr,
p_b1_grid + b1_batch_offset, p_b1_grid + b1_batch_offset,
...@@ -933,7 +934,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -933,7 +934,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
float ave_time = 0; float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) { auto launch_kernel = [&](auto has_main_k_block_loop_, auto is_dropout_) {
const auto kernel = const auto kernel =
kernel_batched_multihead_attention_backward_qloop_xdl_cshuffle_v1< kernel_batched_multihead_attention_backward_qloop_xdl_cshuffle_v1<
GridwiseGemm, GridwiseGemm,
...@@ -957,6 +958,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -957,6 +958,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
ComputeBasePtrOfStridedBatch, ComputeBasePtrOfStridedBatch,
C0MatrixMask, C0MatrixMask,
has_main_k_block_loop_, has_main_k_block_loop_,
is_dropout_,
Deterministic>; Deterministic>;
return launch_and_time_kernel( return launch_and_time_kernel(
...@@ -998,9 +1000,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -998,9 +1000,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
arg.m_raw_padded_, arg.m_raw_padded_,
arg.n_raw_padded_); arg.n_raw_padded_);
}; };
if(arg.p_drop_ > 0.0){
ave_time = launch_kernel(integral_constant<bool, false>{}); ave_time = launch_kernel(integral_constant<bool, false>{}, integral_constant<bool, true>{});
}else{
ave_time = launch_kernel(integral_constant<bool, false>{}, integral_constant<bool, false>{});
}
return ave_time; return ave_time;
} }
...@@ -1020,6 +1024,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1020,6 +1024,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
#if DEBUG_LOG
arg.Print();
#endif
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" || ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
......
...@@ -47,6 +47,7 @@ template <typename GridwiseGemm, ...@@ -47,6 +47,7 @@ template <typename GridwiseGemm,
typename ComputeBasePtrOfStridedBatch, typename ComputeBasePtrOfStridedBatch,
typename C0MatrixMask, typename C0MatrixMask,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
bool IsDropout,
bool Deterministic> bool Deterministic>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
...@@ -119,7 +120,7 @@ __global__ void ...@@ -119,7 +120,7 @@ __global__ void
{ {
for(index_t i = 0; i < nblock; i++) for(index_t i = 0; i < nblock; i++)
{ {
GridwiseGemm::template Run<HasMainKBlockLoop>( GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(
p_a_grid + a_batch_offset, p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset, p_b_grid + b_batch_offset,
z_matrix_ptr, z_matrix_ptr,
...@@ -154,7 +155,7 @@ __global__ void ...@@ -154,7 +155,7 @@ __global__ void
} }
else else
{ {
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset, GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset, p_b_grid + b_batch_offset,
z_matrix_ptr, z_matrix_ptr,
p_b1_grid + b1_batch_offset, p_b1_grid + b1_batch_offset,
...@@ -950,7 +951,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -950,7 +951,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
float ave_time = 0; float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) { auto launch_kernel = [&](auto has_main_k_block_loop_, auto is_dropout_) {
const auto kernel = const auto kernel =
kernel_batched_multihead_attention_backward_qloop_xdl_cshuffle_v2< kernel_batched_multihead_attention_backward_qloop_xdl_cshuffle_v2<
GridwiseGemm, GridwiseGemm,
...@@ -974,6 +975,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -974,6 +975,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
ComputeBasePtrOfStridedBatch, ComputeBasePtrOfStridedBatch,
C0MatrixMask, C0MatrixMask,
has_main_k_block_loop_, has_main_k_block_loop_,
is_dropout_,
Deterministic>; Deterministic>;
return launch_and_time_kernel( return launch_and_time_kernel(
...@@ -1021,11 +1023,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1021,11 +1023,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{ {
ave_time = launch_kernel(integral_constant<bool, true>{}); if(arg.p_drop_ > 0.0)
ave_time = launch_kernel(integral_constant<bool, true>{}, integral_constant<bool, true>{});
else
ave_time = launch_kernel(integral_constant<bool, true>{}, integral_constant<bool, false>{});
} }
else else
{ {
ave_time = launch_kernel(integral_constant<bool, false>{}); if(arg.p_drop_ > 0.0)
ave_time = launch_kernel(integral_constant<bool, false>{}, integral_constant<bool, true>{});
else
ave_time = launch_kernel(integral_constant<bool, false>{}, integral_constant<bool, false>{});
} }
return ave_time; return ave_time;
...@@ -1047,6 +1055,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1047,6 +1055,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
#if DEBUG_LOG
arg.Print();
#endif
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" || ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
......
...@@ -35,6 +35,7 @@ template <typename GridwiseGemm, ...@@ -35,6 +35,7 @@ template <typename GridwiseGemm,
typename B1ElementwiseOperation, typename B1ElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
bool IsDropout,
bool Deterministic> bool Deterministic>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
...@@ -105,7 +106,7 @@ __global__ void ...@@ -105,7 +106,7 @@ __global__ void
{ {
for(index_t i = 0; i < num_blocks_per_batch; i++) for(index_t i = 0; i < num_blocks_per_batch; i++)
{ {
GridwiseGemm::template Run<HasMainKBlockLoop>( GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(
arg_ptr[group_id].p_a_grid_ + a_batch_offset, arg_ptr[group_id].p_a_grid_ + a_batch_offset,
arg_ptr[group_id].p_b_grid_ + b_batch_offset, arg_ptr[group_id].p_b_grid_ + b_batch_offset,
z_matrix_ptr, z_matrix_ptr,
...@@ -141,7 +142,7 @@ __global__ void ...@@ -141,7 +142,7 @@ __global__ void
} }
else else
{ {
GridwiseGemm::template Run<HasMainKBlockLoop>( GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(
arg_ptr[group_id].p_a_grid_ + a_batch_offset, arg_ptr[group_id].p_a_grid_ + a_batch_offset,
arg_ptr[group_id].p_b_grid_ + b_batch_offset, arg_ptr[group_id].p_b_grid_ + b_batch_offset,
z_matrix_ptr, z_matrix_ptr,
...@@ -961,7 +962,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -961,7 +962,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
float ave_time = 0; float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) { auto launch_kernel = [&](auto has_main_k_block_loop_, auto is_dropout_) {
const auto kernel = const auto kernel =
kernel_grouped_multihead_attention_backward_qloop_xdl_cshuffle_v1< kernel_grouped_multihead_attention_backward_qloop_xdl_cshuffle_v1<
GridwiseGemm, GridwiseGemm,
...@@ -972,6 +973,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -972,6 +973,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
B1ElementwiseOperation, B1ElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
has_main_k_block_loop_, has_main_k_block_loop_,
is_dropout_,
Deterministic>; Deterministic>;
return launch_and_time_kernel( return launch_and_time_kernel(
...@@ -996,11 +998,17 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -996,11 +998,17 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
// to concern Gemm0's loop // to concern Gemm0's loop
if(all_has_main_k_block_loop) if(all_has_main_k_block_loop)
{ {
ave_time = launch_kernel(integral_constant<bool, true>{}); if(arg.p_dropout_ > 0.0)
ave_time = launch_kernel(integral_constant<bool, true>{}, integral_constant<bool, true>{});
else
ave_time = launch_kernel(integral_constant<bool, true>{}, integral_constant<bool, false>{});
} }
else if(!some_has_main_k_block_loop) else if(!some_has_main_k_block_loop)
{ {
ave_time = launch_kernel(integral_constant<bool, false>{}); if(arg.p_dropout_ > 0.0)
ave_time = launch_kernel(integral_constant<bool, false>{}, integral_constant<bool, true>{});
else
ave_time = launch_kernel(integral_constant<bool, false>{}, integral_constant<bool, false>{});
} }
else else
{ {
......
...@@ -35,6 +35,7 @@ template <typename GridwiseGemm, ...@@ -35,6 +35,7 @@ template <typename GridwiseGemm,
typename B1ElementwiseOperation, typename B1ElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
bool IsDropout,
bool Deterministic> bool Deterministic>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
...@@ -105,7 +106,7 @@ __global__ void ...@@ -105,7 +106,7 @@ __global__ void
{ {
for(index_t i = 0; i < num_blocks_per_batch; i++) for(index_t i = 0; i < num_blocks_per_batch; i++)
{ {
GridwiseGemm::template Run<HasMainKBlockLoop>( GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(
arg_ptr[group_id].p_a_grid_ + a_batch_offset, arg_ptr[group_id].p_a_grid_ + a_batch_offset,
arg_ptr[group_id].p_b_grid_ + b_batch_offset, arg_ptr[group_id].p_b_grid_ + b_batch_offset,
z_matrix_ptr, z_matrix_ptr,
...@@ -141,7 +142,7 @@ __global__ void ...@@ -141,7 +142,7 @@ __global__ void
} }
else else
{ {
GridwiseGemm::template Run<HasMainKBlockLoop>( GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(
arg_ptr[group_id].p_a_grid_ + a_batch_offset, arg_ptr[group_id].p_a_grid_ + a_batch_offset,
arg_ptr[group_id].p_b_grid_ + b_batch_offset, arg_ptr[group_id].p_b_grid_ + b_batch_offset,
z_matrix_ptr, z_matrix_ptr,
...@@ -968,7 +969,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -968,7 +969,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
float ave_time = 0; float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) { auto launch_kernel = [&](auto has_main_k_block_loop_, auto is_dropout_) {
const auto kernel = const auto kernel =
kernel_grouped_multihead_attention_backward_qloop_xdl_cshuffle_v2< kernel_grouped_multihead_attention_backward_qloop_xdl_cshuffle_v2<
GridwiseGemm, GridwiseGemm,
...@@ -979,6 +980,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -979,6 +980,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
B1ElementwiseOperation, B1ElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
has_main_k_block_loop_, has_main_k_block_loop_,
is_dropout_,
Deterministic>; Deterministic>;
return launch_and_time_kernel( return launch_and_time_kernel(
...@@ -1003,11 +1005,17 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1003,11 +1005,17 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// to concern Gemm0's loop // to concern Gemm0's loop
if(all_has_main_k_block_loop) if(all_has_main_k_block_loop)
{ {
ave_time = launch_kernel(integral_constant<bool, true>{}); if(arg.p_dropout_ > 0.0)
ave_time = launch_kernel(integral_constant<bool, true>{}, integral_constant<bool, true>{});
else
ave_time = launch_kernel(integral_constant<bool, true>{}, integral_constant<bool, false>{});
} }
else if(!some_has_main_k_block_loop) else if(!some_has_main_k_block_loop)
{ {
ave_time = launch_kernel(integral_constant<bool, false>{}); if(arg.p_dropout_ > 0.0)
ave_time = launch_kernel(integral_constant<bool, false>{}, integral_constant<bool, true>{});
else
ave_time = launch_kernel(integral_constant<bool, false>{}, integral_constant<bool, false>{});
} }
else else
{ {
......
...@@ -1222,6 +1222,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1222,6 +1222,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
} }
template <bool HasMainKBlockLoop, template <bool HasMainKBlockLoop,
bool IsDropout,
typename Block2CTileMap, typename Block2CTileMap,
typename C0MatrixMask, typename C0MatrixMask,
typename YGradGridDesc_O0_M_O1> typename YGradGridDesc_O0_M_O1>
...@@ -1947,6 +1948,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1947,6 +1948,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
constexpr auto position_offset = M3 * M4; constexpr auto position_offset = M3 * M4;
// save z to global // save z to global
if constexpr(IsDropout){
if(p_z_grid) if(p_z_grid)
{ {
...@@ -1996,7 +1998,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1996,7 +1998,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
true>( true>(
s_slash_p_thread_buf, ph, global_elem_id, raw_n_padded); s_slash_p_thread_buf, ph, global_elem_id, raw_n_padded);
} }
}
block_sync_lds(); // wait for gemm1 LDS read block_sync_lds(); // wait for gemm1 LDS read
// dS = P * (dP - Y_dot_dY) // dS = P * (dP - Y_dot_dY)
......
...@@ -1154,6 +1154,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1154,6 +1154,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
} }
template <bool HasMainKBlockLoop, template <bool HasMainKBlockLoop,
bool IsDropout,
typename Block2CTileMap, typename Block2CTileMap,
typename C0MatrixMask, typename C0MatrixMask,
typename YGradGridDesc_M0_O_M1> typename YGradGridDesc_M0_O_M1>
...@@ -1863,6 +1864,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1863,6 +1864,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
constexpr auto position_offset = M3 * M4; constexpr auto position_offset = M3 * M4;
// save z to global // save z to global
if constexpr(IsDropout){
if(p_z_grid) if(p_z_grid)
{ {
...@@ -1911,7 +1913,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1911,7 +1913,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
true>( true>(
s_slash_p_thread_buf, ph, global_elem_id, raw_n_padded); s_slash_p_thread_buf, ph, global_elem_id, raw_n_padded);
} }
}
block_sync_lds(); // wait for gemm1 LDS read block_sync_lds(); // wait for gemm1 LDS read
// gemm dV // gemm dV
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment