Commit 092a88bf authored by guangzlu's avatar guangzlu
Browse files

modified z offset into fwd

parent adb8aaaa
......@@ -80,8 +80,8 @@ __global__ void
const GemmAccDataType p_dropout_rescale,
const unsigned long long seed,
const unsigned long long offset,
const index_t MRaw,
const index_t NRaw)
const index_t raw_m_padded,
const index_t raw_n_padded)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
......@@ -102,8 +102,10 @@ __global__ void
const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetLSEBasePtr(g_idx)));
const index_t global_thread_id = get_thread_global_1d_id();
ck::philox ph(seed, global_thread_id, offset);
// const index_t global_thread_id = get_thread_global_1d_id();
ck::philox ph(seed, 0, offset);
const index_t z_random_matrix_offset = g_idx * raw_m_padded * raw_n_padded;
if constexpr(Deterministic)
{
......@@ -133,9 +135,8 @@ __global__ void
p_dropout_in_16bits,
p_dropout_rescale,
ph,
g_idx,
MRaw,
NRaw,
z_random_matrix_offset,
raw_n_padded,
i);
}
}
......@@ -165,9 +166,8 @@ __global__ void
p_dropout_in_16bits,
p_dropout_rescale,
ph,
g_idx,
MRaw,
NRaw,
z_random_matrix_offset,
raw_n_padded,
0);
}
#else
......
......@@ -463,9 +463,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
const ushort p_dropout_in_16bits,
FloatGemmAcc p_dropout_rescale,
ck::philox& ph,
const index_t g_idx,
const index_t MRaw,
const index_t NRaw,
const index_t z_random_matrix_offset,
const index_t raw_n_padded,
const index_t block_idx_m)
{
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
......@@ -1197,8 +1196,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
auto global_elem_id =
MRaw * NRaw * g_idx + m_global * NRaw + n_global; // unique element global 1d id
auto global_elem_id = z_random_matrix_offset + m_global * raw_n_padded +
n_global; // unique element global 1d id
// index_t id_step = Acc0TileIterator::GetNumOfAccess() / n0.value;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment