"docs/source/nas/execution_engine.rst" did not exist on "0247be5e6071f224690729bd3b33a7d0675e0c71"
Commit aaad44e6 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

fix review comments

parents 1b895f4a 291762b7
...@@ -167,39 +167,27 @@ rb_type<T>* to_rocblas_type(T* x) ...@@ -167,39 +167,27 @@ rb_type<T>* to_rocblas_type(T* x)
rocblas_half to_rocblas_type(half x) { return reinterpret_cast<const rocblas_half&>(x); } rocblas_half to_rocblas_type(half x) { return reinterpret_cast<const rocblas_half&>(x); }
shape miopen_gemm::compute_shape(const std::vector<shape>& inputs) const void miopen_gemm::batch_not_transposed(const std::vector<std::size_t>& strides) const
{ {
std::vector<shape> input_shapes(inputs.begin(), inputs.begin() + inputs.size() - 1); if(strides.size() <= 2)
check_shapes{input_shapes}.not_broadcasted(); return;
auto a_strides = inputs[0].strides(); auto dim_0 = strides.size() - 2;
if(a_strides.size() > 2) auto matrix_size = std::max(strides[dim_0], strides[dim_0 + 1]);
std::vector<std::size_t> batch(strides.begin(), strides.begin() + dim_0);
if(std::adjacent_find(batch.begin(), batch.end(), [&](auto i, auto j) {
return (i < j or i < matrix_size or j < matrix_size);
}) != batch.end())
{ {
auto dim_1 = a_strides.size() - 1; MIGRAPHX_THROW("DOT: batch size {" + to_string_range(strides) + "} is transposed!");
auto dim_0 = dim_1 - 1;
auto matrix_size = std::max(a_strides[dim_0], a_strides[1]);
if(std::adjacent_find(a_strides.begin(), a_strides.begin() + dim_0, [&](auto i, auto j) {
return (i < j or i < matrix_size or j < matrix_size);
}) != a_strides.begin() + dim_0)
{
MIGRAPHX_THROW("DOT: batch size of a {" + to_string_range(a_strides) +
"} is transposed!");
}
} }
}
auto b_strides = inputs[1].strides(); shape miopen_gemm::compute_shape(const std::vector<shape>& inputs) const
if(b_strides.size() > 2) {
{ std::vector<shape> input_shapes(inputs.begin(), inputs.begin() + inputs.size() - 1);
auto dim_1 = b_strides.size() - 1; check_shapes{input_shapes}.not_broadcasted();
auto dim_0 = dim_1 - 1; batch_not_transposed(inputs[0].strides());
auto matrix_size = std::max(b_strides[dim_0], b_strides[1]); batch_not_transposed(inputs[1].strides());
if(std::adjacent_find(b_strides.begin(), b_strides.begin() + dim_0, [&](auto i, auto j) {
return (i < j or i < matrix_size or j < matrix_size);
}) != b_strides.begin() + dim_0)
{
MIGRAPHX_THROW("DOT: batch size of b {" + to_string_range(b_strides) +
"} is transposed!");
}
}
return op.compute_shape(input_shapes); return op.compute_shape(input_shapes);
} }
......
...@@ -24,6 +24,7 @@ struct miopen_gemm ...@@ -24,6 +24,7 @@ struct miopen_gemm
shape compute_shape(const std::vector<shape>& inputs) const; shape compute_shape(const std::vector<shape>& inputs) const;
argument argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const; compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
void batch_not_transposed(const std::vector<std::size_t>& strides) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{ {
return shapes.size() - 1; return shapes.size() - 1;
......
...@@ -445,6 +445,21 @@ TEST_CASE(reshape_test) ...@@ -445,6 +445,21 @@ TEST_CASE(reshape_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(reshape_non_standard)
{
migraphx::program p;
migraphx::op::reshape op;
std::vector<int64_t> reshape_dims{4, 3, 2};
migraphx::shape s{migraphx::shape::float_type, {2, 3, 4}};
auto x = p.add_parameter("x", s);
auto tran_x = p.add_instruction(migraphx::op::transpose{{0, 2, 1}}, x);
auto cont_x = p.add_instruction(migraphx::op::contiguous{}, tran_x);
p.add_instruction(migraphx::op::reshape{{4, 3, 2}}, cont_x);
auto prog = migraphx::parse_onnx("reshape_non_standard.onnx");
EXPECT(p == prog);
}
TEST_CASE(shape_test) TEST_CASE(shape_test)
{ {
migraphx::program p; migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment