Commit d0c65caa authored by guangzlu's avatar guangzlu
Browse files

added switch for lse storing in attn fwd

parent 54dfedcd
...@@ -32,7 +32,8 @@ template <typename GridwiseGemm, ...@@ -32,7 +32,8 @@ template <typename GridwiseGemm,
typename B1ElementwiseOperation, typename B1ElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
bool IsDropout> bool IsDropout,
bool IsLseStoring>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
...@@ -97,18 +98,16 @@ __global__ void ...@@ -97,18 +98,16 @@ __global__ void
const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane(static_cast<long_index_t>( const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
arg_ptr[group_id].compute_base_ptr_of_batch_.GetLSEBasePtr(g_idx))); arg_ptr[group_id].compute_base_ptr_of_batch_.GetLSEBasePtr(g_idx)));
// unsigned short* p_z_grid_in = // GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout, IsLseStoring>(
// (arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr
// : arg_ptr[group_id].p_z_grid_ + z_batch_offset);
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(
arg_ptr[group_id].p_a_grid_ + a_batch_offset, arg_ptr[group_id].p_a_grid_ + a_batch_offset,
arg_ptr[group_id].p_b_grid_ + b_batch_offset, arg_ptr[group_id].p_b_grid_ + b_batch_offset,
arg_ptr[group_id].p_b1_grid_ + b1_batch_offset, arg_ptr[group_id].p_b1_grid_ + b1_batch_offset,
arg_ptr[group_id].p_c_grid_ + c_batch_offset, arg_ptr[group_id].p_c_grid_ + c_batch_offset,
arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr
: arg_ptr[group_id].p_z_grid_ + z_batch_offset, : arg_ptr[group_id].p_z_grid_ + z_batch_offset,
arg_ptr[group_id].p_lse_grid_ + lse_batch_offset, arg_ptr[group_id].p_lse_grid_ == nullptr ? nullptr
: arg_ptr[group_id].p_lse_grid_ + lse_batch_offset,
// arg_ptr[group_id].p_lse_grid_ + lse_batch_offset,
p_shared, p_shared,
a_element_op, a_element_op,
b_element_op, b_element_op,
...@@ -589,6 +588,11 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle ...@@ -589,6 +588,11 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
const auto p_z_grid = static_cast<ZDataType*>(p_z_vec[i]); const auto p_z_grid = static_cast<ZDataType*>(p_z_vec[i]);
const auto p_lse_grid = static_cast<LSEDataType*>(p_lse_vec[i]); const auto p_lse_grid = static_cast<LSEDataType*>(p_lse_vec[i]);
if(p_lse_grid == nullptr)
{
is_lse_storing_ = false;
}
const auto& problem_desc = problem_desc_vec[i]; const auto& problem_desc = problem_desc_vec[i];
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
...@@ -724,6 +728,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle ...@@ -724,6 +728,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
unsigned long long offset_; unsigned long long offset_;
GemmAccDataType p_dropout_rescale_; GemmAccDataType p_dropout_rescale_;
bool is_dropout_; bool is_dropout_;
bool is_lse_storing_ = true;
}; };
// Invoker // Invoker
...@@ -756,37 +762,39 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle ...@@ -756,37 +762,39 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
float ave_time = 0; float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_, auto is_dropout_) { auto launch_kernel =
const auto kernel = [&](auto has_main_k_block_loop_, auto is_dropout_, auto is_lse_storing_) {
kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v2<GridwiseGemm, const auto kernel =
GemmAccDataType, kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v2<GridwiseGemm,
GroupKernelArg, GemmAccDataType,
AElementwiseOperation, GroupKernelArg,
BElementwiseOperation, AElementwiseOperation,
AccElementwiseOperation, BElementwiseOperation,
B1ElementwiseOperation, AccElementwiseOperation,
CElementwiseOperation, B1ElementwiseOperation,
has_main_k_block_loop_, CElementwiseOperation,
is_dropout_>; has_main_k_block_loop_,
is_dropout_,
return launch_and_time_kernel( is_lse_storing_>;
stream_config,
kernel, return launch_and_time_kernel(
dim3(arg.grid_size_), stream_config,
dim3(BlockSize), kernel,
0, dim3(arg.grid_size_),
cast_pointer_to_constant_address_space(arg.p_workspace_), dim3(BlockSize),
arg.group_count_, 0,
arg.a_element_op_, cast_pointer_to_constant_address_space(arg.p_workspace_),
arg.b_element_op_, arg.group_count_,
arg.acc_element_op_, arg.a_element_op_,
arg.b1_element_op_, arg.b_element_op_,
arg.c_element_op_, arg.acc_element_op_,
arg.p_dropout_in_16bits_, arg.b1_element_op_,
arg.p_dropout_rescale_, arg.c_element_op_,
arg.seed_, arg.p_dropout_in_16bits_,
arg.offset_); arg.p_dropout_rescale_,
}; arg.seed_,
arg.offset_);
};
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need // Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
// to concern Gemm0's loop // to concern Gemm0's loop
...@@ -794,26 +802,66 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle ...@@ -794,26 +802,66 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
{ {
if(arg.is_dropout_) if(arg.is_dropout_)
{ {
ave_time = launch_kernel(integral_constant<bool, true>{}, if(arg.is_lse_storing_)
integral_constant<bool, true>{}); {
ave_time = launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, true>{},
integral_constant<bool, true>{});
}
else
{
ave_time = launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, true>{},
integral_constant<bool, false>{});
}
} }
else else
{ {
ave_time = launch_kernel(integral_constant<bool, true>{}, if(arg.is_lse_storing_)
integral_constant<bool, false>{}); {
ave_time = launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, false>{},
integral_constant<bool, true>{});
}
else
{
ave_time = launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
} }
} }
else if(!some_has_main_k_block_loop) else if(!some_has_main_k_block_loop)
{ {
if(arg.is_dropout_) if(arg.is_dropout_)
{ {
ave_time = launch_kernel(integral_constant<bool, false>{}, if(arg.is_lse_storing_)
integral_constant<bool, true>{}); {
ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, true>{},
integral_constant<bool, true>{});
}
else
{
ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, true>{},
integral_constant<bool, false>{});
}
} }
else else
{ {
ave_time = launch_kernel(integral_constant<bool, false>{}, if(arg.is_lse_storing_)
integral_constant<bool, false>{}); {
ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, false>{},
integral_constant<bool, true>{});
}
else
{
ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
} }
} }
else else
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment