Commit 5608328c authored by Jing Zhang's avatar Jing Zhang
Browse files

fixed vector_load check

parent 77bdb740
......@@ -101,7 +101,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceContractionMultiple
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
1,
8,
1,
S<4, 64, 1>,
......@@ -130,7 +130,7 @@ int main(int argc, char* argv[])
std::vector<ck::index_t> a0_ms_ks_strides{128 * 32 * 64, 32 * 64, 64, 1};
// A1[M1, K1] -> A1[M0, M1, K0, K1]
std::vector<ck::index_t> a1_ms_ks_lengths{30, 128, 32, 64};
std::vector<ck::index_t> a1_ms_ks_strides{0, 64, 0, 1};
std::vector<ck::index_t> a1_ms_ks_strides{0, 64, 1, 0};
// B[N0, N1, K0, K1]
std::vector<ck::index_t> b_ns_ks_lengths{32, 64, 32, 64};
std::vector<ck::index_t> b_ns_ks_strides{64 * 32 * 64, 32 * 64, 64, 1};
......
......@@ -649,7 +649,8 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
{
if(!(arg.a_mz_stride_[i] == 1 && arg.as_grid_desc_ak0_m_ak1_[i].GetLength(I1) %
ABlockTransferSrcScalarPerVector ==
0))
0) &&
ABlockTransferSrcScalarPerVector != 1)
{
all_valid = false;
}
......@@ -658,7 +659,8 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
{
if(!(arg.a_kz_stride_[i] == 1 && arg.as_grid_desc_ak0_m_ak1_[i].GetLength(I2) %
ABlockTransferSrcScalarPerVector ==
0))
0) &&
ABlockTransferSrcScalarPerVector != 1)
{
all_valid = false;
}
......@@ -671,7 +673,8 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
{
if(!(arg.b_nz_stride_[i] == 1 && arg.bs_grid_desc_bk0_n_bk1_[i].GetLength(I1) %
BBlockTransferSrcScalarPerVector ==
0))
0) &&
BBlockTransferSrcScalarPerVector != 1)
{
all_valid = false;
}
......@@ -680,7 +683,8 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
{
if(!(arg.b_kz_stride_[i] == 1 && arg.bs_grid_desc_bk0_n_bk1_[i].GetLength(I2) %
BBlockTransferSrcScalarPerVector ==
0))
0) &&
BBlockTransferSrcScalarPerVector != 1)
{
all_valid = false;
}
......@@ -692,7 +696,8 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
if(!(arg.ds_nz_stride_[i] == 1 &&
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_[i].GetLength(I3) %
CDEBlockTransferScalarPerVector_NPerBlock ==
0))
0) &&
CDEBlockTransferScalarPerVector_NPerBlock != 1)
{
all_valid = false;
}
......@@ -702,7 +707,8 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
if(!(arg.e_nz_stride_ == 1 &&
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_.GetLength(I3) %
CDEBlockTransferScalarPerVector_NPerBlock ==
0))
0) &&
CDEBlockTransferScalarPerVector_NPerBlock != 1)
{
all_valid = false;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment