Commit 7d6a8ec7 authored by guangzlu's avatar guangzlu
Browse files

added dropout to fwd_v2 and bwd_qoop

parent 6d63c311
...@@ -132,6 +132,9 @@ __global__ void ...@@ -132,6 +132,9 @@ __global__ void
arg_ptr[group_id].c0_matrix_mask_, arg_ptr[group_id].c0_matrix_mask_,
p_dropout, p_dropout,
ph, ph,
arg_ptr[group_id].z_random_matrix_offset_ +
g_idx * arg_ptr[group_id].raw_m_padded_ * arg_ptr[group_id].raw_n_padded_,
arg_ptr[group_id].raw_n_padded_,
i); i);
} }
} }
...@@ -165,6 +168,9 @@ __global__ void ...@@ -165,6 +168,9 @@ __global__ void
arg_ptr[group_id].c0_matrix_mask_, arg_ptr[group_id].c0_matrix_mask_,
p_dropout, p_dropout,
ph, ph,
arg_ptr[group_id].z_random_matrix_offset_ +
g_idx * arg_ptr[group_id].raw_m_padded_ * arg_ptr[group_id].raw_n_padded_,
arg_ptr[group_id].raw_n_padded_,
0); 0);
} }
#else #else
...@@ -654,7 +660,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -654,7 +660,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
y_grid_desc_mblock_mperblock_oblock_operblock_; y_grid_desc_mblock_mperblock_oblock_operblock_;
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_; c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_;
LSEGridDesc_M lse_grid_desc_m_; LSEGridDesc_M lse_grid_desc_m_;
KGridDesc_N_K k_grid_desc_n_k_; KGridDesc_N_K k_grid_desc_n_k_;
...@@ -667,6 +673,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -667,6 +673,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
// check C0 masking and padding // check C0 masking and padding
C0MatrixMask c0_matrix_mask_; C0MatrixMask c0_matrix_mask_;
index_t block_start_, block_end_; index_t block_start_, block_end_;
index_t z_random_matrix_offset_;
index_t raw_m_padded_, raw_n_padded_;
}; };
struct GroupDeviceArg struct GroupDeviceArg
...@@ -740,6 +749,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -740,6 +749,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
} }
grid_size_ = 0; grid_size_ = 0;
index_t z_random_matrix_offset = 0;
for(index_t i = 0; i < group_count_; i++) for(index_t i = 0; i < group_count_; i++)
{ {
const auto p_a_grid = static_cast<const InputDataType*>(p_As[i]); const auto p_a_grid = static_cast<const InputDataType*>(p_As[i]);
...@@ -787,7 +799,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -787,7 +799,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
problem_desc.c_gs_ms_gemm1ns_lengths, problem_desc.c_gs_ms_gemm1ns_strides); problem_desc.c_gs_ms_gemm1ns_lengths, problem_desc.c_gs_ms_gemm1ns_strides);
typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
y_grid_desc_mblock_mperblock_oblock_operblock; y_grid_desc_mblock_mperblock_oblock_operblock;
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5; c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5;
const index_t BlockStart = grid_size_; const index_t BlockStart = grid_size_;
const auto block_2_ctile_map = Block2CTileMap(k_grid_desc_n_k, BlockStart); const auto block_2_ctile_map = Block2CTileMap(k_grid_desc_n_k, BlockStart);
...@@ -802,7 +814,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -802,7 +814,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
} }
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 = c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5( GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(
z_grid_desc_m_n); z_grid_desc_m_n);
const index_t batch_count = c_grid_desc_g_m_n.GetLength(I0); const index_t batch_count = c_grid_desc_g_m_n.GetLength(I0);
...@@ -836,6 +848,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -836,6 +848,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
"match that in template argument"); "match that in template argument");
} }
const auto raw_m_padded = GridwiseGemm::GetPaddedSize(
problem_desc.a_gs_ms_ks_lengths[NumDimG + NumDimM - 1]);
const auto raw_n_padded = GridwiseGemm::GetPaddedSize(
problem_desc.b_gs_ns_ks_lengths[NumDimG + NumDimN - 1]);
group_kernel_args_.push_back({p_a_grid, group_kernel_args_.push_back({p_a_grid,
p_b_grid, p_b_grid,
p_z_grid, p_z_grid,
...@@ -861,7 +878,13 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -861,7 +878,13 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
compute_base_ptr_of_batch, compute_base_ptr_of_batch,
c0_matrix_mask, c0_matrix_mask,
BlockStart, BlockStart,
BlockEnd}); BlockEnd,
z_random_matrix_offset,
raw_m_padded,
raw_n_padded});
z_random_matrix_offset =
z_random_matrix_offset + raw_m_padded * raw_n_padded * batch_count;
group_device_args_.push_back( group_device_args_.push_back(
{{problem_desc.a_gs_ms_ks_lengths[NumDimG + NumDimM - 1], {{problem_desc.a_gs_ms_ks_lengths[NumDimG + NumDimM - 1],
......
...@@ -132,6 +132,9 @@ __global__ void ...@@ -132,6 +132,9 @@ __global__ void
arg_ptr[group_id].c0_matrix_mask_, arg_ptr[group_id].c0_matrix_mask_,
p_dropout, p_dropout,
ph, ph,
arg_ptr[group_id].z_random_matrix_offset_ +
g_idx * arg_ptr[group_id].raw_m_padded_ * arg_ptr[group_id].raw_n_padded_,
arg_ptr[group_id].raw_n_padded_,
i); i);
} }
} }
...@@ -165,6 +168,9 @@ __global__ void ...@@ -165,6 +168,9 @@ __global__ void
arg_ptr[group_id].c0_matrix_mask_, arg_ptr[group_id].c0_matrix_mask_,
p_dropout, p_dropout,
ph, ph,
arg_ptr[group_id].z_random_matrix_offset_ +
g_idx * arg_ptr[group_id].raw_m_padded_ * arg_ptr[group_id].raw_n_padded_,
arg_ptr[group_id].raw_n_padded_,
0); 0);
} }
#else #else
...@@ -662,7 +668,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -662,7 +668,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
y_grid_desc_mblock_mperblock_oblock_operblock_; y_grid_desc_mblock_mperblock_oblock_operblock_;
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_; c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_;
LSEGridDesc_M lse_grid_desc_m_; LSEGridDesc_M lse_grid_desc_m_;
KGridDesc_N_K k_grid_desc_n_k_; KGridDesc_N_K k_grid_desc_n_k_;
...@@ -675,6 +681,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -675,6 +681,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
// check C0 masking and padding // check C0 masking and padding
C0MatrixMask c0_matrix_mask_; C0MatrixMask c0_matrix_mask_;
index_t block_start_, block_end_; index_t block_start_, block_end_;
index_t z_random_matrix_offset_;
index_t raw_m_padded_, raw_n_padded_;
}; };
struct GroupDeviceArg struct GroupDeviceArg
...@@ -748,6 +757,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -748,6 +757,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
} }
grid_size_ = 0; grid_size_ = 0;
index_t z_random_matrix_offset = 0;
for(index_t i = 0; i < group_count_; i++) for(index_t i = 0; i < group_count_; i++)
{ {
const auto p_a_grid = static_cast<const InputDataType*>(p_As[i]); const auto p_a_grid = static_cast<const InputDataType*>(p_As[i]);
...@@ -795,7 +807,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -795,7 +807,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
problem_desc.c_gs_ms_gemm1ns_lengths, problem_desc.c_gs_ms_gemm1ns_strides); problem_desc.c_gs_ms_gemm1ns_lengths, problem_desc.c_gs_ms_gemm1ns_strides);
typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
y_grid_desc_mblock_mperblock_oblock_operblock; y_grid_desc_mblock_mperblock_oblock_operblock;
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5; c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5;
const index_t BlockStart = grid_size_; const index_t BlockStart = grid_size_;
const auto block_2_ctile_map = Block2CTileMap(k_grid_desc_n_k, BlockStart); const auto block_2_ctile_map = Block2CTileMap(k_grid_desc_n_k, BlockStart);
...@@ -810,7 +822,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -810,7 +822,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
} }
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 = c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5( GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(
z_grid_desc_m_n); z_grid_desc_m_n);
const index_t batch_count = c_grid_desc_g_m_n.GetLength(I0); const index_t batch_count = c_grid_desc_g_m_n.GetLength(I0);
...@@ -844,6 +856,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -844,6 +856,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
"match that in template argument"); "match that in template argument");
} }
const auto raw_m_padded = GridwiseGemm::GetPaddedSize(
problem_desc.a_gs_ms_ks_lengths[NumDimG + NumDimM - 1]);
const auto raw_n_padded = GridwiseGemm::GetPaddedSize(
problem_desc.b_gs_ns_ks_lengths[NumDimG + NumDimN - 1]);
group_kernel_args_.push_back({p_a_grid, group_kernel_args_.push_back({p_a_grid,
p_b_grid, p_b_grid,
p_z_grid, p_z_grid,
...@@ -869,7 +886,13 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -869,7 +886,13 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
compute_base_ptr_of_batch, compute_base_ptr_of_batch,
c0_matrix_mask, c0_matrix_mask,
BlockStart, BlockStart,
BlockEnd}); BlockEnd,
z_random_matrix_offset,
raw_m_padded,
raw_n_padded});
z_random_matrix_offset =
z_random_matrix_offset + raw_m_padded * raw_n_padded * batch_count;
group_device_args_.push_back( group_device_args_.push_back(
{{problem_desc.a_gs_ms_ks_lengths[NumDimG + NumDimM - 1], {{problem_desc.a_gs_ms_ks_lengths[NumDimG + NumDimM - 1],
......
...@@ -132,6 +132,9 @@ __global__ void ...@@ -132,6 +132,9 @@ __global__ void
p_dropout_in_16bits, p_dropout_in_16bits,
p_dropout_rescale, p_dropout_rescale,
ph, ph,
arg_ptr[group_id].z_random_matrix_offset_ +
g_idx * arg_ptr[group_id].raw_m_padded_ * arg_ptr[group_id].raw_n_padded_,
arg_ptr[group_id].raw_n_padded_,
i); i);
} }
} }
...@@ -165,6 +168,9 @@ __global__ void ...@@ -165,6 +168,9 @@ __global__ void
p_dropout_in_16bits, p_dropout_in_16bits,
p_dropout_rescale, p_dropout_rescale,
ph, ph,
arg_ptr[group_id].z_random_matrix_offset_ +
g_idx * arg_ptr[group_id].raw_m_padded_ * arg_ptr[group_id].raw_n_padded_,
arg_ptr[group_id].raw_n_padded_,
0); 0);
} }
#else #else
...@@ -567,6 +573,9 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle ...@@ -567,6 +573,9 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
Block2CTileMap block_2_ctile_map_; Block2CTileMap block_2_ctile_map_;
index_t block_start_, block_end_; index_t block_start_, block_end_;
index_t z_random_matrix_offset_;
index_t raw_m_padded_, raw_n_padded_;
}; };
struct GroupDeviceArg struct GroupDeviceArg
...@@ -626,6 +635,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle ...@@ -626,6 +635,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
grid_size_ = 0; grid_size_ = 0;
index_t z_random_matrix_offset = 0;
for(std::size_t i = 0; i < group_count_; i++) for(std::size_t i = 0; i < group_count_; i++)
{ {
const auto p_a_grid = static_cast<const ADataType*>(p_a_vec[i]); const auto p_a_grid = static_cast<const ADataType*>(p_a_vec[i]);
...@@ -712,6 +723,11 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle ...@@ -712,6 +723,11 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
"match that in template argument"); "match that in template argument");
} }
const auto raw_m_padded = GridwiseGemm::GetPaddedSize(
problem_desc.a_gs_ms_ks_lengths[NumDimG + NumDimM - 1]);
const auto raw_n_padded = GridwiseGemm::GetPaddedSize(
problem_desc.b0_gs_ns_ks_lengths[NumDimG + NumDimN - 1]);
group_kernel_args_.push_back({p_a_grid, group_kernel_args_.push_back({p_a_grid,
p_b_grid, p_b_grid,
p_b1_grid, p_b1_grid,
...@@ -730,7 +746,13 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle ...@@ -730,7 +746,13 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
c0_matrix_mask, c0_matrix_mask,
block_2_ctile_map, block_2_ctile_map,
BlockStart, BlockStart,
BlockEnd}); BlockEnd,
z_random_matrix_offset,
raw_m_padded,
raw_n_padded});
z_random_matrix_offset =
z_random_matrix_offset + raw_m_padded * raw_n_padded * batch_count;
group_device_args_.push_back( group_device_args_.push_back(
{{problem_desc.a_gs_ms_ks_lengths[NumDimG + NumDimM - 1], {{problem_desc.a_gs_ms_ks_lengths[NumDimG + NumDimM - 1],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment