Commit 0413755a authored by guangzlu's avatar guangzlu
Browse files

modified offest to bwd pt6

parent 092a88bf
...@@ -86,8 +86,8 @@ __global__ void ...@@ -86,8 +86,8 @@ __global__ void
const float p_drop, const float p_drop,
const unsigned long long seed, const unsigned long long seed,
const unsigned long long offset, const unsigned long long offset,
const index_t MRaw, const index_t raw_m_padded,
const index_t NRaw) const index_t raw_n_padded)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
...@@ -110,10 +110,11 @@ __global__ void ...@@ -110,10 +110,11 @@ __global__ void
const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetLSEBasePtr(g_idx))); static_cast<long_index_t>(compute_base_ptr_of_batch.GetLSEBasePtr(g_idx)));
const index_t global_thread_id = get_thread_global_1d_id(); ck::philox ph(seed, 0, offset);
ck::philox ph(seed, global_thread_id, offset);
ZDataType* z_matrix_ptr = (p_z_grid == nullptr ? nullptr : p_z_grid + z_batch_offset); ZDataType* z_matrix_ptr = (p_z_grid == nullptr ? nullptr : p_z_grid + z_batch_offset);
const index_t z_random_matrix_offset = g_idx * raw_m_padded * raw_n_padded;
if constexpr(Deterministic) if constexpr(Deterministic)
{ {
for(index_t i = 0; i < nblock; i++) for(index_t i = 0; i < nblock; i++)
...@@ -146,9 +147,8 @@ __global__ void ...@@ -146,9 +147,8 @@ __global__ void
c0_matrix_mask, c0_matrix_mask,
p_drop, p_drop,
ph, ph,
g_idx, z_random_matrix_offset,
MRaw, raw_n_padded,
NRaw,
i); i);
} }
} }
...@@ -181,9 +181,8 @@ __global__ void ...@@ -181,9 +181,8 @@ __global__ void
c0_matrix_mask, c0_matrix_mask,
p_drop, p_drop,
ph, ph,
g_idx, z_random_matrix_offset,
MRaw, raw_n_padded,
NRaw,
0); 0);
} }
#else #else
......
...@@ -1254,9 +1254,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -1254,9 +1254,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
const C0MatrixMask& c0_matrix_mask, const C0MatrixMask& c0_matrix_mask,
const float p_drop, const float p_drop,
ck::philox& ph, ck::philox& ph,
const index_t g_idx, const index_t z_random_matrix_offset,
const index_t MRaw, const index_t raw_n_padded,
const index_t NRaw,
const index_t block_idx_n) const index_t block_idx_n)
{ {
const FloatGemmAcc p_dropout = type_convert<FloatGemmAcc>(1.0f - p_drop); const FloatGemmAcc p_dropout = type_convert<FloatGemmAcc>(1.0f - p_drop);
...@@ -1959,16 +1958,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -1959,16 +1958,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
auto m_global = m_local + m_block_data_idx_on_grid; auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid; auto n_global = n_local + n_block_data_idx_on_grid;
auto global_elem_id_raw = auto global_elem_id_raw = z_random_matrix_offset + m_global * raw_n_padded +
MRaw * NRaw * g_idx + m_global * NRaw + n_global; // unique element global 1d id n_global; // unique element global 1d id
auto global_elem_id = auto global_elem_id =
(global_elem_id_raw % M4) * NRaw + int(global_elem_id_raw / M4) * M4; (global_elem_id_raw % M4) * raw_n_padded + int(global_elem_id_raw / M4) * M4;
blockwise_dropout.template ApplyDropoutAttnBwdSaveZ<decltype(s_slash_p_thread_buf), blockwise_dropout.template ApplyDropoutAttnBwdSaveZ<decltype(s_slash_p_thread_buf),
decltype(z_tenor_buffer), decltype(z_tenor_buffer),
true>( true>(
s_slash_p_thread_buf, ph, global_elem_id, z_tenor_buffer, NRaw); s_slash_p_thread_buf, ph, global_elem_id, z_tenor_buffer, raw_n_padded);
z_thread_copy_vgpr_to_global.Run(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, z_thread_copy_vgpr_to_global.Run(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
...@@ -1986,16 +1985,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -1986,16 +1985,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
auto m_global = m_local + m_block_data_idx_on_grid; auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid; auto n_global = n_local + n_block_data_idx_on_grid;
auto global_elem_id_raw = auto global_elem_id_raw = z_random_matrix_offset + m_global * raw_n_padded +
MRaw * NRaw * g_idx + m_global * NRaw + n_global; // unique element global 1d id n_global; // unique element global 1d id
auto global_elem_id = auto global_elem_id =
(global_elem_id_raw % M4) * NRaw + int(global_elem_id_raw / M4) * M4; (global_elem_id_raw % M4) * raw_n_padded + int(global_elem_id_raw / M4) * M4;
// P_dropped // P_dropped
blockwise_dropout blockwise_dropout
.template ApplyDropoutAttnBwd<decltype(s_slash_p_thread_buf), true>( .template ApplyDropoutAttnBwd<decltype(s_slash_p_thread_buf), true>(
s_slash_p_thread_buf, ph, global_elem_id, NRaw); s_slash_p_thread_buf, ph, global_elem_id, raw_n_padded);
} }
block_sync_lds(); // wait for gemm1 LDS read block_sync_lds(); // wait for gemm1 LDS read
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment