Commit 77212cc1 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

code backup.

parent 900bad8b
...@@ -820,7 +820,7 @@ struct gather ...@@ -820,7 +820,7 @@ struct gather
struct dot struct dot
{ {
float alpha = 1.0; float alpha = 1.0;
float beta = 0.0; float beta = 1.0;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
...@@ -839,7 +839,7 @@ struct dot ...@@ -839,7 +839,7 @@ struct dot
// according to the specification of the numpy.matmul() // according to the specification of the numpy.matmul()
// inputs with the shape dims more than 2 are acceptable // inputs with the shape dims more than 2 are acceptable
// as long as dim values are the same in the two inputs // as long as dim values are the same in the two inputs
if(!std::equal(a.lens().rbegin() + 2, a.lens().rend(), b.lens().rbegin() + 2)) if(!std::equal(a.lens().rbegin() + 2, a.lens().rend(), b.lens().rbegin() + 2, b.lens().rend()))
{ {
MIGRAPHX_THROW("DOT: dim values mismatch"); MIGRAPHX_THROW("DOT: dim values mismatch");
} }
...@@ -847,10 +847,20 @@ struct dot ...@@ -847,10 +847,20 @@ struct dot
std::size_t dim_0 = a.lens().size() - 2; std::size_t dim_0 = a.lens().size() - 2;
std::size_t dim_1 = a.lens().size() - 1; std::size_t dim_1 = a.lens().size() - 1;
if(a.lens()[dim_1] != b.lens()[dim_0]) if(a.lens()[dim_1] != b.lens()[dim_0])
MIGRAPHX_THROW("Inner dimensions do not match: {" + to_string_range(a.lens()) + {
MIGRAPHX_THROW("DOT: inner dimensions do not match: {" + to_string_range(a.lens()) +
"} x {" + to_string_range(b.lens()) + "}"); "} x {" + to_string_range(b.lens()) + "}");
}
auto out_lens = a.lens(); auto out_lens = a.lens();
out_lens[dim_1] = b.lens()[dim_1]; out_lens[dim_1] = b.lens()[dim_1];
if (inputs.size() == 3 && out_lens != inputs.at(2).lens())
{
MIGRAPHX_THROW("DOT: dimension mismatch, operand C: {" + to_string_range(c_lens) +
"}, cannot add to operand A * B: {" + to_string_range(out_lens) +
"}");
}
return {t, out_lens}; return {t, out_lens};
} }
}; };
......
...@@ -55,7 +55,13 @@ void migemm_impl(tensor_view<T> cmat, ...@@ -55,7 +55,13 @@ void migemm_impl(tensor_view<T> cmat,
visit_mat(amat, [&](const auto& a) { visit_mat(amat, [&](const auto& a) {
visit_mat(bmat, [&](const auto& b) { visit_mat(bmat, [&](const auto& b) {
auto c = make_mat(cmat); auto c = make_mat(cmat);
c = (a * b) * alpha + beta * c; c = beta * c;
// This is a simple optimization to avoid
// compute A * B if alpha is 0.0
if(alpha != 0.0)
{
c = c + alpha * a * b;
}
}); });
}); });
} }
...@@ -95,8 +101,7 @@ void migemm_impl( ...@@ -95,8 +101,7 @@ void migemm_impl(
{ {
auto lens = amat.get_shape().lens(); auto lens = amat.get_shape().lens();
bool batch_mul = bool batch_mul =
std::accumulate(lens.begin(), lens.end(), std::size_t{1}, std::multiplies<std::size_t>()) == std::accumulate(lens.rbegin() + 2, lens.rend(), std::size_t{1}, std::multiplies<std::size_t>()) == 1;
(*lens.rbegin()) * (*(lens.rbegin() + 1));
if(batch_mul) if(batch_mul)
{ {
migemm_impl(cmat, amat, bmat, alpha, beta, is_fast_gemm_type<T>{}); migemm_impl(cmat, amat, bmat, alpha, beta, is_fast_gemm_type<T>{});
......
...@@ -369,12 +369,43 @@ struct cpu_gemm ...@@ -369,12 +369,43 @@ struct cpu_gemm
{ {
op::dot op; op::dot op;
std::string name() const { return "cpu::dot"; } std::string name() const { return "cpu::dot"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); } shape compute_shape(const std::vector<shape>& inputs) const
{
if(inputs.size() == 3)
{
auto c_shape = inputs.at(2);
check_shapes{{c_shape}}.not_broadcasted();
}
return op.compute_shape(inputs);
}
argument compute(context&, const shape& output_shape, std::vector<argument> args) const argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
// 3 inputs, it is alpha * A * B + beta * C, then
// A and B are matrics, and C is broadcastable to A * B
if(args.size() == 3)
{
// no need to consider the value of args[2]
if(op.beta == 0.0f)
{
result.visit([&](auto output) { std::fill(output.begin(), output.end(), 0); });
}
else
{
visit_all(result, args[2])([&](auto output, auto input) {
std::copy(input.begin(), input.end(), output.begin());
});
}
migemm(result, args[0], args[1], op.alpha, op.beta); migemm(result, args[0], args[1], op.alpha, op.beta);
return result;
}
// 2 input arguments
migemm(result, args[0], args[1], op.alpha, 0.0f);
return result; return result;
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment