Commit d8998cbb authored by letaoqin's avatar letaoqin
Browse files

add check code for vectorload

parent 13a0c55d
......@@ -119,6 +119,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadA
8,
8,
true,
4,
S<16, 16, 1>, // B1BlockTransfer
S<0, 2, 1>,
S<0, 2, 1>,
......
......@@ -192,7 +192,7 @@ int run(int argc, char* argv[])
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount;
std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2 + size_t(M) * N) * BatchCount;
std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N +
sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O + sizeof(Acc0BiasDataType) * M * N) *
BatchCount;
......
......@@ -5,7 +5,7 @@ int run(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
bool time_kernel = true;
bool input_permute = false;
bool output_permute = true;
......@@ -129,7 +129,7 @@ int run(int argc, char* argv[])
Tensor<CDataType> c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides);
int Batch = G0 * G1;
flop += (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * Batch;
flop += (size_t(M) * N * K * 2 + size_t(M) * N * O * 2 + size_t(M) * N) * Batch;
num_byte +=
(sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N + sizeof(B1DataType) * N * O +
sizeof(CDataType) * M * O + sizeof(Acc0BiasDataType) * M * N) *
......
......@@ -742,12 +742,13 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl
if constexpr(!is_same<D0DataType, void>::value)
{
if(arg.d0_n_length_stride_[1] == 1 &&
arg.d0_n_length_stride_[0] % Acc0BiasTransferSrcScalarPerVector != 0)
if(arg.d0_n_length_stride_[1] == 1)
{
if(!(arg.d0_n_length_stride_[0] % Acc0BiasTransferSrcScalarPerVector == 0 ||
Transform::matrix_padder.PadN))
return false;
}
if(arg.d0_n_length_stride_[1] != 1 && Acc0BiasTransferSrcScalarPerVector != 1)
else if(Acc0BiasTransferSrcScalarPerVector != 1)
{
return false;
}
......
......@@ -1026,12 +1026,13 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
if constexpr(!is_same<D0DataType, void>::value)
{
if(arg.d0_n_length_stride_[1] == 1 &&
arg.d0_n_length_stride_[0] % Acc0BiasTransferSrcScalarPerVector != 0)
if(arg.d0_n_length_stride_[1] == 1)
{
if(!(arg.d0_n_length_stride_[0] % Acc0BiasTransferSrcScalarPerVector == 0 ||
Transform::matrix_padder.PadN))
return false;
}
if(arg.d0_n_length_stride_[1] != 1 && Acc0BiasTransferSrcScalarPerVector != 1)
else if(Acc0BiasTransferSrcScalarPerVector != 1)
{
return false;
}
......
......@@ -180,6 +180,7 @@ template <index_t NumDimG,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BBlockLdsExtraN,
index_t Acc0BiasTransferSrcScalarPerVector,
typename B1BlockTransferThreadClusterLengths_BK0_N_BK1,
typename B1BlockTransferThreadClusterArrangeOrder,
typename B1BlockTransferSrcAccessOrder,
......@@ -429,7 +430,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl
BBlockTransferDstScalarPerVector_BK1,
true,
BBlockLdsExtraN,
4,
Acc0BiasTransferSrcScalarPerVector,
B1BlockTransferThreadClusterLengths_BK0_N_BK1,
B1BlockTransferThreadClusterArrangeOrder,
B1BlockTransferSrcAccessOrder,
......@@ -493,6 +494,9 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl
// for gridwise gemm check
C1GridDesc_M_N c1_grid_desc_m_n_;
// raw data
std::vector<ck::index_t> d0_n_length_stride_;
};
// Argument
......@@ -625,6 +629,10 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl
BlockStart,
BlockEnd});
std::vector<ck::index_t> d0_n_length_stride;
d0_n_length_stride.push_back(tmp_d0_gs_ms_ns_lengths[NumDimG + NumDimM]);
d0_n_length_stride.push_back(tmp_d0_gs_ms_ns_strides[NumDimG + NumDimM]);
group_device_args_.push_back(
{{problem_desc.a_gs_ms_ks_lengths[NumDimG + NumDimM - 1],
problem_desc.b0_gs_ns_ks_lengths[NumDimG + NumDimN - 1],
......@@ -638,7 +646,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl
problem_desc.b1_gs_os_ns_strides[NumDimG + NumDimO + NumDimN - 1]},
{problem_desc.c_gs_ms_os_strides[NumDimG + NumDimM - 1],
problem_desc.c_gs_ms_os_strides[NumDimG + NumDimM + NumDimO - 1]},
c_grid_desc_m_n});
c_grid_desc_m_n,
d0_n_length_stride});
}
}
......@@ -774,6 +783,24 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl
return false;
}
if constexpr(!is_same<D0DataType, void>::value)
{
if(device_arg.d0_n_length_stride_[1] == 1)
{
if(!(device_arg.d0_n_length_stride_[0] % Acc0BiasTransferSrcScalarPerVector ==
0 ||
Transform::matrix_padder.PadN))
{
return false;
}
}
else if(Acc0BiasTransferSrcScalarPerVector != 1)
{
return false;
}
}
// Check if having main loop
const auto K = kernel_arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) *
kernel_arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
......
......@@ -1102,13 +1102,14 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
if constexpr(!is_same<D0DataType, void>::value)
{
if(device_arg.d0_n_length_stride_[1] == 1 &&
device_arg.d0_n_length_stride_[0] % Acc0BiasTransferSrcScalarPerVector != 0)
if(device_arg.d0_n_length_stride_[1] == 1)
{
if(!(device_arg.d0_n_length_stride_[0] % Acc0BiasTransferSrcScalarPerVector ==
0 ||
Transform::matrix_padder.PadN))
return false;
}
if(device_arg.d0_n_length_stride_[1] != 1 &&
Acc0BiasTransferSrcScalarPerVector != 1)
else if(Acc0BiasTransferSrcScalarPerVector != 1)
{
return false;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment