"...composable_kernel_rocm.git" did not exist on "8ce410341083f3786b8001069205fbbee8c68e63"
Commit 043c8ff3 authored by guangzlu's avatar guangzlu
Browse files

changed some format

parent 2299a4f1
......@@ -100,15 +100,12 @@ __global__ void
const index_t global_thread_id = get_thread_global_1d_id();
ck::philox ph(seed, global_thread_id, offset);
unsigned short* p_z_grid_in = //
p_z_grid == nullptr ? nullptr : p_z_grid + z_batch_offset;
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(
p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
p_b1_grid + b1_batch_offset,
p_c_grid + c_batch_offset,
p_z_grid_in,
nullptr ? nullptr : p_z_grid + z_batch_offset,
p_lse_grid + lse_batch_offset,
p_shared,
a_element_op,
......
......@@ -97,18 +97,17 @@ __global__ void
const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
arg_ptr[group_id].compute_base_ptr_of_batch_.GetLSEBasePtr(g_idx)));
unsigned short* p_z_grid_in = //
(arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr
: arg_ptr[group_id].p_z_grid_ + z_batch_offset);
// unsigned short* p_z_grid_in = //
// (arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr
// : arg_ptr[group_id].p_z_grid_ + z_batch_offset);
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(
arg_ptr[group_id].p_a_grid_ + a_batch_offset,
arg_ptr[group_id].p_b_grid_ + b_batch_offset,
arg_ptr[group_id].p_b1_grid_ + b1_batch_offset,
arg_ptr[group_id].p_c_grid_ + c_batch_offset,
p_z_grid_in,
// arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr
// : arg_ptr[group_id].p_z_grid_ + z_batch_offset,
arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr
: arg_ptr[group_id].p_z_grid_ + z_batch_offset,
arg_ptr[group_id].p_lse_grid_ + lse_batch_offset,
p_shared,
a_element_op,
......@@ -120,7 +119,7 @@ __global__ void
arg_ptr[group_id].b_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg_ptr[group_id].z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_, ////////
arg_ptr[group_id].z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].block_2_ctile_map_,
arg_ptr[group_id].c0_matrix_mask_,
......
......@@ -840,7 +840,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
// gemm1 K loop
index_t gemm1_k_block_outer_index = 0;
///////////////////=>z for dropout
// z is random number matrix for dropout verify
//
// z vgpr copy to global
//
......@@ -905,8 +905,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
0),
tensor_operation::element_wise::PassThrough{}};
///////////////////=>z for dropout
do
{
auto n_block_data_idx_on_grid =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment