Unverified Commit 77164f3c authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Allow constructing an operation with a format string (#976)

Designed to allow a user to format the values needed for the json_string: migraphx::operation("reduce_mean", "{axes : [%i, %i, %i, %i]}", axes[0], axes[1], axes[2], axes[3]) instead of needing to use string concat or stringstream
parent a05113aa
......@@ -190,6 +190,7 @@ rocm_enable_cppcheck(
shadowVariable
unsafeClassDivZero
definePrefix:*test/include/test.hpp
ctuOneDefinitionRuleViolation:*test/*
useSmartPointer:*src/api/api.cpp
useSmartPointer:*make_shared_array.hpp
constParameter:*src/targets/gpu/*.cpp
......
This diff is collapsed.
......@@ -209,7 +209,8 @@ migraphx_status migraphx_operation_destroy(migraphx_operation_t operation);
migraphx_status migraphx_operation_create(migraphx_operation_t* operation,
const char* name,
const char* attributes);
const char* attributes,
...);
migraphx_status migraphx_operation_name(char* out, size_t out_size, migraphx_operation_t operation);
......
......@@ -599,9 +599,10 @@ struct operation : MIGRAPHX_HANDLE_BASE(operation)
operation(migraphx_operation* p, borrow) { this->set_handle(p, borrow{}); }
operation(const char* name, const char* attributes = nullptr)
template <class... Ts>
operation(const char* name, const char* attributes = nullptr, Ts... xs)
{
this->make_handle(&migraphx_operation_create, name, attributes);
this->make_handle(&migraphx_operation_create, name, attributes, xs...);
}
std::string name()
......
......@@ -212,7 +212,9 @@ def program(h):
@auto_handle()
def operation(h):
h.constructor('create',
api.params(name='const char*', attributes='const char*'),
api.params(name='const char*',
attributes='const char*',
vlist='...'),
fname='migraphx::create_op')
h.method('name', returns='std::string')
......
......@@ -8,16 +8,22 @@ TEST_CASE(add_op)
EXPECT(add_op.name() == "add");
}
TEST_CASE(reduce_mean)
TEST_CASE(reduce_mean_without_quotes)
{
auto rm = migraphx::operation("reduce_mean", "{axes : [1, 2, 3, 4]}");
EXPECT(rm.name() == "reduce_mean");
}
TEST_CASE(reduce_mean1)
TEST_CASE(reduce_mean)
{
auto rm = migraphx::operation("reduce_mean", "{\"axes\" : [1, 2, 3, 4]}");
EXPECT(rm.name() == "reduce_mean");
}
TEST_CASE(reduce_mean_with_format)
{
auto rm = migraphx::operation("reduce_mean", "{axes : [%i, %i, %i, %i]}", 1, 2, 3, 4);
EXPECT(rm.name() == "reduce_mean");
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -35,6 +35,9 @@ class Type:
def is_const(self):
return self.name.startswith('const ')
def is_variadic(self):
return self.name.startswith('...')
def add_pointer(self):
return Type(self.name + '*')
......@@ -101,9 +104,10 @@ ${error_type} ${name}(${params});
c_api_impl = Template('''
extern "C" ${error_type} ${name}(${params})
{
return ${try_wrap}([&] {
${va_start}auto api_error_result = ${try_wrap}([&] {
${body};
});
${va_end}return api_error_result;
}
''')
......@@ -113,6 +117,8 @@ class CFunction:
self.name = name
self.params = []
self.body = []
self.va_start = []
self.va_end = []
def add_param(self, type, pname):
self.params.append('{} {}'.format(type, pname))
......@@ -120,12 +126,23 @@ class CFunction:
def add_statement(self, stmt):
self.body.append(stmt)
def add_vlist(self, name):
last_param = self.params[-1].split()[-1]
self.va_start = [
'va_list {};'.format(name),
'va_start({}, {});'.format(name, last_param)
]
self.va_end = ['va_end({});'.format(name)]
self.add_param('...', '')
def substitute(self, form):
return form.substitute(error_type=error_type,
try_wrap=try_wrap,
name=self.name,
params=', '.join(self.params),
body=";\n ".join(self.body))
body=";\n ".join(self.body),
va_start="\n ".join(self.va_start),
va_end="\n ".join(self.va_end))
def generate_header(self):
return self.substitute(header_function)
......@@ -256,7 +273,10 @@ class Parameter:
def add_to_cfunction(self, cfunction):
for t, name in self.cparams:
cfunction.add_param(self.substitute(t), self.substitute(name))
if t.startswith('...'):
cfunction.add_vlist(name)
else:
cfunction.add_param(self.substitute(t), self.substitute(name))
if self.bad_param_check:
msg = 'Bad parameter {name}: {msg}'.format(
name=self.name, msg=self.bad_param_check.msg)
......
......@@ -13,6 +13,7 @@
#include <migraphx/json.hpp>
#include <migraphx/convert_to_json.hpp>
#include <algorithm>
#include <cstdarg>
namespace migraphx {
......@@ -155,18 +156,30 @@ void quantize_int8_wrap(program& prog, const target& t, quantize_int8_options& o
migraphx::quantize_int8(prog, t, options.calibration, options.op_names);
}
operation create_op(const char* name, const char* attributes)
#ifdef __clang__
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wformat-nonliteral"
#endif
operation create_op(const char* name, const char* attributes, va_list vlist)
{
std::string sattributes = attributes == nullptr ? "" : attributes;
std::vector<char> buffer(sattributes.size() * 2);
std::vsnprintf(buffer.data(), buffer.size(), sattributes.c_str(), vlist);
value v = value::object{};
if(attributes != nullptr)
{
v = from_json_string(convert_to_json(std::string(attributes)));
v = from_json_string(convert_to_json(std::string(buffer.data())));
}
auto op = make_op(name, v);
return op;
}
#ifdef __clang__
#pragma clang diagnostic pop
#endif
template <class T>
bool equal(const T& x, const T& y)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment