Unverified Commit 51fb672d authored by kahmed10's avatar kahmed10 Committed by GitHub
Browse files

Add tf C++ API (#770)



* fix relu6

* add more transposes

* add parse_tf calls

* progress on multi_outputs

* formatting

* add multi output test

* add comment and update migraphx.py

* fix compile

* formatting

* update tools/api

* formatting

* fix function call

* fix generate

* simplify tests

* formatting

* rename tests

* enclose braces

* add more tests

* update comments

* rename file and add default param

* formatting

* fix tidy and change type

* formatting older files
Co-authored-by: default avatarPaul Fultz II <pfultz2@yahoo.com>
parent 1a948d5b
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/onnx.hpp> #include <migraphx/onnx.hpp>
#include <migraphx/tf.hpp>
#include <migraphx/register_target.hpp> #include <migraphx/register_target.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/quantization.hpp> #include <migraphx/quantization.hpp>
...@@ -90,6 +91,10 @@ void set_default_dim_value(onnx_options& options, size_t value) ...@@ -90,6 +91,10 @@ void set_default_dim_value(onnx_options& options, size_t value)
options.default_dim_value = value; options.default_dim_value = value;
} }
void set_nhwc(tf_options& options, bool is_nhwc) { options.is_nhwc = is_nhwc; }
void set_default_dim_value(tf_options& options, size_t value) { options.batch_size = value; }
void set_input_parameter_shape(onnx_options& options, void set_input_parameter_shape(onnx_options& options,
const char* name, const char* name,
std::vector<std::size_t> dims) std::vector<std::size_t> dims)
...@@ -97,6 +102,16 @@ void set_input_parameter_shape(onnx_options& options, ...@@ -97,6 +102,16 @@ void set_input_parameter_shape(onnx_options& options,
options.map_input_dims[std::string(name)] = std::move(dims); options.map_input_dims[std::string(name)] = std::move(dims);
} }
void set_input_parameter_shape(tf_options& options, const char* name, std::vector<std::size_t> dims)
{
options.map_input_dims[std::string(name)] = std::move(dims);
}
void set_output_names(tf_options& options, std::vector<const char*> names)
{
options.output_node_names = std::vector<std::string>(names.begin(), names.end());
}
template <class Value> template <class Value>
std::vector<const char*> get_names(const std::unordered_map<std::string, Value>& m) std::vector<const char*> get_names(const std::unordered_map<std::string, Value>& m)
{ {
...@@ -303,6 +318,16 @@ struct migraphx_onnx_options ...@@ -303,6 +318,16 @@ struct migraphx_onnx_options
migraphx::onnx_options object; migraphx::onnx_options object;
}; };
extern "C" struct migraphx_tf_options;
struct migraphx_tf_options
{
template <class... Ts>
migraphx_tf_options(Ts&&... xs) : object(std::forward<Ts>(xs)...)
{
}
migraphx::tf_options object;
};
extern "C" struct migraphx_quantize_op_names; extern "C" struct migraphx_quantize_op_names;
struct migraphx_quantize_op_names struct migraphx_quantize_op_names
{ {
...@@ -839,6 +864,75 @@ extern "C" migraphx_status migraphx_parse_onnx_buffer(migraphx_program_t* out, ...@@ -839,6 +864,75 @@ extern "C" migraphx_status migraphx_parse_onnx_buffer(migraphx_program_t* out,
}); });
} }
extern "C" migraphx_status migraphx_tf_options_destroy(migraphx_tf_options_t tf_options)
{
return migraphx::try_([&] { destroy((tf_options)); });
}
extern "C" migraphx_status migraphx_tf_options_create(migraphx_tf_options_t* tf_options)
{
return migraphx::try_([&] {
*tf_options = object_cast<migraphx_tf_options_t>(allocate<migraphx::tf_options>());
});
}
extern "C" migraphx_status migraphx_tf_options_set_nhwc(migraphx_tf_options_t tf_options,
bool is_nhwc)
{
return migraphx::try_([&] {
if(tf_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter tf_options: Null pointer");
migraphx::set_nhwc((tf_options->object), (is_nhwc));
});
}
extern "C" migraphx_status migraphx_tf_options_set_input_parameter_shape(
migraphx_tf_options_t tf_options, const char* name, size_t* dims, size_t dims_size)
{
return migraphx::try_([&] {
if(tf_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter tf_options: Null pointer");
if(dims == nullptr and dims_size != 0)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter dims: Null pointer");
migraphx::set_input_parameter_shape(
(tf_options->object), (name), (std::vector<size_t>(dims, dims + dims_size)));
});
}
extern "C" migraphx_status
migraphx_tf_options_set_default_dim_value(migraphx_tf_options_t tf_options, size_t value)
{
return migraphx::try_([&] {
if(tf_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter tf_options: Null pointer");
migraphx::set_default_dim_value((tf_options->object), (value));
});
}
extern "C" migraphx_status migraphx_tf_options_set_output_names(migraphx_tf_options_t tf_options,
const char** names,
size_t names_size)
{
return migraphx::try_([&] {
if(tf_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter tf_options: Null pointer");
if(names == nullptr and names_size != 0)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter names: Null pointer");
migraphx::set_output_names((tf_options->object),
(std::vector<const char*>(names, names + names_size)));
});
}
extern "C" migraphx_status
migraphx_parse_tf(migraphx_program_t* out, const char* name, migraphx_tf_options_t options)
{
return migraphx::try_([&] {
if(options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer");
*out = allocate<migraphx_program_t>(migraphx::parse_tf((name), (options->object)));
});
}
extern "C" migraphx_status extern "C" migraphx_status
migraphx_quantize_op_names_destroy(migraphx_quantize_op_names_t quantize_op_names) migraphx_quantize_op_names_destroy(migraphx_quantize_op_names_t quantize_op_names)
{ {
......
...@@ -93,6 +93,9 @@ typedef const struct migraphx_operation* const_migraphx_operation_t; ...@@ -93,6 +93,9 @@ typedef const struct migraphx_operation* const_migraphx_operation_t;
typedef struct migraphx_onnx_options* migraphx_onnx_options_t; typedef struct migraphx_onnx_options* migraphx_onnx_options_t;
typedef const struct migraphx_onnx_options* const_migraphx_onnx_options_t; typedef const struct migraphx_onnx_options* const_migraphx_onnx_options_t;
typedef struct migraphx_tf_options* migraphx_tf_options_t;
typedef const struct migraphx_tf_options* const_migraphx_tf_options_t;
typedef struct migraphx_quantize_op_names* migraphx_quantize_op_names_t; typedef struct migraphx_quantize_op_names* migraphx_quantize_op_names_t;
typedef const struct migraphx_quantize_op_names* const_migraphx_quantize_op_names_t; typedef const struct migraphx_quantize_op_names* const_migraphx_quantize_op_names_t;
...@@ -247,6 +250,27 @@ migraphx_status migraphx_parse_onnx_buffer(migraphx_program_t* out, ...@@ -247,6 +250,27 @@ migraphx_status migraphx_parse_onnx_buffer(migraphx_program_t* out,
size_t size, size_t size,
migraphx_onnx_options_t options); migraphx_onnx_options_t options);
migraphx_status migraphx_tf_options_destroy(migraphx_tf_options_t tf_options);
migraphx_status migraphx_tf_options_create(migraphx_tf_options_t* tf_options);
migraphx_status migraphx_tf_options_set_nhwc(migraphx_tf_options_t tf_options, bool is_nhwc);
migraphx_status migraphx_tf_options_set_input_parameter_shape(migraphx_tf_options_t tf_options,
const char* name,
size_t* dims,
size_t dims_size);
migraphx_status migraphx_tf_options_set_default_dim_value(migraphx_tf_options_t tf_options,
size_t value);
migraphx_status migraphx_tf_options_set_output_names(migraphx_tf_options_t tf_options,
const char** names,
size_t names_size);
migraphx_status
migraphx_parse_tf(migraphx_program_t* out, const char* name, migraphx_tf_options_t options);
migraphx_status migraphx_quantize_op_names_destroy(migraphx_quantize_op_names_t quantize_op_names); migraphx_status migraphx_quantize_op_names_destroy(migraphx_quantize_op_names_t quantize_op_names);
migraphx_status migraphx_quantize_op_names_create(migraphx_quantize_op_names_t* quantize_op_names); migraphx_status migraphx_quantize_op_names_create(migraphx_quantize_op_names_t* quantize_op_names);
......
...@@ -625,7 +625,7 @@ struct onnx_options : MIGRAPHX_HANDLE_BASE(onnx_options) ...@@ -625,7 +625,7 @@ struct onnx_options : MIGRAPHX_HANDLE_BASE(onnx_options)
dim.size()); dim.size());
} }
/// When there is a dimension parameter than use this default value /// When there is a dimension parameter, then use this default value
void set_default_dim_value(unsigned int value) void set_default_dim_value(unsigned int value)
{ {
call(&migraphx_onnx_options_set_default_dim_value, this->get_handle_ptr(), value); call(&migraphx_onnx_options_set_default_dim_value, this->get_handle_ptr(), value);
...@@ -684,6 +684,60 @@ inline program parse_onnx_buffer(const std::string& buffer) ...@@ -684,6 +684,60 @@ inline program parse_onnx_buffer(const std::string& buffer)
own{}); own{});
} }
/// Options for parsing tf options
struct tf_options : MIGRAPHX_HANDLE_BASE(tf_options)
{
tf_options() { this->make_handle(&migraphx_tf_options_create); }
tf_options(migraphx_tf_options* p, own) { this->set_handle(p, own{}); }
/// Make tf parser treat an inputs with a certain dimensions
void set_input_parameter_shape(const std::string& name, std::vector<std::size_t> dim)
{
call(&migraphx_tf_options_set_input_parameter_shape,
this->get_handle_ptr(),
name.c_str(),
dim.data(),
dim.size());
}
/// Change data layout to NHWC (default is NCHW)
void set_nhwc(bool is_nhwc = true)
{
call(&migraphx_tf_options_set_nhwc, this->get_handle_ptr(), is_nhwc);
}
/// When there is a dimension parameter, then use this default value
void set_default_dim_value(unsigned int value)
{
call(&migraphx_tf_options_set_default_dim_value, this->get_handle_ptr(), value);
}
/// Set output node names to return specific outputs from graph
void set_output_names(std::vector<const char*> names)
{
call(&migraphx_tf_options_set_output_names,
this->get_handle_ptr(),
names.data(),
names.size());
}
};
/// Parse a tf file into a migraphx program
inline program parse_tf(const char* filename, const migraphx::tf_options& options)
{
return program(make<migraphx_program>(&migraphx_parse_tf, filename, options.get_handle_ptr()),
own{});
}
/// Parse a tf file into a migraphx program
inline program parse_tf(const char* filename)
{
migraphx::tf_options options;
return program(make<migraphx_program>(&migraphx_parse_tf, filename, options.get_handle_ptr()),
own{});
}
struct quantize_op_names : MIGRAPHX_HANDLE_BASE(quantize_op_names) struct quantize_op_names : MIGRAPHX_HANDLE_BASE(quantize_op_names)
{ {
quantize_op_names() { this->make_handle(&migraphx_quantize_op_names_create); } quantize_op_names() { this->make_handle(&migraphx_quantize_op_names_create); }
......
...@@ -55,6 +55,17 @@ def onnx_options_type_wrap(p): ...@@ -55,6 +55,17 @@ def onnx_options_type_wrap(p):
p.read = '${name} == nullptr ? migraphx::onnx_options{} : migraphx::to_onnx_options(*${name})' p.read = '${name} == nullptr ? migraphx::onnx_options{} : migraphx::to_onnx_options(*${name})'
@api.cwrap('migraphx::tf_options')
def tf_options_type_wrap(p):
if p.returns:
p.add_param('migraphx_tf_options *')
p.bad_param('${name} == nullptr', 'Null pointer')
p.write = ['*${name} = migraphx::to_tf_options(${result})']
else:
p.add_param('migraphx_tf_options *')
p.read = '${name} == nullptr ? migraphx::tf_options{} : migraphx::to_tf_options(*${name})'
def auto_handle(*args, **kwargs): def auto_handle(*args, **kwargs):
def with_handle(f): def with_handle(f):
return api.handle('migraphx_' + f.__name__, 'migraphx::' + f.__name__, return api.handle('migraphx_' + f.__name__, 'migraphx::' + f.__name__,
...@@ -248,6 +259,38 @@ api.add_function('migraphx_parse_onnx_buffer', ...@@ -248,6 +259,38 @@ api.add_function('migraphx_parse_onnx_buffer',
returns='migraphx::program') returns='migraphx::program')
@auto_handle()
def tf_options(h):
h.constructor('create')
h.method(
'set_nhwc',
api.params(is_nhwc='bool'),
invoke='migraphx::set_nhwc($@)',
)
h.method(
'set_input_parameter_shape',
api.params(name='const char*', dims='std::vector<size_t>'),
invoke='migraphx::set_input_parameter_shape($@)',
)
h.method(
'set_default_dim_value',
api.params(value='size_t'),
invoke='migraphx::set_default_dim_value($@)',
)
h.method(
'set_output_names',
api.params(names='std::vector<const char*>'),
invoke='migraphx::set_output_names($@)',
)
api.add_function('migraphx_parse_tf',
api.params(name='const char*',
options='migraphx::tf_options'),
fname='migraphx::parse_tf',
returns='migraphx::program')
@api.handle('migraphx_quantize_op_names', 'std::vector<std::string>') @api.handle('migraphx_quantize_op_names', 'std::vector<std::string>')
def quantize_op_names(h): def quantize_op_names(h):
h.constructor('create') h.constructor('create')
......
...@@ -649,7 +649,6 @@ inline auto has_attribute(const std::string& name) ...@@ -649,7 +649,6 @@ inline auto has_attribute(const std::string& name)
[=](instruction_ref ins) { return ins->get_operator().attributes().contains(name); }); [=](instruction_ref ins) { return ins->get_operator().attributes().contains(name); });
} }
} // namespace match } // namespace match
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -124,17 +124,18 @@ foreach(ONNX_TEST ${ONNX_TESTS}) ...@@ -124,17 +124,18 @@ foreach(ONNX_TEST ${ONNX_TESTS})
rocm_clang_tidy_check(${TEST_NAME}) rocm_clang_tidy_check(${TEST_NAME})
target_link_libraries(${TEST_NAME} migraphx_onnx migraphx_ref) target_link_libraries(${TEST_NAME} migraphx_onnx migraphx_ref)
target_include_directories(${TEST_NAME} PUBLIC include) target_include_directories(${TEST_NAME} PUBLIC include)
add_test(NAME ${TEST_NAME} COMMAND $<TARGET_FILE:${TEST_NAME}> WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/onnx) add_test(NAME ${TEST_NAME} COMMAND $<TARGET_FILE:${TEST_NAME}> WORKING_DIRECTORY ${TEST_ONNX_DIR})
add_dependencies(tests ${TEST_NAME}) add_dependencies(tests ${TEST_NAME})
add_dependencies(check ${TEST_NAME}) add_dependencies(check ${TEST_NAME})
endforeach() endforeach()
# tf test # tf test
set(TEST_TF_DIR ${CMAKE_CURRENT_SOURCE_DIR}/tf)
add_executable(test_tf tf/tf_test.cpp) add_executable(test_tf tf/tf_test.cpp)
rocm_clang_tidy_check(test_tf) rocm_clang_tidy_check(test_tf)
target_link_libraries(test_tf migraphx_tf migraphx_ref) target_link_libraries(test_tf migraphx_tf migraphx_ref)
target_include_directories(test_tf PUBLIC include) target_include_directories(test_tf PUBLIC include)
add_test(NAME test_tf COMMAND $<TARGET_FILE:test_tf> WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/tf) add_test(NAME test_tf COMMAND $<TARGET_FILE:test_tf> WORKING_DIRECTORY ${TEST_TF_DIR})
add_dependencies(tests test_tf) add_dependencies(tests test_tf)
add_dependencies(check test_tf) add_dependencies(check test_tf)
......
function(add_api_test TEST_NAME TEST_SRC) function(add_api_test TEST_NAME TEST_SRC TEST_DIR)
set(NAME test_api_${TEST_NAME}) set(NAME test_api_${TEST_NAME})
add_executable(${NAME} EXCLUDE_FROM_ALL ${TEST_SRC}) add_executable(${NAME} EXCLUDE_FROM_ALL ${TEST_SRC})
rocm_clang_tidy_check(${NAME}) rocm_clang_tidy_check(${NAME})
target_link_libraries(${NAME} migraphx_c) target_link_libraries(${NAME} migraphx_c)
target_include_directories(${NAME} PUBLIC ../include) target_include_directories(${NAME} PUBLIC ../include)
add_test(NAME ${NAME} COMMAND $<TARGET_FILE:${NAME}> WORKING_DIRECTORY ${TEST_ONNX_DIR}) add_test(NAME ${NAME} COMMAND $<TARGET_FILE:${NAME}> WORKING_DIRECTORY ${TEST_DIR})
add_dependencies(tests ${NAME}) add_dependencies(tests ${NAME})
add_dependencies(check ${NAME}) add_dependencies(check ${NAME})
endfunction() endfunction()
add_api_test(ref test_cpu.cpp) add_api_test(ref test_cpu.cpp ${TEST_ONNX_DIR})
add_api_test(save_load test_save_load.cpp) add_api_test(save_load test_save_load.cpp ${TEST_ONNX_DIR})
add_api_test(op test_op_construct.cpp) add_api_test(op test_op_construct.cpp ${TEST_ONNX_DIR})
add_api_test(tf_parser test_tf_parser.cpp ${TEST_TF_DIR})
if(MIGRAPHX_ENABLE_GPU) if(MIGRAPHX_ENABLE_GPU)
add_api_test(gpu test_gpu.cpp) add_api_test(gpu test_gpu.cpp ${TEST_ONNX_DIR})
# GPU-based tests # GPU-based tests
endif() endif()
#include <migraphx/migraphx.h>
#include <migraphx/migraphx.hpp>
#include "test.hpp"
TEST_CASE(load_tf)
{
auto p = migraphx::parse_tf("add_test.pb");
auto shapes = p.get_output_shapes();
CHECK(shapes.size() == 1);
}
TEST_CASE(load_tf_default_dim)
{
migraphx::tf_options tf_options;
size_t batch = 2;
tf_options.set_default_dim_value(batch);
tf_options.set_nhwc();
auto p = migraphx::parse_tf("conv_batch_test.pb", tf_options);
auto shapes = p.get_output_shapes();
CHECK(shapes.size() == 1);
CHECK(shapes.front().lengths().front() == batch);
}
TEST_CASE(load_tf_param_shape)
{
migraphx::tf_options tf_options;
std::vector<size_t> new_shape{1, 3};
tf_options.set_input_parameter_shape("0", new_shape);
tf_options.set_input_parameter_shape("1", new_shape);
auto p = migraphx::parse_tf("add_test.pb", tf_options);
auto shapes = p.get_output_shapes();
CHECK(shapes.size() == 1);
CHECK(shapes.front().lengths() == new_shape);
}
TEST_CASE(load_tf_multi_outputs)
{
migraphx::tf_options tf_options;
tf_options.set_output_names({"relu", "tanh"});
auto p = migraphx::parse_tf("multi_output_test.pb", tf_options);
auto shapes = p.get_output_shapes();
CHECK(shapes.size() == 2);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -1650,9 +1650,9 @@ TEST_CASE(lessorequal_test) ...@@ -1650,9 +1650,9 @@ TEST_CASE(lessorequal_test)
auto input1 = mm->add_parameter("x1", migraphx::shape{migraphx::shape::float_type, {3}}); auto input1 = mm->add_parameter("x1", migraphx::shape{migraphx::shape::float_type, {3}});
auto input2 = mm->add_parameter("x2", migraphx::shape{migraphx::shape::float_type, {3}}); auto input2 = mm->add_parameter("x2", migraphx::shape{migraphx::shape::float_type, {3}});
auto temp = mm->add_instruction(migraphx::make_op("greater"), input1, input2); auto temp = mm->add_instruction(migraphx::make_op("greater"), input1, input2);
auto le = mm->add_instruction(migraphx::make_op("not"), temp); auto le = mm->add_instruction(migraphx::make_op("not"), temp);
mm->add_return({le}); mm->add_return({le});
auto prog = migraphx::parse_onnx("lessorequal_test.onnx"); auto prog = migraphx::parse_onnx("lessorequal_test.onnx");
......
...@@ -152,18 +152,18 @@ TEST_CASE(lessorequal_test) ...@@ -152,18 +152,18 @@ TEST_CASE(lessorequal_test)
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
migraphx::shape s{migraphx::shape::float_type, {3}}; migraphx::shape s{migraphx::shape::float_type, {3}};
std::vector<float> data1 = { 0.25, 0.75, 0.9375}; std::vector<float> data1 = {0.25, 0.75, 0.9375};
std::vector<float> data2 = { 0.25, 0.74, 0.9411}; std::vector<float> data2 = {0.25, 0.74, 0.9411};
migraphx::parameter_map pp; migraphx::parameter_map pp;
pp["x1"] = migraphx::argument(s, data1.data()); pp["x1"] = migraphx::argument(s, data1.data());
pp["x2"] = migraphx::argument(s, data2.data()); pp["x2"] = migraphx::argument(s, data2.data());
auto result = p.eval(pp).back(); auto result = p.eval(pp).back();
std::vector<float> result_vector; std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1,0,1}; std::vector<float> gold = {1, 0, 1};
EXPECT(migraphx::verify_range(result_vector, gold)); EXPECT(migraphx::verify_range(result_vector, gold));
} }
......
...@@ -926,14 +926,14 @@ TEST_CASE(simplify_split_reduce0) ...@@ -926,14 +926,14 @@ TEST_CASE(simplify_split_reduce0)
auto input = m1.add_parameter("input", s); auto input = m1.add_parameter("input", s);
auto x = m1.add_instruction( auto x = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
auto y = m1.add_instruction( auto y = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), input); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), input);
auto one = m1.add_literal(1); auto one = m1.add_literal(1);
auto two = m1.add_literal(2); auto two = m1.add_literal(2);
auto arx = m1.add_instruction(migraphx::make_op("contiguous"), x); auto arx = m1.add_instruction(migraphx::make_op("contiguous"), x);
auto ary = m1.add_instruction(migraphx::make_op("contiguous"), y); auto ary = m1.add_instruction(migraphx::make_op("contiguous"), y);
auto rmax0 = m1.add_instruction(migraphx::make_op("reduce_sum", {{"axes", {0, 1}}}), x); auto rmax0 = m1.add_instruction(migraphx::make_op("reduce_sum", {{"axes", {0, 1}}}), x);
auto rmin0 = m1.add_instruction(migraphx::make_op("reduce_mean", {{"axes", {0, 1}}}), x); auto rmin0 = m1.add_instruction(migraphx::make_op("reduce_mean", {{"axes", {0, 1}}}), x);
auto rmax1 = m1.add_instruction(migraphx::make_op("gather", {{"axis", 1}}), arx, one); auto rmax1 = m1.add_instruction(migraphx::make_op("gather", {{"axis", 1}}), arx, one);
...@@ -970,15 +970,15 @@ TEST_CASE(simplify_split_reduce1) ...@@ -970,15 +970,15 @@ TEST_CASE(simplify_split_reduce1)
{ {
auto input = m2.add_parameter("input", s); auto input = m2.add_parameter("input", s);
auto rmn = m2.add_instruction(migraphx::make_op("reduce_mean", {{"axes", {0, 2}}}), input); auto rmn = m2.add_instruction(migraphx::make_op("reduce_mean", {{"axes", {0, 2}}}), input);
auto slc0 = m2.add_instruction( auto slc0 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), rmn); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), rmn);
auto rmx = m2.add_instruction(migraphx::make_op("reduce_sum", {{"axes", {0, 2}}}), input); auto rmx = m2.add_instruction(migraphx::make_op("reduce_sum", {{"axes", {0, 2}}}), input);
auto slc1 = m2.add_instruction( auto slc1 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), rmx); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), rmx);
auto slc2 = m2.add_instruction( auto slc2 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), rmn); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), rmn);
auto slc3 = m2.add_instruction( auto slc3 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), rmx); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), rmx);
m2.add_return({slc3, slc2, slc1, slc0}); m2.add_return({slc3, slc2, slc1, slc0});
} }
...@@ -1010,13 +1010,13 @@ TEST_CASE(simplify_split_reduce2) ...@@ -1010,13 +1010,13 @@ TEST_CASE(simplify_split_reduce2)
auto x = m2.add_instruction( auto x = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), input); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), input);
auto rmn1 = m2.add_instruction(migraphx::make_op("reduce_mean", {{"axes", {0, 1}}}), x); auto rmn1 = m2.add_instruction(migraphx::make_op("reduce_mean", {{"axes", {0, 1}}}), x);
auto y = m2.add_instruction( auto y = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
auto rmn2 = m2.add_instruction(migraphx::make_op("reduce_mean", {{"axes", {0, 1}}}), y); auto rmn2 = m2.add_instruction(migraphx::make_op("reduce_mean", {{"axes", {0, 1}}}), y);
auto rms = m2.add_instruction(migraphx::make_op("reduce_sum", {{"axes", {0, 2}}}), input); auto rms = m2.add_instruction(migraphx::make_op("reduce_sum", {{"axes", {0, 2}}}), input);
auto slc0 = m2.add_instruction( auto slc0 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), rms); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), rms);
auto slc1 = m2.add_instruction( auto slc1 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), rms); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), rms);
m2.add_return({slc1, rmn2, slc0, rmn1}); m2.add_return({slc1, rmn2, slc0, rmn1});
} }
......
...@@ -213,6 +213,19 @@ def conv_add_test(g1): ...@@ -213,6 +213,19 @@ def conv_add_test(g1):
tf.add(conv, conv, name='add1') tf.add(conv, conv, name='add1')
@tf_test
def conv_batch_test(g1):
with g1.as_default():
g1_input = tf.compat.v1.placeholder(tf.float32,
shape=(None, 16, 16, 3),
name='0')
g1_weights = tf.constant(value=1.0,
dtype=tf.float32,
shape=(3, 3, 3, 32),
name='1')
tf.nn.conv2d(g1_input, g1_weights, [1, 1, 1, 1], "SAME", name='conv1')
@tf_test @tf_test
def conv_nchw_test(g1): def conv_nchw_test(g1):
with g1.as_default(): with g1.as_default():
...@@ -643,6 +656,7 @@ if __name__ == '__main__': ...@@ -643,6 +656,7 @@ if __name__ == '__main__':
const_test() const_test()
conv_test() conv_test()
conv_add_test() conv_add_test()
conv_batch_test()
conv_nchw_test() conv_nchw_test()
conv_relu_test() conv_relu_test()
conv_relu6_test() conv_relu6_test()
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/onnx.hpp> #include <migraphx/onnx.hpp>
#include <migraphx/tf.hpp>
#include <migraphx/register_target.hpp> #include <migraphx/register_target.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/quantization.hpp> #include <migraphx/quantization.hpp>
...@@ -90,6 +91,10 @@ void set_default_dim_value(onnx_options& options, size_t value) ...@@ -90,6 +91,10 @@ void set_default_dim_value(onnx_options& options, size_t value)
options.default_dim_value = value; options.default_dim_value = value;
} }
void set_nhwc(tf_options& options, bool is_nhwc) { options.is_nhwc = is_nhwc; }
void set_default_dim_value(tf_options& options, size_t value) { options.batch_size = value; }
void set_input_parameter_shape(onnx_options& options, void set_input_parameter_shape(onnx_options& options,
const char* name, const char* name,
std::vector<std::size_t> dims) std::vector<std::size_t> dims)
...@@ -97,6 +102,16 @@ void set_input_parameter_shape(onnx_options& options, ...@@ -97,6 +102,16 @@ void set_input_parameter_shape(onnx_options& options,
options.map_input_dims[std::string(name)] = std::move(dims); options.map_input_dims[std::string(name)] = std::move(dims);
} }
void set_input_parameter_shape(tf_options& options, const char* name, std::vector<std::size_t> dims)
{
options.map_input_dims[std::string(name)] = std::move(dims);
}
void set_output_names(tf_options& options, std::vector<const char*> names)
{
options.output_node_names = std::vector<std::string>(names.begin(), names.end());
}
template <class Value> template <class Value>
std::vector<const char*> get_names(const std::unordered_map<std::string, Value>& m) std::vector<const char*> get_names(const std::unordered_map<std::string, Value>& m)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment