Commit 38e9963b authored by Shucai Xiao's avatar Shucai Xiao
Browse files

revert unnecessary changes.

parent 45f9b527
......@@ -74,39 +74,21 @@ void migemm_impl(tensor_view<T> cmat,
float beta,
std::false_type)
{
auto a_lens = amat.get_shape().lens();
auto b_lens = bmat.get_shape().lens();
auto c_lens = cmat.get_shape().lens();
std::size_t nc_dims = c_lens.size();
std::size_t na_dims = a_lens.size();
std::size_t nb_dims = b_lens.size();
auto k = a_lens[na_dims - 1];
assert(a_lens[na_dims - 1] == b_lens[nb_dims - 2]);
assert(c_lens[nc_dims - 2] == a_lens[na_dims - 2]);
assert(c_lens[nc_dims - 1] == b_lens[nb_dims - 1]);
std::size_t n_dims = cmat.get_shape().lens().size();
std::size_t dim_0 = n_dims - 2;
std::size_t dim_1 = n_dims - 1;
auto k = amat.get_shape().lens()[dim_1];
std::size_t a_len_diff = nc_dims - na_dims;
std::size_t b_len_diff = nc_dims - nb_dims;
std::vector<std::size_t> a_idx(na_dims);
std::vector<std::size_t> b_idx(nb_dims);
assert(amat.get_shape().lens()[dim_1] == bmat.get_shape().lens()[dim_0]);
assert(cmat.get_shape().lens()[dim_0] == amat.get_shape().lens()[dim_0]);
assert(cmat.get_shape().lens()[dim_1] == bmat.get_shape().lens()[dim_1]);
shape_for_each(cmat.get_shape(), [&](const auto& c_idx) {
std::transform(c_idx.begin() + a_len_diff,
c_idx.end(),
a_lens.begin(),
a_idx.begin(),
[&](auto i, auto j) { return (j == 1) ? 0 : i; });
std::transform(c_idx.begin() + b_len_diff,
c_idx.end(),
b_lens.begin(),
b_idx.begin(),
[&](auto i, auto j) { return (j == 1) ? 0 : i; });
double s = 0.0;
auto a_idx = c_idx;
auto b_idx = c_idx;
double s = 0.0;
dfor(k)([&](auto kk) {
a_idx[na_dims - 1] = b_idx[nb_dims - 2] = kk;
a_idx[dim_1] = b_idx[dim_0] = kk;
s += amat(a_idx.begin(), a_idx.end()) * bmat(b_idx.begin(), b_idx.end());
});
cmat(c_idx.begin(), c_idx.end()) = alpha * s + cmat(c_idx.begin(), c_idx.end()) * beta;
......
......@@ -417,50 +417,6 @@ struct cpu_gemm
return op.compute_shape(inputs);
}
void fill_result(argument& result, argument& c) const
{
auto out_lens = result.get_shape().lens();
auto c_lens = c.get_shape().lens();
if(out_lens == c_lens)
{
visit_all(result, c)([&](auto output, auto input) {
std::copy(input.begin(), input.end(), output.begin());
});
}
// need broadcast
else if(c.single())
{
visit_all(result, c)([&](auto output, auto input) {
std::fill(output.begin(), output.end(), input.front());
});
}
// must be c_lens[0] == output_lens[1]
else if(c_lens.size() == 1 || (c_lens.size() == 2 && (c_lens[1] == out_lens[1])))
{
std::size_t m = out_lens[0];
std::size_t n = out_lens[1];
visit_all(result, c)([&](auto output, auto input) {
for(std::size_t i = 0; i < m; i++)
{
std::copy(input.begin(), input.end(), output.begin() + i * n);
}
});
}
// c_lens.size() == 2 and c_lens[0] == out_lens[0]
else
{
std::size_t m = out_lens[0];
std::size_t n = out_lens[1];
visit_all(result, c)([&](auto output, auto input) {
for(std::size_t i = 0; i < m; i++)
{
std::fill(output.begin() + i * n, output.begin() + ((i + 1) * n), input[i]);
}
});
}
}
argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment