Commit 6529e7c9 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'alias_0_as_ret_value' into print_matmul_perf_flops

parents 1c425f14 65de331e
# Modifications Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved
# Copyright 2018 The Google AI Language Team Authors. # Copyright 2018 The Google AI Language Team Authors.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
......
...@@ -385,6 +385,13 @@ extern "C" migraphx_status migraphx_shape_destroy(migraphx_shape_t shape) ...@@ -385,6 +385,13 @@ extern "C" migraphx_status migraphx_shape_destroy(migraphx_shape_t shape)
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status migraphx_shape_assign_to(migraphx_shape_t output,
const_migraphx_shape_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status migraphx_shape_create(migraphx_shape_t* shape, extern "C" migraphx_status migraphx_shape_create(migraphx_shape_t* shape,
migraphx_shape_datatype_t type, migraphx_shape_datatype_t type,
size_t* lengths, size_t* lengths,
...@@ -502,6 +509,13 @@ extern "C" migraphx_status migraphx_argument_destroy(migraphx_argument_t argumen ...@@ -502,6 +509,13 @@ extern "C" migraphx_status migraphx_argument_destroy(migraphx_argument_t argumen
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status migraphx_argument_assign_to(migraphx_argument_t output,
const_migraphx_argument_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status extern "C" migraphx_status
migraphx_argument_create(migraphx_argument_t* argument, const_migraphx_shape_t shape, void* buffer) migraphx_argument_create(migraphx_argument_t* argument, const_migraphx_shape_t shape, void* buffer)
{ {
...@@ -565,6 +579,13 @@ extern "C" migraphx_status migraphx_target_destroy(migraphx_target_t target) ...@@ -565,6 +579,13 @@ extern "C" migraphx_status migraphx_target_destroy(migraphx_target_t target)
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status migraphx_target_assign_to(migraphx_target_t output,
const_migraphx_target_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status migraphx_target_create(migraphx_target_t* target, const char* name) extern "C" migraphx_status migraphx_target_create(migraphx_target_t* target, const char* name)
{ {
auto api_error_result = migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
...@@ -581,6 +602,14 @@ extern "C" migraphx_status migraphx_program_parameter_shapes_destroy( ...@@ -581,6 +602,14 @@ extern "C" migraphx_status migraphx_program_parameter_shapes_destroy(
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status
migraphx_program_parameter_shapes_assign_to(migraphx_program_parameter_shapes_t output,
const_migraphx_program_parameter_shapes_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status extern "C" migraphx_status
migraphx_program_parameter_shapes_size(size_t* out, migraphx_program_parameter_shapes_size(size_t* out,
migraphx_program_parameter_shapes_t program_parameter_shapes) migraphx_program_parameter_shapes_t program_parameter_shapes)
...@@ -631,6 +660,14 @@ migraphx_program_parameters_destroy(migraphx_program_parameters_t program_parame ...@@ -631,6 +660,14 @@ migraphx_program_parameters_destroy(migraphx_program_parameters_t program_parame
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status
migraphx_program_parameters_assign_to(migraphx_program_parameters_t output,
const_migraphx_program_parameters_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status extern "C" migraphx_status
migraphx_program_parameters_create(migraphx_program_parameters_t* program_parameters) migraphx_program_parameters_create(migraphx_program_parameters_t* program_parameters)
{ {
...@@ -663,6 +700,13 @@ extern "C" migraphx_status migraphx_arguments_destroy(migraphx_arguments_t argum ...@@ -663,6 +700,13 @@ extern "C" migraphx_status migraphx_arguments_destroy(migraphx_arguments_t argum
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status migraphx_arguments_assign_to(migraphx_arguments_t output,
const_migraphx_arguments_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status migraphx_arguments_size(size_t* out, migraphx_arguments_t arguments) extern "C" migraphx_status migraphx_arguments_size(size_t* out, migraphx_arguments_t arguments)
{ {
auto api_error_result = migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
...@@ -690,6 +734,13 @@ extern "C" migraphx_status migraphx_shapes_destroy(migraphx_shapes_t shapes) ...@@ -690,6 +734,13 @@ extern "C" migraphx_status migraphx_shapes_destroy(migraphx_shapes_t shapes)
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status migraphx_shapes_assign_to(migraphx_shapes_t output,
const_migraphx_shapes_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status migraphx_shapes_size(size_t* out, migraphx_shapes_t shapes) extern "C" migraphx_status migraphx_shapes_size(size_t* out, migraphx_shapes_t shapes)
{ {
auto api_error_result = migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
...@@ -727,6 +778,13 @@ extern "C" migraphx_status migraphx_program_destroy(migraphx_program_t program) ...@@ -727,6 +778,13 @@ extern "C" migraphx_status migraphx_program_destroy(migraphx_program_t program)
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status migraphx_program_assign_to(migraphx_program_t output,
const_migraphx_program_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status migraphx_program_get_main_module(migraphx_module_t* out, extern "C" migraphx_status migraphx_program_get_main_module(migraphx_module_t* out,
migraphx_program_t program) migraphx_program_t program)
{ {
...@@ -831,6 +889,13 @@ extern "C" migraphx_status migraphx_operation_destroy(migraphx_operation_t opera ...@@ -831,6 +889,13 @@ extern "C" migraphx_status migraphx_operation_destroy(migraphx_operation_t opera
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status migraphx_operation_assign_to(migraphx_operation_t output,
const_migraphx_operation_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status migraphx_operation_create(migraphx_operation_t* operation, extern "C" migraphx_status migraphx_operation_create(migraphx_operation_t* operation,
const char* name, const char* name,
const char* attributes, const char* attributes,
...@@ -891,6 +956,13 @@ extern "C" migraphx_status migraphx_onnx_options_destroy(migraphx_onnx_options_t ...@@ -891,6 +956,13 @@ extern "C" migraphx_status migraphx_onnx_options_destroy(migraphx_onnx_options_t
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status migraphx_onnx_options_assign_to(migraphx_onnx_options_t output,
const_migraphx_onnx_options_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status migraphx_onnx_options_create(migraphx_onnx_options_t* onnx_options) extern "C" migraphx_status migraphx_onnx_options_create(migraphx_onnx_options_t* onnx_options)
{ {
auto api_error_result = migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
...@@ -942,6 +1014,13 @@ extern "C" migraphx_status migraphx_file_options_destroy(migraphx_file_options_t ...@@ -942,6 +1014,13 @@ extern "C" migraphx_status migraphx_file_options_destroy(migraphx_file_options_t
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status migraphx_file_options_assign_to(migraphx_file_options_t output,
const_migraphx_file_options_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status migraphx_file_options_create(migraphx_file_options_t* file_options) extern "C" migraphx_status migraphx_file_options_create(migraphx_file_options_t* file_options)
{ {
auto api_error_result = migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
...@@ -968,6 +1047,14 @@ migraphx_compile_options_destroy(migraphx_compile_options_t compile_options) ...@@ -968,6 +1047,14 @@ migraphx_compile_options_destroy(migraphx_compile_options_t compile_options)
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status
migraphx_compile_options_assign_to(migraphx_compile_options_t output,
const_migraphx_compile_options_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status extern "C" migraphx_status
migraphx_compile_options_create(migraphx_compile_options_t* compile_options) migraphx_compile_options_create(migraphx_compile_options_t* compile_options)
{ {
...@@ -1033,6 +1120,13 @@ extern "C" migraphx_status migraphx_tf_options_destroy(migraphx_tf_options_t tf_ ...@@ -1033,6 +1120,13 @@ extern "C" migraphx_status migraphx_tf_options_destroy(migraphx_tf_options_t tf_
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status migraphx_tf_options_assign_to(migraphx_tf_options_t output,
const_migraphx_tf_options_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status migraphx_tf_options_create(migraphx_tf_options_t* tf_options) extern "C" migraphx_status migraphx_tf_options_create(migraphx_tf_options_t* tf_options)
{ {
auto api_error_result = migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
...@@ -1110,6 +1204,14 @@ migraphx_quantize_op_names_destroy(migraphx_quantize_op_names_t quantize_op_name ...@@ -1110,6 +1204,14 @@ migraphx_quantize_op_names_destroy(migraphx_quantize_op_names_t quantize_op_name
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status
migraphx_quantize_op_names_assign_to(migraphx_quantize_op_names_t output,
const_migraphx_quantize_op_names_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status extern "C" migraphx_status
migraphx_quantize_op_names_create(migraphx_quantize_op_names_t* quantize_op_names) migraphx_quantize_op_names_create(migraphx_quantize_op_names_t* quantize_op_names)
{ {
...@@ -1162,6 +1264,14 @@ migraphx_quantize_int8_options_destroy(migraphx_quantize_int8_options_t quantize ...@@ -1162,6 +1264,14 @@ migraphx_quantize_int8_options_destroy(migraphx_quantize_int8_options_t quantize
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status
migraphx_quantize_int8_options_assign_to(migraphx_quantize_int8_options_t output,
const_migraphx_quantize_int8_options_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status extern "C" migraphx_status
migraphx_quantize_int8_options_create(migraphx_quantize_int8_options_t* quantize_int8_options) migraphx_quantize_int8_options_create(migraphx_quantize_int8_options_t* quantize_int8_options)
{ {
......
...@@ -91,6 +91,8 @@ typedef const struct migraphx_quantize_int8_options* const_migraphx_quantize_int ...@@ -91,6 +91,8 @@ typedef const struct migraphx_quantize_int8_options* const_migraphx_quantize_int
migraphx_status migraphx_shape_destroy(migraphx_shape_t shape); migraphx_status migraphx_shape_destroy(migraphx_shape_t shape);
migraphx_status migraphx_shape_assign_to(migraphx_shape_t output, const_migraphx_shape_t input);
migraphx_status migraphx_shape_create(migraphx_shape_t* shape, migraphx_status migraphx_shape_create(migraphx_shape_t* shape,
migraphx_shape_datatype_t type, migraphx_shape_datatype_t type,
size_t* lengths, size_t* lengths,
...@@ -121,6 +123,9 @@ migraphx_shape_equal(bool* out, const_migraphx_shape_t shape, const_migraphx_sha ...@@ -121,6 +123,9 @@ migraphx_shape_equal(bool* out, const_migraphx_shape_t shape, const_migraphx_sha
migraphx_status migraphx_argument_destroy(migraphx_argument_t argument); migraphx_status migraphx_argument_destroy(migraphx_argument_t argument);
migraphx_status migraphx_argument_assign_to(migraphx_argument_t output,
const_migraphx_argument_t input);
migraphx_status migraphx_status
migraphx_argument_create(migraphx_argument_t* argument, const_migraphx_shape_t shape, void* buffer); migraphx_argument_create(migraphx_argument_t* argument, const_migraphx_shape_t shape, void* buffer);
...@@ -137,11 +142,17 @@ migraphx_argument_generate(migraphx_argument_t* out, const_migraphx_shape_t s, s ...@@ -137,11 +142,17 @@ migraphx_argument_generate(migraphx_argument_t* out, const_migraphx_shape_t s, s
migraphx_status migraphx_target_destroy(migraphx_target_t target); migraphx_status migraphx_target_destroy(migraphx_target_t target);
migraphx_status migraphx_target_assign_to(migraphx_target_t output, const_migraphx_target_t input);
migraphx_status migraphx_target_create(migraphx_target_t* target, const char* name); migraphx_status migraphx_target_create(migraphx_target_t* target, const char* name);
migraphx_status migraphx_program_parameter_shapes_destroy( migraphx_status migraphx_program_parameter_shapes_destroy(
migraphx_program_parameter_shapes_t program_parameter_shapes); migraphx_program_parameter_shapes_t program_parameter_shapes);
migraphx_status
migraphx_program_parameter_shapes_assign_to(migraphx_program_parameter_shapes_t output,
const_migraphx_program_parameter_shapes_t input);
migraphx_status migraphx_program_parameter_shapes_size( migraphx_status migraphx_program_parameter_shapes_size(
size_t* out, migraphx_program_parameter_shapes_t program_parameter_shapes); size_t* out, migraphx_program_parameter_shapes_t program_parameter_shapes);
...@@ -156,6 +167,9 @@ migraphx_status migraphx_program_parameter_shapes_names( ...@@ -156,6 +167,9 @@ migraphx_status migraphx_program_parameter_shapes_names(
migraphx_status migraphx_status
migraphx_program_parameters_destroy(migraphx_program_parameters_t program_parameters); migraphx_program_parameters_destroy(migraphx_program_parameters_t program_parameters);
migraphx_status migraphx_program_parameters_assign_to(migraphx_program_parameters_t output,
const_migraphx_program_parameters_t input);
migraphx_status migraphx_status
migraphx_program_parameters_create(migraphx_program_parameters_t* program_parameters); migraphx_program_parameters_create(migraphx_program_parameters_t* program_parameters);
...@@ -165,6 +179,9 @@ migraphx_status migraphx_program_parameters_add(migraphx_program_parameters_t pr ...@@ -165,6 +179,9 @@ migraphx_status migraphx_program_parameters_add(migraphx_program_parameters_t pr
migraphx_status migraphx_arguments_destroy(migraphx_arguments_t arguments); migraphx_status migraphx_arguments_destroy(migraphx_arguments_t arguments);
migraphx_status migraphx_arguments_assign_to(migraphx_arguments_t output,
const_migraphx_arguments_t input);
migraphx_status migraphx_arguments_size(size_t* out, migraphx_arguments_t arguments); migraphx_status migraphx_arguments_size(size_t* out, migraphx_arguments_t arguments);
migraphx_status migraphx_status
...@@ -172,6 +189,8 @@ migraphx_arguments_get(const_migraphx_argument_t* out, migraphx_arguments_t argu ...@@ -172,6 +189,8 @@ migraphx_arguments_get(const_migraphx_argument_t* out, migraphx_arguments_t argu
migraphx_status migraphx_shapes_destroy(migraphx_shapes_t shapes); migraphx_status migraphx_shapes_destroy(migraphx_shapes_t shapes);
migraphx_status migraphx_shapes_assign_to(migraphx_shapes_t output, const_migraphx_shapes_t input);
migraphx_status migraphx_shapes_size(size_t* out, migraphx_shapes_t shapes); migraphx_status migraphx_shapes_size(size_t* out, migraphx_shapes_t shapes);
migraphx_status migraphx_status
...@@ -181,6 +200,9 @@ migraphx_status migraphx_module_print(const_migraphx_module_t module); ...@@ -181,6 +200,9 @@ migraphx_status migraphx_module_print(const_migraphx_module_t module);
migraphx_status migraphx_program_destroy(migraphx_program_t program); migraphx_status migraphx_program_destroy(migraphx_program_t program);
migraphx_status migraphx_program_assign_to(migraphx_program_t output,
const_migraphx_program_t input);
migraphx_status migraphx_program_get_main_module(migraphx_module_t* out, migraphx_status migraphx_program_get_main_module(migraphx_module_t* out,
migraphx_program_t program); migraphx_program_t program);
...@@ -207,6 +229,9 @@ migraphx_program_equal(bool* out, const_migraphx_program_t program, const_migrap ...@@ -207,6 +229,9 @@ migraphx_program_equal(bool* out, const_migraphx_program_t program, const_migrap
migraphx_status migraphx_operation_destroy(migraphx_operation_t operation); migraphx_status migraphx_operation_destroy(migraphx_operation_t operation);
migraphx_status migraphx_operation_assign_to(migraphx_operation_t output,
const_migraphx_operation_t input);
migraphx_status migraphx_operation_create(migraphx_operation_t* operation, migraphx_status migraphx_operation_create(migraphx_operation_t* operation,
const char* name, const char* name,
const char* attributes, const char* attributes,
...@@ -222,6 +247,9 @@ migraphx_save(migraphx_program_t p, const char* name, migraphx_file_options_t op ...@@ -222,6 +247,9 @@ migraphx_save(migraphx_program_t p, const char* name, migraphx_file_options_t op
migraphx_status migraphx_onnx_options_destroy(migraphx_onnx_options_t onnx_options); migraphx_status migraphx_onnx_options_destroy(migraphx_onnx_options_t onnx_options);
migraphx_status migraphx_onnx_options_assign_to(migraphx_onnx_options_t output,
const_migraphx_onnx_options_t input);
migraphx_status migraphx_onnx_options_create(migraphx_onnx_options_t* onnx_options); migraphx_status migraphx_onnx_options_create(migraphx_onnx_options_t* onnx_options);
migraphx_status migraphx_onnx_options_set_input_parameter_shape( migraphx_status migraphx_onnx_options_set_input_parameter_shape(
...@@ -236,6 +264,9 @@ migraphx_onnx_options_set_default_loop_iterations(migraphx_onnx_options_t onnx_o ...@@ -236,6 +264,9 @@ migraphx_onnx_options_set_default_loop_iterations(migraphx_onnx_options_t onnx_o
migraphx_status migraphx_file_options_destroy(migraphx_file_options_t file_options); migraphx_status migraphx_file_options_destroy(migraphx_file_options_t file_options);
migraphx_status migraphx_file_options_assign_to(migraphx_file_options_t output,
const_migraphx_file_options_t input);
migraphx_status migraphx_file_options_create(migraphx_file_options_t* file_options); migraphx_status migraphx_file_options_create(migraphx_file_options_t* file_options);
migraphx_status migraphx_file_options_set_file_format(migraphx_file_options_t file_options, migraphx_status migraphx_file_options_set_file_format(migraphx_file_options_t file_options,
...@@ -243,6 +274,9 @@ migraphx_status migraphx_file_options_set_file_format(migraphx_file_options_t fi ...@@ -243,6 +274,9 @@ migraphx_status migraphx_file_options_set_file_format(migraphx_file_options_t fi
migraphx_status migraphx_compile_options_destroy(migraphx_compile_options_t compile_options); migraphx_status migraphx_compile_options_destroy(migraphx_compile_options_t compile_options);
migraphx_status migraphx_compile_options_assign_to(migraphx_compile_options_t output,
const_migraphx_compile_options_t input);
migraphx_status migraphx_compile_options_create(migraphx_compile_options_t* compile_options); migraphx_status migraphx_compile_options_create(migraphx_compile_options_t* compile_options);
migraphx_status migraphx_status
...@@ -261,6 +295,9 @@ migraphx_status migraphx_parse_onnx_buffer(migraphx_program_t* out, ...@@ -261,6 +295,9 @@ migraphx_status migraphx_parse_onnx_buffer(migraphx_program_t* out,
migraphx_status migraphx_tf_options_destroy(migraphx_tf_options_t tf_options); migraphx_status migraphx_tf_options_destroy(migraphx_tf_options_t tf_options);
migraphx_status migraphx_tf_options_assign_to(migraphx_tf_options_t output,
const_migraphx_tf_options_t input);
migraphx_status migraphx_tf_options_create(migraphx_tf_options_t* tf_options); migraphx_status migraphx_tf_options_create(migraphx_tf_options_t* tf_options);
migraphx_status migraphx_tf_options_set_nhwc(migraphx_tf_options_t tf_options, bool is_nhwc); migraphx_status migraphx_tf_options_set_nhwc(migraphx_tf_options_t tf_options, bool is_nhwc);
...@@ -282,6 +319,9 @@ migraphx_parse_tf(migraphx_program_t* out, const char* name, migraphx_tf_options ...@@ -282,6 +319,9 @@ migraphx_parse_tf(migraphx_program_t* out, const char* name, migraphx_tf_options
migraphx_status migraphx_quantize_op_names_destroy(migraphx_quantize_op_names_t quantize_op_names); migraphx_status migraphx_quantize_op_names_destroy(migraphx_quantize_op_names_t quantize_op_names);
migraphx_status migraphx_quantize_op_names_assign_to(migraphx_quantize_op_names_t output,
const_migraphx_quantize_op_names_t input);
migraphx_status migraphx_quantize_op_names_create(migraphx_quantize_op_names_t* quantize_op_names); migraphx_status migraphx_quantize_op_names_create(migraphx_quantize_op_names_t* quantize_op_names);
migraphx_status migraphx_quantize_op_names_add(migraphx_quantize_op_names_t quantize_op_names, migraphx_status migraphx_quantize_op_names_add(migraphx_quantize_op_names_t quantize_op_names,
...@@ -295,6 +335,10 @@ migraphx_status migraphx_quantize_fp16(migraphx_program_t prog); ...@@ -295,6 +335,10 @@ migraphx_status migraphx_quantize_fp16(migraphx_program_t prog);
migraphx_status migraphx_status
migraphx_quantize_int8_options_destroy(migraphx_quantize_int8_options_t quantize_int8_options); migraphx_quantize_int8_options_destroy(migraphx_quantize_int8_options_t quantize_int8_options);
migraphx_status
migraphx_quantize_int8_options_assign_to(migraphx_quantize_int8_options_t output,
const_migraphx_quantize_int8_options_t input);
migraphx_status migraphx_status
migraphx_quantize_int8_options_create(migraphx_quantize_int8_options_t* quantize_int8_options); migraphx_quantize_int8_options_create(migraphx_quantize_int8_options_t* quantize_int8_options);
......
...@@ -159,7 +159,7 @@ struct borrow ...@@ -159,7 +159,7 @@ struct borrow
{ {
}; };
template <class T, class D, D Deleter> template <class T, class D, D Deleter, class A, A Assigner>
struct handle_base struct handle_base
{ {
handle_base() : m_handle(nullptr) {} handle_base() : m_handle(nullptr) {}
...@@ -190,6 +190,12 @@ struct handle_base ...@@ -190,6 +190,12 @@ struct handle_base
m_handle = std::shared_ptr<U>{ptr, [](U*) {}}; m_handle = std::shared_ptr<U>{ptr, [](U*) {}};
} }
template <class U>
void assign_to_handle(U* x)
{
Assigner(x, this->get_handle_ptr());
}
protected: protected:
std::shared_ptr<T> m_handle; std::shared_ptr<T> m_handle;
}; };
...@@ -197,10 +203,12 @@ struct handle_base ...@@ -197,10 +203,12 @@ struct handle_base
#ifdef DOXYGEN #ifdef DOXYGEN
#define MIGRAPHX_DETAIL_HANDLE_BASE(name, const_) handle_base<> #define MIGRAPHX_DETAIL_HANDLE_BASE(name, const_) handle_base<>
#else #else
#define MIGRAPHX_DETAIL_HANDLE_BASE(name, const_) \ #define MIGRAPHX_DETAIL_HANDLE_BASE(name, const_) \
handle_base<const_ migraphx_##name, \ handle_base<const_ migraphx_##name, \
decltype(&migraphx_##name##_destroy), \ decltype(&migraphx_##name##_destroy), \
migraphx_##name##_destroy> migraphx_##name##_destroy, \
decltype(&migraphx_##name##_assign_to), \
migraphx_##name##_assign_to>
#endif #endif
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define MIGRAPHX_HANDLE_BASE(name) MIGRAPHX_DETAIL_HANDLE_BASE(name, ) #define MIGRAPHX_HANDLE_BASE(name) MIGRAPHX_DETAIL_HANDLE_BASE(name, )
......
...@@ -10,8 +10,35 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -10,8 +10,35 @@ inline namespace MIGRAPHX_INLINE_NS {
void auto_contiguous::apply(module& p) const void auto_contiguous::apply(module& p) const
{ {
std::string key = "std_shape";
for(auto ins : reverse_iterator_for(p))
{
auto&& attr = ins->get_operator().attributes();
if((attr.contains(key) and attr.at(key).to<bool>()))
{
auto args = ins->inputs();
auto new_args = args;
std::transform(args.begin(), args.end(), new_args.begin(), [&](auto in) {
if(in->name() == "contiguous")
{
return in;
}
return p.insert_instruction(ins, make_op("contiguous"), in);
});
if(new_args != args)
{
p.replace_instruction(ins, ins->get_operator(), new_args);
}
}
}
auto last = std::prev(p.end());
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
{ {
// for last instruction that is NOT a return
if(ins->outputs().empty() and ins != last)
continue;
shape s = ins->get_shape(); shape s = ins->get_shape();
if(not s.standard() and s.elements() != 0) if(not s.standard() and s.elements() != 0)
{ {
...@@ -19,6 +46,26 @@ void auto_contiguous::apply(module& p) const ...@@ -19,6 +46,26 @@ void auto_contiguous::apply(module& p) const
p.replace_instruction(ins, c); p.replace_instruction(ins, c);
} }
} }
// if ops used as output param are alias 0, add a contiguous for the output
// so return outputs with standard shape
if(last->name() == "@return")
{
auto inputs = last->inputs();
for(auto ins : inputs)
{
if(ins->name() == "contiguous")
continue;
auto ins_alias = ins->get_operator().output_alias({});
if(ins_alias == 0 and ins->get_shape().element_space() !=
ins->inputs().front()->get_shape().element_space())
{
auto cont_ins = p.insert_instruction(last, make_op("contiguous"), ins);
p.replace_instruction(ins, cont_ins);
}
}
}
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -32,7 +32,12 @@ void cse_range(module& p, Range&& r) ...@@ -32,7 +32,12 @@ void cse_range(module& p, Range&& r)
continue; continue;
p.replace_instruction(ins, eq); p.replace_instruction(ins, eq);
processed_ins.emplace(ins); processed_ins.emplace(ins);
auto outputs = eq->outputs(); std::vector<instruction_ref> outputs;
std::copy_if(eq->outputs().begin(),
eq->outputs().end(),
std::back_inserter(outputs),
[&](auto x) { return p.has_instruction(x); });
std::sort(outputs.begin(), outputs.end(), [&](auto x, auto y) { std::sort(outputs.begin(), outputs.end(), [&](auto x, auto y) {
return std::distance(eq, x) < std::distance(eq, y); return std::distance(eq, x) < std::distance(eq, y);
}); });
......
...@@ -78,15 +78,16 @@ void eliminate_contiguous::apply(module& p) const ...@@ -78,15 +78,16 @@ void eliminate_contiguous::apply(module& p) const
continue; continue;
// Make a copy so we can modify it while we iterate // Make a copy so we can modify it while we iterate
auto args = ins->inputs(); auto args = ins->inputs();
auto new_args = args;
auto mod_args = ins->module_inputs();
for(auto arg : ins->inputs()) for(auto arg : ins->inputs())
{ {
if(arg->name() == op_name) if(arg->name() == op_name)
{ {
auto new_args = args; auto prev = arg->inputs().front();
auto prev = arg->inputs().front();
replace(new_args, arg, prev); replace(new_args, arg, prev);
if(try_compute_shape(ins, new_args, ins->module_inputs())) if(try_compute_shape(ins, new_args, mod_args))
{ {
instruction::replace_argument(ins, arg, prev); instruction::replace_argument(ins, arg, prev);
} }
......
...@@ -66,7 +66,7 @@ struct reduce_op : op_name<Derived> ...@@ -66,7 +66,7 @@ struct reduce_op : op_name<Derived>
{ {
value normalize; value normalize;
normalize["axes"] = value::array{normalize_attribute::include_min}; normalize["axes"] = value::array{normalize_attribute::include_min};
return {{"normalize_axes", normalize}}; return {{"normalize_axes", normalize}, {"std_shape", true}};
} }
std::vector<int64_t> tune_axes(std::size_t n_dim) const std::vector<int64_t> tune_axes(std::size_t n_dim) const
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <migraphx/shape_for_each.hpp> #include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/lifetime.hpp> #include <migraphx/lifetime.hpp>
#include <migraphx/value.hpp>
#include <cmath> #include <cmath>
#include <utility> #include <utility>
...@@ -26,6 +27,8 @@ struct reshape ...@@ -26,6 +27,8 @@ struct reshape
return pack(f(self.dims, "dims")); return pack(f(self.dims, "dims"));
} }
value attributes() const { return {{"std_shape", true}}; }
std::string name() const { return "reshape"; } std::string name() const { return "reshape"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
......
...@@ -37,11 +37,11 @@ struct unsqueeze ...@@ -37,11 +37,11 @@ struct unsqueeze
std::string name() const { return "unsqueeze"; } std::string name() const { return "unsqueeze"; }
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1).standard_or_scalar(); check_shapes{inputs, *this}.has(1);
auto input_shape = inputs[0]; auto input_shape = inputs[0];
auto type = input_shape.type(); auto type = input_shape.type();
auto old_lens = input_shape.lens(); auto old_lens = input_shape.lens();
auto old_strides = input_shape.strides();
if(input_shape.scalar()) if(input_shape.scalar())
{ {
if(old_lens.size() == 1 and old_lens.front() == 1) if(old_lens.size() == 1 and old_lens.front() == 1)
...@@ -53,19 +53,29 @@ struct unsqueeze ...@@ -53,19 +53,29 @@ struct unsqueeze
std::size_t new_size = old_lens.size() + axes.size(); std::size_t new_size = old_lens.size() + axes.size();
std::vector<std::size_t> new_lens(new_size); std::vector<std::size_t> new_lens(new_size);
std::vector<std::size_t> new_strides(new_size);
std::size_t p = 0; std::size_t p = 0;
for(std::size_t i = 0; i < new_size; i++) for(auto i : range(new_size))
{ {
if(std::find(axes.begin(), axes.end(), i) != axes.end()) if(std::find(axes.begin(), axes.end(), i) != axes.end())
{ {
new_lens[i] = 1; new_lens[i] = 1;
if(p == 0) // unsqueeze on the first axes
{
new_strides[i] = old_lens[0] * old_strides[0];
}
else // unsqueeze on middle or last axes
{
new_strides[i] = (p < old_strides.size()) ? old_strides[p - 1] : 1;
}
} }
else else
{ {
new_lens[i] = old_lens[p++]; new_lens[i] = old_lens[p];
new_strides[i] = old_strides[p++];
} }
} }
return shape{type, new_lens}; return shape{type, new_lens, new_strides};
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
{ {
......
...@@ -223,11 +223,10 @@ struct shape ...@@ -223,11 +223,10 @@ struct shape
static type_t parse_type(const std::string& s); static type_t parse_type(const std::string& s);
const std::vector<shape>& sub_shapes() const; const std::vector<shape>& sub_shapes() const;
std::size_t element_space() const;
private: private:
std::shared_ptr<const shape_impl> impl; std::shared_ptr<const shape_impl> impl;
std::size_t element_space() const;
}; };
void migraphx_to_value(value& v, const shape& s); void migraphx_to_value(value& v, const shape& s);
......
...@@ -627,8 +627,9 @@ std::unordered_map<instruction_ref, std::string> module::print( ...@@ -627,8 +627,9 @@ std::unordered_map<instruction_ref, std::string> module::print(
var_name = this->name(); var_name = this->name();
var_name.append((this->name().empty() ? "@" : ":@")); var_name.append((this->name().empty() ? "@" : ":@"));
var_name.append(std::to_string(count)); var_name.append(std::to_string(count));
count++;
} }
// make instruction index to be the line num in the printed module
count++;
names.emplace(ins, var_name); names.emplace(ins, var_name);
print_func(ins, names); print_func(ins, names);
......
...@@ -172,7 +172,7 @@ void program::compile(const target& t, compile_options options) ...@@ -172,7 +172,7 @@ void program::compile(const target& t, compile_options options)
{ {
auto index = std::distance(mod->begin(), dangling); auto index = std::distance(mod->begin(), dangling);
MIGRAPHX_THROW("Dangling reference in module " + mod->name() + " from instruction " + MIGRAPHX_THROW("Dangling reference in module " + mod->name() + " from instruction " +
std::to_string(index)); std::to_string(index) + ", (" + dangling->name() + ")");
} }
mod->finalize(this->impl->ctx); mod->finalize(this->impl->ctx);
} }
......
...@@ -120,6 +120,17 @@ struct find_nop_reshapes ...@@ -120,6 +120,17 @@ struct find_nop_reshapes
void apply(module& p, const match::matcher_result& mr) const void apply(module& p, const match::matcher_result& mr) const
{ {
auto ins = mr.result; auto ins = mr.result;
// output of reshape and contiguous is standard, so no need to add another contiguous
// if the output is used an a ret value
if(ins->name() == "contiguous" and ins->name() != "contiguous" and ins->name() != "reshape")
{
auto& outputs = ins->outputs();
if(std::any_of(
outputs.begin(), outputs.end(), [&](auto o) { return o->name() == "@return"; }))
{
return;
}
}
p.replace_instruction(ins, ins->inputs().front()); p.replace_instruction(ins, ins->inputs().front());
} }
}; };
......
...@@ -93,10 +93,10 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -93,10 +93,10 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination{}, dead_code_elimination{},
eliminate_common_subexpression{}, eliminate_common_subexpression{},
dead_code_elimination{}, dead_code_elimination{},
auto_contiguous{},
simplify_algebra{}, simplify_algebra{},
simplify_reshapes{}, simplify_reshapes{},
simplify_algebra{}, simplify_algebra{},
auto_contiguous{},
simplify_reshapes{}, simplify_reshapes{},
propagate_constant{}, propagate_constant{},
dead_code_elimination{}, dead_code_elimination{},
......
...@@ -10,6 +10,7 @@ function(add_api_test TEST_NAME TEST_SRC TEST_DIR) ...@@ -10,6 +10,7 @@ function(add_api_test TEST_NAME TEST_SRC TEST_DIR)
add_dependencies(check ${NAME}) add_dependencies(check ${NAME})
endfunction() endfunction()
add_api_test(assign test_assign.cpp ${TEST_ONNX_DIR})
add_api_test(compile_options test_compile_options.cpp ${TEST_ONNX_DIR}) add_api_test(compile_options test_compile_options.cpp ${TEST_ONNX_DIR})
add_api_test(ref test_cpu.cpp ${TEST_ONNX_DIR}) add_api_test(ref test_cpu.cpp ${TEST_ONNX_DIR})
add_api_test(save_load test_save_load.cpp ${TEST_ONNX_DIR}) add_api_test(save_load test_save_load.cpp ${TEST_ONNX_DIR})
......
#include <migraphx/migraphx.h>
#include <migraphx/migraphx.hpp>
#include "test.hpp"
TEST_CASE(shape_assign)
{
auto s1_cpp = migraphx::shape{migraphx_shape_float_type, {1, 3}};
std::vector<size_t> lens{2, 3};
// handle ptr is const, workaround to construct shape using C API
migraphx_shape_t s2;
migraphx_shape_create(&s2, migraphx_shape_float_type, lens.data(), lens.size());
auto s2_cpp = migraphx::shape(s2, migraphx::own{});
CHECK(bool{s1_cpp != s2_cpp});
// use C++ API for assignment
s1_cpp.assign_to_handle(s2);
CHECK(bool{s1_cpp == s2_cpp});
auto s3_cpp = migraphx::shape{migraphx_shape_float_type, lens};
// use C API for assignment
migraphx_shape_assign_to(s2, s3_cpp.get_handle_ptr());
CHECK(bool{s2_cpp == s3_cpp});
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -135,4 +135,52 @@ TEST_CASE(two_transpose_gather) ...@@ -135,4 +135,52 @@ TEST_CASE(two_transpose_gather)
EXPECT(m1 == m2); EXPECT(m1 == m2);
} }
TEST_CASE(standard_reshape)
{
migraphx::module m1;
{
auto data = m1.add_parameter("2x2", {migraphx::shape::float_type, {2, 3, 4, 5}});
auto add = m1.add_instruction(migraphx::make_op("add"), data, data);
auto r = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 1, 12, 5}}}), add);
m1.add_return({r});
}
run_pass(m1);
migraphx::module m2;
{
auto data = m2.add_parameter("2x2", {migraphx::shape::float_type, {2, 3, 4, 5}});
auto add = m2.add_instruction(migraphx::make_op("add"), data, data);
auto ca = m2.add_instruction(migraphx::make_op("contiguous"), add);
auto r = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 1, 12, 5}}}), ca);
m2.add_return({r});
}
EXPECT(m1 == m2);
}
TEST_CASE(dead_instruction)
{
migraphx::module m1;
{
auto data = m1.add_parameter("2x2", {migraphx::shape::float_type, {2, 3, 4, 5}});
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 0, 1, 3}}}), data);
auto r = m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 0, 1, 3}}}),
data);
m1.add_return({r});
}
run_pass(m1);
migraphx::module m2;
{
auto data = m2.add_parameter("2x2", {migraphx::shape::float_type, {2, 3, 4, 5}});
m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 0, 1, 3}}}), data);
auto r = m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 0, 1, 3}}}),
data);
auto cr = m2.add_instruction(migraphx::make_op("contiguous"), r);
m2.add_return({cr});
}
EXPECT(m1 == m2);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -6,6 +6,12 @@ ...@@ -6,6 +6,12 @@
#include <test.hpp> #include <test.hpp>
void run_pass(migraphx::program& p)
{
migraphx::run_passes(
p, {migraphx::eliminate_common_subexpression{}, migraphx::dead_code_elimination{}});
}
void run_pass(migraphx::module& m) void run_pass(migraphx::module& m)
{ {
migraphx::run_passes( migraphx::run_passes(
...@@ -142,4 +148,59 @@ TEST_CASE(cse_test_literal) ...@@ -142,4 +148,59 @@ TEST_CASE(cse_test_literal)
EXPECT(m1 == m2); EXPECT(m1 == m2);
} }
TEST_CASE(cse_test_submodule)
{
migraphx::shape si{migraphx::shape::int64_type};
migraphx::shape s{migraphx::shape::int64_type, {1}};
migraphx::shape sc{migraphx::shape::bool_type};
auto create_program = [&](bool remove_literal = false) {
migraphx::program p;
std::vector<bool> vc = {true};
std::vector<int64_t> vd = {3};
auto* mm = p.get_main_module();
auto in_cond = mm->add_parameter("ccond", sc);
auto in_val = mm->add_parameter("val", s);
auto b0 = mm->add_literal(migraphx::literal(sc, vc));
auto b1 = b0;
if(not(remove_literal))
b1 = mm->add_literal(migraphx::literal(sc, vc));
auto* body1 = p.create_module("loop_module1");
body1->add_parameter("#loop_module_in_1", sc);
auto in_v1 = body1->add_parameter("#loop_module_in_2", s);
auto l1 = body1->add_literal(migraphx::literal(si, vd));
auto ad1 = body1->add_instruction(migraphx::make_op("add"), l1, l1);
auto val1 = body1->add_instruction(migraphx::make_op("add"), in_v1, ad1);
auto cond1 = body1->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::bool_type}}), b0);
auto cond2 = body1->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::bool_type}}), b1);
body1->add_return({cond1, cond2, val1, val1});
auto* body2 = p.create_module("loop_module2");
body2->add_parameter("#loop_module_in_1", sc);
auto in_v2 = body2->add_parameter("#loop_module_in_2", s);
auto l2 = body2->add_literal(migraphx::literal(si, vd));
auto ad2 = body2->add_instruction(migraphx::make_op("add"), l2, l2);
auto val2 = body2->add_instruction(migraphx::make_op("add"), in_v2, ad2);
auto cond3 = body2->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::bool_type}}), b1);
body2->add_return({cond3, val2, val2});
auto loop1 = mm->add_instruction(
migraphx::make_op("loop", {{"max_iterations", 1}}), {in_cond, in_val}, {body1});
auto loop2 = mm->add_instruction(
migraphx::make_op("loop", {{"max_iterations", 1}}), {in_cond, in_val}, {body2});
mm->add_return({loop1, loop2});
return p;
};
auto p = create_program();
run_pass(p);
EXPECT(p == create_program(true));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -1513,6 +1513,55 @@ TEST_CASE(test_unsqueeze_scalar_tensor2) ...@@ -1513,6 +1513,55 @@ TEST_CASE(test_unsqueeze_scalar_tensor2)
throws_shape(migraphx::make_op("unsqueeze", {{"axes", {-2}}}), s); throws_shape(migraphx::make_op("unsqueeze", {{"axes", {-2}}}), s);
} }
TEST_CASE(test_unsqueeze_transpose)
{
migraphx::shape s1{migraphx::shape::float_type, {4, 4, 3}, {12, 1, 4}};
migraphx::shape s2{migraphx::shape::float_type, {4, 4, 1, 3}, {12, 1, 1, 4}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {2}}}), s1);
}
TEST_CASE(test_unsqueeze_multibroadcast)
{
migraphx::shape s1{migraphx::shape::float_type, {2, 3, 4}, {0, 1, 0}};
migraphx::shape s2{migraphx::shape::float_type, {2, 3, 1, 4}, {0, 1, 1, 0}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {2}}}), s1);
}
TEST_CASE(test_unsqueeze_slice)
{
migraphx::shape s1{migraphx::shape::float_type, {2, 3, 4}, {108, 36, 1}};
migraphx::shape s2{migraphx::shape::float_type, {2, 3, 1, 4}, {108, 36, 36, 1}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {2}}}), s1);
}
TEST_CASE(test_unsqueeze_axis_zero)
{
migraphx::shape s1{migraphx::shape::float_type, {2, 3, 4}};
migraphx::shape s2{migraphx::shape::float_type, {1, 2, 3, 4}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {0}}}), s1);
}
TEST_CASE(test_unsqueeze_axis_last)
{
migraphx::shape s1{migraphx::shape::float_type, {2, 3, 4}};
migraphx::shape s2{migraphx::shape::float_type, {2, 3, 4, 1}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {-1}}}), s1);
}
TEST_CASE(test_unsqueeze_multiple_axes_1)
{
migraphx::shape s1{migraphx::shape::float_type, {2, 3, 4}};
migraphx::shape s2{migraphx::shape::float_type, {1, 2, 3, 4, 1}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {0, -1}}}), s1);
}
TEST_CASE(test_unsqueeze_multiple_axes_2)
{
migraphx::shape s1{migraphx::shape::float_type, {2, 3, 4}};
migraphx::shape s2{migraphx::shape::float_type, {1, 1, 2, 3, 4}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {0, 1}}}), s1);
}
TEST_CASE(transpose_shape) TEST_CASE(transpose_shape)
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 2}}; migraphx::shape input{migraphx::shape::float_type, {2, 2}};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment