Commit d0c65caa authored by guangzlu's avatar guangzlu
Browse files

added switch for lse storing in attn fwd

parent 54dfedcd
......@@ -32,7 +32,8 @@ template <typename GridwiseGemm,
typename B1ElementwiseOperation,
typename CElementwiseOperation,
bool HasMainKBlockLoop,
bool IsDropout>
bool IsDropout,
bool IsLseStoring>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
......@@ -97,18 +98,16 @@ __global__ void
const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
arg_ptr[group_id].compute_base_ptr_of_batch_.GetLSEBasePtr(g_idx)));
// unsigned short* p_z_grid_in = //
// (arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr
// : arg_ptr[group_id].p_z_grid_ + z_batch_offset);
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout, IsLseStoring>(
arg_ptr[group_id].p_a_grid_ + a_batch_offset,
arg_ptr[group_id].p_b_grid_ + b_batch_offset,
arg_ptr[group_id].p_b1_grid_ + b1_batch_offset,
arg_ptr[group_id].p_c_grid_ + c_batch_offset,
arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr
: arg_ptr[group_id].p_z_grid_ + z_batch_offset,
arg_ptr[group_id].p_lse_grid_ + lse_batch_offset,
arg_ptr[group_id].p_lse_grid_ == nullptr ? nullptr
: arg_ptr[group_id].p_lse_grid_ + lse_batch_offset,
// arg_ptr[group_id].p_lse_grid_ + lse_batch_offset,
p_shared,
a_element_op,
b_element_op,
......@@ -589,6 +588,11 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
const auto p_z_grid = static_cast<ZDataType*>(p_z_vec[i]);
const auto p_lse_grid = static_cast<LSEDataType*>(p_lse_vec[i]);
if(p_lse_grid == nullptr)
{
is_lse_storing_ = false;
}
const auto& problem_desc = problem_desc_vec[i];
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
......@@ -724,6 +728,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
unsigned long long offset_;
GemmAccDataType p_dropout_rescale_;
bool is_dropout_;
bool is_lse_storing_ = true;
};
// Invoker
......@@ -756,7 +762,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_, auto is_dropout_) {
auto launch_kernel =
[&](auto has_main_k_block_loop_, auto is_dropout_, auto is_lse_storing_) {
const auto kernel =
kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v2<GridwiseGemm,
GemmAccDataType,
......@@ -767,7 +774,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
B1ElementwiseOperation,
CElementwiseOperation,
has_main_k_block_loop_,
is_dropout_>;
is_dropout_,
is_lse_storing_>;
return launch_and_time_kernel(
stream_config,
......@@ -793,29 +801,69 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
if(all_has_main_k_block_loop)
{
if(arg.is_dropout_)
{
if(arg.is_lse_storing_)
{
ave_time = launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, true>{},
integral_constant<bool, true>{});
}
else
{
ave_time = launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, true>{},
integral_constant<bool, false>{});
}
}
else
{
if(arg.is_lse_storing_)
{
ave_time = launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, false>{},
integral_constant<bool, true>{});
}
else
{
ave_time = launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
}
}
else if(!some_has_main_k_block_loop)
{
if(arg.is_dropout_)
{
if(arg.is_lse_storing_)
{
ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, true>{},
integral_constant<bool, true>{});
}
else
{
ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, true>{},
integral_constant<bool, false>{});
}
}
else
{
if(arg.is_lse_storing_)
{
ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, false>{},
integral_constant<bool, true>{});
}
else
{
ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
}
}
else
{
throw std::runtime_error("wrong! all gemm problems have to simultaneously meet "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment