Commit e21c1785 authored by Astha Rai's avatar Astha Rai
Browse files

changed isSupportedArgument for 2D

parent 64026bc3
...@@ -250,26 +250,35 @@ struct DeviceElementwise ...@@ -250,26 +250,35 @@ struct DeviceElementwise
if(pArg->lengths_.back() % MPerThread != 0) if(pArg->lengths_.back() % MPerThread != 0)
return false; return false;
std::cout << "lengths back: " << pArg->lengths_.back() << std::endl;
auto IsScalarPerVectorValid = [&](const std::array<index_t, NumDim>& lengths, auto IsScalarPerVectorValid = [&](const std::array<index_t, NumDim>& lengths,
const std::array<index_t, NumDim>& strides, const std::array<index_t, NumDim>& strides,
index_t scalarPerVector) { index_t scalarPerVector) {
if(strides.back() == 1 && lengths.back() % scalarPerVector == 0) std::cout << "scalarPerVector: " << scalarPerVector << std::endl;
return true; std::cout << "stride back: " << strides.back() << std::endl;
std::cout << "ISPVV Check 1 starting" << std::endl;
if(strides.back() == 1 && lengths.back() % scalarPerVector == 0){
return true; }
std::cout << "Check 1 failed " << std::endl;
if(strides.back() != 1 && scalarPerVector == 1) std::cout << "ISPVV Check 2 starting" << std::endl;
return true; if(strides.back() != 1 && scalarPerVector == MPerThread){
return true; }
return false; return false;
}; };
bool valid = true; bool valid = true;
static_for<0, NumInput, 1>{}([&](auto I) { static_for<0, NumInput, 1>{}([&](auto I) {
std::cout<< "running: " << I << std::endl;
if(!IsScalarPerVectorValid( if(!IsScalarPerVectorValid(
pArg->lengths_, pArg->inStridesArray_[I.value], InScalarPerVectorSeq::At(I))) pArg->lengths_, pArg->inStridesArray_[I.value], InScalarPerVectorSeq::At(I)))
valid = false; valid = false;
}); });
static_for<0, NumOutput, 1>{}([&](auto I) { static_for<0, NumOutput, 1>{}([&](auto I) {
std::cout << "running 2: " << I << std::endl;
if(!IsScalarPerVectorValid( if(!IsScalarPerVectorValid(
pArg->lengths_, pArg->outStridesArray_[I.value], OutScalarPerVectorSeq::At(I))) pArg->lengths_, pArg->outStridesArray_[I.value], OutScalarPerVectorSeq::At(I)))
valid = false; valid = false;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment