Commit 5a3904c7 authored by fsx950223's avatar fsx950223
Browse files

fix group api

parent e1287b9a
......@@ -35,6 +35,7 @@ template <typename GridwiseGemm,
typename B1ElementwiseOperation,
typename CElementwiseOperation,
bool HasMainKBlockLoop,
bool IsDropout,
bool Deterministic>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
......@@ -104,7 +105,7 @@ __global__ void
{
for(index_t i = 0; i < num_blocks_per_batch; i++)
{
GridwiseGemm::template Run<HasMainKBlockLoop>(
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(
arg_ptr[group_id].p_a_grid_ + a_batch_offset,
arg_ptr[group_id].p_b_grid_ + b_batch_offset,
z_matrix_ptr,
......@@ -140,7 +141,7 @@ __global__ void
}
else
{
GridwiseGemm::template Run<HasMainKBlockLoop>(
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(
arg_ptr[group_id].p_a_grid_ + a_batch_offset,
arg_ptr[group_id].p_b_grid_ + b_batch_offset,
z_matrix_ptr,
......@@ -960,7 +961,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) {
auto launch_kernel = [&](auto has_main_k_block_loop_, auto is_dropout_) {
const auto kernel =
kernel_grouped_multihead_attention_backward_qloop_xdl_cshuffle_v1<
GridwiseGemm,
......@@ -971,6 +972,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
B1ElementwiseOperation,
CElementwiseOperation,
has_main_k_block_loop_,
is_dropout_,
Deterministic>;
return launch_and_time_kernel(
......@@ -995,11 +997,17 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
// to concern Gemm0's loop
if(all_has_main_k_block_loop)
{
ave_time = launch_kernel(integral_constant<bool, true>{});
if(arg.p_dropout_ > 0.0)
ave_time = launch_kernel(integral_constant<bool, true>{}, integral_constant<bool, true>{});
else
ave_time = launch_kernel(integral_constant<bool, true>{}, integral_constant<bool, false>{});
}
else if(!some_has_main_k_block_loop)
{
ave_time = launch_kernel(integral_constant<bool, false>{});
if(arg.p_dropout_ > 0.0)
ave_time = launch_kernel(integral_constant<bool, false>{}, integral_constant<bool, true>{});
else
ave_time = launch_kernel(integral_constant<bool, false>{}, integral_constant<bool, false>{});
}
else
{
......
......@@ -35,6 +35,7 @@ template <typename GridwiseGemm,
typename B1ElementwiseOperation,
typename CElementwiseOperation,
bool HasMainKBlockLoop,
bool IsDropout,
bool Deterministic>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
......@@ -104,7 +105,7 @@ __global__ void
{
for(index_t i = 0; i < num_blocks_per_batch; i++)
{
GridwiseGemm::template Run<HasMainKBlockLoop>(
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(
arg_ptr[group_id].p_a_grid_ + a_batch_offset,
arg_ptr[group_id].p_b_grid_ + b_batch_offset,
z_matrix_ptr,
......@@ -140,7 +141,7 @@ __global__ void
}
else
{
GridwiseGemm::template Run<HasMainKBlockLoop>(
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(
arg_ptr[group_id].p_a_grid_ + a_batch_offset,
arg_ptr[group_id].p_b_grid_ + b_batch_offset,
z_matrix_ptr,
......@@ -967,7 +968,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) {
auto launch_kernel = [&](auto has_main_k_block_loop_, auto is_dropout_) {
const auto kernel =
kernel_grouped_multihead_attention_backward_qloop_xdl_cshuffle_v2<
GridwiseGemm,
......@@ -978,6 +979,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
B1ElementwiseOperation,
CElementwiseOperation,
has_main_k_block_loop_,
is_dropout_,
Deterministic>;
return launch_and_time_kernel(
......@@ -1002,11 +1004,17 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// to concern Gemm0's loop
if(all_has_main_k_block_loop)
{
ave_time = launch_kernel(integral_constant<bool, true>{});
if(arg.p_dropout_ > 0.0)
ave_time = launch_kernel(integral_constant<bool, true>{}, integral_constant<bool, true>{});
else
ave_time = launch_kernel(integral_constant<bool, true>{}, integral_constant<bool, false>{});
}
else if(!some_has_main_k_block_loop)
{
ave_time = launch_kernel(integral_constant<bool, false>{});
if(arg.p_dropout_ > 0.0)
ave_time = launch_kernel(integral_constant<bool, false>{}, integral_constant<bool, true>{});
else
ave_time = launch_kernel(integral_constant<bool, false>{}, integral_constant<bool, false>{});
}
else
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment