Unverified Commit b1506c73 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Fix bug when concatting with the vectorization axis (#1653)

parent 7d26eb9d
......@@ -78,7 +78,9 @@ struct concat_compiler : compiler<concat_compiler>
options.params = "-Wno-float-equal";
options.kernel_name = v.get("kernel", "concat_kernel");
auto axis = find_fast_axis(options.inputs);
auto vec = vectorize::elements(ctx, axis, options.inputs);
vectorize vec{};
if(axis != v.at("axis").to<std::size_t>())
vec = vectorize::elements(ctx, axis, options.inputs);
options.set_launch_params(
v, compute_global_for(ctx, get_concat_elements(options.inputs) / vec.size, 256));
auto src = interpolate_string(
......
......@@ -33,13 +33,12 @@ struct test_concat_axis_2 : verify_program<test_concat_axis_2>
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s0{migraphx::shape::int32_type, {3, 2, 1}};
migraphx::shape s1{migraphx::shape::int32_type, {3, 2, 1}};
migraphx::shape s2{migraphx::shape::int32_type, {3, 2, 1}};
auto l0 = mm->add_parameter("x", s0);
auto l1 = mm->add_parameter("y", s1);
auto l2 = mm->add_parameter("z", s2);
mm->add_instruction(migraphx::make_op("concat", {{"axis", 2}}), l0, l1, l2);
migraphx::shape s{migraphx::shape::int32_type, {3, 2, 1}};
auto x0 = mm->add_parameter("x0", s);
auto x1 = mm->add_parameter("x1", s);
auto x2 = mm->add_parameter("x2", s);
auto x3 = mm->add_parameter("x3", s);
mm->add_instruction(migraphx::make_op("concat", {{"axis", 2}}), x0, x1, x2, x3);
return p;
}
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment