Commit d8998cbb authored by letaoqin's avatar letaoqin
Browse files

add check code for vectorload

parent 13a0c55d
...@@ -119,6 +119,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadA ...@@ -119,6 +119,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadA
8, 8,
8, 8,
true, true,
4,
S<16, 16, 1>, // B1BlockTransfer S<16, 16, 1>, // B1BlockTransfer
S<0, 2, 1>, S<0, 2, 1>,
S<0, 2, 1>, S<0, 2, 1>,
......
...@@ -192,7 +192,7 @@ int run(int argc, char* argv[]) ...@@ -192,7 +192,7 @@ int run(int argc, char* argv[])
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount; std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2 + size_t(M) * N) * BatchCount;
std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N + std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N +
sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O + sizeof(Acc0BiasDataType) * M * N) * sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O + sizeof(Acc0BiasDataType) * M * N) *
BatchCount; BatchCount;
......
...@@ -5,7 +5,7 @@ int run(int argc, char* argv[]) ...@@ -5,7 +5,7 @@ int run(int argc, char* argv[])
{ {
bool do_verification = true; bool do_verification = true;
int init_method = 1; int init_method = 1;
bool time_kernel = false; bool time_kernel = true;
bool input_permute = false; bool input_permute = false;
bool output_permute = true; bool output_permute = true;
...@@ -129,7 +129,7 @@ int run(int argc, char* argv[]) ...@@ -129,7 +129,7 @@ int run(int argc, char* argv[])
Tensor<CDataType> c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); Tensor<CDataType> c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides);
int Batch = G0 * G1; int Batch = G0 * G1;
flop += (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * Batch; flop += (size_t(M) * N * K * 2 + size_t(M) * N * O * 2 + size_t(M) * N) * Batch;
num_byte += num_byte +=
(sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N + sizeof(B1DataType) * N * O + (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N + sizeof(B1DataType) * N * O +
sizeof(CDataType) * M * O + sizeof(Acc0BiasDataType) * M * N) * sizeof(CDataType) * M * O + sizeof(Acc0BiasDataType) * M * N) *
......
...@@ -742,12 +742,13 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl ...@@ -742,12 +742,13 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl
if constexpr(!is_same<D0DataType, void>::value) if constexpr(!is_same<D0DataType, void>::value)
{ {
if(arg.d0_n_length_stride_[1] == 1 && if(arg.d0_n_length_stride_[1] == 1)
arg.d0_n_length_stride_[0] % Acc0BiasTransferSrcScalarPerVector != 0)
{ {
return false; if(!(arg.d0_n_length_stride_[0] % Acc0BiasTransferSrcScalarPerVector == 0 ||
Transform::matrix_padder.PadN))
return false;
} }
if(arg.d0_n_length_stride_[1] != 1 && Acc0BiasTransferSrcScalarPerVector != 1) else if(Acc0BiasTransferSrcScalarPerVector != 1)
{ {
return false; return false;
} }
......
...@@ -1026,12 +1026,13 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -1026,12 +1026,13 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
if constexpr(!is_same<D0DataType, void>::value) if constexpr(!is_same<D0DataType, void>::value)
{ {
if(arg.d0_n_length_stride_[1] == 1 && if(arg.d0_n_length_stride_[1] == 1)
arg.d0_n_length_stride_[0] % Acc0BiasTransferSrcScalarPerVector != 0)
{ {
return false; if(!(arg.d0_n_length_stride_[0] % Acc0BiasTransferSrcScalarPerVector == 0 ||
Transform::matrix_padder.PadN))
return false;
} }
if(arg.d0_n_length_stride_[1] != 1 && Acc0BiasTransferSrcScalarPerVector != 1) else if(Acc0BiasTransferSrcScalarPerVector != 1)
{ {
return false; return false;
} }
......
...@@ -180,6 +180,7 @@ template <index_t NumDimG, ...@@ -180,6 +180,7 @@ template <index_t NumDimG,
index_t BBlockTransferSrcScalarPerVector, index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1, index_t BBlockTransferDstScalarPerVector_BK1,
bool BBlockLdsExtraN, bool BBlockLdsExtraN,
index_t Acc0BiasTransferSrcScalarPerVector,
typename B1BlockTransferThreadClusterLengths_BK0_N_BK1, typename B1BlockTransferThreadClusterLengths_BK0_N_BK1,
typename B1BlockTransferThreadClusterArrangeOrder, typename B1BlockTransferThreadClusterArrangeOrder,
typename B1BlockTransferSrcAccessOrder, typename B1BlockTransferSrcAccessOrder,
...@@ -429,7 +430,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl ...@@ -429,7 +430,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl
BBlockTransferDstScalarPerVector_BK1, BBlockTransferDstScalarPerVector_BK1,
true, true,
BBlockLdsExtraN, BBlockLdsExtraN,
4, Acc0BiasTransferSrcScalarPerVector,
B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterLengths_BK0_N_BK1,
B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferThreadClusterArrangeOrder,
B1BlockTransferSrcAccessOrder, B1BlockTransferSrcAccessOrder,
...@@ -493,6 +494,9 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl ...@@ -493,6 +494,9 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl
// for gridwise gemm check // for gridwise gemm check
C1GridDesc_M_N c1_grid_desc_m_n_; C1GridDesc_M_N c1_grid_desc_m_n_;
// raw data
std::vector<ck::index_t> d0_n_length_stride_;
}; };
// Argument // Argument
...@@ -625,6 +629,10 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl ...@@ -625,6 +629,10 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl
BlockStart, BlockStart,
BlockEnd}); BlockEnd});
std::vector<ck::index_t> d0_n_length_stride;
d0_n_length_stride.push_back(tmp_d0_gs_ms_ns_lengths[NumDimG + NumDimM]);
d0_n_length_stride.push_back(tmp_d0_gs_ms_ns_strides[NumDimG + NumDimM]);
group_device_args_.push_back( group_device_args_.push_back(
{{problem_desc.a_gs_ms_ks_lengths[NumDimG + NumDimM - 1], {{problem_desc.a_gs_ms_ks_lengths[NumDimG + NumDimM - 1],
problem_desc.b0_gs_ns_ks_lengths[NumDimG + NumDimN - 1], problem_desc.b0_gs_ns_ks_lengths[NumDimG + NumDimN - 1],
...@@ -638,7 +646,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl ...@@ -638,7 +646,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl
problem_desc.b1_gs_os_ns_strides[NumDimG + NumDimO + NumDimN - 1]}, problem_desc.b1_gs_os_ns_strides[NumDimG + NumDimO + NumDimN - 1]},
{problem_desc.c_gs_ms_os_strides[NumDimG + NumDimM - 1], {problem_desc.c_gs_ms_os_strides[NumDimG + NumDimM - 1],
problem_desc.c_gs_ms_os_strides[NumDimG + NumDimM + NumDimO - 1]}, problem_desc.c_gs_ms_os_strides[NumDimG + NumDimM + NumDimO - 1]},
c_grid_desc_m_n}); c_grid_desc_m_n,
d0_n_length_stride});
} }
} }
...@@ -774,6 +783,24 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl ...@@ -774,6 +783,24 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl
return false; return false;
} }
if constexpr(!is_same<D0DataType, void>::value)
{
if(device_arg.d0_n_length_stride_[1] == 1)
{
if(!(device_arg.d0_n_length_stride_[0] % Acc0BiasTransferSrcScalarPerVector ==
0 ||
Transform::matrix_padder.PadN))
{
return false;
}
}
else if(Acc0BiasTransferSrcScalarPerVector != 1)
{
return false;
}
}
// Check if having main loop // Check if having main loop
const auto K = kernel_arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * const auto K = kernel_arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) *
kernel_arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); kernel_arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
......
...@@ -1102,13 +1102,14 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -1102,13 +1102,14 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
if constexpr(!is_same<D0DataType, void>::value) if constexpr(!is_same<D0DataType, void>::value)
{ {
if(device_arg.d0_n_length_stride_[1] == 1 && if(device_arg.d0_n_length_stride_[1] == 1)
device_arg.d0_n_length_stride_[0] % Acc0BiasTransferSrcScalarPerVector != 0)
{ {
return false; if(!(device_arg.d0_n_length_stride_[0] % Acc0BiasTransferSrcScalarPerVector ==
0 ||
Transform::matrix_padder.PadN))
return false;
} }
if(device_arg.d0_n_length_stride_[1] != 1 && else if(Acc0BiasTransferSrcScalarPerVector != 1)
Acc0BiasTransferSrcScalarPerVector != 1)
{ {
return false; return false;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment