"benchmark/git@developer.sourcefind.cn:change/sglang.git" did not exist on "40e53d65cbb8b609a6ff8e977d2318044d0f0ee0"
Commit 0413755a authored by guangzlu's avatar guangzlu
Browse files

modified offest to bwd pt6

parent 092a88bf
......@@ -86,8 +86,8 @@ __global__ void
const float p_drop,
const unsigned long long seed,
const unsigned long long offset,
const index_t MRaw,
const index_t NRaw)
const index_t raw_m_padded,
const index_t raw_n_padded)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
......@@ -110,10 +110,11 @@ __global__ void
const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetLSEBasePtr(g_idx)));
const index_t global_thread_id = get_thread_global_1d_id();
ck::philox ph(seed, global_thread_id, offset);
ck::philox ph(seed, 0, offset);
ZDataType* z_matrix_ptr = (p_z_grid == nullptr ? nullptr : p_z_grid + z_batch_offset);
const index_t z_random_matrix_offset = g_idx * raw_m_padded * raw_n_padded;
if constexpr(Deterministic)
{
for(index_t i = 0; i < nblock; i++)
......@@ -146,9 +147,8 @@ __global__ void
c0_matrix_mask,
p_drop,
ph,
g_idx,
MRaw,
NRaw,
z_random_matrix_offset,
raw_n_padded,
i);
}
}
......@@ -181,9 +181,8 @@ __global__ void
c0_matrix_mask,
p_drop,
ph,
g_idx,
MRaw,
NRaw,
z_random_matrix_offset,
raw_n_padded,
0);
}
#else
......
......@@ -1254,9 +1254,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
const C0MatrixMask& c0_matrix_mask,
const float p_drop,
ck::philox& ph,
const index_t g_idx,
const index_t MRaw,
const index_t NRaw,
const index_t z_random_matrix_offset,
const index_t raw_n_padded,
const index_t block_idx_n)
{
const FloatGemmAcc p_dropout = type_convert<FloatGemmAcc>(1.0f - p_drop);
......@@ -1959,16 +1958,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
auto global_elem_id_raw =
MRaw * NRaw * g_idx + m_global * NRaw + n_global; // unique element global 1d id
auto global_elem_id_raw = z_random_matrix_offset + m_global * raw_n_padded +
n_global; // unique element global 1d id
auto global_elem_id =
(global_elem_id_raw % M4) * NRaw + int(global_elem_id_raw / M4) * M4;
(global_elem_id_raw % M4) * raw_n_padded + int(global_elem_id_raw / M4) * M4;
blockwise_dropout.template ApplyDropoutAttnBwdSaveZ<decltype(s_slash_p_thread_buf),
decltype(z_tenor_buffer),
true>(
s_slash_p_thread_buf, ph, global_elem_id, z_tenor_buffer, NRaw);
s_slash_p_thread_buf, ph, global_elem_id, z_tenor_buffer, raw_n_padded);
z_thread_copy_vgpr_to_global.Run(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
......@@ -1986,16 +1985,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
auto global_elem_id_raw =
MRaw * NRaw * g_idx + m_global * NRaw + n_global; // unique element global 1d id
auto global_elem_id_raw = z_random_matrix_offset + m_global * raw_n_padded +
n_global; // unique element global 1d id
auto global_elem_id =
(global_elem_id_raw % M4) * NRaw + int(global_elem_id_raw / M4) * M4;
(global_elem_id_raw % M4) * raw_n_padded + int(global_elem_id_raw / M4) * M4;
// P_dropped
blockwise_dropout
.template ApplyDropoutAttnBwd<decltype(s_slash_p_thread_buf), true>(
s_slash_p_thread_buf, ph, global_elem_id, NRaw);
s_slash_p_thread_buf, ph, global_elem_id, raw_n_padded);
}
block_sync_lds(); // wait for gemm1 LDS read
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment