Commit fe493c28 authored by Alan Turner's avatar Alan Turner
Browse files

Merge remote-tracking branch 'origin/develop' into ck-gsg

parents ba0b3794 cce35871
...@@ -40,7 +40,11 @@ struct quantizelinear ...@@ -40,7 +40,11 @@ struct quantizelinear
std::string name() const { return "quantizelinear"; } std::string name() const { return "quantizelinear"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.same_dims(); check_shapes{inputs, *this}.same_dims().has(2, 3);
if(inputs[0].type() != inputs[1].type())
{
MIGRAPHX_THROW("QUANTIZELINEAR: Scales and input must be the same type");
}
if(inputs.size() == 3) if(inputs.size() == 3)
{ {
return {inputs[2].type(), inputs[0].lens(), inputs[0].strides()}; return {inputs[2].type(), inputs[0].lens(), inputs[0].strides()};
...@@ -61,8 +65,7 @@ struct quantizelinear ...@@ -61,8 +65,7 @@ struct quantizelinear
argument result{output_shape}; argument result{output_shape};
visit_all(result, y_zero_point)([&](auto output, auto zero_pts) { visit_all(result, y_zero_point)([&](auto output, auto zero_pts) {
x.visit([&](auto input) { visit_all(x, y_scale)([&](auto input, auto scales) {
y_scale.visit([&](auto scales) {
using quant_type = typename decltype(output)::value_type; using quant_type = typename decltype(output)::value_type;
auto min_value = std::numeric_limits<quant_type>::min(); auto min_value = std::numeric_limits<quant_type>::min();
auto max_value = std::numeric_limits<quant_type>::max(); auto max_value = std::numeric_limits<quant_type>::max();
...@@ -74,7 +77,6 @@ struct quantizelinear ...@@ -74,7 +77,6 @@ struct quantizelinear
}); });
}); });
}); });
});
return result; return result;
} }
......
...@@ -91,7 +91,7 @@ struct reduce_op : op_name<Derived> ...@@ -91,7 +91,7 @@ struct reduce_op : op_name<Derived>
{ {
value normalize; value normalize;
normalize["axes"] = value::array{normalize_attribute::include_min}; normalize["axes"] = value::array{normalize_attribute::include_min};
return {{"normalize_axes", normalize}}; return {{"normalize_axes", normalize}, {"reduce", true}};
} }
std::vector<int64_t> tune_axes(std::size_t n_dim) const std::vector<int64_t> tune_axes(std::size_t n_dim) const
...@@ -123,9 +123,7 @@ struct reduce_op : op_name<Derived> ...@@ -123,9 +123,7 @@ struct reduce_op : op_name<Derived>
auto tuned_axes = tune_axes(output_dyn_dims.size()); auto tuned_axes = tune_axes(output_dyn_dims.size());
for(const auto& axis : tuned_axes) for(const auto& axis : tuned_axes)
{ {
// At the time of writing, there's no functional difference between output_dyn_dims[axis] = {1, 1};
// optimum of 0 (no opt) or 1.
output_dyn_dims[axis] = {1, 1, 0};
} }
return shape{s.type(), output_dyn_dims}; return shape{s.type(), output_dyn_dims};
......
...@@ -57,6 +57,7 @@ struct select_module ...@@ -57,6 +57,7 @@ struct select_module
param_names.cend(), param_names.cend(),
std::back_inserter(ret), std::back_inserter(ret),
[](auto pn) { return not contains(pn, "#output_"); }); [](auto pn) { return not contains(pn, "#output_"); });
std::sort(ret.begin(), ret.end());
return ret; return ret;
} }
...@@ -68,6 +69,8 @@ struct select_module ...@@ -68,6 +69,8 @@ struct select_module
param_names.cend(), param_names.cend(),
std::back_inserter(ret), std::back_inserter(ret),
[](auto pn) { return contains(pn, "#output_"); }); [](auto pn) { return contains(pn, "#output_"); });
// needs to be sorted to ensure output parameter ordering
std::sort(ret.begin(), ret.end());
return ret; return ret;
} }
...@@ -111,6 +114,7 @@ struct select_module ...@@ -111,6 +114,7 @@ struct select_module
// One tuple output parameter in main module to multiple output parameters in submodule // One tuple output parameter in main module to multiple output parameters in submodule
auto out_param_names = get_output_parameter_names(module_to_run); auto out_param_names = get_output_parameter_names(module_to_run);
auto param_shapes = module_to_run->get_parameter_shapes();
auto output_sub_objects = args.back().get_sub_objects(); auto output_sub_objects = args.back().get_sub_objects();
assert(out_param_names.size() == output_sub_objects.size()); assert(out_param_names.size() == output_sub_objects.size());
std::transform(out_param_names.begin(), std::transform(out_param_names.begin(),
...@@ -118,7 +122,7 @@ struct select_module ...@@ -118,7 +122,7 @@ struct select_module
output_sub_objects.begin(), output_sub_objects.begin(),
std::inserter(p_map, p_map.end()), std::inserter(p_map, p_map.end()),
[&](auto&& name, auto&& a) { [&](auto&& name, auto&& a) {
auto ps = module_to_run->get_parameter_shape(name); auto ps = param_shapes.at(name);
if(a.get_shape() != ps) if(a.get_shape() != ps)
{ {
assert(ps.bytes() == a.get_shape().bytes()); assert(ps.bytes() == a.get_shape().bytes());
......
...@@ -111,16 +111,15 @@ struct slice ...@@ -111,16 +111,15 @@ struct slice
// For a static shape, old_lens will be adjusted to a new size // For a static shape, old_lens will be adjusted to a new size
// for those axes that are sliced. // for those axes that are sliced.
// For dynamic shape, the adjusted old_lens become the new max values, // For dynamic shape, the adjusted old_lens become the new max values,
// while updating the old mins and opts if possible. // while updating the old mins and optimals if possible.
std::vector<std::size_t> new_mins; std::vector<std::size_t> new_mins;
std::vector<std::size_t> new_opts;
std::vector<std::size_t> old_lens; std::vector<std::size_t> old_lens;
std::vector<std::size_t> old_strides; std::vector<std::size_t> old_strides;
// Doesn't handle optimals
if(input_shape.dynamic()) if(input_shape.dynamic())
{ {
old_lens = input_shape.max_lens(); old_lens = input_shape.max_lens();
new_mins = input_shape.min_lens(); new_mins = input_shape.min_lens();
new_opts = input_shape.opt_lens();
} }
else else
{ {
...@@ -146,17 +145,11 @@ struct slice ...@@ -146,17 +145,11 @@ struct slice
std::size_t sliced_min_length = ends[i] - starts[i]; std::size_t sliced_min_length = ends[i] - starts[i];
// if the slice size is smaller than maxes but larger than mins // if the slice size is smaller than maxes but larger than mins
new_mins[axis] = std::min(sliced_min_length, new_mins[axis]); new_mins[axis] = std::min(sliced_min_length, new_mins[axis]);
auto sliced_opt_length = ends[i] - starts[i];
if(new_opts[axis] != 0)
new_opts[axis] = sliced_opt_length;
if(new_opts[axis] < new_mins[axis] or new_opts[axis] > new_lens[axis])
new_opts[axis] = 0;
} }
} }
if(input_shape.dynamic()) if(input_shape.dynamic())
{ {
return shape{t, new_mins, new_lens, new_opts}; return shape{t, new_mins, new_lens, {}};
} }
else else
{ {
......
...@@ -81,7 +81,7 @@ struct unsqueeze ...@@ -81,7 +81,7 @@ struct unsqueeze
{ {
if(std::find(axes.begin(), axes.end(), i) != axes.end()) if(std::find(axes.begin(), axes.end(), i) != axes.end())
{ {
dyn_dims.push_back({1, 1, 0}); dyn_dims.push_back({1, 1});
} }
else else
{ {
......
...@@ -39,6 +39,7 @@ struct module_pass_manager ...@@ -39,6 +39,7 @@ struct module_pass_manager
virtual module& get_module() = 0; virtual module& get_module() = 0;
virtual module* create_module(const std::string& name) = 0; virtual module* create_module(const std::string& name) = 0;
virtual module* get_common_parent() = 0; virtual module* get_common_parent() = 0;
virtual module* get_root_module() = 0;
virtual void run_pass(const pass& p) = 0; virtual void run_pass(const pass& p) = 0;
protected: protected:
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/filesystem.hpp> #include <migraphx/filesystem.hpp>
#include <functional>
#include <string> #include <string>
#include <memory> #include <memory>
...@@ -36,6 +37,7 @@ struct process_impl; ...@@ -36,6 +37,7 @@ struct process_impl;
struct process struct process
{ {
using writer = std::function<void(const char*, std::size_t)>;
process(const std::string& cmd); process(const std::string& cmd);
// move constructor // move constructor
...@@ -49,6 +51,7 @@ struct process ...@@ -49,6 +51,7 @@ struct process
process& cwd(const fs::path& p); process& cwd(const fs::path& p);
void exec(); void exec();
void write(std::function<void(process::writer)> pipe_in);
private: private:
std::unique_ptr<process_impl> impl; std::unique_ptr<process_impl> impl;
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_PROMOTE_LITERALS_HPP
#define MIGRAPHX_GUARD_RTGLIB_PROMOTE_LITERALS_HPP
#include <string>
#include <migraphx/pass_manager.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
/**
* Replace literals in submodules with literals in the root module.
* Intended to allow for reuse of the literals between submodules.
*/
struct promote_literals
{
std::string name() const { return "promote_literals"; }
void apply(module_pass_manager&) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -188,7 +188,8 @@ auto from_value_impl(rank<3>, const value& v, T& x) ...@@ -188,7 +188,8 @@ auto from_value_impl(rank<3>, const value& v, T& x)
} }
template <class T> template <class T>
auto from_value_impl(rank<4>, const value& v, T& x) -> decltype(x.insert(*x.begin()), void()) auto from_value_impl(rank<4>, const value& v, T& x)
-> decltype(x.insert(*x.begin()), std::declval<typename T::mapped_type>(), void())
{ {
x.clear(); x.clear();
for(auto&& e : v) for(auto&& e : v)
......
...@@ -29,10 +29,12 @@ ...@@ -29,10 +29,12 @@
#include <ostream> #include <ostream>
#include <numeric> #include <numeric>
#include <memory> #include <memory>
#include <set>
#include <migraphx/functional.hpp> #include <migraphx/functional.hpp>
#include <migraphx/errors.hpp> #include <migraphx/errors.hpp>
#include <migraphx/half.hpp> #include <migraphx/half.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
namespace migraphx { namespace migraphx {
...@@ -87,12 +89,12 @@ struct shape ...@@ -87,12 +89,12 @@ struct shape
{ {
std::size_t min = 0; std::size_t min = 0;
std::size_t max = 0; std::size_t max = 0;
std::size_t opt = 0; std::set<std::size_t> optimals{};
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
{ {
return pack(f(self.min, "min"), f(self.max, "max"), f(self.opt, "opt")); return pack(f(self.min, "min"), f(self.max, "max"), f(self.optimals, "optimals"));
} }
bool is_fixed() const; bool is_fixed() const;
...@@ -132,11 +134,12 @@ struct shape ...@@ -132,11 +134,12 @@ struct shape
shape(type_t t, std::vector<dynamic_dimension> dims); shape(type_t t, std::vector<dynamic_dimension> dims);
// Construct a dynamic shape from three sets of lengths (of the same rank) // Construct a dynamic shape from vectors of mins, maxes, and optimals.
// optimals_list is a vector of optimals that corresponds to each min and max.
shape(type_t t, shape(type_t t,
std::vector<std::size_t> mins, std::vector<std::size_t> mins,
std::vector<std::size_t> maxes, std::vector<std::size_t> maxes,
std::vector<std::size_t> opts); std::vector<std::set<std::size_t>> optimals_list);
template <class Range> template <class Range>
shape(type_t t, const Range& l) : shape(t, std::vector<std::size_t>(l.begin(), l.end())) shape(type_t t, const Range& l) : shape(t, std::vector<std::size_t>(l.begin(), l.end()))
...@@ -186,21 +189,21 @@ struct shape ...@@ -186,21 +189,21 @@ struct shape
/*! /*!
* Minimum lengths for dynamic shape. * Minimum lengths for dynamic shape.
* lens() for fixed shape. * lens() for static shape.
*/ */
std::vector<std::size_t> min_lens() const; std::vector<std::size_t> min_lens() const;
/*! /*!
* Maximum lengths for dynamic shape. * Maximum lengths for dynamic shape.
* lens() for fixed shape. * lens() for static shape.
*/ */
std::vector<std::size_t> max_lens() const; std::vector<std::size_t> max_lens() const;
/*! /*!
* Optimum lengths for dynamic shape. * Optimum lengths for dynamic shape.
* lens() for fixed shape. * Empty for static shape.
*/ */
std::vector<std::size_t> opt_lens() const; std::vector<std::set<std::size_t>> opt_lens() const;
/// Map multiple indices to space index /// Map multiple indices to space index
std::size_t index(std::initializer_list<std::size_t> l) const; std::size_t index(std::initializer_list<std::size_t> l) const;
...@@ -253,9 +256,12 @@ struct shape ...@@ -253,9 +256,12 @@ struct shape
shape with_type(type_t t) const; shape with_type(type_t t) const;
// convert the shape to an equivalent dynamic shape // convert the shape to an equivalent dynamic shape with empty optimals
shape to_dynamic() const; shape to_dynamic() const;
// convert the shape to a static one setting any non-fixed dynamic_dimensions to x
shape to_static(std::size_t x) const;
friend bool operator==(const shape& x, const shape& y); friend bool operator==(const shape& x, const shape& y);
friend bool operator!=(const shape& x, const shape& y); friend bool operator!=(const shape& x, const shape& y);
friend std::ostream& operator<<(std::ostream& os, const shape& x); friend std::ostream& operator<<(std::ostream& os, const shape& x);
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_SPLIT_SINGLE_DYN_DIM_HPP
#define MIGRAPHX_GUARD_RTGLIB_SPLIT_SINGLE_DYN_DIM_HPP
#include <string>
#include <migraphx/pass_manager.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
/**
* Split dynamic dimension over submodules if exactly one dimension in the parameter list is
* dynamic.
*/
struct split_single_dyn_dim
{
std::string name() const { return "split_single_dyn_dim"; }
void apply(module_pass_manager&) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -595,6 +595,14 @@ std::vector<shape> module::get_output_shapes() const ...@@ -595,6 +595,14 @@ std::vector<shape> module::get_output_shapes() const
} }
} }
std::vector<instruction_ref> module::get_returns() const
{
auto last = std::prev(this->end());
if(last->name() == "@return")
return last->inputs();
return {last};
}
instruction_ref module::validate() const instruction_ref module::validate() const
{ {
return std::find_if( return std::find_if(
......
...@@ -172,6 +172,22 @@ struct vector_stream ...@@ -172,6 +172,22 @@ struct vector_stream
} }
}; };
struct writer_stream
{
std::function<void(const char*, std::size_t)> writer;
writer_stream& write(const char* b, std::size_t n)
{
writer(b, n);
return *this;
}
};
void to_msgpack(const value& v, std::function<void(const char*, std::size_t)> writer)
{
writer_stream ws{std::move(writer)};
msgpack::pack(ws, v);
}
std::vector<char> to_msgpack(const value& v) std::vector<char> to_msgpack(const value& v)
{ {
vector_stream vs; vector_stream vs;
......
...@@ -94,7 +94,7 @@ struct onnx_parser ...@@ -94,7 +94,7 @@ struct onnx_parser
node_map nodes; node_map nodes;
std::unordered_map<std::string, instruction_ref> instructions; std::unordered_map<std::string, instruction_ref> instructions;
program prog = program(); program prog = program();
shape::dynamic_dimension default_dyn_dim_value = {1, 1, 0}; shape::dynamic_dimension default_dyn_dim_value = {1, 1};
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims; std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims;
std::unordered_map<std::string, std::vector<shape::dynamic_dimension>> map_dyn_input_dims; std::unordered_map<std::string, std::vector<shape::dynamic_dimension>> map_dyn_input_dims;
bool use_dyn_output = false; bool use_dyn_output = false;
......
...@@ -46,14 +46,14 @@ program parse_onnx_from(const onnx_options& options, Ts&&... xs) ...@@ -46,14 +46,14 @@ program parse_onnx_from(const onnx_options& options, Ts&&... xs)
auto dim_val = options.default_dim_value; auto dim_val = options.default_dim_value;
if(dim_val != 0) if(dim_val != 0)
{ {
if(options.default_dyn_dim_value != shape::dynamic_dimension{1, 1, 0}) if(options.default_dyn_dim_value != shape::dynamic_dimension{1, 1})
{ {
MIGRAPHX_THROW("PARSE_ONNX_FROM: both default_dim_value and default_dyn_dim_value" MIGRAPHX_THROW("PARSE_ONNX_FROM: both default_dim_value and default_dyn_dim_value"
"set to non-default value"); "set to non-default value");
} }
else else
{ {
parser.default_dyn_dim_value = {dim_val, dim_val, 0}; parser.default_dyn_dim_value = {dim_val, dim_val};
} }
} }
else else
......
...@@ -496,7 +496,7 @@ shape onnx_parser::parse_type(const onnx::TypeProto& t, ...@@ -496,7 +496,7 @@ shape onnx_parser::parse_type(const onnx::TypeProto& t,
return default_dyn_dim_value; return default_dyn_dim_value;
} }
std::size_t tmp = d.dim_value(); std::size_t tmp = d.dim_value();
return {tmp, tmp, 0}; return {tmp, tmp};
} }
else else
{ {
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/tune_axis.hpp> #include <migraphx/tune_axis.hpp>
#include <migraphx/common.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -47,18 +48,15 @@ struct parse_quantizelinear : op_parser<parse_quantizelinear> ...@@ -47,18 +48,15 @@ struct parse_quantizelinear : op_parser<parse_quantizelinear>
auto input_lens = args[0]->get_shape().lens(); auto input_lens = args[0]->get_shape().lens();
auto n_dim = input_lens.size(); auto n_dim = input_lens.size();
instruction_ref y_scale; instruction_ref y_scale = args[1];
if(args[1]->get_shape().elements() != 1) if(args[1]->get_shape().elements() != 1)
{ {
auto tuned_axis = tune_axis(n_dim, axis, opd.op_name); auto tuned_axis = tune_axis(n_dim, axis, opd.op_name);
y_scale = info.add_instruction( y_scale = info.add_instruction(
make_op("broadcast", {{"axis", tuned_axis}, {"out_lens", input_lens}}), args[1]); make_op("broadcast", {{"axis", tuned_axis}, {"out_lens", input_lens}}), args[1]);
} }
else
{ auto common_args = add_common_args(*info.mod, {args[0], y_scale});
y_scale = info.add_instruction(make_op("multibroadcast", {{"out_lens", input_lens}}),
args[1]);
}
if(args.size() == 3) if(args.size() == 3)
{ {
...@@ -76,10 +74,10 @@ struct parse_quantizelinear : op_parser<parse_quantizelinear> ...@@ -76,10 +74,10 @@ struct parse_quantizelinear : op_parser<parse_quantizelinear>
make_op("multibroadcast", {{"out_lens", input_lens}}), y_zero_point); make_op("multibroadcast", {{"out_lens", input_lens}}), y_zero_point);
} }
return info.add_instruction(make_op("quantizelinear"), args[0], y_scale, y_zero_point); common_args.push_back(y_zero_point);
} }
return info.add_instruction(make_op("quantizelinear"), args[0], y_scale); return info.add_instruction(make_op("quantizelinear"), common_args);
} }
}; };
......
...@@ -53,8 +53,8 @@ struct parse_reshape : op_parser<parse_reshape> ...@@ -53,8 +53,8 @@ struct parse_reshape : op_parser<parse_reshape>
s.visit([&](auto v) { copy(v, std::back_inserter(dims)); }); s.visit([&](auto v) { copy(v, std::back_inserter(dims)); });
} }
return info.add_instruction(make_op("reshape", {{"dims", dims}}), auto cont = info.add_instruction(make_op("contiguous"), args[0]);
info.make_contiguous(args[0])); return info.add_instruction(make_op("reshape", {{"dims", dims}}), cont);
} }
}; };
......
...@@ -86,12 +86,21 @@ struct module_pm : module_pass_manager ...@@ -86,12 +86,21 @@ struct module_pm : module_pass_manager
assert(mod); assert(mod);
return *mod; return *mod;
} }
virtual module* create_module(const std::string& name) override virtual module* create_module(const std::string& name) override
{ {
assert(prog); assert(prog);
return prog->create_module(name); return prog->create_module(name);
} }
virtual module* get_common_parent() override { return common_parent; } virtual module* get_common_parent() override { return common_parent; }
virtual module* get_root_module() override
{
assert(prog);
return prog->get_main_module();
}
virtual void run_pass(const pass& p) override virtual void run_pass(const pass& p) override
{ {
assert(mod); assert(mod);
......
...@@ -38,27 +38,42 @@ std::function<void(const char*)> redirect_to(std::ostream& os) ...@@ -38,27 +38,42 @@ std::function<void(const char*)> redirect_to(std::ostream& os)
return [&](const char* x) { os << x; }; return [&](const char* x) { os << x; };
} }
int exec(const std::string& cmd, const std::function<void(const char*)>& std_out) template <class F>
int exec(const std::string& cmd, const char* type, F f)
{ {
int ec = 0; int ec = 0;
if(enabled(MIGRAPHX_TRACE_CMD_EXECUTE{})) if(enabled(MIGRAPHX_TRACE_CMD_EXECUTE{}))
std::cout << cmd << std::endl; std::cout << cmd << std::endl;
auto closer = [&](FILE* stream) { auto closer = [&](FILE* stream) {
auto status = pclose(stream); auto status = pclose(stream);
ec = WIFEXITED(status) ? 0 : WEXITSTATUS(status); // NOLINT ec = WIFEXITED(status) ? WEXITSTATUS(status) : 0; // NOLINT
}; };
{ {
// TODO: Use execve instead of popen // TODO: Use execve instead of popen
std::unique_ptr<FILE, decltype(closer)> pipe(popen(cmd.c_str(), "r"), closer); // NOLINT std::unique_ptr<FILE, decltype(closer)> pipe(popen(cmd.c_str(), type), closer); // NOLINT
if(not pipe) if(not pipe)
MIGRAPHX_THROW("popen() failed: " + cmd); MIGRAPHX_THROW("popen() failed: " + cmd);
std::array<char, 128> buffer; f(pipe.get());
while(fgets(buffer.data(), buffer.size(), pipe.get()) != nullptr)
std_out(buffer.data());
} }
return ec; return ec;
} }
int exec(const std::string& cmd, const std::function<void(const char*)>& std_out)
{
return exec(cmd, "r", [&](FILE* f) {
std::array<char, 128> buffer;
while(fgets(buffer.data(), buffer.size(), f) != nullptr)
std_out(buffer.data());
});
}
int exec(const std::string& cmd, std::function<void(process::writer)> std_in)
{
return exec(cmd, "w", [&](FILE* f) {
std_in([&](const char* buffer, std::size_t n) { std::fwrite(buffer, 1, n, f); });
});
}
struct process_impl struct process_impl
{ {
std::string command{}; std::string command{};
...@@ -72,6 +87,15 @@ struct process_impl ...@@ -72,6 +87,15 @@ struct process_impl
result += command; result += command;
return result; return result;
} }
template <class... Ts>
void check_exec(Ts&&... xs) const
{
int ec = migraphx::exec(std::forward<Ts>(xs)...);
if(ec != 0)
MIGRAPHX_THROW("Command " + get_command() + " exited with status " +
std::to_string(ec));
}
}; };
process::process(const std::string& cmd) : impl(std::make_unique<process_impl>()) process::process(const std::string& cmd) : impl(std::make_unique<process_impl>())
...@@ -95,12 +119,11 @@ process& process::cwd(const fs::path& p) ...@@ -95,12 +119,11 @@ process& process::cwd(const fs::path& p)
return *this; return *this;
} }
void process::exec() void process::exec() { impl->check_exec(impl->get_command(), redirect_to(std::cout)); }
void process::write(std::function<void(process::writer)> pipe_in)
{ {
auto ec = migraphx::exec(impl->get_command(), redirect_to(std::cout)); impl->check_exec(impl->get_command(), std::move(pipe_in));
if(ec != 0)
MIGRAPHX_THROW("Command " + impl->get_command() + " exited with status " +
std::to_string(ec));
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment