Commit caf2b2ed authored by Anthony Chang's avatar Anthony Chang
Browse files

host tensor gen: diagonal pattern in lowest two-dimensions only

parent b790e44b
...@@ -152,7 +152,7 @@ struct GeneratorTensor_Sequential ...@@ -152,7 +152,7 @@ struct GeneratorTensor_Sequential
} }
}; };
template <typename T> template <typename T, size_t NumEffectiveDim = 2>
struct GeneratorTensor_Diagonal struct GeneratorTensor_Diagonal
{ {
T value{1}; T value{1};
...@@ -161,9 +161,10 @@ struct GeneratorTensor_Diagonal ...@@ -161,9 +161,10 @@ struct GeneratorTensor_Diagonal
T operator()(Ts... Xs) const T operator()(Ts... Xs) const
{ {
std::array<ck::index_t, sizeof...(Ts)> dims = {{static_cast<ck::index_t>(Xs)...}}; std::array<ck::index_t, sizeof...(Ts)> dims = {{static_cast<ck::index_t>(Xs)...}};
size_t start_dim = dims.size() - NumEffectiveDim;
bool pred = true; bool pred = true;
for (size_t i = 1; i < dims.size(); i++) { for (size_t i = start_dim + 1; i < dims.size(); i++) {
pred &= (dims[0] == dims[i]); pred &= (dims[start_dim] == dims[i]);
} }
return pred ? value : T{0}; return pred ? value : T{0};
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment