Commit caf2b2ed authored by Anthony Chang's avatar Anthony Chang
Browse files

host tensor gen: diagonal pattern in lowest two-dimensions only

parent b790e44b
......@@ -152,7 +152,7 @@ struct GeneratorTensor_Sequential
}
};
template <typename T>
template <typename T, size_t NumEffectiveDim = 2>
struct GeneratorTensor_Diagonal
{
T value{1};
......@@ -161,9 +161,10 @@ struct GeneratorTensor_Diagonal
T operator()(Ts... Xs) const
{
std::array<ck::index_t, sizeof...(Ts)> dims = {{static_cast<ck::index_t>(Xs)...}};
size_t start_dim = dims.size() - NumEffectiveDim;
bool pred = true;
for (size_t i = 1; i < dims.size(); i++) {
pred &= (dims[0] == dims[i]);
for (size_t i = start_dim + 1; i < dims.size(); i++) {
pred &= (dims[start_dim] == dims[i]);
}
return pred ? value : T{0};
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment