Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
23cb7917
Unverified
Commit
23cb7917
authored
Aug 16, 2023
by
Brian Pickrell
Committed by
GitHub
Aug 16, 2023
Browse files
Merge branch 'develop' into blas_tuning
parents
b5fcc0bc
ea32ca70
Changes
458
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1113 additions
and
388 deletions
+1113
-388
src/api/api.cpp
src/api/api.cpp
+286
-3
src/api/include/migraphx/migraphx.h
src/api/include/migraphx/migraphx.h
+340
-229
src/api/include/migraphx/migraphx.hpp
src/api/include/migraphx/migraphx.hpp
+134
-2
src/api/migraphx.py
src/api/migraphx.py
+58
-44
src/common.cpp
src/common.cpp
+50
-62
src/cpp_generator.cpp
src/cpp_generator.cpp
+25
-2
src/dead_code_elimination.cpp
src/dead_code_elimination.cpp
+4
-2
src/driver/CMakeLists.txt
src/driver/CMakeLists.txt
+8
-6
src/driver/argument_parser.hpp
src/driver/argument_parser.hpp
+14
-4
src/driver/main.cpp
src/driver/main.cpp
+84
-23
src/driver/perf.cpp
src/driver/perf.cpp
+34
-0
src/driver/perf.hpp
src/driver/perf.hpp
+9
-0
src/dynamic_loader.cpp
src/dynamic_loader.cpp
+13
-1
src/fuse_pointwise.cpp
src/fuse_pointwise.cpp
+7
-2
src/fuse_reduce.cpp
src/fuse_reduce.cpp
+2
-2
src/generate.cpp
src/generate.cpp
+1
-1
src/include/migraphx/adjust_allocation.hpp
src/include/migraphx/adjust_allocation.hpp
+1
-1
src/include/migraphx/algorithm.hpp
src/include/migraphx/algorithm.hpp
+38
-0
src/include/migraphx/allocation_model.hpp
src/include/migraphx/allocation_model.hpp
+3
-3
src/include/migraphx/analyze_streams.hpp
src/include/migraphx/analyze_streams.hpp
+2
-1
No files found.
src/api/api.cpp
View file @
23cb7917
...
@@ -24,6 +24,7 @@
...
@@ -24,6 +24,7 @@
#include <migraphx/execution_environment.hpp>
#include <migraphx/execution_environment.hpp>
#include <migraphx/migraphx.h>
#include <migraphx/migraphx.h>
#include <migraphx/rank.hpp>
#include <migraphx/rank.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/program.hpp>
#include <migraphx/program.hpp>
#include <migraphx/onnx.hpp>
#include <migraphx/onnx.hpp>
...
@@ -43,7 +44,7 @@ namespace migraphx {
...
@@ -43,7 +44,7 @@ namespace migraphx {
static
thread_local
bool
disable_exception_catch
=
false
;
// NOLINT
static
thread_local
bool
disable_exception_catch
=
false
;
// NOLINT
extern
"C"
void
migraphx_test_private_disable_exception_catch
(
bool
b
)
extern
"C"
MIGRAPHX_C_EXPORT
void
migraphx_test_private_disable_exception_catch
(
bool
b
)
{
{
disable_exception_catch
=
b
;
disable_exception_catch
=
b
;
}
}
...
@@ -145,6 +146,11 @@ void set_default_dim_value(onnx_options& options, size_t value)
...
@@ -145,6 +146,11 @@ void set_default_dim_value(onnx_options& options, size_t value)
options
.
default_dim_value
=
value
;
options
.
default_dim_value
=
value
;
}
}
void
set_default_dyn_dim_value
(
onnx_options
&
options
,
const
shape
::
dynamic_dimension
&
dd
)
{
options
.
default_dyn_dim_value
=
dd
;
}
void
set_default_loop_iterations
(
onnx_options
&
options
,
int64_t
value
)
void
set_default_loop_iterations
(
onnx_options
&
options
,
int64_t
value
)
{
{
options
.
max_loop_iterations
=
value
;
options
.
max_loop_iterations
=
value
;
...
@@ -161,6 +167,13 @@ void set_input_parameter_shape(onnx_options& options,
...
@@ -161,6 +167,13 @@ void set_input_parameter_shape(onnx_options& options,
options
.
map_input_dims
[
std
::
string
(
name
)]
=
std
::
move
(
dims
);
options
.
map_input_dims
[
std
::
string
(
name
)]
=
std
::
move
(
dims
);
}
}
void
set_dyn_input_parameter_shape
(
onnx_options
&
options
,
const
char
*
name
,
std
::
vector
<
shape
::
dynamic_dimension
>
dyn_dims
)
{
options
.
map_dyn_input_dims
[
std
::
string
(
name
)]
=
std
::
move
(
dyn_dims
);
}
void
set_input_parameter_shape
(
tf_options
&
options
,
const
char
*
name
,
std
::
vector
<
std
::
size_t
>
dims
)
void
set_input_parameter_shape
(
tf_options
&
options
,
const
char
*
name
,
std
::
vector
<
std
::
size_t
>
dims
)
{
{
options
.
map_input_dims
[
std
::
string
(
name
)]
=
std
::
move
(
dims
);
options
.
map_input_dims
[
std
::
string
(
name
)]
=
std
::
move
(
dims
);
...
@@ -187,6 +200,12 @@ std::vector<const char*> get_names(const std::unordered_map<std::string, Value>&
...
@@ -187,6 +200,12 @@ std::vector<const char*> get_names(const std::unordered_map<std::string, Value>&
return
result
;
return
result
;
}
}
template
<
class
T
>
std
::
set
<
T
>
make_set
(
const
T
*
x
,
std
::
size_t
n
)
{
return
{
x
,
x
+
n
};
}
void
quantize_fp16_with_op_names
(
program
&
prog
,
std
::
vector
<
std
::
string
>&
names
)
void
quantize_fp16_with_op_names
(
program
&
prog
,
std
::
vector
<
std
::
string
>&
names
)
{
{
if
(
names
.
empty
())
if
(
names
.
empty
())
...
@@ -346,7 +365,10 @@ const Target* object_cast(const U* x)
...
@@ -346,7 +365,10 @@ const Target* object_cast(const U* x)
template
<
class
T
,
class
...
Ts
,
class
Target
=
std
::
remove_pointer_t
<
T
>
>
template
<
class
T
,
class
...
Ts
,
class
Target
=
std
::
remove_pointer_t
<
T
>
>
Target
*
allocate
(
Ts
&&
...
xs
)
Target
*
allocate
(
Ts
&&
...
xs
)
{
{
return
new
Target
(
std
::
forward
<
Ts
>
(
xs
)...);
// NOLINT
if
constexpr
(
std
::
is_aggregate
<
Target
>
{})
return
new
Target
{
std
::
forward
<
Ts
>
(
xs
)...};
// NOLINT
else
return
new
Target
(
std
::
forward
<
Ts
>
(
xs
)...);
// NOLINT
}
}
template
<
class
T
>
template
<
class
T
>
...
@@ -409,6 +431,39 @@ struct manage_generic_ptr
...
@@ -409,6 +431,39 @@ struct manage_generic_ptr
D
deleter
=
nullptr
;
D
deleter
=
nullptr
;
};
};
extern
"C"
struct
migraphx_optimals
;
struct
migraphx_optimals
{
template
<
class
...
Ts
>
migraphx_optimals
(
Ts
&&
...
xs
)
:
object
(
std
::
forward
<
Ts
>
(
xs
)...)
// NOLINT(readability-redundant-member-init)
{
}
std
::
set
<
size_t
>
object
;
};
extern
"C"
struct
migraphx_dynamic_dimension
;
struct
migraphx_dynamic_dimension
{
template
<
class
...
Ts
>
migraphx_dynamic_dimension
(
Ts
&&
...
xs
)
:
object
(
std
::
forward
<
Ts
>
(
xs
)...)
// NOLINT(readability-redundant-member-init)
{
}
migraphx
::
shape
::
dynamic_dimension
object
;
};
extern
"C"
struct
migraphx_dynamic_dimensions
;
struct
migraphx_dynamic_dimensions
{
template
<
class
...
Ts
>
migraphx_dynamic_dimensions
(
Ts
&&
...
xs
)
:
object
(
std
::
forward
<
Ts
>
(
xs
)...)
// NOLINT(readability-redundant-member-init)
{
}
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>
object
;
};
extern
"C"
struct
migraphx_shape
;
extern
"C"
struct
migraphx_shape
;
struct
migraphx_shape
struct
migraphx_shape
{
{
...
@@ -736,6 +791,152 @@ struct migraphx_experimental_custom_op
...
@@ -736,6 +791,152 @@ struct migraphx_experimental_custom_op
}
}
};
};
extern
"C"
migraphx_status
migraphx_optimals_destroy
(
migraphx_optimals_t
optimals
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
destroy
((
optimals
));
});
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_optimals_assign_to
(
migraphx_optimals_t
output
,
const_migraphx_optimals_t
input
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
*
output
=
*
input
;
});
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_optimals_create
(
migraphx_optimals_t
*
optimals
,
const
size_t
*
ptr
,
size_t
size
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
*
optimals
=
object_cast
<
migraphx_optimals_t
>
(
allocate
<
std
::
set
<
size_t
>>
(
migraphx
::
make_set
<
size_t
>
((
ptr
),
(
size
))));
});
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_dynamic_dimension_destroy
(
migraphx_dynamic_dimension_t
dynamic_dimension
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
destroy
((
dynamic_dimension
));
});
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_dynamic_dimension_assign_to
(
migraphx_dynamic_dimension_t
output
,
const_migraphx_dynamic_dimension_t
input
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
*
output
=
*
input
;
});
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_dynamic_dimension_create_min_max
(
migraphx_dynamic_dimension_t
*
dynamic_dimension
,
size_t
min
,
size_t
max
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
*
dynamic_dimension
=
object_cast
<
migraphx_dynamic_dimension_t
>
(
allocate
<
migraphx
::
shape
::
dynamic_dimension
>
((
min
),
(
max
)));
});
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_dynamic_dimension_create_min_max_optimals
(
migraphx_dynamic_dimension_t
*
dynamic_dimension
,
size_t
min
,
size_t
max
,
migraphx_optimals_t
optimals
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
if
(
optimals
==
nullptr
)
MIGRAPHX_THROW
(
migraphx_status_bad_param
,
"Bad parameter optimals: Null pointer"
);
*
dynamic_dimension
=
object_cast
<
migraphx_dynamic_dimension_t
>
(
allocate
<
migraphx
::
shape
::
dynamic_dimension
>
((
min
),
(
max
),
(
optimals
->
object
)));
});
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_dynamic_dimension_is_fixed
(
bool
*
out
,
const_migraphx_dynamic_dimension_t
dynamic_dimension
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
if
(
dynamic_dimension
==
nullptr
)
MIGRAPHX_THROW
(
migraphx_status_bad_param
,
"Bad parameter dynamic_dimension: Null pointer"
);
*
out
=
(
dynamic_dimension
->
object
).
is_fixed
();
});
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_dynamic_dimension_equal
(
bool
*
out
,
const_migraphx_dynamic_dimension_t
dynamic_dimension
,
const_migraphx_dynamic_dimension_t
x
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
if
(
dynamic_dimension
==
nullptr
)
MIGRAPHX_THROW
(
migraphx_status_bad_param
,
"Bad parameter dynamic_dimension: Null pointer"
);
if
(
x
==
nullptr
)
MIGRAPHX_THROW
(
migraphx_status_bad_param
,
"Bad parameter x: Null pointer"
);
*
out
=
migraphx
::
equal
((
dynamic_dimension
->
object
),
(
x
->
object
));
});
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_dynamic_dimensions_destroy
(
migraphx_dynamic_dimensions_t
dynamic_dimensions
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
destroy
((
dynamic_dimensions
));
});
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_dynamic_dimensions_assign_to
(
migraphx_dynamic_dimensions_t
output
,
const_migraphx_dynamic_dimensions_t
input
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
*
output
=
*
input
;
});
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_dynamic_dimensions_create
(
migraphx_dynamic_dimensions_t
*
dynamic_dimensions
,
const
const_migraphx_dynamic_dimension_t
*
ptr
,
size_t
size
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
*
dynamic_dimensions
=
object_cast
<
migraphx_dynamic_dimensions_t
>
(
allocate
<
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>>
(
migraphx
::
to_obj_vector
<
const_migraphx_dynamic_dimension_t
>
((
ptr
),
(
size
))));
});
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_dynamic_dimensions_size
(
size_t
*
out
,
migraphx_dynamic_dimensions_t
dynamic_dimensions
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
if
(
dynamic_dimensions
==
nullptr
)
MIGRAPHX_THROW
(
migraphx_status_bad_param
,
"Bad parameter dynamic_dimensions: Null pointer"
);
*
out
=
(
dynamic_dimensions
->
object
).
size
();
});
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_dynamic_dimensions_get
(
const_migraphx_dynamic_dimension_t
*
out
,
migraphx_dynamic_dimensions_t
dynamic_dimensions
,
size_t
idx
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
if
(
dynamic_dimensions
==
nullptr
)
MIGRAPHX_THROW
(
migraphx_status_bad_param
,
"Bad parameter dynamic_dimensions: Null pointer"
);
*
out
=
object_cast
<
const_migraphx_dynamic_dimension_t
>
(
&
((
dynamic_dimensions
->
object
).
at
((
idx
))));
});
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_shape_destroy
(
migraphx_shape_t
shape
)
extern
"C"
migraphx_status
migraphx_shape_destroy
(
migraphx_shape_t
shape
)
{
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
destroy
((
shape
));
});
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
destroy
((
shape
));
});
...
@@ -794,6 +995,19 @@ extern "C" migraphx_status migraphx_shape_create_scalar(migraphx_shape_t* shape,
...
@@ -794,6 +995,19 @@ extern "C" migraphx_status migraphx_shape_create_scalar(migraphx_shape_t* shape,
return
api_error_result
;
return
api_error_result
;
}
}
extern
"C"
migraphx_status
migraphx_shape_create_dynamic
(
migraphx_shape_t
*
shape
,
migraphx_shape_datatype_t
type
,
migraphx_dynamic_dimensions_t
dims
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
if
(
dims
==
nullptr
)
MIGRAPHX_THROW
(
migraphx_status_bad_param
,
"Bad parameter dims: Null pointer"
);
*
shape
=
object_cast
<
migraphx_shape_t
>
(
allocate
<
migraphx
::
shape
>
((
migraphx
::
to_shape_type
(
type
)),
(
dims
->
object
)));
});
return
api_error_result
;
}
extern
"C"
migraphx_status
extern
"C"
migraphx_status
migraphx_shape_lengths
(
const
size_t
**
out
,
size_t
*
out_size
,
const_migraphx_shape_t
shape
)
migraphx_shape_lengths
(
const
size_t
**
out
,
size_t
*
out_size
,
const_migraphx_shape_t
shape
)
{
{
...
@@ -824,6 +1038,17 @@ migraphx_shape_strides(const size_t** out, size_t* out_size, const_migraphx_shap
...
@@ -824,6 +1038,17 @@ migraphx_shape_strides(const size_t** out, size_t* out_size, const_migraphx_shap
return
api_error_result
;
return
api_error_result
;
}
}
extern
"C"
migraphx_status
migraphx_shape_dyn_dims
(
migraphx_dynamic_dimensions_t
*
out
,
const_migraphx_shape_t
shape
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
if
(
shape
==
nullptr
)
MIGRAPHX_THROW
(
migraphx_status_bad_param
,
"Bad parameter shape: Null pointer"
);
*
out
=
allocate
<
migraphx_dynamic_dimensions_t
>
((
shape
->
object
).
dyn_dims
());
});
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_shape_type
(
migraphx_shape_datatype_t
*
out
,
extern
"C"
migraphx_status
migraphx_shape_type
(
migraphx_shape_datatype_t
*
out
,
const_migraphx_shape_t
shape
)
const_migraphx_shape_t
shape
)
{
{
...
@@ -857,6 +1082,16 @@ extern "C" migraphx_status migraphx_shape_bytes(size_t* out, const_migraphx_shap
...
@@ -857,6 +1082,16 @@ extern "C" migraphx_status migraphx_shape_bytes(size_t* out, const_migraphx_shap
return
api_error_result
;
return
api_error_result
;
}
}
extern
"C"
migraphx_status
migraphx_shape_ndim
(
size_t
*
out
,
const_migraphx_shape_t
shape
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
if
(
shape
==
nullptr
)
MIGRAPHX_THROW
(
migraphx_status_bad_param
,
"Bad parameter shape: Null pointer"
);
*
out
=
(
shape
->
object
).
ndim
();
});
return
api_error_result
;
}
extern
"C"
migraphx_status
extern
"C"
migraphx_status
migraphx_shape_equal
(
bool
*
out
,
const_migraphx_shape_t
shape
,
const_migraphx_shape_t
x
)
migraphx_shape_equal
(
bool
*
out
,
const_migraphx_shape_t
shape
,
const_migraphx_shape_t
x
)
{
{
...
@@ -880,6 +1115,16 @@ extern "C" migraphx_status migraphx_shape_standard(bool* out, const_migraphx_sha
...
@@ -880,6 +1115,16 @@ extern "C" migraphx_status migraphx_shape_standard(bool* out, const_migraphx_sha
return
api_error_result
;
return
api_error_result
;
}
}
extern
"C"
migraphx_status
migraphx_shape_dynamic
(
bool
*
out
,
const_migraphx_shape_t
shape
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
if
(
shape
==
nullptr
)
MIGRAPHX_THROW
(
migraphx_status_bad_param
,
"Bad parameter shape: Null pointer"
);
*
out
=
(
shape
->
object
).
dynamic
();
});
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_shape_index
(
size_t
*
out
,
const_migraphx_shape_t
shape
,
size_t
i
)
extern
"C"
migraphx_status
migraphx_shape_index
(
size_t
*
out
,
const_migraphx_shape_t
shape
,
size_t
i
)
{
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
...
@@ -915,6 +1160,17 @@ migraphx_argument_create(migraphx_argument_t* argument, const_migraphx_shape_t s
...
@@ -915,6 +1160,17 @@ migraphx_argument_create(migraphx_argument_t* argument, const_migraphx_shape_t s
return
api_error_result
;
return
api_error_result
;
}
}
extern
"C"
migraphx_status
migraphx_argument_create_empty
(
migraphx_argument_t
*
argument
,
const_migraphx_shape_t
shape
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
if
(
shape
==
nullptr
)
MIGRAPHX_THROW
(
migraphx_status_bad_param
,
"Bad parameter shape: Null pointer"
);
*
argument
=
object_cast
<
migraphx_argument_t
>
(
allocate
<
migraphx
::
argument
>
((
shape
->
object
)));
});
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_argument_shape
(
const_migraphx_shape_t
*
out
,
extern
"C"
migraphx_status
migraphx_argument_shape
(
const_migraphx_shape_t
*
out
,
const_migraphx_argument_t
argument
)
const_migraphx_argument_t
argument
)
{
{
...
@@ -1176,7 +1432,7 @@ extern "C" migraphx_status migraphx_instructions_assign_to(migraphx_instructions
...
@@ -1176,7 +1432,7 @@ extern "C" migraphx_status migraphx_instructions_assign_to(migraphx_instructions
}
}
extern
"C"
migraphx_status
migraphx_instructions_create
(
migraphx_instructions_t
*
instructions
,
extern
"C"
migraphx_status
migraphx_instructions_create
(
migraphx_instructions_t
*
instructions
,
const_migraphx_instruction_t
*
ptr
,
const
const_migraphx_instruction_t
*
ptr
,
size_t
size
)
size_t
size
)
{
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
...
@@ -1590,6 +1846,19 @@ extern "C" migraphx_status migraphx_onnx_options_set_input_parameter_shape(
...
@@ -1590,6 +1846,19 @@ extern "C" migraphx_status migraphx_onnx_options_set_input_parameter_shape(
return
api_error_result
;
return
api_error_result
;
}
}
extern
"C"
migraphx_status
migraphx_onnx_options_set_dyn_input_parameter_shape
(
migraphx_onnx_options_t
onnx_options
,
const
char
*
name
,
migraphx_dynamic_dimensions_t
dims
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
if
(
onnx_options
==
nullptr
)
MIGRAPHX_THROW
(
migraphx_status_bad_param
,
"Bad parameter onnx_options: Null pointer"
);
if
(
dims
==
nullptr
)
MIGRAPHX_THROW
(
migraphx_status_bad_param
,
"Bad parameter dims: Null pointer"
);
migraphx
::
set_dyn_input_parameter_shape
((
onnx_options
->
object
),
(
name
),
(
dims
->
object
));
});
return
api_error_result
;
}
extern
"C"
migraphx_status
extern
"C"
migraphx_status
migraphx_onnx_options_set_default_dim_value
(
migraphx_onnx_options_t
onnx_options
,
size_t
value
)
migraphx_onnx_options_set_default_dim_value
(
migraphx_onnx_options_t
onnx_options
,
size_t
value
)
{
{
...
@@ -1601,6 +1870,20 @@ migraphx_onnx_options_set_default_dim_value(migraphx_onnx_options_t onnx_options
...
@@ -1601,6 +1870,20 @@ migraphx_onnx_options_set_default_dim_value(migraphx_onnx_options_t onnx_options
return
api_error_result
;
return
api_error_result
;
}
}
extern
"C"
migraphx_status
migraphx_onnx_options_set_default_dyn_dim_value
(
migraphx_onnx_options_t
onnx_options
,
const_migraphx_dynamic_dimension_t
dd
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
if
(
onnx_options
==
nullptr
)
MIGRAPHX_THROW
(
migraphx_status_bad_param
,
"Bad parameter onnx_options: Null pointer"
);
if
(
dd
==
nullptr
)
MIGRAPHX_THROW
(
migraphx_status_bad_param
,
"Bad parameter dd: Null pointer"
);
migraphx
::
set_default_dyn_dim_value
((
onnx_options
->
object
),
(
dd
->
object
));
});
return
api_error_result
;
}
extern
"C"
migraphx_status
extern
"C"
migraphx_status
migraphx_onnx_options_set_default_loop_iterations
(
migraphx_onnx_options_t
onnx_options
,
migraphx_onnx_options_set_default_loop_iterations
(
migraphx_onnx_options_t
onnx_options
,
int64_t
value
)
int64_t
value
)
...
...
src/api/include/migraphx/migraphx.h
View file @
23cb7917
...
@@ -26,6 +26,9 @@
...
@@ -26,6 +26,9 @@
#include <stdlib.h>
#include <stdlib.h>
#include <stdbool.h>
#include <stdbool.h>
#include <migraphx/api/export.h>
// Add new types here
// Add new types here
// clang-format off
// clang-format off
#define MIGRAPHX_SHAPE_VISIT_TYPES(m) \
#define MIGRAPHX_SHAPE_VISIT_TYPES(m) \
...
@@ -66,6 +69,15 @@ typedef enum
...
@@ -66,6 +69,15 @@ typedef enum
}
migraphx_shape_datatype_t
;
}
migraphx_shape_datatype_t
;
#undef MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES
#undef MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES
typedef
struct
migraphx_optimals
*
migraphx_optimals_t
;
typedef
const
struct
migraphx_optimals
*
const_migraphx_optimals_t
;
typedef
struct
migraphx_dynamic_dimension
*
migraphx_dynamic_dimension_t
;
typedef
const
struct
migraphx_dynamic_dimension
*
const_migraphx_dynamic_dimension_t
;
typedef
struct
migraphx_dynamic_dimensions
*
migraphx_dynamic_dimensions_t
;
typedef
const
struct
migraphx_dynamic_dimensions
*
const_migraphx_dynamic_dimensions_t
;
typedef
struct
migraphx_shape
*
migraphx_shape_t
;
typedef
struct
migraphx_shape
*
migraphx_shape_t
;
typedef
const
struct
migraphx_shape
*
const_migraphx_shape_t
;
typedef
const
struct
migraphx_shape
*
const_migraphx_shape_t
;
...
@@ -157,360 +169,460 @@ typedef migraphx_status (*migraphx_experimental_custom_op_copy)(void** out, void
...
@@ -157,360 +169,460 @@ typedef migraphx_status (*migraphx_experimental_custom_op_copy)(void** out, void
typedef
migraphx_status
(
*
migraphx_experimental_custom_op_delete
)(
void
*
input
);
typedef
migraphx_status
(
*
migraphx_experimental_custom_op_delete
)(
void
*
input
);
migraphx_status
migraphx_shape_destroy
(
migraphx_shape_t
shape
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_optimals_destroy
(
migraphx_optimals_t
optimals
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_optimals_assign_to
(
migraphx_optimals_t
output
,
const_migraphx_optimals_t
input
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_optimals_create
(
migraphx_optimals_t
*
optimals
,
const
size_t
*
ptr
,
size_t
size
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_dynamic_dimension_destroy
(
migraphx_dynamic_dimension_t
dynamic_dimension
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_dynamic_dimension_assign_to
(
migraphx_dynamic_dimension_t
output
,
const_migraphx_dynamic_dimension_t
input
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_dynamic_dimension_create_min_max
(
migraphx_dynamic_dimension_t
*
dynamic_dimension
,
size_t
min
,
size_t
max
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_dynamic_dimension_create_min_max_optimals
(
migraphx_dynamic_dimension_t
*
dynamic_dimension
,
size_t
min
,
size_t
max
,
migraphx_optimals_t
optimals
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_dynamic_dimension_is_fixed
(
bool
*
out
,
const_migraphx_dynamic_dimension_t
dynamic_dimension
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_dynamic_dimension_equal
(
bool
*
out
,
const_migraphx_dynamic_dimension_t
dynamic_dimension
,
const_migraphx_dynamic_dimension_t
x
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_dynamic_dimensions_destroy
(
migraphx_dynamic_dimensions_t
dynamic_dimensions
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_dynamic_dimensions_assign_to
(
migraphx_dynamic_dimensions_t
output
,
const_migraphx_dynamic_dimensions_t
input
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_dynamic_dimensions_create
(
migraphx_dynamic_dimensions_t
*
dynamic_dimensions
,
const
const_migraphx_dynamic_dimension_t
*
ptr
,
size_t
size
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_dynamic_dimensions_size
(
size_t
*
out
,
migraphx_dynamic_dimensions_t
dynamic_dimensions
);
migraphx_status
migraphx_shape_assign_to
(
migraphx_shape_t
output
,
const_migraphx_shape_t
input
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_dynamic_dimensions_get
(
const_migraphx_dynamic_dimension_t
*
out
,
migraphx_dynamic_dimensions_t
dynamic_dimensions
,
size_t
idx
);
migraphx_status
migraphx_shape_create
(
migraphx_shape_t
*
shape
,
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_shape_destroy
(
migraphx_shape_t
shape
);
migraphx_shape_datatype_t
type
,
size_t
*
lengths
,
size_t
lengths_size
);
migraphx_status
migraphx_shape_create_with_strides
(
migraphx_shape_t
*
shape
,
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_shape_assign_to
(
migraphx_shape_t
output
,
migraphx_shape_datatype_t
type
,
const_migraphx_shape_t
input
);
size_t
*
lengths
,
size_t
lengths_size
,
size_t
*
strides
,
size_t
strides_size
);
migraphx_status
migraphx_shape_create_scalar
(
migraphx_shape_t
*
shape
,
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_shape_create
(
migraphx_shape_t
*
shape
,
migraphx_shape_datatype_t
type
);
migraphx_shape_datatype_t
type
,
size_t
*
lengths
,
size_t
lengths_size
);
migraphx_status
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_shape_create_with_strides
(
migraphx_shape_t
*
shape
,
migraphx_shape_lengths
(
const
size_t
**
out
,
size_t
*
out_size
,
const_migraphx_shape_t
shape
);
migraphx_shape_datatype_t
type
,
size_t
*
lengths
,
size_t
lengths_size
,
size_t
*
strides
,
size_t
strides_size
);
migraphx_status
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_shape_create_scalar
(
migraphx_shape_t
*
shape
,
migraphx_shape_strides
(
const
size_t
**
out
,
size_t
*
out_size
,
const_migraphx_sha
pe_t
sha
pe
);
migraphx_shape_dataty
pe_t
ty
pe
);
migraphx_status
migraphx_shape_type
(
migraphx_shape_datatype_t
*
out
,
const_migraphx_shape_t
shape
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_shape_create_dynamic
(
migraphx_shape_t
*
shape
,
migraphx_shape_datatype_t
type
,
migraphx_dynamic_dimensions_t
dims
);
migraphx_status
migraphx_shape_elements
(
size_t
*
out
,
const_migraphx_shape_t
shape
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_shape_lengths
(
const
size_t
**
out
,
size_t
*
out_size
,
const_migraphx_shape_t
shape
);
migraphx_status
migraphx_shape_bytes
(
size_t
*
out
,
const_migraphx_shape_t
shape
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_shape_strides
(
const
size_t
**
out
,
size_t
*
out_size
,
const_migraphx_shape_t
shape
);
migraphx_status
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_shape_dyn_dims
(
migraphx_dynamic_dimensions_t
*
out
,
migraphx_shape_equal
(
bool
*
out
,
const_migraphx_shape_t
shape
,
const_migraphx_shape_t
x
);
const_migraphx_shape_t
shape
);
migraphx_status
migraphx_shape_standard
(
bool
*
out
,
const_migraphx_shape_t
shape
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_shape_type
(
migraphx_shape_datatype_t
*
out
,
const_migraphx_shape_t
shape
);
migraphx_status
migraphx_shape_index
(
size_t
*
out
,
const_migraphx_shape_t
shape
,
size_t
i
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_shape_elements
(
size_t
*
out
,
const_migraphx_shape_t
shape
);
migraphx_status
migraphx_
argument_destroy
(
migraphx_argument_t
argument
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_
shape_bytes
(
size_t
*
out
,
const_migraphx_shape_t
shape
);
migraphx_status
migraphx_argument_assign_to
(
migraphx_argument_t
output
,
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_shape_ndim
(
size_t
*
out
,
const_migraphx_shape_t
shape
);
const_migraphx_argument_t
input
);
migraphx_status
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_shape_equal
(
bool
*
out
,
migraphx_argument_create
(
migraphx_argument_t
*
argument
,
const_migraphx_shape_t
shape
,
void
*
buffer
);
const_migraphx_shape_t
shape
,
const_migraphx_shape_t
x
);
migraphx_status
migraphx_argument_shape
(
const_migraphx_shape_t
*
out
,
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_shape_standard
(
bool
*
out
,
const_migraphx_shape_t
shape
);
const_migraphx_argument_t
argument
);
migraphx_status
migraphx_
argument_buffer
(
char
*
*
out
,
const_migraphx_
argument_t
argument
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_
shape_dynamic
(
bool
*
out
,
const_migraphx_
shape_t
shape
);
migraphx_status
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_shape_index
(
size_t
*
out
,
migraphx_argument_equal
(
bool
*
out
,
const_migraphx_argument_t
argument
,
const_migraphx_argument_t
x
);
const_migraphx_shape_t
shape
,
size_t
i
);
migraphx_status
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_argument_destroy
(
migraphx_argument_t
argument
);
migraphx_argument_generate
(
migraphx_argument_t
*
out
,
const_migraphx_shape_t
s
,
size_t
seed
);
migraphx_status
migraphx_target_destroy
(
migraphx_target_t
target
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_argument_assign_to
(
migraphx_argument_t
output
,
const_migraphx_argument_t
input
);
migraphx_status
migraphx_target_assign_to
(
migraphx_target_t
output
,
const_migraphx_target_t
input
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_argument_create
(
migraphx_argument_t
*
argument
,
const_migraphx_shape_t
shape
,
void
*
buffer
);
migraphx_status
migraphx_target_create
(
migraphx_target_t
*
target
,
const
char
*
name
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_argument_create_empty
(
migraphx_argument_t
*
argument
,
const_migraphx_shape_t
shape
);
migraphx_status
migraphx_program_parameter_shapes_destroy
(
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_argument_shape
(
const_migraphx_shape_t
*
out
,
const_migraphx_argument_t
argument
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_argument_buffer
(
char
**
out
,
const_migraphx_argument_t
argument
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_argument_equal
(
bool
*
out
,
const_migraphx_argument_t
argument
,
const_migraphx_argument_t
x
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_argument_generate
(
migraphx_argument_t
*
out
,
const_migraphx_shape_t
s
,
size_t
seed
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_target_destroy
(
migraphx_target_t
target
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_target_assign_to
(
migraphx_target_t
output
,
const_migraphx_target_t
input
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_target_create
(
migraphx_target_t
*
target
,
const
char
*
name
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_program_parameter_shapes_destroy
(
migraphx_program_parameter_shapes_t
program_parameter_shapes
);
migraphx_program_parameter_shapes_t
program_parameter_shapes
);
migraphx_status
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_program_parameter_shapes_assign_to
(
migraphx_program_parameter_shapes_assign_to
(
migraphx_program_parameter_shapes_t
output
,
migraphx_program_parameter_shapes_t
output
,
const_migraphx_program_parameter_shapes_t
input
);
const_migraphx_program_parameter_shapes_t
input
);
migraphx_status
migraphx_program_parameter_shapes_size
(
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_program_parameter_shapes_size
(
size_t
*
out
,
migraphx_program_parameter_shapes_t
program_parameter_shapes
);
size_t
*
out
,
migraphx_program_parameter_shapes_t
program_parameter_shapes
);
migraphx_status
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_program_parameter_shapes_get
(
const_migraphx_shape_t
*
out
,
migraphx_program_parameter_shapes_get
(
const_migraphx_shape_t
*
out
,
migraphx_program_parameter_shapes_t
program_parameter_shapes
,
migraphx_program_parameter_shapes_t
program_parameter_shapes
,
const
char
*
name
);
const
char
*
name
);
migraphx_status
migraphx_program_parameter_shapes_names
(
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_program_parameter_shapes_names
(
const
char
**
out
,
migraphx_program_parameter_shapes_t
program_parameter_shapes
);
const
char
**
out
,
migraphx_program_parameter_shapes_t
program_parameter_shapes
);
migraphx_status
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_program_parameters_destroy
(
migraphx_program_parameters_t
program_parameters
);
migraphx_program_parameters_destroy
(
migraphx_program_parameters_t
program_parameters
);
migraphx_status
migraphx_program_parameters_assign_to
(
migraphx_program_parameters_t
output
,
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_program_parameters_assign_to
(
const_migraphx_program_parameters_t
input
);
migraphx_program_parameters_t
output
,
const_migraphx_program_parameters_t
input
);
migraphx_status
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_program_parameters_create
(
migraphx_program_parameters_t
*
program_parameters
);
migraphx_program_parameters_create
(
migraphx_program_parameters_t
*
program_parameters
);
migraphx_status
migraphx_program_parameters_add
(
migraphx_program_parameters_t
program_parameters
,
MIGRAPHX_C_EXPORT
migraphx_status
const
char
*
name
,
migraphx_program_parameters_add
(
migraphx_program_parameters_t
program_parameters
,
const_migraphx_argument_t
argument
);
const
char
*
name
,
const_migraphx_argument_t
argument
);
migraphx_status
migraphx_arguments_destroy
(
migraphx_arguments_t
arguments
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_arguments_destroy
(
migraphx_arguments_t
arguments
);
migraphx_status
migraphx_arguments_assign_to
(
migraphx_arguments_t
output
,
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_arguments_assign_to
(
migraphx_arguments_t
output
,
const_migraphx_arguments_t
input
);
const_migraphx_arguments_t
input
);
migraphx_status
migraphx_arguments_size
(
size_t
*
out
,
migraphx_arguments_t
arguments
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_arguments_size
(
size_t
*
out
,
migraphx_arguments_t
arguments
);
migraphx_status
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_arguments_get
(
const_migraphx_argument_t
*
out
,
migraphx_arguments_get
(
const_migraphx_argument_t
*
out
,
migraphx_arguments_t
arguments
,
size_t
idx
);
migraphx_arguments_t
arguments
,
size_t
idx
);
migraphx_status
migraphx_shapes_destroy
(
migraphx_shapes_t
shapes
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_shapes_destroy
(
migraphx_shapes_t
shapes
);
migraphx_status
migraphx_shapes_assign_to
(
migraphx_shapes_t
output
,
const_migraphx_shapes_t
input
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_shapes_assign_to
(
migraphx_shapes_t
output
,
const_migraphx_shapes_t
input
);
migraphx_status
migraphx_shapes_size
(
size_t
*
out
,
migraphx_shapes_t
shapes
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_shapes_size
(
size_t
*
out
,
migraphx_shapes_t
shapes
);
migraphx_status
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_shapes_get
(
const_migraphx_shape_t
*
out
,
migraphx_shapes_get
(
const_migraphx_shape_t
*
out
,
migraphx_shapes_t
shapes
,
size_t
idx
);
migraphx_shapes_t
shapes
,
size_t
idx
);
migraphx_status
migraphx_instruction_destroy
(
migraphx_instruction_t
instruction
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_instruction_destroy
(
migraphx_instruction_t
instruction
);
migraphx_status
migraphx_instruction_assign_to
(
migraphx_instruction_t
output
,
MIGRAPHX_C_EXPORT
migraphx_status
const_migraphx_instruction_t
input
);
migraphx_instruction_assign_to
(
migraphx_instruction_t
output
,
const_migraphx_instruction_t
input
);
migraphx_status
migraphx_instructions_destroy
(
migraphx_instructions_t
instructions
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_instructions_destroy
(
migraphx_instructions_t
instructions
);
migraphx_status
migraphx_instructions_assign_to
(
migraphx_instructions_t
output
,
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_instructions_assign_to
(
const_migraphx_instructions_t
input
);
migraphx_instructions_t
output
,
const_migraphx_instructions_t
input
);
migraphx_status
migraphx_instructions_create
(
migraphx_instructions_t
*
instructions
,
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_instructions_create
(
const_migraphx_instruction_t
*
ptr
,
migraphx_instructions_t
*
instructions
,
const
const_migraphx_instruction_t
*
ptr
,
size_t
size
);
size_t
size
);
migraphx_status
migraphx_modules_destroy
(
migraphx_modules_t
modules
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_modules_destroy
(
migraphx_modules_t
modules
);
migraphx_status
migraphx_modules_assign_to
(
migraphx_modules_t
output
,
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_modules_assign_to
(
migraphx_modules_t
output
,
const_migraphx_modules_t
input
);
const_migraphx_modules_t
input
);
migraphx_status
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_modules_create
(
migraphx_modules_t
*
modules
,
migraphx_modules_create
(
migraphx_modules_t
*
modules
,
migraphx_module_t
*
ptr
,
size_t
size
);
migraphx_module_t
*
ptr
,
size_t
size
);
migraphx_status
migraphx_module_create
(
migraphx_module_t
*
module
,
char
*
name
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_module_create
(
migraphx_module_t
*
module
,
char
*
name
);
migraphx_status
migraphx_module_print
(
const_migraphx_module_t
module
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_module_print
(
const_migraphx_module_t
module
);
migraphx_status
migraphx_module_add_instruction
(
migraphx_instruction_t
*
out
,
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_module_add_instruction
(
migraphx_instruction_t
*
out
,
migraphx_module_t
module
,
migraphx_module_t
module
,
migraphx_operation_t
op
,
migraphx_operation_t
op
,
migraphx_instructions_t
args
);
migraphx_instructions_t
args
);
migraphx_status
migraphx_module_add_instruction_with_mod_args
(
migraphx_instruction_t
*
out
,
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_module_t
module
,
migraphx_module_add_instruction_with_mod_args
(
migraphx_instruction_t
*
out
,
migraphx_operation_t
op
,
migraphx_module_t
module
,
migraphx_instructions_t
args
,
migraphx_operation_t
op
,
migraphx_modules_t
module_refs
);
migraphx_instructions_t
args
,
migraphx_modules_t
module_refs
);
migraphx_status
migraphx_module_add_literal
(
migraphx_instruction_t
*
out
,
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_module_add_literal
(
migraphx_instruction_t
*
out
,
migraphx_module_t
module
,
migraphx_module_t
module
,
const_migraphx_shape_t
shape
,
const_migraphx_shape_t
shape
,
const
char
*
buffer
);
const
char
*
buffer
);
migraphx_status
migraphx_module_add_parameter
(
migraphx_instruction_t
*
out
,
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_module_add_parameter
(
migraphx_instruction_t
*
out
,
migraphx_module_t
module
,
migraphx_module_t
module
,
const
char
*
name
,
const
char
*
name
,
const_migraphx_shape_t
shape
);
const_migraphx_shape_t
shape
);
migraphx_status
migraphx_module_add_return
(
migraphx_instruction_t
*
out
,
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_module_add_return
(
migraphx_instruction_t
*
out
,
migraphx_module_t
module
,
migraphx_module_t
module
,
migraphx_instructions_t
args
);
migraphx_instructions_t
args
);
migraphx_status
migraphx_module_add_allocation
(
migraphx_instruction_t
*
out
,
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_module_add_allocation
(
migraphx_instruction_t
*
out
,
migraphx_module_t
module
,
migraphx_module_t
module
,
const_migraphx_shape_t
s
);
const_migraphx_shape_t
s
);
migraphx_status
migraphx_program_destroy
(
migraphx_program_t
program
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_program_destroy
(
migraphx_program_t
program
);
migraphx_status
migraphx_program_assign_to
(
migraphx_program_t
output
,
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_program_assign_to
(
migraphx_program_t
output
,
const_migraphx_program_t
input
);
const_migraphx_program_t
input
);
migraphx_status
migraphx_program_create
(
migraphx_program_t
*
program
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_program_create
(
migraphx_program_t
*
program
);
migraphx_status
migraphx_program_get_main_module
(
migraphx_module_t
*
out
,
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_program_get_main_module
(
migraphx_module_t
*
out
,
migraphx_program_t
program
);
migraphx_program_t
program
);
migraphx_status
migraphx_program_create_module
(
migraphx_module_t
*
out
,
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_program_create_module
(
migraphx_module_t
*
out
,
migraphx_program_t
program
,
migraphx_program_t
program
,
const
char
*
name
);
const
char
*
name
);
migraphx_status
migraphx_program_compile
(
migraphx_program_t
program
,
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_program_compile
(
migraphx_program_t
program
,
migraphx_target_t
target
,
migraphx_target_t
target
,
migraphx_compile_options_t
options
);
migraphx_compile_options_t
options
);
migraphx_status
migraphx_program_get_parameter_shapes
(
migraphx_program_parameter_shapes_t
*
out
,
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_program_get_parameter_shapes
(
migraphx_program_t
program
);
migraphx_program_parameter_shapes_t
*
out
,
migraphx_program_t
program
);
migraphx_status
migraphx_program_get_output_shapes
(
migraphx_shapes_t
*
out
,
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_program_get_output_shapes
(
migraphx_shapes_t
*
out
,
migraphx_program_t
program
);
migraphx_program_t
program
);
migraphx_status
migraphx_program_print
(
const_migraphx_program_t
program
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_program_print
(
const_migraphx_program_t
program
);
migraphx_status
migraphx_program_sort
(
migraphx_program_t
program
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_program_sort
(
migraphx_program_t
program
);
migraphx_status
migraphx_program_run
(
migraphx_arguments_t
*
out
,
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_program_run
(
migraphx_arguments_t
*
out
,
migraphx_program_t
program
,
migraphx_program_t
program
,
migraphx_program_parameters_t
params
);
migraphx_program_parameters_t
params
);
migraphx_status
migraphx_program_run_async
(
migraphx_arguments_t
*
out
,
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_program_run_async
(
migraphx_arguments_t
*
out
,
migraphx_program_t
program
,
migraphx_program_t
program
,
migraphx_program_parameters_t
params
,
migraphx_program_parameters_t
params
,
void
*
s
,
void
*
s
,
const
char
*
name
);
const
char
*
name
);
migraphx_status
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_program_equal
(
bool
*
out
,
migraphx_program_equal
(
bool
*
out
,
const_migraphx_program_t
program
,
const_migraphx_program_t
x
);
const_migraphx_program_t
program
,
const_migraphx_program_t
x
);
migraphx_status
migraphx_program_experimental_get_context
(
migraphx_context_t
*
out
,
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_program_experimental_get_context
(
const_migraphx_program_t
program
);
migraphx_context_t
*
out
,
const_migraphx_program_t
program
);
migraphx_status
migraphx_operation_destroy
(
migraphx_operation_t
operation
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_operation_destroy
(
migraphx_operation_t
operation
);
migraphx_status
migraphx_operation_assign_to
(
migraphx_operation_t
output
,
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_operation_assign_to
(
migraphx_operation_t
output
,
const_migraphx_operation_t
input
);
const_migraphx_operation_t
input
);
migraphx_status
migraphx_operation_create
(
migraphx_operation_t
*
operation
,
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_operation_create
(
migraphx_operation_t
*
operation
,
const
char
*
name
,
const
char
*
name
,
const
char
*
attributes
,
const
char
*
attributes
,
...);
...);
migraphx_status
migraphx_operation_name
(
char
*
out
,
size_t
out_size
,
migraphx_operation_t
operation
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_operation_name
(
char
*
out
,
size_t
out_size
,
migraphx_operation_t
operation
);
migraphx_status
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_load
(
migraphx_program_t
*
out
,
migraphx_load
(
migraphx_program_t
*
out
,
const
char
*
name
,
migraphx_file_options_t
options
);
const
char
*
name
,
migraphx_file_options_t
options
);
migraphx_status
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_save
(
migraphx_program_t
p
,
migraphx_save
(
migraphx_program_t
p
,
const
char
*
name
,
migraphx_file_options_t
options
);
const
char
*
name
,
migraphx_file_options_t
options
);
migraphx_status
migraphx_onnx_options_destroy
(
migraphx_onnx_options_t
onnx_options
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_onnx_options_destroy
(
migraphx_onnx_options_t
onnx_options
);
migraphx_status
migraphx_onnx_options_assign_to
(
migraphx_onnx_options_t
output
,
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_onnx_options_assign_to
(
const_migraphx_onnx_options_t
input
);
migraphx_onnx_options_t
output
,
const_migraphx_onnx_options_t
input
);
migraphx_status
migraphx_onnx_options_create
(
migraphx_onnx_options_t
*
onnx_options
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_onnx_options_create
(
migraphx_onnx_options_t
*
onnx_options
);
migraphx_status
migraphx_onnx_options_set_input_parameter_shape
(
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_onnx_options_set_input_parameter_shape
(
migraphx_onnx_options_t
onnx_options
,
const
char
*
name
,
size_t
*
dims
,
size_t
dims_size
);
migraphx_onnx_options_t
onnx_options
,
const
char
*
name
,
size_t
*
dims
,
size_t
dims_size
);
migraphx_status
migraphx_onnx_options_set_default_dim_value
(
migraphx_onnx_options_t
onnx_options
,
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_onnx_options_set_dyn_input_parameter_shape
(
size_t
value
);
migraphx_onnx_options_t
onnx_options
,
const
char
*
name
,
migraphx_dynamic_dimensions_t
dims
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_onnx_options_set_default_dim_value
(
migraphx_onnx_options_t
onnx_options
,
size_t
value
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_onnx_options_set_default_dyn_dim_value
(
migraphx_onnx_options_t
onnx_options
,
const_migraphx_dynamic_dimension_t
dd
);
migraphx_status
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_onnx_options_set_default_loop_iterations
(
migraphx_onnx_options_set_default_loop_iterations
(
migraphx_onnx_options_t
onnx_options
,
migraphx_onnx_options_t
onnx_options
,
int64_t
value
);
int64_t
value
);
migraphx_status
migraphx_file_options_destroy
(
migraphx_file_options_t
file_options
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_file_options_destroy
(
migraphx_file_options_t
file_options
);
migraphx_status
migraphx_file_options_assign_to
(
migraphx_file_options_t
output
,
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_file_options_assign_to
(
const_migraphx_file_options_t
input
);
migraphx_file_options_t
output
,
const_migraphx_file_options_t
input
);
migraphx_status
migraphx_file_options_create
(
migraphx_file_options_t
*
file_options
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_file_options_create
(
migraphx_file_options_t
*
file_options
);
migraphx_status
migraphx_file_options_set_file_format
(
migraphx_file_options_t
file_options
,
MIGRAPHX_C_EXPORT
migraphx_status
const
char
*
format
);
migraphx_file_options_set_file_format
(
migraphx_file_options_t
file_options
,
const
char
*
format
);
migraphx_status
migraphx_compile_options_destroy
(
migraphx_compile_options_t
compile_options
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_compile_options_destroy
(
migraphx_compile_options_t
compile_options
);
migraphx_status
migraphx_compile_options_assign_to
(
migraphx_compile_options_t
output
,
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_compile_options_assign_to
(
const_migraphx_compile_options_t
input
);
migraphx_compile_options_t
output
,
const_migraphx_compile_options_t
input
);
migraphx_status
migraphx_compile_options_create
(
migraphx_compile_options_t
*
compile_options
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_compile_options_create
(
migraphx_compile_options_t
*
compile_options
);
migraphx_status
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_compile_options_set_offload_copy
(
migraphx_compile_options_t
compile_options
,
bool
value
);
migraphx_compile_options_set_offload_copy
(
migraphx_compile_options_t
compile_options
,
bool
value
);
migraphx_status
migraphx_compile_options_set_fast_math
(
migraphx_compile_options_t
compile_options
,
MIGRAPHX_C_EXPORT
migraphx_status
bool
value
);
migraphx_compile_options_set_fast_math
(
migraphx_compile_options_t
compile_options
,
bool
value
);
migraphx_status
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_compile_options_set_exhaustive_tune_flag
(
migraphx_compile_options_set_exhaustive_tune_flag
(
migraphx_compile_options_t
compile_options
,
migraphx_compile_options_t
compile_options
,
bool
value
);
bool
value
);
migraphx_status
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_parse_onnx
(
migraphx_program_t
*
out
,
migraphx_parse_onnx
(
migraphx_program_t
*
out
,
const
char
*
name
,
migraphx_onnx_options_t
options
);
const
char
*
name
,
migraphx_onnx_options_t
options
);
migraphx_status
migraphx_parse_onnx_buffer
(
migraphx_program_t
*
out
,
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_parse_onnx_buffer
(
migraphx_program_t
*
out
,
const
void
*
data
,
const
void
*
data
,
size_t
size
,
size_t
size
,
migraphx_onnx_options_t
options
);
migraphx_onnx_options_t
options
);
migraphx_status
migraphx_tf_options_destroy
(
migraphx_tf_options_t
tf_options
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_tf_options_destroy
(
migraphx_tf_options_t
tf_options
);
migraphx_status
migraphx_tf_options_assign_to
(
migraphx_tf_options_t
output
,
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_tf_options_assign_to
(
migraphx_tf_options_t
output
,
const_migraphx_tf_options_t
input
);
const_migraphx_tf_options_t
input
);
migraphx_status
migraphx_tf_options_create
(
migraphx_tf_options_t
*
tf_options
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_tf_options_create
(
migraphx_tf_options_t
*
tf_options
);
migraphx_status
migraphx_tf_options_set_nhwc
(
migraphx_tf_options_t
tf_options
,
bool
is_nhwc
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_tf_options_set_nhwc
(
migraphx_tf_options_t
tf_options
,
bool
is_nhwc
);
migraphx_status
migraphx_tf_options_set_input_parameter_shape
(
migraphx_tf_options_t
tf_options
,
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_tf_options_set_input_parameter_shape
(
const
char
*
name
,
migraphx_tf_options_t
tf_options
,
const
char
*
name
,
size_t
*
dims
,
size_t
dims_size
);
size_t
*
dims
,
size_t
dims_size
);
migraphx_status
migraphx_tf_options_set_default_dim_value
(
migraphx_tf_options_t
tf_options
,
MIGRAPHX_C_EXPORT
migraphx_status
size_t
value
);
migraphx_tf_options_set_default_dim_value
(
migraphx_tf_options_t
tf_options
,
size_t
value
);
migraphx_status
migraphx_tf_options_set_output_names
(
migraphx_tf_options_t
tf_options
,
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_tf_options_set_output_names
(
const
char
**
names
,
migraphx_tf_options_t
tf_options
,
const
char
**
names
,
size_t
names_size
);
size_t
names_size
);
migraphx_status
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_parse_tf
(
migraphx_program_t
*
out
,
migraphx_parse_tf
(
migraphx_program_t
*
out
,
const
char
*
name
,
migraphx_tf_options_t
options
);
const
char
*
name
,
migraphx_tf_options_t
options
);
migraphx_status
migraphx_quantize_op_names_destroy
(
migraphx_quantize_op_names_t
quantize_op_names
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_quantize_op_names_destroy
(
migraphx_quantize_op_names_t
quantize_op_names
);
migraphx_status
migraphx_quantize_op_names_assign_to
(
migraphx_quantize_op_names_t
output
,
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_quantize_op_names_assign_to
(
const_migraphx_quantize_op_names_t
input
);
migraphx_quantize_op_names_t
output
,
const_migraphx_quantize_op_names_t
input
);
migraphx_status
migraphx_quantize_op_names_create
(
migraphx_quantize_op_names_t
*
quantize_op_names
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_quantize_op_names_create
(
migraphx_quantize_op_names_t
*
quantize_op_names
);
migraphx_status
migraphx_quantize_op_names_add
(
migraphx_quantize_op_names_t
quantize_op_names
,
MIGRAPHX_C_EXPORT
migraphx_status
const
char
*
name
);
migraphx_quantize_op_names_add
(
migraphx_quantize_op_names_t
quantize_op_names
,
const
char
*
name
);
migraphx_status
migraphx_quantize_fp16_with_op_names
(
migraphx_program_t
prog
,
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_quantize_op_names_t
name
);
migraphx_quantize_fp16_with_op_names
(
migraphx_program_t
prog
,
migraphx_quantize_op_names_t
name
);
migraphx_status
migraphx_quantize_fp16
(
migraphx_program_t
prog
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_quantize_fp16
(
migraphx_program_t
prog
);
migraphx_status
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_quantize_int8_options_destroy
(
migraphx_quantize_int8_options_t
quantize_int8_options
);
migraphx_quantize_int8_options_destroy
(
migraphx_quantize_int8_options_t
quantize_int8_options
);
migraphx_status
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_quantize_int8_options_assign_to
(
migraphx_quantize_int8_options_assign_to
(
migraphx_quantize_int8_options_t
output
,
migraphx_quantize_int8_options_t
output
,
const_migraphx_quantize_int8_options_t
input
);
const_migraphx_quantize_int8_options_t
input
);
migraphx_status
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_quantize_int8_options_create
(
migraphx_quantize_int8_options_t
*
quantize_int8_options
);
migraphx_quantize_int8_options_create
(
migraphx_quantize_int8_options_t
*
quantize_int8_options
);
migraphx_status
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_quantize_int8_options_add_op_name
(
migraphx_quantize_int8_options_add_op_name
(
migraphx_quantize_int8_options_t
quantize_int8_options
,
migraphx_quantize_int8_options_t
quantize_int8_options
,
const
char
*
name
);
const
char
*
name
);
migraphx_status
migraphx_quantize_int8_options_add_calibration_data
(
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_quantize_int8_options_add_calibration_data
(
migraphx_quantize_int8_options_t
quantize_int8_options
,
migraphx_program_parameters_t
data
);
migraphx_quantize_int8_options_t
quantize_int8_options
,
migraphx_program_parameters_t
data
);
migraphx_status
migraphx_quantize_int8
(
migraphx_program_t
prog
,
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_quantize_int8
(
migraphx_program_t
prog
,
migraphx_target_t
target
,
migraphx_target_t
target
,
migraphx_quantize_int8_options_t
options
);
migraphx_quantize_int8_options_t
options
);
migraphx_status
migraphx_context_finish
(
const_migraphx_context_t
context
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_context_finish
(
const_migraphx_context_t
context
);
migraphx_status
migraphx_context_get_queue
(
void
**
out
,
migraphx_context_t
context
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_context_get_queue
(
void
**
out
,
migraphx_context_t
context
);
migraphx_status
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_experimental_custom_op_destroy
(
migraphx_experimental_custom_op_t
experimental_custom_op
);
migraphx_experimental_custom_op_destroy
(
migraphx_experimental_custom_op_t
experimental_custom_op
);
migraphx_status
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_experimental_custom_op_assign_to
(
migraphx_experimental_custom_op_assign_to
(
migraphx_experimental_custom_op_t
output
,
migraphx_experimental_custom_op_t
output
,
const_migraphx_experimental_custom_op_t
input
);
const_migraphx_experimental_custom_op_t
input
);
migraphx_status
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_experimental_custom_op_create
(
migraphx_experimental_custom_op_t
*
experimental_custom_op
,
migraphx_experimental_custom_op_create
(
migraphx_experimental_custom_op_t
*
experimental_custom_op
,
void
*
obj
,
void
*
obj
,
migraphx_experimental_custom_op_copy
c
,
migraphx_experimental_custom_op_copy
c
,
...
@@ -518,21 +630,20 @@ migraphx_experimental_custom_op_create(migraphx_experimental_custom_op_t* experi
...
@@ -518,21 +630,20 @@ migraphx_experimental_custom_op_create(migraphx_experimental_custom_op_t* experi
const
char
*
obj_typename
,
const
char
*
obj_typename
,
const
char
*
name
);
const
char
*
name
);
migraphx_status
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_experimental_custom_op_set_compute
(
migraphx_experimental_custom_op_set_compute
(
migraphx_experimental_custom_op_t
obj
,
migraphx_experimental_custom_op_t
obj
,
migraphx_experimental_custom_op_compute
input
);
migraphx_experimental_custom_op_compute
input
);
migraphx_status
migraphx_experimental_custom_op_set_compute_shape
(
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_experimental_custom_op_set_compute_shape
(
migraphx_experimental_custom_op_t
obj
,
migraphx_experimental_custom_op_compute_shape
input
);
migraphx_experimental_custom_op_t
obj
,
migraphx_experimental_custom_op_compute_shape
input
);
migraphx_status
migraphx_experimental_custom_op_set_output_alias
(
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_experimental_custom_op_set_output_alias
(
migraphx_experimental_custom_op_t
obj
,
migraphx_experimental_custom_op_output_alias
input
);
migraphx_experimental_custom_op_t
obj
,
migraphx_experimental_custom_op_output_alias
input
);
migraphx_status
migraphx_experimental_custom_op_set_runs_on_offload_target
(
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_experimental_custom_op_set_runs_on_offload_target
(
migraphx_experimental_custom_op_t
obj
,
migraphx_experimental_custom_op_t
obj
,
migraphx_experimental_custom_op_runs_on_offload_target
input
);
migraphx_experimental_custom_op_runs_on_offload_target
input
);
migraphx_status
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_experimental_custom_op_register
(
migraphx_experimental_custom_op_t
experimental_custom_op
);
migraphx_experimental_custom_op_register
(
migraphx_experimental_custom_op_t
experimental_custom_op
);
#ifdef __cplusplus
#ifdef __cplusplus
...
...
src/api/include/migraphx/migraphx.hpp
View file @
23cb7917
...
@@ -571,10 +571,90 @@ using require_interface =
...
@@ -571,10 +571,90 @@ using require_interface =
// NOLINTNEXTLINE
// NOLINTNEXTLINE
#define MIGRAPHX_CONST_HANDLE_BASE(name) MIGRAPHX_DETAIL_HANDLE_BASE(name, const)
#define MIGRAPHX_CONST_HANDLE_BASE(name) MIGRAPHX_DETAIL_HANDLE_BASE(name, const)
/**
* Container to hold optimal dynamic dimension values.
*/
struct
optimals
:
MIGRAPHX_HANDLE_BASE
(
optimals
)
{
MIGRAPHX_HANDLE_CONSTRUCTOR
(
optimals
)
optimals
(
std
::
initializer_list
<
size_t
>
init_list
)
{
this
->
make_handle
(
&
migraphx_optimals_create
,
init_list
.
begin
(),
init_list
.
size
());
}
};
/**
* @brief Dynamic dimension object.
* @details minimum, maximum, and optimal dimensions
*/
struct
dynamic_dimension
:
MIGRAPHX_CONST_HANDLE_BASE
(
dynamic_dimension
)
{
MIGRAPHX_HANDLE_CONSTRUCTOR
(
dynamic_dimension
)
dynamic_dimension
(
size_t
min
,
size_t
max
)
{
this
->
make_handle
(
&
migraphx_dynamic_dimension_create_min_max
,
min
,
max
);
}
dynamic_dimension
(
size_t
min
,
size_t
max
,
const
optimals
&
opts
)
{
this
->
make_handle
(
&
migraphx_dynamic_dimension_create_min_max_optimals
,
min
,
max
,
opts
.
get_handle_ptr
());
}
bool
is_fixed
()
const
{
bool
result
=
false
;
call
(
&
migraphx_dynamic_dimension_is_fixed
,
&
result
,
this
->
get_handle_ptr
());
return
result
;
}
friend
bool
operator
==
(
const
dynamic_dimension
&
x
,
const
dynamic_dimension
&
y
)
{
bool
pout
;
call
(
&
migraphx_dynamic_dimension_equal
,
&
pout
,
x
.
get_handle_ptr
(),
y
.
get_handle_ptr
());
return
pout
;
}
friend
bool
operator
!=
(
const
dynamic_dimension
&
x
,
const
dynamic_dimension
&
y
)
{
return
not
(
x
==
y
);
}
};
/**
* Container to hold dynamic_dimension objects.
*/
struct
dynamic_dimensions
:
MIGRAPHX_HANDLE_BASE
(
dynamic_dimensions
)
{
MIGRAPHX_HANDLE_CONSTRUCTOR
(
dynamic_dimensions
)
template
<
class
...
Ts
>
dynamic_dimensions
(
Ts
...
xs
)
{
std
::
array
<
const_migraphx_dynamic_dimension_t
,
sizeof
...(
Ts
)
>
a
{
xs
.
get_handle_ptr
()...};
this
->
make_handle
(
&
migraphx_dynamic_dimensions_create
,
a
.
data
(),
a
.
size
());
}
size_t
size
()
const
{
size_t
pout
;
call
(
&
migraphx_dynamic_dimensions_size
,
&
pout
,
this
->
get_handle_ptr
());
return
pout
;
}
dynamic_dimension
operator
[](
size_t
pidx
)
const
{
const_migraphx_dynamic_dimension_t
pout
;
call
(
&
migraphx_dynamic_dimensions_get
,
&
pout
,
this
->
get_handle_ptr
(),
pidx
);
return
{
pout
,
this
->
share_handle
()};
}
};
/**
/**
* @brief Describe shape of tensor
* @brief Describe shape of tensor
* @details A shape consists of a data type, lengths of multi-dimension tensor, and strides
* @details A shape consists of a data type, lengths of multi-dimension tensor, and strides
*
*/
*/
struct
shape
:
MIGRAPHX_CONST_HANDLE_BASE
(
shape
)
struct
shape
:
MIGRAPHX_CONST_HANDLE_BASE
(
shape
)
{
{
...
@@ -598,6 +678,13 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
...
@@ -598,6 +678,13 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
this
->
make_handle
(
&
migraphx_shape_create
,
type
,
plengths
.
data
(),
plengths
.
size
());
this
->
make_handle
(
&
migraphx_shape_create
,
type
,
plengths
.
data
(),
plengths
.
size
());
}
}
// Force all calls of the format `shape( type_t, { size_t compatibles } )` to map to
// shape(type_t, std::vector<std::size_t> l)
shape
(
migraphx_shape_datatype_t
t
,
std
::
initializer_list
<
std
::
size_t
>
d
)
:
shape
::
shape
(
t
,
std
::
vector
<
std
::
size_t
>
{
d
.
begin
(),
d
.
end
()})
{
}
shape
(
migraphx_shape_datatype_t
type
,
shape
(
migraphx_shape_datatype_t
type
,
std
::
vector
<
size_t
>
plengths
,
std
::
vector
<
size_t
>
plengths
,
std
::
vector
<
size_t
>
pstrides
)
std
::
vector
<
size_t
>
pstrides
)
...
@@ -610,6 +697,11 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
...
@@ -610,6 +697,11 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
pstrides
.
size
());
pstrides
.
size
());
}
}
shape
(
migraphx_shape_datatype_t
type
,
const
dynamic_dimensions
&
dyn_dims
)
{
this
->
make_handle
(
&
migraphx_shape_create_dynamic
,
type
,
dyn_dims
.
get_handle_ptr
());
}
std
::
vector
<
size_t
>
lengths
()
const
std
::
vector
<
size_t
>
lengths
()
const
{
{
const
size_t
*
pout
;
const
size_t
*
pout
;
...
@@ -626,6 +718,14 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
...
@@ -626,6 +718,14 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
return
{
pout
,
pout
+
pout_size
};
return
{
pout
,
pout
+
pout_size
};
}
}
/// Get the dynamic dimensions of the shape
dynamic_dimensions
dyn_dims
()
const
{
migraphx_dynamic_dimensions_t
pout
;
call
(
&
migraphx_shape_dyn_dims
,
&
pout
,
this
->
get_handle_ptr
());
return
{
pout
,
own
{}};
}
migraphx_shape_datatype_t
type
()
const
migraphx_shape_datatype_t
type
()
const
{
{
migraphx_shape_datatype_t
pout
;
migraphx_shape_datatype_t
pout
;
...
@@ -654,6 +754,14 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
...
@@ -654,6 +754,14 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
return
result
;
return
result
;
}
}
/// Is the shape dynamic
bool
dynamic
()
const
{
bool
result
=
false
;
call
(
&
migraphx_shape_dynamic
,
&
result
,
this
->
get_handle_ptr
());
return
result
;
}
// map element index to space index
// map element index to space index
size_t
index
(
size_t
i
)
const
size_t
index
(
size_t
i
)
const
{
{
...
@@ -687,6 +795,11 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
...
@@ -687,6 +795,11 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
MIGRAPHX_DEPRECATED
(
"Contructor without lifetime annotation is deprecated."
)
MIGRAPHX_DEPRECATED
(
"Contructor without lifetime annotation is deprecated."
)
argument
(
const
migraphx_argument
*
p
)
{
this
->
set_handle
(
p
,
borrow
{});
}
argument
(
const
migraphx_argument
*
p
)
{
this
->
set_handle
(
p
,
borrow
{});
}
argument
(
shape
pshape
)
{
this
->
make_handle
(
&
migraphx_argument_create_empty
,
pshape
.
get_handle_ptr
());
}
argument
(
shape
pshape
,
void
*
pbuffer
)
argument
(
shape
pshape
,
void
*
pbuffer
)
{
{
this
->
make_handle
(
&
migraphx_argument_create
,
pshape
.
get_handle_ptr
(),
pbuffer
);
this
->
make_handle
(
&
migraphx_argument_create
,
pshape
.
get_handle_ptr
(),
pbuffer
);
...
@@ -1182,12 +1295,27 @@ struct onnx_options : MIGRAPHX_HANDLE_BASE(onnx_options)
...
@@ -1182,12 +1295,27 @@ struct onnx_options : MIGRAPHX_HANDLE_BASE(onnx_options)
dim
.
size
());
dim
.
size
());
}
}
void
set_dyn_input_parameter_shape
(
const
std
::
string
&
name
,
const
dynamic_dimensions
&
dyn_dims
)
{
call
(
&
migraphx_onnx_options_set_dyn_input_parameter_shape
,
this
->
get_handle_ptr
(),
name
.
c_str
(),
dyn_dims
.
get_handle_ptr
());
}
/// When there is a dimension parameter, then use this default value
/// When there is a dimension parameter, then use this default value
void
set_default_dim_value
(
unsigned
int
value
)
void
set_default_dim_value
(
unsigned
int
value
)
{
{
call
(
&
migraphx_onnx_options_set_default_dim_value
,
this
->
get_handle_ptr
(),
value
);
call
(
&
migraphx_onnx_options_set_default_dim_value
,
this
->
get_handle_ptr
(),
value
);
}
}
void
set_default_dyn_dim_value
(
const
dynamic_dimension
&
dd
)
{
call
(
&
migraphx_onnx_options_set_default_dyn_dim_value
,
this
->
get_handle_ptr
(),
dd
.
get_handle_ptr
());
}
/// Set default max iteration number for the loop operator
/// Set default max iteration number for the loop operator
void
set_default_loop_iterations
(
int64_t
value
)
void
set_default_loop_iterations
(
int64_t
value
)
{
{
...
@@ -1359,13 +1487,17 @@ quantize_int8(const program& prog, const target& ptarget, const quantize_int8_op
...
@@ -1359,13 +1487,17 @@ quantize_int8(const program& prog, const target& ptarget, const quantize_int8_op
struct
experimental_custom_op_base
struct
experimental_custom_op_base
{
{
experimental_custom_op_base
()
=
default
;
experimental_custom_op_base
(
const
experimental_custom_op_base
&
)
=
default
;
experimental_custom_op_base
&
operator
=
(
const
experimental_custom_op_base
&
)
=
default
;
virtual
~
experimental_custom_op_base
()
=
default
;
virtual
std
::
string
name
()
const
=
0
;
virtual
std
::
string
name
()
const
=
0
;
virtual
argument
compute
(
context
ctx
,
shape
output
,
arguments
inputs
)
const
=
0
;
virtual
argument
compute
(
context
ctx
,
shape
output
,
arguments
inputs
)
const
=
0
;
virtual
shape
compute_shape
(
shapes
inputs
)
const
=
0
;
virtual
shape
compute_shape
(
shapes
inputs
)
const
=
0
;
virtual
std
::
vector
<
size_t
>
output_alias
(
shapes
)
const
{
return
{};
}
virtual
std
::
vector
<
size_t
>
output_alias
(
shapes
)
const
{
return
{};
}
// TODO: Return target string instead of bool
// TODO: Return target string instead of bool
virtual
bool
runs_on_offload_target
()
const
=
0
;
virtual
bool
runs_on_offload_target
()
const
=
0
;
virtual
~
experimental_custom_op_base
()
=
default
;
};
};
struct
experimental_custom_op
:
interface_base
<
MIGRAPHX_HANDLE_BASE
(
experimental_custom_op
)
>
struct
experimental_custom_op
:
interface_base
<
MIGRAPHX_HANDLE_BASE
(
experimental_custom_op
)
>
...
...
src/api/migraphx.py
View file @
23cb7917
...
@@ -45,56 +45,49 @@ def shape_type_wrap(p):
...
@@ -45,56 +45,49 @@ def shape_type_wrap(p):
p
.
read
=
'migraphx::to_shape_type(${name})'
p
.
read
=
'migraphx::to_shape_type(${name})'
@
api
.
cwrap
(
'migraphx::compile_options'
)
def
auto_handle
(
*
args
,
**
kwargs
):
def
compile_options_type_wrap
(
p
):
def
with_handle
(
f
):
if
p
.
returns
:
return
api
.
handle
(
'migraphx_'
+
f
.
__name__
,
'migraphx::'
+
f
.
__name__
,
p
.
add_param
(
'migraphx_compile_options *'
)
*
args
,
**
kwargs
)(
f
)
p
.
bad_param
(
'${name} == nullptr'
,
'Null pointer'
)
p
.
write
=
[
'*${name} = migraphx::to_compile_options(${result})'
]
else
:
p
.
add_param
(
'migraphx_compile_options *'
)
p
.
read
=
'${name} == nullptr ? migraphx::compile_options{} : migraphx::to_compile_options(*${name})'
@
api
.
cwrap
(
'migraphx::file_options'
)
def
file_options_type_wrap
(
p
):
if
p
.
returns
:
p
.
add_param
(
'migraphx_file_options *'
)
p
.
bad_param
(
'${name} == nullptr'
,
'Null pointer'
)
p
.
write
=
[
'*${name} = migraphx::to_file_options(${result})'
]
else
:
p
.
add_param
(
'migraphx_file_options *'
)
p
.
read
=
'${name} == nullptr ? migraphx::file_options{} : migraphx::to_file_options(*${name})'
return
with_handle
@
api
.
cwrap
(
'migraphx::onnx_options'
)
def
onnx_options_type_wrap
(
p
):
if
p
.
returns
:
p
.
add_param
(
'migraphx_onnx_options *'
)
p
.
bad_param
(
'${name} == nullptr'
,
'Null pointer'
)
p
.
write
=
[
'*${name} = migraphx::to_onnx_options(${result})'
]
else
:
p
.
add_param
(
'migraphx_onnx_options *'
)
p
.
read
=
'${name} == nullptr ? migraphx::onnx_options{} : migraphx::to_onnx_options(*${name})'
@
api
.
handle
(
'migraphx_optimals'
,
'std::set<size_t>'
)
def
optimals
(
h
):
h
.
constructor
(
'create'
,
api
.
params
(
ptr
=
'const size_t*'
,
size
=
'size_t'
),
fname
=
'migraphx::make_set<size_t>'
)
@
api
.
cwrap
(
'migraphx::tf_options'
)
def
tf_options_type_wrap
(
p
):
if
p
.
returns
:
p
.
add_param
(
'migraphx_tf_options *'
)
p
.
bad_param
(
'${name} == nullptr'
,
'Null pointer'
)
p
.
write
=
[
'*${name} = migraphx::to_tf_options(${result})'
]
else
:
p
.
add_param
(
'migraphx_tf_options *'
)
p
.
read
=
'${name} == nullptr ? migraphx::tf_options{} : migraphx::to_tf_options(*${name})'
@
api
.
handle
(
'migraphx_dynamic_dimension'
,
'migraphx::shape::dynamic_dimension'
)
def
dynamic_dimension
(
h
):
h
.
constructor
(
'create_min_max'
,
api
.
params
(
min
=
'size_t'
,
max
=
'size_t'
))
h
.
constructor
(
'create_min_max_optimals'
,
api
.
params
(
min
=
'size_t'
,
max
=
'size_t'
,
optimals
=
'std::set<size_t>'
))
h
.
method
(
'is_fixed'
,
returns
=
'bool'
,
const
=
True
)
h
.
method
(
'equal'
,
api
.
params
(
x
=
'const migraphx::shape::dynamic_dimension&'
),
invoke
=
'migraphx::equal($@)'
,
returns
=
'bool'
,
const
=
True
)
def
auto_handle
(
*
args
,
**
kwargs
):
def
with_handle
(
f
):
return
api
.
handle
(
'migraphx_'
+
f
.
__name__
,
'migraphx::'
+
f
.
__name__
,
*
args
,
**
kwargs
)(
f
)
return
with_handle
@
api
.
handle
(
'migraphx_dynamic_dimensions'
,
'std::vector<migraphx::shape::dynamic_dimension>'
)
def
dynamic_dimensions
(
h
):
h
.
constructor
(
'create'
,
api
.
params
(
ptr
=
'const const_migraphx_dynamic_dimension_t*'
,
size
=
'size_t'
),
fname
=
'migraphx::to_obj_vector<const_migraphx_dynamic_dimension_t>'
)
h
.
method
(
'size'
,
returns
=
'size_t'
)
h
.
method
(
'get'
,
api
.
params
(
idx
=
'size_t'
),
fname
=
'at'
,
cpp_name
=
'operator[]'
,
returns
=
'const migraphx::shape::dynamic_dimension&'
)
@
auto_handle
()
@
auto_handle
()
...
@@ -109,20 +102,29 @@ def shape(h):
...
@@ -109,20 +102,29 @@ def shape(h):
lengths
=
'std::vector<size_t>'
,
lengths
=
'std::vector<size_t>'
,
strides
=
'std::vector<size_t>'
))
strides
=
'std::vector<size_t>'
))
h
.
constructor
(
'create_scalar'
,
api
.
params
(
type
=
'migraphx::shape::type_t'
))
h
.
constructor
(
'create_scalar'
,
api
.
params
(
type
=
'migraphx::shape::type_t'
))
h
.
constructor
(
'create_dynamic'
,
api
.
params
(
type
=
'migraphx::shape::type_t'
,
dims
=
'std::vector<migraphx::shape::dynamic_dimension>'
))
h
.
method
(
'lengths'
,
h
.
method
(
'lengths'
,
fname
=
'lens'
,
fname
=
'lens'
,
returns
=
'const std::vector<size_t>&'
,
returns
=
'const std::vector<size_t>&'
,
const
=
True
)
const
=
True
)
h
.
method
(
'strides'
,
returns
=
'const std::vector<size_t>&'
,
const
=
True
)
h
.
method
(
'strides'
,
returns
=
'const std::vector<size_t>&'
,
const
=
True
)
h
.
method
(
'dyn_dims'
,
returns
=
'std::vector<migraphx::shape::dynamic_dimension>'
,
const
=
True
)
h
.
method
(
'type'
,
returns
=
'migraphx::shape::type_t'
,
const
=
True
)
h
.
method
(
'type'
,
returns
=
'migraphx::shape::type_t'
,
const
=
True
)
h
.
method
(
'elements'
,
returns
=
'size_t'
,
const
=
True
)
h
.
method
(
'elements'
,
returns
=
'size_t'
,
const
=
True
)
h
.
method
(
'bytes'
,
returns
=
'size_t'
,
const
=
True
)
h
.
method
(
'bytes'
,
returns
=
'size_t'
,
const
=
True
)
h
.
method
(
'ndim'
,
returns
=
'size_t'
,
const
=
True
)
h
.
method
(
'equal'
,
h
.
method
(
'equal'
,
api
.
params
(
x
=
'const migraphx::shape&'
),
api
.
params
(
x
=
'const migraphx::shape&'
),
invoke
=
'migraphx::equal($@)'
,
invoke
=
'migraphx::equal($@)'
,
returns
=
'bool'
,
returns
=
'bool'
,
const
=
True
)
const
=
True
)
h
.
method
(
'standard'
,
returns
=
'bool'
,
const
=
True
)
h
.
method
(
'standard'
,
returns
=
'bool'
,
const
=
True
)
h
.
method
(
'dynamic'
,
returns
=
'bool'
,
const
=
True
)
h
.
method
(
'index'
,
api
.
params
(
i
=
'size_t'
),
returns
=
'size_t'
,
const
=
True
)
h
.
method
(
'index'
,
api
.
params
(
i
=
'size_t'
),
returns
=
'size_t'
,
const
=
True
)
...
@@ -130,6 +132,7 @@ def shape(h):
...
@@ -130,6 +132,7 @@ def shape(h):
def
argument
(
h
):
def
argument
(
h
):
h
.
constructor
(
'create'
,
h
.
constructor
(
'create'
,
api
.
params
(
shape
=
'const migraphx::shape&'
,
buffer
=
'void*'
))
api
.
params
(
shape
=
'const migraphx::shape&'
,
buffer
=
'void*'
))
h
.
constructor
(
'create_empty'
,
api
.
params
(
shape
=
'const migraphx::shape&'
))
h
.
method
(
'shape'
,
h
.
method
(
'shape'
,
fname
=
'get_shape'
,
fname
=
'get_shape'
,
cpp_name
=
'get_shape'
,
cpp_name
=
'get_shape'
,
...
@@ -213,7 +216,7 @@ def instruction(h):
...
@@ -213,7 +216,7 @@ def instruction(h):
def
instructions
(
h
):
def
instructions
(
h
):
h
.
constructor
(
h
.
constructor
(
'create'
,
'create'
,
api
.
params
(
ptr
=
'const_migraphx_instruction_t*'
,
size
=
'size_t'
),
api
.
params
(
ptr
=
'const
const
_migraphx_instruction_t*'
,
size
=
'size_t'
),
fname
=
'migraphx::to_obj_vector<const_migraphx_instruction_t>'
)
fname
=
'migraphx::to_obj_vector<const_migraphx_instruction_t>'
)
...
@@ -325,11 +328,22 @@ def onnx_options(h):
...
@@ -325,11 +328,22 @@ def onnx_options(h):
api
.
params
(
name
=
'const char*'
,
dims
=
'std::vector<size_t>'
),
api
.
params
(
name
=
'const char*'
,
dims
=
'std::vector<size_t>'
),
invoke
=
'migraphx::set_input_parameter_shape($@)'
,
invoke
=
'migraphx::set_input_parameter_shape($@)'
,
)
)
h
.
method
(
'set_dyn_input_parameter_shape'
,
api
.
params
(
name
=
'const char*'
,
dims
=
'std::vector<migraphx::shape::dynamic_dimension>'
),
invoke
=
'migraphx::set_dyn_input_parameter_shape($@)'
,
)
h
.
method
(
h
.
method
(
'set_default_dim_value'
,
'set_default_dim_value'
,
api
.
params
(
value
=
'size_t'
),
api
.
params
(
value
=
'size_t'
),
invoke
=
'migraphx::set_default_dim_value($@)'
,
invoke
=
'migraphx::set_default_dim_value($@)'
,
)
)
h
.
method
(
'set_default_dyn_dim_value'
,
api
.
params
(
dd
=
'const migraphx::shape::dynamic_dimension&'
),
invoke
=
'migraphx::set_default_dyn_dim_value($@)'
,
)
h
.
method
(
h
.
method
(
'set_default_loop_iterations'
,
'set_default_loop_iterations'
,
api
.
params
(
value
=
'int64_t'
),
api
.
params
(
value
=
'int64_t'
),
...
...
src/common.cpp
View file @
23cb7917
/*
/*
* The MIT License (MIT)
* The MIT License (MIT)
*
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* of this software and associated documentation files (the "Software"), to deal
...
@@ -31,20 +31,6 @@
...
@@ -31,20 +31,6 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
// Example:
// s0 = (3,2,4,5) and s1 = (2,1,1)
//
// In this case we need to broadcast (:,1,1) portion of
// s1 plus broadcast the 1st dimension of s1
// giving output_lens = (3,2,4,5)
//
// Another example:
// s0 = (3,2,1,5) and s1 = (2,7,5)
// In this case we need to broadcast the (:,:,1:,:) axis
// of s0 plus the 1st dimension of s1 giving
// output_lens = (3,2,7,5)
//
std
::
vector
<
std
::
size_t
>
compute_broadcasted_lens
(
std
::
vector
<
std
::
size_t
>
s0
,
std
::
vector
<
std
::
size_t
>
compute_broadcasted_lens
(
std
::
vector
<
std
::
size_t
>
s0
,
std
::
vector
<
std
::
size_t
>
s1
)
std
::
vector
<
std
::
size_t
>
s1
)
{
{
...
@@ -77,32 +63,38 @@ std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, sha
...
@@ -77,32 +63,38 @@ std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, sha
}
}
auto
offset
=
s1
.
ndim
()
-
s0
.
ndim
();
auto
offset
=
s1
.
ndim
()
-
s0
.
ndim
();
std
::
vector
<
shape
::
dynamic_dimension
>
out_dims
(
s1
.
dyn_dims
());
std
::
vector
<
shape
::
dynamic_dimension
>
out_dims
(
s1
.
dyn_dims
());
std
::
transform
(
std
::
transform
(
s0
.
dyn_dims
().
cbegin
(),
s0
.
dyn_dims
().
cbegin
(),
s0
.
dyn_dims
().
cend
(),
s0
.
dyn_dims
().
cend
(),
s1
.
dyn_dims
().
cbegin
()
+
offset
,
s1
.
dyn_dims
().
cbegin
()
+
offset
,
out_dims
.
begin
()
+
offset
,
out_dims
.
begin
()
+
offset
,
[
&
](
auto
a
,
auto
b
)
{
[
&
](
auto
a
,
auto
b
)
{
if
(
a
==
b
or
b
==
1
)
if
(
a
==
b
)
{
{
return
a
;
return
a
;
}
}
else
if
(
a
==
1
)
else
if
(
a
==
1
or
b
==
1
)
{
{
return
b
;
// setting optimals to empty, may need to be changed
}
return
shape
::
dynamic_dimension
{
std
::
max
(
a
.
min
,
b
.
min
),
std
::
max
(
a
.
max
,
b
.
max
)};
else
}
{
else
MIGRAPHX_THROW
(
"COMPUTE_BROADCASTED_DYN_DIMS: dynamic shapes {"
+
{
migraphx
::
to_string_range
(
s0
.
dyn_dims
())
+
"} and {"
+
MIGRAPHX_THROW
(
"COMPUTE_BROADCASTED_DYN_DIMS: dynamic shapes {"
+
migraphx
::
to_string_range
(
s1
.
dyn_dims
())
+
"} mismatch!"
);
migraphx
::
to_string_range
(
s0
.
dyn_dims
())
+
"} and {"
+
}
migraphx
::
to_string_range
(
s1
.
dyn_dims
())
+
"} mismatch!"
);
});
}
});
return
out_dims
;
return
out_dims
;
}
}
// Compute the common (broadcasted) dimensions of a list of fixed shapes
std
::
vector
<
shape
::
dynamic_dimension
>
compute_common_dyn_dims
(
const
std
::
vector
<
shape
>&
shapes
)
{
auto
ret_shape
=
shapes
.
at
(
0
);
std
::
for_each
(
shapes
.
cbegin
()
+
1
,
shapes
.
cend
(),
[
&
](
auto
s
)
{
ret_shape
=
shape
{
ret_shape
.
type
(),
compute_broadcasted_dyn_dims
(
ret_shape
,
s
)};
});
return
ret_shape
.
dyn_dims
();
}
std
::
vector
<
std
::
size_t
>
compute_common_lens
(
const
std
::
vector
<
shape
>&
shapes
)
std
::
vector
<
std
::
size_t
>
compute_common_lens
(
const
std
::
vector
<
shape
>&
shapes
)
{
{
assert
(
not
shapes
.
empty
());
assert
(
not
shapes
.
empty
());
...
@@ -154,34 +146,30 @@ insert_common_args(module& m, instruction_ref ins, std::vector<instruction_ref>
...
@@ -154,34 +146,30 @@ insert_common_args(module& m, instruction_ref ins, std::vector<instruction_ref>
if
(
std
::
any_of
(
if
(
std
::
any_of
(
inputs
.
cbegin
(),
inputs
.
cend
(),
[](
auto
input
)
{
return
input
->
get_shape
().
dynamic
();
}))
inputs
.
cbegin
(),
inputs
.
cend
(),
[](
auto
input
)
{
return
input
->
get_shape
().
dynamic
();
}))
{
{
// currently only handles the binary case
auto
input_shapes
=
to_shapes
(
inputs
);
if
(
inputs
.
size
()
!=
2
)
auto
c_type
=
compute_common_types
(
input_shapes
);
{
auto
c_dyn_dims
=
compute_common_dyn_dims
(
input_shapes
);
MIGRAPHX_THROW
(
"INSERT_COMMON_OP: not handled; "
+
migraphx
::
to_string
(
inputs
.
size
())
+
"inputs, only handle two inputs if any are dynamic shape"
);
}
auto
c_type
=
compute_common_types
(
to_shapes
(
inputs
));
auto
c_dyn_dims
=
compute_broadcasted_dyn_dims
(
inputs
[
0
]
->
get_shape
(),
inputs
[
1
]
->
get_shape
());
// following should work for a static or dynamic
shape
auto
s0
=
inputs
[
0
]
->
get_
shape
();
if
(
inputs
[
0
]
->
get_shape
()
.
dyn_dims
()
!=
c_dyn_dims
)
if
(
not
s0
.
dynamic
()
or
s0
.
dyn_dims
()
!=
c_dyn_dims
)
{
{
inputs
[
0
]
=
m
.
insert_instruction
(
inputs
[
0
]
=
m
.
insert_instruction
(
ins
,
ins
,
make_op
(
"multibroadcast"
,
{{
"out_dyn_dims"
,
to_value
(
c_dyn_dims
)}}),
inputs
);
make_op
(
"multibroadcast"
,
{{
"out_dyn_dims"
,
to_value
(
c_dyn_dims
)}}),
inputs
[
0
],
inputs
[
1
]);
}
if
(
inputs
[
1
]
->
get_shape
().
dyn_dims
()
!=
c_dyn_dims
)
{
inputs
[
1
]
=
m
.
insert_instruction
(
ins
,
make_op
(
"multibroadcast"
,
{{
"out_dyn_dims"
,
to_value
(
c_dyn_dims
)}}),
inputs
[
1
],
inputs
[
0
]);
}
}
std
::
transform
(
inputs
.
begin
()
+
1
,
inputs
.
end
(),
inputs
.
begin
()
+
1
,
[
&
](
auto
input
)
{
// uses previous input to avoid recalculating the common shape from the
// full set of input shapes at runtime
auto
s
=
input
->
get_shape
();
if
(
not
s
.
dynamic
()
or
s
.
dyn_dims
()
!=
c_dyn_dims
)
{
return
m
.
insert_instruction
(
ins
,
make_op
(
"multibroadcast"
,
{{
"out_dyn_dims"
,
to_value
(
c_dyn_dims
)}}),
input
,
inputs
[
0
]);
}
return
input
;
});
std
::
transform
(
inputs
.
begin
(),
inputs
.
end
(),
inputs
.
begin
(),
[
&
](
auto
input
)
{
std
::
transform
(
inputs
.
begin
(),
inputs
.
end
(),
inputs
.
begin
(),
[
&
](
auto
input
)
{
if
(
input
->
get_shape
().
type
()
!=
c_type
)
if
(
input
->
get_shape
().
type
()
!=
c_type
)
{
{
...
...
src/cpp_generator.cpp
View file @
23cb7917
...
@@ -106,6 +106,11 @@ cpp_generator::function& cpp_generator::function::set_generic_types(const module
...
@@ -106,6 +106,11 @@ cpp_generator::function& cpp_generator::function::set_generic_types(const module
return
*
this
;
return
*
this
;
}
}
cpp_generator
::
function
&
cpp_generator
::
function
::
unused_param
(
const
std
::
string
&
pname
)
{
body
.
insert
(
0
,
"(void)"
+
pname
+
";
\n
"
);
return
*
this
;
}
cpp_generator
::
function
&
cpp_generator
::
function
::
add_generic_param
(
const
std
::
string
&
pname
)
cpp_generator
::
function
&
cpp_generator
::
function
::
add_generic_param
(
const
std
::
string
&
pname
)
{
{
params
.
push_back
({
pname
,
"T"
+
pname
});
params
.
push_back
({
pname
,
"T"
+
pname
});
...
@@ -174,6 +179,8 @@ std::string cpp_generator::generate_point_op(const operation& op,
...
@@ -174,6 +179,8 @@ std::string cpp_generator::generate_point_op(const operation& op,
else
if
(
with_char
(
::
isdigit
)(
key
[
0
]))
else
if
(
with_char
(
::
isdigit
)(
key
[
0
]))
{
{
auto
i
=
std
::
stoul
(
key
);
auto
i
=
std
::
stoul
(
key
);
if
(
i
>=
args
.
size
())
MIGRAPHX_THROW
(
"Invalid argument index: "
+
key
);
return
args
.
at
(
i
);
return
args
.
at
(
i
);
}
}
else
if
(
v
.
contains
(
key
))
else
if
(
v
.
contains
(
key
))
...
@@ -201,8 +208,24 @@ cpp_generator::function cpp_generator::generate_module(const module& m,
...
@@ -201,8 +208,24 @@ cpp_generator::function cpp_generator::generate_module(const module& m,
f
.
set_name
(
name
).
set_types
(
m
).
set_body
(
f
.
set_name
(
name
).
set_types
(
m
).
set_body
(
m
,
[
&
](
instruction_ref
ins
,
const
auto
&
names
)
->
std
::
string
{
m
,
[
&
](
instruction_ref
ins
,
const
auto
&
names
)
->
std
::
string
{
if
(
ins
->
name
()
==
"@literal"
)
if
(
ins
->
name
()
==
"@literal"
)
return
shape
::
cpp_type
(
ins
->
get_shape
().
type
())
+
"("
+
{
ins
->
get_literal
().
to_string
()
+
")"
;
std
::
string
string_literal
;
ins
->
get_literal
().
visit
([
&
](
auto
v
)
{
assert
(
v
.
size
()
==
1
);
auto
x
=
v
.
front
();
if
(
std
::
isinf
(
x
))
{
string_literal
=
"__builtin_huge_val()"
;
if
(
x
<
0
)
string_literal
=
"-__builtin_huge_val()"
;
}
else
if
(
std
::
isnan
(
x
))
string_literal
=
"__builtin_nan()"
;
else
string_literal
=
ins
->
get_literal
().
to_string
();
});
return
shape
::
cpp_type
(
ins
->
get_shape
().
type
())
+
"("
+
string_literal
+
")"
;
}
auto
s
=
g
(
ins
,
names
);
auto
s
=
g
(
ins
,
names
);
if
(
impl
->
fresult
)
if
(
impl
->
fresult
)
return
impl
->
fresult
(
ins
->
get_shape
())
+
'('
+
s
+
')'
;
return
impl
->
fresult
(
ins
->
get_shape
())
+
'('
+
s
+
')'
;
...
...
src/dead_code_elimination.cpp
View file @
23cb7917
...
@@ -49,8 +49,10 @@ void dead_code_elimination::apply(module& m) const
...
@@ -49,8 +49,10 @@ void dead_code_elimination::apply(module& m) const
if
(
i
==
last
)
if
(
i
==
last
)
break
;
break
;
// Skip instruction with empty shape as output unless its [dynamic, builtin, undefined,
// Skip instruction with empty shape as output unless its [dynamic, builtin, undefined,
// identity, allocate]
// identity, allocate or tuple_type]
if
((
not
i
->
get_shape
().
dynamic
()
and
i
->
get_shape
().
elements
()
==
0
)
and
if
((
not
i
->
get_shape
().
dynamic
()
and
(
i
->
get_shape
().
elements
()
==
0
and
i
->
get_shape
().
type
()
!=
migraphx
::
shape
::
tuple_type
))
and
not
(
i
->
name
().
front
()
==
'@'
)
and
not
contains
({
"identity"
,
"allocate"
},
i
->
name
())
and
not
(
i
->
name
().
front
()
==
'@'
)
and
not
contains
({
"identity"
,
"allocate"
},
i
->
name
())
and
not
i
->
is_undefined
())
not
i
->
is_undefined
())
continue
;
continue
;
...
...
src/driver/CMakeLists.txt
View file @
23cb7917
...
@@ -32,18 +32,20 @@ add_executable(driver
...
@@ -32,18 +32,20 @@ add_executable(driver
marker_roctx.cpp
marker_roctx.cpp
)
)
set_target_properties
(
driver PROPERTIES OUTPUT_NAME migraphx-driver
)
set_target_properties
(
driver PROPERTIES OUTPUT_NAME migraphx-driver
)
# Copy driver for backwards compatibility
if
(
NOT WIN32
)
add_custom_command
(
# Copy driver for backwards compatibility (Linux only)
TARGET driver
add_custom_command
(
TARGET driver
POST_BUILD COMMAND
${
CMAKE_COMMAND
}
-E copy
POST_BUILD COMMAND
${
CMAKE_COMMAND
}
-E copy
$<TARGET_FILE:driver>
$<TARGET_FILE:driver>
${
CMAKE_RUNTIME_OUTPUT_DIRECTORY
}
/driver
${
CMAKE_RUNTIME_OUTPUT_DIRECTORY
}
/driver
BYPRODUCTS
${
CMAKE_RUNTIME_OUTPUT_DIRECTORY
}
/driver
BYPRODUCTS
${
CMAKE_RUNTIME_OUTPUT_DIRECTORY
}
/driver
)
)
set_directory_properties
(
PROPERTIES ADDITIONAL_MAKE_CLEAN_FILES
${
CMAKE_RUNTIME_OUTPUT_DIRECTORY
}
/driver
)
set_directory_properties
(
PROPERTIES ADDITIONAL_CLEAN_FILES
${
CMAKE_RUNTIME_OUTPUT_DIRECTORY
}
/driver
)
endif
()
rocm_clang_tidy_check
(
driver
)
rocm_clang_tidy_check
(
driver
)
target_link_libraries
(
driver migraphx_all_targets migraphx_onnx migraphx_tf
)
target_link_libraries
(
driver migraphx_all_targets migraphx_onnx migraphx_tf
migraphx_py
)
rocm_install_targets
(
rocm_install_targets
(
TARGETS driver
TARGETS driver
...
...
src/driver/argument_parser.hpp
View file @
23cb7917
...
@@ -338,11 +338,22 @@ struct argument_parser
...
@@ -338,11 +338,22 @@ struct argument_parser
MIGRAPHX_DRIVER_STATIC
auto
file_exist
()
MIGRAPHX_DRIVER_STATIC
auto
file_exist
()
{
{
return
validate
([](
auto
&
,
auto
&
,
auto
&
params
)
{
return
validate
([](
auto
&
,
auto
&
,
const
auto
&
params
)
{
if
(
params
.
empty
())
if
(
params
.
empty
())
throw
std
::
runtime_error
(
"No argument passed."
);
throw
std
::
runtime_error
(
"No argument passed."
);
if
(
not
fs
::
exists
(
params
.
back
()))
if
(
not
fs
::
exists
(
params
.
back
()))
throw
std
::
runtime_error
(
"Path does not exists: "
+
params
.
back
());
throw
std
::
runtime_error
(
"Path does not exist: "
+
params
.
back
());
});
}
MIGRAPHX_DRIVER_STATIC
auto
matches
(
const
std
::
unordered_set
<
std
::
string
>&
names
)
{
return
validate
([
=
](
auto
&
,
auto
&
,
const
auto
&
params
)
{
auto
invalid_param
=
std
::
find_if
(
params
.
begin
(),
params
.
end
(),
[
&
](
const
auto
&
p
)
{
return
names
.
count
(
p
)
==
0
;
});
if
(
invalid_param
!=
params
.
end
())
throw
std
::
runtime_error
(
"Invalid argument: "
+
*
invalid_param
+
". Valid arguments are {"
+
to_string_range
(
names
)
+
"}"
);
});
});
}
}
...
@@ -570,8 +581,7 @@ struct argument_parser
...
@@ -570,8 +581,7 @@ struct argument_parser
continue
;
continue
;
if
(
flag
[
0
]
!=
'-'
)
if
(
flag
[
0
]
!=
'-'
)
continue
;
continue
;
auto
d
=
std
::
ptrdiff_t
d
=
levenshtein_distance
(
flag
,
input
);
levenshtein_distance
(
flag
.
begin
(),
flag
.
end
(),
input
.
begin
(),
input
.
end
());
if
(
d
<
result
.
distance
)
if
(
d
<
result
.
distance
)
result
=
result_t
{
&
arg
,
flag
,
input
,
d
};
result
=
result_t
{
&
arg
,
flag
,
input
,
d
};
}
}
...
...
src/driver/main.cpp
View file @
23cb7917
...
@@ -32,6 +32,7 @@
...
@@ -32,6 +32,7 @@
#include <migraphx/tf.hpp>
#include <migraphx/tf.hpp>
#include <migraphx/onnx.hpp>
#include <migraphx/onnx.hpp>
#include <migraphx/py.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/convert_to_json.hpp>
#include <migraphx/convert_to_json.hpp>
#include <migraphx/load_save.hpp>
#include <migraphx/load_save.hpp>
...
@@ -81,6 +82,7 @@ struct loader
...
@@ -81,6 +82,7 @@ struct loader
{
"--model"
},
{
"--model"
},
ap
.
help
(
"Load model"
),
ap
.
help
(
"Load model"
),
ap
.
type
(
"resnet50|inceptionv3|alexnet"
),
ap
.
type
(
"resnet50|inceptionv3|alexnet"
),
ap
.
matches
({
"resnet50"
,
"inceptionv3"
,
"alexnet"
}),
ap
.
group
(
"input"
));
ap
.
group
(
"input"
));
ap
(
file_type
,
{
"--onnx"
},
ap
.
help
(
"Load as onnx"
),
ap
.
set_value
(
"onnx"
));
ap
(
file_type
,
{
"--onnx"
},
ap
.
help
(
"Load as onnx"
),
ap
.
set_value
(
"onnx"
));
ap
(
file_type
,
{
"--tf"
},
ap
.
help
(
"Load as tensorflow"
),
ap
.
set_value
(
"tf"
));
ap
(
file_type
,
{
"--tf"
},
ap
.
help
(
"Load as tensorflow"
),
ap
.
set_value
(
"tf"
));
...
@@ -241,6 +243,20 @@ struct loader
...
@@ -241,6 +243,20 @@ struct loader
return
options
;
return
options
;
}
}
static
std
::
string
get_file_type
(
const
std
::
string
&
file
)
{
if
(
ends_with
(
file
,
".onnx"
))
return
"onnx"
;
else
if
(
ends_with
(
file
,
".pb"
))
return
"tf"
;
else
if
(
ends_with
(
file
,
".json"
))
return
"json"
;
else
if
(
ends_with
(
file
,
".py"
))
return
"py"
;
else
return
"migraphx"
;
}
program
load
()
program
load
()
{
{
program
p
;
program
p
;
...
@@ -248,14 +264,7 @@ struct loader
...
@@ -248,14 +264,7 @@ struct loader
{
{
if
(
file_type
.
empty
())
if
(
file_type
.
empty
())
{
{
if
(
ends_with
(
file
,
".onnx"
))
file_type
=
get_file_type
(
file
);
file_type
=
"onnx"
;
else
if
(
ends_with
(
file
,
".pb"
))
file_type
=
"tf"
;
else
if
(
ends_with
(
file
,
".json"
))
file_type
=
"json"
;
else
file_type
=
"migraphx"
;
}
}
std
::
cout
<<
"Reading: "
<<
file
<<
std
::
endl
;
std
::
cout
<<
"Reading: "
<<
file
<<
std
::
endl
;
if
(
file_type
==
"onnx"
)
if
(
file_type
==
"onnx"
)
...
@@ -272,6 +281,10 @@ struct loader
...
@@ -272,6 +281,10 @@ struct loader
options
.
format
=
"json"
;
options
.
format
=
"json"
;
p
=
migraphx
::
load
(
file
,
options
);
p
=
migraphx
::
load
(
file
,
options
);
}
}
else
if
(
file_type
==
"py"
)
{
p
=
migraphx
::
load_py
(
file
);
}
else
if
(
file_type
==
"migraphx"
)
else
if
(
file_type
==
"migraphx"
)
{
{
p
=
migraphx
::
load
(
file
);
p
=
migraphx
::
load
(
file
);
...
@@ -415,7 +428,8 @@ struct compiler
...
@@ -415,7 +428,8 @@ struct compiler
program_params
parameters
;
program_params
parameters
;
compiler_target
ct
;
compiler_target
ct
;
compile_options
co
;
compile_options
co
;
precision
quantize
=
precision
::
fp32
;
bool
to_fp16
=
false
;
bool
to_int8
=
false
;
std
::
vector
<
std
::
string
>
fill0
;
std
::
vector
<
std
::
string
>
fill0
;
std
::
vector
<
std
::
string
>
fill1
;
std
::
vector
<
std
::
string
>
fill1
;
...
@@ -436,13 +450,8 @@ struct compiler
...
@@ -436,13 +450,8 @@ struct compiler
{
"--exhaustive-tune"
},
{
"--exhaustive-tune"
},
ap
.
help
(
"Exhastively search for best tuning parameters for kernels"
),
ap
.
help
(
"Exhastively search for best tuning parameters for kernels"
),
ap
.
set_value
(
true
));
ap
.
set_value
(
true
));
ap
(
co
.
split_single_dyn_dim
,
ap
(
to_fp16
,
{
"--fp16"
},
ap
.
help
(
"Quantize for fp16"
),
ap
.
set_value
(
true
));
{
"--split-single-dyn-dim"
},
ap
(
to_int8
,
{
"--int8"
},
ap
.
help
(
"Quantize for int8"
),
ap
.
set_value
(
true
));
ap
.
help
(
"If there is a single non-fixed dynamic dimension in the model, then split to "
"static submodules"
),
ap
.
set_value
(
true
));
ap
(
quantize
,
{
"--fp16"
},
ap
.
help
(
"Quantize for fp16"
),
ap
.
set_value
(
precision
::
fp16
));
ap
(
quantize
,
{
"--int8"
},
ap
.
help
(
"Quantize for int8"
),
ap
.
set_value
(
precision
::
int8
));
}
}
auto
params
(
const
program
&
p
)
auto
params
(
const
program
&
p
)
...
@@ -450,20 +459,46 @@ struct compiler
...
@@ -450,20 +459,46 @@ struct compiler
return
parameters
.
generate
(
p
,
ct
.
get_target
(),
co
.
offload_copy
,
l
.
batch
);
return
parameters
.
generate
(
p
,
ct
.
get_target
(),
co
.
offload_copy
,
l
.
batch
);
}
}
auto
host_params
(
const
program
&
p
)
{
return
parameters
.
generate
(
p
,
ct
.
get_target
(),
true
,
l
.
batch
);
}
program
compile
()
program
compile
()
{
{
auto
p
=
l
.
load
();
auto
p
=
l
.
load
();
// Dont compile if its already been compiled
// Dont compile if its already been compiled
if
(
p
.
is_compiled
())
if
(
p
.
is_compiled
())
{
if
(
ct
.
target_name
==
"gpu"
)
{
if
(
is_offload_copy_set
(
p
)
and
not
co
.
offload_copy
)
{
std
::
cout
<<
"MIGraphX program was likely compiled with offload_copy set, Try "
"passing "
"`--enable-offload-copy` if program run fails.
\n
"
;
}
else
if
(
co
.
offload_copy
)
{
std
::
cout
<<
"MIGraphX program was likely compiled without "
"offload_copy set, Try "
"removing "
"`--enable-offload-copy` flag if passed to driver, if program run "
"fails.
\n
"
;
}
}
return
p
;
return
p
;
}
auto
t
=
ct
.
get_target
();
auto
t
=
ct
.
get_target
();
if
(
quantize
==
precision
::
fp16
)
if
(
to_
fp16
)
{
{
quantize_fp16
(
p
);
quantize_fp16
(
p
);
}
}
else
if
(
quantize
==
precision
::
int8
)
if
(
to_
int8
)
{
{
quantize_int8
(
p
,
t
,
{
params
(
p
)});
quantize_int8
(
p
,
t
,
{
host_
params
(
p
)});
}
}
p
.
compile
(
t
,
co
);
p
.
compile
(
t
,
co
);
l
.
save
(
p
);
l
.
save
(
p
);
...
@@ -522,17 +557,23 @@ struct verify : command<verify>
...
@@ -522,17 +557,23 @@ struct verify : command<verify>
auto
t
=
c
.
ct
.
get_target
();
auto
t
=
c
.
ct
.
get_target
();
auto
m
=
c
.
parameters
.
generate
(
p
,
t
,
true
,
c
.
l
.
batch
);
auto
m
=
c
.
parameters
.
generate
(
p
,
t
,
true
,
c
.
l
.
batch
);
auto
quantize
=
precision
::
fp32
;
if
(
c
.
to_fp16
)
quantize
=
precision
::
fp16
;
if
(
c
.
to_int8
)
quantize
=
precision
::
int8
;
if
(
per_instruction
)
if
(
per_instruction
)
{
{
verify_instructions
(
p
,
t
,
c
.
co
,
c
.
quantize
,
tolerance
);
verify_instructions
(
p
,
t
,
c
.
co
,
quantize
,
tolerance
);
}
}
else
if
(
reduce
)
else
if
(
reduce
)
{
{
verify_reduced_program
(
p
,
t
,
c
.
co
,
c
.
quantize
,
m
,
tolerance
);
verify_reduced_program
(
p
,
t
,
c
.
co
,
quantize
,
m
,
tolerance
);
}
}
else
else
{
{
verify_program
(
c
.
l
.
file
,
p
,
t
,
c
.
co
,
c
.
quantize
,
m
,
tolerance
);
verify_program
(
c
.
l
.
file
,
p
,
t
,
c
.
co
,
quantize
,
m
,
tolerance
);
}
}
}
}
};
};
...
@@ -662,6 +703,26 @@ struct onnx : command<onnx>
...
@@ -662,6 +703,26 @@ struct onnx : command<onnx>
}
}
};
};
struct
tf
:
command
<
tf
>
{
bool
show_ops
=
false
;
void
parse
(
argument_parser
&
ap
)
{
ap
(
show_ops
,
{
"--list"
,
"-l"
},
ap
.
help
(
"List all tf operators supported by MIGraphX"
),
ap
.
set_value
(
true
));
}
void
run
()
const
{
if
(
show_ops
)
{
for
(
const
auto
&
name
:
get_tf_operators
())
std
::
cout
<<
name
<<
std
::
endl
;
}
}
};
struct
main_command
struct
main_command
{
{
static
std
::
string
get_command_help
(
const
std
::
string
&
title
=
colorize
(
color
::
fg_yellow
,
static
std
::
string
get_command_help
(
const
std
::
string
&
title
=
colorize
(
color
::
fg_yellow
,
...
@@ -709,7 +770,7 @@ struct main_command
...
@@ -709,7 +770,7 @@ struct main_command
{
{
std
::
cout
<<
"'"
<<
color
::
fg_yellow
<<
wrong_commands
.
front
()
<<
color
::
reset
std
::
cout
<<
"'"
<<
color
::
fg_yellow
<<
wrong_commands
.
front
()
<<
color
::
reset
<<
"' is not a valid command."
<<
std
::
endl
;
<<
"' is not a valid command."
<<
std
::
endl
;
std
::
cout
<<
get_command_help
(
"Available commands:"
)
<<
std
::
endl
;
std
::
cout
<<
get_command_help
(
"Available commands:"
);
}
}
else
else
{
{
...
...
src/driver/perf.cpp
View file @
23cb7917
...
@@ -24,6 +24,8 @@
...
@@ -24,6 +24,8 @@
#include "perf.hpp"
#include "perf.hpp"
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/register_target.hpp>
#ifdef HAVE_GPU
#ifdef HAVE_GPU
#include <migraphx/gpu/hip.hpp>
#include <migraphx/gpu/hip.hpp>
...
@@ -97,6 +99,38 @@ target get_target(bool gpu)
...
@@ -97,6 +99,38 @@ target get_target(bool gpu)
return
make_target
(
"cpu"
);
return
make_target
(
"cpu"
);
}
}
bool
is_offload_copy_set
(
const
program
&
p
)
{
assert
(
p
.
is_compiled
());
const
module
*
mm
=
p
.
get_main_module
();
std
::
vector
<
std
::
string
>
param_names
=
mm
->
get_parameter_names
();
std
::
unordered_set
<
instruction_ref
>
param_ins
;
std
::
transform
(
param_names
.
begin
(),
param_names
.
end
(),
std
::
inserter
(
param_ins
,
param_ins
.
begin
()),
[
&
](
const
auto
&
i
)
{
return
mm
->
get_parameter
(
i
);
});
for
(
const
auto
&
i
:
*
mm
)
{
if
(
i
.
name
()
==
"hip::copy_to_gpu"
)
{
auto
copy_arg
=
instruction
::
get_output_alias
(
i
.
inputs
().
front
(),
true
);
param_ins
.
erase
(
copy_arg
);
}
else
if
(
i
.
name
()
==
"@return"
)
{
auto
return_args
=
i
.
inputs
();
for
(
const
auto
&
j
:
return_args
)
{
auto
alias_ins
=
instruction
::
get_output_alias
(
j
,
true
);
if
((
alias_ins
->
name
()
==
"@param"
&&
param_ins
.
erase
(
alias_ins
)
==
0
)
or
(
alias_ins
->
name
()
!=
"hip::copy_from_gpu"
))
return
false
;
}
}
}
return
param_ins
.
empty
();
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace driver
}
// namespace driver
}
// namespace migraphx
}
// namespace migraphx
src/driver/perf.hpp
View file @
23cb7917
...
@@ -39,6 +39,15 @@ parameter_map create_param_map(const program& p, const target& t, bool offload =
...
@@ -39,6 +39,15 @@ parameter_map create_param_map(const program& p, const target& t, bool offload =
parameter_map
fill_param_map
(
parameter_map
&
m
,
const
program
&
p
,
bool
gpu
);
parameter_map
fill_param_map
(
parameter_map
&
m
,
const
program
&
p
,
bool
gpu
);
parameter_map
create_param_map
(
const
program
&
p
,
bool
gpu
=
true
);
parameter_map
create_param_map
(
const
program
&
p
,
bool
gpu
=
true
);
target
get_target
(
bool
gpu
);
target
get_target
(
bool
gpu
);
/**
* @brief Checks if MIGraphX program compiled for "GPU" has offload_copy set of not. This is
intended to print a HINT for the users and would not always correctly classify compiled program as
with or without offload_copy in all cases.
* @param p Compiled MIGraphX program for GPU backend
* @return true if program is classified as compiled with "offload_copy" set
*/
bool
is_offload_copy_set
(
const
program
&
p
);
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace driver
}
// namespace driver
...
...
src/dynamic_loader.cpp
View file @
23cb7917
...
@@ -48,7 +48,7 @@ struct dynamic_loader_impl
...
@@ -48,7 +48,7 @@ struct dynamic_loader_impl
#pragma GCC diagnostic ignored "-Wignored-attributes"
#pragma GCC diagnostic ignored "-Wignored-attributes"
#endif
#endif
dynamic_loader_impl
(
const
fs
::
path
&
p
,
std
::
shared_ptr
<
tmp_dir
>
t
=
nullptr
)
dynamic_loader_impl
(
const
fs
::
path
&
p
,
std
::
shared_ptr
<
tmp_dir
>
t
=
nullptr
)
:
handle
(
dlopen
(
p
.
string
().
c_str
(),
RTLD_
LAZY
),
:
handle
(
dlopen
(
p
.
string
().
c_str
(),
RTLD_
GLOBAL
|
RTLD_NOW
),
manage_deleter
<
decltype
(
&
dlclose
),
&
dlclose
>
{}),
manage_deleter
<
decltype
(
&
dlclose
),
&
dlclose
>
{}),
temp
(
std
::
move
(
t
))
temp
(
std
::
move
(
t
))
{
{
...
@@ -81,6 +81,18 @@ fs::path dynamic_loader::path(void* address)
...
@@ -81,6 +81,18 @@ fs::path dynamic_loader::path(void* address)
return
p
;
return
p
;
}
}
optional
<
dynamic_loader
>
dynamic_loader
::
try_load
(
const
fs
::
path
&
p
)
{
try
{
return
dynamic_loader
{
p
};
}
catch
(
const
std
::
exception
&
)
{
return
nullopt
;
}
}
dynamic_loader
::
dynamic_loader
(
const
fs
::
path
&
p
)
:
impl
(
std
::
make_shared
<
dynamic_loader_impl
>
(
p
))
dynamic_loader
::
dynamic_loader
(
const
fs
::
path
&
p
)
:
impl
(
std
::
make_shared
<
dynamic_loader_impl
>
(
p
))
{
{
}
}
...
...
src/fuse_pointwise.cpp
View file @
23cb7917
...
@@ -31,6 +31,8 @@
...
@@ -31,6 +31,8 @@
#include <migraphx/ranges.hpp>
#include <migraphx/ranges.hpp>
#include <iterator>
#include <iterator>
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_DISABLE_POINTWISE_FUSION
)
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -39,7 +41,7 @@ static literal get_scalar(instruction_ref ins)
...
@@ -39,7 +41,7 @@ static literal get_scalar(instruction_ref ins)
if
(
ins
->
name
()
==
"contiguous"
)
if
(
ins
->
name
()
==
"contiguous"
)
return
get_scalar
(
ins
->
inputs
().
front
());
return
get_scalar
(
ins
->
inputs
().
front
());
const
auto
&
s
=
ins
->
get_shape
();
const
auto
&
s
=
ins
->
get_shape
();
if
(
s
.
elements
()
!=
1
&&
not
(
s
.
scalar
()))
if
(
s
.
elements
()
!=
1
and
not
(
s
.
scalar
()))
return
{};
return
{};
if
(
not
ins
->
can_eval
())
if
(
not
ins
->
can_eval
())
return
{};
return
{};
...
@@ -67,7 +69,6 @@ static void create_pointwise_modules(module_pass_manager& mpm)
...
@@ -67,7 +69,6 @@ static void create_pointwise_modules(module_pass_manager& mpm)
continue
;
continue
;
if
(
ins
->
get_operator
().
name
()
==
"layout"
)
if
(
ins
->
get_operator
().
name
()
==
"layout"
)
continue
;
continue
;
assert
(
ins
->
get_operator
().
attributes
().
contains
(
"point_op"
));
auto
*
pm
=
mpm
.
create_module
(
mpm
.
get_module
().
name
()
+
":pointwise"
+
std
::
to_string
(
n
++
));
auto
*
pm
=
mpm
.
create_module
(
mpm
.
get_module
().
name
()
+
":pointwise"
+
std
::
to_string
(
n
++
));
pm
->
set_bypass
();
pm
->
set_bypass
();
...
@@ -193,6 +194,10 @@ void fuse_pointwise::apply(module_pass_manager& mpm) const
...
@@ -193,6 +194,10 @@ void fuse_pointwise::apply(module_pass_manager& mpm) const
{
{
create_pointwise_modules
(
mpm
);
create_pointwise_modules
(
mpm
);
mpm
.
run_pass
(
dead_code_elimination
{});
mpm
.
run_pass
(
dead_code_elimination
{});
if
(
enabled
(
MIGRAPHX_DISABLE_POINTWISE_FUSION
{}))
{
return
;
}
for
(
int
i
=
0
;
i
<
8
;
i
++
)
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
{
if
(
not
find_pointwise_modules
(
mpm
.
get_module
()))
if
(
not
find_pointwise_modules
(
mpm
.
get_module
()))
...
...
src/fuse_reduce.cpp
View file @
23cb7917
...
@@ -52,7 +52,7 @@ struct fused_reduce
...
@@ -52,7 +52,7 @@ struct fused_reduce
{
{
if
(
mods
.
size
()
!=
1
)
if
(
mods
.
size
()
!=
1
)
MIGRAPHX_THROW
(
"should have one submodule."
);
MIGRAPHX_THROW
(
"should have one submodule."
);
auto
*
sm
=
mods
.
front
();
const
auto
*
sm
=
mods
.
front
();
if
(
sm
->
get_output_shapes
().
size
()
!=
1
)
if
(
sm
->
get_output_shapes
().
size
()
!=
1
)
MIGRAPHX_THROW
(
"Only one output supported"
);
MIGRAPHX_THROW
(
"Only one output supported"
);
auto
names
=
sm
->
get_parameter_names
();
auto
names
=
sm
->
get_parameter_names
();
...
@@ -143,7 +143,7 @@ insert_module_in_submodule(module_ref sm,
...
@@ -143,7 +143,7 @@ insert_module_in_submodule(module_ref sm,
}
}
static
std
::
vector
<
instruction_ref
>
static
std
::
vector
<
instruction_ref
>
find_inputs
(
module_ref
sm
,
find_inputs
(
const_
module_ref
sm
,
const
module
&
parent
,
const
module
&
parent
,
const
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>&
map_ins
)
const
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>&
map_ins
)
{
{
...
...
src/generate.cpp
View file @
23cb7917
...
@@ -26,7 +26,7 @@
...
@@ -26,7 +26,7 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
argument
fill_argument
(
shape
s
,
unsigned
long
value
)
argument
fill_argument
(
shape
s
,
double
value
)
{
{
argument
result
;
argument
result
;
if
(
s
.
type
()
==
shape
::
tuple_type
)
if
(
s
.
type
()
==
shape
::
tuple_type
)
...
...
src/include/migraphx/adjust_allocation.hpp
View file @
23cb7917
...
@@ -32,7 +32,7 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -32,7 +32,7 @@ inline namespace MIGRAPHX_INLINE_NS {
struct
module
;
struct
module
;
struct
adjust_allocation
struct
MIGRAPHX_EXPORT
adjust_allocation
{
{
allocation_model
model
;
allocation_model
model
;
std
::
string
name
()
const
{
return
"adjust_allocation"
;
}
std
::
string
name
()
const
{
return
"adjust_allocation"
;
}
...
...
src/include/migraphx/algorithm.hpp
View file @
23cb7917
...
@@ -26,6 +26,8 @@
...
@@ -26,6 +26,8 @@
#include <algorithm>
#include <algorithm>
#include <numeric>
#include <numeric>
#include <string>
#include <vector>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
namespace
migraphx
{
...
@@ -90,6 +92,42 @@ levenshtein_distance(Iterator1 first1, Iterator1 last1, Iterator2 first2, Iterat
...
@@ -90,6 +92,42 @@ levenshtein_distance(Iterator1 first1, Iterator1 last1, Iterator2 first2, Iterat
return
std
::
ptrdiff_t
{
1
}
+
std
::
min
({
x1
,
x2
,
x3
});
return
std
::
ptrdiff_t
{
1
}
+
std
::
min
({
x1
,
x2
,
x3
});
}
}
inline
size_t
levenshtein_distance
(
const
std
::
string
&
s1
,
const
std
::
string
&
s2
)
{
const
size_t
l1
=
s1
.
length
();
const
size_t
l2
=
s2
.
length
();
if
(
l1
<
l2
)
levenshtein_distance
(
s2
,
s1
);
std
::
vector
<
size_t
>
d
(
l2
+
1
);
std
::
iota
(
d
.
begin
(),
d
.
end
(),
0
);
for
(
size_t
i
=
1
;
i
<=
l1
;
i
++
)
{
size_t
prev_cost
=
d
[
0
];
d
[
0
]
=
i
;
for
(
size_t
j
=
1
;
j
<=
l2
;
j
++
)
{
if
(
s1
[
i
-
1
]
==
s2
[
j
-
1
])
{
d
[
j
]
=
prev_cost
;
}
else
{
size_t
cost_insert_or_delete
=
std
::
min
(
d
[
j
-
1
],
d
[
j
]);
size_t
cost_substitute
=
prev_cost
;
prev_cost
=
d
[
j
];
d
[
j
]
=
std
::
min
(
cost_substitute
,
cost_insert_or_delete
)
+
1
;
}
}
}
return
d
[
l2
];
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
...
...
src/include/migraphx/allocation_model.hpp
View file @
23cb7917
...
@@ -60,7 +60,7 @@ struct allocation_model
...
@@ -60,7 +60,7 @@ struct allocation_model
#ifdef TYPE_ERASED_DECLARATION
#ifdef TYPE_ERASED_DECLARATION
// Type-erased interface for:
// Type-erased interface for:
struct
allocation_model
struct
MIGRAPHX_EXPORT
allocation_model
{
{
//
//
std
::
string
name
()
const
;
std
::
string
name
()
const
;
...
@@ -96,7 +96,7 @@ struct allocation_model
...
@@ -96,7 +96,7 @@ struct allocation_model
{
{
using
std
::
swap
;
using
std
::
swap
;
auto
*
derived
=
this
->
any_cast
<
PrivateDetailTypeErasedT
>
();
auto
*
derived
=
this
->
any_cast
<
PrivateDetailTypeErasedT
>
();
if
(
derived
and
private_detail_te_handle_mem_var
.
u
nique
()
)
if
(
derived
and
private_detail_te_handle_mem_var
.
u
se_count
()
==
1
)
{
{
*
derived
=
std
::
forward
<
PrivateDetailTypeErasedT
>
(
value
);
*
derived
=
std
::
forward
<
PrivateDetailTypeErasedT
>
(
value
);
}
}
...
@@ -267,7 +267,7 @@ struct allocation_model
...
@@ -267,7 +267,7 @@ struct allocation_model
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
{
{
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
if
(
not
private_detail_te_handle_mem_var
.
u
nique
()
)
if
(
private_detail_te_handle_mem_var
.
u
se_count
()
>
1
)
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
return
*
private_detail_te_handle_mem_var
;
return
*
private_detail_te_handle_mem_var
;
}
}
...
...
src/include/migraphx/analyze_streams.hpp
View file @
23cb7917
...
@@ -39,7 +39,8 @@ struct stream_race
...
@@ -39,7 +39,8 @@ struct stream_race
instruction_ref
before
;
instruction_ref
before
;
};
};
std
::
vector
<
stream_race
>
analyze_streams
(
const
module
&
m
,
const
stream_model
&
strmm
);
MIGRAPHX_EXPORT
std
::
vector
<
stream_race
>
analyze_streams
(
const
module
&
m
,
const
stream_model
&
strmm
);
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
...
...
Prev
1
2
3
4
5
6
7
8
…
23
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment