Commit 9679ba63 authored by letaoqin's avatar letaoqin
Browse files

add verify d0 vector load

parent 5d6bfabb
......@@ -358,6 +358,15 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
B1Spec,
CSpec>;
using RawTransform = TransformBatchedContractionContractionToBatchedGemmGemm<
Sequence<NumDimG, NumDimM, NumDimN, NumDimK, NumDimO>,
Sequence<MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock>,
GemmSpecialization::Default,
ASpec,
BSpec,
B1Spec,
CSpec>;
static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector<index_t>& a_gs_ms_ks_lengths_vec,
const std::vector<index_t>& a_gs_ms_ks_strides_vec)
{
......@@ -563,8 +572,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
true,
Acc0BiasTransferSrcScalarPerVector,
BBlockLdsExtraN,
Acc0BiasTransferSrcScalarPerVector,
B1BlockTransferThreadClusterLengths_BK0_N_BK1,
B1BlockTransferThreadClusterArrangeOrder,
B1BlockTransferSrcAccessOrder,
......@@ -639,6 +648,9 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
// for gridwise gemm check
CGridDesc_M_N c_grid_desc_m_n_;
// raw data
int raw_d0_n_;
};
// Argument
......@@ -820,6 +832,11 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
z_random_matrix_offset =
z_random_matrix_offset + raw_m_padded * raw_n_padded * batch_count;
auto raw_d0_m_n = NumD0Tensor == 0
? RawTransform::MakeCGridDescriptor_M_N({}, {})
: RawTransform::MakeCGridDescriptor_M_N(
problem_desc.acc0_biases_gs_ms_ns_lengths[0],
problem_desc.acc0_biases_gs_ms_ns_strides[0]);
group_device_args_.push_back(
{{problem_desc.a_gs_ms_ks_lengths[NumDimG + NumDimM - 1],
problem_desc.b0_gs_ns_ks_lengths[NumDimG + NumDimN - 1],
......@@ -833,7 +850,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
problem_desc.b1_gs_os_ns_strides[NumDimG + NumDimO + NumDimN - 1]},
{problem_desc.c_gs_ms_os_strides[NumDimG + NumDimM - 1],
problem_desc.c_gs_ms_os_strides[NumDimG + NumDimM + NumDimO - 1]},
c_grid_desc_m_n});
c_grid_desc_m_n,
NumD0Tensor == 0 ? 0 : raw_d0_m_n.GetLength(I1)});
}
is_dropout_ = p_dropout > 0.0; //
......@@ -1048,6 +1066,12 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
const index_t c_gemm1n = device_arg.c_grid_desc_m_n_.GetLength(I1);
const index_t a_m = kernel_arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
const index_t b1_gemm1n = kernel_arg.b1_grid_desc_bk0_n_bk1_.GetLength(I1);
if(device_arg.raw_d0_n_ % Acc0BiasTransferSrcScalarPerVector != 0)
{
return false;
}
if(!(c_m == a_m && c_gemm1n == b1_gemm1n))
{
return false;
......
......@@ -95,6 +95,10 @@ template <typename FloatAB,
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
{
static_assert(D0BlockTransferSrcScalarPerVector == 1 ||
D0BlockTransferSrcScalarPerVector == 2 ||
D0BlockTransferSrcScalarPerVector == 4,
"D0BlockTransferSrcScalarPerVector must be 1 or 2 or 4");
using DDataType = FloatAB;
static_assert(LoopSched == LoopScheduler::Default,
"Non-default loop scheduler is currently not supported");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment