"vscode:/vscode.git/clone" did not exist on "89f7ac0dc2ab3f3f20c9a0f0095fb624948c607c"
Commit 31065c7d authored by charlie's avatar charlie
Browse files

Merge branch 'dyn_squeeze' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into dyn_model_test

parents 6bec381f 6acbd4e4
......@@ -115,6 +115,7 @@ def shape(h):
const=True)
h.method('strides', returns='const std::vector<size_t>&', const=True)
h.method('type', returns='migraphx::shape::type_t', const=True)
h.method('elements', returns='size_t', const=True)
h.method('bytes', returns='size_t', const=True)
h.method('equal',
api.params(x='const migraphx::shape&'),
......@@ -122,6 +123,7 @@ def shape(h):
returns='bool',
const=True)
h.method('standard', returns='bool', const=True)
h.method('index', api.params(i='size_t'), returns='size_t', const=True)
@auto_handle()
......@@ -274,6 +276,13 @@ def program(h):
params='std::unordered_map<std::string, migraphx::argument>'),
invoke='migraphx::run($@)',
returns='std::vector<migraphx::argument>')
h.method('run_async',
api.params(
params='std::unordered_map<std::string, migraphx::argument>',
s='void*',
name='const char *'),
invoke='migraphx::run_async($@)',
returns='std::vector<migraphx::argument>')
h.method('equal',
api.params(x='const migraphx::program&'),
invoke='migraphx::equal($@)',
......@@ -450,4 +459,8 @@ def experimental_custom_op(h):
h.virtual('compute_shape',
api.params(inputs='std::vector<migraphx::shape>'),
returns='migraphx::shape')
h.virtual('output_alias',
api.params(inputs='std::vector<migraphx::shape>'),
returns='std::vector<size_t>')
h.virtual('runs_on_offload_target', returns='bool')
h.method('register', invoke='migraphx::register_custom_op($@)')
......@@ -39,7 +39,7 @@ instruction_ref insert_apply_alpha_beta(module& m,
auto a = args[0];
auto b = args[1];
auto input_type = a->get_shape().type();
if(!float_equal(alpha.at<float>(0), 1.0))
if(not float_equal(alpha.at<float>(0), 1.0))
{
auto alpha_literal = m.add_literal(alpha);
a = insert_common_op(m, pos, migraphx::make_op("mul"), {alpha_literal, a});
......
......@@ -27,6 +27,7 @@
#include <migraphx/algorithm.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -43,6 +44,7 @@ inline namespace MIGRAPHX_INLINE_NS {
// In this case we need to broadcast the (:,:,1:,:) axis
// of s0 plus the 1st dimension of s1 giving
// output_lens = (3,2,7,5)
//
std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
std::vector<std::size_t> s1)
{
......@@ -50,25 +52,67 @@ std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
return s0;
if(s0.size() > s1.size())
s0.swap(s1);
std::vector<std::size_t> out_lens(s1);
auto offset = s1.size() - s0.size();
std::transform(
s0.begin(), s0.end(), s1.begin() + offset, out_lens.begin() + offset, [&](auto a, auto b) {
if(a != b and a != 1 and b != 1)
{
MIGRAPHX_THROW("COMPUTE_BROADCASTLEN: shape {" + to_string_range(s0) + "} and {" +
to_string_range(s1) + "} mismatch!");
MIGRAPHX_THROW("COMPUTE_BROADCASTLEN: shape {" + migraphx::to_string_range(s0) +
"} and {" + migraphx::to_string_range(s1) + "} mismatch!");
}
return std::max(a, b);
});
return out_lens;
}
std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, shape s1)
{
assert(s0.dynamic() or s1.dynamic());
// change both shapes to dynamic_dimension representation
if(not s0.dynamic())
s0 = s0.to_dynamic();
if(not s1.dynamic())
s1 = s1.to_dynamic();
if(s0.ndim() > s1.ndim())
{
std::swap(s0, s1);
}
auto offset = s1.ndim() - s0.ndim();
std::vector<shape::dynamic_dimension> out_dims(s1.dyn_dims());
shape::dynamic_dimension one_dyn_dim{1, 1, 0};
std::transform(
s0.dyn_dims().cbegin(),
s0.dyn_dims().cend(),
s1.dyn_dims().cbegin() + offset,
out_dims.begin() + offset,
[&](auto a, auto b) {
if(a == b)
{
return a;
}
else if(a == one_dyn_dim or b == one_dyn_dim)
{
// setting opt to 0, may need to be changed
return shape::dynamic_dimension{std::max(a.min, b.min), std::max(a.max, b.max), 0};
}
else
{
MIGRAPHX_THROW("COMPUTE_BROADCASTED_DYN_DIMS: dynamic shapes {" +
migraphx::to_string_range(s0.dyn_dims()) + "} and {" +
migraphx::to_string_range(s1.dyn_dims()) + "} mismatch!");
}
});
return out_dims;
}
// Compute the common (broadcasted) dimensions of a list of fixed shapes
std::vector<std::size_t> compute_common_lens(const std::vector<shape>& shapes)
{
assert(not shapes.empty());
assert(
std::none_of(shapes.cbegin(), shapes.cend(), [](auto shape) { return shape.dynamic(); }));
return transform_accumulate(shapes.begin() + 1,
shapes.end(),
shapes.front().lens(),
......@@ -114,20 +158,63 @@ instruction_ref insert_common_op(module& m,
const operation& op,
std::vector<instruction_ref> inputs)
{
auto common = common_shape(to_shapes(inputs));
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) {
if(input->get_shape().lens() != common.lens())
if(std::any_of(
inputs.cbegin(), inputs.cend(), [](auto input) { return input->get_shape().dynamic(); }))
{
// currently only handles the binary case
if(inputs.size() != 2)
{
input = m.insert_instruction(
ins, make_op("multibroadcast", {{"out_lens", common.lens()}}), input);
MIGRAPHX_THROW("INSERT_COMMON_OP: not handled; " + migraphx::to_string(inputs.size()) +
"inputs, only handle two inputs if any are dynamic shape");
}
if(input->get_shape().type() != common.type())
auto c_type = compute_common_types(to_shapes(inputs));
auto c_dyn_dims =
compute_broadcasted_dyn_dims(inputs[0]->get_shape(), inputs[1]->get_shape());
// following should work for a static or dynamic shape
if(inputs[0]->get_shape().dyn_dims() != c_dyn_dims)
{
input = m.insert_instruction(
ins, make_op("convert", {{"target_type", common.type()}}), input);
inputs[0] = m.insert_instruction(
ins,
make_op("multibroadcast", {{"out_dyn_dims", to_value(c_dyn_dims)}}),
inputs[0],
inputs[1]);
}
return input;
});
if(inputs[1]->get_shape().dyn_dims() != c_dyn_dims)
{
inputs[1] = m.insert_instruction(
ins,
make_op("multibroadcast", {{"out_dyn_dims", to_value(c_dyn_dims)}}),
inputs[1],
inputs[0]);
}
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) {
if(input->get_shape().type() != c_type)
{
input =
m.insert_instruction(ins, make_op("convert", {{"target_type", c_type}}), input);
}
return input;
});
}
else
{
auto common = common_shape(to_shapes(inputs));
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) {
if(input->get_shape().lens() != common.lens())
{
input = m.insert_instruction(
ins, make_op("multibroadcast", {{"out_lens", common.lens()}}), input);
}
if(input->get_shape().type() != common.type())
{
input = m.insert_instruction(
ins, make_op("convert", {{"target_type", common.type()}}), input);
}
return input;
});
}
return m.insert_instruction(ins, op, inputs);
}
......
......@@ -25,13 +25,10 @@
#include <migraphx/make_op.hpp>
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/json.hpp>
#include "models.hpp"
namespace migraphx {
namespace driver {
inline namespace MIGRAPHX_INLINE_NS {
migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
{
migraphx::program p;
......@@ -42,179 +39,153 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1}}, 1)));
auto x_main_module_2 = mmain->add_literal(migraphx::abs(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1}}, 2)));
auto x_input_1 = mmain->add_parameter(
"input.1", migraphx::shape{migraphx::shape::float_type, {batch, 3, 224, 224}});
auto x_0 = mmain->add_parameter(
"0", migraphx::shape{migraphx::shape::float_type, {batch, 3, 224, 224}});
auto x_main_module_4 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096, 4096}}, 3));
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1000}}, 3));
auto x_main_module_5 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096}}, 4));
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1000, 4096}}, 4));
auto x_main_module_6 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096, 9216}}, 5));
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096}}, 5));
auto x_main_module_7 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096}}, 6));
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096, 4096}}, 6));
auto x_main_module_8 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1000, 4096}}, 7));
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096}}, 7));
auto x_main_module_9 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1000}}, 8));
auto x_main_module_10 = mmain->add_literal(migraphx::generate_literal(
migraphx::shape{migraphx::shape::float_type, {256, 384, 3, 3}}, 9));
auto x_main_module_11 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 10));
auto x_main_module_12 = mmain->add_literal(migraphx::generate_literal(
migraphx::shape{migraphx::shape::float_type, {384, 192, 3, 3}}, 11));
auto x_main_module_13 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 12));
auto x_main_module_14 = mmain->add_literal(migraphx::generate_literal(
migraphx::shape{migraphx::shape::float_type, {192, 64, 5, 5}}, 13));
auto x_main_module_15 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 14));
auto x_main_module_16 = mmain->add_literal(migraphx::generate_literal(
migraphx::shape{migraphx::shape::float_type, {256, 256, 3, 3}}, 15));
auto x_main_module_17 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 16));
auto x_main_module_18 = mmain->add_literal(migraphx::generate_literal(
migraphx::shape{migraphx::shape::float_type, {64, 3, 11, 11}}, 17));
auto x_main_module_19 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 18));
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096, 9216}}, 8));
auto x_main_module_10 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 9));
auto x_main_module_11 = mmain->add_literal(migraphx::generate_literal(
migraphx::shape{migraphx::shape::float_type, {256, 256, 3, 3}}, 10));
auto x_main_module_12 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 11));
auto x_main_module_13 = mmain->add_literal(migraphx::generate_literal(
migraphx::shape{migraphx::shape::float_type, {256, 384, 3, 3}}, 12));
auto x_main_module_14 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 13));
auto x_main_module_15 = mmain->add_literal(migraphx::generate_literal(
migraphx::shape{migraphx::shape::float_type, {384, 192, 3, 3}}, 14));
auto x_main_module_16 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 15));
auto x_main_module_17 = mmain->add_literal(migraphx::generate_literal(
migraphx::shape{migraphx::shape::float_type, {192, 64, 5, 5}}, 16));
auto x_main_module_18 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 17));
auto x_main_module_19 = mmain->add_literal(migraphx::generate_literal(
migraphx::shape{migraphx::shape::float_type, {64, 3, 11, 11}}, 18));
auto x_main_module_20 = mmain->add_instruction(
migraphx::make_op(
migraphx::make_json_op(
"convolution",
migraphx::from_json_string(
"{dilation:[1,1],group:1,padding:[2,2,2,2],padding_mode:0,stride:[4,4]}")),
x_input_1,
x_main_module_18);
auto x_main_module_21 = mmain->add_instruction(
migraphx::make_op("broadcast",
migraphx::from_json_string("{axis:1,out_lens:[1,64,55,55]}")),
"{dilation:[1,1],group:1,padding:[2,2,2,2],padding_mode:0,stride:[4,4]}"),
x_0,
x_main_module_19);
auto x_main_module_21 = mmain->add_instruction(
migraphx::make_json_op("broadcast", "{axis:1,out_lens:[1,64,55,55]}"), x_main_module_18);
auto x_main_module_22 =
mmain->add_instruction(migraphx::make_op("add"), x_main_module_20, x_main_module_21);
auto x_main_module_23 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_22);
auto x_main_module_24 = mmain->add_instruction(
migraphx::make_op(
migraphx::make_json_op(
"pooling",
migraphx::from_json_string(
"{ceil_mode:0,lengths:[3,3],lp_order:2,mode:1,padding:[0,0,0,0],stride:[2,2]}")),
"{ceil_mode:0,lengths:[3,3],lp_order:2,mode:1,padding:[0,0,0,0],stride:[2,2]}"),
x_main_module_23);
auto x_main_module_25 = mmain->add_instruction(
migraphx::make_op(
migraphx::make_json_op(
"convolution",
migraphx::from_json_string(
"{dilation:[1,1],group:1,padding:[2,2,2,2],padding_mode:0,stride:[1,1]}")),
"{dilation:[1,1],group:1,padding:[2,2,2,2],padding_mode:0,stride:[1,1]}"),
x_main_module_24,
x_main_module_14);
x_main_module_17);
auto x_main_module_26 = mmain->add_instruction(
migraphx::make_op("broadcast",
migraphx::from_json_string("{axis:1,out_lens:[1,192,27,27]}")),
x_main_module_15);
migraphx::make_json_op("broadcast", "{axis:1,out_lens:[1,192,27,27]}"), x_main_module_16);
auto x_main_module_27 =
mmain->add_instruction(migraphx::make_op("add"), x_main_module_25, x_main_module_26);
auto x_main_module_28 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_27);
auto x_main_module_29 = mmain->add_instruction(
migraphx::make_op(
migraphx::make_json_op(
"pooling",
migraphx::from_json_string(
"{ceil_mode:0,lengths:[3,3],lp_order:2,mode:1,padding:[0,0,0,0],stride:[2,2]}")),
"{ceil_mode:0,lengths:[3,3],lp_order:2,mode:1,padding:[0,0,0,0],stride:[2,2]}"),
x_main_module_28);
auto x_main_module_30 = mmain->add_instruction(
migraphx::make_op(
migraphx::make_json_op(
"convolution",
migraphx::from_json_string(
"{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1,1]}")),
"{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1,1]}"),
x_main_module_29,
x_main_module_12);
x_main_module_15);
auto x_main_module_31 = mmain->add_instruction(
migraphx::make_op("broadcast",
migraphx::from_json_string("{axis:1,out_lens:[1,384,13,13]}")),
x_main_module_13);
migraphx::make_json_op("broadcast", "{axis:1,out_lens:[1,384,13,13]}"), x_main_module_14);
auto x_main_module_32 =
mmain->add_instruction(migraphx::make_op("add"), x_main_module_30, x_main_module_31);
auto x_main_module_33 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_32);
auto x_main_module_34 = mmain->add_instruction(
migraphx::make_op(
migraphx::make_json_op(
"convolution",
migraphx::from_json_string(
"{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1,1]}")),
"{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1,1]}"),
x_main_module_33,
x_main_module_10);
x_main_module_13);
auto x_main_module_35 = mmain->add_instruction(
migraphx::make_op("broadcast",
migraphx::from_json_string("{axis:1,out_lens:[1,256,13,13]}")),
x_main_module_11);
migraphx::make_json_op("broadcast", "{axis:1,out_lens:[1,256,13,13]}"), x_main_module_12);
auto x_main_module_36 =
mmain->add_instruction(migraphx::make_op("add"), x_main_module_34, x_main_module_35);
auto x_main_module_37 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_36);
auto x_main_module_38 = mmain->add_instruction(
migraphx::make_op(
migraphx::make_json_op(
"convolution",
migraphx::from_json_string(
"{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1,1]}")),
"{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1,1]}"),
x_main_module_37,
x_main_module_16);
x_main_module_11);
auto x_main_module_39 = mmain->add_instruction(
migraphx::make_op("broadcast",
migraphx::from_json_string("{axis:1,out_lens:[1,256,13,13]}")),
x_main_module_17);
migraphx::make_json_op("broadcast", "{axis:1,out_lens:[1,256,13,13]}"), x_main_module_10);
auto x_main_module_40 =
mmain->add_instruction(migraphx::make_op("add"), x_main_module_38, x_main_module_39);
auto x_main_module_41 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_40);
auto x_main_module_42 = mmain->add_instruction(
migraphx::make_op(
migraphx::make_json_op(
"pooling",
migraphx::from_json_string(
"{ceil_mode:0,lengths:[3,3],lp_order:2,mode:1,padding:[0,0,0,0],stride:[2,2]}")),
"{ceil_mode:0,lengths:[3,3],lp_order:2,mode:1,padding:[0,0,0,0],stride:[2,2]}"),
x_main_module_41);
auto x_main_module_43 = mmain->add_instruction(
migraphx::make_op("reshape", migraphx::from_json_string("{dims:[1,9216]}")),
x_main_module_42);
auto x_main_module_44 = mmain->add_instruction(
migraphx::make_op("transpose", migraphx::from_json_string("{permutation:[1,0]}")),
x_main_module_6);
auto x_main_module_45 =
mmain->add_instruction(migraphx::make_op("dot"), x_main_module_43, x_main_module_44);
auto x_main_module_46 = mmain->add_instruction(
migraphx::make_op("multibroadcast", migraphx::from_json_string("{out_lens:[1,4096]}")),
x_main_module_7);
auto x_main_module_43 =
mmain->add_instruction(migraphx::make_json_op("flatten", "{axis:1}"), x_main_module_42);
auto x_main_module_44 = mmain->add_instruction(migraphx::make_op("identity"), x_main_module_43);
auto x_main_module_45 = mmain->add_instruction(
migraphx::make_json_op("transpose", "{permutation:[1,0]}"), x_main_module_9);
auto x_main_module_46 =
mmain->add_instruction(migraphx::make_op("dot"), x_main_module_44, x_main_module_45);
auto x_main_module_47 = mmain->add_instruction(
migraphx::make_op("multibroadcast", migraphx::from_json_string("{out_lens:[1,4096]}")),
x_main_module_2);
auto x_main_module_48 =
mmain->add_instruction(migraphx::make_op("mul"), x_main_module_46, x_main_module_47);
migraphx::make_json_op("multibroadcast", "{out_lens:[1,4096]}"), x_main_module_8);
auto x_main_module_48 = mmain->add_instruction(
migraphx::make_json_op("multibroadcast", "{out_lens:[1,4096]}"), x_main_module_2);
auto x_main_module_49 =
mmain->add_instruction(migraphx::make_op("add"), x_main_module_45, x_main_module_48);
auto x_main_module_50 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_49);
auto x_main_module_51 = mmain->add_instruction(
migraphx::make_op("transpose", migraphx::from_json_string("{permutation:[1,0]}")),
x_main_module_4);
auto x_main_module_52 =
mmain->add_instruction(migraphx::make_op("dot"), x_main_module_50, x_main_module_51);
mmain->add_instruction(migraphx::make_op("mul"), x_main_module_47, x_main_module_48);
auto x_main_module_50 =
mmain->add_instruction(migraphx::make_op("add"), x_main_module_46, x_main_module_49);
auto x_main_module_51 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_50);
auto x_main_module_52 = mmain->add_instruction(migraphx::make_op("identity"), x_main_module_51);
auto x_main_module_53 = mmain->add_instruction(
migraphx::make_op("multibroadcast", migraphx::from_json_string("{out_lens:[1,4096]}")),
x_main_module_5);
auto x_main_module_54 = mmain->add_instruction(
migraphx::make_op("multibroadcast", migraphx::from_json_string("{out_lens:[1,4096]}")),
x_main_module_1);
auto x_main_module_55 =
mmain->add_instruction(migraphx::make_op("mul"), x_main_module_53, x_main_module_54);
auto x_main_module_56 =
mmain->add_instruction(migraphx::make_op("add"), x_main_module_52, x_main_module_55);
auto x_main_module_57 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_56);
auto x_main_module_58 = mmain->add_instruction(
migraphx::make_op("transpose", migraphx::from_json_string("{permutation:[1,0]}")),
x_main_module_8);
auto x_main_module_59 =
mmain->add_instruction(migraphx::make_op("dot"), x_main_module_57, x_main_module_58);
migraphx::make_json_op("transpose", "{permutation:[1,0]}"), x_main_module_7);
auto x_main_module_54 =
mmain->add_instruction(migraphx::make_op("dot"), x_main_module_52, x_main_module_53);
auto x_main_module_55 = mmain->add_instruction(
migraphx::make_json_op("multibroadcast", "{out_lens:[1,4096]}"), x_main_module_6);
auto x_main_module_56 = mmain->add_instruction(
migraphx::make_json_op("multibroadcast", "{out_lens:[1,4096]}"), x_main_module_1);
auto x_main_module_57 =
mmain->add_instruction(migraphx::make_op("mul"), x_main_module_55, x_main_module_56);
auto x_main_module_58 =
mmain->add_instruction(migraphx::make_op("add"), x_main_module_54, x_main_module_57);
auto x_main_module_59 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_58);
auto x_main_module_60 = mmain->add_instruction(
migraphx::make_op("multibroadcast", migraphx::from_json_string("{out_lens:[1,1000]}")),
x_main_module_9);
auto x_main_module_61 = mmain->add_instruction(
migraphx::make_op("multibroadcast", migraphx::from_json_string("{out_lens:[1,1000]}")),
x_main_module_0);
auto x_main_module_62 =
mmain->add_instruction(migraphx::make_op("mul"), x_main_module_60, x_main_module_61);
auto x_main_module_63 =
mmain->add_instruction(migraphx::make_op("add"), x_main_module_59, x_main_module_62);
mmain->add_return({x_main_module_63});
migraphx::make_json_op("transpose", "{permutation:[1,0]}"), x_main_module_5);
auto x_main_module_61 =
mmain->add_instruction(migraphx::make_op("dot"), x_main_module_59, x_main_module_60);
auto x_main_module_62 = mmain->add_instruction(
migraphx::make_json_op("multibroadcast", "{out_lens:[1,1000]}"), x_main_module_4);
auto x_main_module_63 = mmain->add_instruction(
migraphx::make_json_op("multibroadcast", "{out_lens:[1,1000]}"), x_main_module_0);
auto x_main_module_64 =
mmain->add_instruction(migraphx::make_op("mul"), x_main_module_62, x_main_module_63);
auto x_main_module_65 =
mmain->add_instruction(migraphx::make_op("add"), x_main_module_61, x_main_module_64);
mmain->add_return({x_main_module_65});
return p;
}
......
This diff is collapsed.
......@@ -44,7 +44,6 @@
#include <migraphx/propagate_constant.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/rewrite_batchnorm.hpp>
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/register_target.hpp>
......@@ -221,7 +220,6 @@ struct loader
{
migraphx::run_passes(*p.get_main_module(),
{
migraphx::rewrite_batchnorm{},
migraphx::eliminate_identity{},
migraphx::dead_code_elimination{},
migraphx::simplify_algebra{},
......
This diff is collapsed.
......@@ -145,7 +145,7 @@ void verify_reduced(program p,
auto* mm = p.get_main_module();
auto last = std::prev(mm->end(), n + 1);
mm->remove_instructions(last, mm->end());
std::cout << "Verify: " << std::endl;
std::cout << "Verify: " << n << std::endl;
std::cout << p << std::endl;
verify_program(std::to_string(n), p, t, options, quantize, inputs, tolerance);
}
......@@ -159,6 +159,7 @@ void verify_reduced_program(const program& p,
{
const auto* mm = p.get_main_module();
auto n = std::distance(mm->begin(), mm->end());
std::cout << "Verify steps: " << n << std::endl;
for(std::size_t i = 0; i < n; i++)
{
verify_reduced(p, i, t, options, quantize, inputs, tolerance);
......
......@@ -60,7 +60,7 @@ void eliminate_concat::apply(module& m) const
auto lens = ins->inputs().front()->get_shape().lens();
auto concat_op = concat_opt.get_concat(ins->get_operator());
std::size_t axis_index = tune_axis(lens.size(), concat_op.axis, concat_op.name());
if(axis_index == 0 ||
if(axis_index == 0 or
std::all_of(lens.begin(), lens.begin() + axis_index, [](auto x) { return x == 1; }))
{
// Last input should be an allocation
......
......@@ -71,7 +71,7 @@ static bool try_compute_shape(instruction_ref ins,
return (arg == ins) ? new_shape : arg->get_shape();
});
if(!try_compute_shape(output, input_shapes, mods))
if(not try_compute_shape(output, input_shapes, mods))
{
return false;
}
......
......@@ -39,7 +39,7 @@ T generic_read_file(const std::string& filename)
is.seekg(0, std::ios::beg);
T buffer(size, 0);
if(!is.read(&buffer[0], size))
if(not is.read(&buffer[0], size))
MIGRAPHX_THROW("Error reading file: " + filename);
return buffer;
}
......
......@@ -39,7 +39,7 @@ static literal get_scalar(instruction_ref ins)
if(ins->name() == "contiguous")
return get_scalar(ins->inputs().front());
const auto& s = ins->get_shape();
if(not(s.elements() == 1 or s.scalar()))
if(s.elements() != 1 && not(s.scalar()))
return {};
if(not ins->can_eval())
return {};
......
......@@ -205,7 +205,7 @@ struct allocation_model
template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
private_detail_te_handle_type(
PrivateDetailTypeErasedT value,
typename std::enable_if<!std::is_reference<PrivateDetailTypeErasedU>::value,
typename std::enable_if<not std::is_reference<PrivateDetailTypeErasedU>::value,
int>::type* = nullptr) noexcept
: private_detail_te_value(std::move(value))
{
......@@ -267,7 +267,7 @@ struct allocation_model
private_detail_te_handle_base_type& private_detail_te_get_handle()
{
assert(private_detail_te_handle_mem_var != nullptr);
if(!private_detail_te_handle_mem_var.unique())
if(not private_detail_te_handle_mem_var.unique())
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var;
}
......
......@@ -107,6 +107,7 @@ struct argument : raw_data<argument>
data_t m_data{};
};
std::vector<shape> to_shapes(const std::vector<argument>& args);
void migraphx_to_value(value& v, const argument& a);
void migraphx_from_value(const value& v, argument& a);
......
......@@ -101,7 +101,7 @@ struct check_shapes
const check_shapes& nelements(std::size_t n) const
{
if(!this->all_of([&](const shape& s) { return s.elements() == n; }))
if(not this->all_of([&](const shape& s) { return s.elements() == n; }))
MIGRAPHX_THROW(prefix() + "Shapes must have only " + std::to_string(n) + " elements");
return *this;
}
......@@ -164,7 +164,7 @@ struct check_shapes
*/
const check_shapes& same_shape() const
{
if(!this->same([](const shape& s) { return s; }))
if(not this->same([](const shape& s) { return s; }))
MIGRAPHX_THROW(prefix() + "Shapes do not match");
return *this;
}
......@@ -174,7 +174,7 @@ struct check_shapes
*/
const check_shapes& same_type() const
{
if(!this->same([](const shape& s) { return s.type(); }))
if(not this->same([](const shape& s) { return s.type(); }))
MIGRAPHX_THROW(prefix() + "Types do not match");
return *this;
}
......@@ -184,10 +184,10 @@ struct check_shapes
*/
const check_shapes& same_dims() const
{
if(!this->same([](const shape& s) { return s.max_lens(); }))
if(not this->same([](const shape& s) { return s.max_lens(); }))
MIGRAPHX_THROW(prefix() + "Dimensions do not match");
if(this->any_of([&](const shape& s) { return s.dynamic(); }))
if(!this->same([](const shape& s) { return s.min_lens(); }))
if(not this->same([](const shape& s) { return s.min_lens(); }))
MIGRAPHX_THROW(prefix() + "Min dynamic dimensions do not match");
return *this;
}
......@@ -197,7 +197,7 @@ struct check_shapes
*/
const check_shapes& same_ndims() const
{
if(!this->same([](const shape& s) { return s.max_lens().size(); }))
if(not this->same([](const shape& s) { return s.max_lens().size(); }))
MIGRAPHX_THROW(prefix() + "Number of dimensions do not match");
return *this;
}
......@@ -207,7 +207,7 @@ struct check_shapes
*/
const check_shapes& standard() const
{
if(!this->all_of([](const shape& s) { return s.standard(); }))
if(not this->all_of([](const shape& s) { return s.standard(); }))
MIGRAPHX_THROW(prefix() + "Shapes are not in standard layout");
return *this;
}
......@@ -217,7 +217,7 @@ struct check_shapes
*/
const check_shapes& standard_or_scalar() const
{
if(!this->all_of([](const shape& s) { return s.standard() or s.scalar(); }))
if(not this->all_of([](const shape& s) { return s.standard() or s.scalar(); }))
MIGRAPHX_THROW(prefix() + "Shapes are not a scalar or in standard layout");
return *this;
}
......@@ -227,7 +227,7 @@ struct check_shapes
*/
const check_shapes& packed() const
{
if(!this->all_of([](const shape& s) { return s.packed(); }))
if(not this->all_of([](const shape& s) { return s.packed(); }))
MIGRAPHX_THROW(prefix() + "Shapes are not packed");
return *this;
}
......@@ -237,7 +237,7 @@ struct check_shapes
*/
const check_shapes& packed_or_broadcasted() const
{
if(!this->all_of([](const shape& s) { return s.packed() or s.broadcasted(); }))
if(not this->all_of([](const shape& s) { return s.packed() or s.broadcasted(); }))
MIGRAPHX_THROW(prefix() + "Shapes are not packed nor broadcasted");
return *this;
}
......@@ -247,7 +247,7 @@ struct check_shapes
*/
const check_shapes& tuple_type() const
{
if(!this->all_of([](const shape& s) { return s.type() == shape::tuple_type; }))
if(not this->all_of([](const shape& s) { return s.type() == shape::tuple_type; }))
MIGRAPHX_THROW(prefix() + "Shapes are not tuple!");
return *this;
}
......@@ -257,7 +257,7 @@ struct check_shapes
*/
const check_shapes& not_transposed() const
{
if(!this->all_of([](const shape& s) { return not s.transposed(); }))
if(not this->all_of([](const shape& s) { return not s.transposed(); }))
MIGRAPHX_THROW(prefix() + "Shapes are transposed");
return *this;
}
......@@ -267,7 +267,7 @@ struct check_shapes
*/
const check_shapes& not_broadcasted() const
{
if(!this->all_of([](const shape& s) { return not s.broadcasted(); }))
if(not this->all_of([](const shape& s) { return not s.broadcasted(); }))
MIGRAPHX_THROW(prefix() + "Shapes are broadcasted");
return *this;
}
......@@ -278,7 +278,7 @@ struct check_shapes
*/
const check_shapes& elements(std::size_t n) const
{
if(!this->all_of([&](const shape& s) { return s.elements() == n; }))
if(not this->all_of([&](const shape& s) { return s.elements() == n; }))
MIGRAPHX_THROW(prefix() + "Wrong number of elements");
return *this;
}
......@@ -288,7 +288,8 @@ struct check_shapes
*/
const check_shapes& batch_not_transposed() const
{
if(!this->all_of([&](const shape& s) { return batch_not_transposed_strides(s.strides()); }))
if(not this->all_of(
[&](const shape& s) { return batch_not_transposed_strides(s.strides()); }))
MIGRAPHX_THROW(prefix() + "Batch size is transposed");
return *this;
}
......
......@@ -36,6 +36,9 @@ struct operation;
std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
std::vector<std::size_t> s1);
std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, shape s1);
shape common_shape(const std::vector<shape>& shapes);
instruction_ref insert_common_op(module& m,
......
......@@ -183,7 +183,7 @@ struct concat_optimization
template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
private_detail_te_handle_type(
PrivateDetailTypeErasedT value,
typename std::enable_if<!std::is_reference<PrivateDetailTypeErasedU>::value,
typename std::enable_if<not std::is_reference<PrivateDetailTypeErasedU>::value,
int>::type* = nullptr) noexcept
: private_detail_te_value(std::move(value))
{
......@@ -233,7 +233,7 @@ struct concat_optimization
private_detail_te_handle_base_type& private_detail_te_get_handle()
{
assert(private_detail_te_handle_mem_var != nullptr);
if(!private_detail_te_handle_mem_var.unique())
if(not private_detail_te_handle_mem_var.unique())
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var;
}
......
......@@ -66,6 +66,15 @@ any_ptr get_queue_context(T&)
{
return {};
}
template <class T>
void wait_for_context(T&, any_ptr)
{
}
template <class T>
void finish_on_context(T&, any_ptr)
{
}
#ifdef TYPE_ERASED_DECLARATION
......@@ -78,6 +87,10 @@ struct context
void from_value(const value& v);
// (optional)
any_ptr get_queue();
// (optional)
void wait_for(any_ptr queue);
// (optional)
void finish_on(any_ptr queue);
//
void finish() const;
};
......@@ -165,6 +178,18 @@ struct context
return (*this).private_detail_te_get_handle().get_queue();
}
void wait_for(any_ptr queue)
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().wait_for(queue);
}
void finish_on(any_ptr queue)
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().finish_on(queue);
}
void finish() const
{
assert((*this).private_detail_te_handle_mem_var);
......@@ -187,6 +212,8 @@ struct context
virtual value to_value() const = 0;
virtual void from_value(const value& v) = 0;
virtual any_ptr get_queue() = 0;
virtual void wait_for(any_ptr queue) = 0;
virtual void finish_on(any_ptr queue) = 0;
virtual void finish() const = 0;
};
......@@ -231,6 +258,33 @@ struct context
return get_queue_context(private_detail_te_self);
}
template <class T>
static auto private_detail_te_default_wait_for(char, T&& private_detail_te_self, any_ptr queue)
-> decltype(private_detail_te_self.wait_for(queue))
{
private_detail_te_self.wait_for(queue);
}
template <class T>
static void private_detail_te_default_wait_for(float, T&& private_detail_te_self, any_ptr queue)
{
wait_for_context(private_detail_te_self, queue);
}
template <class T>
static auto private_detail_te_default_finish_on(char, T&& private_detail_te_self, any_ptr queue)
-> decltype(private_detail_te_self.finish_on(queue))
{
private_detail_te_self.finish_on(queue);
}
template <class T>
static void
private_detail_te_default_finish_on(float, T&& private_detail_te_self, any_ptr queue)
{
finish_on_context(private_detail_te_self, queue);
}
template <typename PrivateDetailTypeErasedT>
struct private_detail_te_handle_type : private_detail_te_handle_base_type
{
......@@ -246,9 +300,9 @@ struct context
template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
private_detail_te_handle_type(
PrivateDetailTypeErasedT value,
typename std::enable_if<!std::is_reference<PrivateDetailTypeErasedU>::value,
typename std::enable_if<not std::is_reference<PrivateDetailTypeErasedU>::value,
int>::type* = nullptr) noexcept
: private_detail_te_value(std::move(value))
: private_detail_te_value(value)
{
}
......@@ -277,6 +331,18 @@ struct context
return private_detail_te_default_get_queue(char(0), private_detail_te_value);
}
void wait_for(any_ptr queue) override
{
private_detail_te_default_wait_for(char(0), private_detail_te_value, queue);
}
void finish_on(any_ptr queue) override
{
private_detail_te_default_finish_on(char(0), private_detail_te_value, queue);
}
void finish() const override { private_detail_te_value.finish(); }
PrivateDetailTypeErasedT private_detail_te_value;
......@@ -306,7 +372,7 @@ struct context
private_detail_te_handle_base_type& private_detail_te_get_handle()
{
assert(private_detail_te_handle_mem_var != nullptr);
if(!private_detail_te_handle_mem_var.unique())
if(not private_detail_te_handle_mem_var.unique())
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var;
}
......
......@@ -21,36 +21,55 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHLIB_DYN_OUTPUT_HPP
#define MIGRAPHX_GUARD_MIGRAPHLIB_DYN_OUTPUT_HPP
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/gpu/device/concat.hpp>
#include <migraphx/gpu/device/contiguous.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
argument concat(hipStream_t stream,
const migraphx::shape&,
std::vector<migraphx::argument> args,
std::vector<std::size_t> offsets)
struct dyn_output
{
auto ninputs = args.size() - 1;
for(std::size_t j = 0; j < ninputs; j++)
// original shape from the instruction
shape ins_shape;
// shape computed at eval time using input arguments
shape computed_shape;
};
/**
* Handle dynamic and static shape at evaluation time.
* If converted to shape type, returns original ins_shape.
* If converted to dyn_output type, will compute an output shape using the input arguments.
*/
template <class F>
struct compute_output_shape
{
F ins_inputs;
operator dyn_output() const
{
auto&& arg = args[j];
auto offset = offsets[j];
auto byte_offset = offset * arg.get_shape().type_size();
auto output_shape = shape{
arg.get_shape().type(), arg.get_shape().lens(), args.back().get_shape().strides()};
auto output = argument{output_shape, args.back().data() + byte_offset};
contiguous(stream, output, arg);
return ins_inputs([](const auto& x, shape ins_shape, const std::vector<argument>& inputs) {
if(ins_shape.dynamic())
return dyn_output{ins_shape, compute_shape(x, to_shapes(inputs))};
return dyn_output{ins_shape, ins_shape};
});
}
return args.back();
operator shape() const
{
return ins_inputs(
[](const auto&, shape ins_shape, const std::vector<argument>&) { return ins_shape; });
}
};
template <class F>
compute_output_shape<F> make_compute_output_shape(F f)
{
return {f};
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -21,22 +21,21 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_ACOSH_HPP
#define MIGRAPHX_GUARD_RTGLIB_ACOSH_HPP
#ifndef MIGRAPHX_GUARD_MIGRAPHLIB_EXECUTION_ENV_HPP
#define MIGRAPHX_GUARD_MIGRAPHLIB_EXECUTION_ENV_HPP
#include <migraphx/gpu/oper.hpp>
#include <migraphx/gpu/device/acosh.hpp>
#include <migraphx/any_ptr.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct hip_acosh : unary_device<hip_acosh, device::acosh>
struct execution_environment
{
any_ptr queue = any_ptr{};
bool async = false;
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#endif /* MIGRAPHX_GUARD_MIGRAPHLIB_EXECUTION_ENV_HPP */
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment