Commit 4f07b8f1 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'develop' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into test_branch_for_ort2

parents af110526 1e0bbd78
#ifndef MIGRAPHX_GUARD_GPU_COMPILER_HPP
#define MIGRAPHX_GUARD_GPU_COMPILER_HPP
#include <migraphx/config.hpp>
#include <migraphx/auto_register.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/value.hpp>
#include <migraphx/module.hpp>
#include <migraphx/instruction.hpp>
#include <functional>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct context;
using compiler_replace = std::function<void(module& m, instruction_ref ins)>;
using compiler_compile = std::function<compiler_replace(context&, instruction_ref, operation)>;
using compiler_compile_op =
std::function<operation(context&, const std::vector<shape>& inputs, const value&)>;
void register_compiler(const std::string& name, compiler_compile c, compiler_compile_op cop);
bool has_compiler_for(const std::string& name);
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op);
operation
compile_op(const std::string& name, context& ctx, const std::vector<shape>& inputs, const value& v);
template <class T>
void register_compiler()
{
T c;
for(auto&& name : c.names())
{
register_compiler(
name,
[=](auto&&... xs) { return c.compile(std::forward<decltype(xs)>(xs)...); },
[=](auto&&... xs) { return c.compile_op(std::forward<decltype(xs)>(xs)...); });
}
}
struct register_compiler_action
{
template <class T>
static void apply()
{
register_compiler<T>();
}
};
template <class T>
using auto_register_compiler = auto_register<register_compiler_action, T>;
template <class Derived>
struct compiler : auto_register_compiler<Derived>
{
auto replace(const operation& op) const
{
return
[=](module& m, instruction_ref ins) { m.replace_instruction(ins, op, ins->inputs()); };
}
operation compile_op(context&, const std::vector<shape>&, const value&) const { return {}; }
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_GPU_COMPILER_HPP
......@@ -154,6 +154,13 @@ struct hip_device
std::size_t get_cu_count() const { return device_props.multiProcessorCount; }
std::size_t get_max_workitems_per_cu() const
{
return device_props.maxThreadsPerMultiProcessor;
}
std::size_t get_max_workitems_per_block() const { return device_props.maxThreadsPerBlock; }
private:
std::size_t device_id = 0;
std::size_t current_stream = 0;
......
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
static const char* const pointwise_kernel = R"__migraphx__(
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/pointwise.hpp>
#include <args.hpp>
namespace migraphx {
${preamble}
extern "C" {
__global__ void kernel(${params})
{
pointwise(${lambda}, ${args});
}
}
} // namespace migraphx
)__migraphx__";
struct pointwise_compiler : compiler<pointwise_compiler>
{
std::vector<std::string> names() const { return {"pointwise"}; }
static std::size_t oversubscribe(const std::vector<shape>& inputs)
{
if(std::any_of(inputs.begin(), inputs.end(), [](const auto& s) { return s.broadcasted(); }))
return 1;
else
return 4;
}
static std::size_t vectorize_elements(const std::vector<shape>& inputs)
{
std::size_t n = inputs.front().elements();
if(std::all_of(inputs.begin(), inputs.end(), [](const auto& s) {
return s.packed() or s.broadcasted();
}))
{
if((n % 4) == 0)
return n / 4;
else if((n % 2) == 0)
return n / 2;
}
return n;
}
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
hip_compile_options options;
options.set_launch_params(
v, compute_global_for(ctx, vectorize_elements(inputs), oversubscribe(inputs)));
options.inputs = inputs;
options.output = inputs.back();
options.virtual_inputs = reduce_dims(inputs);
options.params = "-Wno-float-equal";
auto src = interpolate_string(pointwise_kernel,
{{"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")},
{"lambda", v.at("lambda").to<std::string>()},
{"preamble", v.get("preamble", std::string{})}});
return compile_hip_code_object(src, options);
}
compiler_replace compile(context& ctx, instruction_ref ins, const operation&) const
{
assert(not ins->module_inputs().empty());
auto* pm = ins->module_inputs().front();
run_passes(*pm, {eliminate_common_subexpression{}, dead_code_elimination{}});
cpp_generator g;
g.fmap([](const std::string& fname) { return "migraphx::" + fname; });
g.add_point_op("where", "${function:where}(${0}, ${1}, ${2})");
g.add_point_op("prelu", "${function:where}(${0} < 0, ${0} * ${1}, ${0})");
g.add_point_op("sign",
"${function:where}(${0} > 0, 1, ${function:where}(${0} < 0, -1, 0))");
g.add_point_op("equal", "migraphx::abs(${0} == ${1})");
g.add_point_op("less", "migraphx::abs(${0} < ${1})");
g.add_point_op("greater", "migraphx::abs(${0} > ${1})");
g.add_point_op("not", "migraphx::abs(not ${0})");
// Add explict conversions
g.fresult(
[](const shape& s) { return "migraphx::convert<" + shape::cpp_type(s.type()) + ">"; });
auto name = g.create_function(
g.generate_module(*pm).set_attributes({"__device__"}).set_generic_types(*pm));
std::string lambda = "MIGRAPHX_LIFT(" + name + ")";
return replace(
compile_op(ctx, to_shapes(ins->inputs()), {{"lambda", lambda}, {"preamble", g.str()}}));
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/gpu/compile_roialign.hpp>
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -37,46 +43,46 @@ __global__ void roialign_kernel(void* in_x, void* in_rois, void* in_ind, void* y
} // namespace migraphx
int main() {}
)__migraphx__";
operation compile_roialign(context&, const std::vector<shape>& io_shapes, const value& val)
struct roialign_compiler : compiler<roialign_compiler>
{
hip_compile_options options;
auto out_s = io_shapes.back();
options.local = 128;
options.global = compute_global(out_s.elements(), options.local);
options.inputs = io_shapes;
options.output = out_s;
options.kernel_name = "roialign_kernel";
options.virtual_inputs = io_shapes;
// sampling_ratio
assert(val.contains("sampling_ratio"));
auto sampling_ratio = val.at("sampling_ratio").to<int64_t>();
options.params += " -DSAMPLING_RATIO=" + std::to_string(sampling_ratio);
// pooling_mode
assert(val.contains("mode"));
auto mode = val.at("mode").to<migraphx::op::pooling_mode>();
bool is_avg_pooling = (mode == migraphx::op::pooling_mode::average);
options.params += " -DIS_AVG_POOLING=" + std::to_string(static_cast<int>(is_avg_pooling));
// coord_trans_mode
assert(val.contains("coordinate_transformation_mode"));
auto ctm = val.at("coordinate_transformation_mode").to<std::string>();
float rois_offset = (ctm == "output_half_pixel") ? -0.5f : 0.0f;
options.params += " -DROIS_OFFSET=" + std::to_string(rois_offset);
// spatial_scale
assert(val.contains("spatial_scale"));
float spatial_scale = val.at("spatial_scale").to<float>();
options.params += " -DSPATIAL_SCALE=" + std::to_string(spatial_scale);
return compile_hip_code_object(roialign_kernel, options);
}
} // namespace gpu
std::vector<std::string> names() const { return {"roialign"}; }
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
hip_compile_options options;
options.set_launch_params(v, compute_global_for(ctx, inputs.back().elements()), 128);
options.output = inputs.back();
options.inputs = inputs;
options.kernel_name = "roialign_kernel";
// sampling_ratio
options.params += " -DSAMPLING_RATIO=" + v.at("sampling_ratio").to<std::string>();
// pooling_mode
auto mode = v.at("mode").to<migraphx::op::pooling_mode>();
std::string is_avg_pooling =
(mode == migraphx::op::pooling_mode::average) ? "true" : "false";
options.params += " -DIS_AVG_POOLING=" + is_avg_pooling;
// coord_trans_mode
auto ctm = v.at("coordinate_transformation_mode").to<std::string>();
float rois_offset = (ctm == "output_half_pixel") ? -0.5f : 0.0f;
options.params += " -DROIS_OFFSET=" + std::to_string(rois_offset);
// spatial_scale
options.params += " -DSPATIAL_SCALE=" + v.at("spatial_scale").to<std::string>();
return compile_hip_code_object(roialign_kernel, options);
}
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
return replace(compile_op(ctx, to_shapes(ins->inputs()), op.to_value()));
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/gpu/compile_scatternd.hpp>
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -25,7 +31,7 @@ extern "C" {
__global__ void scatternd_kernel(void* in_indices, void* in_updates, void* output)
{
make_tensors()(in_indices, in_updates, output)([](auto&&... xs) {
scatternd(xs..., REDUCTION);
scatternd(xs..., ${reduction}{});
});
}
......@@ -33,28 +39,50 @@ __global__ void scatternd_kernel(void* in_indices, void* in_updates, void* outpu
} // namespace migraphx
int main() {}
)__migraphx__";
operation
compile_scatternd(context&, const std::vector<shape>& io_shapes, const std::string& reduction)
struct scatternd_compiler : compiler<scatternd_compiler>
{
hip_compile_options options;
auto out_s = io_shapes.back();
options.local = 1024;
options.global = compute_global(io_shapes.at(1).elements(), options.local);
options.inputs = io_shapes;
options.output = out_s;
options.kernel_name = "scatternd_kernel";
options.virtual_inputs = io_shapes;
options.params += " -DREDUCTION=assign_" + reduction + "{}";
return compile_hip_code_object(scatternd_kernel, options);
}
std::vector<std::string> names() const
{
return {"scatternd_none", "scatternd_add", "scatternd_mul"};
}
} // namespace gpu
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
hip_compile_options options;
options.set_launch_params(v, compute_global_for(ctx, inputs.at(1).elements()));
auto out_s = inputs.back();
options.inputs = inputs;
options.output = out_s;
options.kernel_name = "scatternd_kernel";
options.virtual_inputs = inputs;
auto reduction = "assign_" + v.get("reduction", std::string{"none"});
auto src = interpolate_string(scatternd_kernel, {{"reduction", reduction}});
return compile_hip_code_object(src, options);
}
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
assert(starts_with(op.name(), "scatternd_"));
auto reduction = op.name().substr(10);
return insert(compile_op(ctx,
to_shapes({ins->inputs().begin() + 1, ins->inputs().end()}),
{{"reduction", reduction}}));
}
compiler_replace insert(const operation& op) const
{
return [=](module& m, instruction_ref ins) {
auto args = ins->inputs();
args.back() =
m.insert_instruction(ins, make_op("hip::copy"), args.front(), args.back());
args.erase(args.begin());
return m.replace_instruction(ins, op, args);
};
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -20,8 +20,6 @@
#include <migraphx/gpu/abs.hpp>
#include <migraphx/gpu/batch_norm_inference.hpp>
#include <migraphx/gpu/compile_roialign.hpp>
#include <migraphx/gpu/compile_scatternd.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/convolution.hpp>
#include <migraphx/gpu/deconvolution.hpp>
......@@ -42,6 +40,7 @@
#include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/unary_not.hpp>
#include <migraphx/gpu/where.hpp>
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/program.hpp>
#include <utility>
......@@ -195,8 +194,6 @@ struct miopen_apply
add_extend_op("softmax");
add_extend_op("topk");
add_precompile_op("pointwise");
add_batch_norm_inference_op();
add_convolution_op();
add_deconvolution_op();
......@@ -207,8 +204,6 @@ struct miopen_apply
add_neg_op();
add_nms_op();
add_quant_convolution_op();
add_roialign();
add_scatternd();
}
void copy_params()
......@@ -262,11 +257,28 @@ struct miopen_apply
{
check_shape(s, apply_map.at(it->name())(it));
}
else if(has_compiler_for(it->name()))
{
check_shape(s, insert_precompile_op(it));
}
}
copy_params();
}
instruction_ref insert_precompile_op(instruction_ref ins)
{
auto output = insert_allocation(ins, ins->get_shape());
std::vector<instruction_ref> refs = ins->inputs();
refs.push_back(output);
return mod->replace_instruction(
ins,
make_op("gpu::precompile_op", {{"op", to_value(ins->get_operator())}}),
refs,
ins->module_inputs());
}
instruction_ref insert_allocation(instruction_ref ins, const shape& s, std::string tag = "")
{
// Instruction's output is an input of the ret instruction
......@@ -396,21 +408,6 @@ struct miopen_apply
});
}
void add_precompile_op(const std::string& name)
{
apply_map.emplace(name, [=](instruction_ref ins) {
auto output = insert_allocation(ins, ins->get_shape());
std::vector<instruction_ref> refs = ins->inputs();
refs.push_back(output);
return mod->replace_instruction(
ins,
make_op("gpu::precompile_op", {{"op", to_value(ins->get_operator())}}),
refs,
ins->module_inputs());
});
}
void add_batch_norm_inference_op()
{
apply_map.emplace("batch_norm_inference", [=](instruction_ref ins) {
......@@ -501,75 +498,6 @@ struct miopen_apply
});
}
void add_roialign()
{
apply_map.emplace("roialign", [=](instruction_ref ins) {
auto s = ins->get_shape();
auto op_val = ins->get_operator().to_value();
auto output = insert_allocation(ins, s);
auto args = ins->inputs();
args.push_back(output);
auto io_shapes = to_shapes(args);
auto co = compile_roialign(get_context(), io_shapes, op_val);
return mod->replace_instruction(ins, co, args);
});
}
void add_scatternd()
{
apply_map.emplace("scatternd_none", [=](instruction_ref ins) {
auto s = ins->get_shape();
auto op_val = ins->get_operator().to_value();
auto output = insert_allocation(ins, s);
auto args = ins->inputs();
args.push_back(output);
auto io_shapes = to_shapes(args);
io_shapes.erase(io_shapes.begin());
const std::string reduction = "none";
auto co = compile_scatternd(get_context(), io_shapes, reduction);
auto copy = mod->insert_instruction(ins, make_op("hip::copy"), args.front(), output);
args.back() = copy;
args.erase(args.begin());
return mod->replace_instruction(ins, co, args);
});
apply_map.emplace("scatternd_add", [=](instruction_ref ins) {
auto s = ins->get_shape();
auto op_val = ins->get_operator().to_value();
auto output = insert_allocation(ins, s);
auto args = ins->inputs();
args.push_back(output);
auto io_shapes = to_shapes(args);
io_shapes.erase(io_shapes.begin());
const std::string reduction = "add";
auto co = compile_scatternd(get_context(), io_shapes, reduction);
auto copy = mod->insert_instruction(ins, make_op("hip::copy"), args.front(), output);
args.back() = copy;
args.erase(args.begin());
return mod->replace_instruction(ins, co, args);
});
apply_map.emplace("scatternd_mul", [=](instruction_ref ins) {
auto s = ins->get_shape();
auto op_val = ins->get_operator().to_value();
auto output = insert_allocation(ins, s);
auto args = ins->inputs();
args.push_back(output);
auto io_shapes = to_shapes(args);
io_shapes.erase(io_shapes.begin());
const std::string reduction = "mul";
auto co = compile_scatternd(get_context(), io_shapes, reduction);
auto copy = mod->insert_instruction(ins, make_op("hip::copy"), args.front(), output);
args.back() = copy;
args.erase(args.begin());
return mod->replace_instruction(ins, co, args);
});
}
// replace the loop operator with gpu_loop operator
void add_loop_op()
{
......
......@@ -4,6 +4,7 @@
#include <migraphx/errors.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/value.hpp>
#include <migraphx/optional.hpp>
#include <unordered_map>
#include <utility>
......@@ -138,6 +139,7 @@ value::value(const std::string& pkey, const value& rhs)
{
}
value::value(const std::string& pkey, const char* i) : value(pkey, std::string(i)) {}
value::value(const char* i) : value(std::string(i)) {}
#define MIGRAPHX_VALUE_GENERATE_DEFINE_METHODS(vt, cpp_type) \
......@@ -161,6 +163,12 @@ value::value(const char* i) : value(std::string(i)) {}
const cpp_type* value::if_##vt() const { return x ? x->if_##vt() : nullptr; }
MIGRAPHX_VISIT_VALUE_TYPES(MIGRAPHX_VALUE_GENERATE_DEFINE_METHODS)
value& value::operator=(const char* c)
{
*this = std::string{c};
return *this;
}
value& value::operator=(std::nullptr_t)
{
x = nullptr;
......@@ -410,25 +418,12 @@ value value::with_key(const std::string& pkey) const
return result;
}
template <class F, class T, class U, class Common = typename std::common_type<T, U>::type>
auto compare_common_impl(
rank<1>, F f, const std::string& keyx, const T& x, const std::string& keyy, const U& y)
{
return f(std::forward_as_tuple(keyx, Common(x)), std::forward_as_tuple(keyy, Common(y)));
}
template <class F>
auto compare_common_impl(
rank<1>, F f, const std::string& keyx, std::nullptr_t, const std::string& keyy, std::nullptr_t)
{
return f(std::forward_as_tuple(keyx, 0), std::forward_as_tuple(keyy, 0));
}
template <class F, class T, class U>
auto compare_common_impl(rank<0>, F, const std::string&, const T&, const std::string&, const U&)
template <class T>
const T& compare_decay(const T& x)
{
return false;
return x;
}
int compare_decay(std::nullptr_t) { return 0; }
template <class F>
bool compare(const value& x, const value& y, F f)
......@@ -436,7 +431,11 @@ bool compare(const value& x, const value& y, F f)
bool result = false;
x.visit_value([&](auto&& a) {
y.visit_value([&](auto&& b) {
result = compare_common_impl(rank<1>{}, f, x.get_key(), a, y.get_key(), b);
if constexpr(std::is_same<decltype(a), decltype(b)>{})
result = f(std::forward_as_tuple(x.get_key(), compare_decay(a)),
std::forward_as_tuple(y.get_key(), compare_decay(b)));
else
assert(false); // NOLINT
});
});
return result;
......@@ -455,11 +454,16 @@ bool operator==(const value& x, const value& y)
return false;
return compare(x, y, std::equal_to<>{});
}
bool operator!=(const value& x, const value& y) { return !(x == y); }
bool operator<(const value& x, const value& y) { return compare(x, y, std::less<>{}); }
bool operator<=(const value& x, const value& y) { return x == y or x < y; }
bool operator!=(const value& x, const value& y) { return not(x == y); }
bool operator<(const value& x, const value& y)
{
if(x.get_type() != y.get_type())
return x.get_type() < y.get_type();
return compare(x, y, std::less<>{});
}
bool operator<=(const value& x, const value& y) { return not(x > y); }
bool operator>(const value& x, const value& y) { return y < x; }
bool operator>=(const value& x, const value& y) { return x == y or x > y; }
bool operator>=(const value& x, const value& y) { return not(x < y); }
void print_value(std::ostream& os, std::nullptr_t) { os << "null"; }
......
......@@ -11,6 +11,7 @@ function(add_api_test TEST_NAME TEST_SRC TEST_DIR)
endfunction()
add_api_test(assign test_assign.cpp ${TEST_ONNX_DIR})
add_api_test(custom_op test_custom_op.cpp ${TEST_ONNX_DIR})
add_api_test(compile_options test_compile_options.cpp ${TEST_ONNX_DIR})
add_api_test(lookup test_lookup.cpp ${TEST_ONNX_DIR})
add_api_test(module_construct test_module_construct.cpp ${TEST_ONNX_DIR})
......
#include <migraphx/migraphx.h>
#include <migraphx/migraphx.hpp>
#include "test.hpp"
struct simple_custom_op final : migraphx::experimental_custom_op_base
{
virtual std::string name() const override { return "simple_custom_op"; }
virtual migraphx::shape compute_shape(migraphx::shapes inputs) const override
{
return inputs.front();
}
};
TEST_CASE(register_custom_op)
{
simple_custom_op simple_op;
migraphx::register_experimental_custom_op(simple_op);
auto op = migraphx::operation("simple_custom_op");
EXPECT(op.name() == "simple_custom_op");
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
#include <migraphx/migraphx.hpp>
#include <migraphx/rank.hpp>
#include "test.hpp"
template <class T>
......
......@@ -10,7 +10,7 @@
#include <migraphx/gpu/device_name.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_pointwise.hpp>
#include <migraphx/gpu/compiler.hpp>
// NOLINTNEXTLINE
const std::string write_2s = R"__migraphx__(
......@@ -230,7 +230,8 @@ TEST_CASE(compile_pointwise)
migraphx::shape input{migraphx::shape::float_type, {5, 2}};
migraphx::gpu::context ctx;
auto co = migraphx::gpu::compile_pointwise(ctx, {input, input}, "[](auto x) { return x + 1; }");
auto co = migraphx::gpu::compile_op(
"pointwise", ctx, {input, input}, {{"lambda", "[](auto x) { return x + 1; }"}});
migraphx::program p;
auto* mm = p.get_main_module();
......
......@@ -68,9 +68,9 @@ struct nop
{
static std::string as_string() { return ""; }
template <class T>
static decltype(auto) call(T&& x)
static auto call(T&& x)
{
return x;
return static_cast<T&&>(x);
}
};
......@@ -113,6 +113,33 @@ inline auto operator<<(Stream& s, const Range& v) -> decltype(stream_range(s, v.
return s;
}
template <class T>
const T& get_value(const T& x)
{
return x;
}
template <class T, class Operator = nop>
struct lhs_expression;
template <class T>
lhs_expression<T> make_lhs_expression(T&& lhs);
template <class T, class Operator>
lhs_expression<T, Operator> make_lhs_expression(T&& lhs, Operator);
// NOLINTNEXTLINE
#define TEST_EXPR_BINARY_OPERATOR(op, name) \
template <class V> \
auto operator op(const V& rhs2) const \
{ \
return make_expression(*this, rhs2, name{}); /* NOLINT */ \
}
// NOLINTNEXTLINE
#define TEST_EXPR_UNARY_OPERATOR(op, name) \
auto operator op() const { return make_lhs_expression(lhs, name{}); /* NOLINT */ }
template <class T, class U, class Operator>
struct expression
{
......@@ -125,7 +152,12 @@ struct expression
return s;
}
decltype(auto) value() const { return Operator::call(lhs, rhs); };
friend decltype(auto) get_value(const expression& e) { return e.value(); }
decltype(auto) value() const { return Operator::call(get_value(lhs), get_value(rhs)); };
TEST_FOREACH_UNARY_OPERATORS(TEST_EXPR_UNARY_OPERATOR)
TEST_FOREACH_BINARY_OPERATORS(TEST_EXPR_BINARY_OPERATOR)
};
// TODO: Remove rvalue references
......@@ -135,9 +167,6 @@ expression<T, U, Operator> make_expression(T&& rhs, U&& lhs, Operator)
return {std::forward<T>(rhs), std::forward<U>(lhs)};
}
template <class T, class Operator = nop>
struct lhs_expression;
// TODO: Remove rvalue reference
template <class T>
lhs_expression<T> make_lhs_expression(T&& lhs)
......@@ -166,22 +195,12 @@ struct lhs_expression
return s;
}
decltype(auto) value() const { return Operator::call(lhs); }
// NOLINTNEXTLINE
#define TEST_LHS_BINARY_OPERATOR(op, name) \
template <class U> \
auto operator op(const U& rhs) const \
{ \
return make_expression(lhs, rhs, name{}); /* NOLINT */ \
}
friend decltype(auto) get_value(const lhs_expression& e) { return e.value(); }
TEST_FOREACH_BINARY_OPERATORS(TEST_LHS_BINARY_OPERATOR)
decltype(auto) value() const { return Operator::call(get_value(lhs)); }
// NOLINTNEXTLINE
#define TEST_LHS_UNARY_OPERATOR(op, name) \
auto operator op() const { return make_lhs_expression(lhs, name{}); /* NOLINT */ }
TEST_FOREACH_UNARY_OPERATORS(TEST_LHS_UNARY_OPERATOR)
TEST_FOREACH_BINARY_OPERATORS(TEST_EXPR_BINARY_OPERATOR)
TEST_FOREACH_UNARY_OPERATORS(TEST_EXPR_UNARY_OPERATOR)
// NOLINTNEXTLINE
#define TEST_LHS_REOPERATOR(op) \
......@@ -223,6 +242,13 @@ auto make_predicate(const std::string& msg, F f)
return make_lhs_expression(predicate<F>{msg, f}, function{});
}
inline std::string as_string(bool x)
{
if(x)
return "true";
return "false";
}
template <class T>
std::string as_string(const T& x)
{
......@@ -627,18 +653,21 @@ inline void run(int argc, const char* argv[])
} // namespace test
// NOLINTNEXTLINE
#define TEST_CAPTURE(...) test::capture{}->*__VA_ARGS__
// NOLINTNEXTLINE
#define CHECK(...) \
test::failed( \
test::capture{}->*__VA_ARGS__, #__VA_ARGS__, __PRETTY_FUNCTION__, __FILE__, __LINE__, [] { \
})
// NOLINTNEXTLINE
#define EXPECT(...) \
test::failed(test::capture{}->*__VA_ARGS__, \
#__VA_ARGS__, \
__PRETTY_FUNCTION__, \
__FILE__, \
__LINE__, \
#define EXPECT(...) \
test::failed(TEST_CAPTURE(__VA_ARGS__), \
#__VA_ARGS__, \
__PRETTY_FUNCTION__, \
__FILE__, \
__LINE__, \
&test::fail)
// NOLINTNEXTLINE
#define STATUS(...) EXPECT((__VA_ARGS__) == 0)
......
......@@ -27,6 +27,7 @@ add_py_test(ref test_cpu.py WORKING_DIRECTORY ${TEST_ONNX_DIR})
add_py_test(save_load test_save_load.py WORKING_DIRECTORY ${TEST_ONNX_DIR})
add_py_test(op test_op.py WORKING_DIRECTORY ${TEST_ONNX_DIR})
add_py_test(shape test_shape.py WORKING_DIRECTORY ${TEST_ONNX_DIR})
add_py_test(module_construct test_module_construct.py WORKING_DIRECTORY ${TEST_ONNX_DIR})
if(MIGRAPHX_ENABLE_GPU)
add_py_test(gpu_offload test_gpu_offload.py WORKING_DIRECTORY ${TEST_ONNX_DIR})
add_py_test(gpu test_gpu.py WORKING_DIRECTORY ${TEST_ONNX_DIR})
......
import migraphx
def test_add_op():
p = migraphx.program()
mm = p.get_main_module()
param_shape = migraphx.shape(lens=[3, 3], type="float")
x = mm.add_parameter("x", param_shape)
y = mm.add_parameter("y", param_shape)
add_op = mm.add_instruction(migraphx.op("add"), [x, y])
mm.add_return([add_op])
p.compile(migraphx.get_target("ref"))
params = {}
params["x"] = migraphx.generate_argument(param_shape)
params["y"] = migraphx.generate_argument(param_shape)
output = p.run(params)[-1].tolist()
assert output == [
a + b for a, b in zip(params["x"].tolist(), params["y"].tolist())
]
def test_if_then_else():
param_shape = migraphx.shape(lens=[3, 3], type="float")
cond_shape = migraphx.shape(type="bool", lens=[1], strides=[0])
def create_program():
p = migraphx.program()
mm = p.get_main_module()
cond = mm.add_parameter("cond", cond_shape)
x = mm.add_parameter("x", param_shape)
y = mm.add_parameter("y", param_shape)
then_mod = p.create_module("If_0_if")
x_identity = then_mod.add_instruction(migraphx.op("identity"), [x])
then_mod.add_return([x_identity])
else_mod = p.create_module("If_0_else")
y_identity = else_mod.add_instruction(migraphx.op("identity"), [y])
else_mod.add_return([y_identity])
if_ins = mm.add_instruction(migraphx.op("if"), [cond],
[then_mod, else_mod])
ret = mm.add_instruction(migraphx.op("get_tuple_elem", **{"index": 0}),
[if_ins])
mm.add_return([ret])
return p
params = {}
params["x"] = migraphx.generate_argument(param_shape)
params["y"] = migraphx.generate_argument(param_shape)
def run_prog(cond):
p = create_program()
p.compile(migraphx.get_target("ref"))
params["cond"] = migraphx.fill_argument(cond_shape, cond)
output = p.run(params)[-1]
return output
assert run_prog(True) == params["x"]
assert run_prog(False) == params["y"]
if __name__ == "__main__":
test_add_op()
test_if_then_else()
......@@ -16,6 +16,17 @@ def test_create_shape_broadcast():
def test_create_shape_type():
s = migraphx.shape(type='uint8')
assert s.type_string() == 'uint8_type'
s = migraphx.shape(type='int64_t')
assert s.type_string() == 'int64_type'
assert s.type_size() == 8
s = migraphx.shape(type='uint8_t')
assert s.type_string() == "uint8_type"
assert s.type_size() == 1
s = migraphx.shape(type='float')
assert s.type_size() == 4
if __name__ == "__main__":
test_create_shape()
test_create_shape_broadcast()
test_create_shape_type()
#include <iostream>
#include <migraphx/program.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/verify.hpp>
#include <migraphx/ref/target.hpp>
#include <migraphx/onnx.hpp>
#include "test.hpp"
/*!
* Example MIGraphX programs for following the Contributor's Guide.
*/
TEST_CASE(add_two_literals)
{
/*!
* Simple MIGraphX program to add two literal values.
* Equivalent to adding two constant scalar values together.
*/
// create the program a get a pointer to the main module
migraphx::program p;
auto* mm = p.get_main_module();
// add two literals to the program
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
// make the "add" operation between the two literals and add it to the program
mm->add_instruction(migraphx::make_op("add"), one, two);
// compile the program on the reference device
p.compile(migraphx::ref::target{});
// evaulate the program and retreive the result
auto result = p.eval({}).back();
std::cout << "add_two_literals: 1 + 2 = " << result << "\n";
EXPECT(result.at<int>() == 3);
}
TEST_CASE(add_parameters)
{
/*!
* Modified version of MIGraphX program seen in add_two_literals to accept a parameter.
* Equivalent to adding a constant scalar value with another scalar input.
*/
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::int32_type, {1}};
// add a "x" parameter with the shape s
auto x = mm->add_parameter("x", s);
auto two = mm->add_literal(2);
// add the "add" instruction between the "x" parameter and "two" to the module
mm->add_instruction(migraphx::make_op("add"), x, two);
p.compile(migraphx::ref::target{});
// create a parameter_map object for passing a value to the "x" parameter
std::vector<int> data = {4};
migraphx::parameter_map params;
params["x"] = migraphx::argument(s, data.data());
auto result = p.eval(params).back();
std::cout << "add_parameters: 4 + 2 = " << result << "\n";
EXPECT(result.at<int>() == 6);
}
TEST_CASE(handling_tensors)
{
/*!
* This example does a convolution operation over an input tensor using the given weighting
* tensor. This is meant to show an example of working with tensors in MIGraphX. The output
* tensor is compared against a precomputed solution tensor at the end of the program.
*/
migraphx::program p;
auto* mm = p.get_main_module();
// create shape objects for the input tensor and weights
migraphx::shape input_shape{migraphx::shape::float_type, {2, 3, 4, 4}};
migraphx::shape weights_shape{migraphx::shape::float_type, {2, 3, 3, 3}};
// create the parameters and add the "convolution" operation to the module
auto input = mm->add_parameter("X", input_shape);
auto weights = mm->add_parameter("W", weights_shape);
mm->add_instruction(migraphx::make_op("convolution", {{"padding", {1, 1}}, {"stride", {2, 2}}}),
input,
weights);
p.compile(migraphx::ref::target{});
// Allocated buffers by the user
std::vector<float> a = {
2.71567607, -0.9960829, 0.91671127, 0.28140706, 0.63235772, 0.08077253, 0.80927712,
-0.59108931, -1.05421555, -2.76622486, -0.85044265, -0.52049929, 0.67726439, -0.65290606,
0.02345525, -0.33579525, 0.38901961, 1.05473483, -1.31188095, 1.8963089, -0.07265259,
0.947339, 0.41949373, -0.70814759, 0.25892952, 1.07311416, 1.2571274, -0.62318051,
-0.19951548, -0.94232577, -0.29393643, 0.42292568, -0.80230367, 1.40909171, 0.63617158,
0.13900366, 1.09253144, -0.15265895, 1.54781747, 0.72780299, 1.09189606, -0.38068101,
0.97057933, -0.58958799, 1.56188643, 0.21474874, 0.58725154, -1.27097559, -0.03024297,
1.09437096, -0.4897908, 0.34838957, -1.31042492, -1.69069934, 0.86956722, -0.40457946,
0.46691212, 1.29273605, 0.26464137, 0.22073045, -1.02178168, 0.22163901, -1.84387338,
0.75522131, -0.45775682, -0.42241111, -1.50944722, 1.07256448, -1.95876884, -0.28106022,
0.3341668, 2.13129425, -1.14728117, -1.06555498, -0.298444, -0.88322699, -0.65866792,
-2.06007552, 0.01374334, 0.45612028, 0.52715492, 1.01914406, -1.72659791, 0.80650896,
0.16860051, 2.24112225, -0.78620857, 0.36566174, -0.07020134, -0.47976932, -0.68230027,
-0.94711417, -0.54506505, 1.66504931, -0.71860826, 0.61132306};
std::vector<float> c = {
-0.14601797, -0.13000923, 0.06521662, 0.06178288, -0.11083675, 0.10154136, 0.09990512,
0.06030385, -0.11374587, -0.17523311, -0.14344215, 0.17802463, 0.06300922, -0.15325832,
0.07066704, 0.05166031, 0.00615084, -0.02606523, 0.08083995, -0.17913306, 0.0624622,
0.0735731, -0.04198661, -0.0164391, -0.06374192, 0.16569914, 0.10681538, 0.07370754,
0.02802075, 0.00282027, 0.15104802, -0.11084409, -0.00197773, 0.07924436, 0.03528272,
0.04765259, -0.15896152, 0.07917164, 0.12125669, -0.1154705, -0.11999125, 0.12749968,
-0.06269585, 0.18658121, -0.03944227, 0.0111798, -0.17731084, 0.11789055, -0.09982193,
0.08142821, 0.0729029, 0.11303909, 0.12735154, 0.03885292};
// Solution vector
std::vector<float> sol = {-0.20817225,
0.87965256,
0.14958936,
-1.24887264,
-0.06540672,
0.20778663,
0.40456355,
-0.99900877,
0.4917807,
0.1994698,
0.64205718,
0.37798831,
-0.25315839,
0.44276932,
-0.16138598,
0.79344082};
// Create the arguments in a parameter_map
migraphx::parameter_map params;
params["X"] = migraphx::argument(input_shape, a.data());
params["W"] = migraphx::argument(weights_shape, c.data());
// Evaluate and confirm the result
auto result = p.eval(params).back();
std::vector<float> results_vector(64);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, sol));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -12,6 +12,11 @@ TEST_CASE(make_target)
}
}
TEST_CASE(make_invalid_target)
{
EXPECT(test::throws([&] { migraphx::make_target("mi100"); }));
}
TEST_CASE(targets)
{
auto ts = migraphx::get_targets();
......
......@@ -57,6 +57,15 @@ TEST_CASE(value_construct_string)
EXPECT(v.get_key().empty());
}
TEST_CASE(value_construct_key_string_literal_pair)
{
// Use parens instead {} to construct to test the key-pair constructor
migraphx::value v("key", "one");
EXPECT(v.is_string());
EXPECT(v.get_string() == "one");
EXPECT(v.get_key() == "key");
}
TEST_CASE(value_construct_float)
{
migraphx::value v = 1.0;
......@@ -167,6 +176,15 @@ TEST_CASE(value_copy_assign_keyless)
EXPECT(v1.without_key() == v2.without_key());
}
TEST_CASE(value_assign_key_string_literal_pair)
{
migraphx::value v = migraphx::value::object{};
v["key"] = "one";
EXPECT(v["key"].is_string());
EXPECT(v["key"].get_string() == "one");
EXPECT(v["key"].get_key() == "key");
}
TEST_CASE(value_construct_array)
{
migraphx::value v = {1, 2, 3};
......@@ -522,6 +540,14 @@ TEST_CASE(value_construct_object_string_mixed_value)
EXPECT(v.at("two").get_int64() == 2);
}
template <class Expression>
auto compare_predicate(const Expression& e)
{
bool result = e.value();
return test::make_predicate(test::as_string(e) + " => " + test::as_string(result),
[=] { return result; });
}
TEST_CASE(value_compare)
{
EXPECT(migraphx::value(1) == migraphx::value(1));
......@@ -535,6 +561,46 @@ TEST_CASE(value_compare)
EXPECT(migraphx::value(2) > migraphx::value(1));
EXPECT(migraphx::value(2) >= migraphx::value(1));
EXPECT(migraphx::value(1) >= migraphx::value(1));
EXPECT(migraphx::value(1) != migraphx::value("1"));
EXPECT(migraphx::value(1) != migraphx::value());
}
// NOLINTNEXTLINE
#define MIGRAPHX_VALUE_TEST_COMPARE(...) compare_predicate(TEST_CAPTURE(__VA_ARGS__))
// NOLINTNEXTLINE
#define EXPECT_TOTALLY_ORDERED_IMPL(_, x, y) \
EXPECT(_(x <= y) or _(x >= y)); \
EXPECT(_(x < y) or _(x > y) or _(x == y)); \
EXPECT((_(x < y) or _(x > y)) == _(x != y)); \
EXPECT(_(x < y) == _(y > x)); \
EXPECT(_(x <= y) == _(y >= x)); \
EXPECT(_(x < y) != _(x >= y)); \
EXPECT(_(x > y) != _(x <= y)); \
EXPECT(_(x == y) != _(x != y))
// NOLINTNEXTLINE
#define EXPECT_TOTALLY_ORDERED(x, y) \
EXPECT_TOTALLY_ORDERED_IMPL(MIGRAPHX_VALUE_TEST_COMPARE, x, y); \
EXPECT_TOTALLY_ORDERED_IMPL(MIGRAPHX_VALUE_TEST_COMPARE, y, x)
// NOLINTNEXTLINE(readability-function-size)
TEST_CASE(value_compare_ordered)
{
EXPECT_TOTALLY_ORDERED(migraphx::value(), migraphx::value());
EXPECT_TOTALLY_ORDERED(migraphx::value(1), migraphx::value(1));
EXPECT_TOTALLY_ORDERED(migraphx::value(1), migraphx::value(2));
EXPECT_TOTALLY_ORDERED(migraphx::value("key", 1), migraphx::value("key", 1));
EXPECT_TOTALLY_ORDERED(migraphx::value("key1", 1), migraphx::value("key2", 2));
EXPECT_TOTALLY_ORDERED(migraphx::value("key", 1), migraphx::value("key", 2));
EXPECT_TOTALLY_ORDERED(migraphx::value("key1", 1), migraphx::value("key2", 2));
EXPECT_TOTALLY_ORDERED(migraphx::value("key", 1), migraphx::value("key", "2"));
EXPECT_TOTALLY_ORDERED(migraphx::value("key1", 1), migraphx::value("key2", "2"));
EXPECT_TOTALLY_ORDERED(migraphx::value(std::int64_t{1}), migraphx::value(std::uint64_t{1}));
EXPECT_TOTALLY_ORDERED(migraphx::value(std::int64_t{1}), migraphx::value(std::uint64_t{2}));
EXPECT_TOTALLY_ORDERED(migraphx::value(std::int64_t{2}), migraphx::value(std::uint64_t{1}));
EXPECT_TOTALLY_ORDERED(migraphx::value(1), migraphx::value("1"));
EXPECT_TOTALLY_ORDERED(migraphx::value(1), migraphx::value());
}
TEST_CASE(value_to_from_string)
......@@ -835,4 +901,38 @@ TEST_CASE(value_or_null)
EXPECT(v.value_or(3) == 3);
}
TEST_CASE(value_get_default)
{
migraphx::value v = {{"key", 1}};
EXPECT(v.get("key", 3) == 1);
EXPECT(v.get("missing", 3) == 3);
}
TEST_CASE(value_get_default_vector)
{
std::vector<int> ints = {1, 2, 3};
std::vector<int> fallback = {-1};
migraphx::value v = {{"key", ints}};
EXPECT(v.get("key", fallback) == ints);
EXPECT(v.get("missing", fallback) == fallback);
EXPECT(v.get("missing", {-1}) == fallback);
}
TEST_CASE(value_get_default_string_literal)
{
migraphx::value v = {{"key", "hello"}};
EXPECT(v.get("key", "none") == "hello");
EXPECT(v.get("missing", "none") == "none");
}
TEST_CASE(value_get_default_string_literal_vector)
{
std::vector<std::string> strings = {"1", "2", "3"};
std::vector<std::string> fallback = {"none"};
migraphx::value v = {{"key", strings}};
EXPECT(v.get("key", fallback) == strings);
EXPECT(v.get("missing", fallback) == fallback);
EXPECT(v.get("missing", {"none"}) == fallback);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_roialign_nonstandard : verify_program<test_roialign_nonstandard>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
auto x_s = migraphx::shape::from_permutation(
migraphx::shape::float_type, {5, 4, 10, 10}, {0, 2, 3, 1});
migraphx::shape roi_s{migraphx::shape::float_type, {5, 4}};
migraphx::shape ind_s{migraphx::shape::int64_type, {5}};
std::vector<int64_t> ind_vec = {0, 2, 3, 4, 1};
auto x = mm->add_parameter("x", x_s);
auto roi = mm->add_parameter("roi", roi_s);
auto ind = mm->add_literal(migraphx::literal(ind_s, ind_vec));
auto r = mm->add_instruction(migraphx::make_op("roialign",
{{"spatial_scale", 1.0},
{"output_height", 5},
{"output_width", 5},
{"sampling_ratio", 2}}),
x,
roi,
ind);
mm->add_return({r});
return p;
}
};
......@@ -102,6 +102,10 @@ header_function = Template('''
${error_type} ${name}(${params});
''')
function_pointer_typedef = Template('''
typedef ${error_type} (*${fname})(${params});
''')
c_api_impl = Template('''
extern "C" ${error_type} ${name}(${params})
{
......@@ -136,18 +140,23 @@ class CFunction:
self.va_end = ['va_end({});'.format(name)]
self.add_param('...', '')
def substitute(self, form: Template) -> str:
def substitute(self, form: Template, **kwargs) -> str:
return form.substitute(error_type=error_type,
try_wrap=try_wrap,
name=self.name,
params=', '.join(self.params),
body=";\n ".join(self.body),
va_start="\n ".join(self.va_start),
va_end="\n ".join(self.va_end))
va_end="\n ".join(self.va_end),
**kwargs)
def generate_header(self) -> str:
return self.substitute(header_function)
def generate_function_pointer(self, name: Optional[str] = None) -> str:
return self.substitute(function_pointer_typedef,
fname=name or self.name)
def generate_body(self) -> str:
return self.substitute(c_api_impl)
......@@ -163,7 +172,9 @@ class Parameter:
name: str,
type: str,
optional: bool = False,
returns: bool = False) -> None:
returns: bool = False,
virtual: bool = False,
this: bool = False) -> None:
self.name = name
self.type = Type(type)
self.optional = optional
......@@ -175,7 +186,11 @@ class Parameter:
self.cpp_read = '${name}'
self.cpp_write = '${name}'
self.returns = returns
self.virtual = virtual
self.this = this
self.bad_param_check: Optional[BadParam] = None
self.virtual_read: Optional[List[str]] = None
self.virtual_write: Optional[str] = None
def get_name(self, prefix: Optional[str] = None) -> str:
if prefix:
......@@ -248,6 +263,48 @@ class Parameter:
raise ValueError("Error for {}: write cannot be a string".format(
self.type.str()))
def virtual_arg(self, prefix: Optional[str] = None) -> List[str]:
read = self.virtual_read
if not read and len(self.write) >= len(self.cparams):
read = [
Template(w.partition('=')[2]).safe_substitute(result='${name}')
for w in self.write
]
if not read:
raise ValueError("No virtual_read parameter provided for: " +
self.type.str())
if isinstance(read, str):
raise ValueError(
"Error for {}: virtual_read cannot be a string".format(
self.type.str()))
return [self.substitute(r, prefix=prefix) for r in read]
def virtual_param(self, prefix: Optional[str] = None) -> str:
return self.substitute('${type} ${name}', prefix=prefix)
def virtual_output_args(self, prefix: Optional[str] = None) -> List[str]:
return [
'&{prefix}{n}'.format(prefix=prefix or '', n=n)
for t, n in self.cparams
]
def virtual_output_declarations(self,
prefix: Optional[str] = None) -> List[str]:
return [
'std::remove_pointer_t<{type}> {prefix}{n};'.format(
type=Type(t).str(), prefix=prefix or '', n=n)
for t, n in self.cparams
]
def virtual_output(self, prefix: Optional[str] = None) -> str:
write = self.virtual_write
if not write:
if '*' in self.read or '->' in self.read:
write = Template(self.read).safe_substitute(name='(&${name})')
else:
write = self.read
return self.substitute(write, prefix=prefix)
def cpp_param(self, prefix: Optional[str] = None) -> str:
return self.substitute('${cpptype} ${name}', prefix=prefix)
......@@ -311,6 +368,7 @@ class Function:
invoke: Optional[str] = None,
fname: Optional[str] = None,
return_name: Optional[str] = None,
virtual: bool = False,
**kwargs) -> None:
self.name = name
self.params = params or []
......@@ -321,6 +379,10 @@ class Function:
self.return_name = return_name or 'out'
self.returns = Parameter(self.return_name, returns,
returns=True) if returns else None
for p in self.params:
p.virtual = virtual
if self.returns:
self.returns.virtual = virtual
def share_params(self) -> None:
if self.shared_size == True:
......@@ -556,6 +618,9 @@ def params(virtual: Optional[Dict[str, str]] = None,
return result
gparams = params
def add_function(name: str, *args, **kwargs) -> Function:
f = Function(name, *args, **kwargs)
functions.append(f)
......@@ -627,7 +692,7 @@ extern "C" struct ${ctype};
struct ${ctype} {
template<class... Ts>
${ctype}(Ts&&... xs)
: object(std::forward<Ts>(xs)...)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{}
${cpptype} object;
};
......@@ -656,6 +721,55 @@ void destroy(T* x)
{
delete x; // NOLINT
}
// TODO: Move to interface preamble
template <class C, class D>
struct manage_generic_ptr
{
manage_generic_ptr() = default;
manage_generic_ptr(std::nullptr_t)
{
}
manage_generic_ptr(void* pdata, C pcopier, D pdeleter)
: data(nullptr), copier(pcopier), deleter(pdeleter)
{
copier(&data, pdata);
}
manage_generic_ptr(const manage_generic_ptr& rhs)
: data(nullptr), copier(rhs.copier), deleter(rhs.deleter)
{
if(copier)
copier(&data, rhs.data);
}
manage_generic_ptr(manage_generic_ptr&& other) noexcept
: data(other.data), copier(other.copier), deleter(other.deleter)
{
other.data = nullptr;
other.copier = nullptr;
other.deleter = nullptr;
}
manage_generic_ptr& operator=(manage_generic_ptr rhs)
{
std::swap(data, rhs.data);
std::swap(copier, rhs.copier);
std::swap(deleter, rhs.deleter);
return *this;
}
~manage_generic_ptr()
{
if(data != nullptr)
deleter(data);
}
void* data = nullptr;
C copier = nullptr;
D deleter = nullptr;
};
'''
cpp_handle_preamble = '''
......@@ -718,30 +832,40 @@ def add_handle(name: str,
ctype: str,
cpptype: str,
destroy: Optional[str] = None,
ref: Optional[bool] = None) -> None:
ref=False,
skip_def=False) -> None:
opaque_type = ctype + '_t'
const_opaque_type = 'const_' + opaque_type
def handle_wrap(p):
def handle_wrap(p: Parameter):
t = Type(opaque_type)
if p.type.is_const():
t = Type('const_' + opaque_type)
if p.returns:
# p.read = 'object_cast<${ctype}>(&(${name}))'
if p.virtual:
p.add_param(t)
elif p.returns:
p.add_param(t.add_pointer())
if p.type.is_reference():
p.cpp_write = '${cpptype}(${name}, false)'
p.write = ['*${name} = object_cast<${ctype}>(&(${result}))']
elif p.type.is_pointer():
p.cpp_write = '${cpptype}(${name}, false)'
p.write = ['*${name} = object_cast<${ctype}>(${result})']
else:
p.cpp_write = '${cpptype}(${name})'
p.write = ['*${name} = allocate<${ctype}>(${result})']
else:
p.add_param(t)
p.bad_param('${name} == nullptr', 'Null pointer')
if p.type.is_reference():
p.virtual_read = ['object_cast<${ctype}>(&(${name}))']
p.cpp_write = '${cpptype}(${name}, false)'
p.write = ['*${name} = object_cast<${ctype}>(&(${result}))']
elif p.type.is_pointer():
p.virtual_read = ['object_cast<${ctype}>(${result})']
p.cpp_write = '${cpptype}(${name}, false)'
p.write = ['*${name} = object_cast<${ctype}>(${result})']
else:
p.virtual_read = ['object_cast<${ctype}>(&(${name}))']
p.cpp_write = '${cpptype}(${name})'
p.write = ['*${name} = allocate<${ctype}>(${result})']
if skip_def:
p.read = '*${name}'
else:
p.read = '${name}->object'
p.cpp_read = '${name}.get_handle_ptr()'
p.cpp_read = '${name}.get_handle_ptr()'
type_map[cpptype] = handle_wrap
if not ref:
......@@ -753,7 +877,8 @@ def add_handle(name: str,
invoke='*output = *input')
add_handle_preamble()
c_header_preamble.append(handle_typedef.substitute(locals()))
c_api_body_preamble.append(handle_definition.substitute(locals()))
if not skip_def:
c_api_body_preamble.append(handle_definition.substitute(locals()))
@cwrap('std::vector')
......@@ -763,30 +888,32 @@ def vector_c_wrap(p: Parameter) -> None:
if not inner:
return
t = inner.add_pointer()
if p.type.is_reference():
if p.type.is_const():
t = t.add_const()
if p.returns:
if p.type.is_reference():
if p.type.is_const():
t = t.add_const()
p.add_param(t.add_pointer())
p.add_size_param()
p.bad_param('${name} == nullptr or ${size} == nullptr',
'Null pointer')
p.cpp_write = '${type}(${name}, ${name}+${size})'
p.write = [
'*${name} = ${result}.data()', '*${size} = ${result}.size()'
]
else:
p.add_param(t)
p.bad_param('${name} == nullptr', 'Null pointer')
p.cpp_write = '${type}(${name}, ${name}+${size})'
p.write = [
'std::copy(${result}.begin(), ${result}.end(), ${name})'
]
else:
p.add_param(t)
p.add_size_param()
p.bad_param('${name} == nullptr and ${size} != 0', 'Null pointer')
p.read = '${type}(${name}, ${name}+${size})'
p.read = '${type}(${name}, ${name}+${size})'
p.cpp_write = '${type}(${name}, ${name}+${size})'
p.virtual_read = ['${name}.data()', '${name}.size()']
if p.type.is_reference():
p.write = [
'*${name} = ${result}.data()', '*${size} = ${result}.size()'
]
else:
p.write = ['std::copy(${result}.begin(), ${result}.end(), ${name})']
@cwrap('std::string')
......@@ -796,34 +923,34 @@ def string_c_wrap(p: Parameter) -> None:
if p.type.is_reference():
p.add_param(t.add_pointer())
p.bad_param('${name} == nullptr', 'Null pointer')
p.cpp_write = '${type}(${name})'
p.write = ['*${name} = ${result}.c_str()']
else:
p.add_param(t)
p.add_param('size_t', p.name + '_size')
p.bad_param('${name} == nullptr', 'Null pointer')
p.cpp_write = '${type}(${name})'
p.write = [
'auto* it = std::copy_n(${result}.begin(), std::min(${result}.size(), ${name}_size - 1), ${name});'
'*it = \'\\0\''
]
else:
p.add_param(t)
p.bad_param('${name} == nullptr', 'Null pointer')
p.read = '${type}(${name})'
p.read = '${type}(${name})'
p.cpp_write = '${type}(${name})'
p.virtual_read = ['${name}.c_str()']
if p.type.is_reference():
p.write = ['*${name} = ${result}.c_str()']
else:
p.write = [
'auto* it = std::copy_n(${result}.begin(), std::min(${result}.size(), ${name}_size - 1), ${name});'
'*it = \'\\0\''
]
class Handle:
def __init__(self,
name: str,
ctype: str,
cpptype: str,
ref: Optional[bool] = None) -> None:
def __init__(self, name: str, ctype: str, cpptype: str, **kwargs) -> None:
self.name = name
self.ctype = ctype
self.cpptype = cpptype
self.opaque_type = self.ctype + '_t'
self.cpp_class = CPPClass(name, ctype)
add_handle(name, ctype, cpptype, ref=ref)
add_handle(name, ctype, cpptype, **kwargs)
cpp_type_map[cpptype] = name
def cname(self, name: str) -> str:
......@@ -833,6 +960,7 @@ class Handle:
return Template(s).safe_substitute(name=self.name,
ctype=self.ctype,
cpptype=self.cpptype,
opaque_type=self.opaque_type,
**kwargs)
def constructor(self,
......@@ -887,6 +1015,137 @@ class Handle:
cpp_classes.append(self.cpp_class)
interface_handle_definition = Template('''
extern "C" struct ${ctype};
struct ${ctype} {
template<class... Ts>
${ctype}(void* p, ${copier} c, ${deleter} d, Ts&&... xs)
: object_ptr(p, c, d), xobject(std::forward<Ts>(xs)...)
{}
manage_generic_ptr<${copier}, ${deleter}> object_ptr = nullptr;
${cpptype} xobject;
${functions}
};
''')
c_api_virtual_impl = Template('''
${return_type} ${name}(${params}) const
{
${output_decls}
if (${fname} == nullptr)
throw std::runtime_error("${name} function is missing.");
auto api_error_result = ${fname}(${args});
if (api_error_result != ${success})
throw std::runtime_error("Error in ${name}.");
return ${output};
}
''')
def generate_virtual_impl(f: Function, fname: str) -> str:
success = success_type
name = f.name
return_type = 'void'
output_decls = ''
output = ''
largs = []
lparams = []
if f.returns:
return_type = f.returns.type.str()
output_decls = '\n'.join(f.returns.virtual_output_declarations())
largs += f.returns.virtual_output_args()
output = f.returns.virtual_output()
largs += [arg for p in f.params for arg in p.virtual_arg()]
lparams += [p.virtual_param() for p in f.params if not p.this]
args = ', '.join(largs)
params = ', '.join(lparams)
return c_api_virtual_impl.substitute(locals())
class Interface(Handle):
def __init__(self, name: str, ctype: str, cpptype: str) -> None:
super().__init__(name, ctype, cpptype, skip_def=True)
self.ifunctions: List[Function] = []
self.members: List[str] = []
def mname(self, name: str) -> str:
return name + "_f"
def constructor( # type: ignore
self,
name: str,
params: Optional[List[Parameter]] = None,
**kwargs) -> 'Interface':
create = self.substitute('allocate<${opaque_type}>($@)')
initial_params = gparams(obj='void*',
c=self.cname('copy'),
d=self.cname('delete'))
add_function(self.cname(name),
params=initial_params + (params or []),
invoke=create,
returns=self.opaque_type,
return_name=self.name,
**kwargs)
return self
def method(self, *args, **kwargs) -> 'Interface':
super().method(*args, **kwargs)
return self
def virtual(self,
name: str,
params: Optional[List[Parameter]] = None,
const: Optional[bool] = None,
**kwargs) -> 'Interface':
# Add this parameter to the function
this = Parameter('obj', 'void*', this=True)
this.virtual_read = ['object_ptr.data']
f = Function(name,
params=[this] + (params or []),
virtual=True,
**kwargs)
self.ifunctions.append(f)
add_function(self.cname('set_' + name),
params=gparams(obj=self.opaque_type,
input=self.cname(name)),
invoke='${{obj}}->{name} = ${{input}}'.format(
name=self.mname(name)))
return self
def generate_function(self, f: Function):
cname = self.cname(f.name)
mname = self.mname(f.name)
function = generate_virtual_impl(f, fname=mname)
return f"{cname} {mname} = nullptr;{function}"
def generate(self):
required_functions = [
Function('copy',
params=gparams(out='void**', input='void*'),
virtual=True),
Function('delete', params=gparams(input='void*'), virtual=True)
]
for f in self.ifunctions + required_functions:
f.update()
c_header_preamble.extend([
f.get_cfunction().generate_function_pointer(self.cname(f.name))
for f in self.ifunctions + required_functions
])
function_list = [self.generate_function(f) for f in self.ifunctions]
ctype = self.ctype
cpptype = self.cpptype
copier = self.cname('copy')
deleter = self.cname('delete')
functions = '\n'.join(function_list)
c_api_body_preamble.append(
interface_handle_definition.substitute(locals()))
def handle(ctype: str,
cpptype: str,
name: Optional[str] = None,
......@@ -906,6 +1165,23 @@ def handle(ctype: str,
return with_handle
def interface(ctype: str, cpptype: str,
name: Optional[str] = None) -> Callable:
def with_interface(f):
n = name or f.__name__
h = Interface(n, ctype, cpptype)
f(h)
h.generate()
@wraps(f)
def decorated(*args, **kwargs):
return f(*args, **kwargs)
return decorated
return with_interface
def template_eval(template, **kwargs):
start = '<%'
end = '%>'
......@@ -928,7 +1204,7 @@ def run(args: List[str]) -> None:
else:
sys.stdout.write(generate_c_header())
sys.stdout.write(generate_c_api_body())
sys.stdout.write(generate_cpp_header())
# sys.stdout.write(generate_cpp_header())
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment