Commit c43eba64 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

changes to support the seq2seq translation example.

parent af90a792
......@@ -843,10 +843,31 @@ struct dot
const shape& b = inputs.at(1);
auto t = a.type();
if(a.lens()[1] != b.lens()[0])
// change to support cases like {1, 1, 3, 5} X {1, 1, 5, 6},
// which can be handled by numpy. as long as all previous
// dims are 1 except the last two dims, the two matrices
// are multipliable
if (std::any_of(a.lens().rbegin() + 2, a.lens().rend(), [](auto i) {
return (i != 1);
}))
{
MIGRAPHX_THROW("DOT: first matrix, dimensions before matrix dims must be 1");
}
if (std::any_of(b.lens().rbegin() + 2, b.lens().rend(), [](auto i) {
return (i != 1);
}))
{
MIGRAPHX_THROW("DOT: second matrix, dimensions before matrix dims must be 1");
}
std::size_t n_dims = a.lens().size();
if(a.lens()[n_dims - 1] != b.lens()[n_dims - 2])
MIGRAPHX_THROW("Inner dimensions do not match: {" + to_string_range(a.lens()) +
"} x {" + to_string_range(b.lens()) + "}");
return {t, {a.lens()[0], b.lens()[1]}};
auto out_lens = a.lens();
out_lens[n_dims - 1] = b.lens()[n_dims - 1];
return {t, out_lens};
}
};
......
......@@ -470,17 +470,6 @@ struct onnx_parser
transb = parse_value(attributes.at("transB")).at<bool>();
}
// beginning or end of both args have dimension 1, need to squeeze
// before calling gemm, then doing unsqueeze after getting results
std::size_t num_squeeze = args[0]->get_shape().lens().size();
if(num_squeeze > 2)
{
std::vector<int64_t> vec_axises(num_squeeze - 2);
std::iota(vec_axises.begin(), vec_axises.end(), 0);
args[0] = prog.add_instruction(op::squeeze{vec_axises}, args[0]);
args[1] = prog.add_instruction(op::squeeze{vec_axises}, args[1]);
}
std::vector<int64_t> perm = {1, 0};
auto l1 = (transa) ? prog.add_instruction(op::transpose{perm}, args[0]) : args[0];
auto l2 = (transb) ? prog.add_instruction(op::transpose{perm}, args[1]) : args[1];
......@@ -489,13 +478,6 @@ struct onnx_parser
if(beta != 0.f)
{
auto l3 = prog.add_instruction(op::dot{alpha}, l1, l2);
if(num_squeeze > 2)
{
std::vector<int64_t> vec_axises(num_squeeze - 2);
std::iota(vec_axises.begin(), vec_axises.end(), 0);
l3 = prog.add_instruction(op::unsqueeze{vec_axises}, l3);
}
auto l4 = args[2];
if(l4->get_shape().scalar()) // ignore args[2] (no C value added to alpha*A*B)
return l3;
......@@ -510,12 +492,6 @@ struct onnx_parser
}
auto dot_res = prog.add_instruction(op::dot{alpha, beta}, l1, l2);
if(num_squeeze > 2)
{
std::vector<int64_t> vec_axises(num_squeeze - 2);
std::iota(vec_axises.begin(), vec_axises.end(), 0);
dot_res = prog.add_instruction(op::unsqueeze{vec_axises}, dot_res);
}
return dot_res;
}
......
......@@ -14,10 +14,13 @@ template <class T>
static auto make_mat(tensor_view<T> x)
{
const auto& s = x.get_shape();
assert(s.lens().size() == 2);
//assert(s.lens().size() == 2);
std::size_t n_dims = s.lens().size();
std::size_t dim_0 = n_dims - 2;
std::size_t dim_1 = n_dims - 1;
if(s.transposed())
return matrix<T>{x.data(), s.lens()[1], s.lens()[0], s.strides()[1]};
return matrix<T>{x.data(), s.lens()[0], s.lens()[1], s.strides()[0]};
return matrix<T>{x.data(), s.lens()[dim_1], s.lens()[dim_0], s.strides()[dim_1]};
return matrix<T>{x.data(), s.lens()[dim_0], s.lens()[dim_1], s.strides()[dim_0]};
}
template <class T, class F>
......@@ -64,13 +67,16 @@ void migemm_impl(tensor_view<T> cmat,
float beta,
std::false_type)
{
auto m = cmat.get_shape().lens()[0];
auto n = cmat.get_shape().lens()[1];
auto k = amat.get_shape().lens()[1];
std::size_t n_dims = cmat.get_shape().lens().size();
std::size_t dim_0 = n_dims - 2;
std::size_t dim_1 = n_dims - 1;
auto m = cmat.get_shape().lens()[dim_0];
auto n = cmat.get_shape().lens()[dim_1];
auto k = amat.get_shape().lens()[dim_1];
assert(amat.get_shape().lens()[1] == bmat.get_shape().lens()[0]);
assert(m == amat.get_shape().lens()[0]);
assert(n == bmat.get_shape().lens()[1]);
assert(amat.get_shape().lens()[dim_1] == bmat.get_shape().lens()[dim_0]);
assert(m == amat.get_shape().lens()[dim_0]);
assert(n == bmat.get_shape().lens()[dim_1]);
dfor(m, n)([&](auto ii, auto jj) {
double s = cmat(ii, jj) * beta;
......
......@@ -80,12 +80,15 @@ argument miopen_gemm::compute(context& ctx,
float beta = 0.0f;
bool transa = args[0].get_shape().transposed();
bool transb = args[1].get_shape().transposed();
rocblas_int lda = args[0].get_shape().strides()[transa ? 1 : 0];
rocblas_int ldb = args[1].get_shape().strides()[transb ? 1 : 0];
rocblas_int ldc = args[2].get_shape().strides()[0];
rocblas_int m = output_shape.lens()[0];
rocblas_int n = output_shape.lens()[1];
rocblas_int k = args[0].get_shape().lens()[1];
std::size_t n_dims = args[0].get_shape().lens().size();
std::size_t dim_0 = n_dims - 2;
std::size_t dim_1 = n_dims - 1;
rocblas_int lda = args[0].get_shape().strides()[transa ? dim_1 : dim_0];
rocblas_int ldb = args[1].get_shape().strides()[transb ? dim_1 : dim_0];
rocblas_int ldc = args[2].get_shape().strides()[dim_0];
rocblas_int m = output_shape.lens()[dim_0];
rocblas_int n = output_shape.lens()[dim_1];
rocblas_int k = args[0].get_shape().lens()[dim_1];
output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(alpha));
auto beta_r = to_rocblas_type(as(beta));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment