"...lm-evaluation-harness.git" did not exist on "b898bdaa09507c45cb3a6ff87808ff36fb89353d"
Commit 870a396b authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

manual merge

parents 228b665c d309e02f
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include <migraphx/algorithm.hpp> #include <migraphx/algorithm.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -43,6 +44,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -43,6 +44,7 @@ inline namespace MIGRAPHX_INLINE_NS {
// In this case we need to broadcast the (:,:,1:,:) axis // In this case we need to broadcast the (:,:,1:,:) axis
// of s0 plus the 1st dimension of s1 giving // of s0 plus the 1st dimension of s1 giving
// output_lens = (3,2,7,5) // output_lens = (3,2,7,5)
//
std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0, std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
std::vector<std::size_t> s1) std::vector<std::size_t> s1)
{ {
...@@ -50,25 +52,62 @@ std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0, ...@@ -50,25 +52,62 @@ std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
return s0; return s0;
if(s0.size() > s1.size()) if(s0.size() > s1.size())
s0.swap(s1); s0.swap(s1);
std::vector<std::size_t> out_lens(s1); std::vector<std::size_t> out_lens(s1);
auto offset = s1.size() - s0.size(); auto offset = s1.size() - s0.size();
std::transform( std::transform(
s0.begin(), s0.end(), s1.begin() + offset, out_lens.begin() + offset, [&](auto a, auto b) { s0.begin(), s0.end(), s1.begin() + offset, out_lens.begin() + offset, [&](auto a, auto b) {
if(a != b and a != 1 and b != 1) if(a != b and a != 1 and b != 1)
{ {
MIGRAPHX_THROW("COMPUTE_BROADCASTLEN: shape {" + to_string_range(s0) + "} and {" + MIGRAPHX_THROW("COMPUTE_BROADCASTLEN: shape {" + migraphx::to_string_range(s0) +
to_string_range(s1) + "} mismatch!"); "} and {" + migraphx::to_string_range(s1) + "} mismatch!");
} }
return std::max(a, b); return std::max(a, b);
}); });
return out_lens; return out_lens;
} }
std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, shape s1)
{
// change both shapes to dynamic_dimension representation
s0 = s0.to_dynamic();
s1 = s1.to_dynamic();
if(s0.ndim() > s1.ndim())
{
std::swap(s0, s1);
}
auto offset = s1.ndim() - s0.ndim();
std::vector<shape::dynamic_dimension> out_dims(s1.dyn_dims());
std::transform(
s0.dyn_dims().cbegin(),
s0.dyn_dims().cend(),
s1.dyn_dims().cbegin() + offset,
out_dims.begin() + offset,
[&](auto a, auto b) {
if(a == b)
{
return a;
}
else if(a == 1 or b == 1)
{
// setting opt to 0, may need to be changed
return shape::dynamic_dimension{std::max(a.min, b.min), std::max(a.max, b.max), 0};
}
else
{
MIGRAPHX_THROW("COMPUTE_BROADCASTED_DYN_DIMS: dynamic shapes {" +
migraphx::to_string_range(s0.dyn_dims()) + "} and {" +
migraphx::to_string_range(s1.dyn_dims()) + "} mismatch!");
}
});
return out_dims;
}
// Compute the common (broadcasted) dimensions of a list of fixed shapes
std::vector<std::size_t> compute_common_lens(const std::vector<shape>& shapes) std::vector<std::size_t> compute_common_lens(const std::vector<shape>& shapes)
{ {
assert(not shapes.empty()); assert(not shapes.empty());
assert(
std::none_of(shapes.cbegin(), shapes.cend(), [](auto shape) { return shape.dynamic(); }));
return transform_accumulate(shapes.begin() + 1, return transform_accumulate(shapes.begin() + 1,
shapes.end(), shapes.end(),
shapes.front().lens(), shapes.front().lens(),
...@@ -114,20 +153,63 @@ instruction_ref insert_common_op(module& m, ...@@ -114,20 +153,63 @@ instruction_ref insert_common_op(module& m,
const operation& op, const operation& op,
std::vector<instruction_ref> inputs) std::vector<instruction_ref> inputs)
{ {
auto common = common_shape(to_shapes(inputs)); if(std::any_of(
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) { inputs.cbegin(), inputs.cend(), [](auto input) { return input->get_shape().dynamic(); }))
if(input->get_shape().lens() != common.lens()) {
// currently only handles the binary case
if(inputs.size() != 2)
{ {
input = m.insert_instruction( MIGRAPHX_THROW("INSERT_COMMON_OP: not handled; " + migraphx::to_string(inputs.size()) +
ins, make_op("multibroadcast", {{"out_lens", common.lens()}}), input); "inputs, only handle two inputs if any are dynamic shape");
} }
if(input->get_shape().type() != common.type())
auto c_type = compute_common_types(to_shapes(inputs));
auto c_dyn_dims =
compute_broadcasted_dyn_dims(inputs[0]->get_shape(), inputs[1]->get_shape());
// following should work for a static or dynamic shape
if(inputs[0]->get_shape().dyn_dims() != c_dyn_dims)
{ {
input = m.insert_instruction( inputs[0] = m.insert_instruction(
ins, make_op("convert", {{"target_type", common.type()}}), input); ins,
make_op("multibroadcast", {{"out_dyn_dims", to_value(c_dyn_dims)}}),
inputs[0],
inputs[1]);
} }
return input; if(inputs[1]->get_shape().dyn_dims() != c_dyn_dims)
}); {
inputs[1] = m.insert_instruction(
ins,
make_op("multibroadcast", {{"out_dyn_dims", to_value(c_dyn_dims)}}),
inputs[1],
inputs[0]);
}
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) {
if(input->get_shape().type() != c_type)
{
input =
m.insert_instruction(ins, make_op("convert", {{"target_type", c_type}}), input);
}
return input;
});
}
else
{
auto common = common_shape(to_shapes(inputs));
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) {
if(input->get_shape().lens() != common.lens())
{
input = m.insert_instruction(
ins, make_op("multibroadcast", {{"out_lens", common.lens()}}), input);
}
if(input->get_shape().type() != common.type())
{
input = m.insert_instruction(
ins, make_op("convert", {{"target_type", common.type()}}), input);
}
return input;
});
}
return m.insert_instruction(ins, op, inputs); return m.insert_instruction(ins, op, inputs);
} }
......
...@@ -51,8 +51,8 @@ void dead_code_elimination::apply(module& m) const ...@@ -51,8 +51,8 @@ void dead_code_elimination::apply(module& m) const
// Skip instruction with empty shape as output unless its [dynamic, builtin, undefined, // Skip instruction with empty shape as output unless its [dynamic, builtin, undefined,
// identity, allocate] // identity, allocate]
if((not i->get_shape().dynamic() and i->get_shape().elements() == 0) and if((not i->get_shape().dynamic() and i->get_shape().elements() == 0) and
i->name().front() != '@' and not(i->name().front() == '@') and not contains({"identity", "allocate"}, i->name()) and
not contains({"undefined", "identity", "allocate"}, i->name())) not i->is_undefined())
continue; continue;
assert(std::distance(m.begin(), i) <= std::distance(m.begin(), last)); assert(std::distance(m.begin(), i) <= std::distance(m.begin(), last));
std::unordered_set<instruction_ref> visited; std::unordered_set<instruction_ref> visited;
......
...@@ -74,9 +74,9 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size) ...@@ -74,9 +74,9 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
auto x_main_module_19 = mmain->add_literal(migraphx::generate_literal( auto x_main_module_19 = mmain->add_literal(migraphx::generate_literal(
migraphx::shape{migraphx::shape::float_type, {64, 3, 11, 11}}, 18)); migraphx::shape{migraphx::shape::float_type, {64, 3, 11, 11}}, 18));
auto x_main_module_20 = mmain->add_instruction( auto x_main_module_20 = mmain->add_instruction(
migraphx::make_json_op("convolution", migraphx::make_json_op(
"{dilation:[1,1],group:1,padding:[2,2,2,2],padding_mode:0,stride:[4," "convolution",
"4],use_dynamic_same_auto_pad:0}"), "{dilation:[1,1],group:1,padding:[2,2,2,2],padding_mode:0,stride:[4,4]}"),
x_0, x_0,
x_main_module_19); x_main_module_19);
auto x_main_module_21 = mmain->add_instruction( auto x_main_module_21 = mmain->add_instruction(
...@@ -90,9 +90,9 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size) ...@@ -90,9 +90,9 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
"{ceil_mode:0,lengths:[3,3],lp_order:2,mode:1,padding:[0,0,0,0],stride:[2,2]}"), "{ceil_mode:0,lengths:[3,3],lp_order:2,mode:1,padding:[0,0,0,0],stride:[2,2]}"),
x_main_module_23); x_main_module_23);
auto x_main_module_25 = mmain->add_instruction( auto x_main_module_25 = mmain->add_instruction(
migraphx::make_json_op("convolution", migraphx::make_json_op(
"{dilation:[1,1],group:1,padding:[2,2,2,2],padding_mode:0,stride:[1," "convolution",
"1],use_dynamic_same_auto_pad:0}"), "{dilation:[1,1],group:1,padding:[2,2,2,2],padding_mode:0,stride:[1,1]}"),
x_main_module_24, x_main_module_24,
x_main_module_17); x_main_module_17);
auto x_main_module_26 = mmain->add_instruction( auto x_main_module_26 = mmain->add_instruction(
...@@ -106,9 +106,9 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size) ...@@ -106,9 +106,9 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
"{ceil_mode:0,lengths:[3,3],lp_order:2,mode:1,padding:[0,0,0,0],stride:[2,2]}"), "{ceil_mode:0,lengths:[3,3],lp_order:2,mode:1,padding:[0,0,0,0],stride:[2,2]}"),
x_main_module_28); x_main_module_28);
auto x_main_module_30 = mmain->add_instruction( auto x_main_module_30 = mmain->add_instruction(
migraphx::make_json_op("convolution", migraphx::make_json_op(
"{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1," "convolution",
"1],use_dynamic_same_auto_pad:0}"), "{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1,1]}"),
x_main_module_29, x_main_module_29,
x_main_module_15); x_main_module_15);
auto x_main_module_31 = mmain->add_instruction( auto x_main_module_31 = mmain->add_instruction(
...@@ -117,9 +117,9 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size) ...@@ -117,9 +117,9 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
mmain->add_instruction(migraphx::make_op("add"), x_main_module_30, x_main_module_31); mmain->add_instruction(migraphx::make_op("add"), x_main_module_30, x_main_module_31);
auto x_main_module_33 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_32); auto x_main_module_33 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_32);
auto x_main_module_34 = mmain->add_instruction( auto x_main_module_34 = mmain->add_instruction(
migraphx::make_json_op("convolution", migraphx::make_json_op(
"{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1," "convolution",
"1],use_dynamic_same_auto_pad:0}"), "{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1,1]}"),
x_main_module_33, x_main_module_33,
x_main_module_13); x_main_module_13);
auto x_main_module_35 = mmain->add_instruction( auto x_main_module_35 = mmain->add_instruction(
...@@ -128,9 +128,9 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size) ...@@ -128,9 +128,9 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
mmain->add_instruction(migraphx::make_op("add"), x_main_module_34, x_main_module_35); mmain->add_instruction(migraphx::make_op("add"), x_main_module_34, x_main_module_35);
auto x_main_module_37 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_36); auto x_main_module_37 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_36);
auto x_main_module_38 = mmain->add_instruction( auto x_main_module_38 = mmain->add_instruction(
migraphx::make_json_op("convolution", migraphx::make_json_op(
"{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1," "convolution",
"1],use_dynamic_same_auto_pad:0}"), "{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1,1]}"),
x_main_module_37, x_main_module_37,
x_main_module_11); x_main_module_11);
auto x_main_module_39 = mmain->add_instruction( auto x_main_module_39 = mmain->add_instruction(
......
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -44,7 +44,6 @@ ...@@ -44,7 +44,6 @@
#include <migraphx/propagate_constant.hpp> #include <migraphx/propagate_constant.hpp>
#include <migraphx/quantization.hpp> #include <migraphx/quantization.hpp>
#include <migraphx/register_op.hpp> #include <migraphx/register_op.hpp>
#include <migraphx/rewrite_batchnorm.hpp>
#include <migraphx/simplify_algebra.hpp> #include <migraphx/simplify_algebra.hpp>
#include <migraphx/simplify_reshapes.hpp> #include <migraphx/simplify_reshapes.hpp>
#include <migraphx/register_target.hpp> #include <migraphx/register_target.hpp>
...@@ -110,8 +109,12 @@ struct loader ...@@ -110,8 +109,12 @@ struct loader
ap(brief, {"--brief"}, ap.help("Make the output brief."), ap.set_value(true)); ap(brief, {"--brief"}, ap.help("Make the output brief."), ap.set_value(true));
ap(output_type, ap(output_type,
{"--cpp"}, {"--cpp"},
ap.help("Print out the program as cpp program."), ap.help("Print out the program as C++ program."),
ap.set_value("cpp")); ap.set_value("cpp"));
ap(output_type,
{"--python", "--py"},
ap.help("Print out the program as python program."),
ap.set_value("py"));
ap(output_type, {"--json"}, ap.help("Print out program as json."), ap.set_value("json")); ap(output_type, {"--json"}, ap.help("Print out program as json."), ap.set_value("json"));
ap(output_type, ap(output_type,
{"--text"}, {"--text"},
...@@ -221,7 +224,6 @@ struct loader ...@@ -221,7 +224,6 @@ struct loader
{ {
migraphx::run_passes(*p.get_main_module(), migraphx::run_passes(*p.get_main_module(),
{ {
migraphx::rewrite_batchnorm{},
migraphx::eliminate_identity{}, migraphx::eliminate_identity{},
migraphx::dead_code_elimination{}, migraphx::dead_code_elimination{},
migraphx::simplify_algebra{}, migraphx::simplify_algebra{},
...@@ -261,7 +263,9 @@ struct loader ...@@ -261,7 +263,9 @@ struct loader
type = "binary"; type = "binary";
} }
if(type == "cpp") if(type == "py")
p.print_py(*os);
else if(type == "cpp")
p.print_cpp(*os); p.print_cpp(*os);
else if(type == "graphviz") else if(type == "graphviz")
p.print_graph(*os, brief); p.print_graph(*os, brief);
......
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -145,7 +145,7 @@ void verify_reduced(program p, ...@@ -145,7 +145,7 @@ void verify_reduced(program p,
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto last = std::prev(mm->end(), n + 1); auto last = std::prev(mm->end(), n + 1);
mm->remove_instructions(last, mm->end()); mm->remove_instructions(last, mm->end());
std::cout << "Verify: " << std::endl; std::cout << "Verify: " << n << std::endl;
std::cout << p << std::endl; std::cout << p << std::endl;
verify_program(std::to_string(n), p, t, options, quantize, inputs, tolerance); verify_program(std::to_string(n), p, t, options, quantize, inputs, tolerance);
} }
...@@ -159,6 +159,7 @@ void verify_reduced_program(const program& p, ...@@ -159,6 +159,7 @@ void verify_reduced_program(const program& p,
{ {
const auto* mm = p.get_main_module(); const auto* mm = p.get_main_module();
auto n = std::distance(mm->begin(), mm->end()); auto n = std::distance(mm->begin(), mm->end());
std::cout << "Verify steps: " << n << std::endl;
for(std::size_t i = 0; i < n; i++) for(std::size_t i = 0; i < n; i++)
{ {
verify_reduced(p, i, t, options, quantize, inputs, tolerance); verify_reduced(p, i, t, options, quantize, inputs, tolerance);
......
...@@ -42,6 +42,13 @@ static bool try_compute_shape(instruction_ref ins, ...@@ -42,6 +42,13 @@ static bool try_compute_shape(instruction_ref ins,
try try
{ {
shape new_shape = ins->get_operator().compute_shape(inputs, mods); shape new_shape = ins->get_operator().compute_shape(inputs, mods);
// Cannot tell if a dynamic shape will need to be made contiguous
if(new_shape.dynamic())
{
return false;
}
// If the output shape is a standard shape, no need to try its output // If the output shape is a standard shape, no need to try its output
if(new_shape.standard()) if(new_shape.standard())
{ {
...@@ -133,14 +140,20 @@ static void remove_contiguous(const std::string& op_name, module& m, F f) ...@@ -133,14 +140,20 @@ static void remove_contiguous(const std::string& op_name, module& m, F f)
} }
} }
// Perform evaluations in parallel // Perform static contiguous evaluations in parallel
std::vector<argument> literals(const_instructions.size()); std::vector<argument> literals(const_instructions.size());
par_for(const_instructions.size(), 1, [&](const auto i) { par_for(const_instructions.size(), 1, [&](const auto i) {
auto c = op::contiguous{}; auto c = op::contiguous{};
auto prev = const_instructions[i]->inputs().front(); auto prev = const_instructions[i]->inputs().front();
literals[i] = c.compute(c.compute_shape({prev->get_shape()}), {prev->eval()}); // compute the output contiguous shape from the previous instruction shape
shape computed_shape = c.compute_shape({prev->get_shape()});
const std::vector<argument>& prev_eval = {prev->eval()};
// prev_eval should not be used in make_compute_output_shape() as computed_shape is static
auto co_shape = make_compute_output_shape(pack(c, computed_shape, prev_eval));
literals[i] = c.compute(co_shape, prev_eval);
}); });
// Replace static contiguous operations with a literal
for(size_t i = 0; i < const_instructions.size(); i++) for(size_t i = 0; i < const_instructions.size(); i++)
{ {
auto l = m.add_literal(literals[i].get_shape(), literals[i].data()); auto l = m.add_literal(literals[i].get_shape(), literals[i].data());
......
...@@ -30,23 +30,31 @@ namespace migraphx { ...@@ -30,23 +30,31 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
template <class T> template <class T>
T generic_read_file(const std::string& filename) T generic_read_file(const std::string& filename, size_t offset = 0, size_t nbytes = 0)
{ {
std::ifstream is(filename, std::ios::binary | std::ios::ate); std::ifstream is(filename, std::ios::binary | std::ios::ate);
std::streamsize size = is.tellg(); if(nbytes == 0)
if(size < 1) {
// if there is a non-zero offset and nbytes is not set,
// calculate size of remaining bytes to read
nbytes = is.tellg();
if(offset > nbytes)
MIGRAPHX_THROW("offset is larger than file size");
nbytes -= offset;
}
if(nbytes < 1)
MIGRAPHX_THROW("Invalid size for: " + filename); MIGRAPHX_THROW("Invalid size for: " + filename);
is.seekg(0, std::ios::beg); is.seekg(offset, std::ios::beg);
T buffer(size, 0); T buffer(nbytes, 0);
if(not is.read(&buffer[0], size)) if(not is.read(&buffer[0], nbytes))
MIGRAPHX_THROW("Error reading file: " + filename); MIGRAPHX_THROW("Error reading file: " + filename);
return buffer; return buffer;
} }
std::vector<char> read_buffer(const std::string& filename) std::vector<char> read_buffer(const std::string& filename, size_t offset, size_t nbytes)
{ {
return generic_read_file<std::vector<char>>(filename); return generic_read_file<std::vector<char>>(filename, offset, nbytes);
} }
std::string read_string(const std::string& filename) std::string read_string(const std::string& filename)
......
...@@ -39,13 +39,22 @@ static literal get_scalar(instruction_ref ins) ...@@ -39,13 +39,22 @@ static literal get_scalar(instruction_ref ins)
if(ins->name() == "contiguous") if(ins->name() == "contiguous")
return get_scalar(ins->inputs().front()); return get_scalar(ins->inputs().front());
const auto& s = ins->get_shape(); const auto& s = ins->get_shape();
if(not(s.elements() == 1 or s.scalar())) if(s.elements() != 1 && not(s.scalar()))
return {}; return {};
if(not ins->can_eval()) if(not ins->can_eval())
return {}; return {};
auto e = ins->eval(); auto e = ins->eval();
literal r{}; literal r{};
e.visit_at([&](auto x) { r = literal{x}; }); // needed for bool as visit_at invokes as() which promotes bool to int8
// Without this we'll break type checks for logical ops that are fused.
if(e.get_shape().type() == shape::bool_type)
{
r = literal{e.at<bool>()};
}
else
{
e.visit_at([&](auto x) { r = literal{x}; });
}
return r; return r;
} }
......
...@@ -107,6 +107,7 @@ struct argument : raw_data<argument> ...@@ -107,6 +107,7 @@ struct argument : raw_data<argument>
data_t m_data{}; data_t m_data{};
}; };
std::vector<shape> to_shapes(const std::vector<argument>& args);
void migraphx_to_value(value& v, const argument& a); void migraphx_to_value(value& v, const argument& a);
void migraphx_from_value(const value& v, argument& a); void migraphx_from_value(const value& v, argument& a);
......
...@@ -198,7 +198,7 @@ struct check_shapes ...@@ -198,7 +198,7 @@ struct check_shapes
*/ */
const check_shapes& same_ndims() const const check_shapes& same_ndims() const
{ {
if(not this->same([](const shape& s) { return s.max_lens().size(); })) if(not this->same([](const shape& s) { return s.ndim(); }))
MIGRAPHX_THROW(prefix() + "Number of dimensions do not match"); MIGRAPHX_THROW(prefix() + "Number of dimensions do not match");
return *this; return *this;
} }
......
...@@ -36,6 +36,9 @@ struct operation; ...@@ -36,6 +36,9 @@ struct operation;
std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0, std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
std::vector<std::size_t> s1); std::vector<std::size_t> s1);
std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, shape s1);
shape common_shape(const std::vector<shape>& shapes); shape common_shape(const std::vector<shape>& shapes);
instruction_ref insert_common_op(module& m, instruction_ref insert_common_op(module& m,
......
...@@ -66,6 +66,15 @@ any_ptr get_queue_context(T&) ...@@ -66,6 +66,15 @@ any_ptr get_queue_context(T&)
{ {
return {}; return {};
} }
template <class T>
void wait_for_context(T&, any_ptr)
{
}
template <class T>
void finish_on_context(T&, any_ptr)
{
}
#ifdef TYPE_ERASED_DECLARATION #ifdef TYPE_ERASED_DECLARATION
...@@ -78,6 +87,10 @@ struct context ...@@ -78,6 +87,10 @@ struct context
void from_value(const value& v); void from_value(const value& v);
// (optional) // (optional)
any_ptr get_queue(); any_ptr get_queue();
// (optional)
void wait_for(any_ptr queue);
// (optional)
void finish_on(any_ptr queue);
// //
void finish() const; void finish() const;
}; };
...@@ -165,6 +178,18 @@ struct context ...@@ -165,6 +178,18 @@ struct context
return (*this).private_detail_te_get_handle().get_queue(); return (*this).private_detail_te_get_handle().get_queue();
} }
void wait_for(any_ptr queue)
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().wait_for(queue);
}
void finish_on(any_ptr queue)
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().finish_on(queue);
}
void finish() const void finish() const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
...@@ -187,6 +212,8 @@ struct context ...@@ -187,6 +212,8 @@ struct context
virtual value to_value() const = 0; virtual value to_value() const = 0;
virtual void from_value(const value& v) = 0; virtual void from_value(const value& v) = 0;
virtual any_ptr get_queue() = 0; virtual any_ptr get_queue() = 0;
virtual void wait_for(any_ptr queue) = 0;
virtual void finish_on(any_ptr queue) = 0;
virtual void finish() const = 0; virtual void finish() const = 0;
}; };
...@@ -231,6 +258,33 @@ struct context ...@@ -231,6 +258,33 @@ struct context
return get_queue_context(private_detail_te_self); return get_queue_context(private_detail_te_self);
} }
template <class T>
static auto private_detail_te_default_wait_for(char, T&& private_detail_te_self, any_ptr queue)
-> decltype(private_detail_te_self.wait_for(queue))
{
private_detail_te_self.wait_for(queue);
}
template <class T>
static void private_detail_te_default_wait_for(float, T&& private_detail_te_self, any_ptr queue)
{
wait_for_context(private_detail_te_self, queue);
}
template <class T>
static auto private_detail_te_default_finish_on(char, T&& private_detail_te_self, any_ptr queue)
-> decltype(private_detail_te_self.finish_on(queue))
{
private_detail_te_self.finish_on(queue);
}
template <class T>
static void
private_detail_te_default_finish_on(float, T&& private_detail_te_self, any_ptr queue)
{
finish_on_context(private_detail_te_self, queue);
}
template <typename PrivateDetailTypeErasedT> template <typename PrivateDetailTypeErasedT>
struct private_detail_te_handle_type : private_detail_te_handle_base_type struct private_detail_te_handle_type : private_detail_te_handle_base_type
{ {
...@@ -248,7 +302,7 @@ struct context ...@@ -248,7 +302,7 @@ struct context
PrivateDetailTypeErasedT value, PrivateDetailTypeErasedT value,
typename std::enable_if<not std::is_reference<PrivateDetailTypeErasedU>::value, typename std::enable_if<not std::is_reference<PrivateDetailTypeErasedU>::value,
int>::type* = nullptr) noexcept int>::type* = nullptr) noexcept
: private_detail_te_value(std::move(value)) : private_detail_te_value(value)
{ {
} }
...@@ -277,6 +331,18 @@ struct context ...@@ -277,6 +331,18 @@ struct context
return private_detail_te_default_get_queue(char(0), private_detail_te_value); return private_detail_te_default_get_queue(char(0), private_detail_te_value);
} }
void wait_for(any_ptr queue) override
{
private_detail_te_default_wait_for(char(0), private_detail_te_value, queue);
}
void finish_on(any_ptr queue) override
{
private_detail_te_default_finish_on(char(0), private_detail_te_value, queue);
}
void finish() const override { private_detail_te_value.finish(); } void finish() const override { private_detail_te_value.finish(); }
PrivateDetailTypeErasedT private_detail_te_value; PrivateDetailTypeErasedT private_detail_te_value;
......
...@@ -21,36 +21,55 @@ ...@@ -21,36 +21,55 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#ifndef MIGRAPHX_GUARD_MIGRAPHLIB_DYN_OUTPUT_HPP
#define MIGRAPHX_GUARD_MIGRAPHLIB_DYN_OUTPUT_HPP
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/gpu/device/concat.hpp>
#include <migraphx/gpu/device/contiguous.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
argument concat(hipStream_t stream, struct dyn_output
const migraphx::shape&,
std::vector<migraphx::argument> args,
std::vector<std::size_t> offsets)
{ {
auto ninputs = args.size() - 1; // original shape from the instruction
for(std::size_t j = 0; j < ninputs; j++) shape ins_shape;
// shape computed at eval time using input arguments
shape computed_shape;
};
/**
* Handle dynamic and static shape at evaluation time.
* If converted to shape type, returns original ins_shape.
* If converted to dyn_output type, will compute an output shape using the input arguments.
*/
template <class F>
struct compute_output_shape
{
F ins_inputs;
operator dyn_output() const
{ {
auto&& arg = args[j]; return ins_inputs([](const auto& x, shape ins_shape, const std::vector<argument>& inputs) {
auto offset = offsets[j]; if(ins_shape.dynamic())
auto byte_offset = offset * arg.get_shape().type_size(); return dyn_output{ins_shape, compute_shape(x, to_shapes(inputs))};
auto output_shape = shape{ return dyn_output{ins_shape, ins_shape};
arg.get_shape().type(), arg.get_shape().lens(), args.back().get_shape().strides()}; });
auto output = argument{output_shape, args.back().data() + byte_offset};
contiguous(stream, output, arg);
} }
return args.back();
operator shape() const
{
return ins_inputs(
[](const auto&, shape ins_shape, const std::vector<argument>&) { return ins_shape; });
}
};
template <class F>
compute_output_shape<F> make_compute_output_shape(F f)
{
return {f};
} }
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
#endif
...@@ -21,22 +21,21 @@ ...@@ -21,22 +21,21 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#ifndef MIGRAPHX_GUARD_RTGLIB_ADD_HPP #ifndef MIGRAPHX_GUARD_MIGRAPHLIB_EXECUTION_ENV_HPP
#define MIGRAPHX_GUARD_RTGLIB_ADD_HPP #define MIGRAPHX_GUARD_MIGRAPHLIB_EXECUTION_ENV_HPP
#include <migraphx/gpu/oper.hpp> #include <migraphx/any_ptr.hpp>
#include <migraphx/gpu/device/add.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct hip_add : binary_device<hip_add, device::add> struct execution_environment
{ {
any_ptr queue = any_ptr{};
bool async = false;
}; };
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
#endif #endif /* MIGRAPHX_GUARD_MIGRAPHLIB_EXECUTION_ENV_HPP */
...@@ -31,7 +31,7 @@ ...@@ -31,7 +31,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
std::vector<char> read_buffer(const std::string& filename); std::vector<char> read_buffer(const std::string& filename, size_t offset = 0, size_t nbytes = 0);
std::string read_string(const std::string& filename); std::string read_string(const std::string& filename);
void write_buffer(const std::string& filename, const char* buffer, std::size_t size); void write_buffer(const std::string& filename, const char* buffer, std::size_t size);
......
...@@ -121,6 +121,8 @@ struct instruction ...@@ -121,6 +121,8 @@ struct instruction
bool can_eval() const; bool can_eval() const;
bool is_undefined() const;
argument eval(bool check_eval = true) const; argument eval(bool check_eval = true) const;
void finalize(context& ctx); void finalize(context& ctx);
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHX_LAYOUT_NHWC_HPP #ifndef MIGRAPHX_GUARD_MIGRAPHX_LAYOUT_NHWC_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_LAYOUT_NHWC_HPP #define MIGRAPHX_GUARD_MIGRAPHX_LAYOUT_NHWC_HPP
......
...@@ -80,6 +80,7 @@ struct literal : raw_data<literal> ...@@ -80,6 +80,7 @@ struct literal : raw_data<literal>
fill(start, end); fill(start, end);
} }
// Directly copies buffer of x
template <class T, MIGRAPHX_REQUIRES(sizeof(T) == 1)> template <class T, MIGRAPHX_REQUIRES(sizeof(T) == 1)>
literal(const shape& s, T* x) : buffer(make_shared_array<char>(s.bytes())), m_shape(s) literal(const shape& s, T* x) : buffer(make_shared_array<char>(s.bytes())), m_shape(s)
{ {
...@@ -107,25 +108,15 @@ struct literal : raw_data<literal> ...@@ -107,25 +108,15 @@ struct literal : raw_data<literal>
std::shared_ptr<char> buffer; std::shared_ptr<char> buffer;
shape m_shape; shape m_shape;
// Keeps the same data ordering as the given container
template <class Iterator> template <class Iterator>
void fill(Iterator start, Iterator end) void fill(Iterator start, Iterator end)
{ {
assert(std::distance(start, end) == m_shape.elements()); assert(std::distance(start, end) == m_shape.elements());
if(m_shape.standard()) m_shape.visit_type([&](auto as) {
{ auto output = make_view(m_shape, as.from(buffer.get()));
m_shape.visit_type([&](auto as) { std::copy(start, end, as.from(buffer.get())); }); std::copy(start, end, output.begin());
} });
else
{
auto it = start;
m_shape.visit_type([&](auto as) {
auto output = make_view(m_shape, as.from(buffer.get()));
shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) = *it; // NOLINT(bugprone-signed-char-misuse)
it++;
});
});
}
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment