Commit 77df3ccb authored by letaoqin's avatar letaoqin
Browse files

format

parent 48f98948
...@@ -119,7 +119,8 @@ __global__ void ...@@ -119,7 +119,8 @@ __global__ void
const index_t z_random_matrix_offset = g_idx * raw_m_padded * raw_n_padded; const index_t z_random_matrix_offset = g_idx * raw_m_padded * raw_n_padded;
const D0DataType* tmp_p_d0_grid = nullptr; const D0DataType* tmp_p_d0_grid = nullptr;
if constexpr(!is_same<D0DataType,void>::value){ if constexpr(!is_same<D0DataType, void>::value)
{
const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetD0BasePtr(g_idx))); static_cast<long_index_t>(compute_base_ptr_of_batch.GetD0BasePtr(g_idx)));
tmp_p_d0_grid = p_d0_grid + d0_batch_offset; tmp_p_d0_grid = p_d0_grid + d0_batch_offset;
...@@ -186,7 +187,7 @@ __global__ void ...@@ -186,7 +187,7 @@ __global__ void
c_element_op, c_element_op,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
d0_grid_desc_m0_n0_m1_m2_n1_m3 d0_grid_desc_m0_n0_m1_m2_n1_m3,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
b1_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
......
...@@ -1191,7 +1191,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1191,7 +1191,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
typename Block2CTileMap, typename Block2CTileMap,
typename C0MatrixMask, typename C0MatrixMask,
typename YGradGridDesc_M0_O_M1> typename YGradGridDesc_M0_O_M1>
__device__ static void Run(const InputDataType* __restrict__ p_q_grid, __device__ static void
Run(const InputDataType* __restrict__ p_q_grid,
const InputDataType* __restrict__ p_k_grid, const InputDataType* __restrict__ p_k_grid,
const D0DataType* __restrict__ p_d_grid, const D0DataType* __restrict__ p_d_grid,
ZDataType* __restrict__ p_z_grid, ZDataType* __restrict__ p_z_grid,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment