Unverified Commit 1329b9be authored by shivadbhavsar's avatar shivadbhavsar Committed by GitHub
Browse files

fix stable diffusion decoder non standard shape issue (#1594)

parent e3fb3a0d
......@@ -48,7 +48,7 @@ struct contiguous
{
check_shapes{inputs, *this, true}.has(1);
auto s0 = inputs.front();
if(s0.dynamic() or s0.standard())
if(s0.dynamic())
{
return s0;
}
......
......@@ -53,8 +53,8 @@ struct parse_reshape : op_parser<parse_reshape>
s.visit([&](auto v) { copy(v, std::back_inserter(dims)); });
}
return info.add_instruction(make_op("reshape", {{"dims", dims}}),
info.make_contiguous(args[0]));
auto cont = info.add_instruction(make_op("contiguous"), args[0]);
return info.add_instruction(make_op("reshape", {{"dims", dims}}), cont);
}
};
......
......@@ -5116,8 +5116,10 @@ TEST_CASE(reshape_test)
migraphx::literal{migraphx::shape{migraphx::shape::int64_type, {2}}, reshape_dims});
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {4, 2, 3}});
op.dims = reshape_dims;
mm->add_instruction(op, l0);
mm->add_instruction(op, l0);
auto c0 = mm->add_instruction(migraphx::make_op("contiguous"), l0);
mm->add_instruction(op, c0);
auto c1 = mm->add_instruction(migraphx::make_op("contiguous"), l0);
mm->add_instruction(op, c1);
auto prog = optimize_onnx("reshape_test.onnx");
EXPECT(p == prog);
......
......@@ -431,11 +431,18 @@ TEST_CASE(contiguous_dyn_shape)
TEST_CASE(contiguous_shape_scalar)
{
migraphx::shape output{migraphx::shape::float_type};
migraphx::shape output{migraphx::shape::float_type, {1}};
migraphx::shape input{migraphx::shape::float_type};
expect_shape(output, migraphx::make_op("contiguous"), input);
}
TEST_CASE(contiguous_shape_singleton_dim)
{
migraphx::shape output{migraphx::shape::float_type, {5, 1, 8}, {8, 8, 1}};
migraphx::shape input{migraphx::shape::float_type, {5, 1, 8}, {8, 4, 1}};
expect_shape(output, migraphx::make_op("contiguous"), input);
}
TEST_CASE(deconvolution_shape)
{
migraphx::shape input{migraphx::shape::float_type, {4, 4, 1, 1}};
......
......@@ -85,6 +85,15 @@ TEST_CASE(test_shape_standard)
EXPECT(not s.broadcasted());
}
TEST_CASE(test_shape_standard_singleton_dim)
{
migraphx::shape s{migraphx::shape::float_type, {5, 1, 8}, {8, 4, 1}};
EXPECT(s.standard());
EXPECT(s.packed());
EXPECT(not s.transposed());
EXPECT(not s.broadcasted());
}
TEST_CASE(test_shape_min_max_opt)
{
migraphx::shape s{migraphx::shape::float_type, {2, 2, 3}, {6, 3, 1}};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment