Unverified Commit 1329b9be authored by shivadbhavsar's avatar shivadbhavsar Committed by GitHub
Browse files

fix stable diffusion decoder non standard shape issue (#1594)

parent e3fb3a0d
...@@ -48,7 +48,7 @@ struct contiguous ...@@ -48,7 +48,7 @@ struct contiguous
{ {
check_shapes{inputs, *this, true}.has(1); check_shapes{inputs, *this, true}.has(1);
auto s0 = inputs.front(); auto s0 = inputs.front();
if(s0.dynamic() or s0.standard()) if(s0.dynamic())
{ {
return s0; return s0;
} }
......
...@@ -53,8 +53,8 @@ struct parse_reshape : op_parser<parse_reshape> ...@@ -53,8 +53,8 @@ struct parse_reshape : op_parser<parse_reshape>
s.visit([&](auto v) { copy(v, std::back_inserter(dims)); }); s.visit([&](auto v) { copy(v, std::back_inserter(dims)); });
} }
return info.add_instruction(make_op("reshape", {{"dims", dims}}), auto cont = info.add_instruction(make_op("contiguous"), args[0]);
info.make_contiguous(args[0])); return info.add_instruction(make_op("reshape", {{"dims", dims}}), cont);
} }
}; };
......
...@@ -5116,8 +5116,10 @@ TEST_CASE(reshape_test) ...@@ -5116,8 +5116,10 @@ TEST_CASE(reshape_test)
migraphx::literal{migraphx::shape{migraphx::shape::int64_type, {2}}, reshape_dims}); migraphx::literal{migraphx::shape{migraphx::shape::int64_type, {2}}, reshape_dims});
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {4, 2, 3}}); auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {4, 2, 3}});
op.dims = reshape_dims; op.dims = reshape_dims;
mm->add_instruction(op, l0); auto c0 = mm->add_instruction(migraphx::make_op("contiguous"), l0);
mm->add_instruction(op, l0); mm->add_instruction(op, c0);
auto c1 = mm->add_instruction(migraphx::make_op("contiguous"), l0);
mm->add_instruction(op, c1);
auto prog = optimize_onnx("reshape_test.onnx"); auto prog = optimize_onnx("reshape_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
......
...@@ -431,11 +431,18 @@ TEST_CASE(contiguous_dyn_shape) ...@@ -431,11 +431,18 @@ TEST_CASE(contiguous_dyn_shape)
TEST_CASE(contiguous_shape_scalar) TEST_CASE(contiguous_shape_scalar)
{ {
migraphx::shape output{migraphx::shape::float_type}; migraphx::shape output{migraphx::shape::float_type, {1}};
migraphx::shape input{migraphx::shape::float_type}; migraphx::shape input{migraphx::shape::float_type};
expect_shape(output, migraphx::make_op("contiguous"), input); expect_shape(output, migraphx::make_op("contiguous"), input);
} }
TEST_CASE(contiguous_shape_singleton_dim)
{
migraphx::shape output{migraphx::shape::float_type, {5, 1, 8}, {8, 8, 1}};
migraphx::shape input{migraphx::shape::float_type, {5, 1, 8}, {8, 4, 1}};
expect_shape(output, migraphx::make_op("contiguous"), input);
}
TEST_CASE(deconvolution_shape) TEST_CASE(deconvolution_shape)
{ {
migraphx::shape input{migraphx::shape::float_type, {4, 4, 1, 1}}; migraphx::shape input{migraphx::shape::float_type, {4, 4, 1, 1}};
......
...@@ -85,6 +85,15 @@ TEST_CASE(test_shape_standard) ...@@ -85,6 +85,15 @@ TEST_CASE(test_shape_standard)
EXPECT(not s.broadcasted()); EXPECT(not s.broadcasted());
} }
TEST_CASE(test_shape_standard_singleton_dim)
{
migraphx::shape s{migraphx::shape::float_type, {5, 1, 8}, {8, 4, 1}};
EXPECT(s.standard());
EXPECT(s.packed());
EXPECT(not s.transposed());
EXPECT(not s.broadcasted());
}
TEST_CASE(test_shape_min_max_opt) TEST_CASE(test_shape_min_max_opt)
{ {
migraphx::shape s{migraphx::shape::float_type, {2, 2, 3}, {6, 3, 1}}; migraphx::shape s{migraphx::shape::float_type, {2, 2, 3}, {6, 3, 1}};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment