Commit 6e676287 authored by ltqin's avatar ltqin
Browse files

rewrite acc_elementwise_op parameter

parent f99f419d
...@@ -51,7 +51,7 @@ __global__ void ...@@ -51,7 +51,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v2( kernel_batched_multihead_attention_backward_xdl_cshuffle_v2(
const DataType* __restrict__ p_a_grid, const DataType* __restrict__ p_a_grid,
const DataType* __restrict__ p_b_grid, const DataType* __restrict__ p_b_grid,
ZDataType* __restrict__ p_z_grid, ZDataType* __restrict__ p_z_grid,
...@@ -784,8 +784,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -784,8 +784,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
p_dropout_ = 1.f - p_drop; p_dropout_ = 1.f - p_drop;
is_dropout_ = p_drop > 0.0f; is_dropout_ = p_drop > 0.0f;
float rp_dropout_ = 1.f / p_dropout_;
acc_element_op_.Append(rp_dropout_);
seed_ = std::get<0>(seeds); seed_ = std::get<0>(seeds);
offset_ = std::get<1>(seeds); offset_ = std::get<1>(seeds);
...@@ -905,7 +903,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -905,7 +903,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
float ave_time = 0; float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) { auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v2< const auto kernel = kernel_batched_multihead_attention_backward_xdl_cshuffle_v2<
GridwiseGemm, GridwiseGemm,
DataType, DataType,
ZDataType, ZDataType,
......
...@@ -95,7 +95,7 @@ struct Scale ...@@ -95,7 +95,7 @@ struct Scale
y = scale_ * x; y = scale_ * x;
}; };
__host__ __device__ void Append(float scale) { scale_ = scale_ * scale; } __host__ __device__ auto Value() const { return scale_; }
float scale_; float scale_;
}; };
......
...@@ -1175,6 +1175,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1175,6 +1175,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
{ {
const ushort p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0)); const ushort p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0));
const FloatGemmAcc rp_dropout = 1.0f / p_dropout; const FloatGemmAcc rp_dropout = 1.0f / p_dropout;
const tensor_operation::element_wise::Scale scale_rp_dropout(s_element_op.Value() *
rp_dropout);
const auto q_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto q_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_q_grid, q_grid_desc_k0_m_k1.GetElementSpaceSize()); p_q_grid, q_grid_desc_k0_m_k1.GetElementSpaceSize());
...@@ -1604,9 +1606,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1604,9 +1606,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
auto kgrad_thread_copy_vgpr_to_global = typename Gemm2::template CBlockwiseCopy< auto kgrad_thread_copy_vgpr_to_global = typename Gemm2::template CBlockwiseCopy<
decltype(kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4), decltype(kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4),
decltype(s_element_op)>(kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4, decltype(scale_rp_dropout)>(kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4,
kgrad_thread_origin_on_grid_n0_o0_n1_o1_n2_o2_o3_o4, kgrad_thread_origin_on_grid_n0_o0_n1_o1_n2_o2_o3_o4,
s_element_op); scale_rp_dropout);
// //
// set up Y dot dY // set up Y dot dY
...@@ -1749,8 +1751,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1749,8 +1751,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
const index_t num_gemm1_k_block_outer_loop = k_grid_desc_k0_n_k1.GetLength(I1) / NPerBlock; const index_t num_gemm1_k_block_outer_loop = k_grid_desc_k0_n_k1.GetLength(I1) / NPerBlock;
constexpr index_t num_gemm1_k_block_inner_loop = NPerBlock / Gemm1KPerBlock; constexpr index_t num_gemm1_k_block_inner_loop = NPerBlock / Gemm1KPerBlock;
const index_t K = k_grid_desc_k0_n_k1.GetLength(I0) * k_grid_desc_k0_n_k1.GetLength(I2);
const float scalar = 1.0f / std::sqrt(K);
// Initialize dQ // Initialize dQ
qgrad_thread_buf.Clear(); qgrad_thread_buf.Clear();
...@@ -1831,14 +1831,15 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1831,14 +1831,15 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
} }
else else
{ {
s_slash_p_thread_buf(i) = scalar * s_slash_p_thread_buf[i]; s_element_op(s_slash_p_thread_buf(i), s_slash_p_thread_buf[i]);
} }
}); });
} }
else else
{ {
static_for<0, s_slash_p_thread_buf.Size(), 1>{}( static_for<0, s_slash_p_thread_buf.Size(), 1>{}([&](auto i) {
[&](auto i) { s_slash_p_thread_buf(i) = scalar * s_slash_p_thread_buf[i]; }); s_element_op(s_slash_p_thread_buf(i), s_slash_p_thread_buf[i]);
});
} }
block_sync_lds(); // wait for lds read in gemm0 blockwise gemm block_sync_lds(); // wait for lds read in gemm0 blockwise gemm
...@@ -2230,7 +2231,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -2230,7 +2231,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
n_thread_data_on_block_idx[I2], n_thread_data_on_block_idx[I2],
n_thread_data_on_block_idx[I3], n_thread_data_on_block_idx[I3],
n_thread_data_on_block_idx[I4]), n_thread_data_on_block_idx[I4]),
s_element_op}; scale_rp_dropout};
// shuffle: blockwise copy C from LDS to global // shuffle: blockwise copy C from LDS to global
auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment