"src/targets/gpu/device/sqdiff.cpp" did not exist on "96358e41cc883791c8d3ad50280bea4871a18000"
Commit 5b4bb22c authored by Shucai Xiao's avatar Shucai Xiao
Browse files

extend the gemm implementation to support 3 arguments.

parent 5ded35b0
......@@ -810,7 +810,7 @@ struct gather
struct dot
{
float alpha = 1.0;
float beta = 0.0;
float beta = 1.0;
template <class Self, class F>
static auto reflect(Self& self, F f)
......@@ -821,7 +821,7 @@ struct dot
std::string name() const { return "dot"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(2).same_type();
check_shapes{{inputs[0], inputs[1]}, *this}.has(2).same_type();
const shape& a = inputs.at(0);
const shape& b = inputs.at(1);
auto t = a.type();
......@@ -831,14 +831,40 @@ struct dot
// as long as dim values are the same in the two inputs
if(!std::equal(a.lens().rbegin() + 2, a.lens().rend(), b.lens().rbegin() + 2))
{
MIGRAPHX_THROW("DOT: dim values mismatch");
MIGRAPHX_THROW("DOT: number of matrices in stack are different in A and B");
}
if (inputs.size() == 3)
{
check_shapes{{inputs[0], inputs[2]}, *this}.has(2).same_type();
const shape& c = inputs.at(2);
if(!std::equal(a.lens().rbegin() + 2, a.lens().rend(), c.lens().rbegin() + 2))
{
MIGRAPHX_THROW("DOT: number of matrices in stack are different in A and C");
}
}
std::size_t dim_0 = a.lens().size() - 2;
std::size_t dim_1 = a.lens().size() - 1;
if(a.lens()[dim_1] != b.lens()[dim_0])
MIGRAPHX_THROW("Inner dimensions do not match: {" + to_string_range(a.lens()) +
MIGRAPHX_THROW("DOT : inner dimensions do not match: {" + to_string_range(a.lens()) +
"} x {" + to_string_range(b.lens()) + "}");
if (inputs.size() == 3)
{
const shape& c = inputs.at(2);
if (a.lens()[dim_0] != c.lens()[dim_0])
{
MIGRAPHX_THROW("DOT : matrix size does not match: A: {" + to_string_range(a.lens()) +
"}, C: {" + to_string_range(c.lens()) + "}");
}
if (b.lens()[dim_1] != c.lens()[dim_1])
{
MIGRAPHX_THROW("DOT : matrix size does not match: B: {" + to_string_range(b.lens()) +
"}, C: {" + to_string_range(c.lens()) + "}");
}
}
auto out_lens = a.lens();
out_lens[dim_1] = b.lens()[dim_1];
return {t, out_lens};
......
......@@ -55,7 +55,15 @@ void migemm_impl(tensor_view<T> cmat,
visit_mat(amat, [&](const auto& a) {
visit_mat(bmat, [&](const auto& b) {
auto c = make_mat(cmat);
c = (a * b) * alpha + beta * c;
if (beta != 0.0)
{
c = beta * c;
}
if (alpha != 0.0)
{
c = c + alpha * a * b;
}
});
});
}
......
......@@ -374,6 +374,22 @@ struct cpu_gemm
argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
// if there is a C input
if (args.size() == 3)
{
result.visit([&](auto output) {
args[2].visit([&](auto input) {
std::copy(input.begin(), input.end(), output.begin());
});
});
}
else
{
result.visit([&](auto output) {
std::fill(output.begin(), output.end(), 0);
});
}
migemm(result, args[0], args[1], op.alpha, op.beta);
return result;
}
......
......@@ -90,15 +90,12 @@ rocblas_half to_rocblas_type(half x) { return reinterpret_cast<const rocblas_hal
shape miopen_gemm::compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(3);
return op.compute_shape({inputs.at(0), inputs.at(1)});
return op.compute_shape(inputs);
}
argument miopen_gemm::compute(context& ctx,
const shape& output_shape,
const std::vector<argument>& args) const
{
float alpha = 1.0f;
float beta = 0.0f;
bool transa = args[0].get_shape().transposed();
bool transb = args[1].get_shape().transposed();
std::size_t n_dims = args[0].get_shape().lens().size();
......@@ -113,9 +110,19 @@ argument miopen_gemm::compute(context& ctx,
rocblas_int k = args[0].get_shape().lens()[dim_1];
auto batch_num = std::accumulate(
out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
bool is_3inputs = (args.size() == 4);
output_shape.visit_type([&](auto as) {
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
if (is_3inputs)
hipMemcpy(to_pointer(args[3]), to_pointer(args[2]), output_shape.bytes(), hipMemcpyDeviceToDevice);
else
hipMemset(to_pointer(args[2]), 0, output_shape.bytes());
});
output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(alpha));
auto beta_r = to_rocblas_type(as(beta));
auto alpha_r = to_rocblas_type(as(op.alpha));
auto beta_r = to_rocblas_type(as(op.beta));
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
generic_rocblas_batched_gemm(as,
ctx.get_stream().get_rocblas(),
......@@ -132,14 +139,14 @@ argument miopen_gemm::compute(context& ctx,
lda,
m * k,
&beta_r,
to_pointer(args[2]),
is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]),
ldc,
m * n,
batch_num);
});
return args[2];
return (is_3inputs ? args[3] : args[2]);
}
} // namespace gpu
......
......@@ -1112,6 +1112,53 @@ TEST_CASE(gemm_mutli_dim1_2_3)
EXPECT(migraphx::verify_range(m, m_res));
}
TEST_CASE(gemm_mutli_3args)
{
migraphx::program p;
std::vector<float> m1 = {
1.23636469, -0.47041261, -0.14375651, -0.48371852, 1.16479301, -0.89361055,
-0.18569086, 1.10700457, -1.02632638, 0.82277012, 0.33525769, 0.52825145,
-1.00141689, 0.45510090, -0.02675039, -0.60454439, 0.38551153, -0.01658514,
0.93059292, -0.54595188, -0.04911005, -0.91397221, -0.83127477, -1.57685603,
-1.36200452, 2.25822236, -1.23416970, 0.12312496, 0.76232760, -0.83594234,
1.67418145, -0.19412936, 1.05261378, 0.66246074, -1.15233398, 0.16429736};
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3, 2, 3}};
std::vector<float> m2 = {
-0.87300530, -0.07112838, 0.19196860, -1.04986840, 1.20348200, 0.31966893,
1.04805440, -2.04777729, -0.67906052, -1.17250760, 0.34305044, -1.01957785,
-1.12694862, 0.18431338, -1.63712290, 0.27566931, -1.11282021, 1.41738919,
0.47871283, -1.01980420, 1.00212436, -0.78740444, -1.65636133, 1.51466547,
-0.12470397, 0.70404393, -0.15244797, 0.74288871, 0.07339926, -1.45811623,
0.27185845, 0.08804596, 0.99061977, -1.61752428, 0.29191159, 0.87271953};
migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 3, 2}};
std::vector<float> m3 = {-1.07692443, 0.85223457, -0.37266530, 2.31511577, 0.04227017,
1.13229428, -0.52769242, 0.27307182, -0.47779843, -0.08023168,
-0.22862823, 0.81489871, 1.13139581, 1.13860467, 0.24309065,
0.26533729, 0.49106772, -1.18860493, 0.27842449, 1.03568141,
0.49759611, 0.10021662, 0.00592602, 0.90862000};
migraphx::shape m3_shape{migraphx::shape::float_type, {2, 3, 2, 2}};
auto l1 = p.add_literal(migraphx::literal{m1_shape, m1});
auto l2 = p.add_literal(migraphx::literal{m2_shape, m2});
auto l3 = p.add_literal(migraphx::literal{m3_shape, m3});
float alpha = 0.35;
float beta = 0.41;
p.add_instruction(migraphx::op::dot{alpha, beta}, l1, l2, l3);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
std::vector<float> m_res = {-0.91147203, 0.47540785, -0.30313587, 0.43325099, -0.43711586,
0.50928632, 0.06919868, -0.80382802, -0.05125718, -0.06685650,
-0.06972163, 0.32407764, 0.45677396, 0.25909489, 0.56911252,
-0.17183724, 0.10858734, 0.39406289, 0.04662959, 1.07979824,
0.40355016, 0.52410648, -0.31728447, 1.09550845};
EXPECT(migraphx::verify_range(m, m_res));
}
TEST_CASE(maxpool_test)
{
migraphx::program p;
......
......@@ -882,6 +882,26 @@ struct gemm_mutli_dim_2_3
}
};
struct gemm_mutli_3args
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 3, 2}};
migraphx::shape m3_shape{migraphx::shape::float_type, {2, 3, 2, 2}};
auto l1 = p.add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape);
auto l3 = p.add_parameter("3", m3_shape);
float alpha = 0.35;
float beta = 0.41;
p.add_instruction(migraphx::op::dot{alpha, beta}, l1, l2, l3);
return p;
}
};
struct test_contiguous
{
migraphx::program create_program() const
......@@ -3016,6 +3036,7 @@ int main()
verify_program<test_gemm_transposeab>();
verify_program<gemm_mutli_dim_2>();
verify_program<gemm_mutli_dim_2_3>();
verify_program<gemm_mutli_3args>();
verify_program<test_contiguous>();
verify_program<test_eliminate_contiguous>();
verify_program<test_transpose>();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment