Commit 7d6a8ec7 authored by guangzlu's avatar guangzlu
Browse files

added dropout to fwd_v2 and bwd_qoop

parent 6d63c311
......@@ -132,6 +132,9 @@ __global__ void
arg_ptr[group_id].c0_matrix_mask_,
p_dropout,
ph,
arg_ptr[group_id].z_random_matrix_offset_ +
g_idx * arg_ptr[group_id].raw_m_padded_ * arg_ptr[group_id].raw_n_padded_,
arg_ptr[group_id].raw_n_padded_,
i);
}
}
......@@ -165,6 +168,9 @@ __global__ void
arg_ptr[group_id].c0_matrix_mask_,
p_dropout,
ph,
arg_ptr[group_id].z_random_matrix_offset_ +
g_idx * arg_ptr[group_id].raw_m_padded_ * arg_ptr[group_id].raw_n_padded_,
arg_ptr[group_id].raw_n_padded_,
0);
}
#else
......@@ -654,7 +660,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
y_grid_desc_mblock_mperblock_oblock_operblock_;
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_;
LSEGridDesc_M lse_grid_desc_m_;
KGridDesc_N_K k_grid_desc_n_k_;
......@@ -667,6 +673,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
// check C0 masking and padding
C0MatrixMask c0_matrix_mask_;
index_t block_start_, block_end_;
index_t z_random_matrix_offset_;
index_t raw_m_padded_, raw_n_padded_;
};
struct GroupDeviceArg
......@@ -740,6 +749,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
}
grid_size_ = 0;
index_t z_random_matrix_offset = 0;
for(index_t i = 0; i < group_count_; i++)
{
const auto p_a_grid = static_cast<const InputDataType*>(p_As[i]);
......@@ -787,7 +799,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
problem_desc.c_gs_ms_gemm1ns_lengths, problem_desc.c_gs_ms_gemm1ns_strides);
typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
y_grid_desc_mblock_mperblock_oblock_operblock;
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5;
const index_t BlockStart = grid_size_;
const auto block_2_ctile_map = Block2CTileMap(k_grid_desc_n_k, BlockStart);
......@@ -802,7 +814,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
}
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(
z_grid_desc_m_n);
const index_t batch_count = c_grid_desc_g_m_n.GetLength(I0);
......@@ -836,6 +848,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
"match that in template argument");
}
const auto raw_m_padded = GridwiseGemm::GetPaddedSize(
problem_desc.a_gs_ms_ks_lengths[NumDimG + NumDimM - 1]);
const auto raw_n_padded = GridwiseGemm::GetPaddedSize(
problem_desc.b_gs_ns_ks_lengths[NumDimG + NumDimN - 1]);
group_kernel_args_.push_back({p_a_grid,
p_b_grid,
p_z_grid,
......@@ -861,7 +878,13 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
compute_base_ptr_of_batch,
c0_matrix_mask,
BlockStart,
BlockEnd});
BlockEnd,
z_random_matrix_offset,
raw_m_padded,
raw_n_padded});
z_random_matrix_offset =
z_random_matrix_offset + raw_m_padded * raw_n_padded * batch_count;
group_device_args_.push_back(
{{problem_desc.a_gs_ms_ks_lengths[NumDimG + NumDimM - 1],
......
......@@ -132,6 +132,9 @@ __global__ void
arg_ptr[group_id].c0_matrix_mask_,
p_dropout,
ph,
arg_ptr[group_id].z_random_matrix_offset_ +
g_idx * arg_ptr[group_id].raw_m_padded_ * arg_ptr[group_id].raw_n_padded_,
arg_ptr[group_id].raw_n_padded_,
i);
}
}
......@@ -165,6 +168,9 @@ __global__ void
arg_ptr[group_id].c0_matrix_mask_,
p_dropout,
ph,
arg_ptr[group_id].z_random_matrix_offset_ +
g_idx * arg_ptr[group_id].raw_m_padded_ * arg_ptr[group_id].raw_n_padded_,
arg_ptr[group_id].raw_n_padded_,
0);
}
#else
......@@ -662,7 +668,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
y_grid_desc_mblock_mperblock_oblock_operblock_;
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_;
LSEGridDesc_M lse_grid_desc_m_;
KGridDesc_N_K k_grid_desc_n_k_;
......@@ -675,6 +681,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
// check C0 masking and padding
C0MatrixMask c0_matrix_mask_;
index_t block_start_, block_end_;
index_t z_random_matrix_offset_;
index_t raw_m_padded_, raw_n_padded_;
};
struct GroupDeviceArg
......@@ -748,6 +757,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
}
grid_size_ = 0;
index_t z_random_matrix_offset = 0;
for(index_t i = 0; i < group_count_; i++)
{
const auto p_a_grid = static_cast<const InputDataType*>(p_As[i]);
......@@ -795,7 +807,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
problem_desc.c_gs_ms_gemm1ns_lengths, problem_desc.c_gs_ms_gemm1ns_strides);
typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
y_grid_desc_mblock_mperblock_oblock_operblock;
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5;
const index_t BlockStart = grid_size_;
const auto block_2_ctile_map = Block2CTileMap(k_grid_desc_n_k, BlockStart);
......@@ -810,7 +822,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
}
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(
z_grid_desc_m_n);
const index_t batch_count = c_grid_desc_g_m_n.GetLength(I0);
......@@ -844,6 +856,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
"match that in template argument");
}
const auto raw_m_padded = GridwiseGemm::GetPaddedSize(
problem_desc.a_gs_ms_ks_lengths[NumDimG + NumDimM - 1]);
const auto raw_n_padded = GridwiseGemm::GetPaddedSize(
problem_desc.b_gs_ns_ks_lengths[NumDimG + NumDimN - 1]);
group_kernel_args_.push_back({p_a_grid,
p_b_grid,
p_z_grid,
......@@ -869,7 +886,13 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
compute_base_ptr_of_batch,
c0_matrix_mask,
BlockStart,
BlockEnd});
BlockEnd,
z_random_matrix_offset,
raw_m_padded,
raw_n_padded});
z_random_matrix_offset =
z_random_matrix_offset + raw_m_padded * raw_n_padded * batch_count;
group_device_args_.push_back(
{{problem_desc.a_gs_ms_ks_lengths[NumDimG + NumDimM - 1],
......
......@@ -132,6 +132,9 @@ __global__ void
p_dropout_in_16bits,
p_dropout_rescale,
ph,
arg_ptr[group_id].z_random_matrix_offset_ +
g_idx * arg_ptr[group_id].raw_m_padded_ * arg_ptr[group_id].raw_n_padded_,
arg_ptr[group_id].raw_n_padded_,
i);
}
}
......@@ -165,6 +168,9 @@ __global__ void
p_dropout_in_16bits,
p_dropout_rescale,
ph,
arg_ptr[group_id].z_random_matrix_offset_ +
g_idx * arg_ptr[group_id].raw_m_padded_ * arg_ptr[group_id].raw_n_padded_,
arg_ptr[group_id].raw_n_padded_,
0);
}
#else
......@@ -567,6 +573,9 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
Block2CTileMap block_2_ctile_map_;
index_t block_start_, block_end_;
index_t z_random_matrix_offset_;
index_t raw_m_padded_, raw_n_padded_;
};
struct GroupDeviceArg
......@@ -626,6 +635,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
grid_size_ = 0;
index_t z_random_matrix_offset = 0;
for(std::size_t i = 0; i < group_count_; i++)
{
const auto p_a_grid = static_cast<const ADataType*>(p_a_vec[i]);
......@@ -712,6 +723,11 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
"match that in template argument");
}
const auto raw_m_padded = GridwiseGemm::GetPaddedSize(
problem_desc.a_gs_ms_ks_lengths[NumDimG + NumDimM - 1]);
const auto raw_n_padded = GridwiseGemm::GetPaddedSize(
problem_desc.b0_gs_ns_ks_lengths[NumDimG + NumDimN - 1]);
group_kernel_args_.push_back({p_a_grid,
p_b_grid,
p_b1_grid,
......@@ -730,7 +746,13 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
c0_matrix_mask,
block_2_ctile_map,
BlockStart,
BlockEnd});
BlockEnd,
z_random_matrix_offset,
raw_m_padded,
raw_n_padded});
z_random_matrix_offset =
z_random_matrix_offset + raw_m_padded * raw_n_padded * batch_count;
group_device_args_.push_back(
{{problem_desc.a_gs_ms_ks_lengths[NumDimG + NumDimM - 1],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment