Commit 77df3ccb authored by letaoqin's avatar letaoqin
Browse files

format

parent 48f98948
......@@ -119,7 +119,8 @@ __global__ void
const index_t z_random_matrix_offset = g_idx * raw_m_padded * raw_n_padded;
const D0DataType* tmp_p_d0_grid = nullptr;
if constexpr(!is_same<D0DataType,void>::value){
if constexpr(!is_same<D0DataType, void>::value)
{
const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetD0BasePtr(g_idx)));
tmp_p_d0_grid = p_d0_grid + d0_batch_offset;
......@@ -186,7 +187,7 @@ __global__ void
c_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
d0_grid_desc_m0_n0_m1_m2_n1_m3
d0_grid_desc_m0_n0_m1_m2_n1_m3,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
......
......@@ -1191,7 +1191,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
typename Block2CTileMap,
typename C0MatrixMask,
typename YGradGridDesc_M0_O_M1>
__device__ static void Run(const InputDataType* __restrict__ p_q_grid,
__device__ static void
Run(const InputDataType* __restrict__ p_q_grid,
const InputDataType* __restrict__ p_k_grid,
const D0DataType* __restrict__ p_d_grid,
ZDataType* __restrict__ p_z_grid,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment