Commit 3a9e6db3 authored by Astha Rai's avatar Astha Rai
Browse files

fixed isSupportedArgument

parent c2487eaa
......@@ -246,43 +246,73 @@ struct DeviceElementwise
if(pArg == nullptr)
return false;
std::cout << "made it here" << std::endl;
std::cout << "lengths back: " << pArg->lengths_.back() << std::endl;
if(pArg->lengths_.back() % MPerThread != 0)
return false;
std::cout << "lengths back: " << pArg->lengths_.back() << std::endl;
auto IsScalarPerVectorValid = [&](const std::array<index_t, NumDim>& lengths,
const std::array<index_t, NumDim>& strides,
index_t scalarPerVector) {
std::cout << "scalarPerVector: " << scalarPerVector << std::endl;
std::cout << "stride back: " << strides.back() << std::endl;
std::cout << "ISPVV Check 1 starting" << std::endl;
if(strides.back() == 1 && lengths.back() % scalarPerVector == 0){
return true; }
std::cout << "Check 1 failed " << std::endl;
std::cout << "ISPVV Check 2 starting" << std::endl;
if(strides.back() != 1 && scalarPerVector == MPerThread){
return true; }
index_t scalarPerVector,
index_t vectorDim) {
std::cout << "scalarPerVector: " << scalarPerVector << std::endl;
std::cout << "stride back: " << strides.back() << std::endl;
std::cout << "len back: " << lengths.back() << std::endl;
std::cout << "NumDim-1: " << NumDim - 1 << std::endl;
std::cout << "stride[nd-1]: " << strides[NumDim - 1] << std::endl;
std::cout << "NumDim_m-1: " << NumDim_m - 1 << std::endl;
std::cout << std::endl;
std::cout << "ISPVV Check 1 starting" << std::endl;
if(strides[vectorDim] == 1 && (lengths[vectorDim] % scalarPerVector == 0 || lengths[vectorDim]%scalarPerVector == lengths[vectorDim]))
{
std::cout << "Check 1 passed" << std::endl;
return true;
}
std::cout << "Check 1 failed " << std::endl;
std::cout << "ISPVV Check 2 starting" << std::endl;
std::cout << "strides[vectorDim]: " << strides[vectorDim] << std::endl;
if(strides[vectorDim] != 1 && scalarPerVector == strides[vectorDim])
{
std::cout << "Check 2 passed " << std::endl;
return true;
}
std::cout << "Check 2 failed" << std::endl;
return false;
};
/**auto IsOutScalarPerVectorValid =
[&](const std::array<index_t, NumDim>& lengths,
const std::array<index_t, NumDim>& strides,
index_t scalarPerVector) {
std::cout << "ISPVV Check 1 starting" << std::endl;
if(strides.back() != 1 && lengths.back() % scalarPerVector == strides[NumDim - 1])
{
std::cout << "Check 1 passed " << std::endl;
return true;
}
std::cout << "Check 1 failed" << std::endl;
};**/
bool valid = true;
static_for<0, NumInput, 1>{}([&](auto I) {
std::cout<< "running: " << I << std::endl;
std::cout << "running: " << I << std::endl;
if(!IsScalarPerVectorValid(
pArg->lengths_, pArg->inStridesArray_[I.value], InScalarPerVectorSeq::At(I)))
pArg->lengths_, pArg->inStridesArray_[I.value], InScalarPerVectorSeq::At(I), NumDim_m - 1))
valid = false;
});
std::cout << "valid after loop through input: " << valid << std::endl;
static_for<0, NumOutput, 1>{}([&](auto I) {
std::cout << "running 2: " << I << std::endl;
std::cout << "running 2: " << I << std::endl;
if(!IsScalarPerVectorValid(
pArg->lengths_, pArg->outStridesArray_[I.value], OutScalarPerVectorSeq::At(I)))
pArg->lengths_, pArg->outStridesArray_[I.value], OutScalarPerVectorSeq::At(I), NumDim - 1))
valid = false;
});
std::cout << "valid after loop through output: " << valid << std::endl;
return valid;
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment