"test/git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "7d0f33699c5ebd60de028ee95c0faf287f764510"
Commit 9679ba63 authored by letaoqin's avatar letaoqin
Browse files

add verify d0 vector load

parent 5d6bfabb
...@@ -358,6 +358,15 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -358,6 +358,15 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
B1Spec, B1Spec,
CSpec>; CSpec>;
using RawTransform = TransformBatchedContractionContractionToBatchedGemmGemm<
Sequence<NumDimG, NumDimM, NumDimN, NumDimK, NumDimO>,
Sequence<MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock>,
GemmSpecialization::Default,
ASpec,
BSpec,
B1Spec,
CSpec>;
static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector<index_t>& a_gs_ms_ks_lengths_vec, static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector<index_t>& a_gs_ms_ks_lengths_vec,
const std::vector<index_t>& a_gs_ms_ks_strides_vec) const std::vector<index_t>& a_gs_ms_ks_strides_vec)
{ {
...@@ -563,8 +572,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -563,8 +572,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1, BBlockTransferDstScalarPerVector_BK1,
true, true,
Acc0BiasTransferSrcScalarPerVector,
BBlockLdsExtraN, BBlockLdsExtraN,
Acc0BiasTransferSrcScalarPerVector,
B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterLengths_BK0_N_BK1,
B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferThreadClusterArrangeOrder,
B1BlockTransferSrcAccessOrder, B1BlockTransferSrcAccessOrder,
...@@ -639,6 +648,9 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -639,6 +648,9 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
// for gridwise gemm check // for gridwise gemm check
CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M_N c_grid_desc_m_n_;
// raw data
int raw_d0_n_;
}; };
// Argument // Argument
...@@ -820,6 +832,11 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -820,6 +832,11 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
z_random_matrix_offset = z_random_matrix_offset =
z_random_matrix_offset + raw_m_padded * raw_n_padded * batch_count; z_random_matrix_offset + raw_m_padded * raw_n_padded * batch_count;
auto raw_d0_m_n = NumD0Tensor == 0
? RawTransform::MakeCGridDescriptor_M_N({}, {})
: RawTransform::MakeCGridDescriptor_M_N(
problem_desc.acc0_biases_gs_ms_ns_lengths[0],
problem_desc.acc0_biases_gs_ms_ns_strides[0]);
group_device_args_.push_back( group_device_args_.push_back(
{{problem_desc.a_gs_ms_ks_lengths[NumDimG + NumDimM - 1], {{problem_desc.a_gs_ms_ks_lengths[NumDimG + NumDimM - 1],
problem_desc.b0_gs_ns_ks_lengths[NumDimG + NumDimN - 1], problem_desc.b0_gs_ns_ks_lengths[NumDimG + NumDimN - 1],
...@@ -833,7 +850,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -833,7 +850,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
problem_desc.b1_gs_os_ns_strides[NumDimG + NumDimO + NumDimN - 1]}, problem_desc.b1_gs_os_ns_strides[NumDimG + NumDimO + NumDimN - 1]},
{problem_desc.c_gs_ms_os_strides[NumDimG + NumDimM - 1], {problem_desc.c_gs_ms_os_strides[NumDimG + NumDimM - 1],
problem_desc.c_gs_ms_os_strides[NumDimG + NumDimM + NumDimO - 1]}, problem_desc.c_gs_ms_os_strides[NumDimG + NumDimM + NumDimO - 1]},
c_grid_desc_m_n}); c_grid_desc_m_n,
NumD0Tensor == 0 ? 0 : raw_d0_m_n.GetLength(I1)});
} }
is_dropout_ = p_dropout > 0.0; // is_dropout_ = p_dropout > 0.0; //
...@@ -1048,6 +1066,12 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -1048,6 +1066,12 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
const index_t c_gemm1n = device_arg.c_grid_desc_m_n_.GetLength(I1); const index_t c_gemm1n = device_arg.c_grid_desc_m_n_.GetLength(I1);
const index_t a_m = kernel_arg.a_grid_desc_ak0_m_ak1_.GetLength(I1); const index_t a_m = kernel_arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
const index_t b1_gemm1n = kernel_arg.b1_grid_desc_bk0_n_bk1_.GetLength(I1); const index_t b1_gemm1n = kernel_arg.b1_grid_desc_bk0_n_bk1_.GetLength(I1);
if(device_arg.raw_d0_n_ % Acc0BiasTransferSrcScalarPerVector != 0)
{
return false;
}
if(!(c_m == a_m && c_gemm1n == b1_gemm1n)) if(!(c_m == a_m && c_gemm1n == b1_gemm1n))
{ {
return false; return false;
......
...@@ -95,6 +95,10 @@ template <typename FloatAB, ...@@ -95,6 +95,10 @@ template <typename FloatAB,
PipelineVersion PipelineVer = PipelineVersion::v1> PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2 struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
{ {
static_assert(D0BlockTransferSrcScalarPerVector == 1 ||
D0BlockTransferSrcScalarPerVector == 2 ||
D0BlockTransferSrcScalarPerVector == 4,
"D0BlockTransferSrcScalarPerVector must be 1 or 2 or 4");
using DDataType = FloatAB; using DDataType = FloatAB;
static_assert(LoopSched == LoopScheduler::Default, static_assert(LoopSched == LoopScheduler::Default,
"Non-default loop scheduler is currently not supported"); "Non-default loop scheduler is currently not supported");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment