Unverified Commit 77164f3c authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Allow constructing an operation with a format string (#976)

Designed to allow a user to format the values needed for the json_string: migraphx::operation("reduce_mean", "{axes : [%i, %i, %i, %i]}", axes[0], axes[1], axes[2], axes[3]) instead of needing to use string concat or stringstream
parent a05113aa
...@@ -190,6 +190,7 @@ rocm_enable_cppcheck( ...@@ -190,6 +190,7 @@ rocm_enable_cppcheck(
shadowVariable shadowVariable
unsafeClassDivZero unsafeClassDivZero
definePrefix:*test/include/test.hpp definePrefix:*test/include/test.hpp
ctuOneDefinitionRuleViolation:*test/*
useSmartPointer:*src/api/api.cpp useSmartPointer:*src/api/api.cpp
useSmartPointer:*make_shared_array.hpp useSmartPointer:*make_shared_array.hpp
constParameter:*src/targets/gpu/*.cpp constParameter:*src/targets/gpu/*.cpp
......
This diff is collapsed.
...@@ -209,7 +209,8 @@ migraphx_status migraphx_operation_destroy(migraphx_operation_t operation); ...@@ -209,7 +209,8 @@ migraphx_status migraphx_operation_destroy(migraphx_operation_t operation);
migraphx_status migraphx_operation_create(migraphx_operation_t* operation, migraphx_status migraphx_operation_create(migraphx_operation_t* operation,
const char* name, const char* name,
const char* attributes); const char* attributes,
...);
migraphx_status migraphx_operation_name(char* out, size_t out_size, migraphx_operation_t operation); migraphx_status migraphx_operation_name(char* out, size_t out_size, migraphx_operation_t operation);
......
...@@ -599,9 +599,10 @@ struct operation : MIGRAPHX_HANDLE_BASE(operation) ...@@ -599,9 +599,10 @@ struct operation : MIGRAPHX_HANDLE_BASE(operation)
operation(migraphx_operation* p, borrow) { this->set_handle(p, borrow{}); } operation(migraphx_operation* p, borrow) { this->set_handle(p, borrow{}); }
operation(const char* name, const char* attributes = nullptr) template <class... Ts>
operation(const char* name, const char* attributes = nullptr, Ts... xs)
{ {
this->make_handle(&migraphx_operation_create, name, attributes); this->make_handle(&migraphx_operation_create, name, attributes, xs...);
} }
std::string name() std::string name()
......
...@@ -212,7 +212,9 @@ def program(h): ...@@ -212,7 +212,9 @@ def program(h):
@auto_handle() @auto_handle()
def operation(h): def operation(h):
h.constructor('create', h.constructor('create',
api.params(name='const char*', attributes='const char*'), api.params(name='const char*',
attributes='const char*',
vlist='...'),
fname='migraphx::create_op') fname='migraphx::create_op')
h.method('name', returns='std::string') h.method('name', returns='std::string')
......
...@@ -8,16 +8,22 @@ TEST_CASE(add_op) ...@@ -8,16 +8,22 @@ TEST_CASE(add_op)
EXPECT(add_op.name() == "add"); EXPECT(add_op.name() == "add");
} }
TEST_CASE(reduce_mean) TEST_CASE(reduce_mean_without_quotes)
{ {
auto rm = migraphx::operation("reduce_mean", "{axes : [1, 2, 3, 4]}"); auto rm = migraphx::operation("reduce_mean", "{axes : [1, 2, 3, 4]}");
EXPECT(rm.name() == "reduce_mean"); EXPECT(rm.name() == "reduce_mean");
} }
TEST_CASE(reduce_mean1) TEST_CASE(reduce_mean)
{ {
auto rm = migraphx::operation("reduce_mean", "{\"axes\" : [1, 2, 3, 4]}"); auto rm = migraphx::operation("reduce_mean", "{\"axes\" : [1, 2, 3, 4]}");
EXPECT(rm.name() == "reduce_mean"); EXPECT(rm.name() == "reduce_mean");
} }
TEST_CASE(reduce_mean_with_format)
{
auto rm = migraphx::operation("reduce_mean", "{axes : [%i, %i, %i, %i]}", 1, 2, 3, 4);
EXPECT(rm.name() == "reduce_mean");
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -35,6 +35,9 @@ class Type: ...@@ -35,6 +35,9 @@ class Type:
def is_const(self): def is_const(self):
return self.name.startswith('const ') return self.name.startswith('const ')
def is_variadic(self):
return self.name.startswith('...')
def add_pointer(self): def add_pointer(self):
return Type(self.name + '*') return Type(self.name + '*')
...@@ -101,9 +104,10 @@ ${error_type} ${name}(${params}); ...@@ -101,9 +104,10 @@ ${error_type} ${name}(${params});
c_api_impl = Template(''' c_api_impl = Template('''
extern "C" ${error_type} ${name}(${params}) extern "C" ${error_type} ${name}(${params})
{ {
return ${try_wrap}([&] { ${va_start}auto api_error_result = ${try_wrap}([&] {
${body}; ${body};
}); });
${va_end}return api_error_result;
} }
''') ''')
...@@ -113,6 +117,8 @@ class CFunction: ...@@ -113,6 +117,8 @@ class CFunction:
self.name = name self.name = name
self.params = [] self.params = []
self.body = [] self.body = []
self.va_start = []
self.va_end = []
def add_param(self, type, pname): def add_param(self, type, pname):
self.params.append('{} {}'.format(type, pname)) self.params.append('{} {}'.format(type, pname))
...@@ -120,12 +126,23 @@ class CFunction: ...@@ -120,12 +126,23 @@ class CFunction:
def add_statement(self, stmt): def add_statement(self, stmt):
self.body.append(stmt) self.body.append(stmt)
def add_vlist(self, name):
last_param = self.params[-1].split()[-1]
self.va_start = [
'va_list {};'.format(name),
'va_start({}, {});'.format(name, last_param)
]
self.va_end = ['va_end({});'.format(name)]
self.add_param('...', '')
def substitute(self, form): def substitute(self, form):
return form.substitute(error_type=error_type, return form.substitute(error_type=error_type,
try_wrap=try_wrap, try_wrap=try_wrap,
name=self.name, name=self.name,
params=', '.join(self.params), params=', '.join(self.params),
body=";\n ".join(self.body)) body=";\n ".join(self.body),
va_start="\n ".join(self.va_start),
va_end="\n ".join(self.va_end))
def generate_header(self): def generate_header(self):
return self.substitute(header_function) return self.substitute(header_function)
...@@ -256,6 +273,9 @@ class Parameter: ...@@ -256,6 +273,9 @@ class Parameter:
def add_to_cfunction(self, cfunction): def add_to_cfunction(self, cfunction):
for t, name in self.cparams: for t, name in self.cparams:
if t.startswith('...'):
cfunction.add_vlist(name)
else:
cfunction.add_param(self.substitute(t), self.substitute(name)) cfunction.add_param(self.substitute(t), self.substitute(name))
if self.bad_param_check: if self.bad_param_check:
msg = 'Bad parameter {name}: {msg}'.format( msg = 'Bad parameter {name}: {msg}'.format(
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
#include <migraphx/json.hpp> #include <migraphx/json.hpp>
#include <migraphx/convert_to_json.hpp> #include <migraphx/convert_to_json.hpp>
#include <algorithm> #include <algorithm>
#include <cstdarg>
namespace migraphx { namespace migraphx {
...@@ -155,18 +156,30 @@ void quantize_int8_wrap(program& prog, const target& t, quantize_int8_options& o ...@@ -155,18 +156,30 @@ void quantize_int8_wrap(program& prog, const target& t, quantize_int8_options& o
migraphx::quantize_int8(prog, t, options.calibration, options.op_names); migraphx::quantize_int8(prog, t, options.calibration, options.op_names);
} }
operation create_op(const char* name, const char* attributes) #ifdef __clang__
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wformat-nonliteral"
#endif
operation create_op(const char* name, const char* attributes, va_list vlist)
{ {
std::string sattributes = attributes == nullptr ? "" : attributes;
std::vector<char> buffer(sattributes.size() * 2);
std::vsnprintf(buffer.data(), buffer.size(), sattributes.c_str(), vlist);
value v = value::object{}; value v = value::object{};
if(attributes != nullptr) if(attributes != nullptr)
{ {
v = from_json_string(convert_to_json(std::string(attributes))); v = from_json_string(convert_to_json(std::string(buffer.data())));
} }
auto op = make_op(name, v); auto op = make_op(name, v);
return op; return op;
} }
#ifdef __clang__
#pragma clang diagnostic pop
#endif
template <class T> template <class T>
bool equal(const T& x, const T& y) bool equal(const T& x, const T& y)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment