"vscode:/vscode.git/clone" did not exist on "bc37ea69d5541debb89766c76ad3f38db88a5e5f"
Commit bdf91961 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'print_matmul_perf_flops' of...

Merge branch 'print_matmul_perf_flops' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into branch_for_ort2
parents 65ef1423 322283db
tensorflow==2.5.2 tensorflow==2.5.3
onnxruntime onnxruntime
tokenizers tokenizers
\ No newline at end of file
# Modifications Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved
# Copyright 2018 The Google AI Language Team Authors. # Copyright 2018 The Google AI Language Team Authors.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
......
...@@ -385,6 +385,13 @@ extern "C" migraphx_status migraphx_shape_destroy(migraphx_shape_t shape) ...@@ -385,6 +385,13 @@ extern "C" migraphx_status migraphx_shape_destroy(migraphx_shape_t shape)
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status migraphx_shape_assign_to(migraphx_shape_t output,
const_migraphx_shape_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status migraphx_shape_create(migraphx_shape_t* shape, extern "C" migraphx_status migraphx_shape_create(migraphx_shape_t* shape,
migraphx_shape_datatype_t type, migraphx_shape_datatype_t type,
size_t* lengths, size_t* lengths,
...@@ -502,6 +509,13 @@ extern "C" migraphx_status migraphx_argument_destroy(migraphx_argument_t argumen ...@@ -502,6 +509,13 @@ extern "C" migraphx_status migraphx_argument_destroy(migraphx_argument_t argumen
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status migraphx_argument_assign_to(migraphx_argument_t output,
const_migraphx_argument_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status extern "C" migraphx_status
migraphx_argument_create(migraphx_argument_t* argument, const_migraphx_shape_t shape, void* buffer) migraphx_argument_create(migraphx_argument_t* argument, const_migraphx_shape_t shape, void* buffer)
{ {
...@@ -565,6 +579,13 @@ extern "C" migraphx_status migraphx_target_destroy(migraphx_target_t target) ...@@ -565,6 +579,13 @@ extern "C" migraphx_status migraphx_target_destroy(migraphx_target_t target)
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status migraphx_target_assign_to(migraphx_target_t output,
const_migraphx_target_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status migraphx_target_create(migraphx_target_t* target, const char* name) extern "C" migraphx_status migraphx_target_create(migraphx_target_t* target, const char* name)
{ {
auto api_error_result = migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
...@@ -581,6 +602,14 @@ extern "C" migraphx_status migraphx_program_parameter_shapes_destroy( ...@@ -581,6 +602,14 @@ extern "C" migraphx_status migraphx_program_parameter_shapes_destroy(
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status
migraphx_program_parameter_shapes_assign_to(migraphx_program_parameter_shapes_t output,
const_migraphx_program_parameter_shapes_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status extern "C" migraphx_status
migraphx_program_parameter_shapes_size(size_t* out, migraphx_program_parameter_shapes_size(size_t* out,
migraphx_program_parameter_shapes_t program_parameter_shapes) migraphx_program_parameter_shapes_t program_parameter_shapes)
...@@ -631,6 +660,14 @@ migraphx_program_parameters_destroy(migraphx_program_parameters_t program_parame ...@@ -631,6 +660,14 @@ migraphx_program_parameters_destroy(migraphx_program_parameters_t program_parame
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status
migraphx_program_parameters_assign_to(migraphx_program_parameters_t output,
const_migraphx_program_parameters_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status extern "C" migraphx_status
migraphx_program_parameters_create(migraphx_program_parameters_t* program_parameters) migraphx_program_parameters_create(migraphx_program_parameters_t* program_parameters)
{ {
...@@ -663,6 +700,13 @@ extern "C" migraphx_status migraphx_arguments_destroy(migraphx_arguments_t argum ...@@ -663,6 +700,13 @@ extern "C" migraphx_status migraphx_arguments_destroy(migraphx_arguments_t argum
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status migraphx_arguments_assign_to(migraphx_arguments_t output,
const_migraphx_arguments_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status migraphx_arguments_size(size_t* out, migraphx_arguments_t arguments) extern "C" migraphx_status migraphx_arguments_size(size_t* out, migraphx_arguments_t arguments)
{ {
auto api_error_result = migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
...@@ -690,6 +734,13 @@ extern "C" migraphx_status migraphx_shapes_destroy(migraphx_shapes_t shapes) ...@@ -690,6 +734,13 @@ extern "C" migraphx_status migraphx_shapes_destroy(migraphx_shapes_t shapes)
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status migraphx_shapes_assign_to(migraphx_shapes_t output,
const_migraphx_shapes_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status migraphx_shapes_size(size_t* out, migraphx_shapes_t shapes) extern "C" migraphx_status migraphx_shapes_size(size_t* out, migraphx_shapes_t shapes)
{ {
auto api_error_result = migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
...@@ -727,6 +778,13 @@ extern "C" migraphx_status migraphx_program_destroy(migraphx_program_t program) ...@@ -727,6 +778,13 @@ extern "C" migraphx_status migraphx_program_destroy(migraphx_program_t program)
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status migraphx_program_assign_to(migraphx_program_t output,
const_migraphx_program_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status migraphx_program_get_main_module(migraphx_module_t* out, extern "C" migraphx_status migraphx_program_get_main_module(migraphx_module_t* out,
migraphx_program_t program) migraphx_program_t program)
{ {
...@@ -831,6 +889,13 @@ extern "C" migraphx_status migraphx_operation_destroy(migraphx_operation_t opera ...@@ -831,6 +889,13 @@ extern "C" migraphx_status migraphx_operation_destroy(migraphx_operation_t opera
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status migraphx_operation_assign_to(migraphx_operation_t output,
const_migraphx_operation_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status migraphx_operation_create(migraphx_operation_t* operation, extern "C" migraphx_status migraphx_operation_create(migraphx_operation_t* operation,
const char* name, const char* name,
const char* attributes, const char* attributes,
...@@ -891,6 +956,13 @@ extern "C" migraphx_status migraphx_onnx_options_destroy(migraphx_onnx_options_t ...@@ -891,6 +956,13 @@ extern "C" migraphx_status migraphx_onnx_options_destroy(migraphx_onnx_options_t
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status migraphx_onnx_options_assign_to(migraphx_onnx_options_t output,
const_migraphx_onnx_options_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status migraphx_onnx_options_create(migraphx_onnx_options_t* onnx_options) extern "C" migraphx_status migraphx_onnx_options_create(migraphx_onnx_options_t* onnx_options)
{ {
auto api_error_result = migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
...@@ -942,6 +1014,13 @@ extern "C" migraphx_status migraphx_file_options_destroy(migraphx_file_options_t ...@@ -942,6 +1014,13 @@ extern "C" migraphx_status migraphx_file_options_destroy(migraphx_file_options_t
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status migraphx_file_options_assign_to(migraphx_file_options_t output,
const_migraphx_file_options_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status migraphx_file_options_create(migraphx_file_options_t* file_options) extern "C" migraphx_status migraphx_file_options_create(migraphx_file_options_t* file_options)
{ {
auto api_error_result = migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
...@@ -968,6 +1047,14 @@ migraphx_compile_options_destroy(migraphx_compile_options_t compile_options) ...@@ -968,6 +1047,14 @@ migraphx_compile_options_destroy(migraphx_compile_options_t compile_options)
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status
migraphx_compile_options_assign_to(migraphx_compile_options_t output,
const_migraphx_compile_options_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status extern "C" migraphx_status
migraphx_compile_options_create(migraphx_compile_options_t* compile_options) migraphx_compile_options_create(migraphx_compile_options_t* compile_options)
{ {
...@@ -1033,6 +1120,13 @@ extern "C" migraphx_status migraphx_tf_options_destroy(migraphx_tf_options_t tf_ ...@@ -1033,6 +1120,13 @@ extern "C" migraphx_status migraphx_tf_options_destroy(migraphx_tf_options_t tf_
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status migraphx_tf_options_assign_to(migraphx_tf_options_t output,
const_migraphx_tf_options_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status migraphx_tf_options_create(migraphx_tf_options_t* tf_options) extern "C" migraphx_status migraphx_tf_options_create(migraphx_tf_options_t* tf_options)
{ {
auto api_error_result = migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
...@@ -1110,6 +1204,14 @@ migraphx_quantize_op_names_destroy(migraphx_quantize_op_names_t quantize_op_name ...@@ -1110,6 +1204,14 @@ migraphx_quantize_op_names_destroy(migraphx_quantize_op_names_t quantize_op_name
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status
migraphx_quantize_op_names_assign_to(migraphx_quantize_op_names_t output,
const_migraphx_quantize_op_names_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status extern "C" migraphx_status
migraphx_quantize_op_names_create(migraphx_quantize_op_names_t* quantize_op_names) migraphx_quantize_op_names_create(migraphx_quantize_op_names_t* quantize_op_names)
{ {
...@@ -1162,6 +1264,14 @@ migraphx_quantize_int8_options_destroy(migraphx_quantize_int8_options_t quantize ...@@ -1162,6 +1264,14 @@ migraphx_quantize_int8_options_destroy(migraphx_quantize_int8_options_t quantize
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status
migraphx_quantize_int8_options_assign_to(migraphx_quantize_int8_options_t output,
const_migraphx_quantize_int8_options_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status extern "C" migraphx_status
migraphx_quantize_int8_options_create(migraphx_quantize_int8_options_t* quantize_int8_options) migraphx_quantize_int8_options_create(migraphx_quantize_int8_options_t* quantize_int8_options)
{ {
......
...@@ -91,6 +91,8 @@ typedef const struct migraphx_quantize_int8_options* const_migraphx_quantize_int ...@@ -91,6 +91,8 @@ typedef const struct migraphx_quantize_int8_options* const_migraphx_quantize_int
migraphx_status migraphx_shape_destroy(migraphx_shape_t shape); migraphx_status migraphx_shape_destroy(migraphx_shape_t shape);
migraphx_status migraphx_shape_assign_to(migraphx_shape_t output, const_migraphx_shape_t input);
migraphx_status migraphx_shape_create(migraphx_shape_t* shape, migraphx_status migraphx_shape_create(migraphx_shape_t* shape,
migraphx_shape_datatype_t type, migraphx_shape_datatype_t type,
size_t* lengths, size_t* lengths,
...@@ -121,6 +123,9 @@ migraphx_shape_equal(bool* out, const_migraphx_shape_t shape, const_migraphx_sha ...@@ -121,6 +123,9 @@ migraphx_shape_equal(bool* out, const_migraphx_shape_t shape, const_migraphx_sha
migraphx_status migraphx_argument_destroy(migraphx_argument_t argument); migraphx_status migraphx_argument_destroy(migraphx_argument_t argument);
migraphx_status migraphx_argument_assign_to(migraphx_argument_t output,
const_migraphx_argument_t input);
migraphx_status migraphx_status
migraphx_argument_create(migraphx_argument_t* argument, const_migraphx_shape_t shape, void* buffer); migraphx_argument_create(migraphx_argument_t* argument, const_migraphx_shape_t shape, void* buffer);
...@@ -137,11 +142,17 @@ migraphx_argument_generate(migraphx_argument_t* out, const_migraphx_shape_t s, s ...@@ -137,11 +142,17 @@ migraphx_argument_generate(migraphx_argument_t* out, const_migraphx_shape_t s, s
migraphx_status migraphx_target_destroy(migraphx_target_t target); migraphx_status migraphx_target_destroy(migraphx_target_t target);
migraphx_status migraphx_target_assign_to(migraphx_target_t output, const_migraphx_target_t input);
migraphx_status migraphx_target_create(migraphx_target_t* target, const char* name); migraphx_status migraphx_target_create(migraphx_target_t* target, const char* name);
migraphx_status migraphx_program_parameter_shapes_destroy( migraphx_status migraphx_program_parameter_shapes_destroy(
migraphx_program_parameter_shapes_t program_parameter_shapes); migraphx_program_parameter_shapes_t program_parameter_shapes);
migraphx_status
migraphx_program_parameter_shapes_assign_to(migraphx_program_parameter_shapes_t output,
const_migraphx_program_parameter_shapes_t input);
migraphx_status migraphx_program_parameter_shapes_size( migraphx_status migraphx_program_parameter_shapes_size(
size_t* out, migraphx_program_parameter_shapes_t program_parameter_shapes); size_t* out, migraphx_program_parameter_shapes_t program_parameter_shapes);
...@@ -156,6 +167,9 @@ migraphx_status migraphx_program_parameter_shapes_names( ...@@ -156,6 +167,9 @@ migraphx_status migraphx_program_parameter_shapes_names(
migraphx_status migraphx_status
migraphx_program_parameters_destroy(migraphx_program_parameters_t program_parameters); migraphx_program_parameters_destroy(migraphx_program_parameters_t program_parameters);
migraphx_status migraphx_program_parameters_assign_to(migraphx_program_parameters_t output,
const_migraphx_program_parameters_t input);
migraphx_status migraphx_status
migraphx_program_parameters_create(migraphx_program_parameters_t* program_parameters); migraphx_program_parameters_create(migraphx_program_parameters_t* program_parameters);
...@@ -165,6 +179,9 @@ migraphx_status migraphx_program_parameters_add(migraphx_program_parameters_t pr ...@@ -165,6 +179,9 @@ migraphx_status migraphx_program_parameters_add(migraphx_program_parameters_t pr
migraphx_status migraphx_arguments_destroy(migraphx_arguments_t arguments); migraphx_status migraphx_arguments_destroy(migraphx_arguments_t arguments);
migraphx_status migraphx_arguments_assign_to(migraphx_arguments_t output,
const_migraphx_arguments_t input);
migraphx_status migraphx_arguments_size(size_t* out, migraphx_arguments_t arguments); migraphx_status migraphx_arguments_size(size_t* out, migraphx_arguments_t arguments);
migraphx_status migraphx_status
...@@ -172,6 +189,8 @@ migraphx_arguments_get(const_migraphx_argument_t* out, migraphx_arguments_t argu ...@@ -172,6 +189,8 @@ migraphx_arguments_get(const_migraphx_argument_t* out, migraphx_arguments_t argu
migraphx_status migraphx_shapes_destroy(migraphx_shapes_t shapes); migraphx_status migraphx_shapes_destroy(migraphx_shapes_t shapes);
migraphx_status migraphx_shapes_assign_to(migraphx_shapes_t output, const_migraphx_shapes_t input);
migraphx_status migraphx_shapes_size(size_t* out, migraphx_shapes_t shapes); migraphx_status migraphx_shapes_size(size_t* out, migraphx_shapes_t shapes);
migraphx_status migraphx_status
...@@ -181,6 +200,9 @@ migraphx_status migraphx_module_print(const_migraphx_module_t module); ...@@ -181,6 +200,9 @@ migraphx_status migraphx_module_print(const_migraphx_module_t module);
migraphx_status migraphx_program_destroy(migraphx_program_t program); migraphx_status migraphx_program_destroy(migraphx_program_t program);
migraphx_status migraphx_program_assign_to(migraphx_program_t output,
const_migraphx_program_t input);
migraphx_status migraphx_program_get_main_module(migraphx_module_t* out, migraphx_status migraphx_program_get_main_module(migraphx_module_t* out,
migraphx_program_t program); migraphx_program_t program);
...@@ -207,6 +229,9 @@ migraphx_program_equal(bool* out, const_migraphx_program_t program, const_migrap ...@@ -207,6 +229,9 @@ migraphx_program_equal(bool* out, const_migraphx_program_t program, const_migrap
migraphx_status migraphx_operation_destroy(migraphx_operation_t operation); migraphx_status migraphx_operation_destroy(migraphx_operation_t operation);
migraphx_status migraphx_operation_assign_to(migraphx_operation_t output,
const_migraphx_operation_t input);
migraphx_status migraphx_operation_create(migraphx_operation_t* operation, migraphx_status migraphx_operation_create(migraphx_operation_t* operation,
const char* name, const char* name,
const char* attributes, const char* attributes,
...@@ -222,6 +247,9 @@ migraphx_save(migraphx_program_t p, const char* name, migraphx_file_options_t op ...@@ -222,6 +247,9 @@ migraphx_save(migraphx_program_t p, const char* name, migraphx_file_options_t op
migraphx_status migraphx_onnx_options_destroy(migraphx_onnx_options_t onnx_options); migraphx_status migraphx_onnx_options_destroy(migraphx_onnx_options_t onnx_options);
migraphx_status migraphx_onnx_options_assign_to(migraphx_onnx_options_t output,
const_migraphx_onnx_options_t input);
migraphx_status migraphx_onnx_options_create(migraphx_onnx_options_t* onnx_options); migraphx_status migraphx_onnx_options_create(migraphx_onnx_options_t* onnx_options);
migraphx_status migraphx_onnx_options_set_input_parameter_shape( migraphx_status migraphx_onnx_options_set_input_parameter_shape(
...@@ -236,6 +264,9 @@ migraphx_onnx_options_set_default_loop_iterations(migraphx_onnx_options_t onnx_o ...@@ -236,6 +264,9 @@ migraphx_onnx_options_set_default_loop_iterations(migraphx_onnx_options_t onnx_o
migraphx_status migraphx_file_options_destroy(migraphx_file_options_t file_options); migraphx_status migraphx_file_options_destroy(migraphx_file_options_t file_options);
migraphx_status migraphx_file_options_assign_to(migraphx_file_options_t output,
const_migraphx_file_options_t input);
migraphx_status migraphx_file_options_create(migraphx_file_options_t* file_options); migraphx_status migraphx_file_options_create(migraphx_file_options_t* file_options);
migraphx_status migraphx_file_options_set_file_format(migraphx_file_options_t file_options, migraphx_status migraphx_file_options_set_file_format(migraphx_file_options_t file_options,
...@@ -243,6 +274,9 @@ migraphx_status migraphx_file_options_set_file_format(migraphx_file_options_t fi ...@@ -243,6 +274,9 @@ migraphx_status migraphx_file_options_set_file_format(migraphx_file_options_t fi
migraphx_status migraphx_compile_options_destroy(migraphx_compile_options_t compile_options); migraphx_status migraphx_compile_options_destroy(migraphx_compile_options_t compile_options);
migraphx_status migraphx_compile_options_assign_to(migraphx_compile_options_t output,
const_migraphx_compile_options_t input);
migraphx_status migraphx_compile_options_create(migraphx_compile_options_t* compile_options); migraphx_status migraphx_compile_options_create(migraphx_compile_options_t* compile_options);
migraphx_status migraphx_status
...@@ -261,6 +295,9 @@ migraphx_status migraphx_parse_onnx_buffer(migraphx_program_t* out, ...@@ -261,6 +295,9 @@ migraphx_status migraphx_parse_onnx_buffer(migraphx_program_t* out,
migraphx_status migraphx_tf_options_destroy(migraphx_tf_options_t tf_options); migraphx_status migraphx_tf_options_destroy(migraphx_tf_options_t tf_options);
migraphx_status migraphx_tf_options_assign_to(migraphx_tf_options_t output,
const_migraphx_tf_options_t input);
migraphx_status migraphx_tf_options_create(migraphx_tf_options_t* tf_options); migraphx_status migraphx_tf_options_create(migraphx_tf_options_t* tf_options);
migraphx_status migraphx_tf_options_set_nhwc(migraphx_tf_options_t tf_options, bool is_nhwc); migraphx_status migraphx_tf_options_set_nhwc(migraphx_tf_options_t tf_options, bool is_nhwc);
...@@ -282,6 +319,9 @@ migraphx_parse_tf(migraphx_program_t* out, const char* name, migraphx_tf_options ...@@ -282,6 +319,9 @@ migraphx_parse_tf(migraphx_program_t* out, const char* name, migraphx_tf_options
migraphx_status migraphx_quantize_op_names_destroy(migraphx_quantize_op_names_t quantize_op_names); migraphx_status migraphx_quantize_op_names_destroy(migraphx_quantize_op_names_t quantize_op_names);
migraphx_status migraphx_quantize_op_names_assign_to(migraphx_quantize_op_names_t output,
const_migraphx_quantize_op_names_t input);
migraphx_status migraphx_quantize_op_names_create(migraphx_quantize_op_names_t* quantize_op_names); migraphx_status migraphx_quantize_op_names_create(migraphx_quantize_op_names_t* quantize_op_names);
migraphx_status migraphx_quantize_op_names_add(migraphx_quantize_op_names_t quantize_op_names, migraphx_status migraphx_quantize_op_names_add(migraphx_quantize_op_names_t quantize_op_names,
...@@ -295,6 +335,10 @@ migraphx_status migraphx_quantize_fp16(migraphx_program_t prog); ...@@ -295,6 +335,10 @@ migraphx_status migraphx_quantize_fp16(migraphx_program_t prog);
migraphx_status migraphx_status
migraphx_quantize_int8_options_destroy(migraphx_quantize_int8_options_t quantize_int8_options); migraphx_quantize_int8_options_destroy(migraphx_quantize_int8_options_t quantize_int8_options);
migraphx_status
migraphx_quantize_int8_options_assign_to(migraphx_quantize_int8_options_t output,
const_migraphx_quantize_int8_options_t input);
migraphx_status migraphx_status
migraphx_quantize_int8_options_create(migraphx_quantize_int8_options_t* quantize_int8_options); migraphx_quantize_int8_options_create(migraphx_quantize_int8_options_t* quantize_int8_options);
......
...@@ -159,7 +159,7 @@ struct borrow ...@@ -159,7 +159,7 @@ struct borrow
{ {
}; };
template <class T, class D, D Deleter> template <class T, class D, D Deleter, class A, A Assigner>
struct handle_base struct handle_base
{ {
handle_base() : m_handle(nullptr) {} handle_base() : m_handle(nullptr) {}
...@@ -190,6 +190,12 @@ struct handle_base ...@@ -190,6 +190,12 @@ struct handle_base
m_handle = std::shared_ptr<U>{ptr, [](U*) {}}; m_handle = std::shared_ptr<U>{ptr, [](U*) {}};
} }
template <class U>
void assign_to_handle(U* x)
{
Assigner(x, this->get_handle_ptr());
}
protected: protected:
std::shared_ptr<T> m_handle; std::shared_ptr<T> m_handle;
}; };
...@@ -197,10 +203,12 @@ struct handle_base ...@@ -197,10 +203,12 @@ struct handle_base
#ifdef DOXYGEN #ifdef DOXYGEN
#define MIGRAPHX_DETAIL_HANDLE_BASE(name, const_) handle_base<> #define MIGRAPHX_DETAIL_HANDLE_BASE(name, const_) handle_base<>
#else #else
#define MIGRAPHX_DETAIL_HANDLE_BASE(name, const_) \ #define MIGRAPHX_DETAIL_HANDLE_BASE(name, const_) \
handle_base<const_ migraphx_##name, \ handle_base<const_ migraphx_##name, \
decltype(&migraphx_##name##_destroy), \ decltype(&migraphx_##name##_destroy), \
migraphx_##name##_destroy> migraphx_##name##_destroy, \
decltype(&migraphx_##name##_assign_to), \
migraphx_##name##_assign_to>
#endif #endif
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define MIGRAPHX_HANDLE_BASE(name) MIGRAPHX_DETAIL_HANDLE_BASE(name, ) #define MIGRAPHX_HANDLE_BASE(name) MIGRAPHX_DETAIL_HANDLE_BASE(name, )
......
...@@ -10,8 +10,35 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -10,8 +10,35 @@ inline namespace MIGRAPHX_INLINE_NS {
void auto_contiguous::apply(module& p) const void auto_contiguous::apply(module& p) const
{ {
std::string key = "std_shape";
for(auto ins : reverse_iterator_for(p))
{
auto&& attr = ins->get_operator().attributes();
if((attr.contains(key) and attr.at(key).to<bool>()))
{
auto args = ins->inputs();
auto new_args = args;
std::transform(args.begin(), args.end(), new_args.begin(), [&](auto in) {
if(in->name() == "contiguous")
{
return in;
}
return p.insert_instruction(ins, make_op("contiguous"), in);
});
if(new_args != args)
{
p.replace_instruction(ins, ins->get_operator(), new_args);
}
}
}
auto last = std::prev(p.end());
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
{ {
// for last instruction that is NOT a return
if(ins->outputs().empty() and ins != last)
continue;
shape s = ins->get_shape(); shape s = ins->get_shape();
if(not s.standard() and s.elements() != 0) if(not s.standard() and s.elements() != 0)
{ {
...@@ -19,6 +46,26 @@ void auto_contiguous::apply(module& p) const ...@@ -19,6 +46,26 @@ void auto_contiguous::apply(module& p) const
p.replace_instruction(ins, c); p.replace_instruction(ins, c);
} }
} }
// if ops used as output param are alias 0, add a contiguous for the output
// so return outputs with standard shape
if(last->name() == "@return")
{
auto inputs = last->inputs();
for(auto ins : inputs)
{
if(ins->name() == "contiguous")
continue;
auto ins_alias = ins->get_operator().output_alias({});
if(ins_alias == 0 and ins->get_shape().element_space() !=
ins->inputs().front()->get_shape().element_space())
{
auto cont_ins = p.insert_instruction(last, make_op("contiguous"), ins);
p.replace_instruction(ins, cont_ins);
}
}
}
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -32,7 +32,12 @@ void cse_range(module& p, Range&& r) ...@@ -32,7 +32,12 @@ void cse_range(module& p, Range&& r)
continue; continue;
p.replace_instruction(ins, eq); p.replace_instruction(ins, eq);
processed_ins.emplace(ins); processed_ins.emplace(ins);
auto outputs = eq->outputs(); std::vector<instruction_ref> outputs;
std::copy_if(eq->outputs().begin(),
eq->outputs().end(),
std::back_inserter(outputs),
[&](auto x) { return p.has_instruction(x); });
std::sort(outputs.begin(), outputs.end(), [&](auto x, auto y) { std::sort(outputs.begin(), outputs.end(), [&](auto x, auto y) {
return std::distance(eq, x) < std::distance(eq, y); return std::distance(eq, x) < std::distance(eq, y);
}); });
......
...@@ -78,15 +78,16 @@ void eliminate_contiguous::apply(module& p) const ...@@ -78,15 +78,16 @@ void eliminate_contiguous::apply(module& p) const
continue; continue;
// Make a copy so we can modify it while we iterate // Make a copy so we can modify it while we iterate
auto args = ins->inputs(); auto args = ins->inputs();
auto new_args = args;
auto mod_args = ins->module_inputs();
for(auto arg : ins->inputs()) for(auto arg : ins->inputs())
{ {
if(arg->name() == op_name) if(arg->name() == op_name)
{ {
auto new_args = args; auto prev = arg->inputs().front();
auto prev = arg->inputs().front();
replace(new_args, arg, prev); replace(new_args, arg, prev);
if(try_compute_shape(ins, new_args, ins->module_inputs())) if(try_compute_shape(ins, new_args, mod_args))
{ {
instruction::replace_argument(ins, arg, prev); instruction::replace_argument(ins, arg, prev);
} }
......
...@@ -66,7 +66,7 @@ struct reduce_op : op_name<Derived> ...@@ -66,7 +66,7 @@ struct reduce_op : op_name<Derived>
{ {
value normalize; value normalize;
normalize["axes"] = value::array{normalize_attribute::include_min}; normalize["axes"] = value::array{normalize_attribute::include_min};
return {{"normalize_axes", normalize}}; return {{"normalize_axes", normalize}, {"std_shape", true}};
} }
std::vector<int64_t> tune_axes(std::size_t n_dim) const std::vector<int64_t> tune_axes(std::size_t n_dim) const
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <migraphx/shape_for_each.hpp> #include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/lifetime.hpp> #include <migraphx/lifetime.hpp>
#include <migraphx/value.hpp>
#include <cmath> #include <cmath>
#include <utility> #include <utility>
...@@ -26,6 +27,8 @@ struct reshape ...@@ -26,6 +27,8 @@ struct reshape
return pack(f(self.dims, "dims")); return pack(f(self.dims, "dims"));
} }
value attributes() const { return {{"std_shape", true}}; }
std::string name() const { return "reshape"; } std::string name() const { return "reshape"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
......
...@@ -41,39 +41,45 @@ struct squeeze ...@@ -41,39 +41,45 @@ struct squeeze
auto input_shape = inputs[0]; auto input_shape = inputs[0];
auto type = input_shape.type(); auto type = input_shape.type();
auto old_lens = input_shape.lens(); auto old_lens = input_shape.lens();
auto old_strides = input_shape.strides();
if(std::any_of(axes.begin(), axes.end(), [&](auto axis) { return old_lens[axis] != 1; })) if(std::any_of(axes.begin(), axes.end(), [&](auto axis) { return old_lens[axis] != 1; }))
{ {
MIGRAPHX_THROW("squeeze axis dimension should be equal to 1"); MIGRAPHX_THROW("squeeze axis dimension should be equal to 1");
} }
std::vector<std::size_t> new_lens; std::vector<std::size_t> new_lens;
std::vector<std::size_t> new_strides;
if(axes.empty()) if(axes.empty())
{ {
std::copy_if(old_lens.begin(), for(auto i : range(old_lens.size()))
old_lens.end(), {
std::back_inserter(new_lens), if(old_lens[i] != 1)
[](auto len) { return len != 1; }); {
new_lens.push_back(old_lens[i]);
new_strides.push_back(old_strides[i]);
}
}
} }
else else
{ {
for(std::size_t i = 0; i < old_lens.size(); i++) for(auto i : range(old_lens.size()))
{ {
if(std::find(axes.begin(), axes.end(), i) == axes.end()) if(std::find(axes.begin(), axes.end(), i) == axes.end())
{ {
new_lens.push_back(old_lens[i]); new_lens.push_back(old_lens[i]);
new_strides.push_back(old_strides[i]);
} }
} }
} }
if(new_lens.empty()) if(new_lens.empty())
{ {
return shape{type}; return shape{type};
} }
else else
{ {
return shape{type, new_lens}; return shape{type, new_lens, new_strides};
} }
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
{ {
return args[0].reshape(output_shape); return args[0].reshape(output_shape);
......
...@@ -37,11 +37,11 @@ struct unsqueeze ...@@ -37,11 +37,11 @@ struct unsqueeze
std::string name() const { return "unsqueeze"; } std::string name() const { return "unsqueeze"; }
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1).standard_or_scalar(); check_shapes{inputs, *this}.has(1);
auto input_shape = inputs[0]; auto input_shape = inputs[0];
auto type = input_shape.type(); auto type = input_shape.type();
auto old_lens = input_shape.lens(); auto old_lens = input_shape.lens();
auto old_strides = input_shape.strides();
if(input_shape.scalar()) if(input_shape.scalar())
{ {
if(old_lens.size() == 1 and old_lens.front() == 1) if(old_lens.size() == 1 and old_lens.front() == 1)
...@@ -53,19 +53,29 @@ struct unsqueeze ...@@ -53,19 +53,29 @@ struct unsqueeze
std::size_t new_size = old_lens.size() + axes.size(); std::size_t new_size = old_lens.size() + axes.size();
std::vector<std::size_t> new_lens(new_size); std::vector<std::size_t> new_lens(new_size);
std::vector<std::size_t> new_strides(new_size);
std::size_t p = 0; std::size_t p = 0;
for(std::size_t i = 0; i < new_size; i++) for(auto i : range(new_size))
{ {
if(std::find(axes.begin(), axes.end(), i) != axes.end()) if(std::find(axes.begin(), axes.end(), i) != axes.end())
{ {
new_lens[i] = 1; new_lens[i] = 1;
if(p == 0) // unsqueeze on the first axes
{
new_strides[i] = old_lens[0] * old_strides[0];
}
else // unsqueeze on middle or last axes
{
new_strides[i] = (p < old_strides.size()) ? old_strides[p - 1] : 1;
}
} }
else else
{ {
new_lens[i] = old_lens[p++]; new_lens[i] = old_lens[p];
new_strides[i] = old_strides[p++];
} }
} }
return shape{type, new_lens}; return shape{type, new_lens, new_strides};
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
{ {
......
...@@ -77,6 +77,9 @@ struct program ...@@ -77,6 +77,9 @@ struct program
void debug_print() const; void debug_print() const;
void debug_print(instruction_ref ins) const; void debug_print(instruction_ref ins) const;
void debug_print(std::ostream& os,
instruction_ref ins,
const std::unordered_map<instruction_ref, std::string>& names) const;
void print(std::unordered_map<instruction_ref, std::string>& names, void print(std::unordered_map<instruction_ref, std::string>& names,
const std::function<void(instruction_ref, const std::function<void(instruction_ref,
std::unordered_map<instruction_ref, std::string>)>& std::unordered_map<instruction_ref, std::string>)>&
...@@ -111,6 +114,7 @@ struct program ...@@ -111,6 +114,7 @@ struct program
private: private:
void assign(const program& p); void assign(const program& p);
int max_ins_length() const;
std::unique_ptr<program_impl> impl; std::unique_ptr<program_impl> impl;
}; };
......
...@@ -223,11 +223,10 @@ struct shape ...@@ -223,11 +223,10 @@ struct shape
static type_t parse_type(const std::string& s); static type_t parse_type(const std::string& s);
const std::vector<shape>& sub_shapes() const; const std::vector<shape>& sub_shapes() const;
std::size_t element_space() const;
private: private:
std::shared_ptr<const shape_impl> impl; std::shared_ptr<const shape_impl> impl;
std::size_t element_space() const;
}; };
void migraphx_to_value(value& v, const shape& s); void migraphx_to_value(value& v, const shape& s);
......
...@@ -627,8 +627,9 @@ std::unordered_map<instruction_ref, std::string> module::print( ...@@ -627,8 +627,9 @@ std::unordered_map<instruction_ref, std::string> module::print(
var_name = this->name(); var_name = this->name();
var_name.append((this->name().empty() ? "@" : ":@")); var_name.append((this->name().empty() ? "@" : ":@"));
var_name.append(std::to_string(count)); var_name.append(std::to_string(count));
count++;
} }
// make instruction index to be the line num in the printed module
count++;
names.emplace(ins, var_name); names.emplace(ins, var_name);
print_func(ins, names); print_func(ins, names);
......
#include "migraphx/instruction_ref.hpp"
#include <functional>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
...@@ -15,10 +17,12 @@ ...@@ -15,10 +17,12 @@
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/marker.hpp> #include <migraphx/marker.hpp>
#include <iostream> #include <iostream>
#include <numeric>
#include <sstream> #include <sstream>
#include <algorithm> #include <algorithm>
#include <set> #include <set>
#include <utility> #include <utility>
#include <iomanip>
#include <unordered_set> #include <unordered_set>
#include <map> #include <map>
...@@ -168,7 +172,7 @@ void program::compile(const target& t, compile_options options) ...@@ -168,7 +172,7 @@ void program::compile(const target& t, compile_options options)
{ {
auto index = std::distance(mod->begin(), dangling); auto index = std::distance(mod->begin(), dangling);
MIGRAPHX_THROW("Dangling reference in module " + mod->name() + " from instruction " + MIGRAPHX_THROW("Dangling reference in module " + mod->name() + " from instruction " +
std::to_string(index)); std::to_string(index) + ", (" + dangling->name() + ")");
} }
mod->finalize(this->impl->ctx); mod->finalize(this->impl->ctx);
} }
...@@ -325,6 +329,204 @@ std::vector<argument> generic_eval(const program& p, ...@@ -325,6 +329,204 @@ std::vector<argument> generic_eval(const program& p,
return generic_eval(mm, ctx, params, {}, make_trace); return generic_eval(mm, ctx, params, {}, make_trace);
} }
static void print_space(std::ostream& os, int n)
{
for(int i = 0; i < n; ++i)
{
os << ' ';
}
}
using op_flops = std::function<double(const std::vector<shape>& vec_ss)>;
auto& get_flops_funcs()
{
static std::unordered_map<std::string, op_flops> op_funcs;
op_funcs.emplace("gemm", [&](const std::vector<shape>& vec_ss) {
assert(vec_ss.size() >= 2);
auto sa = vec_ss.front();
auto sb = vec_ss.at(1);
auto batch = 1;
auto lens_a = sa.lens();
batch =
std::accumulate(lens_a.rbegin() + 2, lens_a.rend(), 1, std::multiplies<std::size_t>{});
auto m = lens_a[lens_a.size() - 2];
auto k = lens_a.back();
auto lens_b = sb.lens();
assert(k == lens_b[lens_b.size() - 2]);
auto n = lens_b.back();
return 2.0 * m * n * k * batch;
});
op_funcs.emplace("convolution", [&](const std::vector<shape>& vec_ss) {
assert(vec_ss.size() >= 2);
auto alens = vec_ss.front().lens();
auto blens = vec_ss.at(1).lens();
auto olens = vec_ss.back().lens();
auto n = alens.front();
auto k = blens.front();
auto c = alens.at(1);
auto y = blens.at(2);
auto x = blens.back();
auto ho = olens.at(2);
auto wo = olens.back();
return 2.0 * n * k * ho * wo * c * y * x;
});
return op_funcs;
}
int program::max_ins_length() const
{
std::unordered_map<instruction_ref, std::string> names;
int max_ins_len = 0;
this->print(names, [&](auto ins, auto ins_names) {
std::stringstream ss;
instruction::print(ss, ins, ins_names);
if(max_ins_len < ss.str().length())
{
max_ins_len = ss.str().length();
}
// skip return instruction
if(ins->name() == "@return")
return;
});
return max_ins_len;
}
static auto& get_titles()
{
static std::vector<std::string> titles = {"Instructions",
"Time(ms) ",
"Percentage ",
"(b, m, n, k) ",
"Flops(TFlops/s) ",
"Throughput(GB/s)"};
return titles;
}
static void print_title(std::ostream& os, std::size_t max_ins_len, bool print_percentage = true)
{
auto titles = get_titles();
std::string& str = titles.front();
str.append(max_ins_len + 1 - str.length(), ' ');
str.append(1, '\t');
os << str;
if(not print_percentage)
titles.erase(titles.begin() + 2);
int i = 1;
for(; i < titles.size(); ++i)
{
os << titles[i];
}
os << std::endl;
}
static void print_ins_perf(std::ostream& os,
const std::vector<std::string>& titles,
instruction_ref ins,
double t,
double total_t,
bool print_percentage = true)
{
auto& time_str = titles.at(1);
auto& time_per = titles.at(2);
auto& size_str = titles.at(3);
auto& flops_str = titles.at(4);
auto& thrpt_str = titles.at(5);
auto& flops_funcs = get_flops_funcs();
std::string tms = std::to_string(t);
tms.append(time_str.length() - tms.length(), ' ');
std::string pers;
if(print_percentage)
{
double percent = 100.0 * t / total_t;
pers = std::to_string(percent);
auto loc = pers.find('.');
if(loc != std::string::npos)
{
pers.erase(pers.begin() + loc + 6, pers.end());
}
pers.append(time_per.length() - pers.length(), ' ');
}
// calculate flops
std::string szs;
std::string flps;
std::string op_name = ins->name();
auto nloc = op_name.find("::");
op_name.erase(op_name.begin(), op_name.begin() + nloc + 2);
auto inss = to_shapes(ins->inputs());
if(contains(flops_funcs, op_name))
{
// print size
auto alens = inss.front().lens();
auto blens = inss.at(1).lens();
auto mb =
std::accumulate(alens.rbegin() + 2, alens.rend(), 1, std::multiplies<std::size_t>{});
int mm = alens[alens.size() - 2];
int mk = alens.back();
int mn = blens.back();
szs = "{";
szs.append(std::to_string(mb));
szs.append(1, ',');
szs.append(std::to_string(mm));
szs.append(1, ',');
szs.append(std::to_string(mk));
szs.append(1, ',');
szs.append(std::to_string(mn));
szs.append("}");
szs.append(size_str.length() - szs.length(), ' ');
auto op_flop_func = flops_funcs.at(op_name);
double flops = op_flop_func(inss);
flops /= t;
// convert to GFlops
flops /= 1.0e9;
flps = std::to_string(flops);
auto floc = flps.find('.');
if(floc != std::string::npos)
{
flps.erase(flps.begin() + floc + 4, flps.end());
}
}
szs.append(size_str.length() - szs.length(), ' ');
flps.append(flops_str.length() - flps.length(), ' ');
// print throughput for pointwise instruction
auto alias_num = ins->get_operator().output_alias({});
std::string thrpt;
if(alias_num != 0)
{
auto size =
std::accumulate(inss.begin(), inss.end(), std::size_t{0}, [&](auto init, auto s) {
return init + s.bytes();
});
double throughput = size / t;
// convert to GB/s
throughput /= 1.0e6;
thrpt = std::to_string(throughput);
auto floc = flps.find('.');
if(floc != std::string::npos)
{
thrpt.erase(thrpt.begin() + floc + 4, thrpt.end());
}
}
thrpt.append(thrpt_str.length() - thrpt.length(), ' ');
os << tms << pers << szs << flps << thrpt << std::endl;
}
std::vector<argument> program::eval(parameter_map params) const std::vector<argument> program::eval(parameter_map params) const
{ {
auto& ctx = this->impl->ctx; auto& ctx = this->impl->ctx;
...@@ -353,37 +555,70 @@ std::vector<argument> program::eval(parameter_map params) const ...@@ -353,37 +555,70 @@ std::vector<argument> program::eval(parameter_map params) const
if(trace_level > 0) if(trace_level > 0)
{ {
auto max_ins_len = max_ins_length();
std::unordered_map<instruction_ref, std::string> ins_names;
this->print(ins_names, [&](auto, auto) {});
if(trace_level == 3)
{
std::string prefix = "Run instruction: ";
max_ins_len += prefix.length();
print_title(std::cout, max_ins_len, false);
}
return generic_eval(*this, return generic_eval(*this,
ctx, ctx,
std::move(params), std::move(params),
with_check_context([&](auto& ins, auto f, auto&& check_context) { with_check_context([&](auto& ins, auto f, auto&& check_context) {
ctx.finish(); ctx.finish();
std::cout << "Run instruction: "; std::stringstream ss;
this->debug_print(ins); ss << "Run instruction: ";
this->debug_print(ss, ins, ins_names);
timer t{}; timer t{};
auto result = check_context(f); auto result = check_context(f);
double t1 = t.record<milliseconds>(); double t1 = t.record<milliseconds>();
ctx.finish(); ctx.finish();
double t2 = t.record<milliseconds>(); double t2 = t.record<milliseconds>();
std::cout << "Time: " << t1 << "ms, " << t2 << "ms" << std::endl; if(trace_level < 3)
if(trace_level > 1 and ins->name().front() != '@' and
ins->name() != "load" and not result.empty())
{ {
target tgt = make_target(this->impl->target_name); std::cout << ss.str() << std::endl;
auto buffer = tgt.copy_from(result); std::cout << "Time: " << t1 << "ms, " << t2
if(trace_level == 2) << "ms, execution time:\t";
if(trace_level == 2 and ins->name().front() != '@' and
ins->name() != "load" and not result.empty())
{ {
std::cout << "Output has " target tgt = make_target(this->impl->target_name);
<< to_string_range(classify_argument(buffer)) auto buffer = tgt.copy_from(result);
<< std::endl; if(trace_level == 2)
std::cout << "Output: "; {
preview_argument(std::cout, buffer); std::cout << "Output has "
std::cout << std::endl; << to_string_range(classify_argument(buffer))
<< std::endl;
std::cout << "Output: ";
preview_argument(std::cout, buffer);
std::cout << std::endl;
}
else
{
std::cout << "Output: " << buffer << std::endl;
}
} }
else }
else if(trace_level == 3)
{
// count max instruction length
if(ins->get_operator().output_alias({}) == 0)
{ {
std::cout << "Output: " << buffer << std::endl; std::cout << ss.str() << std::endl;
return result;
} }
print_space(ss, max_ins_len - ss.str().length());
ss << '\t';
std::cout << ss.str();
auto titles = get_titles();
double exec_t = t2 - t1;
print_ins_perf(std::cout, titles, ins, exec_t, exec_t, false);
} }
return result; return result;
})); }));
...@@ -661,17 +896,26 @@ void program::perf_report(std::ostream& os, ...@@ -661,17 +896,26 @@ void program::perf_report(std::ostream& os,
double calculate_overhead_percent = calculate_overhead_time * 100.0 / total_time; double calculate_overhead_percent = calculate_overhead_time * 100.0 / total_time;
std::unordered_map<instruction_ref, std::string> names; std::unordered_map<instruction_ref, std::string> names;
// count max instruction length
auto titles = get_titles();
const int max_ins_len = max_ins_length();
print_title(os, max_ins_len);
this->print(names, [&](auto ins, auto ins_names) { this->print(names, [&](auto ins, auto ins_names) {
instruction::print(std::cout, ins, ins_names); std::stringstream ss;
instruction::print(ss, ins, ins_names);
os << ss.str();
// skip return instruction // skip return instruction
if(ins->name() == "@return") if(ins->name() == "@return")
return; return;
double avg = common_average(ins_vec[ins]); // insert space to align
double percent = std::ceil(100.0 * avg / total_instruction_time); print_space(os, max_ins_len - ss.str().length());
os << ": " << avg << "ms, " << percent << "%"; os << "\t";
os << std::endl; double avg = common_average(ins_vec[ins]);
print_ins_perf(os, titles, ins, avg, total_instruction_time);
}); });
os << std::endl; os << std::endl;
...@@ -731,6 +975,31 @@ void program::debug_print(instruction_ref ins) const ...@@ -731,6 +975,31 @@ void program::debug_print(instruction_ref ins) const
}); });
} }
void program::debug_print(std::ostream& os,
instruction_ref ins,
const std::unordered_map<instruction_ref, std::string>& names) const
{
if(std::any_of(this->impl->modules.begin(), this->impl->modules.end(), [&](const auto& pp) {
return is_end(pp.second.end(), ins);
}))
{
os << "End instruction" << std::endl;
return;
}
else if(std::none_of(this->impl->modules.begin(),
this->impl->modules.end(),
[&](const auto& pp) { return pp.second.has_instruction(ins); }))
{
os << "Instruction not part of program" << std::endl;
return;
}
if(contains(names, ins))
{
instruction::print(os, ins, names);
}
}
void program::print( void program::print(
std::unordered_map<instruction_ref, std::string>& names, std::unordered_map<instruction_ref, std::string>& names,
const std::function<void(instruction_ref, std::unordered_map<instruction_ref, std::string>)>& const std::function<void(instruction_ref, std::unordered_map<instruction_ref, std::string>)>&
......
...@@ -120,6 +120,17 @@ struct find_nop_reshapes ...@@ -120,6 +120,17 @@ struct find_nop_reshapes
void apply(module& p, const match::matcher_result& mr) const void apply(module& p, const match::matcher_result& mr) const
{ {
auto ins = mr.result; auto ins = mr.result;
// output of reshape and contiguous is standard, so no need to add another contiguous
// if the output is used an a ret value
if(ins->name() == "contiguous" and ins->name() != "contiguous" and ins->name() != "reshape")
{
auto& outputs = ins->outputs();
if(std::any_of(
outputs.begin(), outputs.end(), [&](auto o) { return o->name() == "@return"; }))
{
return;
}
}
p.replace_instruction(ins, ins->inputs().front()); p.replace_instruction(ins, ins->inputs().front());
} }
}; };
......
...@@ -6,15 +6,32 @@ ...@@ -6,15 +6,32 @@
namespace migraphx { namespace migraphx {
template <class T>
struct remove_vec_impl
{
using type = T;
};
template <class T, index_int N>
struct remove_vec_impl<vec<T, N>>
{
using type = T;
};
template <class T>
using remove_vec = typename remove_vec_impl<T>::type;
template <class T, class... Shapes> template <class T, class... Shapes>
constexpr auto traverse_preload(Shapes... ss) constexpr auto traverse_preload(Shapes... ss)
{ {
return [=](auto f, auto... g) { return [=](auto f, auto... g) {
index_int offset = 0; index_int offset = 0;
auto each = [&](auto x) { auto each = [&](auto x) {
using type = remove_vec<typename decltype(x)::type>;
constexpr auto s = decltype(x.get_shape()){}; constexpr auto s = decltype(x.get_shape()){};
constexpr auto size = _c<s.element_space()>; constexpr auto size = _c<s.element_space()>;
if constexpr(not s.broadcasted() or (s.elements() - size) < 64) if constexpr(not s.broadcasted() or (s.elements() - size) < 64 or
not is_same<T, type>{})
return f(x, offset, false_type{}); return f(x, offset, false_type{});
else else
{ {
...@@ -78,23 +95,23 @@ __device__ auto preload_copy(index idx, F f, __shared__ T* buffer, Ts... xs) ...@@ -78,23 +95,23 @@ __device__ auto preload_copy(index idx, F f, __shared__ T* buffer, Ts... xs)
invoke); invoke);
} }
template <class T> template <class T, class Shape>
struct remove_vec struct shape_type : Shape
{ {
using type = T; using type = T;
}; };
template <class T, index_int N> template <class T>
struct remove_vec<vec<T, N>> constexpr auto make_shape_type(T)
{ {
using type = T; return shape_type<typename T::type, typename T::shape_type>{};
}; }
template <class T, class... Ts> template <class T, class... Ts>
__device__ auto preload(index idx, Ts... xs) __device__ auto preload(index idx, Ts... xs)
{ {
using type = typename remove_vec<T>::type; using type = remove_vec<T>;
constexpr auto size = decltype(compute_preload_size<type>(xs.get_shape()...)){}; constexpr auto size = decltype(compute_preload_size<type>(make_shape_type(xs)...)){};
const index_int max_size = 512 * sizeof(type); const index_int max_size = 512 * sizeof(type);
return [=](auto f) { return [=](auto f) {
if constexpr(size > 0 and size < max_size) if constexpr(size > 0 and size < max_size)
......
...@@ -9,7 +9,8 @@ namespace migraphx { ...@@ -9,7 +9,8 @@ namespace migraphx {
template <class T, class Shape> template <class T, class Shape>
struct tensor_view struct tensor_view
{ {
using type = T; using type = T;
using shape_type = Shape;
constexpr Shape get_shape() const { return Shape{}; } constexpr Shape get_shape() const { return Shape{}; }
constexpr index_int size() const { return get_shape().elements(); } constexpr index_int size() const { return get_shape().elements(); }
......
...@@ -25,6 +25,16 @@ struct is_convertible : bool_constant<__is_convertible(From, To)> ...@@ -25,6 +25,16 @@ struct is_convertible : bool_constant<__is_convertible(From, To)>
{ {
}; };
template <class T, class U>
struct is_same : false_type
{
};
template <class T>
struct is_same<T, T> : true_type
{
};
#define MIGRAPHX_REQUIRES(...) class = enable_if_t<__VA_ARGS__> #define MIGRAPHX_REQUIRES(...) class = enable_if_t<__VA_ARGS__>
} // namespace migraphx } // namespace migraphx
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment