Commit 5608328c authored by Jing Zhang's avatar Jing Zhang
Browse files

fixed vector_load check

parent 77bdb740
...@@ -101,7 +101,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceContractionMultiple ...@@ -101,7 +101,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceContractionMultiple
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
2, 2,
8, 1,
8, 8,
1, 1,
S<4, 64, 1>, S<4, 64, 1>,
...@@ -130,7 +130,7 @@ int main(int argc, char* argv[]) ...@@ -130,7 +130,7 @@ int main(int argc, char* argv[])
std::vector<ck::index_t> a0_ms_ks_strides{128 * 32 * 64, 32 * 64, 64, 1}; std::vector<ck::index_t> a0_ms_ks_strides{128 * 32 * 64, 32 * 64, 64, 1};
// A1[M1, K1] -> A1[M0, M1, K0, K1] // A1[M1, K1] -> A1[M0, M1, K0, K1]
std::vector<ck::index_t> a1_ms_ks_lengths{30, 128, 32, 64}; std::vector<ck::index_t> a1_ms_ks_lengths{30, 128, 32, 64};
std::vector<ck::index_t> a1_ms_ks_strides{0, 64, 0, 1}; std::vector<ck::index_t> a1_ms_ks_strides{0, 64, 1, 0};
// B[N0, N1, K0, K1] // B[N0, N1, K0, K1]
std::vector<ck::index_t> b_ns_ks_lengths{32, 64, 32, 64}; std::vector<ck::index_t> b_ns_ks_lengths{32, 64, 32, 64};
std::vector<ck::index_t> b_ns_ks_strides{64 * 32 * 64, 32 * 64, 64, 1}; std::vector<ck::index_t> b_ns_ks_strides{64 * 32 * 64, 32 * 64, 64, 1};
......
...@@ -649,7 +649,8 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle ...@@ -649,7 +649,8 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
{ {
if(!(arg.a_mz_stride_[i] == 1 && arg.as_grid_desc_ak0_m_ak1_[i].GetLength(I1) % if(!(arg.a_mz_stride_[i] == 1 && arg.as_grid_desc_ak0_m_ak1_[i].GetLength(I1) %
ABlockTransferSrcScalarPerVector == ABlockTransferSrcScalarPerVector ==
0)) 0) &&
ABlockTransferSrcScalarPerVector != 1)
{ {
all_valid = false; all_valid = false;
} }
...@@ -658,7 +659,8 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle ...@@ -658,7 +659,8 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
{ {
if(!(arg.a_kz_stride_[i] == 1 && arg.as_grid_desc_ak0_m_ak1_[i].GetLength(I2) % if(!(arg.a_kz_stride_[i] == 1 && arg.as_grid_desc_ak0_m_ak1_[i].GetLength(I2) %
ABlockTransferSrcScalarPerVector == ABlockTransferSrcScalarPerVector ==
0)) 0) &&
ABlockTransferSrcScalarPerVector != 1)
{ {
all_valid = false; all_valid = false;
} }
...@@ -671,7 +673,8 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle ...@@ -671,7 +673,8 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
{ {
if(!(arg.b_nz_stride_[i] == 1 && arg.bs_grid_desc_bk0_n_bk1_[i].GetLength(I1) % if(!(arg.b_nz_stride_[i] == 1 && arg.bs_grid_desc_bk0_n_bk1_[i].GetLength(I1) %
BBlockTransferSrcScalarPerVector == BBlockTransferSrcScalarPerVector ==
0)) 0) &&
BBlockTransferSrcScalarPerVector != 1)
{ {
all_valid = false; all_valid = false;
} }
...@@ -680,7 +683,8 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle ...@@ -680,7 +683,8 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
{ {
if(!(arg.b_kz_stride_[i] == 1 && arg.bs_grid_desc_bk0_n_bk1_[i].GetLength(I2) % if(!(arg.b_kz_stride_[i] == 1 && arg.bs_grid_desc_bk0_n_bk1_[i].GetLength(I2) %
BBlockTransferSrcScalarPerVector == BBlockTransferSrcScalarPerVector ==
0)) 0) &&
BBlockTransferSrcScalarPerVector != 1)
{ {
all_valid = false; all_valid = false;
} }
...@@ -692,7 +696,8 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle ...@@ -692,7 +696,8 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
if(!(arg.ds_nz_stride_[i] == 1 && if(!(arg.ds_nz_stride_[i] == 1 &&
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_[i].GetLength(I3) % arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_[i].GetLength(I3) %
CDEBlockTransferScalarPerVector_NPerBlock == CDEBlockTransferScalarPerVector_NPerBlock ==
0)) 0) &&
CDEBlockTransferScalarPerVector_NPerBlock != 1)
{ {
all_valid = false; all_valid = false;
} }
...@@ -702,7 +707,8 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle ...@@ -702,7 +707,8 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
if(!(arg.e_nz_stride_ == 1 && if(!(arg.e_nz_stride_ == 1 &&
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_.GetLength(I3) % arg.e_grid_desc_mblock_mperblock_nblock_nperblock_.GetLength(I3) %
CDEBlockTransferScalarPerVector_NPerBlock == CDEBlockTransferScalarPerVector_NPerBlock ==
0)) 0) &&
CDEBlockTransferScalarPerVector_NPerBlock != 1)
{ {
all_valid = false; all_valid = false;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment