Commit 45dddfa7 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX...

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into fp32_fp16_convert
parents 5b2aeb80 4a3e493c
...@@ -19,7 +19,7 @@ namespace op { ...@@ -19,7 +19,7 @@ namespace op {
struct dot struct dot
{ {
float alpha = 1.0; float alpha = 1.0;
float beta = 0.0; float beta = 1.0;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
...@@ -30,26 +30,41 @@ struct dot ...@@ -30,26 +30,41 @@ struct dot
std::string name() const { return "dot"; } std::string name() const { return "dot"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(2).same_type(); check_shapes{inputs, *this}.same_type();
const shape& a = inputs.at(0); const shape& a = inputs.at(0);
const shape& b = inputs.at(1); const shape& b = inputs.at(1);
auto t = a.type(); auto t = a.type();
// according to the specification of the numpy.matmul() if(!std::all_of(inputs.begin(), inputs.end(), [](auto s) { return s.lens().size() >= 2; }))
// inputs with the shape dims more than 2 are acceptable
// as long as dim values are the same in the two inputs
if(!std::equal(a.lens().rbegin() + 2, a.lens().rend(), b.lens().rbegin() + 2))
{ {
MIGRAPHX_THROW("DOT: dim values mismatch"); MIGRAPHX_THROW("DOT: dot only accept 2 or more dims operands");
}
// only handle the case that the batch size of a and b are the same
if(!std::equal(
a.lens().rbegin() + 2, a.lens().rend(), b.lens().rbegin() + 2, b.lens().rend()))
{
MIGRAPHX_THROW("DOT: batch size of A and B mismatch: {" + to_string_range(a.lens()) +
"} x {" + to_string_range(b.lens()) + "}");
} }
std::size_t dim_0 = a.lens().size() - 2; std::size_t dim_0 = a.lens().size() - 2;
std::size_t dim_1 = a.lens().size() - 1; std::size_t dim_1 = a.lens().size() - 1;
if(a.lens()[dim_1] != b.lens()[dim_0]) if(a.lens()[dim_1] != b.lens()[dim_0])
MIGRAPHX_THROW("Inner dimensions do not match: {" + to_string_range(a.lens()) + {
MIGRAPHX_THROW("DOT: inner dimensions do not match: {" + to_string_range(a.lens()) +
"} x {" + to_string_range(b.lens()) + "}"); "} x {" + to_string_range(b.lens()) + "}");
}
auto out_lens = a.lens(); auto out_lens = a.lens();
out_lens[dim_1] = b.lens()[dim_1]; out_lens[dim_1] = b.lens()[dim_1];
if(inputs.size() == 3 && out_lens != inputs.at(2).lens())
{
MIGRAPHX_THROW("DOT: dimension mismatch, operand C: {" +
to_string_range(inputs.at(2).lens()) +
"}, cannot add to operand A * B: {" + to_string_range(out_lens) + "}");
}
return {t, out_lens}; return {t, out_lens};
} }
}; };
......
...@@ -55,7 +55,15 @@ struct squeeze ...@@ -55,7 +55,15 @@ struct squeeze
} }
} }
} }
return shape{type, new_lens};
if(new_lens.empty())
{
return shape{type};
}
else
{
return shape{type, new_lens};
}
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
{ {
......
...@@ -36,7 +36,6 @@ struct onnx_parser ...@@ -36,7 +36,6 @@ struct onnx_parser
onnx_parser() onnx_parser()
{ {
add_generic_op("MatMul", op::dot{});
add_generic_op("Relu", op::relu{}); add_generic_op("Relu", op::relu{});
add_generic_op("Sigmoid", op::sigmoid{}); add_generic_op("Sigmoid", op::sigmoid{});
add_generic_op("Abs", op::abs{}); add_generic_op("Abs", op::abs{});
...@@ -77,6 +76,7 @@ struct onnx_parser ...@@ -77,6 +76,7 @@ struct onnx_parser
add_mem_op("Reshape", &onnx_parser::parse_reshape); add_mem_op("Reshape", &onnx_parser::parse_reshape);
add_mem_op("Flatten", &onnx_parser::parse_flatten); add_mem_op("Flatten", &onnx_parser::parse_flatten);
add_mem_op("Gemm", &onnx_parser::parse_gemm); add_mem_op("Gemm", &onnx_parser::parse_gemm);
add_mem_op("MatMul", &onnx_parser::parse_matmul);
add_mem_op("BatchNormalization", &onnx_parser::parse_batchnorm); add_mem_op("BatchNormalization", &onnx_parser::parse_batchnorm);
add_mem_op("Softmax", &onnx_parser::parse_softmax); add_mem_op("Softmax", &onnx_parser::parse_softmax);
add_mem_op("LogSoftmax", &onnx_parser::parse_logsoftmax); add_mem_op("LogSoftmax", &onnx_parser::parse_logsoftmax);
...@@ -154,42 +154,48 @@ struct onnx_parser ...@@ -154,42 +154,48 @@ struct onnx_parser
}); });
} }
std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
std::vector<std::size_t> s1)
{
// Example:
// s0 = (3,2,4,5) and s1 = (2,1,1)
//
// In this case we need to broadcast (:,1,1) portion of
// s1 plus broadcast the 1st dimension of s1
// giving output_lens = (3,2,4,5)
//
// Another example:
// s0 = (3,2,1,5) and s1 = (2,7,5)
// In this case we need to broadcast the (:,:,1:,:) axis
// of s0 plus the 1st dimension of s1 giving
// output_lens = (3,2,7,5)
if(s0.size() > s1.size())
{
s0.swap(s1);
}
std::vector<std::size_t> out_lens(s1);
auto offset = s1.size() - s0.size();
std::transform(s0.begin(),
s0.end(),
s1.begin() + offset,
out_lens.begin() + offset,
[](auto a, auto b) { return std::max(a, b); });
return out_lens;
}
template <class T> template <class T>
instruction_ref add_broadcastable_binary_op(instruction_ref arg0, instruction_ref arg1, T x) instruction_ref add_broadcastable_binary_op(instruction_ref arg0, instruction_ref arg1, T x)
{ {
if(arg0->get_shape().lens() != arg1->get_shape().lens()) if(arg0->get_shape().lens() != arg1->get_shape().lens())
{ {
// Example:
// s0 = (3,2,4,5) and s1 = (2,1,1)
//
// In this case we need to broadcast (:,1,1) portion of
// s1 plus broadcast the 1st dimension of s1
// giving output_lens = (3,2,4,5)
//
// Another example:
// s0 = (3,2,1,5) and s1 = (2,7,5)
// In this case we need to broadcast the (:,:,1:,:) axis
// of s0 plus the 1st dimension of s1 giving
// output_lens = (3,2,7,5)
//
// Get lengths for both arguments // Get lengths for both arguments
const std::vector<std::size_t>* s0 = &arg0->get_shape().lens(); auto s0 = arg0->get_shape().lens();
const std::vector<std::size_t>* s1 = &arg1->get_shape().lens(); auto s1 = arg1->get_shape().lens();
auto out_lens = compute_broadcasted_lens(s0, s1);
// Make sure s0 is the smaller size auto l0 = prog.add_instruction(op::multibroadcast{out_lens}, arg0);
if(s0->size() > s1->size()) auto l1 = prog.add_instruction(op::multibroadcast{out_lens}, arg1);
std::swap(s0, s1);
std::vector<std::size_t> output_lens(*s1);
auto offset = s1->size() - s0->size();
std::transform(s0->begin(),
s0->end(),
s1->begin() + offset,
output_lens.begin() + offset,
[](auto a, auto b) { return std::max(a, b); });
auto l0 = prog.add_instruction(op::multibroadcast{output_lens}, arg0);
auto l1 = prog.add_instruction(op::multibroadcast{output_lens}, arg1);
return prog.add_instruction(x, l0, l1); return prog.add_instruction(x, l0, l1);
} }
else else
...@@ -495,25 +501,86 @@ struct onnx_parser ...@@ -495,25 +501,86 @@ struct onnx_parser
auto l2 = (transb) ? prog.add_instruction(op::transpose{perm}, args[1]) : args[1]; auto l2 = (transb) ? prog.add_instruction(op::transpose{perm}, args[1]) : args[1];
if(args.size() == 3) if(args.size() == 3)
{ {
if(beta != 0.f) if(beta != 0.f && args[2]->get_shape().elements() > 0)
{ {
auto l3 = prog.add_instruction(op::dot{alpha}, l1, l2); auto out_lens = l1->get_shape().lens();
auto l4 = args[2]; out_lens.back() = l2->get_shape().lens().back();
if(l4->get_shape().scalar()) // ignore args[2] (no C value added to alpha*A*B) auto l3 = args[2];
return l3; auto l3_lens = l3->get_shape().lens();
if(beta != 1.f) if(!std::equal(out_lens.begin(), out_lens.end(), l3_lens.begin(), l3_lens.end()))
{ {
auto beta_val = prog.add_literal(beta); l3 = prog.add_instruction(op::multibroadcast{out_lens}, args[2]);
auto l5 = prog.add_instruction(op::scalar{args[2]->get_shape()}, beta_val);
l4 = prog.add_instruction(op::mul{}, args[2], l5);
} }
return add_broadcastable_binary_op(l3, l4, op::add{}); return prog.add_instruction(op::dot{alpha, beta}, l1, l2, l3);
} }
} }
return prog.add_instruction(op::dot{alpha, beta}, l1, l2); return prog.add_instruction(op::dot{alpha, beta}, l1, l2);
} }
instruction_ref
parse_matmul(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
{
auto l0 = args[0];
auto l1 = args[1];
auto l0_lens = l0->get_shape().lens();
auto l1_lens = l1->get_shape().lens();
// args[0] is a vector, prepend 1 to the shape
bool is_a_prepended = false;
if(l0_lens.size() == 1)
{
is_a_prepended = true;
l0_lens.insert(l0_lens.begin(), 1);
l0 = prog.add_instruction(op::unsqueeze{{0}}, args[0]);
}
bool is_b_appended = false;
if(l1_lens.size() == 1)
{
is_b_appended = true;
l1_lens.push_back(1);
l1 = prog.add_instruction(op::unsqueeze{{1}}, args[1]);
}
instruction_ref bl0 = l0;
instruction_ref bl1 = l1;
if(!std::equal(l0_lens.rbegin() + 2, l0_lens.rend(), l1_lens.rbegin() + 2, l1_lens.rend()))
{
auto l0_it = l0_lens.begin() + l0_lens.size() - 2;
std::vector<std::size_t> l0_broadcasted_lens(l0_lens.begin(), l0_it);
auto l1_it = l1_lens.begin() + l1_lens.size() - 2;
std::vector<std::size_t> l1_broadcasted_lens(l1_lens.begin(), l1_it);
auto output_lens = compute_broadcasted_lens(l0_broadcasted_lens, l1_broadcasted_lens);
l0_broadcasted_lens = output_lens;
l0_broadcasted_lens.insert(l0_broadcasted_lens.end(), l0_it, l0_lens.end());
l1_broadcasted_lens = output_lens;
l1_broadcasted_lens.insert(l1_broadcasted_lens.end(), l1_it, l1_lens.end());
if(l0_lens != l0_broadcasted_lens)
{
bl0 = prog.add_instruction(op::multibroadcast{l0_broadcasted_lens}, l0);
}
if(l1_lens != l1_broadcasted_lens)
{
bl1 = prog.add_instruction(op::multibroadcast{l1_broadcasted_lens}, l1);
}
}
auto dot_res = prog.add_instruction(op::dot{1.0f, 0.0f}, bl0, bl1);
int64_t num_axis = static_cast<int64_t>(dot_res->get_shape().lens().size());
if(is_a_prepended)
{
dot_res = prog.add_instruction(op::squeeze{{num_axis - 2}}, dot_res);
--num_axis;
}
if(is_b_appended)
{
dot_res = prog.add_instruction(op::squeeze{{num_axis - 1}}, dot_res);
}
return dot_res;
}
instruction_ref instruction_ref
parse_batchnorm(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_batchnorm(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
......
...@@ -55,7 +55,13 @@ void migemm_impl(tensor_view<T> cmat, ...@@ -55,7 +55,13 @@ void migemm_impl(tensor_view<T> cmat,
visit_mat(amat, [&](const auto& a) { visit_mat(amat, [&](const auto& a) {
visit_mat(bmat, [&](const auto& b) { visit_mat(bmat, [&](const auto& b) {
auto c = make_mat(cmat); auto c = make_mat(cmat);
c = (a * b) * alpha + beta * c; c = beta * c;
// This is a simple optimization to avoid
// compute A * B if alpha is 0.0
if(alpha != 0.0)
{
c = c + alpha * a * b;
}
}); });
}); });
} }
...@@ -95,8 +101,8 @@ void migemm_impl( ...@@ -95,8 +101,8 @@ void migemm_impl(
{ {
auto lens = amat.get_shape().lens(); auto lens = amat.get_shape().lens();
bool batch_mul = bool batch_mul =
std::accumulate(lens.begin(), lens.end(), std::size_t{1}, std::multiplies<std::size_t>()) == std::accumulate(
(*lens.rbegin()) * (*(lens.rbegin() + 1)); lens.rbegin() + 2, lens.rend(), std::size_t{1}, std::multiplies<std::size_t>()) == 1;
if(batch_mul) if(batch_mul)
{ {
migemm_impl(cmat, amat, bmat, alpha, beta, is_fast_gemm_type<T>{}); migemm_impl(cmat, amat, bmat, alpha, beta, is_fast_gemm_type<T>{});
......
...@@ -369,12 +369,43 @@ struct cpu_gemm ...@@ -369,12 +369,43 @@ struct cpu_gemm
{ {
op::dot op; op::dot op;
std::string name() const { return "cpu::dot"; } std::string name() const { return "cpu::dot"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); } shape compute_shape(const std::vector<shape>& inputs) const
{
if(inputs.size() == 3)
{
auto c_shape = inputs.at(2);
check_shapes{{c_shape}}.not_broadcasted();
}
return op.compute_shape(inputs);
}
argument compute(context&, const shape& output_shape, std::vector<argument> args) const argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
migemm(result, args[0], args[1], op.alpha, op.beta); // 3 inputs, it is alpha * A * B + beta * C, then
// A and B are matrics, and C is broadcastable to A * B
if(args.size() == 3)
{
// no need to consider the value of args[2]
if(op.beta == 0.0f)
{
result.visit([&](auto output) { std::fill(output.begin(), output.end(), 0); });
}
else
{
visit_all(result, args[2])([&](auto output, auto input) {
std::copy(input.begin(), input.end(), output.begin());
});
}
migemm(result, args[0], args[1], op.alpha, op.beta);
return result;
}
// 2 input arguments
migemm(result, args[0], args[1], op.alpha, 0.0f);
return result; return result;
} }
}; };
......
#include <migraphx/gpu/gemm.hpp> #include <migraphx/gpu/gemm.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/device/add.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
template <class... Ts> template <class... Ts>
void generic_rocblas_batched_gemm(shape::as<float>, Ts&&... xs) rocblas_status generic_rocblas_scal(shape::as<float>, Ts&&... xs)
{ {
rocblas_sgemm_strided_batched(std::forward<Ts>(xs)...); return rocblas_sscal(std::forward<Ts>(xs)...);
} }
template <class... Ts> template <class... Ts>
void generic_rocblas_batched_gemm(shape::as<double>, Ts&&... xs) rocblas_status generic_rocblas_scal(shape::as<double>, Ts&&... xs)
{ {
rocblas_dgemm_strided_batched(std::forward<Ts>(xs)...); return rocblas_dscal(std::forward<Ts>(xs)...);
}
template <class T, class... Ts>
rocblas_status generic_rocblas_scal(shape::as<T>, Ts&&...)
{
MIGRAPHX_THROW("GENERIC_ROCBLAS_SCAL: type unsupported by rocblas");
}
template <class... Ts>
rocblas_status generic_rocblas_axpy(shape::as<half>, Ts&&... xs)
{
return rocblas_haxpy(std::forward<Ts>(xs)...);
}
template <class... Ts>
rocblas_status generic_rocblas_axpy(shape::as<float>, Ts&&... xs)
{
return rocblas_saxpy(std::forward<Ts>(xs)...);
}
template <class... Ts>
rocblas_status generic_rocblas_axpy(shape::as<double>, Ts&&... xs)
{
return rocblas_daxpy(std::forward<Ts>(xs)...);
}
template <class T, class... Ts>
rocblas_status generic_rocblas_axpy(shape::as<T>, Ts&&...)
{
MIGRAPHX_THROW("GENERIC_ROCBLAS_AXPY: type unsupported by rocblas");
}
template <class... Ts>
rocblas_status generic_rocblas_dot(shape::as<float>, Ts&&... xs)
{
return rocblas_sdot(std::forward<Ts>(xs)...);
} }
template <class... Ts> template <class... Ts>
void generic_rocblas_batched_gemm(shape::as<half>, Ts&&... xs) rocblas_status generic_rocblas_dot(shape::as<double>, Ts&&... xs)
{ {
rocblas_hgemm_strided_batched(std::forward<Ts>(xs)...); return rocblas_ddot(std::forward<Ts>(xs)...);
} }
template <class T, class... Ts> template <class T, class... Ts>
void generic_rocblas_batched_gemm(shape::as<T>, Ts&&...) rocblas_status generic_rocblas_dot(shape::as<T>, Ts&&...)
{
MIGRAPHX_THROW("GENERIC_ROCBLAS_DOT: type unsupported by rocblas");
}
template <class... Ts>
rocblas_status generic_rocblas_gemv(shape::as<float>, Ts&&... xs)
{
return rocblas_sgemv(std::forward<Ts>(xs)...);
}
template <class... Ts>
rocblas_status generic_rocblas_gemv(shape::as<double>, Ts&&... xs)
{
return rocblas_dgemv(std::forward<Ts>(xs)...);
}
template <class T, class... Ts>
rocblas_status generic_rocblas_gemv(shape::as<T>, Ts&&...)
{
MIGRAPHX_THROW("GENERIC_ROCBLAS_GEMMV: type unsupported by rocblas");
}
template <class... Ts>
rocblas_status generic_rocblas_batched_gemm(shape::as<float>, Ts&&... xs)
{
return rocblas_sgemm_strided_batched(std::forward<Ts>(xs)...);
}
template <class... Ts>
rocblas_status generic_rocblas_batched_gemm(shape::as<double>, Ts&&... xs)
{
return rocblas_dgemm_strided_batched(std::forward<Ts>(xs)...);
}
template <class... Ts>
rocblas_status generic_rocblas_batched_gemm(shape::as<half>, Ts&&... xs)
{
return rocblas_hgemm_strided_batched(std::forward<Ts>(xs)...);
}
template <class T, class... Ts>
rocblas_status generic_rocblas_batched_gemm(shape::as<T>, Ts&&...)
{ {
MIGRAPHX_THROW("GENERIC_ROCBLAS_BATCHED_GEMM: type unsupported by rocblas"); MIGRAPHX_THROW("GENERIC_ROCBLAS_BATCHED_GEMM: type unsupported by rocblas");
} }
template <class... Ts> template <class... Ts>
void generic_rocblas_gemm(shape::as<float>, Ts&&... xs) rocblas_status generic_rocblas_gemm(shape::as<float>, Ts&&... xs)
{ {
rocblas_sgemm(std::forward<Ts>(xs)...); return rocblas_sgemm(std::forward<Ts>(xs)...);
} }
template <class... Ts> template <class... Ts>
void generic_rocblas_gemm(shape::as<double>, Ts&&... xs) rocblas_status generic_rocblas_gemm(shape::as<double>, Ts&&... xs)
{ {
rocblas_dgemm(std::forward<Ts>(xs)...); return rocblas_dgemm(std::forward<Ts>(xs)...);
} }
template <class... Ts> template <class... Ts>
void generic_rocblas_gemm(shape::as<half>, Ts&&... xs) rocblas_status generic_rocblas_gemm(shape::as<half>, Ts&&... xs)
{ {
rocblas_hgemm(std::forward<Ts>(xs)...); return rocblas_hgemm(std::forward<Ts>(xs)...);
} }
template <class T, class... Ts> template <class T, class... Ts>
void generic_rocblas_gemm(shape::as<T>, Ts&&...) rocblas_status generic_rocblas_gemm(shape::as<T>, Ts&&...)
{ {
MIGRAPHX_THROW("GENERIC_ROCBLAS_GEMM: type unsupported by rocblas"); MIGRAPHX_THROW("GENERIC_ROCBLAS_GEMM: type unsupported by rocblas");
} }
...@@ -90,56 +169,94 @@ rocblas_half to_rocblas_type(half x) { return reinterpret_cast<const rocblas_hal ...@@ -90,56 +169,94 @@ rocblas_half to_rocblas_type(half x) { return reinterpret_cast<const rocblas_hal
shape miopen_gemm::compute_shape(const std::vector<shape>& inputs) const shape miopen_gemm::compute_shape(const std::vector<shape>& inputs) const
{ {
check_shapes{inputs, *this}.has(3); std::vector<shape> input_shapes(inputs.begin(), inputs.begin() + inputs.size() - 1);
return op.compute_shape({inputs.at(0), inputs.at(1)}); check_shapes{input_shapes}.not_broadcasted();
return op.compute_shape(input_shapes);
} }
argument miopen_gemm::compute(context& ctx, argument miopen_gemm::compute(context& ctx,
const shape& output_shape, const shape& output_shape,
const std::vector<argument>& args) const const std::vector<argument>& args) const
{ {
float alpha = 1.0f; bool is_3inputs = (args.size() == 4);
float beta = 0.0f; float beta = 0.0f;
bool transa = args[0].get_shape().transposed(); if(is_3inputs)
bool transb = args[1].get_shape().transposed(); {
std::size_t n_dims = args[0].get_shape().lens().size(); output_shape.visit_type([&](auto as) {
std::size_t dim_0 = n_dims - 2; auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
std::size_t dim_1 = n_dims - 1; hipMemcpyAsync(to_pointer(args[3]),
rocblas_int lda = args[0].get_shape().strides()[transa ? dim_1 : dim_0]; to_pointer(args[2]),
rocblas_int ldb = args[1].get_shape().strides()[transb ? dim_1 : dim_0]; output_shape.bytes(),
rocblas_int ldc = args[2].get_shape().strides()[dim_0]; hipMemcpyDeviceToDevice,
auto out_lens = output_shape.lens(); ctx.get_stream().get());
rocblas_int m = out_lens[dim_0]; });
rocblas_int n = out_lens[dim_1]; beta = op.beta;
rocblas_int k = args[0].get_shape().lens()[dim_1]; }
auto batch_num = std::accumulate(
out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>()); auto a_lens = args[0].get_shape().lens();
auto b_lens = args[1].get_shape().lens();
output_shape.visit_type([&](auto as) { output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(alpha)); auto n_dim = output_shape.lens().size();
auto beta_r = to_rocblas_type(as(beta)); auto dim_1 = n_dim - 1;
auto dim_0 = n_dim - 2;
auto alpha_r = to_rocblas_type(as(op.alpha));
auto beta_r = to_rocblas_type(as(beta));
bool transa = args[0].get_shape().transposed();
bool transb = args[1].get_shape().transposed();
rocblas_int lda = args[0].get_shape().strides()[transa ? dim_1 : dim_0];
rocblas_int ldb = args[1].get_shape().strides()[transb ? dim_1 : dim_0];
rocblas_int ldc = args[2].get_shape().strides()[dim_0];
auto out_lens = output_shape.lens();
rocblas_int m = out_lens[dim_0];
rocblas_int n = out_lens[dim_1];
rocblas_int k = args[0].get_shape().lens()[dim_1];
auto num_matrices = std::accumulate(
out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); }; auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
generic_rocblas_batched_gemm(as, if(num_matrices == 1)
ctx.get_stream().get_rocblas(), {
transb ? rocblas_operation_transpose : rocblas_operation_none, generic_rocblas_gemm(as,
transa ? rocblas_operation_transpose : rocblas_operation_none, ctx.get_stream().get_rocblas(),
n, transb ? rocblas_operation_transpose : rocblas_operation_none,
m, transa ? rocblas_operation_transpose : rocblas_operation_none,
k, n,
&alpha_r, m,
to_pointer(args[1]), k,
ldb, &alpha_r,
k * n, to_pointer(args[1]),
to_pointer(args[0]), ldb,
lda, to_pointer(args[0]),
m * k, lda,
&beta_r, &beta_r,
to_pointer(args[2]), (is_3inputs ? to_pointer(args[3]) : to_pointer(args[2])),
ldc, ldc);
m * n, }
batch_num); else
{
generic_rocblas_batched_gemm(
as,
ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
&alpha_r,
to_pointer(args[1]),
ldb,
k * n,
to_pointer(args[0]),
lda,
m * k,
&beta_r,
(is_3inputs ? to_pointer(args[3]) : to_pointer(args[2])),
ldc,
m * n,
num_matrices);
}
}); });
return args[2]; return (is_3inputs ? args[3] : args[2]);
} }
} // namespace gpu } // namespace gpu
......
This diff is collapsed.
...@@ -876,242 +876,6 @@ TEST_CASE(reshape_test) ...@@ -876,242 +876,6 @@ TEST_CASE(reshape_test)
} }
} }
template <class T>
void gemm_test()
{
migraphx::program p;
std::vector<T> a = {-0.00925222, 0.56250403, 0.70107397, 0.75402161, -0.505885,
1.33628943, -0.11413, -0.31270559, 1.59336732, -0.19361027,
-0.91620867, 0.40108416, -0.06969921, 0.68483471, -0.39906632,
-1.66423624, 0.69040076, -1.31490171, -0.11282616, -0.79391814};
std::vector<float> b = {6.09568541e-01,
-6.10527007e-01,
3.66646462e-01,
1.18951101e-01,
5.58777432e-01,
-3.21296298e-01,
-5.95997198e-01,
-5.01425721e-01,
-2.84606807e-01,
-5.73673557e-01,
-8.99430260e-01,
-4.25103093e-01,
1.53027987e+00,
-3.81407415e-04,
-3.29650255e-01};
std::vector<float> c = {-1.56327541e+00,
-7.09570140e-01,
-5.37424982e-01,
-2.22994831e-01,
-2.15586437e+00,
2.09177941e-03,
-1.47279677e+00,
2.02627040e-01,
-6.04527691e-01,
-1.29885596e+00,
2.16294914e+00,
-1.48101497e-01};
migraphx::shape a_shape{migraphx::shape::get_type<T>{}, {4, 5}};
auto al = p.add_literal(migraphx::literal{a_shape, a});
migraphx::shape b_shape{migraphx::shape::get_type<T>{}, {5, 3}};
auto bl = p.add_literal(migraphx::literal{b_shape, b});
p.add_instruction(migraphx::op::dot{}, al, bl);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<T> results_vector(12);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(c, results_vector));
}
TEST_CASE_REGISTER(gemm_test<float>)
TEST_CASE_REGISTER(gemm_test<double>)
template <class T>
void gemm_test_ex()
{
migraphx::program p;
std::vector<T> a = {-0.00925222, 0.56250403, 0.70107397, 0.75402161, -0.505885,
1.33628943, -0.11413, -0.31270559, 1.59336732, -0.19361027,
-0.91620867, 0.40108416, -0.06969921, 0.68483471, -0.39906632,
-1.66423624, 0.69040076, -1.31490171, -0.11282616, -0.79391814};
std::vector<float> b = {6.09568541e-01,
-6.10527007e-01,
3.66646462e-01,
1.18951101e-01,
5.58777432e-01,
-3.21296298e-01,
-5.95997198e-01,
-5.01425721e-01,
-2.84606807e-01,
-5.73673557e-01,
-8.99430260e-01,
-4.25103093e-01,
1.53027987e+00,
-3.81407415e-04,
-3.29650255e-01};
std::vector<float> c = {-1.56327541e+00,
-7.09570140e-01,
-5.37424982e-01,
-2.22994831e-01,
-2.15586437e+00,
2.09177941e-03,
-1.47279677e+00,
2.02627040e-01,
-6.04527691e-01,
-1.29885596e+00,
2.16294914e+00,
-1.48101497e-01};
migraphx::shape a_shape{migraphx::shape::get_type<T>{}, {1, 1, 4, 5}};
auto al = p.add_literal(migraphx::literal{a_shape, a});
migraphx::shape b_shape{migraphx::shape::get_type<T>{}, {1, 1, 5, 3}};
auto bl = p.add_literal(migraphx::literal{b_shape, b});
p.add_instruction(migraphx::op::dot{}, al, bl);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<T> results_vector(12);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(c, results_vector));
}
TEST_CASE_REGISTER(gemm_test_ex<float>)
TEST_CASE_REGISTER(gemm_test_ex<double>)
TEST_CASE(gemm_mutli_dim_2)
{
migraphx::program p;
std::vector<float> m1 = {-0.76234141,
0.01368910,
-0.86343423,
-0.99465282,
0.76133268,
0.96507140,
-0.55893585,
0.02625652,
0.75171776,
0.23112578,
0.25624787,
-1.50442161};
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 2, 3}};
std::vector<float> m2 = {-0.15933632, -0.69594712, -0.06198966, -1.23905184, -0.83672704,
-1.06971832, -0.12272917, 1.07094116, -0.08346820, 1.16820693,
-0.95700874, 0.24059691, 0.43326023, 0.78305235, -0.53506601,
-0.69359678, -0.26334436, 1.56292796, -0.33629175, -1.72693469,
0.41435494, 1.52136843, -0.40699791, -1.59839430};
migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 4}};
auto l1 = p.add_literal(migraphx::literal{m1_shape, m1});
auto l2 = p.add_literal(migraphx::literal{m2_shape, m2});
p.add_instruction(migraphx::op::dot{}, l1, l2);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
std::vector<float> m_res = {0.18208394,
-0.49276402,
0.87189133,
0.75150114,
-0.55909610,
1.00521735,
-0.95536130,
2.27996211,
0.06239879,
0.74700068,
-0.01570983,
-0.85920856,
-0.59070835,
-1.70729902,
0.40245487,
1.80182751};
EXPECT(migraphx::verify_range(m, m_res));
}
TEST_CASE(gemm_mutli_dim_2_3)
{
migraphx::program p;
std::vector<float> m1 = {
-1.93300070, 0.33902698, -0.45173527, -0.72283069, -0.17177134, 1.62199882,
0.87052847, 0.14989811, -0.88969184, -0.18131398, 0.72654339, -0.57123693,
0.03852506, -0.72332085, -1.81844083, -0.33465167, -0.71400352, 0.36883161,
0.08698452, 0.94974586, 0.40087323, -0.05448534, 0.03220677, -1.22494296,
0.97938472, -1.43714454, -0.80430904, -0.08098728, 0.31520301, 0.49642169,
-1.63471091, 0.34390096, 2.81292176, -0.22666528, 1.54559556, -1.51075762};
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3, 2, 3}};
std::vector<float> m2 = {
-0.33170529, 2.26325120, -0.50639461, 0.64802947, 0.44748888, 0.33768068,
-0.53621075, 0.34341460, 0.58742520, -1.13995790, -0.99322535, 0.35447353,
0.01977110, -0.10155016, -1.02288245, -0.16575791, -1.47870374, 0.29300008,
-0.39112198, 1.42303608, -0.02853060, 1.52610164, 0.53540909, 0.75618998,
-0.26877787, -1.90886366, 0.30622790, 0.59794535, 1.29795331, -0.37805803,
-1.58167176, -1.26966832, 0.27435891, 0.89430347, 0.22854926, -0.50317658};
migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 3, 2}};
auto l1 = p.add_literal(migraphx::literal{m1_shape, m1});
auto l2 = p.add_literal(migraphx::literal{m2_shape, m2});
p.add_instruction(migraphx::op::dot{}, l1, l2);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
std::vector<float> m_res = {0.26735861, -4.30770895, 1.05257728, -1.19954265, 0.50493170,
-0.18729756, 1.09137941, -1.09298312, 3.42956915, -0.41681939,
0.17833257, 0.26040336, 0.15351280, 1.87632715, -0.63545406,
-0.95467340, -1.74728628, -2.42477030, 0.76262372, 0.15539164,
3.32281958, 0.96769613, 0.43727545, 2.43019906};
EXPECT(migraphx::verify_range(m, m_res));
}
TEST_CASE(gemm_mutli_dim1_2_3)
{
migraphx::program p;
std::vector<float> m1 = {
1.23636469, -0.47041261, -0.14375651, -0.48371852, 1.16479301, -0.89361055,
-0.18569086, 1.10700457, -1.02632638, 0.82277012, 0.33525769, 0.52825145,
-1.00141689, 0.45510090, -0.02675039, -0.60454439, 0.38551153, -0.01658514,
0.93059292, -0.54595188, -0.04911005, -0.91397221, -0.83127477, -1.57685603,
-1.36200452, 2.25822236, -1.23416970, 0.12312496, 0.76232760, -0.83594234,
1.67418145, -0.19412936, 1.05261378, 0.66246074, -1.15233398, 0.16429736};
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3, 2, 3}};
std::vector<float> m2 = {
-0.87300530, -0.07112838, 0.19196860, -1.04986840, 1.20348200, 0.31966893,
1.04805440, -2.04777729, -0.67906052, -1.17250760, 0.34305044, -1.01957785,
-1.12694862, 0.18431338, -1.63712290, 0.27566931, -1.11282021, 1.41738919,
0.47871283, -1.01980420, 1.00212436, -0.78740444, -1.65636133, 1.51466547,
-0.12470397, 0.70404393, -0.15244797, 0.74288871, 0.07339926, -1.45811623,
0.27185845, 0.08804596, 0.99061977, -1.61752428, 0.29191159, 0.87271953};
migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 3, 2}};
std::vector<float> m3 = {-1.07692443, 0.85223457, -0.37266530, 2.31511577, 0.04227017,
1.13229428, -0.52769242, 0.27307182, -0.47779843, -0.08023168,
-0.22862823, 0.81489871, 1.13139581, 1.13860467, 0.24309065,
0.26533729, 0.49106772, -1.18860493, 0.27842449, 1.03568141,
0.49759611, 0.10021662, 0.00592602, 0.90862000};
migraphx::shape m3_shape{migraphx::shape::float_type, {2, 3, 2, 2}};
auto l1 = p.add_literal(migraphx::literal{m1_shape, m1});
auto l2 = p.add_literal(migraphx::literal{m2_shape, m2});
auto l3 = p.add_literal(migraphx::literal{m3_shape, m3});
float alpha = 0.35;
float beta = 0.41;
auto m12_alpha = p.add_instruction(migraphx::op::dot{alpha, beta}, l1, l2);
auto l_beta = p.add_literal(beta);
auto b_beta = p.add_instruction(migraphx::op::scalar{m12_alpha->get_shape()}, l_beta);
auto m3_beta = p.add_instruction(migraphx::op::mul{}, b_beta, l3);
p.add_instruction(migraphx::op::add{}, m3_beta, m12_alpha);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
std::vector<float> m_res = {-0.91147203, 0.47540785, -0.30313587, 0.43325099, -0.43711586,
0.50928632, 0.06919868, -0.80382802, -0.05125718, -0.06685650,
-0.06972163, 0.32407764, 0.45677396, 0.25909489, 0.56911252,
-0.17183724, 0.10858734, 0.39406289, 0.04662959, 1.07979824,
0.40355016, 0.52410648, -0.31728447, 1.09550845};
EXPECT(migraphx::verify_range(m, m_res));
}
TEST_CASE(maxpool_test) TEST_CASE(maxpool_test)
{ {
migraphx::program p; migraphx::program p;
......
...@@ -870,7 +870,7 @@ struct test_gemm_transposeab : verify_program<test_gemm_transposeab> ...@@ -870,7 +870,7 @@ struct test_gemm_transposeab : verify_program<test_gemm_transposeab>
} }
}; };
struct gemm_mutli_dim_2 struct gemm_multi_dim_2 : verify_program<gemm_multi_dim_2>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -886,7 +886,127 @@ struct gemm_mutli_dim_2 ...@@ -886,7 +886,127 @@ struct gemm_mutli_dim_2
} }
}; };
struct gemm_mutli_dim_2_3 struct gemm_2args_mm_1 : verify_program<gemm_2args_mm_1>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {1, 3, 4}};
auto l1 = p.add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape);
auto bl2 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4}}, l2);
p.add_instruction(migraphx::op::dot{}, l1, bl2);
return p;
}
};
struct gemm_2args_mm_2 : verify_program<gemm_2args_mm_2>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {3, 4}};
auto l1 = p.add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape);
auto bl2 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4}}, l2);
p.add_instruction(migraphx::op::dot{}, l1, bl2);
return p;
}
};
struct gemm_2args_mm_3 : verify_program<gemm_2args_mm_3>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {1, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {3, 3, 4}};
auto l1 = p.add_parameter("1", m1_shape);
auto bl1 = p.add_instruction(migraphx::op::multibroadcast{{3, 2, 3}}, l1);
auto l2 = p.add_parameter("2", m2_shape);
p.add_instruction(migraphx::op::dot{}, bl1, l2);
return p;
}
};
struct gemm_2args_mm_4 : verify_program<gemm_2args_mm_4>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {3, 3, 4}};
auto l1 = p.add_parameter("1", m1_shape);
auto bl1 = p.add_instruction(migraphx::op::multibroadcast{{3, 2, 3}}, l1);
auto l2 = p.add_parameter("2", m2_shape);
p.add_instruction(migraphx::op::dot{}, bl1, l2);
return p;
}
};
struct gemm_2args_mm_5 : verify_program<gemm_2args_mm_5>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 1, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 3, 4}};
auto l1 = p.add_parameter("1", m1_shape);
auto bl1 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 2, 3}}, l1);
auto l2 = p.add_parameter("2", m2_shape);
p.add_instruction(migraphx::op::dot{}, bl1, l2);
return p;
}
};
struct gemm_2args_mm_6 : verify_program<gemm_2args_mm_6>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 1, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {1, 3, 3, 4}};
auto l1 = p.add_parameter("1", m1_shape);
auto bl1 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 2, 3}}, l1);
auto l2 = p.add_parameter("2", m2_shape);
auto bl2 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 3, 4}}, l2);
p.add_instruction(migraphx::op::dot{}, bl1, bl2);
return p;
}
};
struct gemm_2args_mm_7 : verify_program<gemm_2args_mm_7>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 3, 4}};
auto l1 = p.add_parameter("1", m1_shape);
auto bl1 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 2, 3}}, l1);
auto l2 = p.add_parameter("2", m2_shape);
p.add_instruction(migraphx::op::dot{}, bl1, l2);
return p;
}
};
struct gemm_multi_dim_2_3 : verify_program<gemm_multi_dim_2_3>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -902,6 +1022,180 @@ struct gemm_mutli_dim_2_3 ...@@ -902,6 +1022,180 @@ struct gemm_mutli_dim_2_3
} }
}; };
struct gemm_2args_vv : verify_program<gemm_2args_vv>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {8}};
migraphx::shape m2_shape{migraphx::shape::float_type, {8}};
auto l1 = p.add_parameter("1", m1_shape);
auto ul1 = p.add_instruction(migraphx::op::unsqueeze{{0}}, l1);
auto l2 = p.add_parameter("2", m2_shape);
auto ul2 = p.add_instruction(migraphx::op::unsqueeze{{1}}, l2);
float alpha = 0.23f;
auto res = p.add_instruction(migraphx::op::dot{alpha}, ul1, ul2);
auto sres = p.add_instruction(migraphx::op::squeeze{{0}}, res);
p.add_instruction(migraphx::op::squeeze{{0}}, sres);
return p;
}
};
struct gemm_2args_mv : verify_program<gemm_2args_mv>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {3, 5}};
migraphx::shape m2_shape{migraphx::shape::float_type, {5}};
auto l1 = p.add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape);
auto ul2 = p.add_instruction(migraphx::op::unsqueeze{{1}}, l2);
p.add_instruction(migraphx::op::dot{}, l1, ul2);
return p;
}
};
struct gemm_2args_bmv : verify_program<gemm_2args_bmv>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3, 3, 5}};
migraphx::shape m2_shape{migraphx::shape::float_type, {5}};
auto l1 = p.add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape);
auto ul2 = p.add_instruction(migraphx::op::unsqueeze{{1}}, l2);
auto bul2 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 5, 1}}, ul2);
p.add_instruction(migraphx::op::dot{}, l1, bul2);
return p;
}
};
struct gemm_2args_vm : verify_program<gemm_2args_vm>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {5}};
migraphx::shape m2_shape{migraphx::shape::float_type, {5, 4}};
auto l1 = p.add_parameter("1", m1_shape);
auto ul1 = p.add_instruction(migraphx::op::unsqueeze{{0}}, l1);
auto l2 = p.add_parameter("2", m2_shape);
auto res = p.add_instruction(migraphx::op::dot{}, ul1, l2);
p.add_instruction(migraphx::op::squeeze{{0}}, res);
return p;
}
};
struct gemm_2args_vbm : verify_program<gemm_2args_vbm>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {5}};
migraphx::shape m2_shape{migraphx::shape::float_type, {2, 2, 5, 4}};
auto l1 = p.add_parameter("1", m1_shape);
auto ul1 = p.add_instruction(migraphx::op::unsqueeze{{0}}, l1);
auto bul1 = p.add_instruction(migraphx::op::multibroadcast{{2, 2, 1, 5}}, ul1);
auto l2 = p.add_parameter("2", m2_shape);
auto res = p.add_instruction(migraphx::op::dot{}, bul1, l2);
p.add_instruction(migraphx::op::squeeze{{2}}, res);
return p;
}
};
struct gemm_multi_3args : verify_program<gemm_multi_3args>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 3, 2}};
migraphx::shape m3_shape{migraphx::shape::float_type, {2, 3, 2, 2}};
auto l1 = p.add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape);
auto l3 = p.add_parameter("3", m3_shape);
float alpha = 0.35;
float beta = 0.41;
p.add_instruction(migraphx::op::dot{alpha, beta}, l1, l2, l3);
return p;
}
};
struct gemm_multi_3args_c25 : verify_program<gemm_multi_3args_c25>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {3, 5}};
migraphx::shape m3_shape{migraphx::shape::float_type, {2, 5}};
auto l1 = p.add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape);
auto l3 = p.add_parameter("3", m3_shape);
float alpha = 0.35;
float beta = 0.41;
p.add_instruction(migraphx::op::dot{alpha, beta}, l1, l2, l3);
return p;
}
};
struct gemm_multi_3args_beta0 : verify_program<gemm_multi_3args_beta0>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {1, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {1, 3, 4}};
migraphx::shape m3_shape{migraphx::shape::float_type, {1, 2, 4}};
auto l1 = p.add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape);
auto l3 = p.add_parameter("3", m3_shape);
float alpha = 1.0f;
float beta = 0.0f;
p.add_instruction(migraphx::op::dot{alpha, beta}, l1, l2, l3);
return p;
}
};
struct gemm_multi_3args_alpha0 : verify_program<gemm_multi_3args_alpha0>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {1, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {1, 3, 4}};
migraphx::shape m3_shape{migraphx::shape::float_type, {1, 2, 4}};
auto l1 = p.add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape);
auto l3 = p.add_parameter("3", m3_shape);
float alpha = 0.0f;
float beta = 1.0f;
p.add_instruction(migraphx::op::dot{alpha, beta}, l1, l2, l3);
return p;
}
};
struct test_contiguous : verify_program<test_contiguous> struct test_contiguous : verify_program<test_contiguous>
{ {
migraphx::program create_program() const migraphx::program create_program() const
......
matmul-example:{

1
2y"MatMul test_matmulZ
1



Z
2





b
y





B
\ No newline at end of file
matmul-example:_

1
2y"MatMul test_matmulZ
1



Z
2

b
y


B
\ No newline at end of file
matmul-example:W

1
2y"MatMul test_matmulZ
1


Z
2

b
y

B
\ No newline at end of file
matmul-example:_

1
2y"MatMul test_matmulZ
1

Z
2



b
y


B
\ No newline at end of file
matmul-example:W

1
2y"MatMul test_matmulZ
1

Z
2


b
y

B
\ No newline at end of file
matmul-example:S

1
2y"MatMul test_matmulZ
1

Z
2

b
y

B
\ No newline at end of file
...@@ -566,7 +566,8 @@ TEST_CASE(gemm_test) ...@@ -566,7 +566,8 @@ TEST_CASE(gemm_test)
auto t0 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l0); auto t0 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l0);
auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l1); auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l1);
auto alpha = 2.f; auto alpha = 2.f;
p.add_instruction(migraphx::op::dot{alpha}, t0, t1); auto beta = 2.0f;
p.add_instruction(migraphx::op::dot{alpha, beta}, t0, t1);
auto prog = migraphx::parse_onnx("gemm_test.onnx"); auto prog = migraphx::parse_onnx("gemm_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
...@@ -575,20 +576,121 @@ TEST_CASE(gemm_test) ...@@ -575,20 +576,121 @@ TEST_CASE(gemm_test)
TEST_CASE(gemm_ex) TEST_CASE(gemm_ex)
{ {
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 6}}); auto l0 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 6}});
auto l1 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 7}}); auto l1 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 7}});
auto l2 = p.add_parameter("3", migraphx::shape{migraphx::shape::float_type, {1, 1, 6, 7}}); auto l2 = p.add_parameter("3", migraphx::shape{migraphx::shape::float_type, {1, 1, 6, 7}});
auto t0 = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, l0); auto t0 = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, l0);
auto alpha = 0.5f; auto alpha = 0.5f;
auto res_ab = p.add_instruction(migraphx::op::dot{alpha}, t0, l1); auto beta = 0.8f;
p.add_instruction(migraphx::op::dot{alpha, beta}, t0, l1, l2);
auto prog = migraphx::parse_onnx("gemm_test_ex.onnx");
auto beta = 0.8f; EXPECT(p == prog);
auto l_beta = p.add_literal(beta); }
auto brcst_beta = p.add_instruction(migraphx::op::scalar{l2->get_shape()}, l_beta);
auto res_c = p.add_instruction(migraphx::op::mul{}, l2, brcst_beta);
p.add_instruction(migraphx::op::add{}, res_ab, res_c);
auto prog = migraphx::parse_onnx("gemm_test_ex.onnx"); TEST_CASE(gemm_ex_brcst)
{
migraphx::program p;
auto l0 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 6}});
auto l1 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 7}});
auto l2 = p.add_parameter("3", migraphx::shape{migraphx::shape::float_type, {1, 1, 6, 1}});
auto t0 = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, l0);
std::vector<std::size_t> out_lens{1, 1, 6, 7};
auto t2 = p.add_instruction(migraphx::op::multibroadcast{out_lens}, l2);
auto alpha = 0.5f;
auto beta = 0.8f;
p.add_instruction(migraphx::op::dot{alpha, beta}, t0, l1, t2);
auto prog = migraphx::parse_onnx("gemm_test_ex1.onnx");
EXPECT(p == prog);
}
TEST_CASE(matmul_vv)
{
migraphx::program p;
auto l0 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {7}});
auto l1 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {7}});
auto sl0 = p.add_instruction(migraphx::op::unsqueeze{{0}}, l0);
auto sl1 = p.add_instruction(migraphx::op::unsqueeze{{1}}, l1);
auto res = p.add_instruction(migraphx::op::dot{1.0f, 0.0f}, sl0, sl1);
auto sr0 = p.add_instruction(migraphx::op::squeeze{{0}}, res);
p.add_instruction(migraphx::op::squeeze{{0}}, sr0);
auto prog = migraphx::parse_onnx("matmul_vv.onnx");
EXPECT(p == prog);
}
TEST_CASE(matmul_vm)
{
migraphx::program p;
auto l0 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {7}});
auto l1 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {7, 8}});
auto sl0 = p.add_instruction(migraphx::op::unsqueeze{{0}}, l0);
auto res = p.add_instruction(migraphx::op::dot{1.0f, 0.0f}, sl0, l1);
p.add_instruction(migraphx::op::squeeze{{0}}, res);
auto prog = migraphx::parse_onnx("matmul_vm.onnx");
EXPECT(p == prog);
}
TEST_CASE(matmul_vbm)
{
migraphx::program p;
auto l0 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {7}});
auto l1 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {5, 7, 8}});
auto sl0 = p.add_instruction(migraphx::op::unsqueeze{{0}}, l0);
auto bsl0 = p.add_instruction(migraphx::op::multibroadcast{{5, 1, 7}}, sl0);
std::cout << "ONNX_TEST" << std::endl;
auto res = p.add_instruction(migraphx::op::dot{1.0f, 0.0f}, bsl0, l1);
std::cout << "After Dot" << std::endl;
p.add_instruction(migraphx::op::squeeze{{1}}, res);
auto prog = migraphx::parse_onnx("matmul_vbm.onnx");
EXPECT(p == prog);
}
TEST_CASE(matmul_mv)
{
migraphx::program p;
auto l0 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {6, 7}});
auto l1 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {7}});
auto sl1 = p.add_instruction(migraphx::op::unsqueeze{{1}}, l1);
auto res = p.add_instruction(migraphx::op::dot{1.0f, 0.0f}, l0, sl1);
p.add_instruction(migraphx::op::squeeze{{1}}, res);
auto prog = migraphx::parse_onnx("matmul_mv.onnx");
EXPECT(p == prog);
}
TEST_CASE(matmul_bmv)
{
migraphx::program p;
auto l0 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 6, 7}});
auto l1 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {7}});
auto sl1 = p.add_instruction(migraphx::op::unsqueeze{{1}}, l1);
auto bsl1 = p.add_instruction(migraphx::op::multibroadcast{{3, 7, 1}}, sl1);
auto res = p.add_instruction(migraphx::op::dot{1.0f, 0.0f}, l0, bsl1);
p.add_instruction(migraphx::op::squeeze{{2}}, res);
auto prog = migraphx::parse_onnx("matmul_bmv.onnx");
EXPECT(p == prog);
}
TEST_CASE(matmul_bmbm)
{
migraphx::program p;
auto l0 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 6, 7}});
auto l1 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {5, 2, 1, 7, 8}});
auto bl0 = p.add_instruction(migraphx::op::multibroadcast{{5, 2, 3, 6, 7}}, l0);
auto bl1 = p.add_instruction(migraphx::op::multibroadcast{{5, 2, 3, 7, 8}}, l1);
p.add_instruction(migraphx::op::dot{1.0f, 0.0f}, bl0, bl1);
auto prog = migraphx::parse_onnx("matmul_bmbm.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
......
...@@ -371,21 +371,71 @@ TEST_CASE(logsoftmax) ...@@ -371,21 +371,71 @@ TEST_CASE(logsoftmax)
} }
} }
TEST_CASE(dot) // 2 inputs arguments
TEST_CASE(matmul)
{ {
{ {
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}}; migraphx::shape s_m1{migraphx::shape::float_type, {5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}}; migraphx::shape s_m2{migraphx::shape::float_type, {5}};
throws_shape(migraphx::op::dot{}, s_m1, s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 2}};
throws_shape(migraphx::op::dot{}, s_m1, s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5}};
throws_shape(migraphx::op::dot{}, s_m1, s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 4}};
expect_shape( expect_shape(
migraphx::shape{migraphx::shape::float_type, {4, 8}}, migraphx::op::dot{}, s_m1, s_m2); migraphx::shape{migraphx::shape::float_type, {1, 4}}, migraphx::op::dot{}, s_m1, s_m2);
} }
{ {
migraphx::shape s_m1{migraphx::shape::float_type, {4, 6}}; migraphx::shape s_m1{migraphx::shape::float_type, {1, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}}; migraphx::shape s_m2{migraphx::shape::float_type, {4, 4}};
throws_shape(migraphx::op::dot{}, s_m1, s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {6, 5, 4}};
throws_shape(migraphx::op::dot{}, s_m1, s_m2); throws_shape(migraphx::op::dot{}, s_m1, s_m2);
} }
{
migraphx::shape s_m1{migraphx::shape::float_type, {6, 1, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {6, 5, 4}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {6, 1, 4}},
migraphx::op::dot{},
s_m1,
s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 6, 1, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 6, 5, 4}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 6, 1, 4}},
migraphx::op::dot{},
s_m1,
s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
expect_shape(
migraphx::shape{migraphx::shape::float_type, {4, 8}}, migraphx::op::dot{}, s_m1, s_m2);
}
{ {
migraphx::shape s_m1{migraphx::shape::float_type, {1, 1}}; migraphx::shape s_m1{migraphx::shape::float_type, {1, 1}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 1}}; migraphx::shape s_m2{migraphx::shape::float_type, {1, 1}};
...@@ -403,45 +453,104 @@ TEST_CASE(dot) ...@@ -403,45 +453,104 @@ TEST_CASE(dot)
} }
{ {
migraphx::shape s_m1{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {2, 3, 5, 7}}; migraphx::shape s_m2{migraphx::shape::float_type, {1, 1, 5, 7}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 7}}, throws_shape(migraphx::op::dot{}, s_m1, s_m2);
migraphx::op::dot{},
s_m1,
s_m2);
} }
{ {
migraphx::shape s_m1{migraphx::shape::float_type, {1, 1, 4, 5}}; migraphx::shape s_m1{migraphx::shape::float_type, {1, 1, 4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 1, 5, 7}}; migraphx::shape s_m2{migraphx::shape::float_type, {1, 2, 5, 7}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 1, 4, 7}}, throws_shape(migraphx::op::dot{}, s_m1, s_m2);
}
}
// 3 input arguments
TEST_CASE(gemm)
{
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {1}};
throws_shape(migraphx::op::dot{}, s_m1, s_m2, s_m3);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {1, 1}};
throws_shape(migraphx::op::dot{}, s_m1, s_m2, s_m3);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {8}};
throws_shape(migraphx::op::dot{}, s_m1, s_m2, s_m3);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {4, 1}};
throws_shape(migraphx::op::dot{}, s_m1, s_m2, s_m3);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 6}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {4, 8}};
throws_shape(migraphx::op::dot{}, s_m1, s_m2, s_m3);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {4}};
throws_shape(migraphx::op::dot{}, s_m1, s_m2, s_m3);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {4, 8}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {4, 8}},
migraphx::op::dot{}, migraphx::op::dot{},
s_m1, s_m1,
s_m2); s_m2,
s_m3);
} }
{ {
migraphx::shape s_m1{migraphx::shape::float_type, {3, 1, 4, 6}}; migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {3, 1, 5, 7}}; migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}};
throws_shape(migraphx::op::dot{}, s_m1, s_m2); migraphx::shape s_m3{migraphx::shape::float_type, {1, 4, 8}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 4, 8}},
migraphx::op::dot{},
s_m1,
s_m2,
s_m3);
} }
{ {
migraphx::shape s_m1{migraphx::shape::float_type, {2, 2, 4, 5}}; migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 6}};
migraphx::shape s_m2{migraphx::shape::float_type, {3, 2, 5, 7}}; migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}};
throws_shape(migraphx::op::dot{}, s_m1, s_m2); migraphx::shape s_m3{migraphx::shape::float_type, {1, 4, 8}};
throws_shape(migraphx::op::dot{}, s_m1, s_m2, s_m3);
} }
{ {
migraphx::shape s_m1{migraphx::shape::float_type, {1, 1, 4, 5}}; migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 2, 5, 7}}; migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}};
throws_shape(migraphx::op::dot{}, s_m1, s_m2); migraphx::shape s_m3{migraphx::shape::float_type, {4, 8}};
throws_shape(migraphx::op::dot{}, s_m1, s_m2, s_m3);
} }
{ {
migraphx::shape s_m1{migraphx::shape::float_type, {1, 2, 4, 5}}; migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {2, 1, 5, 7}}; migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}};
throws_shape(migraphx::op::dot{}, s_m1, s_m2); migraphx::shape s_m3{migraphx::shape::float_type};
throws_shape(migraphx::op::dot{}, s_m1, s_m2, s_m3);
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment