"...composable_kernel_rocm.git" did not exist on "85fc91c3218c1d85169ed1fe95eef7b07942e648"
Commit 77df3ccb authored by letaoqin's avatar letaoqin
Browse files

format

parent 48f98948
...@@ -119,12 +119,13 @@ __global__ void ...@@ -119,12 +119,13 @@ __global__ void
const index_t z_random_matrix_offset = g_idx * raw_m_padded * raw_n_padded; const index_t z_random_matrix_offset = g_idx * raw_m_padded * raw_n_padded;
const D0DataType* tmp_p_d0_grid = nullptr; const D0DataType* tmp_p_d0_grid = nullptr;
if constexpr(!is_same<D0DataType,void>::value){ if constexpr(!is_same<D0DataType, void>::value)
{
const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetD0BasePtr(g_idx))); static_cast<long_index_t>(compute_base_ptr_of_batch.GetD0BasePtr(g_idx)));
tmp_p_d0_grid = p_d0_grid + d0_batch_offset; tmp_p_d0_grid = p_d0_grid + d0_batch_offset;
} }
if constexpr(Deterministic) if constexpr(Deterministic)
{ {
for(index_t i = 0; i < nblock; i++) for(index_t i = 0; i < nblock; i++)
...@@ -186,7 +187,7 @@ __global__ void ...@@ -186,7 +187,7 @@ __global__ void
c_element_op, c_element_op,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
d0_grid_desc_m0_n0_m1_m2_n1_m3 d0_grid_desc_m0_n0_m1_m2_n1_m3,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
b1_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
......
...@@ -1191,43 +1191,44 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1191,43 +1191,44 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
typename Block2CTileMap, typename Block2CTileMap,
typename C0MatrixMask, typename C0MatrixMask,
typename YGradGridDesc_M0_O_M1> typename YGradGridDesc_M0_O_M1>
__device__ static void Run(const InputDataType* __restrict__ p_q_grid, __device__ static void
const InputDataType* __restrict__ p_k_grid, Run(const InputDataType* __restrict__ p_q_grid,
const D0DataType* __restrict__ p_d_grid, const InputDataType* __restrict__ p_k_grid,
ZDataType* __restrict__ p_z_grid, const D0DataType* __restrict__ p_d_grid,
const InputDataType* __restrict__ p_v_grid, ZDataType* __restrict__ p_z_grid,
const InputDataType* __restrict__ p_y_grid, const InputDataType* __restrict__ p_v_grid,
const FloatLSE* __restrict__ p_lse_grid, const InputDataType* __restrict__ p_y_grid,
const InputDataType* __restrict__ p_ygrad_grid, const FloatLSE* __restrict__ p_lse_grid,
OutputDataType* __restrict__ p_qgrad_grid, const InputDataType* __restrict__ p_ygrad_grid,
OutputDataType* __restrict__ p_kgrad_grid, OutputDataType* __restrict__ p_qgrad_grid,
OutputDataType* __restrict__ p_vgrad_grid, OutputDataType* __restrict__ p_kgrad_grid,
void* __restrict__ p_shared, OutputDataType* __restrict__ p_vgrad_grid,
const AElementwiseOperation& a_element_op, void* __restrict__ p_shared,
const BElementwiseOperation& b_element_op, const AElementwiseOperation& a_element_op,
const SElementwiseOperation& s_element_op, const BElementwiseOperation& b_element_op,
const B1ElementwiseOperation& b1_element_op, const SElementwiseOperation& s_element_op,
const CElementwiseOperation& c_element_op, const B1ElementwiseOperation& b1_element_op,
const QGridDesc_K0_M_K1& q_grid_desc_k0_m_k1, const CElementwiseOperation& c_element_op,
const KGridDesc_K0_N_K1& k_grid_desc_k0_n_k1, const QGridDesc_K0_M_K1& q_grid_desc_k0_m_k1,
const D0GridDescriptor_M0_N0_M1_M2_N1_M3& d0_grid_desc_m0_n0_m1_m2_n1_m3, const KGridDesc_K0_N_K1& k_grid_desc_k0_n_k1,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3& const D0GridDescriptor_M0_N0_M1_M2_N1_M3& d0_grid_desc_m0_n0_m1_m2_n1_m3,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3&
const VGridDesc_N0_O_N1& v_grid_desc_n0_o_n1, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock& const VGridDesc_N0_O_N1& v_grid_desc_n0_o_n1,
y_grid_desc_mblock_mperblock_oblock_operblock, const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock&
const LSEGridDesc_M& lse_grid_desc_m, y_grid_desc_mblock_mperblock_oblock_operblock,
const YGradGridDesc_M0_O_M1& ygrad_grid_desc_m0_o_m1, const LSEGridDesc_M& lse_grid_desc_m,
const Block2CTileMap& block_2_ctile_map, const YGradGridDesc_M0_O_M1& ygrad_grid_desc_m0_o_m1,
const C0MatrixMask& c0_matrix_mask, const Block2CTileMap& block_2_ctile_map,
const float p_drop, const C0MatrixMask& c0_matrix_mask,
ck::philox& ph, const float p_drop,
const index_t z_random_matrix_offset, ck::philox& ph,
const index_t raw_n_padded, const index_t z_random_matrix_offset,
const index_t block_idx_n) const index_t raw_n_padded,
const index_t block_idx_n)
{ {
ignore = d0_grid_desc_m0_n0_m1_m2_n1_m3; ignore = d0_grid_desc_m0_n0_m1_m2_n1_m3;
ignore = p_d_grid; ignore = p_d_grid;
const FloatGemmAcc p_dropout = type_convert<FloatGemmAcc>(1.0f - p_drop); const FloatGemmAcc p_dropout = type_convert<FloatGemmAcc>(1.0f - p_drop);
const FloatGemmAcc rp_dropout = type_convert<FloatGemmAcc>(1.0f / p_dropout); const FloatGemmAcc rp_dropout = type_convert<FloatGemmAcc>(1.0f / p_dropout);
const ushort p_dropout_in_16bits = const ushort p_dropout_in_16bits =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment