"testing/vscode:/vscode.git/clone" did not exist on "7959d78662d9846f6786e81ec5ec12d2932ecc5e"
Commit c06d254a authored by Paul's avatar Paul
Browse files

Merge

parents 33189188 1555f298
...@@ -105,15 +105,7 @@ static void remove_contiguous(const std::string& op_name, module& m, F f) ...@@ -105,15 +105,7 @@ static void remove_contiguous(const std::string& op_name, module& m, F f)
} }
else if(prev->can_eval()) else if(prev->can_eval())
{ {
replace(new_args, arg, prev); const_instructions.push_back(arg);
if(try_compute_shape(ins, new_args, mod_args))
{
instruction::replace_argument(ins, arg, prev);
}
else if(prev->can_eval())
{
const_instructions.push_back(arg);
}
} }
} }
} }
......
...@@ -19,11 +19,12 @@ namespace op { ...@@ -19,11 +19,12 @@ namespace op {
struct unsqueeze struct unsqueeze
{ {
std::vector<int64_t> axes; std::vector<int64_t> axes;
std::vector<int64_t> steps;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
{ {
return pack(f(self.axes, "axes")); return pack(f(self.axes, "axes"), f(self.steps, "steps"));
} }
value attributes() const value attributes() const
...@@ -57,16 +58,27 @@ struct unsqueeze ...@@ -57,16 +58,27 @@ struct unsqueeze
std::size_t p = 0; std::size_t p = 0;
for(auto i : range(new_size)) for(auto i : range(new_size))
{ {
if(std::find(axes.begin(), axes.end(), i) != axes.end()) auto axis_idx = std::find(axes.begin(), axes.end(), i) - axes.begin();
if(axis_idx < axes.size())
{ {
new_lens[i] = 1; std::int64_t step = 1;
if(p == 0) // unsqueeze on the first axes if(axis_idx < steps.size())
step = steps[axis_idx];
if(step == 0)
MIGRAPHX_THROW("UNSQUEEZE: step must be non-zero");
new_lens[i] = step;
if(p < old_strides.size())
{ {
new_strides[i] = old_lens[0] * old_strides[0]; if((old_lens[p] % step) != 0)
MIGRAPHX_THROW("UNSQUEEZE: Axis dimenstion is not divisible by step");
old_lens[p] /= step;
new_strides[i] = old_strides[p] * old_lens[p];
} }
else // unsqueeze on middle or last axes else
{ {
new_strides[i] = (p < old_strides.size()) ? old_strides[p - 1] : 1; if(step != 1)
MIGRAPHX_THROW("UNSQUEEZE: Step must be 1 for extra axes");
new_strides[i] = 1;
} }
} }
else else
......
...@@ -128,8 +128,11 @@ struct find_transpose ...@@ -128,8 +128,11 @@ struct find_transpose
{ {
auto matcher() const auto matcher() const
{ {
return match::name("transpose")(match::none_of( auto output_not_transpose =
match::skip_output(match::name("contiguous"))(match::name("transpose")))); match::none_of(match::skip_output(match::name("contiguous"))(match::name("transpose")));
auto input_has_transpose =
match::skip(match::name("contiguous"))(match::args(match::name("transpose")));
return match::name("transpose")(output_not_transpose, input_has_transpose);
} }
void apply(module& m, const match::matcher_result& mr) const void apply(module& m, const match::matcher_result& mr) const
...@@ -578,7 +581,7 @@ struct find_transpose_contiguous_reshaper_unary ...@@ -578,7 +581,7 @@ struct find_transpose_contiguous_reshaper_unary
} }
}; };
struct find_transpose_slice struct find_slice_transpose
{ {
auto matcher() const auto matcher() const
{ {
...@@ -626,10 +629,9 @@ struct find_transpose_slice ...@@ -626,10 +629,9 @@ struct find_transpose_slice
auto split = splits[i]; auto split = splits[i];
auto t = transposes[i]; auto t = transposes[i];
auto op = any_cast<op::slice>(split->get_operator()); auto op = any_cast<op::slice>(split->get_operator());
for(auto& axis : op.axes) std::transform(op.axes.begin(), op.axes.end(), op.axes.begin(), [&](auto axis) {
{ return iperm[axis];
axis = iperm[axis]; });
}
auto new_ins = m.insert_instruction(t, op, pre); auto new_ins = m.insert_instruction(t, op, pre);
if(t->get_operator() != pre->get_operator()) if(t->get_operator() != pre->get_operator())
{ {
...@@ -642,9 +644,94 @@ struct find_transpose_slice ...@@ -642,9 +644,94 @@ struct find_transpose_slice
} }
}; };
struct find_transpose_slice
{
auto matcher() const
{
return match::name("transpose")(match::all_of[match::outputs()](match::name("slice")));
}
static std::vector<int64_t> slice_distance(const op::slice& op)
{
assert(op.starts.size() == op.ends.size());
std::vector<int64_t> result(op.starts.size());
std::transform(
op.ends.begin(), op.ends.end(), op.starts.begin(), result.begin(), std::minus<>{});
return result;
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto slices = ins->outputs();
if(slices.empty())
return;
auto slice = any_cast<op::slice>(slices.front()->get_operator());
auto sdistance = slice_distance(slice);
// Check all distances and axes are the same
if(std::any_of(slices.begin(), slices.end(), [&](auto sins) {
auto s = any_cast<op::slice>(sins->get_operator());
return s.axes != slice.axes or slice_distance(s) != sdistance;
}))
return;
// Check distances are divisible by axes
auto mod_by_distance = [&](const auto& v, auto f) {
return std::inner_product(v.begin(),
v.end(),
sdistance.begin(),
0,
std::plus<>{},
[&](auto x, auto d) -> uint64_t {
if(d == 0)
return 1;
return f(x) % d;
});
};
if(mod_by_distance(slice.axes, [&](auto x) { return ins->get_shape().lens()[x]; }) != 0 or
mod_by_distance(slice.starts, id{}) != 0 or mod_by_distance(slice.ends, id{}) != 0)
return;
// TODO: Handle multiple axes
if(sdistance.size() != 1)
return;
auto axis = slice.axes.front();
// Skip if axis would be packed
if(std::all_of(ins->get_shape().lens().begin(),
ins->get_shape().lens().begin() + axis,
[](auto x) { return x == 1; }))
return;
// Compute axis before transpose to use for unsqueeeze
auto perm = ins->get_operator().to_value()["permutation"].to_vector<int64_t>();
auto preaxis = std::find(perm.begin(), perm.end(), axis) - perm.begin();
// Make unsqeeze
auto unsqueeze = m.insert_instruction(
ins, make_op("unsqueeze", {{"axes", {preaxis}}, {"steps", sdistance}}), ins->inputs());
// Make transpose
std::transform(perm.begin(), perm.end(), perm.begin(), [&](auto i) {
if(i > preaxis)
return i + 1;
return i;
});
perm.insert(perm.begin(), preaxis + 1);
auto transpose =
m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), unsqueeze);
// Slice and sqeeze
for(auto s : slices)
{
auto op = any_cast<op::slice>(s->get_operator());
op.axes = {0};
op.starts = {op.starts.front() / sdistance.front()};
op.ends = {op.ends.front() / sdistance.front()};
auto slice_ins = m.insert_instruction(ins, op, transpose);
auto squeeze =
m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), slice_ins);
m.replace_instruction(s, squeeze);
}
}
};
void simplify_reshapes::apply(module& m) const void simplify_reshapes::apply(module& m) const
{ {
for(int i = 0; i < 2; i++) for(int i = 0; i < 4; i++)
{ {
match::find_matches(m, match::find_matches(m,
find_where_op{}, find_where_op{},
...@@ -658,6 +745,7 @@ void simplify_reshapes::apply(module& m) const ...@@ -658,6 +745,7 @@ void simplify_reshapes::apply(module& m) const
find_nested_slice{}, find_nested_slice{},
find_nested_concat{}, find_nested_concat{},
find_transpose_slice{}, find_transpose_slice{},
find_slice_transpose{},
find_transpose_contiguous_reshaper_unary{}); find_transpose_contiguous_reshaper_unary{});
dead_code_elimination{}.apply(m); dead_code_elimination{}.apply(m);
} }
......
...@@ -10,7 +10,7 @@ namespace gen { ...@@ -10,7 +10,7 @@ namespace gen {
static std::vector<std::size_t> vector_sizes(const std::vector<shape>& inputs) static std::vector<std::size_t> vector_sizes(const std::vector<shape>& inputs)
{ {
// If all inputs is half then only use half2 // If all inputs are half then only use half2
if(std::all_of(inputs.begin(), inputs.end(), [](const auto& s) { if(std::all_of(inputs.begin(), inputs.end(), [](const auto& s) {
return s.type() == shape::half_type; return s.type() == shape::half_type;
})) }))
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#define MIGRAPHX_GUARD_RTGLIB_QUANT_CONVOLUTION_HPP #define MIGRAPHX_GUARD_RTGLIB_QUANT_CONVOLUTION_HPP
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/op/quant_convolution.hpp> #include <migraphx/op/quant_convolution.hpp>
#include <migraphx/gpu/miopen.hpp> #include <migraphx/gpu/miopen.hpp>
...@@ -14,6 +15,7 @@ struct context; ...@@ -14,6 +15,7 @@ struct context;
struct miopen_quant_convolution struct miopen_quant_convolution
{ {
op::quant_convolution op; op::quant_convolution op;
bool int8_x4_format = false;
shared<convolution_descriptor> cd; shared<convolution_descriptor> cd;
miopenConvFwdAlgorithm_t algo{}; miopenConvFwdAlgorithm_t algo{};
miopenHandle_t handle = nullptr; miopenHandle_t handle = nullptr;
...@@ -22,7 +24,8 @@ struct miopen_quant_convolution ...@@ -22,7 +24,8 @@ struct miopen_quant_convolution
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
{ {
// TODO: Add algo // TODO: Add algo
return op::quant_convolution::reflect(self.op, f); return pack_join(migraphx::reflect(self.op, f),
pack(f(self.int8_x4_format, "int8_x4_format")));
} }
std::string name() const { return "gpu::quant_convolution"; } std::string name() const { return "gpu::quant_convolution"; }
......
...@@ -364,8 +364,22 @@ struct miopen_apply ...@@ -364,8 +364,22 @@ struct miopen_apply
{ {
apply_map.emplace("quant_convolution", [=](instruction_ref ins) { apply_map.emplace("quant_convolution", [=](instruction_ref ins) {
auto&& op = any_cast<op::quant_convolution>(ins->get_operator()); auto&& op = any_cast<op::quant_convolution>(ins->get_operator());
auto conv = miopen_quant_convolution{op, make_conv(op)}; shape ws;
auto ws = conv.compile(get_context(), ins->get_shape(), to_shapes(ins->inputs())); miopen_quant_convolution conv;
auto compile_quant_conv_with_format = [&](bool format) {
conv = miopen_quant_convolution{op, format, make_conv(op)};
ws = conv.compile(get_context(), ins->get_shape(), to_shapes(ins->inputs()));
};
try
{
compile_quant_conv_with_format(int8_x4_format);
}
catch(migraphx::exception&)
{
// In case no solver supports the default format, retry using the other format.
compile_quant_conv_with_format(!int8_x4_format);
}
auto args = ins->inputs(); auto args = ins->inputs();
auto workspace = insert_allocation(ins, ws, "workspace"); auto workspace = insert_allocation(ins, ws, "workspace");
......
...@@ -118,7 +118,7 @@ void pack_int8_args::apply(module& m) const ...@@ -118,7 +118,7 @@ void pack_int8_args::apply(module& m) const
assert(val.contains("int8_x4_format")); assert(val.contains("int8_x4_format"));
if(not val.at("int8_x4_format").to<bool>()) if(not val.at("int8_x4_format").to<bool>())
{ {
return; continue;
} }
auto inputs = ins->inputs(); auto inputs = ins->inputs();
auto lens = inputs.at(0)->get_shape().lens(); auto lens = inputs.at(0)->get_shape().lens();
...@@ -156,6 +156,12 @@ void pack_int8_args::apply(module& m) const ...@@ -156,6 +156,12 @@ void pack_int8_args::apply(module& m) const
} }
else if(ins->name() == "gpu::quant_convolution") else if(ins->name() == "gpu::quant_convolution")
{ {
auto val = ins->get_operator().to_value();
if(not val.at("int8_x4_format").to<bool>())
{
continue;
}
auto inputs = ins->inputs(); auto inputs = ins->inputs();
auto packed_x = m.insert_instruction( auto packed_x = m.insert_instruction(
ins, ins,
......
...@@ -16,8 +16,8 @@ argument miopen_quant_convolution::compute(context& ctx, ...@@ -16,8 +16,8 @@ argument miopen_quant_convolution::compute(context& ctx,
const shape& output_shape, const shape& output_shape,
const std::vector<argument>& args) const const std::vector<argument>& args) const
{ {
auto x_desc = make_tensor(args[0].get_shape(), true); auto x_desc = make_tensor(args[0].get_shape(), int8_x4_format);
auto w_desc = make_tensor(args[1].get_shape(), true); auto w_desc = make_tensor(args[1].get_shape(), int8_x4_format);
auto y_desc = make_tensor(output_shape); auto y_desc = make_tensor(output_shape);
float alpha = 1; float alpha = 1;
...@@ -49,8 +49,8 @@ shape miopen_quant_convolution::compile(context& ctx, ...@@ -49,8 +49,8 @@ shape miopen_quant_convolution::compile(context& ctx,
std::vector<shape> inputs) std::vector<shape> inputs)
{ {
shape workspace_shape{}; shape workspace_shape{};
auto x_desc = make_tensor(inputs[0], true); auto x_desc = make_tensor(inputs[0], int8_x4_format);
auto w_desc = make_tensor(inputs[1], true); auto w_desc = make_tensor(inputs[1], int8_x4_format);
auto y_desc = make_tensor(output_shape); auto y_desc = make_tensor(output_shape);
std::size_t workspace_size = 0; std::size_t workspace_size = 0;
...@@ -62,8 +62,15 @@ shape miopen_quant_convolution::compile(context& ctx, ...@@ -62,8 +62,15 @@ shape miopen_quant_convolution::compile(context& ctx,
&workspace_size); &workspace_size);
workspace_shape = shape{shape::int8_type, {workspace_size}}; workspace_shape = shape{shape::int8_type, {workspace_size}};
auto arg_vec4_x = to_gpu(generate_argument(pack_int8_shape(inputs[0]))); auto x_shape = inputs[0];
auto arg_vec4_w = to_gpu(generate_argument(pack_int8_shape(inputs[1]))); auto w_shape = inputs[1];
if(int8_x4_format)
{
x_shape = pack_int8_shape(x_shape);
w_shape = pack_int8_shape(w_shape);
}
auto arg_vec4_x = to_gpu(generate_argument(x_shape));
auto arg_vec4_w = to_gpu(generate_argument(w_shape));
auto y = allocate_gpu(output_shape); auto y = allocate_gpu(output_shape);
auto workspace = allocate_gpu(workspace_shape); auto workspace = allocate_gpu(workspace_shape);
......
...@@ -1510,15 +1510,40 @@ TEST_CASE(test_squeeze_wrong_axis) ...@@ -1510,15 +1510,40 @@ TEST_CASE(test_squeeze_wrong_axis)
TEST_CASE(test_unsqueeze) TEST_CASE(test_unsqueeze)
{ {
migraphx::shape s1{migraphx::shape::float_type, {4, 3, 3}}; migraphx::shape s1{migraphx::shape::float_type, {4, 5, 3}};
migraphx::shape s2{migraphx::shape::float_type, {4, 3, 1, 3}}; migraphx::shape s2{migraphx::shape::float_type, {4, 5, 1, 3}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {2}}}), s1); expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {2}}}), s1);
} }
TEST_CASE(test_unsqueeze_step)
{
migraphx::shape s1{migraphx::shape::float_type, {4, 5, 12}};
migraphx::shape s2{migraphx::shape::float_type, {4, 5, 2, 6}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {2}}, {"steps", {2}}}), s1);
}
TEST_CASE(test_unsqueeze_step_non_divisable)
{
migraphx::shape s1{migraphx::shape::float_type, {4, 5, 3}};
throws_shape(migraphx::make_op("unsqueeze", {{"axes", {2}}, {"steps", {2}}}), s1);
}
TEST_CASE(test_unsqueeze_step_non_zero)
{
migraphx::shape s1{migraphx::shape::float_type, {4, 5, 12}};
throws_shape(migraphx::make_op("unsqueeze", {{"axes", {2}}, {"steps", {0}}}), s1);
}
TEST_CASE(test_unsqueeze_step_at_end)
{
migraphx::shape s1{migraphx::shape::float_type, {4, 5, 12}};
throws_shape(migraphx::make_op("unsqueeze", {{"axes", {3}}, {"steps", {2}}}), s1);
}
TEST_CASE(test_unsqueeze_negative_axis) TEST_CASE(test_unsqueeze_negative_axis)
{ {
migraphx::shape s1{migraphx::shape::float_type, {4, 3, 3}}; migraphx::shape s1{migraphx::shape::float_type, {4, 5, 3}};
migraphx::shape s2{migraphx::shape::float_type, {4, 3, 1, 3}}; migraphx::shape s2{migraphx::shape::float_type, {4, 5, 1, 3}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {-2}}}), s1); expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {-2}}}), s1);
} }
...@@ -1544,21 +1569,28 @@ TEST_CASE(test_unsqueeze_scalar_tensor2) ...@@ -1544,21 +1569,28 @@ TEST_CASE(test_unsqueeze_scalar_tensor2)
TEST_CASE(test_unsqueeze_transpose) TEST_CASE(test_unsqueeze_transpose)
{ {
migraphx::shape s1{migraphx::shape::float_type, {4, 4, 3}, {12, 1, 4}}; migraphx::shape s1{migraphx::shape::float_type, {4, 4, 3}, {12, 1, 4}};
migraphx::shape s2{migraphx::shape::float_type, {4, 4, 1, 3}, {12, 1, 1, 4}}; migraphx::shape s2{migraphx::shape::float_type, {4, 4, 1, 3}, {12, 1, 12, 4}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {2}}}), s1); expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {2}}}), s1);
} }
TEST_CASE(test_unsqueeze_transpose_step)
{
migraphx::shape s1{migraphx::shape::float_type, {4, 4, 6}, {24, 1, 4}};
migraphx::shape s2{migraphx::shape::float_type, {4, 4, 2, 3}, {24, 1, 12, 4}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {2}}, {"steps", {2}}}), s1);
}
TEST_CASE(test_unsqueeze_multibroadcast) TEST_CASE(test_unsqueeze_multibroadcast)
{ {
migraphx::shape s1{migraphx::shape::float_type, {2, 3, 4}, {0, 1, 0}}; migraphx::shape s1{migraphx::shape::float_type, {2, 3, 4}, {0, 1, 0}};
migraphx::shape s2{migraphx::shape::float_type, {2, 3, 1, 4}, {0, 1, 1, 0}}; migraphx::shape s2{migraphx::shape::float_type, {2, 3, 1, 4}, {0, 1, 0, 0}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {2}}}), s1); expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {2}}}), s1);
} }
TEST_CASE(test_unsqueeze_slice) TEST_CASE(test_unsqueeze_slice)
{ {
migraphx::shape s1{migraphx::shape::float_type, {2, 3, 4}, {108, 36, 1}}; migraphx::shape s1{migraphx::shape::float_type, {2, 3, 4}, {108, 36, 1}};
migraphx::shape s2{migraphx::shape::float_type, {2, 3, 1, 4}, {108, 36, 36, 1}}; migraphx::shape s2{migraphx::shape::float_type, {2, 3, 1, 4}, {108, 36, 4, 1}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {2}}}), s1); expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {2}}}), s1);
} }
...@@ -1590,6 +1622,20 @@ TEST_CASE(test_unsqueeze_multiple_axes_2) ...@@ -1590,6 +1622,20 @@ TEST_CASE(test_unsqueeze_multiple_axes_2)
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {0, 1}}}), s1); expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {0, 1}}}), s1);
} }
TEST_CASE(test_unsqueeze_multiple_axes_3)
{
migraphx::shape s1{migraphx::shape::float_type, {3, 4, 5}};
migraphx::shape s2{migraphx::shape::float_type, {3, 4, 1, 5, 1, 1}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {2, 4, 5}}}), s1);
}
TEST_CASE(test_unsqueeze_multiple_axes_4)
{
migraphx::shape s1{migraphx::shape::float_type, {3, 4, 5}};
migraphx::shape s2{migraphx::shape::float_type, {3, 4, 1, 5, 1, 1}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {5, 4, 2}}}), s1);
}
TEST_CASE(transpose_shape) TEST_CASE(transpose_shape)
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 2}}; migraphx::shape input{migraphx::shape::float_type, {2, 2}};
......
...@@ -1220,4 +1220,82 @@ TEST_CASE(transpose_slice_single_transpose) ...@@ -1220,4 +1220,82 @@ TEST_CASE(transpose_slice_single_transpose)
EXPECT(m1 == m2); EXPECT(m1 == m2);
} }
TEST_CASE(transpose_slice_non_packed_axis)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::float_type, {2, 384, 36, 64}});
auto transpose =
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1, 3}}}), x);
auto slice = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {12}}}),
transpose);
auto sqrt = m1.add_instruction(migraphx::make_op("sqrt"), slice);
m1.add_return({sqrt});
}
auto output_shapes = m1.get_output_shapes();
run_pass(m1);
EXPECT(m1.get_output_shapes() == output_shapes);
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::float_type, {2, 384, 36, 64}});
auto unsqueeze =
m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}, {"steps", {12}}}), x);
auto transpose = m2.add_instruction(
migraphx::make_op("transpose", {{"permutation", {3, 0, 2, 1, 4}}}), unsqueeze);
auto slice = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), transpose);
auto squeeze = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), slice);
auto sqrt = m2.add_instruction(migraphx::make_op("sqrt"), squeeze);
m2.add_return({sqrt});
}
EXPECT(m1 == m2);
}
TEST_CASE(transpose_slice_non_packed_multi_axis)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::float_type, {2, 384, 36, 64}});
auto transpose =
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1, 3}}}), x);
auto slice1 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {12}}}),
transpose);
auto slice2 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {12}}, {"ends", {24}}}),
transpose);
auto transpose2 = m1.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), slice2);
auto slice3 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {24}}, {"ends", {36}}}),
transpose);
m1.add_return({slice1, transpose2, slice3});
}
auto output_shapes = m1.get_output_shapes();
run_pass(m1);
EXPECT(m1.get_output_shapes() == output_shapes);
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::float_type, {2, 384, 36, 64}});
auto unsqueeze =
m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}, {"steps", {12}}}), x);
auto transpose = m2.add_instruction(
migraphx::make_op("transpose", {{"permutation", {3, 0, 2, 1, 4}}}), unsqueeze);
auto slice1 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), transpose);
auto squeeze1 = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), slice1);
auto slice2 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), transpose);
auto squeeze2 = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), slice2);
auto transpose2 = m2.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), squeeze2);
auto slice3 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {2}}, {"ends", {3}}}), transpose);
auto squeeze3 = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), slice3);
m2.add_return({squeeze1, transpose2, squeeze3});
}
EXPECT(m1.sort() == m2.sort());
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/op/quant_convolution.hpp>
struct quant_conv_int8x4_default : verify_program<quant_conv_int8x4_default>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::int8_type, {16, 16, 4, 4}};
auto pa = mm->add_parameter("a", a_shape);
migraphx::shape c_shape{migraphx::shape::int8_type, {16, 16, 3, 3}};
auto pc = mm->add_parameter("c", c_shape);
mm->add_instruction(
migraphx::op::quant_convolution{{{0, 0}}, {{1, 1}}, {{1, 1}}, migraphx::op::same},
pa,
pc);
return p;
}
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment