Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
c1ec929c
Commit
c1ec929c
authored
Mar 21, 2022
by
Shucai Xiao
Browse files
merge changes from develop branch
parents
abe2a889
03225b57
Changes
116
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
629 additions
and
63 deletions
+629
-63
.gitignore
.gitignore
+63
-1
src/CMakeLists.txt
src/CMakeLists.txt
+1
-0
src/api/api.cpp
src/api/api.cpp
+240
-0
src/api/include/migraphx/migraphx.h
src/api/include/migraphx/migraphx.h
+67
-0
src/api/include/migraphx/migraphx.hpp
src/api/include/migraphx/migraphx.hpp
+117
-17
src/api/migraphx.py
src/api/migraphx.py
+50
-0
src/driver/alexnet.cpp
src/driver/alexnet.cpp
+3
-3
src/driver/inceptionv3.cpp
src/driver/inceptionv3.cpp
+14
-14
src/driver/main.cpp
src/driver/main.cpp
+8
-0
src/driver/perf.cpp
src/driver/perf.cpp
+1
-1
src/driver/resnet50.cpp
src/driver/resnet50.cpp
+2
-2
src/eliminate_pad.cpp
src/eliminate_pad.cpp
+1
-1
src/include/migraphx/json.hpp
src/include/migraphx/json.hpp
+1
-0
src/include/migraphx/op/common.hpp
src/include/migraphx/op/common.hpp
+11
-0
src/include/migraphx/op/pooling.hpp
src/include/migraphx/op/pooling.hpp
+2
-1
src/include/migraphx/op/prefix_scan_op.hpp
src/include/migraphx/op/prefix_scan_op.hpp
+29
-9
src/include/migraphx/op/roialign.hpp
src/include/migraphx/op/roialign.hpp
+13
-11
src/include/migraphx/operation.hpp
src/include/migraphx/operation.hpp
+2
-2
src/include/migraphx/program.hpp
src/include/migraphx/program.hpp
+3
-0
src/insert_pad.cpp
src/insert_pad.cpp
+1
-1
No files found.
.gitignore
View file @
c1ec929c
*.swp
#==============================================================================#
# File extensions to be ignored anywhere in the tree.
#==============================================================================#
# Temp files created by most text editors
*~
# Merge files created by git
*.orig
# Byte compiled python modules
*.pyc
*.pyd
# Vim swap files
.*.sw?
.sw?
# Visual Studio
.vs
/.vscode/*
# Sublime Text settings
*.sublime-workspace
*.sublime-project
# Eclipse Project settings
*.*project
.settings
# OS X specific files
.DS_store
#==============================================================================#
# Explicit files to ignore (only matches one).
#==============================================================================#
# Various tags
/tags
/TAGS
/GPATH
/GRTAGS
/GSYMS
/GTAGS
/ID
.gitusers
/compile_commands.json
/CMakeSettings.json
#==============================================================================#
# Directories to ignore (do not add trailing '/'s, they skip symlinks).
#==============================================================================#
# Nested build directory
/build*
# Downloaded models
test/onnx/models
# VS2017 and VSCode config files.
.vscode
.vs
src/CMakeLists.txt
View file @
c1ec929c
...
...
@@ -38,6 +38,7 @@ add_library(migraphx
msgpack.cpp
normalize_attributes.cpp
normalize_ops.cpp
op_enums.cpp
operation.cpp
opt/memory_coloring.cpp
opt/memory_coloring_impl.cpp
...
...
src/api/api.cpp
View file @
c1ec929c
...
...
@@ -4,6 +4,7 @@
#include <migraphx/program.hpp>
#include <migraphx/onnx.hpp>
#include <migraphx/tf.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/quantization.hpp>
...
...
@@ -72,6 +73,23 @@ migraphx_shape_datatype_t to_shape_type(shape::type_t t)
MIGRAPHX_THROW
(
migraphx_status_bad_param
,
"Unknown type"
);
}
template
<
class
T
>
auto
to_obj_vector
(
const
T
*
x
,
std
::
size_t
n
)
{
std
::
vector
<
decltype
((
*
x
)
->
object
)
>
result
;
std
::
transform
(
x
,
x
+
n
,
std
::
back_inserter
(
result
),
[
&
](
auto
&&
y
)
{
return
y
->
object
;
});
return
result
;
}
template
<
class
T
,
class
U
>
auto
to_objptr_vector
(
const
U
*
x
,
std
::
size_t
n
)
{
std
::
vector
<
T
>
result
;
std
::
transform
(
x
,
x
+
n
,
std
::
back_inserter
(
result
),
[
&
](
auto
&&
y
)
{
return
std
::
addressof
(
y
->
object
);
});
return
result
;
}
target
get_target
(
const
std
::
string
&
name
)
{
return
make_target
(
name
);
}
void
set_offload_copy
(
compile_options
&
options
,
bool
value
)
{
options
.
offload_copy
=
value
;
}
...
...
@@ -194,6 +212,8 @@ void print_program(const program& p) { std::cout << p << std::endl; }
void
print_module
(
const
module
&
m
)
{
std
::
cout
<<
m
<<
std
::
endl
;
}
migraphx
::
context
get_context
(
const
program
&
p
)
{
return
p
.
get_context
();
}
}
// namespace migraphx
template
<
class
T
,
class
U
,
class
Target
=
std
::
remove_pointer_t
<
T
>
>
...
...
@@ -289,6 +309,36 @@ struct migraphx_shapes
std
::
vector
<
migraphx
::
shape
>
object
;
};
extern
"C"
struct
migraphx_instruction
;
struct
migraphx_instruction
{
template
<
class
...
Ts
>
migraphx_instruction
(
Ts
&&
...
xs
)
:
object
(
std
::
forward
<
Ts
>
(
xs
)...)
{
}
migraphx
::
instruction_ref
object
;
};
extern
"C"
struct
migraphx_instructions
;
struct
migraphx_instructions
{
template
<
class
...
Ts
>
migraphx_instructions
(
Ts
&&
...
xs
)
:
object
(
std
::
forward
<
Ts
>
(
xs
)...)
{
}
std
::
vector
<
migraphx
::
instruction_ref
>
object
;
};
extern
"C"
struct
migraphx_modules
;
struct
migraphx_modules
{
template
<
class
...
Ts
>
migraphx_modules
(
Ts
&&
...
xs
)
:
object
(
std
::
forward
<
Ts
>
(
xs
)...)
{
}
std
::
vector
<
migraphx
::
module
*>
object
;
};
extern
"C"
struct
migraphx_module
;
struct
migraphx_module
{
...
...
@@ -379,6 +429,16 @@ struct migraphx_quantize_int8_options
migraphx
::
quantize_int8_options
object
;
};
extern
"C"
struct
migraphx_context
;
struct
migraphx_context
{
template
<
class
...
Ts
>
migraphx_context
(
Ts
&&
...
xs
)
:
object
(
std
::
forward
<
Ts
>
(
xs
)...)
{
}
migraphx
::
context
object
;
};
extern
"C"
migraphx_status
migraphx_shape_destroy
(
migraphx_shape_t
shape
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
destroy
((
shape
));
});
...
...
@@ -762,6 +822,77 @@ migraphx_shapes_get(const_migraphx_shape_t* out, migraphx_shapes_t shapes, size_
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_instruction_destroy
(
migraphx_instruction_t
instruction
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
destroy
((
instruction
));
});
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_instruction_assign_to
(
migraphx_instruction_t
output
,
const_migraphx_instruction_t
input
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
*
output
=
*
input
;
});
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_instructions_destroy
(
migraphx_instructions_t
instructions
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
destroy
((
instructions
));
});
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_instructions_assign_to
(
migraphx_instructions_t
output
,
const_migraphx_instructions_t
input
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
*
output
=
*
input
;
});
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_instructions_create
(
migraphx_instructions_t
*
instructions
,
const_migraphx_instruction_t
*
ptr
,
size_t
size
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
*
instructions
=
object_cast
<
migraphx_instructions_t
>
(
allocate
<
std
::
vector
<
migraphx
::
instruction_ref
>>
(
migraphx
::
to_obj_vector
<
const_migraphx_instruction_t
>
((
ptr
),
(
size
))));
});
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_modules_destroy
(
migraphx_modules_t
modules
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
destroy
((
modules
));
});
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_modules_assign_to
(
migraphx_modules_t
output
,
const_migraphx_modules_t
input
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
*
output
=
*
input
;
});
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_modules_create
(
migraphx_modules_t
*
modules
,
migraphx_module_t
*
ptr
,
size_t
size
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
*
modules
=
object_cast
<
migraphx_modules_t
>
(
allocate
<
std
::
vector
<
migraphx
::
module
*>>
(
migraphx
::
to_objptr_vector
<
migraphx
::
module
*>
((
ptr
),
(
size
))));
});
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_module_create
(
migraphx_module_t
*
module
,
char
*
name
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
if
(
name
==
nullptr
)
MIGRAPHX_THROW
(
migraphx_status_bad_param
,
"Bad parameter name: Null pointer"
);
*
module
=
object_cast
<
migraphx_module_t
>
(
allocate
<
migraphx
::
module
>
((
std
::
string
(
name
))));
});
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_module_print
(
const_migraphx_module_t
module
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
...
...
@@ -772,6 +903,76 @@ extern "C" migraphx_status migraphx_module_print(const_migraphx_module_t module)
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_module_add_instruction
(
migraphx_instruction_t
*
out
,
migraphx_module_t
module
,
migraphx_operation_t
op
,
migraphx_instructions_t
args
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
if
(
module
==
nullptr
)
MIGRAPHX_THROW
(
migraphx_status_bad_param
,
"Bad parameter module: Null pointer"
);
if
(
op
==
nullptr
)
MIGRAPHX_THROW
(
migraphx_status_bad_param
,
"Bad parameter op: Null pointer"
);
if
(
args
==
nullptr
)
MIGRAPHX_THROW
(
migraphx_status_bad_param
,
"Bad parameter args: Null pointer"
);
*
out
=
allocate
<
migraphx_instruction_t
>
(
(
module
->
object
).
add_instruction
((
op
->
object
),
(
args
->
object
)));
});
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_module_add_instruction_with_mod_args
(
migraphx_instruction_t
*
out
,
migraphx_module_t
module
,
migraphx_operation_t
op
,
migraphx_instructions_t
args
,
migraphx_modules_t
module_refs
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
if
(
module
==
nullptr
)
MIGRAPHX_THROW
(
migraphx_status_bad_param
,
"Bad parameter module: Null pointer"
);
if
(
op
==
nullptr
)
MIGRAPHX_THROW
(
migraphx_status_bad_param
,
"Bad parameter op: Null pointer"
);
if
(
args
==
nullptr
)
MIGRAPHX_THROW
(
migraphx_status_bad_param
,
"Bad parameter args: Null pointer"
);
if
(
module_refs
==
nullptr
)
MIGRAPHX_THROW
(
migraphx_status_bad_param
,
"Bad parameter module_refs: Null pointer"
);
*
out
=
allocate
<
migraphx_instruction_t
>
(
(
module
->
object
).
add_instruction
((
op
->
object
),
(
args
->
object
),
(
module_refs
->
object
)));
});
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_module_add_parameter
(
migraphx_instruction_t
*
out
,
migraphx_module_t
module
,
const
char
*
name
,
const_migraphx_shape_t
shape
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
if
(
module
==
nullptr
)
MIGRAPHX_THROW
(
migraphx_status_bad_param
,
"Bad parameter module: Null pointer"
);
if
(
shape
==
nullptr
)
MIGRAPHX_THROW
(
migraphx_status_bad_param
,
"Bad parameter shape: Null pointer"
);
*
out
=
allocate
<
migraphx_instruction_t
>
(
(
module
->
object
).
add_parameter
((
name
),
(
shape
->
object
)));
});
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_module_add_return
(
migraphx_instruction_t
*
out
,
migraphx_module_t
module
,
migraphx_instructions_t
args
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
if
(
module
==
nullptr
)
MIGRAPHX_THROW
(
migraphx_status_bad_param
,
"Bad parameter module: Null pointer"
);
if
(
args
==
nullptr
)
MIGRAPHX_THROW
(
migraphx_status_bad_param
,
"Bad parameter args: Null pointer"
);
*
out
=
allocate
<
migraphx_instruction_t
>
((
module
->
object
).
add_return
((
args
->
object
)));
});
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_program_destroy
(
migraphx_program_t
program
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
destroy
((
program
));
});
...
...
@@ -785,6 +986,13 @@ extern "C" migraphx_status migraphx_program_assign_to(migraphx_program_t output,
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_program_create
(
migraphx_program_t
*
program
)
{
auto
api_error_result
=
migraphx
::
try_
(
[
&
]
{
*
program
=
object_cast
<
migraphx_program_t
>
(
allocate
<
migraphx
::
program
>
());
});
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_program_get_main_module
(
migraphx_module_t
*
out
,
migraphx_program_t
program
)
{
...
...
@@ -796,6 +1004,17 @@ extern "C" migraphx_status migraphx_program_get_main_module(migraphx_module_t* o
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_program_create_module
(
migraphx_module_t
*
out
,
migraphx_program_t
program
,
const
char
*
name
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
if
(
program
==
nullptr
)
MIGRAPHX_THROW
(
migraphx_status_bad_param
,
"Bad parameter program: Null pointer"
);
*
out
=
object_cast
<
migraphx_module_t
>
((
program
->
object
).
create_module
((
name
)));
});
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_program_compile
(
migraphx_program_t
program
,
migraphx_target_t
target
,
migraphx_compile_options_t
options
)
...
...
@@ -883,6 +1102,17 @@ migraphx_program_equal(bool* out, const_migraphx_program_t program, const_migrap
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_program_experimental_get_context
(
migraphx_context_t
*
out
,
const_migraphx_program_t
program
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
if
(
program
==
nullptr
)
MIGRAPHX_THROW
(
migraphx_status_bad_param
,
"Bad parameter program: Null pointer"
);
*
out
=
allocate
<
migraphx_context_t
>
(
migraphx
::
get_context
((
program
->
object
)));
});
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_operation_destroy
(
migraphx_operation_t
operation
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
destroy
((
operation
));
});
...
...
@@ -1324,3 +1554,13 @@ extern "C" migraphx_status migraphx_quantize_int8(migraphx_program_t prog,
});
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_context_finish
(
const_migraphx_context_t
context
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
if
(
context
==
nullptr
)
MIGRAPHX_THROW
(
migraphx_status_bad_param
,
"Bad parameter context: Null pointer"
);
(
context
->
object
).
finish
();
});
return
api_error_result
;
}
src/api/include/migraphx/migraphx.h
View file @
c1ec929c
...
...
@@ -64,6 +64,15 @@ typedef const struct migraphx_arguments* const_migraphx_arguments_t;
typedef
struct
migraphx_shapes
*
migraphx_shapes_t
;
typedef
const
struct
migraphx_shapes
*
const_migraphx_shapes_t
;
typedef
struct
migraphx_instruction
*
migraphx_instruction_t
;
typedef
const
struct
migraphx_instruction
*
const_migraphx_instruction_t
;
typedef
struct
migraphx_instructions
*
migraphx_instructions_t
;
typedef
const
struct
migraphx_instructions
*
const_migraphx_instructions_t
;
typedef
struct
migraphx_modules
*
migraphx_modules_t
;
typedef
const
struct
migraphx_modules
*
const_migraphx_modules_t
;
typedef
struct
migraphx_module
*
migraphx_module_t
;
typedef
const
struct
migraphx_module
*
const_migraphx_module_t
;
...
...
@@ -91,6 +100,9 @@ typedef const struct migraphx_quantize_op_names* const_migraphx_quantize_op_name
typedef
struct
migraphx_quantize_int8_options
*
migraphx_quantize_int8_options_t
;
typedef
const
struct
migraphx_quantize_int8_options
*
const_migraphx_quantize_int8_options_t
;
typedef
struct
migraphx_context
*
migraphx_context_t
;
typedef
const
struct
migraphx_context
*
const_migraphx_context_t
;
migraphx_status
migraphx_shape_destroy
(
migraphx_shape_t
shape
);
migraphx_status
migraphx_shape_assign_to
(
migraphx_shape_t
output
,
const_migraphx_shape_t
input
);
...
...
@@ -198,16 +210,66 @@ migraphx_status migraphx_shapes_size(size_t* out, migraphx_shapes_t shapes);
migraphx_status
migraphx_shapes_get
(
const_migraphx_shape_t
*
out
,
migraphx_shapes_t
shapes
,
size_t
idx
);
migraphx_status
migraphx_instruction_destroy
(
migraphx_instruction_t
instruction
);
migraphx_status
migraphx_instruction_assign_to
(
migraphx_instruction_t
output
,
const_migraphx_instruction_t
input
);
migraphx_status
migraphx_instructions_destroy
(
migraphx_instructions_t
instructions
);
migraphx_status
migraphx_instructions_assign_to
(
migraphx_instructions_t
output
,
const_migraphx_instructions_t
input
);
migraphx_status
migraphx_instructions_create
(
migraphx_instructions_t
*
instructions
,
const_migraphx_instruction_t
*
ptr
,
size_t
size
);
migraphx_status
migraphx_modules_destroy
(
migraphx_modules_t
modules
);
migraphx_status
migraphx_modules_assign_to
(
migraphx_modules_t
output
,
const_migraphx_modules_t
input
);
migraphx_status
migraphx_modules_create
(
migraphx_modules_t
*
modules
,
migraphx_module_t
*
ptr
,
size_t
size
);
migraphx_status
migraphx_module_create
(
migraphx_module_t
*
module
,
char
*
name
);
migraphx_status
migraphx_module_print
(
const_migraphx_module_t
module
);
migraphx_status
migraphx_module_add_instruction
(
migraphx_instruction_t
*
out
,
migraphx_module_t
module
,
migraphx_operation_t
op
,
migraphx_instructions_t
args
);
migraphx_status
migraphx_module_add_instruction_with_mod_args
(
migraphx_instruction_t
*
out
,
migraphx_module_t
module
,
migraphx_operation_t
op
,
migraphx_instructions_t
args
,
migraphx_modules_t
module_refs
);
migraphx_status
migraphx_module_add_parameter
(
migraphx_instruction_t
*
out
,
migraphx_module_t
module
,
const
char
*
name
,
const_migraphx_shape_t
shape
);
migraphx_status
migraphx_module_add_return
(
migraphx_instruction_t
*
out
,
migraphx_module_t
module
,
migraphx_instructions_t
args
);
migraphx_status
migraphx_program_destroy
(
migraphx_program_t
program
);
migraphx_status
migraphx_program_assign_to
(
migraphx_program_t
output
,
const_migraphx_program_t
input
);
migraphx_status
migraphx_program_create
(
migraphx_program_t
*
program
);
migraphx_status
migraphx_program_get_main_module
(
migraphx_module_t
*
out
,
migraphx_program_t
program
);
migraphx_status
migraphx_program_create_module
(
migraphx_module_t
*
out
,
migraphx_program_t
program
,
const
char
*
name
);
migraphx_status
migraphx_program_compile
(
migraphx_program_t
program
,
migraphx_target_t
target
,
migraphx_compile_options_t
options
);
...
...
@@ -229,6 +291,9 @@ migraphx_status migraphx_program_run(migraphx_arguments_t* out,
migraphx_status
migraphx_program_equal
(
bool
*
out
,
const_migraphx_program_t
program
,
const_migraphx_program_t
x
);
migraphx_status
migraphx_program_experimental_get_context
(
migraphx_context_t
*
out
,
const_migraphx_program_t
program
);
migraphx_status
migraphx_operation_destroy
(
migraphx_operation_t
operation
);
migraphx_status
migraphx_operation_assign_to
(
migraphx_operation_t
output
,
...
...
@@ -355,6 +420,8 @@ migraphx_status migraphx_quantize_int8(migraphx_program_t prog,
migraphx_target_t
target
,
migraphx_quantize_int8_options_t
options
);
migraphx_status
migraphx_context_finish
(
const_migraphx_context_t
context
);
#ifdef __cplusplus
}
#endif
...
...
src/api/include/migraphx/migraphx.hpp
View file @
c1ec929c
#ifndef MIGRAPHX_GUARD_API_RTGLIB_MIGRAPHX_HPP
#define MIGRAPHX_GUARD_API_RTGLIB_MIGRAPHX_HPP
#include "migraphx.h"
#include <initializer_list>
#include <migraphx/migraphx.h>
#include <memory>
#include <exception>
...
...
@@ -523,12 +525,116 @@ struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes>
};
};
struct
operation
:
MIGRAPHX_HANDLE_BASE
(
operation
)
{
operation
(
migraphx_operation
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
operation
(
migraphx_operation
*
p
,
borrow
)
{
this
->
set_handle
(
p
,
borrow
{});
}
template
<
class
...
Ts
>
operation
(
const
char
*
name
,
const
char
*
attributes
=
nullptr
,
Ts
...
xs
)
{
this
->
make_handle
(
&
migraphx_operation_create
,
name
,
attributes
,
xs
...);
}
std
::
string
name
()
{
std
::
array
<
char
,
1024
>
out_name
;
call
(
&
migraphx_operation_name
,
out_name
.
data
(),
1024
,
this
->
get_handle_ptr
());
return
{
out_name
.
data
()};
}
};
struct
instruction
:
MIGRAPHX_CONST_HANDLE_BASE
(
instruction
)
{
instruction
(
migraphx_instruction
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
};
struct
instructions
:
MIGRAPHX_HANDLE_BASE
(
instructions
)
{
instructions
(
migraphx_instructions
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
instructions
(
migraphx_instructions
*
p
,
borrow
)
{
this
->
set_handle
(
p
,
borrow
{});
}
template
<
class
...
Ts
>
instructions
(
Ts
...
xs
)
{
std
::
array
<
const_migraphx_instruction_t
,
sizeof
...(
Ts
)
>
a
{
xs
.
get_handle_ptr
()...};
this
->
make_handle
(
&
migraphx_instructions_create
,
a
.
data
(),
a
.
size
());
}
};
struct
module
;
struct
modules
:
MIGRAPHX_HANDLE_BASE
(
modules
)
{
modules
(
migraphx_modules
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
modules
(
migraphx_modules
*
p
,
borrow
)
{
this
->
set_handle
(
p
,
borrow
{});
}
template
<
class
...
Ts
>
modules
(
Ts
...
xs
)
{
std
::
array
<
migraphx_module_t
,
sizeof
...(
Ts
)
>
a
=
{
xs
.
mm
...};
this
->
make_handle
(
&
migraphx_modules_create
,
a
.
data
(),
a
.
size
());
}
};
struct
module
{
migraphx_module_t
mm
;
module
(
const
migraphx_module_t
&
m
)
:
mm
(
m
)
{}
void
print
()
const
{
call
(
&
migraphx_module_print
,
mm
);
}
instruction
add_instruction
(
const
migraphx
::
operation
&
op
,
const
migraphx
::
instructions
&
args
)
{
migraphx_instruction_t
op_ins
;
call
(
&
migraphx_module_add_instruction
,
&
op_ins
,
mm
,
op
.
get_handle_ptr
(),
args
.
get_handle_ptr
());
return
instruction
(
op_ins
,
own
{});
}
instruction
add_instruction
(
const
migraphx
::
operation
&
op
,
const
migraphx
::
instructions
&
args
,
const
migraphx
::
modules
&
module_args
)
{
migraphx_instruction_t
op_ins
;
call
(
&
migraphx_module_add_instruction_with_mod_args
,
&
op_ins
,
mm
,
op
.
get_handle_ptr
(),
args
.
get_handle_ptr
(),
module_args
.
get_handle_ptr
());
return
instruction
(
op_ins
,
own
{});
}
instruction
add_parameter
(
const
std
::
string
&
name
,
shape
s
)
{
migraphx_instruction_t
param_ins
;
call
(
&
migraphx_module_add_parameter
,
&
param_ins
,
mm
,
name
.
c_str
(),
s
.
get_handle_ptr
());
return
instruction
(
param_ins
,
own
{});
}
instruction
add_return
(
const
migraphx
::
instructions
&
args
)
{
migraphx_instruction_t
ret_ins
;
call
(
&
migraphx_module_add_return
,
&
ret_ins
,
mm
,
args
.
get_handle_ptr
());
return
instruction
(
ret_ins
,
own
{});
}
};
struct
context
{
migraphx_context_t
ctx
;
void
finish
()
const
{
call
(
&
migraphx_context_finish
,
ctx
);
}
};
struct
compile_options
:
MIGRAPHX_HANDLE_BASE
(
compile_options
)
...
...
@@ -557,7 +663,7 @@ struct compile_options : MIGRAPHX_HANDLE_BASE(compile_options)
/// A program represents the all computation graphs to be compiled and executed
struct
program
:
MIGRAPHX_HANDLE_BASE
(
program
)
{
program
()
{}
program
()
{
this
->
make_handle
(
&
migraphx_program_create
);
}
program
(
migraphx_program
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
...
...
@@ -627,27 +733,21 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
return
module
{
p_modu
};
}
friend
bool
operator
!=
(
const
program
&
px
,
const
program
&
py
)
{
return
!
(
px
==
py
);
}
};
struct
operation
:
MIGRAPHX_HANDLE_BASE
(
operation
)
{
operation
(
migraphx_operation
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
operation
(
migraphx_operation
*
p
,
borrow
)
{
this
->
set_handle
(
p
,
borrow
{});
}
template
<
class
...
Ts
>
operation
(
const
char
*
name
,
const
char
*
attributes
=
nullptr
,
Ts
...
xs
)
context
experimental_get_context
()
{
this
->
make_handle
(
&
migraphx_operation_create
,
name
,
attributes
,
xs
...);
migraphx_context_t
ctx
;
call
(
&
migraphx_program_experimental_get_context
,
&
ctx
,
this
->
get_handle_ptr
());
return
context
{
ctx
};
}
std
::
string
name
(
)
module
create_module
(
const
std
::
string
&
name
)
{
std
::
array
<
char
,
1024
>
out_name
;
call
(
&
migraphx_
operation_name
,
out_name
.
data
(),
1024
,
this
->
get_handle_ptr
());
return
{
out_name
.
data
()
};
migraphx_module_t
p_modu
;
call
(
&
migraphx_
program_create_module
,
&
p_modu
,
this
->
get_handle_ptr
()
,
name
.
data
()
);
return
module
{
p_modu
};
}
friend
bool
operator
!=
(
const
program
&
px
,
const
program
&
py
)
{
return
!
(
px
==
py
);
}
};
// options for migraphx file format options
...
...
src/api/migraphx.py
View file @
c1ec929c
...
...
@@ -178,14 +178,55 @@ def shapes(h):
returns
=
'const migraphx::shape&'
)
@
api
.
handle
(
'migraphx_instruction'
,
'migraphx::instruction_ref'
)
def
instruction
(
h
):
pass
@
api
.
handle
(
'migraphx_instructions'
,
'std::vector<migraphx::instruction_ref>'
)
def
instructions
(
h
):
h
.
constructor
(
'create'
,
api
.
params
(
ptr
=
'const_migraphx_instruction_t*'
,
size
=
'size_t'
),
fname
=
'migraphx::to_obj_vector<const_migraphx_instruction_t>'
)
@
api
.
handle
(
'migraphx_modules'
,
'std::vector<migraphx::module*>'
)
def
modules
(
h
):
h
.
constructor
(
'create'
,
api
.
params
(
ptr
=
'migraphx_module_t*'
,
size
=
'size_t'
),
fname
=
'migraphx::to_objptr_vector<migraphx::module*>'
)
@
auto_handle
(
ref
=
True
)
def
module
(
h
):
h
.
constructor
(
'create'
,
api
.
params
(
name
=
'std::string'
))
h
.
method
(
'print'
,
invoke
=
'migraphx::print_module($@)'
,
const
=
True
)
h
.
method
(
'add_instruction'
,
api
.
params
(
op
=
'migraphx::operation'
,
args
=
'std::vector<migraphx::instruction_ref>'
),
returns
=
'migraphx::instruction_ref'
)
h
.
method
(
'add_instruction_with_mod_args'
,
api
.
params
(
op
=
'migraphx::operation'
,
args
=
'std::vector<migraphx::instruction_ref>'
,
module_refs
=
'std::vector<migraphx::module*>'
),
fname
=
'add_instruction'
,
returns
=
'migraphx::instruction_ref'
)
h
.
method
(
'add_parameter'
,
api
.
params
(
name
=
'const char*'
,
shape
=
'const migraphx::shape&'
),
returns
=
'migraphx::instruction_ref'
)
h
.
method
(
'add_return'
,
api
.
params
(
args
=
'std::vector<migraphx::instruction_ref>'
),
returns
=
'migraphx::instruction_ref'
)
@
auto_handle
()
def
program
(
h
):
h
.
constructor
(
'create'
)
h
.
method
(
'get_main_module'
,
returns
=
'migraphx::module*'
)
h
.
method
(
'create_module'
,
api
.
params
(
name
=
'const char*'
),
returns
=
'migraphx::module*'
)
h
.
method
(
'compile'
,
api
.
params
(
target
=
'migraphx::target'
,
...
...
@@ -207,6 +248,10 @@ def program(h):
invoke
=
'migraphx::equal($@)'
,
returns
=
'bool'
,
const
=
True
)
h
.
method
(
'experimental_get_context'
,
invoke
=
'migraphx::get_context($@)'
,
const
=
True
,
returns
=
'migraphx::context'
)
@
auto_handle
()
...
...
@@ -353,3 +398,8 @@ api.add_function('migraphx_quantize_int8',
target
=
'migraphx::target'
,
options
=
'migraphx::quantize_int8_options'
),
fname
=
'migraphx::quantize_int8_wrap'
)
@
auto_handle
(
ref
=
True
)
def
context
(
h
):
h
.
method
(
'finish'
,
const
=
True
)
src/driver/alexnet.cpp
View file @
c1ec929c
...
...
@@ -61,7 +61,7 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
migraphx
::
op
::
relu
relu19
;
auto
mx19
=
mm
->
add_instruction
(
relu19
,
mx18
);
migraphx
::
op
::
pooling
pooling20
;
pooling20
.
mode
=
"max"
;
pooling20
.
mode
=
migraphx
::
op
::
pooling_mode
::
max
;
pooling20
.
padding
=
{
0
,
0
};
pooling20
.
stride
=
{
2
,
2
};
pooling20
.
lengths
=
{
3
,
3
};
...
...
@@ -81,7 +81,7 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
migraphx
::
op
::
relu
relu24
;
auto
mx24
=
mm
->
add_instruction
(
relu24
,
mx23
);
migraphx
::
op
::
pooling
pooling25
;
pooling25
.
mode
=
"max"
;
pooling25
.
mode
=
migraphx
::
op
::
pooling_mode
::
max
;
pooling25
.
padding
=
{
0
,
0
};
pooling25
.
stride
=
{
2
,
2
};
pooling25
.
lengths
=
{
3
,
3
};
...
...
@@ -129,7 +129,7 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
migraphx
::
op
::
relu
relu37
;
auto
mx37
=
mm
->
add_instruction
(
relu37
,
mx36
);
migraphx
::
op
::
pooling
pooling38
;
pooling38
.
mode
=
"max"
;
pooling38
.
mode
=
migraphx
::
op
::
pooling_mode
::
max
;
pooling38
.
padding
=
{
0
,
0
};
pooling38
.
stride
=
{
2
,
2
};
pooling38
.
lengths
=
{
3
,
3
};
...
...
src/driver/inceptionv3.cpp
View file @
c1ec929c
...
...
@@ -995,7 +995,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx
::
op
::
relu
relu492
;
auto
mx492
=
mm
->
add_instruction
(
relu492
,
mx491
);
migraphx
::
op
::
pooling
pooling493
;
pooling493
.
mode
=
"max"
;
pooling493
.
mode
=
migraphx
::
op
::
pooling_mode
::
max
;
pooling493
.
padding
=
{
0
,
0
};
pooling493
.
stride
=
{
2
,
2
};
pooling493
.
lengths
=
{
3
,
3
};
...
...
@@ -1025,7 +1025,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx
::
op
::
relu
relu499
;
auto
mx499
=
mm
->
add_instruction
(
relu499
,
mx498
);
migraphx
::
op
::
pooling
pooling500
;
pooling500
.
mode
=
"max"
;
pooling500
.
mode
=
migraphx
::
op
::
pooling_mode
::
max
;
pooling500
.
padding
=
{
0
,
0
};
pooling500
.
stride
=
{
2
,
2
};
pooling500
.
lengths
=
{
3
,
3
};
...
...
@@ -1103,7 +1103,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx
::
op
::
relu
relu518
;
auto
mx518
=
mm
->
add_instruction
(
relu518
,
mx517
);
migraphx
::
op
::
pooling
pooling519
;
pooling519
.
mode
=
"
average
"
;
pooling519
.
mode
=
migraphx
::
op
::
pooling_mode
::
average
;
pooling519
.
padding
=
{
1
,
1
};
pooling519
.
stride
=
{
1
,
1
};
pooling519
.
lengths
=
{
3
,
3
};
...
...
@@ -1196,7 +1196,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx
::
op
::
relu
relu541
;
auto
mx541
=
mm
->
add_instruction
(
relu541
,
mx540
);
migraphx
::
op
::
pooling
pooling542
;
pooling542
.
mode
=
"
average
"
;
pooling542
.
mode
=
migraphx
::
op
::
pooling_mode
::
average
;
pooling542
.
padding
=
{
1
,
1
};
pooling542
.
stride
=
{
1
,
1
};
pooling542
.
lengths
=
{
3
,
3
};
...
...
@@ -1289,7 +1289,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx
::
op
::
relu
relu564
;
auto
mx564
=
mm
->
add_instruction
(
relu564
,
mx563
);
migraphx
::
op
::
pooling
pooling565
;
pooling565
.
mode
=
"
average
"
;
pooling565
.
mode
=
migraphx
::
op
::
pooling_mode
::
average
;
pooling565
.
padding
=
{
1
,
1
};
pooling565
.
stride
=
{
1
,
1
};
pooling565
.
lengths
=
{
3
,
3
};
...
...
@@ -1358,7 +1358,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx
::
op
::
relu
relu581
;
auto
mx581
=
mm
->
add_instruction
(
relu581
,
mx580
);
migraphx
::
op
::
pooling
pooling582
;
pooling582
.
mode
=
"max"
;
pooling582
.
mode
=
migraphx
::
op
::
pooling_mode
::
max
;
pooling582
.
padding
=
{
0
,
0
};
pooling582
.
stride
=
{
2
,
2
};
pooling582
.
lengths
=
{
3
,
3
};
...
...
@@ -1475,7 +1475,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx
::
op
::
relu
relu610
;
auto
mx610
=
mm
->
add_instruction
(
relu610
,
mx609
);
migraphx
::
op
::
pooling
pooling611
;
pooling611
.
mode
=
"
average
"
;
pooling611
.
mode
=
migraphx
::
op
::
pooling_mode
::
average
;
pooling611
.
padding
=
{
1
,
1
};
pooling611
.
stride
=
{
1
,
1
};
pooling611
.
lengths
=
{
3
,
3
};
...
...
@@ -1604,7 +1604,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx
::
op
::
relu
relu642
;
auto
mx642
=
mm
->
add_instruction
(
relu642
,
mx641
);
migraphx
::
op
::
pooling
pooling643
;
pooling643
.
mode
=
"
average
"
;
pooling643
.
mode
=
migraphx
::
op
::
pooling_mode
::
average
;
pooling643
.
padding
=
{
1
,
1
};
pooling643
.
stride
=
{
1
,
1
};
pooling643
.
lengths
=
{
3
,
3
};
...
...
@@ -1733,7 +1733,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx
::
op
::
relu
relu674
;
auto
mx674
=
mm
->
add_instruction
(
relu674
,
mx673
);
migraphx
::
op
::
pooling
pooling675
;
pooling675
.
mode
=
"
average
"
;
pooling675
.
mode
=
migraphx
::
op
::
pooling_mode
::
average
;
pooling675
.
padding
=
{
1
,
1
};
pooling675
.
stride
=
{
1
,
1
};
pooling675
.
lengths
=
{
3
,
3
};
...
...
@@ -1862,7 +1862,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx
::
op
::
relu
relu706
;
auto
mx706
=
mm
->
add_instruction
(
relu706
,
mx705
);
migraphx
::
op
::
pooling
pooling707
;
pooling707
.
mode
=
"
average
"
;
pooling707
.
mode
=
migraphx
::
op
::
pooling_mode
::
average
;
pooling707
.
padding
=
{
1
,
1
};
pooling707
.
stride
=
{
1
,
1
};
pooling707
.
lengths
=
{
3
,
3
};
...
...
@@ -1955,7 +1955,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx
::
op
::
relu
relu729
;
auto
mx729
=
mm
->
add_instruction
(
relu729
,
mx728
);
migraphx
::
op
::
pooling
pooling730
;
pooling730
.
mode
=
"max"
;
pooling730
.
mode
=
migraphx
::
op
::
pooling_mode
::
max
;
pooling730
.
padding
=
{
0
,
0
};
pooling730
.
stride
=
{
2
,
2
};
pooling730
.
lengths
=
{
3
,
3
};
...
...
@@ -2066,7 +2066,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
concat757
.
axis
=
1
;
auto
mx757
=
mm
->
add_instruction
(
concat757
,
mx753
,
mx756
);
migraphx
::
op
::
pooling
pooling758
;
pooling758
.
mode
=
"
average
"
;
pooling758
.
mode
=
migraphx
::
op
::
pooling_mode
::
average
;
pooling758
.
padding
=
{
1
,
1
};
pooling758
.
stride
=
{
1
,
1
};
pooling758
.
lengths
=
{
3
,
3
};
...
...
@@ -2189,7 +2189,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
concat788
.
axis
=
1
;
auto
mx788
=
mm
->
add_instruction
(
concat788
,
mx784
,
mx787
);
migraphx
::
op
::
pooling
pooling789
;
pooling789
.
mode
=
"
average
"
;
pooling789
.
mode
=
migraphx
::
op
::
pooling_mode
::
average
;
pooling789
.
padding
=
{
1
,
1
};
pooling789
.
stride
=
{
1
,
1
};
pooling789
.
lengths
=
{
3
,
3
};
...
...
@@ -2210,7 +2210,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
concat793
.
axis
=
1
;
auto
mx793
=
mm
->
add_instruction
(
concat793
,
mx765
,
mx775
,
mx788
,
mx792
);
migraphx
::
op
::
pooling
pooling794
;
pooling794
.
mode
=
"
average
"
;
pooling794
.
mode
=
migraphx
::
op
::
pooling_mode
::
average
;
pooling794
.
padding
=
{
0
,
0
};
pooling794
.
stride
=
{
8
,
8
};
pooling794
.
lengths
=
{
8
,
8
};
...
...
src/driver/main.cpp
View file @
c1ec929c
...
...
@@ -505,8 +505,10 @@ struct roctx : command<roctx>
struct
op
:
command
<
op
>
{
bool
show_ops
=
false
;
std
::
string
op_name
{};
void
parse
(
argument_parser
&
ap
)
{
ap
(
op_name
,
{},
ap
.
metavar
(
"<MIGraphX operator name>"
));
ap
(
show_ops
,
{
"--list"
,
"-l"
},
ap
.
help
(
"List all the operators of MIGraphX"
),
...
...
@@ -519,6 +521,12 @@ struct op : command<op>
for
(
const
auto
&
name
:
get_operators
())
std
::
cout
<<
name
<<
std
::
endl
;
}
else
{
auto
op
=
load_op
(
op_name
);
std
::
cout
<<
op_name
<<
": "
<<
std
::
endl
;
std
::
cout
<<
to_pretty_json_string
(
op
.
to_value
())
<<
std
::
endl
;
}
}
};
...
...
src/driver/perf.cpp
View file @
c1ec929c
src/driver/resnet50.cpp
View file @
c1ec929c
...
...
@@ -561,7 +561,7 @@ migraphx::program resnet50(unsigned batch) // NOLINT(readability-function-size)
migraphx
::
op
::
relu
relu269
;
auto
mx269
=
mm
->
add_instruction
(
relu269
,
mx268
);
migraphx
::
op
::
pooling
pooling270
;
pooling270
.
mode
=
"max"
;
pooling270
.
mode
=
migraphx
::
op
::
pooling_mode
::
max
;
pooling270
.
padding
=
{
1
,
1
};
pooling270
.
stride
=
{
2
,
2
};
pooling270
.
lengths
=
{
3
,
3
};
...
...
@@ -1215,7 +1215,7 @@ migraphx::program resnet50(unsigned batch) // NOLINT(readability-function-size)
migraphx
::
op
::
relu
relu438
;
auto
mx438
=
mm
->
add_instruction
(
relu438
,
mx437
);
migraphx
::
op
::
pooling
pooling439
;
pooling439
.
mode
=
"
average
"
;
pooling439
.
mode
=
migraphx
::
op
::
pooling_mode
::
average
;
pooling439
.
padding
=
{
0
,
0
};
pooling439
.
stride
=
{
1
,
1
};
pooling439
.
lengths
=
{
7
,
7
};
...
...
src/eliminate_pad.cpp
View file @
c1ec929c
...
...
@@ -44,7 +44,7 @@ static void update_op(const instruction_ref& input, const instruction_ref& ins,
static
void
update_pooling
(
const
instruction_ref
&
input
,
const
instruction_ref
&
ins
,
module
&
m
)
{
auto
op
=
any_cast
<
op
::
pooling
>
(
ins
->
get_operator
());
if
(
op
.
mode
==
"
average
"
)
if
(
op
.
mode
==
op
::
pooling_mode
::
average
)
{
return
;
}
...
...
src/include/migraphx/json.hpp
View file @
c1ec929c
...
...
@@ -8,6 +8,7 @@
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
std
::
string
to_pretty_json_string
(
const
value
&
val
,
std
::
size_t
indent
=
4
);
std
::
string
to_json_string
(
const
value
&
val
);
value
from_json_string
(
const
std
::
string
&
str
);
value
from_json_string
(
const
char
*
str
,
std
::
size_t
size
);
...
...
src/include/migraphx/op/common.hpp
View file @
c1ec929c
#ifndef MIGRAPHX_GUARD_OPERATORS_COMMON_HPP
#define MIGRAPHX_GUARD_OPERATORS_COMMON_HPP
#include <ostream>
#include <vector>
#include <migraphx/config.hpp>
#include <utility>
...
...
@@ -15,6 +17,14 @@ enum padding_mode_t
valid
};
// The pooling modes must correspond 1-1 to the operators defined for struct parse_pooling.
// Used in pooling and roialign operators.
enum
class
pooling_mode
{
average
,
max
};
// indicate rnn computation direction
enum
class
rnn_direction
{
...
...
@@ -23,6 +33,7 @@ enum class rnn_direction
bidirectional
,
};
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
pooling_mode
v
);
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
rnn_direction
v
);
}
// namespace op
...
...
src/include/migraphx/op/pooling.hpp
View file @
c1ec929c
...
...
@@ -16,12 +16,13 @@
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
pooling
{
std
::
string
mode
=
"
average
"
;
pooling_mode
mode
=
{
pooling_mode
::
average
}
;
std
::
vector
<
std
::
size_t
>
padding
=
{
0
,
0
};
std
::
vector
<
std
::
size_t
>
stride
=
{
1
,
1
};
std
::
vector
<
std
::
size_t
>
lengths
=
{
1
,
1
};
...
...
src/include/migraphx/op/prefix_scan_op.hpp
100755 → 100644
View file @
c1ec929c
...
...
@@ -38,13 +38,33 @@ struct prefix_scan_op : op_name<Derived>
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
return
inputs
.
at
(
0
);
auto
s
=
inputs
.
front
();
if
(
s
.
broadcasted
())
{
return
{
s
.
type
(),
s
.
lens
()};
}
else
{
return
s
.
with_lens
(
s
.
lens
());
}
}
argument
compute
(
const
shape
&
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
output_shape
};
auto
s
=
args
[
0
].
get_shape
();
if
(
s
==
output_shape
)
{
argument
result
=
args
[
0
].
copy
();
auto
s
=
result
.
get_shape
();
result
=
args
[
0
].
copy
();
}
else
{
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
input
)
{
par_for
(
output_shape
.
elements
(),
[
&
](
auto
i
)
{
output
[
output_shape
.
index
(
i
)]
=
input
[
s
.
index
(
i
)];
});
});
s
=
output_shape
;
}
auto
slice
=
shape
{
s
.
type
(),
{
s
.
lens
()[
axis
]},
{
s
.
strides
()[
axis
]}};
auto
lens
=
s
.
lens
();
lens
[
axis
]
=
1
;
...
...
src/include/migraphx/op/roialign.hpp
View file @
c1ec929c
...
...
@@ -3,6 +3,7 @@
#include <limits>
#include <migraphx/check_shapes.hpp>
#include <migraphx/op/common.hpp>
#include <migraphx/config.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/par_for.hpp>
...
...
@@ -21,7 +22,7 @@ namespace op {
struct
roialign
{
std
::
string
coord_trans_mode
=
"half_pixel"
;
std
::
string
mode
=
"avg"
;
pooling_mode
mode
=
{
pooling_mode
::
average
}
;
int64_t
output_height
=
1
;
int64_t
output_width
=
1
;
int64_t
sampling_ratio
=
0
;
...
...
@@ -241,7 +242,8 @@ struct roialign
in_dims
[
0
]
*
in_dims
[
1
]);
double
output_val
;
std
::
tie
(
output_val
,
vec_index
[
c
])
=
(
mode
==
"avg"
)
?
this
->
calc_pooling
(
offset_bottom_data
,
(
mode
==
migraphx
::
op
::
pooling_mode
::
average
)
?
this
->
calc_pooling
(
offset_bottom_data
,
bin_grid_size
,
pre_calc
,
vec_index
[
c
],
...
...
src/include/migraphx/operation.hpp
View file @
c1ec929c
...
...
@@ -461,8 +461,8 @@ lifetime get_lifetime_op(const T&)
* shape compute_shape(const std::vector<shape>& input) const;
* shape compute_shape(const std::vector<shape>& inputs,const std::vector<module_ref>&
* mod_args) const; argument compute(context& ctx,const shape& output,const std::vector<argument>&
* input) const; argument compute(const shape& output,const std::vector<argument>& input)
*
const;
argument compute(const shape& output,const std::vector<argument>& input,const
* input) const; argument compute(const shape& output,const std::vector<argument>& input)
const;
*
argument compute(const shape& output,const std::vector<argument>& input,const
* std::vector<module_ref>& module_args,std::function<std::vector<argument>(module_ref&, const
* std::unordered_map<std::string, argument>&)> run) const; argument compute(context& ctx,const
* shape& output,const std::vector<argument>& input,const std::vector<module_ref>&
...
...
src/include/migraphx/program.hpp
View file @
c1ec929c
...
...
@@ -82,6 +82,9 @@ struct program
const
std
::
function
<
void
(
instruction_ref
,
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
)
>&
print_func
)
const
;
void
print
(
const
std
::
function
<
void
(
instruction_ref
ins
,
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
)
>&
print_func
)
const
;
void
print_graph
(
std
::
ostream
&
os
,
bool
brief
=
false
)
const
;
void
print_cpp
(
std
::
ostream
&
os
)
const
;
...
...
src/insert_pad.cpp
View file @
c1ec929c
...
...
@@ -44,7 +44,7 @@ static void update_op(const instruction_ref& input, const instruction_ref& ins,
static
void
update_pooling
(
const
instruction_ref
&
input
,
const
instruction_ref
&
ins
,
module
&
m
)
{
auto
op
=
any_cast
<
op
::
pooling
>
(
ins
->
get_operator
());
if
(
op
.
mode
==
"
average
"
)
if
(
op
.
mode
==
op
::
pooling_mode
::
average
)
{
return
;
}
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment