Commit 48f98948 authored by letaoqin's avatar letaoqin
Browse files

add code to device

parent 79cf90f2
...@@ -104,8 +104,6 @@ __global__ void ...@@ -104,8 +104,6 @@ __global__ void
static_cast<long_index_t>(compute_base_ptr_of_batch.GetABasePtr(g_idx))); static_cast<long_index_t>(compute_base_ptr_of_batch.GetABasePtr(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetBBasePtr(g_idx))); static_cast<long_index_t>(compute_base_ptr_of_batch.GetBBasePtr(g_idx)));
const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetD0BasePtr(g_idx)));
const long_index_t z_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t z_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetZBasePtr(g_idx))); static_cast<long_index_t>(compute_base_ptr_of_batch.GetZBasePtr(g_idx)));
const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane(
...@@ -120,9 +118,13 @@ __global__ void ...@@ -120,9 +118,13 @@ __global__ void
const index_t z_random_matrix_offset = g_idx * raw_m_padded * raw_n_padded; const index_t z_random_matrix_offset = g_idx * raw_m_padded * raw_n_padded;
ignore = p_d0_grid; const D0DataType* tmp_p_d0_grid = nullptr;
ignore = d0_grid_desc_m0_n0_m1_m2_n1_m3; if constexpr(!is_same<D0DataType,void>::value){
ignore = d0_batch_offset; const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetD0BasePtr(g_idx)));
tmp_p_d0_grid = p_d0_grid + d0_batch_offset;
}
if constexpr(Deterministic) if constexpr(Deterministic)
{ {
for(index_t i = 0; i < nblock; i++) for(index_t i = 0; i < nblock; i++)
...@@ -130,6 +132,7 @@ __global__ void ...@@ -130,6 +132,7 @@ __global__ void
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>( GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(
p_a_grid + a_batch_offset, p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset, p_b_grid + b_batch_offset,
tmp_p_d0_grid,
z_matrix_ptr, z_matrix_ptr,
p_b1_grid + b1_batch_offset, p_b1_grid + b1_batch_offset,
p_c_grid + c_batch_offset, p_c_grid + c_batch_offset,
...@@ -146,6 +149,7 @@ __global__ void ...@@ -146,6 +149,7 @@ __global__ void
c_element_op, c_element_op,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
d0_grid_desc_m0_n0_m1_m2_n1_m3,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
b1_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
...@@ -165,6 +169,7 @@ __global__ void ...@@ -165,6 +169,7 @@ __global__ void
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>( GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(
p_a_grid + a_batch_offset, p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset, p_b_grid + b_batch_offset,
tmp_p_d0_grid,
z_matrix_ptr, z_matrix_ptr,
p_b1_grid + b1_batch_offset, p_b1_grid + b1_batch_offset,
p_c_grid + c_batch_offset, p_c_grid + c_batch_offset,
...@@ -181,6 +186,7 @@ __global__ void ...@@ -181,6 +186,7 @@ __global__ void
c_element_op, c_element_op,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
d0_grid_desc_m0_n0_m1_m2_n1_m3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
b1_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
......
...@@ -1193,6 +1193,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1193,6 +1193,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
typename YGradGridDesc_M0_O_M1> typename YGradGridDesc_M0_O_M1>
__device__ static void Run(const InputDataType* __restrict__ p_q_grid, __device__ static void Run(const InputDataType* __restrict__ p_q_grid,
const InputDataType* __restrict__ p_k_grid, const InputDataType* __restrict__ p_k_grid,
const D0DataType* __restrict__ p_d_grid,
ZDataType* __restrict__ p_z_grid, ZDataType* __restrict__ p_z_grid,
const InputDataType* __restrict__ p_v_grid, const InputDataType* __restrict__ p_v_grid,
const InputDataType* __restrict__ p_y_grid, const InputDataType* __restrict__ p_y_grid,
...@@ -1209,6 +1210,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1209,6 +1210,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const CElementwiseOperation& c_element_op, const CElementwiseOperation& c_element_op,
const QGridDesc_K0_M_K1& q_grid_desc_k0_m_k1, const QGridDesc_K0_M_K1& q_grid_desc_k0_m_k1,
const KGridDesc_K0_N_K1& k_grid_desc_k0_n_k1, const KGridDesc_K0_N_K1& k_grid_desc_k0_n_k1,
const D0GridDescriptor_M0_N0_M1_M2_N1_M3& d0_grid_desc_m0_n0_m1_m2_n1_m3,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3& const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3&
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
const VGridDesc_N0_O_N1& v_grid_desc_n0_o_n1, const VGridDesc_N0_O_N1& v_grid_desc_n0_o_n1,
...@@ -1224,6 +1226,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1224,6 +1226,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const index_t raw_n_padded, const index_t raw_n_padded,
const index_t block_idx_n) const index_t block_idx_n)
{ {
ignore = d0_grid_desc_m0_n0_m1_m2_n1_m3;
ignore = p_d_grid;
const FloatGemmAcc p_dropout = type_convert<FloatGemmAcc>(1.0f - p_drop); const FloatGemmAcc p_dropout = type_convert<FloatGemmAcc>(1.0f - p_drop);
const FloatGemmAcc rp_dropout = type_convert<FloatGemmAcc>(1.0f / p_dropout); const FloatGemmAcc rp_dropout = type_convert<FloatGemmAcc>(1.0f / p_dropout);
const ushort p_dropout_in_16bits = const ushort p_dropout_in_16bits =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment