Commit de53e421 authored by letaoqin's avatar letaoqin
Browse files

load d0 to lds

parent ec2ad713
...@@ -1229,7 +1229,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1229,7 +1229,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
__device__ static void __device__ static void
Run(const InputDataType* __restrict__ p_q_grid, Run(const InputDataType* __restrict__ p_q_grid,
const InputDataType* __restrict__ p_k_grid, const InputDataType* __restrict__ p_k_grid,
const D0DataType* __restrict__ p_d_grid, const D0DataType* __restrict__ p_d0_grid,
ZDataType* __restrict__ p_z_grid, ZDataType* __restrict__ p_z_grid,
const InputDataType* __restrict__ p_v_grid, const InputDataType* __restrict__ p_v_grid,
const InputDataType* __restrict__ p_y_grid, const InputDataType* __restrict__ p_y_grid,
...@@ -1262,8 +1262,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1262,8 +1262,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const index_t raw_n_padded, const index_t raw_n_padded,
const index_t block_idx_n) const index_t block_idx_n)
{ {
ignore = d0_grid_desc_m0_n0_m1_m2_n1_m3;
ignore = p_d_grid;
const FloatGemmAcc p_dropout = type_convert<FloatGemmAcc>(1.0f - p_drop); const FloatGemmAcc p_dropout = type_convert<FloatGemmAcc>(1.0f - p_drop);
const FloatGemmAcc rp_dropout = type_convert<FloatGemmAcc>(1.0f / p_dropout); const FloatGemmAcc rp_dropout = type_convert<FloatGemmAcc>(1.0f / p_dropout);
const ushort p_dropout_in_16bits = const ushort p_dropout_in_16bits =
...@@ -1940,6 +1938,29 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1940,6 +1938,29 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
block_sync_lds(); // wait for lds read in gemm0 blockwise gemm block_sync_lds(); // wait for lds read in gemm0 blockwise gemm
// add bias // add bias
if constexpr(!is_same<D0DataType, void>::value)
{
const auto d0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d0_grid, d0_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
auto d0_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<GemmDataType*>(p_shared) + SharedMemTrait::a_block_space_offset,
D0Loader::d0_block_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
static_for<0, D0M1, 1>{}([&](auto) {
d0_block_copy_global_to_lds.RunRead(d0_grid_desc_m0_n0_m1_m2_n1_m3,
d0_grid_buf);
d0_block_copy_global_to_lds.MoveSrcSliceWindow(
d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(0, 0, 1, 0, 0, 0));
d0_block_copy_global_to_lds.RunWrite(D0Loader::d0_block_desc_m0_n0_m1_m2_n1_m3,
d0_block_buf);
});
d0_block_copy_global_to_lds.MoveSrcSliceWindow(
d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(1, 0, -D0M1.value, 0, 0, 0));
}
// P_i: = softmax(scalar * S_i:) // P_i: = softmax(scalar * S_i:)
// scaling is already performed in the preceding statements with s_element_op // scaling is already performed in the preceding statements with s_element_op
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment