Commit 687c87d5 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

merge changes from develop branch and fixed a bug in checking broadcasted shape

parents 6ab58fbf 4480eb79
...@@ -383,6 +383,13 @@ extern "C" migraphx_status migraphx_shape_destroy(migraphx_shape_t shape) ...@@ -383,6 +383,13 @@ extern "C" migraphx_status migraphx_shape_destroy(migraphx_shape_t shape)
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status migraphx_shape_assign_to(migraphx_shape_t output,
const_migraphx_shape_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status migraphx_shape_create(migraphx_shape_t* shape, extern "C" migraphx_status migraphx_shape_create(migraphx_shape_t* shape,
migraphx_shape_datatype_t type, migraphx_shape_datatype_t type,
int* lengths, int* lengths,
...@@ -499,6 +506,13 @@ extern "C" migraphx_status migraphx_argument_destroy(migraphx_argument_t argumen ...@@ -499,6 +506,13 @@ extern "C" migraphx_status migraphx_argument_destroy(migraphx_argument_t argumen
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status migraphx_argument_assign_to(migraphx_argument_t output,
const_migraphx_argument_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status extern "C" migraphx_status
migraphx_argument_create(migraphx_argument_t* argument, const_migraphx_shape_t shape, void* buffer) migraphx_argument_create(migraphx_argument_t* argument, const_migraphx_shape_t shape, void* buffer)
{ {
...@@ -562,6 +576,13 @@ extern "C" migraphx_status migraphx_target_destroy(migraphx_target_t target) ...@@ -562,6 +576,13 @@ extern "C" migraphx_status migraphx_target_destroy(migraphx_target_t target)
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status migraphx_target_assign_to(migraphx_target_t output,
const_migraphx_target_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status migraphx_target_create(migraphx_target_t* target, const char* name) extern "C" migraphx_status migraphx_target_create(migraphx_target_t* target, const char* name)
{ {
auto api_error_result = migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
...@@ -578,6 +599,14 @@ extern "C" migraphx_status migraphx_program_parameter_shapes_destroy( ...@@ -578,6 +599,14 @@ extern "C" migraphx_status migraphx_program_parameter_shapes_destroy(
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status
migraphx_program_parameter_shapes_assign_to(migraphx_program_parameter_shapes_t output,
const_migraphx_program_parameter_shapes_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status extern "C" migraphx_status
migraphx_program_parameter_shapes_size(size_t* out, migraphx_program_parameter_shapes_size(size_t* out,
migraphx_program_parameter_shapes_t program_parameter_shapes) migraphx_program_parameter_shapes_t program_parameter_shapes)
...@@ -628,6 +657,14 @@ migraphx_program_parameters_destroy(migraphx_program_parameters_t program_parame ...@@ -628,6 +657,14 @@ migraphx_program_parameters_destroy(migraphx_program_parameters_t program_parame
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status
migraphx_program_parameters_assign_to(migraphx_program_parameters_t output,
const_migraphx_program_parameters_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status extern "C" migraphx_status
migraphx_program_parameters_create(migraphx_program_parameters_t* program_parameters) migraphx_program_parameters_create(migraphx_program_parameters_t* program_parameters)
{ {
...@@ -660,6 +697,13 @@ extern "C" migraphx_status migraphx_arguments_destroy(migraphx_arguments_t argum ...@@ -660,6 +697,13 @@ extern "C" migraphx_status migraphx_arguments_destroy(migraphx_arguments_t argum
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status migraphx_arguments_assign_to(migraphx_arguments_t output,
const_migraphx_arguments_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status migraphx_arguments_size(size_t* out, migraphx_arguments_t arguments) extern "C" migraphx_status migraphx_arguments_size(size_t* out, migraphx_arguments_t arguments)
{ {
auto api_error_result = migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
...@@ -687,6 +731,13 @@ extern "C" migraphx_status migraphx_shapes_destroy(migraphx_shapes_t shapes) ...@@ -687,6 +731,13 @@ extern "C" migraphx_status migraphx_shapes_destroy(migraphx_shapes_t shapes)
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status migraphx_shapes_assign_to(migraphx_shapes_t output,
const_migraphx_shapes_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status migraphx_shapes_size(size_t* out, migraphx_shapes_t shapes) extern "C" migraphx_status migraphx_shapes_size(size_t* out, migraphx_shapes_t shapes)
{ {
auto api_error_result = migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
...@@ -724,6 +775,13 @@ extern "C" migraphx_status migraphx_program_destroy(migraphx_program_t program) ...@@ -724,6 +775,13 @@ extern "C" migraphx_status migraphx_program_destroy(migraphx_program_t program)
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status migraphx_program_assign_to(migraphx_program_t output,
const_migraphx_program_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status migraphx_program_get_main_module(migraphx_module_t* out, extern "C" migraphx_status migraphx_program_get_main_module(migraphx_module_t* out,
migraphx_program_t program) migraphx_program_t program)
{ {
...@@ -828,6 +886,13 @@ extern "C" migraphx_status migraphx_operation_destroy(migraphx_operation_t opera ...@@ -828,6 +886,13 @@ extern "C" migraphx_status migraphx_operation_destroy(migraphx_operation_t opera
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status migraphx_operation_assign_to(migraphx_operation_t output,
const_migraphx_operation_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status migraphx_operation_create(migraphx_operation_t* operation, extern "C" migraphx_status migraphx_operation_create(migraphx_operation_t* operation,
const char* name, const char* name,
const char* attributes, const char* attributes,
...@@ -889,6 +954,13 @@ extern "C" migraphx_status migraphx_onnx_options_destroy(migraphx_onnx_options_t ...@@ -889,6 +954,13 @@ extern "C" migraphx_status migraphx_onnx_options_destroy(migraphx_onnx_options_t
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status migraphx_onnx_options_assign_to(migraphx_onnx_options_t output,
const_migraphx_onnx_options_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status migraphx_onnx_options_create(migraphx_onnx_options_t* onnx_options) extern "C" migraphx_status migraphx_onnx_options_create(migraphx_onnx_options_t* onnx_options)
{ {
auto api_error_result = migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
...@@ -940,6 +1012,13 @@ extern "C" migraphx_status migraphx_file_options_destroy(migraphx_file_options_t ...@@ -940,6 +1012,13 @@ extern "C" migraphx_status migraphx_file_options_destroy(migraphx_file_options_t
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status migraphx_file_options_assign_to(migraphx_file_options_t output,
const_migraphx_file_options_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status migraphx_file_options_create(migraphx_file_options_t* file_options) extern "C" migraphx_status migraphx_file_options_create(migraphx_file_options_t* file_options)
{ {
auto api_error_result = migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
...@@ -966,6 +1045,14 @@ migraphx_compile_options_destroy(migraphx_compile_options_t compile_options) ...@@ -966,6 +1045,14 @@ migraphx_compile_options_destroy(migraphx_compile_options_t compile_options)
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status
migraphx_compile_options_assign_to(migraphx_compile_options_t output,
const_migraphx_compile_options_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status extern "C" migraphx_status
migraphx_compile_options_create(migraphx_compile_options_t* compile_options) migraphx_compile_options_create(migraphx_compile_options_t* compile_options)
{ {
...@@ -1031,6 +1118,13 @@ extern "C" migraphx_status migraphx_tf_options_destroy(migraphx_tf_options_t tf_ ...@@ -1031,6 +1118,13 @@ extern "C" migraphx_status migraphx_tf_options_destroy(migraphx_tf_options_t tf_
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status migraphx_tf_options_assign_to(migraphx_tf_options_t output,
const_migraphx_tf_options_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status migraphx_tf_options_create(migraphx_tf_options_t* tf_options) extern "C" migraphx_status migraphx_tf_options_create(migraphx_tf_options_t* tf_options)
{ {
auto api_error_result = migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
...@@ -1108,6 +1202,14 @@ migraphx_quantize_op_names_destroy(migraphx_quantize_op_names_t quantize_op_name ...@@ -1108,6 +1202,14 @@ migraphx_quantize_op_names_destroy(migraphx_quantize_op_names_t quantize_op_name
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status
migraphx_quantize_op_names_assign_to(migraphx_quantize_op_names_t output,
const_migraphx_quantize_op_names_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status extern "C" migraphx_status
migraphx_quantize_op_names_create(migraphx_quantize_op_names_t* quantize_op_names) migraphx_quantize_op_names_create(migraphx_quantize_op_names_t* quantize_op_names)
{ {
...@@ -1160,6 +1262,14 @@ migraphx_quantize_int8_options_destroy(migraphx_quantize_int8_options_t quantize ...@@ -1160,6 +1262,14 @@ migraphx_quantize_int8_options_destroy(migraphx_quantize_int8_options_t quantize
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status
migraphx_quantize_int8_options_assign_to(migraphx_quantize_int8_options_t output,
const_migraphx_quantize_int8_options_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status extern "C" migraphx_status
migraphx_quantize_int8_options_create(migraphx_quantize_int8_options_t* quantize_int8_options) migraphx_quantize_int8_options_create(migraphx_quantize_int8_options_t* quantize_int8_options)
{ {
......
...@@ -91,6 +91,8 @@ typedef const struct migraphx_quantize_int8_options* const_migraphx_quantize_int ...@@ -91,6 +91,8 @@ typedef const struct migraphx_quantize_int8_options* const_migraphx_quantize_int
migraphx_status migraphx_shape_destroy(migraphx_shape_t shape); migraphx_status migraphx_shape_destroy(migraphx_shape_t shape);
migraphx_status migraphx_shape_assign_to(migraphx_shape_t output, const_migraphx_shape_t input);
migraphx_status migraphx_shape_create(migraphx_shape_t* shape, migraphx_status migraphx_shape_create(migraphx_shape_t* shape,
migraphx_shape_datatype_t type, migraphx_shape_datatype_t type,
int* lengths, int* lengths,
...@@ -121,6 +123,9 @@ migraphx_shape_equal(bool* out, const_migraphx_shape_t shape, const_migraphx_sha ...@@ -121,6 +123,9 @@ migraphx_shape_equal(bool* out, const_migraphx_shape_t shape, const_migraphx_sha
migraphx_status migraphx_argument_destroy(migraphx_argument_t argument); migraphx_status migraphx_argument_destroy(migraphx_argument_t argument);
migraphx_status migraphx_argument_assign_to(migraphx_argument_t output,
const_migraphx_argument_t input);
migraphx_status migraphx_status
migraphx_argument_create(migraphx_argument_t* argument, const_migraphx_shape_t shape, void* buffer); migraphx_argument_create(migraphx_argument_t* argument, const_migraphx_shape_t shape, void* buffer);
...@@ -137,11 +142,17 @@ migraphx_argument_generate(migraphx_argument_t* out, const_migraphx_shape_t s, s ...@@ -137,11 +142,17 @@ migraphx_argument_generate(migraphx_argument_t* out, const_migraphx_shape_t s, s
migraphx_status migraphx_target_destroy(migraphx_target_t target); migraphx_status migraphx_target_destroy(migraphx_target_t target);
migraphx_status migraphx_target_assign_to(migraphx_target_t output, const_migraphx_target_t input);
migraphx_status migraphx_target_create(migraphx_target_t* target, const char* name); migraphx_status migraphx_target_create(migraphx_target_t* target, const char* name);
migraphx_status migraphx_program_parameter_shapes_destroy( migraphx_status migraphx_program_parameter_shapes_destroy(
migraphx_program_parameter_shapes_t program_parameter_shapes); migraphx_program_parameter_shapes_t program_parameter_shapes);
migraphx_status
migraphx_program_parameter_shapes_assign_to(migraphx_program_parameter_shapes_t output,
const_migraphx_program_parameter_shapes_t input);
migraphx_status migraphx_program_parameter_shapes_size( migraphx_status migraphx_program_parameter_shapes_size(
size_t* out, migraphx_program_parameter_shapes_t program_parameter_shapes); size_t* out, migraphx_program_parameter_shapes_t program_parameter_shapes);
...@@ -156,6 +167,9 @@ migraphx_status migraphx_program_parameter_shapes_names( ...@@ -156,6 +167,9 @@ migraphx_status migraphx_program_parameter_shapes_names(
migraphx_status migraphx_status
migraphx_program_parameters_destroy(migraphx_program_parameters_t program_parameters); migraphx_program_parameters_destroy(migraphx_program_parameters_t program_parameters);
migraphx_status migraphx_program_parameters_assign_to(migraphx_program_parameters_t output,
const_migraphx_program_parameters_t input);
migraphx_status migraphx_status
migraphx_program_parameters_create(migraphx_program_parameters_t* program_parameters); migraphx_program_parameters_create(migraphx_program_parameters_t* program_parameters);
...@@ -165,6 +179,9 @@ migraphx_status migraphx_program_parameters_add(migraphx_program_parameters_t pr ...@@ -165,6 +179,9 @@ migraphx_status migraphx_program_parameters_add(migraphx_program_parameters_t pr
migraphx_status migraphx_arguments_destroy(migraphx_arguments_t arguments); migraphx_status migraphx_arguments_destroy(migraphx_arguments_t arguments);
migraphx_status migraphx_arguments_assign_to(migraphx_arguments_t output,
const_migraphx_arguments_t input);
migraphx_status migraphx_arguments_size(size_t* out, migraphx_arguments_t arguments); migraphx_status migraphx_arguments_size(size_t* out, migraphx_arguments_t arguments);
migraphx_status migraphx_status
...@@ -172,6 +189,8 @@ migraphx_arguments_get(const_migraphx_argument_t* out, migraphx_arguments_t argu ...@@ -172,6 +189,8 @@ migraphx_arguments_get(const_migraphx_argument_t* out, migraphx_arguments_t argu
migraphx_status migraphx_shapes_destroy(migraphx_shapes_t shapes); migraphx_status migraphx_shapes_destroy(migraphx_shapes_t shapes);
migraphx_status migraphx_shapes_assign_to(migraphx_shapes_t output, const_migraphx_shapes_t input);
migraphx_status migraphx_shapes_size(size_t* out, migraphx_shapes_t shapes); migraphx_status migraphx_shapes_size(size_t* out, migraphx_shapes_t shapes);
migraphx_status migraphx_status
...@@ -181,6 +200,9 @@ migraphx_status migraphx_module_print(const_migraphx_module_t module); ...@@ -181,6 +200,9 @@ migraphx_status migraphx_module_print(const_migraphx_module_t module);
migraphx_status migraphx_program_destroy(migraphx_program_t program); migraphx_status migraphx_program_destroy(migraphx_program_t program);
migraphx_status migraphx_program_assign_to(migraphx_program_t output,
const_migraphx_program_t input);
migraphx_status migraphx_program_get_main_module(migraphx_module_t* out, migraphx_status migraphx_program_get_main_module(migraphx_module_t* out,
migraphx_program_t program); migraphx_program_t program);
...@@ -207,6 +229,9 @@ migraphx_program_equal(bool* out, const_migraphx_program_t program, const_migrap ...@@ -207,6 +229,9 @@ migraphx_program_equal(bool* out, const_migraphx_program_t program, const_migrap
migraphx_status migraphx_operation_destroy(migraphx_operation_t operation); migraphx_status migraphx_operation_destroy(migraphx_operation_t operation);
migraphx_status migraphx_operation_assign_to(migraphx_operation_t output,
const_migraphx_operation_t input);
migraphx_status migraphx_operation_create(migraphx_operation_t* operation, migraphx_status migraphx_operation_create(migraphx_operation_t* operation,
const char* name, const char* name,
const char* attributes, const char* attributes,
...@@ -222,6 +247,9 @@ migraphx_save(migraphx_program_t p, const char* name, migraphx_file_options_t op ...@@ -222,6 +247,9 @@ migraphx_save(migraphx_program_t p, const char* name, migraphx_file_options_t op
migraphx_status migraphx_onnx_options_destroy(migraphx_onnx_options_t onnx_options); migraphx_status migraphx_onnx_options_destroy(migraphx_onnx_options_t onnx_options);
migraphx_status migraphx_onnx_options_assign_to(migraphx_onnx_options_t output,
const_migraphx_onnx_options_t input);
migraphx_status migraphx_onnx_options_create(migraphx_onnx_options_t* onnx_options); migraphx_status migraphx_onnx_options_create(migraphx_onnx_options_t* onnx_options);
migraphx_status migraphx_onnx_options_set_input_parameter_shape( migraphx_status migraphx_onnx_options_set_input_parameter_shape(
...@@ -236,6 +264,9 @@ migraphx_onnx_options_set_default_loop_iterations(migraphx_onnx_options_t onnx_o ...@@ -236,6 +264,9 @@ migraphx_onnx_options_set_default_loop_iterations(migraphx_onnx_options_t onnx_o
migraphx_status migraphx_file_options_destroy(migraphx_file_options_t file_options); migraphx_status migraphx_file_options_destroy(migraphx_file_options_t file_options);
migraphx_status migraphx_file_options_assign_to(migraphx_file_options_t output,
const_migraphx_file_options_t input);
migraphx_status migraphx_file_options_create(migraphx_file_options_t* file_options); migraphx_status migraphx_file_options_create(migraphx_file_options_t* file_options);
migraphx_status migraphx_file_options_set_file_format(migraphx_file_options_t file_options, migraphx_status migraphx_file_options_set_file_format(migraphx_file_options_t file_options,
...@@ -243,6 +274,9 @@ migraphx_status migraphx_file_options_set_file_format(migraphx_file_options_t fi ...@@ -243,6 +274,9 @@ migraphx_status migraphx_file_options_set_file_format(migraphx_file_options_t fi
migraphx_status migraphx_compile_options_destroy(migraphx_compile_options_t compile_options); migraphx_status migraphx_compile_options_destroy(migraphx_compile_options_t compile_options);
migraphx_status migraphx_compile_options_assign_to(migraphx_compile_options_t output,
const_migraphx_compile_options_t input);
migraphx_status migraphx_compile_options_create(migraphx_compile_options_t* compile_options); migraphx_status migraphx_compile_options_create(migraphx_compile_options_t* compile_options);
migraphx_status migraphx_status
...@@ -261,6 +295,9 @@ migraphx_status migraphx_parse_onnx_buffer(migraphx_program_t* out, ...@@ -261,6 +295,9 @@ migraphx_status migraphx_parse_onnx_buffer(migraphx_program_t* out,
migraphx_status migraphx_tf_options_destroy(migraphx_tf_options_t tf_options); migraphx_status migraphx_tf_options_destroy(migraphx_tf_options_t tf_options);
migraphx_status migraphx_tf_options_assign_to(migraphx_tf_options_t output,
const_migraphx_tf_options_t input);
migraphx_status migraphx_tf_options_create(migraphx_tf_options_t* tf_options); migraphx_status migraphx_tf_options_create(migraphx_tf_options_t* tf_options);
migraphx_status migraphx_tf_options_set_nhwc(migraphx_tf_options_t tf_options, bool is_nhwc); migraphx_status migraphx_tf_options_set_nhwc(migraphx_tf_options_t tf_options, bool is_nhwc);
...@@ -282,6 +319,9 @@ migraphx_parse_tf(migraphx_program_t* out, const char* name, migraphx_tf_options ...@@ -282,6 +319,9 @@ migraphx_parse_tf(migraphx_program_t* out, const char* name, migraphx_tf_options
migraphx_status migraphx_quantize_op_names_destroy(migraphx_quantize_op_names_t quantize_op_names); migraphx_status migraphx_quantize_op_names_destroy(migraphx_quantize_op_names_t quantize_op_names);
migraphx_status migraphx_quantize_op_names_assign_to(migraphx_quantize_op_names_t output,
const_migraphx_quantize_op_names_t input);
migraphx_status migraphx_quantize_op_names_create(migraphx_quantize_op_names_t* quantize_op_names); migraphx_status migraphx_quantize_op_names_create(migraphx_quantize_op_names_t* quantize_op_names);
migraphx_status migraphx_quantize_op_names_add(migraphx_quantize_op_names_t quantize_op_names, migraphx_status migraphx_quantize_op_names_add(migraphx_quantize_op_names_t quantize_op_names,
...@@ -295,6 +335,10 @@ migraphx_status migraphx_quantize_fp16(migraphx_program_t prog); ...@@ -295,6 +335,10 @@ migraphx_status migraphx_quantize_fp16(migraphx_program_t prog);
migraphx_status migraphx_status
migraphx_quantize_int8_options_destroy(migraphx_quantize_int8_options_t quantize_int8_options); migraphx_quantize_int8_options_destroy(migraphx_quantize_int8_options_t quantize_int8_options);
migraphx_status
migraphx_quantize_int8_options_assign_to(migraphx_quantize_int8_options_t output,
const_migraphx_quantize_int8_options_t input);
migraphx_status migraphx_status
migraphx_quantize_int8_options_create(migraphx_quantize_int8_options_t* quantize_int8_options); migraphx_quantize_int8_options_create(migraphx_quantize_int8_options_t* quantize_int8_options);
......
...@@ -159,7 +159,7 @@ struct borrow ...@@ -159,7 +159,7 @@ struct borrow
{ {
}; };
template <class T, class D, D Deleter> template <class T, class D, D Deleter, class A, A Assigner>
struct handle_base struct handle_base
{ {
handle_base() : m_handle(nullptr) {} handle_base() : m_handle(nullptr) {}
...@@ -190,6 +190,12 @@ struct handle_base ...@@ -190,6 +190,12 @@ struct handle_base
m_handle = std::shared_ptr<U>{ptr, [](U*) {}}; m_handle = std::shared_ptr<U>{ptr, [](U*) {}};
} }
template <class U>
void assign_to_handle(U* x)
{
Assigner(x, this->get_handle_ptr());
}
protected: protected:
std::shared_ptr<T> m_handle; std::shared_ptr<T> m_handle;
}; };
...@@ -197,10 +203,12 @@ struct handle_base ...@@ -197,10 +203,12 @@ struct handle_base
#ifdef DOXYGEN #ifdef DOXYGEN
#define MIGRAPHX_DETAIL_HANDLE_BASE(name, const_) handle_base<> #define MIGRAPHX_DETAIL_HANDLE_BASE(name, const_) handle_base<>
#else #else
#define MIGRAPHX_DETAIL_HANDLE_BASE(name, const_) \ #define MIGRAPHX_DETAIL_HANDLE_BASE(name, const_) \
handle_base<const_ migraphx_##name, \ handle_base<const_ migraphx_##name, \
decltype(&migraphx_##name##_destroy), \ decltype(&migraphx_##name##_destroy), \
migraphx_##name##_destroy> migraphx_##name##_destroy, \
decltype(&migraphx_##name##_assign_to), \
migraphx_##name##_assign_to>
#endif #endif
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define MIGRAPHX_HANDLE_BASE(name) MIGRAPHX_DETAIL_HANDLE_BASE(name, ) #define MIGRAPHX_HANDLE_BASE(name) MIGRAPHX_DETAIL_HANDLE_BASE(name, )
......
...@@ -37,11 +37,11 @@ struct unsqueeze ...@@ -37,11 +37,11 @@ struct unsqueeze
std::string name() const { return "unsqueeze"; } std::string name() const { return "unsqueeze"; }
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1).standard_or_scalar(); check_shapes{inputs, *this}.has(1);
auto input_shape = inputs[0]; auto input_shape = inputs[0];
auto type = input_shape.type(); auto type = input_shape.type();
auto old_lens = input_shape.lens(); auto old_lens = input_shape.lens();
auto old_strides = input_shape.strides();
if(input_shape.scalar()) if(input_shape.scalar())
{ {
if(old_lens.size() == 1 and old_lens.front() == 1) if(old_lens.size() == 1 and old_lens.front() == 1)
...@@ -53,19 +53,29 @@ struct unsqueeze ...@@ -53,19 +53,29 @@ struct unsqueeze
int new_size = old_lens.size() + axes.size(); int new_size = old_lens.size() + axes.size();
std::vector<int> new_lens(new_size); std::vector<int> new_lens(new_size);
int p = 0; std::vector<int> new_strides(new_size);
for(int i = 0; i < new_size; i++) std::size_t p = 0;
for(auto i : range(new_size))
{ {
if(std::find(axes.begin(), axes.end(), i) != axes.end()) if(std::find(axes.begin(), axes.end(), i) != axes.end())
{ {
new_lens[i] = 1; new_lens[i] = 1;
if(p == 0) // unsqueeze on the first axes
{
new_strides[i] = old_lens[0] * old_strides[0];
}
else // unsqueeze on middle or last axes
{
new_strides[i] = (p < old_strides.size()) ? old_strides[p - 1] : 1;
}
} }
else else
{ {
new_lens[i] = old_lens[p++]; new_lens[i] = old_lens[p];
new_strides[i] = old_strides[p++];
} }
} }
return shape{type, new_lens}; return shape{type, new_lens, new_strides};
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
{ {
......
...@@ -246,9 +246,9 @@ bool shape::transposed() const ...@@ -246,9 +246,9 @@ bool shape::transposed() const
bool shape::broadcasted() const bool shape::broadcasted() const
{ {
assert(this->lens().size() == this->strides().size()); const auto& strides = this->strides();
return std::accumulate( assert(this->lens().size() == strides.size());
this->strides().begin(), this->strides().end(), int{1}, std::multiplies<int>()) == 0; return std::any_of(strides.begin(), strides.end(), [](auto i) { return i == 0; });
} }
bool shape::scalar() const bool shape::scalar() const
......
...@@ -10,6 +10,7 @@ function(add_api_test TEST_NAME TEST_SRC TEST_DIR) ...@@ -10,6 +10,7 @@ function(add_api_test TEST_NAME TEST_SRC TEST_DIR)
add_dependencies(check ${NAME}) add_dependencies(check ${NAME})
endfunction() endfunction()
add_api_test(assign test_assign.cpp ${TEST_ONNX_DIR})
add_api_test(compile_options test_compile_options.cpp ${TEST_ONNX_DIR}) add_api_test(compile_options test_compile_options.cpp ${TEST_ONNX_DIR})
add_api_test(ref test_cpu.cpp ${TEST_ONNX_DIR}) add_api_test(ref test_cpu.cpp ${TEST_ONNX_DIR})
add_api_test(save_load test_save_load.cpp ${TEST_ONNX_DIR}) add_api_test(save_load test_save_load.cpp ${TEST_ONNX_DIR})
......
#include <migraphx/migraphx.h>
#include <migraphx/migraphx.hpp>
#include "test.hpp"
TEST_CASE(shape_assign)
{
auto s1_cpp = migraphx::shape{migraphx_shape_float_type, {1, 3}};
std::vector<size_t> lens{2, 3};
// handle ptr is const, workaround to construct shape using C API
migraphx_shape_t s2;
migraphx_shape_create(&s2, migraphx_shape_float_type, lens.data(), lens.size());
auto s2_cpp = migraphx::shape(s2, migraphx::own{});
CHECK(bool{s1_cpp != s2_cpp});
// use C++ API for assignment
s1_cpp.assign_to_handle(s2);
CHECK(bool{s1_cpp == s2_cpp});
auto s3_cpp = migraphx::shape{migraphx_shape_float_type, lens};
// use C API for assignment
migraphx_shape_assign_to(s2, s3_cpp.get_handle_ptr());
CHECK(bool{s2_cpp == s3_cpp});
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -1513,6 +1513,55 @@ TEST_CASE(test_unsqueeze_scalar_tensor2) ...@@ -1513,6 +1513,55 @@ TEST_CASE(test_unsqueeze_scalar_tensor2)
throws_shape(migraphx::make_op("unsqueeze", {{"axes", {-2}}}), s); throws_shape(migraphx::make_op("unsqueeze", {{"axes", {-2}}}), s);
} }
TEST_CASE(test_unsqueeze_transpose)
{
migraphx::shape s1{migraphx::shape::float_type, {4, 4, 3}, {12, 1, 4}};
migraphx::shape s2{migraphx::shape::float_type, {4, 4, 1, 3}, {12, 1, 1, 4}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {2}}}), s1);
}
TEST_CASE(test_unsqueeze_multibroadcast)
{
migraphx::shape s1{migraphx::shape::float_type, {2, 3, 4}, {0, 1, 0}};
migraphx::shape s2{migraphx::shape::float_type, {2, 3, 1, 4}, {0, 1, 1, 0}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {2}}}), s1);
}
TEST_CASE(test_unsqueeze_slice)
{
migraphx::shape s1{migraphx::shape::float_type, {2, 3, 4}, {108, 36, 1}};
migraphx::shape s2{migraphx::shape::float_type, {2, 3, 1, 4}, {108, 36, 36, 1}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {2}}}), s1);
}
TEST_CASE(test_unsqueeze_axis_zero)
{
migraphx::shape s1{migraphx::shape::float_type, {2, 3, 4}};
migraphx::shape s2{migraphx::shape::float_type, {1, 2, 3, 4}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {0}}}), s1);
}
TEST_CASE(test_unsqueeze_axis_last)
{
migraphx::shape s1{migraphx::shape::float_type, {2, 3, 4}};
migraphx::shape s2{migraphx::shape::float_type, {2, 3, 4, 1}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {-1}}}), s1);
}
TEST_CASE(test_unsqueeze_multiple_axes_1)
{
migraphx::shape s1{migraphx::shape::float_type, {2, 3, 4}};
migraphx::shape s2{migraphx::shape::float_type, {1, 2, 3, 4, 1}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {0, -1}}}), s1);
}
TEST_CASE(test_unsqueeze_multiple_axes_2)
{
migraphx::shape s1{migraphx::shape::float_type, {2, 3, 4}};
migraphx::shape s2{migraphx::shape::float_type, {1, 1, 2, 3, 4}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {0, 1}}}), s1);
}
TEST_CASE(transpose_shape) TEST_CASE(transpose_shape)
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 2}}; migraphx::shape input{migraphx::shape::float_type, {2, 2}};
......
...@@ -57,14 +57,15 @@ TEST_CASE(squeeze_transpose_test) ...@@ -57,14 +57,15 @@ TEST_CASE(squeeze_transpose_test)
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 2, 3, 0, 4}}}), l0); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 2, 3, 0, 4}}}), l0);
mm->add_instruction(migraphx::make_op("squeeze"), l0_trans); mm->add_instruction(migraphx::make_op("squeeze"), l0_trans);
auto p_uncompiled = p; auto p_uncompiled = p;
// contiguous is required to read the values in standard shaped order
auto* mm_uncompiled = p_uncompiled.get_main_module();
mm_uncompiled->add_instruction(migraphx::make_op("contiguous"),
std::prev(mm_uncompiled->end()));
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
auto expected_result = p_uncompiled.eval({}).back(); auto expected_result = p_uncompiled.eval({}).back();
// contiguous is required to read the values in standard shaped order
auto tr_op = migraphx::make_op("contiguous");
auto std_expected_result = tr_op.compute(result.get_shape(), {expected_result});
EXPECT(result.get_shape() == migraphx::shape{migraphx::shape::float_type, {3, 4, 3}}); EXPECT(result.get_shape() == migraphx::shape{migraphx::shape::float_type, {3, 4, 3}});
EXPECT(result == std_expected_result); EXPECT(result == expected_result);
} }
TEST_CASE(squeeze_multibroadcast_test) TEST_CASE(squeeze_multibroadcast_test)
...@@ -76,14 +77,15 @@ TEST_CASE(squeeze_multibroadcast_test) ...@@ -76,14 +77,15 @@ TEST_CASE(squeeze_multibroadcast_test)
auto l0_brcst = mm->add_instruction( auto l0_brcst = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {4, 1, 3, 4, 3}}}), l0); migraphx::make_op("multibroadcast", {{"out_lens", {4, 1, 3, 4, 3}}}), l0);
mm->add_instruction(migraphx::make_op("squeeze"), l0_brcst); mm->add_instruction(migraphx::make_op("squeeze"), l0_brcst);
auto p_uncompiled = p; auto p_uncompiled = p;
auto* mm_uncompiled = p_uncompiled.get_main_module();
mm_uncompiled->add_instruction(migraphx::make_op("contiguous"),
std::prev(mm_uncompiled->end()));
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
auto expected_result = p_uncompiled.eval({}).back(); auto expected_result = p_uncompiled.eval({}).back();
auto tr_op = migraphx::make_op("contiguous");
auto std_expected_result = tr_op.compute(result.get_shape(), {expected_result});
EXPECT(result.get_shape() == migraphx::shape{migraphx::shape::float_type, {4, 3, 4, 3}}); EXPECT(result.get_shape() == migraphx::shape{migraphx::shape::float_type, {4, 3, 4, 3}});
EXPECT(result == std_expected_result); EXPECT(result == expected_result);
} }
TEST_CASE(squeeze_slice_test) TEST_CASE(squeeze_slice_test)
...@@ -95,14 +97,75 @@ TEST_CASE(squeeze_slice_test) ...@@ -95,14 +97,75 @@ TEST_CASE(squeeze_slice_test)
auto l0_slice = mm->add_instruction( auto l0_slice = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {2}}, {"ends", {3}}}), l0); migraphx::make_op("slice", {{"axes", {2}}, {"starts", {2}}, {"ends", {3}}}), l0);
mm->add_instruction(migraphx::make_op("squeeze"), l0_slice); mm->add_instruction(migraphx::make_op("squeeze"), l0_slice);
auto p_uncompiled = p; auto p_uncompiled = p;
auto* mm_uncompiled = p_uncompiled.get_main_module();
mm_uncompiled->add_instruction(migraphx::make_op("contiguous"),
std::prev(mm_uncompiled->end()));
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
auto expected_result = p_uncompiled.eval({}).back(); auto expected_result = p_uncompiled.eval({}).back();
auto tr_op = migraphx::make_op("contiguous");
auto std_expected_result = tr_op.compute(result.get_shape(), {expected_result});
EXPECT(result.get_shape() == migraphx::shape{migraphx::shape::float_type, {3, 3}}); EXPECT(result.get_shape() == migraphx::shape{migraphx::shape::float_type, {3, 3}});
EXPECT(result == std_expected_result); EXPECT(result == expected_result);
}
TEST_CASE(unsqueeze_transpose_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s1{migraphx::shape::float_type, {4, 3, 3}};
auto l0 = mm->add_literal(migraphx::generate_literal(s1));
auto l0_trans =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 0, 1}}}), l0);
mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), l0_trans);
auto p_uncompiled = p;
auto* mm_uncompiled = p_uncompiled.get_main_module();
mm_uncompiled->add_instruction(migraphx::make_op("contiguous"),
std::prev(mm_uncompiled->end()));
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
auto expected_result = p_uncompiled.eval({}).back();
EXPECT(result.get_shape() == migraphx::shape{migraphx::shape::float_type, {3, 4, 1, 3}});
EXPECT(result == expected_result);
}
TEST_CASE(unsqueeze_multibroadcast_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s1{migraphx::shape::float_type, {4, 1, 3}};
auto l0 = mm->add_literal(migraphx::generate_literal(s1));
auto l0_brcst =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {4, 4, 3, 3}}}), l0);
mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), l0_brcst);
auto p_uncompiled = p;
auto* mm_uncompiled = p_uncompiled.get_main_module();
mm_uncompiled->add_instruction(migraphx::make_op("contiguous"),
std::prev(mm_uncompiled->end()));
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
auto expected_result = p_uncompiled.eval({}).back();
EXPECT(result.get_shape() == migraphx::shape{migraphx::shape::float_type, {4, 4, 1, 3, 3}});
EXPECT(result == expected_result);
}
TEST_CASE(unsqueeze_slice_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s1{migraphx::shape::float_type, {2, 3, 4, 4}};
auto l0 = mm->add_literal(migraphx::generate_literal(s1));
auto l0_slice = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {3}}, {"starts", {2}}, {"ends", {3}}}), l0);
mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), l0_slice);
auto p_uncompiled = p;
auto* mm_uncompiled = p_uncompiled.get_main_module();
mm_uncompiled->add_instruction(migraphx::make_op("contiguous"),
std::prev(mm_uncompiled->end()));
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
auto expected_result = p_uncompiled.eval({}).back();
EXPECT(result.get_shape() == migraphx::shape{migraphx::shape::float_type, {2, 1, 3, 4, 1}});
EXPECT(result == expected_result);
} }
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -15,7 +15,7 @@ c_api_body_preamble: List[str] = [] ...@@ -15,7 +15,7 @@ c_api_body_preamble: List[str] = []
cpp_header_preamble: List[str] = [] cpp_header_preamble: List[str] = []
def bad_param_error(msg): def bad_param_error(msg: str):
return 'throw std::runtime_error("{}")'.format(msg) return 'throw std::runtime_error("{}")'.format(msg)
...@@ -89,7 +89,7 @@ class Type: ...@@ -89,7 +89,7 @@ class Type:
else: else:
return t.remove_const() return t.remove_const()
def const_compatible(self, t): def const_compatible(self, t: 'Type'):
if t.is_const(): if t.is_const():
return self.add_const() return self.add_const()
return self return self
...@@ -720,6 +720,7 @@ def add_handle(name: str, ...@@ -720,6 +720,7 @@ def add_handle(name: str,
destroy: Optional[str] = None, destroy: Optional[str] = None,
ref: Optional[bool] = None) -> None: ref: Optional[bool] = None) -> None:
opaque_type = ctype + '_t' opaque_type = ctype + '_t'
const_opaque_type = 'const_' + opaque_type
def handle_wrap(p): def handle_wrap(p):
t = Type(opaque_type) t = Type(opaque_type)
...@@ -747,6 +748,9 @@ def add_handle(name: str, ...@@ -747,6 +748,9 @@ def add_handle(name: str,
add_function(destroy or ctype + '_' + 'destroy', add_function(destroy or ctype + '_' + 'destroy',
params({name: opaque_type}), params({name: opaque_type}),
fname='destroy') fname='destroy')
add_function(ctype + '_' + 'assign_to',
params(output=opaque_type, input=const_opaque_type),
invoke='*output = *input')
add_handle_preamble() add_handle_preamble()
c_header_preamble.append(handle_typedef.substitute(locals())) c_header_preamble.append(handle_typedef.substitute(locals()))
c_api_body_preamble.append(handle_definition.substitute(locals())) c_api_body_preamble.append(handle_definition.substitute(locals()))
......
...@@ -14,4 +14,6 @@ function api { ...@@ -14,4 +14,6 @@ function api {
} }
api $DIR/api/migraphx.h $SRC_DIR/api/include/migraphx/migraphx.h api $DIR/api/migraphx.h $SRC_DIR/api/include/migraphx/migraphx.h
echo "Finished generating header migraphx.h"
api $DIR/api/api.cpp $SRC_DIR/api/api.cpp api $DIR/api/api.cpp $SRC_DIR/api/api.cpp
echo "Finished generating source api.cpp "
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment