Commit 5a3904c7 authored by fsx950223's avatar fsx950223
Browse files

fix group api

parent e1287b9a
...@@ -35,6 +35,7 @@ template <typename GridwiseGemm, ...@@ -35,6 +35,7 @@ template <typename GridwiseGemm,
typename B1ElementwiseOperation, typename B1ElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
bool IsDropout,
bool Deterministic> bool Deterministic>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
...@@ -104,7 +105,7 @@ __global__ void ...@@ -104,7 +105,7 @@ __global__ void
{ {
for(index_t i = 0; i < num_blocks_per_batch; i++) for(index_t i = 0; i < num_blocks_per_batch; i++)
{ {
GridwiseGemm::template Run<HasMainKBlockLoop>( GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(
arg_ptr[group_id].p_a_grid_ + a_batch_offset, arg_ptr[group_id].p_a_grid_ + a_batch_offset,
arg_ptr[group_id].p_b_grid_ + b_batch_offset, arg_ptr[group_id].p_b_grid_ + b_batch_offset,
z_matrix_ptr, z_matrix_ptr,
...@@ -140,7 +141,7 @@ __global__ void ...@@ -140,7 +141,7 @@ __global__ void
} }
else else
{ {
GridwiseGemm::template Run<HasMainKBlockLoop>( GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(
arg_ptr[group_id].p_a_grid_ + a_batch_offset, arg_ptr[group_id].p_a_grid_ + a_batch_offset,
arg_ptr[group_id].p_b_grid_ + b_batch_offset, arg_ptr[group_id].p_b_grid_ + b_batch_offset,
z_matrix_ptr, z_matrix_ptr,
...@@ -960,7 +961,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -960,7 +961,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
float ave_time = 0; float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) { auto launch_kernel = [&](auto has_main_k_block_loop_, auto is_dropout_) {
const auto kernel = const auto kernel =
kernel_grouped_multihead_attention_backward_qloop_xdl_cshuffle_v1< kernel_grouped_multihead_attention_backward_qloop_xdl_cshuffle_v1<
GridwiseGemm, GridwiseGemm,
...@@ -971,6 +972,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -971,6 +972,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
B1ElementwiseOperation, B1ElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
has_main_k_block_loop_, has_main_k_block_loop_,
is_dropout_,
Deterministic>; Deterministic>;
return launch_and_time_kernel( return launch_and_time_kernel(
...@@ -995,11 +997,17 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -995,11 +997,17 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
// to concern Gemm0's loop // to concern Gemm0's loop
if(all_has_main_k_block_loop) if(all_has_main_k_block_loop)
{ {
ave_time = launch_kernel(integral_constant<bool, true>{}); if(arg.p_dropout_ > 0.0)
ave_time = launch_kernel(integral_constant<bool, true>{}, integral_constant<bool, true>{});
else
ave_time = launch_kernel(integral_constant<bool, true>{}, integral_constant<bool, false>{});
} }
else if(!some_has_main_k_block_loop) else if(!some_has_main_k_block_loop)
{ {
ave_time = launch_kernel(integral_constant<bool, false>{}); if(arg.p_dropout_ > 0.0)
ave_time = launch_kernel(integral_constant<bool, false>{}, integral_constant<bool, true>{});
else
ave_time = launch_kernel(integral_constant<bool, false>{}, integral_constant<bool, false>{});
} }
else else
{ {
......
...@@ -35,6 +35,7 @@ template <typename GridwiseGemm, ...@@ -35,6 +35,7 @@ template <typename GridwiseGemm,
typename B1ElementwiseOperation, typename B1ElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
bool IsDropout,
bool Deterministic> bool Deterministic>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
...@@ -104,7 +105,7 @@ __global__ void ...@@ -104,7 +105,7 @@ __global__ void
{ {
for(index_t i = 0; i < num_blocks_per_batch; i++) for(index_t i = 0; i < num_blocks_per_batch; i++)
{ {
GridwiseGemm::template Run<HasMainKBlockLoop>( GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(
arg_ptr[group_id].p_a_grid_ + a_batch_offset, arg_ptr[group_id].p_a_grid_ + a_batch_offset,
arg_ptr[group_id].p_b_grid_ + b_batch_offset, arg_ptr[group_id].p_b_grid_ + b_batch_offset,
z_matrix_ptr, z_matrix_ptr,
...@@ -140,7 +141,7 @@ __global__ void ...@@ -140,7 +141,7 @@ __global__ void
} }
else else
{ {
GridwiseGemm::template Run<HasMainKBlockLoop>( GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(
arg_ptr[group_id].p_a_grid_ + a_batch_offset, arg_ptr[group_id].p_a_grid_ + a_batch_offset,
arg_ptr[group_id].p_b_grid_ + b_batch_offset, arg_ptr[group_id].p_b_grid_ + b_batch_offset,
z_matrix_ptr, z_matrix_ptr,
...@@ -967,7 +968,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -967,7 +968,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
float ave_time = 0; float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) { auto launch_kernel = [&](auto has_main_k_block_loop_, auto is_dropout_) {
const auto kernel = const auto kernel =
kernel_grouped_multihead_attention_backward_qloop_xdl_cshuffle_v2< kernel_grouped_multihead_attention_backward_qloop_xdl_cshuffle_v2<
GridwiseGemm, GridwiseGemm,
...@@ -978,6 +979,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -978,6 +979,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
B1ElementwiseOperation, B1ElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
has_main_k_block_loop_, has_main_k_block_loop_,
is_dropout_,
Deterministic>; Deterministic>;
return launch_and_time_kernel( return launch_and_time_kernel(
...@@ -1002,11 +1004,17 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1002,11 +1004,17 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// to concern Gemm0's loop // to concern Gemm0's loop
if(all_has_main_k_block_loop) if(all_has_main_k_block_loop)
{ {
ave_time = launch_kernel(integral_constant<bool, true>{}); if(arg.p_dropout_ > 0.0)
ave_time = launch_kernel(integral_constant<bool, true>{}, integral_constant<bool, true>{});
else
ave_time = launch_kernel(integral_constant<bool, true>{}, integral_constant<bool, false>{});
} }
else if(!some_has_main_k_block_loop) else if(!some_has_main_k_block_loop)
{ {
ave_time = launch_kernel(integral_constant<bool, false>{}); if(arg.p_dropout_ > 0.0)
ave_time = launch_kernel(integral_constant<bool, false>{}, integral_constant<bool, true>{});
else
ave_time = launch_kernel(integral_constant<bool, false>{}, integral_constant<bool, false>{});
} }
else else
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment