Commit 6e676287 authored by ltqin's avatar ltqin
Browse files

rewrite acc_elementwise_op parameter

parent f99f419d
......@@ -51,7 +51,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v2(
kernel_batched_multihead_attention_backward_xdl_cshuffle_v2(
const DataType* __restrict__ p_a_grid,
const DataType* __restrict__ p_b_grid,
ZDataType* __restrict__ p_z_grid,
......@@ -782,10 +782,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
y_grid_desc_m_o_);
}
p_dropout_ = 1.f - p_drop;
is_dropout_ = p_drop > 0.0f;
float rp_dropout_ = 1.f / p_dropout_;
acc_element_op_.Append(rp_dropout_);
p_dropout_ = 1.f - p_drop;
is_dropout_ = p_drop > 0.0f;
seed_ = std::get<0>(seeds);
offset_ = std::get<1>(seeds);
......@@ -905,7 +903,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v2<
const auto kernel = kernel_batched_multihead_attention_backward_xdl_cshuffle_v2<
GridwiseGemm,
DataType,
ZDataType,
......
......@@ -95,7 +95,7 @@ struct Scale
y = scale_ * x;
};
__host__ __device__ void Append(float scale) { scale_ = scale_ * scale; }
__host__ __device__ auto Value() const { return scale_; }
float scale_;
};
......
......@@ -1175,6 +1175,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
{
const ushort p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0));
const FloatGemmAcc rp_dropout = 1.0f / p_dropout;
const tensor_operation::element_wise::Scale scale_rp_dropout(s_element_op.Value() *
rp_dropout);
const auto q_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_q_grid, q_grid_desc_k0_m_k1.GetElementSpaceSize());
......@@ -1604,9 +1606,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
auto kgrad_thread_copy_vgpr_to_global = typename Gemm2::template CBlockwiseCopy<
decltype(kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4),
decltype(s_element_op)>(kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4,
kgrad_thread_origin_on_grid_n0_o0_n1_o1_n2_o2_o3_o4,
s_element_op);
decltype(scale_rp_dropout)>(kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4,
kgrad_thread_origin_on_grid_n0_o0_n1_o1_n2_o2_o3_o4,
scale_rp_dropout);
//
// set up Y dot dY
......@@ -1749,8 +1751,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
const index_t num_gemm1_k_block_outer_loop = k_grid_desc_k0_n_k1.GetLength(I1) / NPerBlock;
constexpr index_t num_gemm1_k_block_inner_loop = NPerBlock / Gemm1KPerBlock;
const index_t K = k_grid_desc_k0_n_k1.GetLength(I0) * k_grid_desc_k0_n_k1.GetLength(I2);
const float scalar = 1.0f / std::sqrt(K);
// Initialize dQ
qgrad_thread_buf.Clear();
......@@ -1831,14 +1831,15 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
}
else
{
s_slash_p_thread_buf(i) = scalar * s_slash_p_thread_buf[i];
s_element_op(s_slash_p_thread_buf(i), s_slash_p_thread_buf[i]);
}
});
}
else
{
static_for<0, s_slash_p_thread_buf.Size(), 1>{}(
[&](auto i) { s_slash_p_thread_buf(i) = scalar * s_slash_p_thread_buf[i]; });
static_for<0, s_slash_p_thread_buf.Size(), 1>{}([&](auto i) {
s_element_op(s_slash_p_thread_buf(i), s_slash_p_thread_buf[i]);
});
}
block_sync_lds(); // wait for lds read in gemm0 blockwise gemm
......@@ -2230,7 +2231,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
n_thread_data_on_block_idx[I2],
n_thread_data_on_block_idx[I3],
n_thread_data_on_block_idx[I4]),
s_element_op};
scale_rp_dropout};
// shuffle: blockwise copy C from LDS to global
auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment