Commit 092a88bf authored by guangzlu's avatar guangzlu
Browse files

modified z offset into fwd

parent adb8aaaa
...@@ -80,8 +80,8 @@ __global__ void ...@@ -80,8 +80,8 @@ __global__ void
const GemmAccDataType p_dropout_rescale, const GemmAccDataType p_dropout_rescale,
const unsigned long long seed, const unsigned long long seed,
const unsigned long long offset, const unsigned long long offset,
const index_t MRaw, const index_t raw_m_padded,
const index_t NRaw) const index_t raw_n_padded)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
...@@ -102,8 +102,10 @@ __global__ void ...@@ -102,8 +102,10 @@ __global__ void
const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetLSEBasePtr(g_idx))); static_cast<long_index_t>(compute_base_ptr_of_batch.GetLSEBasePtr(g_idx)));
const index_t global_thread_id = get_thread_global_1d_id(); // const index_t global_thread_id = get_thread_global_1d_id();
ck::philox ph(seed, global_thread_id, offset); ck::philox ph(seed, 0, offset);
const index_t z_random_matrix_offset = g_idx * raw_m_padded * raw_n_padded;
if constexpr(Deterministic) if constexpr(Deterministic)
{ {
...@@ -133,9 +135,8 @@ __global__ void ...@@ -133,9 +135,8 @@ __global__ void
p_dropout_in_16bits, p_dropout_in_16bits,
p_dropout_rescale, p_dropout_rescale,
ph, ph,
g_idx, z_random_matrix_offset,
MRaw, raw_n_padded,
NRaw,
i); i);
} }
} }
...@@ -165,9 +166,8 @@ __global__ void ...@@ -165,9 +166,8 @@ __global__ void
p_dropout_in_16bits, p_dropout_in_16bits,
p_dropout_rescale, p_dropout_rescale,
ph, ph,
g_idx, z_random_matrix_offset,
MRaw, raw_n_padded,
NRaw,
0); 0);
} }
#else #else
......
...@@ -463,9 +463,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -463,9 +463,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
const ushort p_dropout_in_16bits, const ushort p_dropout_in_16bits,
FloatGemmAcc p_dropout_rescale, FloatGemmAcc p_dropout_rescale,
ck::philox& ph, ck::philox& ph,
const index_t g_idx, const index_t z_random_matrix_offset,
const index_t MRaw, const index_t raw_n_padded,
const index_t NRaw,
const index_t block_idx_m) const index_t block_idx_m)
{ {
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
...@@ -1197,8 +1196,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -1197,8 +1196,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
auto m_global = m_local + m_block_data_idx_on_grid; auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid; auto n_global = n_local + n_block_data_idx_on_grid;
auto global_elem_id = auto global_elem_id = z_random_matrix_offset + m_global * raw_n_padded +
MRaw * NRaw * g_idx + m_global * NRaw + n_global; // unique element global 1d id n_global; // unique element global 1d id
// index_t id_step = Acc0TileIterator::GetNumOfAccess() / n0.value; // index_t id_step = Acc0TileIterator::GetNumOfAccess() / n0.value;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment