Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
a0edd061
Commit
a0edd061
authored
Jul 11, 2022
by
Paul
Browse files
Merge branch 'develop' into jit-improve
parents
6deee23b
2781ccd8
Changes
59
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
610 additions
and
82 deletions
+610
-82
Dockerfile
Dockerfile
+1
-1
src/CMakeLists.txt
src/CMakeLists.txt
+1
-0
src/eliminate_contiguous.cpp
src/eliminate_contiguous.cpp
+36
-18
src/include/migraphx/assignment_options.hpp
src/include/migraphx/assignment_options.hpp
+40
-0
src/include/migraphx/check_shapes.hpp
src/include/migraphx/check_shapes.hpp
+5
-0
src/include/migraphx/matcher.hpp
src/include/migraphx/matcher.hpp
+13
-11
src/include/migraphx/module.hpp
src/include/migraphx/module.hpp
+8
-2
src/include/migraphx/module_ref.hpp
src/include/migraphx/module_ref.hpp
+2
-1
src/include/migraphx/op/unsqueeze.hpp
src/include/migraphx/op/unsqueeze.hpp
+22
-7
src/include/migraphx/permutation.hpp
src/include/migraphx/permutation.hpp
+6
-0
src/include/migraphx/program.hpp
src/include/migraphx/program.hpp
+5
-0
src/include/migraphx/ranges.hpp
src/include/migraphx/ranges.hpp
+6
-0
src/include/migraphx/shape.hpp
src/include/migraphx/shape.hpp
+75
-1
src/include/migraphx/support_metric.hpp
src/include/migraphx/support_metric.hpp
+38
-0
src/include/migraphx/target.hpp
src/include/migraphx/target.hpp
+49
-0
src/include/migraphx/target_assignments.hpp
src/include/migraphx/target_assignments.hpp
+47
-0
src/module.cpp
src/module.cpp
+22
-12
src/program.cpp
src/program.cpp
+38
-14
src/serialize.cpp
src/serialize.cpp
+2
-2
src/shape.cpp
src/shape.cpp
+194
-13
No files found.
Dockerfile
View file @
a0edd061
...
@@ -86,7 +86,7 @@ RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXR
...
@@ -86,7 +86,7 @@ RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXR
ADD
tools/build_and_test_onnxrt.sh /onnxruntime/build_and_test_onnxrt.sh
ADD
tools/build_and_test_onnxrt.sh /onnxruntime/build_and_test_onnxrt.sh
RUN
PATH
=
/opt/cmake/bin:
$PATH
cget
-p
/usr/local
install
ROCmSoftwarePlatform/llvm-project-mlir@
02078ce236ad90e3aec04c0c770ef5bfc99e49c2
RUN
cget
-p
/usr/local
install
ROCmSoftwarePlatform/llvm-project-mlir@
26a4b3cfc0a1a15181490f24ae461608fef1b04e
-DBUILD_MIXR_TARGET
=
On
ENV
MIOPEN_FIND_DB_PATH=/tmp/miopen/find-db
ENV
MIOPEN_FIND_DB_PATH=/tmp/miopen/find-db
ENV
MIOPEN_USER_DB_PATH=/tmp/miopen/user-db
ENV
MIOPEN_USER_DB_PATH=/tmp/miopen/user-db
...
...
src/CMakeLists.txt
View file @
a0edd061
...
@@ -88,6 +88,7 @@ add_library(migraphx
...
@@ -88,6 +88,7 @@ add_library(migraphx
shape.cpp
shape.cpp
simplify_algebra.cpp
simplify_algebra.cpp
simplify_reshapes.cpp
simplify_reshapes.cpp
target_assignments.cpp
tmp_dir.cpp
tmp_dir.cpp
value.cpp
value.cpp
verify_args.cpp
verify_args.cpp
...
...
src/eliminate_contiguous.cpp
View file @
a0edd061
...
@@ -93,9 +93,11 @@ static bool try_compute_shape(instruction_ref ins,
...
@@ -93,9 +93,11 @@ static bool try_compute_shape(instruction_ref ins,
return
try_compute_shape
(
ins
,
inputs
,
mods
);
return
try_compute_shape
(
ins
,
inputs
,
mods
);
}
}
void
eliminate_contiguous
::
apply
(
module
&
m
)
const
template
<
class
F
>
static
void
remove_contiguous
(
const
std
::
string
&
op_name
,
module
&
m
,
F
f
)
{
{
std
::
vector
<
instruction_ref
>
const_instruction
;
auto
last
=
std
::
prev
(
m
.
end
());
std
::
vector
<
instruction_ref
>
const_instructions
;
for
(
auto
ins
:
iterator_for
(
m
))
for
(
auto
ins
:
iterator_for
(
m
))
{
{
...
@@ -103,6 +105,12 @@ void eliminate_contiguous::apply(module& m) const
...
@@ -103,6 +105,12 @@ void eliminate_contiguous::apply(module& m) const
if
(
ins
->
name
()
==
"@return"
)
if
(
ins
->
name
()
==
"@return"
)
continue
;
continue
;
if
(
ins
!=
last
and
ins
->
outputs
().
empty
())
continue
;
if
(
not
f
(
ins
))
continue
;
// Make a copy so we can modify it while we iterate
// Make a copy so we can modify it while we iterate
auto
args
=
ins
->
inputs
();
auto
args
=
ins
->
inputs
();
auto
new_args
=
args
;
auto
new_args
=
args
;
...
@@ -110,36 +118,46 @@ void eliminate_contiguous::apply(module& m) const
...
@@ -110,36 +118,46 @@ void eliminate_contiguous::apply(module& m) const
for
(
auto
arg
:
ins
->
inputs
())
for
(
auto
arg
:
ins
->
inputs
())
{
{
if
(
arg
->
name
()
==
op_name
)
if
(
arg
->
name
()
!=
op_name
)
continue
;
auto
prev
=
arg
->
inputs
().
front
();
replace
(
new_args
,
arg
,
prev
);
if
(
try_compute_shape
(
ins
,
new_args
,
mod_args
))
{
instruction
::
replace_argument
(
ins
,
arg
,
prev
);
}
else
if
(
prev
->
can_eval
())
{
{
auto
prev
=
arg
->
inputs
().
front
();
const_instructions
.
push_back
(
arg
);
replace
(
new_args
,
arg
,
prev
);
if
(
try_compute_shape
(
ins
,
new_args
,
mod_args
))
{
instruction
::
replace_argument
(
ins
,
arg
,
prev
);
}
else
if
(
prev
->
can_eval
())
{
const_instruction
.
push_back
(
arg
);
}
}
}
}
}
}
}
// Perform evaluations in parallel
// Perform evaluations in parallel
std
::
vector
<
argument
>
literals
(
const_instruction
.
size
());
std
::
vector
<
argument
>
literals
(
const_instruction
s
.
size
());
par_for
(
const_instruction
.
size
(),
1
,
[
&
](
const
auto
i
)
{
par_for
(
const_instruction
s
.
size
(),
1
,
[
&
](
const
auto
i
)
{
auto
c
=
op
::
contiguous
{};
auto
c
=
op
::
contiguous
{};
auto
prev
=
const_instruction
[
i
]
->
inputs
().
front
();
auto
prev
=
const_instruction
s
[
i
]
->
inputs
().
front
();
literals
[
i
]
=
c
.
compute
(
c
.
compute_shape
({
prev
->
get_shape
()}),
{
prev
->
eval
()});
literals
[
i
]
=
c
.
compute
(
c
.
compute_shape
({
prev
->
get_shape
()}),
{
prev
->
eval
()});
});
});
for
(
size_t
i
=
0
;
i
<
const_instruction
.
size
();
i
++
)
for
(
size_t
i
=
0
;
i
<
const_instruction
s
.
size
();
i
++
)
{
{
auto
l
=
m
.
add_literal
(
literals
[
i
].
get_shape
(),
literals
[
i
].
data
());
auto
l
=
m
.
add_literal
(
literals
[
i
].
get_shape
(),
literals
[
i
].
data
());
m
.
replace_instruction
(
const_instruction
[
i
],
l
);
m
.
replace_instruction
(
const_instruction
s
[
i
],
l
);
}
}
}
}
void
eliminate_contiguous
::
apply
(
module
&
m
)
const
{
// Skip contiguous from splits first
remove_contiguous
(
op_name
,
m
,
[](
auto
ins
)
{
if
(
ins
->
name
()
!=
"slice"
)
return
true
;
return
(
ins
->
inputs
().
front
()
->
outputs
().
size
()
==
1
);
});
remove_contiguous
(
op_name
,
m
,
[](
auto
)
{
return
true
;
});
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
src/include/migraphx/assignment_options.hpp
0 → 100644
View file @
a0edd061
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_ASSIGNMENT_OPTIONS_HPP
#define MIGRAPHX_GUARD_RTGLIB_ASSIGNMENT_OPTIONS_HPP
#include <migraphx/support_metric.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
assignment_options
{
support_metric
metric
=
support_metric
::
latency
;
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_RTGLIB_ASSIGNMENT_OPTIONS_HPP
src/include/migraphx/check_shapes.hpp
View file @
a0edd061
...
@@ -71,6 +71,11 @@ struct check_shapes
...
@@ -71,6 +71,11 @@ struct check_shapes
return
end
-
begin
;
return
end
-
begin
;
}
}
/*!
* Check if the number of shape objects is equal to atleast one of the
* given sizes.
* \param ns template parameter pack of sizes to check against
*/
template
<
class
...
Ts
>
template
<
class
...
Ts
>
const
check_shapes
&
has
(
Ts
...
ns
)
const
const
check_shapes
&
has
(
Ts
...
ns
)
const
{
{
...
...
src/include/migraphx/matcher.hpp
View file @
a0edd061
...
@@ -349,25 +349,27 @@ match::matcher_result find_match(module& modl, M&& m)
...
@@ -349,25 +349,27 @@ match::matcher_result find_match(module& modl, M&& m)
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_TRACE_MATCHES
)
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_TRACE_MATCHES
)
/// Find matches for an instruction in the module
/// Find matches for an instruction in the module
template
<
class
...
Ms
>
template
<
class
Mod
,
class
...
Ms
>
void
find_matches
(
m
od
ule
&
mod
,
instruction_ref
ins
,
Ms
&&
...
ms
)
void
find_matches
(
M
od
&
mod
,
instruction_ref
ins
,
Ms
&&
...
ms
)
{
{
#if !defined(__GNUC__) || defined(__clang__) || __GNUC__ > 5
#if !defined(__GNUC__) || defined(__clang__) || __GNUC__ > 5
const
const
#endif
#endif
bool
trace
=
enabled
(
MIGRAPHX_TRACE_MATCHES
{});
int
trace
=
value_of
(
MIGRAPHX_TRACE_MATCHES
{});
bool
match
=
false
;
bool
match
=
false
;
each_args
(
each_args
(
[
&
](
auto
&&
m
)
{
[
&
](
auto
&&
m
)
{
if
(
match
)
if
(
match
)
return
;
return
;
auto
r
=
match_instruction
(
mod
,
ins
,
m
.
matcher
());
if
(
trace
>
1
)
if
(
r
.
result
==
mod
.
end
())
std
::
cout
<<
"Match: "
<<
get_type_name
(
m
)
<<
std
::
endl
;
auto
r
=
match_instruction
(
get_module
(
mod
),
ins
,
m
.
matcher
());
if
(
r
.
result
==
get_module
(
mod
).
end
())
return
;
return
;
if
(
trace
)
if
(
trace
>
0
)
{
{
std
::
cout
<<
"Matched by "
<<
get_type_name
(
m
)
<<
std
::
endl
;
std
::
cout
<<
"Matched by "
<<
get_type_name
(
m
)
<<
std
::
endl
;
mod
.
debug_print
(
ins
);
get_module
(
mod
)
.
debug_print
(
ins
);
}
}
m
.
apply
(
mod
,
r
);
m
.
apply
(
mod
,
r
);
match
=
true
;
match
=
true
;
...
@@ -376,10 +378,10 @@ void find_matches(module& mod, instruction_ref ins, Ms&&... ms)
...
@@ -376,10 +378,10 @@ void find_matches(module& mod, instruction_ref ins, Ms&&... ms)
}
}
/// Find matches in a module
/// Find matches in a module
template
<
class
...
Ms
>
template
<
class
Mod
,
class
...
Ms
>
void
find_matches
(
m
od
ule
&
mod
,
Ms
&&
...
ms
)
void
find_matches
(
M
od
&
mod
,
Ms
&&
...
ms
)
{
{
for
(
auto
ins
:
iterator_for
(
mod
))
for
(
auto
ins
:
iterator_for
(
get_module
(
mod
)
))
{
{
find_matches
(
mod
,
ins
,
ms
...);
find_matches
(
mod
,
ins
,
ms
...);
}
}
...
...
src/include/migraphx/module.hpp
View file @
a0edd061
...
@@ -124,7 +124,7 @@ struct module
...
@@ -124,7 +124,7 @@ struct module
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
map_ins
=
{});
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
map_ins
=
{});
std
::
vector
<
instruction_ref
>
std
::
vector
<
instruction_ref
>
add_instructions
(
module_ref
m
,
add_instructions
(
const_
module_ref
m
,
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
map_ins
=
{});
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
map_ins
=
{});
std
::
vector
<
instruction_ref
>
std
::
vector
<
instruction_ref
>
...
@@ -139,7 +139,7 @@ struct module
...
@@ -139,7 +139,7 @@ struct module
std
::
vector
<
instruction_ref
>
std
::
vector
<
instruction_ref
>
insert_instructions
(
instruction_ref
ins
,
insert_instructions
(
instruction_ref
ins
,
module_ref
m
,
const_
module_ref
m
,
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
map_ins
=
{});
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
map_ins
=
{});
std
::
vector
<
instruction_ref
>
std
::
vector
<
instruction_ref
>
...
@@ -164,6 +164,10 @@ struct module
...
@@ -164,6 +164,10 @@ struct module
instruction_ref
replace_return
(
std
::
vector
<
instruction_ref
>
args
);
instruction_ref
replace_return
(
std
::
vector
<
instruction_ref
>
args
);
instruction_ref
insert_literal
(
instruction_ref
ins
,
literal
l
);
instruction_ref
insert_parameter
(
instruction_ref
ins
,
std
::
string
name
,
shape
s
);
std
::
vector
<
std
::
string
>
get_parameter_names
()
const
;
std
::
vector
<
std
::
string
>
get_parameter_names
()
const
;
shape
get_parameter_shape
(
std
::
string
name
)
const
;
shape
get_parameter_shape
(
std
::
string
name
)
const
;
...
@@ -227,6 +231,8 @@ struct module
...
@@ -227,6 +231,8 @@ struct module
std
::
unique_ptr
<
module_impl
>
impl
;
std
::
unique_ptr
<
module_impl
>
impl
;
};
};
inline
module
&
get_module
(
module
&
m
)
{
return
m
;
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
...
...
src/include/migraphx/module_ref.hpp
View file @
a0edd061
...
@@ -32,7 +32,8 @@ namespace migraphx {
...
@@ -32,7 +32,8 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
module
;
struct
module
;
using
module_ref
=
module
*
;
using
module_ref
=
module
*
;
using
const_module_ref
=
const
module
*
;
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
...
...
src/include/migraphx/op/unsqueeze.hpp
View file @
a0edd061
...
@@ -42,11 +42,12 @@ namespace op {
...
@@ -42,11 +42,12 @@ namespace op {
struct
unsqueeze
struct
unsqueeze
{
{
std
::
vector
<
int64_t
>
axes
;
std
::
vector
<
int64_t
>
axes
;
std
::
vector
<
int64_t
>
steps
;
template
<
class
Self
,
class
F
>
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
static
auto
reflect
(
Self
&
self
,
F
f
)
{
{
return
pack
(
f
(
self
.
axes
,
"axes"
));
return
pack
(
f
(
self
.
axes
,
"axes"
)
,
f
(
self
.
steps
,
"steps"
)
);
}
}
value
attributes
()
const
value
attributes
()
const
...
@@ -73,6 +74,9 @@ struct unsqueeze
...
@@ -73,6 +74,9 @@ struct unsqueeze
MIGRAPHX_THROW
(
"UNSQUEEZE: Input must be a scalar"
);
MIGRAPHX_THROW
(
"UNSQUEEZE: Input must be a scalar"
);
}
}
if
(
steps
.
size
()
>
axes
.
size
())
MIGRAPHX_THROW
(
"UNSQUEEZE: Steps provided with no axis"
);
std
::
size_t
new_size
=
old_lens
.
size
()
+
axes
.
size
();
std
::
size_t
new_size
=
old_lens
.
size
()
+
axes
.
size
();
std
::
vector
<
std
::
size_t
>
new_lens
(
new_size
);
std
::
vector
<
std
::
size_t
>
new_lens
(
new_size
);
...
@@ -80,16 +84,27 @@ struct unsqueeze
...
@@ -80,16 +84,27 @@ struct unsqueeze
std
::
size_t
p
=
0
;
std
::
size_t
p
=
0
;
for
(
auto
i
:
range
(
new_size
))
for
(
auto
i
:
range
(
new_size
))
{
{
if
(
std
::
find
(
axes
.
begin
(),
axes
.
end
(),
i
)
!=
axes
.
end
())
auto
axis_idx
=
std
::
find
(
axes
.
begin
(),
axes
.
end
(),
i
)
-
axes
.
begin
();
if
(
axis_idx
<
axes
.
size
())
{
{
new_lens
[
i
]
=
1
;
std
::
int64_t
step
=
1
;
if
(
p
==
0
)
// unsqueeze on the first axes
if
(
axis_idx
<
steps
.
size
())
step
=
steps
[
axis_idx
];
if
(
step
==
0
)
MIGRAPHX_THROW
(
"UNSQUEEZE: step must be non-zero"
);
new_lens
[
i
]
=
step
;
if
(
p
<
old_strides
.
size
())
{
{
new_strides
[
i
]
=
old_lens
[
0
]
*
old_strides
[
0
];
if
((
old_lens
[
p
]
%
step
)
!=
0
)
MIGRAPHX_THROW
(
"UNSQUEEZE: Axis dimenstion is not divisible by step"
);
old_lens
[
p
]
/=
step
;
new_strides
[
i
]
=
old_strides
[
p
]
*
old_lens
[
p
];
}
}
else
// unsqueeze on middle or last axes
else
{
{
new_strides
[
i
]
=
(
p
<
old_strides
.
size
())
?
old_strides
[
p
-
1
]
:
1
;
if
(
step
!=
1
)
MIGRAPHX_THROW
(
"UNSQUEEZE: Step must be 1 for extra axes"
);
new_strides
[
i
]
=
1
;
}
}
}
}
else
else
...
...
src/include/migraphx/permutation.hpp
View file @
a0edd061
...
@@ -55,8 +55,14 @@ inline std::vector<int64_t> sort_permutation(const Vector& data, Op op)
...
@@ -55,8 +55,14 @@ inline std::vector<int64_t> sort_permutation(const Vector& data, Op op)
return
result
;
return
result
;
}
}
/*!
* Returns the permutation needed to apply to the shape to undo the current permutation
*/
std
::
vector
<
int64_t
>
invert_permutation
(
const
std
::
vector
<
int64_t
>&
permutation
);
std
::
vector
<
int64_t
>
invert_permutation
(
const
std
::
vector
<
int64_t
>&
permutation
);
/*!
* Finds the permutation most likely from a transpose operator that has been applied to the shape.
*/
std
::
vector
<
int64_t
>
find_permutation
(
const
shape
&
s
);
std
::
vector
<
int64_t
>
find_permutation
(
const
shape
&
s
);
std
::
vector
<
int64_t
>
find_permutation
(
const
std
::
vector
<
shape
>&
shapes
);
std
::
vector
<
int64_t
>
find_permutation
(
const
std
::
vector
<
shape
>&
shapes
);
...
...
src/include/migraphx/program.hpp
View file @
a0edd061
...
@@ -33,6 +33,8 @@
...
@@ -33,6 +33,8 @@
#include <migraphx/instruction_ref.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/target.hpp>
#include <migraphx/target.hpp>
#include <migraphx/compile_options.hpp>
#include <migraphx/compile_options.hpp>
#include <migraphx/target_assignments.hpp>
#include <migraphx/assignment_options.hpp>
#include <migraphx/env.hpp>
#include <migraphx/env.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <algorithm>
#include <algorithm>
...
@@ -84,6 +86,9 @@ struct program
...
@@ -84,6 +86,9 @@ struct program
instruction_ref
validate
()
const
;
instruction_ref
validate
()
const
;
target_assignments
get_target_assignments
(
const
std
::
vector
<
target
>&
targets
,
assignment_options
options
=
assignment_options
{});
void
compile
(
const
target
&
t
,
compile_options
options
=
compile_options
{});
void
compile
(
const
target
&
t
,
compile_options
options
=
compile_options
{});
bool
is_compiled
()
const
;
bool
is_compiled
()
const
;
...
...
src/include/migraphx/ranges.hpp
View file @
a0edd061
...
@@ -198,6 +198,12 @@ void transform(Range&& r, Iterator it, F f)
...
@@ -198,6 +198,12 @@ void transform(Range&& r, Iterator it, F f)
std
::
transform
(
r
.
begin
(),
r
.
end
(),
it
,
f
);
std
::
transform
(
r
.
begin
(),
r
.
end
(),
it
,
f
);
}
}
template
<
class
Range1
,
class
Range2
,
class
Iterator
,
class
F
>
void
transform
(
Range1
&&
r1
,
Range2
&&
r2
,
Iterator
it
,
F
f
)
{
std
::
transform
(
r1
.
begin
(),
r1
.
end
(),
r2
.
begin
(),
it
,
f
);
}
template
<
class
Range
>
template
<
class
Range
>
auto
reverse
(
Range
&
r
)
auto
reverse
(
Range
&
r
)
{
{
...
...
src/include/migraphx/shape.hpp
View file @
a0edd061
...
@@ -82,6 +82,23 @@ struct shape
...
@@ -82,6 +82,23 @@ struct shape
{
{
};
};
struct
dynamic_dimension
{
std
::
size_t
min
=
0
;
std
::
size_t
max
=
0
;
std
::
size_t
opt
=
0
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
);
bool
is_fixed
()
const
;
bool
has_optimal
()
const
;
friend
bool
operator
==
(
const
dynamic_dimension
&
x
,
const
dynamic_dimension
&
y
);
friend
bool
operator
!=
(
const
dynamic_dimension
&
x
,
const
dynamic_dimension
&
y
);
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
dynamic_dimension
&
x
);
};
static
const
std
::
vector
<
type_t
>&
types
();
static
const
std
::
vector
<
type_t
>&
types
();
static
std
::
string
name
(
type_t
t
);
static
std
::
string
name
(
type_t
t
);
...
@@ -92,6 +109,12 @@ struct shape
...
@@ -92,6 +109,12 @@ struct shape
shape
(
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
);
shape
(
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
);
shape
(
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
,
std
::
vector
<
std
::
size_t
>
s
);
shape
(
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
,
std
::
vector
<
std
::
size_t
>
s
);
// Force all calls of the format `shape( type_t, { size_t compatibles } )` to map to
// shape(type_t, std::vector<std::size_t> l)
shape
(
type_t
t
,
std
::
initializer_list
<
std
::
size_t
>
d
);
shape
(
type_t
t
,
std
::
vector
<
dynamic_dimension
>
dims
);
template
<
class
Range
>
template
<
class
Range
>
shape
(
type_t
t
,
const
Range
&
l
)
:
shape
(
t
,
std
::
vector
<
std
::
size_t
>
(
l
.
begin
(),
l
.
end
()))
shape
(
type_t
t
,
const
Range
&
l
)
:
shape
(
t
,
std
::
vector
<
std
::
size_t
>
(
l
.
begin
(),
l
.
end
()))
{
{
...
@@ -112,10 +135,44 @@ struct shape
...
@@ -112,10 +135,44 @@ struct shape
type_t
type
()
const
;
type_t
type
()
const
;
const
std
::
vector
<
std
::
size_t
>&
lens
()
const
;
const
std
::
vector
<
std
::
size_t
>&
lens
()
const
;
const
std
::
vector
<
std
::
size_t
>&
strides
()
const
;
const
std
::
vector
<
std
::
size_t
>&
strides
()
const
;
/*!
* Return the number of elements in the tensor.
*/
std
::
size_t
elements
()
const
;
std
::
size_t
elements
()
const
;
/*!
* Return the number of total bytes used for storage of the tensor data; includes subshapes.
* For dynamic shape, returns the maximum number of bytes presuming a packed shape.
*/
std
::
size_t
bytes
()
const
;
std
::
size_t
bytes
()
const
;
/*!
* Return the size of the type of the main shape.
* Returns 0 if there are subshapes.
*/
std
::
size_t
type_size
()
const
;
std
::
size_t
type_size
()
const
;
const
std
::
vector
<
dynamic_dimension
>&
dyn_dims
()
const
;
/*!
* Minimum lengths for dynamic shape.
* lens() for fixed shape.
*/
std
::
vector
<
std
::
size_t
>
min_lens
()
const
;
/*!
* Maximum lengths for dynamic shape.
* lens() for fixed shape.
*/
std
::
vector
<
std
::
size_t
>
max_lens
()
const
;
/*!
* Optimum lengths for dynamic shape.
* lens() for fixed shape.
*/
std
::
vector
<
std
::
size_t
>
opt_lens
()
const
;
/// Map multiple indices to space index
/// Map multiple indices to space index
std
::
size_t
index
(
std
::
initializer_list
<
std
::
size_t
>
l
)
const
;
std
::
size_t
index
(
std
::
initializer_list
<
std
::
size_t
>
l
)
const
;
/// Map multiple indices to space index
/// Map multiple indices to space index
...
@@ -136,19 +193,27 @@ struct shape
...
@@ -136,19 +193,27 @@ struct shape
std
::
vector
<
std
::
size_t
>
multi
(
std
::
size_t
i
)
const
;
std
::
vector
<
std
::
size_t
>
multi
(
std
::
size_t
i
)
const
;
void
multi_copy
(
std
::
size_t
i
,
std
::
size_t
*
start
,
const
std
::
size_t
*
end
)
const
;
void
multi_copy
(
std
::
size_t
i
,
std
::
size_t
*
start
,
const
std
::
size_t
*
end
)
const
;
/// Returns true if the shape is packed with no padding
/// Returns true if the shape is packed (number of elements and buffer size the same) with no
/// padding
bool
packed
()
const
;
bool
packed
()
const
;
/// Returns true is the shape has been transposed. That is the strides are not in descending
/// Returns true is the shape has been transposed. That is the strides are not in descending
/// order
/// order
bool
transposed
()
const
;
bool
transposed
()
const
;
/// Returns true if the shape is broadcasting a dimension. That is, one of the strides are zero
/// Returns true if the shape is broadcasting a dimension. That is, one of the strides are zero
bool
broadcasted
()
const
;
bool
broadcasted
()
const
;
/// Returns true if the shape is in its standard format. That is, the shape is both packed and
/// Returns true if the shape is in its standard format. That is, the shape is both packed and
/// not transposed.
/// not transposed.
bool
standard
()
const
;
bool
standard
()
const
;
/// Returns true if all strides are equal to 0 (scalar tensor)
/// Returns true if all strides are equal to 0 (scalar tensor)
bool
scalar
()
const
;
bool
scalar
()
const
;
/// Return true if the shape is dynamic
bool
dynamic
()
const
;
shape
normalize_standard
()
const
;
shape
normalize_standard
()
const
;
shape
with_lens
(
type_t
t
,
const
std
::
vector
<
std
::
size_t
>&
l
)
const
;
shape
with_lens
(
type_t
t
,
const
std
::
vector
<
std
::
size_t
>&
l
)
const
;
...
@@ -191,6 +256,10 @@ struct shape
...
@@ -191,6 +256,10 @@ struct shape
std
::
size_t
size
(
std
::
size_t
n
=
1
)
const
{
return
sizeof
(
type
)
*
n
;
}
std
::
size_t
size
(
std
::
size_t
n
=
1
)
const
{
return
sizeof
(
type
)
*
n
;
}
auto
is_integral
()
const
{
return
std
::
is_integral
<
type
>
{};
}
auto
is_signed
()
const
{
return
std
::
is_signed
<
type
>
{};
}
auto
is_unsigned
()
const
{
return
std
::
is_unsigned
<
type
>
{};
}
template
<
class
U
>
template
<
class
U
>
type
*
from
(
U
*
buffer
,
std
::
size_t
n
=
0
)
const
type
*
from
(
U
*
buffer
,
std
::
size_t
n
=
0
)
const
{
{
...
@@ -248,6 +317,11 @@ struct shape
...
@@ -248,6 +317,11 @@ struct shape
const
std
::
vector
<
shape
>&
sub_shapes
()
const
;
const
std
::
vector
<
shape
>&
sub_shapes
()
const
;
/*!
* Returns the number of elements in the data buffer.
* For a dynamic shape, returns the maximum number of elements of the data buffer and assumes it
* is packed.
*/
std
::
size_t
element_space
()
const
;
std
::
size_t
element_space
()
const
;
private:
private:
...
...
src/include/migraphx/support_metric.hpp
0 → 100644
View file @
a0edd061
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHX_SUPPORT_METRIC_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_SUPPORT_METRIC_HPP
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
enum
class
support_metric
{
latency
,
throughput
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_SUPPORT_METRIC_HPP
src/include/migraphx/target.hpp
View file @
a0edd061
...
@@ -37,6 +37,8 @@
...
@@ -37,6 +37,8 @@
#include <migraphx/compile_options.hpp>
#include <migraphx/compile_options.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/rank.hpp>
#include <migraphx/rank.hpp>
#include <migraphx/support_metric.hpp>
#include <migraphx/instruction_ref.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -61,6 +63,13 @@ struct target
...
@@ -61,6 +63,13 @@ struct target
* @return The context to be used during compilation and execution.
* @return The context to be used during compilation and execution.
*/
*/
context
get_context
()
const
;
context
get_context
()
const
;
/**
* @brief Check how well an instruction is supported on a target with the given metric
* @param ins Instruction to check if it's supported
* @param metric Used to define how the return value should be interpreted
* @return The value based on the chosen metric. Negative numbers mean unsupported
*/
float
is_supported
(
T
&
,
instruction_ref
ins
,
support_metric
m
)
const
;
/**
/**
* @brief copy an argument to the current target.
* @brief copy an argument to the current target.
*
*
...
@@ -105,6 +114,12 @@ argument copy_from_target(T&, const argument& arg)
...
@@ -105,6 +114,12 @@ argument copy_from_target(T&, const argument& arg)
return
arg
;
return
arg
;
}
}
template
<
class
T
>
float
target_is_supported
(
T
&
,
instruction_ref
,
support_metric
)
{
return
0
;
}
#ifdef TYPE_ERASED_DECLARATION
#ifdef TYPE_ERASED_DECLARATION
// Type-erased interface for:
// Type-erased interface for:
...
@@ -117,6 +132,8 @@ struct target
...
@@ -117,6 +132,8 @@ struct target
//
//
context
get_context
()
const
;
context
get_context
()
const
;
// (optional)
// (optional)
float
is_supported
(
instruction_ref
ins
,
support_metric
m
)
const
;
// (optional)
argument
copy_to
(
const
argument
&
input
)
const
;
argument
copy_to
(
const
argument
&
input
)
const
;
// (optional)
// (optional)
argument
copy_from
(
const
argument
&
input
)
const
;
argument
copy_from
(
const
argument
&
input
)
const
;
...
@@ -207,6 +224,12 @@ struct target
...
@@ -207,6 +224,12 @@ struct target
return
(
*
this
).
private_detail_te_get_handle
().
get_context
();
return
(
*
this
).
private_detail_te_get_handle
().
get_context
();
}
}
float
is_supported
(
instruction_ref
ins
,
support_metric
m
)
const
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
return
(
*
this
).
private_detail_te_get_handle
().
is_supported
(
ins
,
m
);
}
argument
copy_to
(
const
argument
&
input
)
const
argument
copy_to
(
const
argument
&
input
)
const
{
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
assert
((
*
this
).
private_detail_te_handle_mem_var
);
...
@@ -242,11 +265,31 @@ struct target
...
@@ -242,11 +265,31 @@ struct target
virtual
std
::
vector
<
pass
>
get_passes
(
context
&
ctx
,
virtual
std
::
vector
<
pass
>
get_passes
(
context
&
ctx
,
const
compile_options
&
options
)
const
=
0
;
const
compile_options
&
options
)
const
=
0
;
virtual
context
get_context
()
const
=
0
;
virtual
context
get_context
()
const
=
0
;
virtual
float
is_supported
(
instruction_ref
ins
,
support_metric
m
)
const
=
0
;
virtual
argument
copy_to
(
const
argument
&
input
)
const
=
0
;
virtual
argument
copy_to
(
const
argument
&
input
)
const
=
0
;
virtual
argument
copy_from
(
const
argument
&
input
)
const
=
0
;
virtual
argument
copy_from
(
const
argument
&
input
)
const
=
0
;
virtual
argument
allocate
(
const
shape
&
s
)
const
=
0
;
virtual
argument
allocate
(
const
shape
&
s
)
const
=
0
;
};
};
template
<
class
T
>
static
auto
private_detail_te_default_is_supported
(
char
,
T
&&
private_detail_te_self
,
instruction_ref
ins
,
support_metric
m
)
->
decltype
(
private_detail_te_self
.
is_supported
(
ins
,
m
))
{
return
private_detail_te_self
.
is_supported
(
ins
,
m
);
}
template
<
class
T
>
static
float
private_detail_te_default_is_supported
(
float
,
T
&&
private_detail_te_self
,
instruction_ref
ins
,
support_metric
m
)
{
return
target_is_supported
(
private_detail_te_self
,
ins
,
m
);
}
template
<
class
T
>
template
<
class
T
>
static
auto
static
auto
private_detail_te_default_copy_to
(
char
,
T
&&
private_detail_te_self
,
const
argument
&
input
)
private_detail_te_default_copy_to
(
char
,
T
&&
private_detail_te_self
,
const
argument
&
input
)
...
@@ -329,6 +372,12 @@ struct target
...
@@ -329,6 +372,12 @@ struct target
context
get_context
()
const
override
{
return
private_detail_te_value
.
get_context
();
}
context
get_context
()
const
override
{
return
private_detail_te_value
.
get_context
();
}
float
is_supported
(
instruction_ref
ins
,
support_metric
m
)
const
override
{
return
private_detail_te_default_is_supported
(
char
(
0
),
private_detail_te_value
,
ins
,
m
);
}
argument
copy_to
(
const
argument
&
input
)
const
override
argument
copy_to
(
const
argument
&
input
)
const
override
{
{
...
...
src/include/migraphx/target_assignments.hpp
0 → 100644
View file @
a0edd061
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHX_ASSIGNMENT_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_ASSIGNMENT_HPP
#include <unordered_map>
#include <migraphx/instruction_ref.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
target_assignments
{
void
add_assignment
(
instruction_ref
ins
,
const
std
::
string
&
target
);
auto
begin
()
const
{
return
assignments
.
cbegin
();
}
auto
end
()
const
{
return
assignments
.
cend
();
}
private:
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
assignments
;
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_ASSIGNMENT_HPP
src/module.cpp
View file @
a0edd061
...
@@ -399,7 +399,8 @@ module::add_instructions(const std::vector<instruction_ref>& instructions,
...
@@ -399,7 +399,8 @@ module::add_instructions(const std::vector<instruction_ref>& instructions,
}
}
std
::
vector
<
instruction_ref
>
std
::
vector
<
instruction_ref
>
module
::
add_instructions
(
module_ref
m
,
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
map_ins
)
module
::
add_instructions
(
const_module_ref
m
,
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
map_ins
)
{
{
return
this
->
insert_instructions
(
this
->
end
(),
m
,
std
::
move
(
map_ins
));
return
this
->
insert_instructions
(
this
->
end
(),
m
,
std
::
move
(
map_ins
));
}
}
...
@@ -420,8 +421,10 @@ module::insert_instructions(instruction_ref ins,
...
@@ -420,8 +421,10 @@ module::insert_instructions(instruction_ref ins,
return
insert_generic_instructions
(
*
this
,
ins
,
instructions
,
std
::
move
(
map_ins
));
return
insert_generic_instructions
(
*
this
,
ins
,
instructions
,
std
::
move
(
map_ins
));
}
}
std
::
vector
<
instruction_ref
>
module
::
insert_instructions
(
std
::
vector
<
instruction_ref
>
instruction_ref
ins
,
module_ref
m
,
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
map_ins
)
module
::
insert_instructions
(
instruction_ref
ins
,
const_module_ref
m
,
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
map_ins
)
{
{
return
insert_generic_instructions
(
*
this
,
ins
,
iterator_for
(
*
m
),
std
::
move
(
map_ins
));
return
insert_generic_instructions
(
*
this
,
ins
,
iterator_for
(
*
m
),
std
::
move
(
map_ins
));
}
}
...
@@ -436,11 +439,7 @@ module::insert_instructions(instruction_ref ins,
...
@@ -436,11 +439,7 @@ module::insert_instructions(instruction_ref ins,
return
insert_generic_instructions
(
*
this
,
ins
,
iterator_for
(
r
),
std
::
move
(
map_ins
));
return
insert_generic_instructions
(
*
this
,
ins
,
iterator_for
(
r
),
std
::
move
(
map_ins
));
}
}
instruction_ref
module
::
add_literal
(
literal
l
)
instruction_ref
module
::
add_literal
(
literal
l
)
{
return
insert_literal
(
begin
(),
std
::
move
(
l
));
}
{
impl
->
emplace_front
(
std
::
move
(
l
));
return
impl
->
instructions
.
begin
();
}
instruction_ref
module
::
add_outline
(
const
shape
&
s
)
instruction_ref
module
::
add_outline
(
const
shape
&
s
)
{
{
...
@@ -450,10 +449,7 @@ instruction_ref module::add_outline(const shape& s)
...
@@ -450,10 +449,7 @@ instruction_ref module::add_outline(const shape& s)
instruction_ref
module
::
add_parameter
(
std
::
string
name
,
shape
s
)
instruction_ref
module
::
add_parameter
(
std
::
string
name
,
shape
s
)
{
{
assert
(
get_parameter_shape
(
name
)
==
shape
{});
return
insert_parameter
(
begin
(),
std
::
move
(
name
),
std
::
move
(
s
));
impl
->
push_front
({
builtin
::
param
{
std
::
move
(
name
),
impl
->
nparams
},
std
::
move
(
s
),
{}});
impl
->
nparams
++
;
return
impl
->
instructions
.
begin
();
}
}
instruction_ref
module
::
add_return
(
std
::
vector
<
instruction_ref
>
args
)
instruction_ref
module
::
add_return
(
std
::
vector
<
instruction_ref
>
args
)
...
@@ -466,6 +462,20 @@ instruction_ref module::add_return(std::vector<instruction_ref> args)
...
@@ -466,6 +462,20 @@ instruction_ref module::add_return(std::vector<instruction_ref> args)
return
result
;
return
result
;
}
}
instruction_ref
module
::
insert_literal
(
instruction_ref
ins
,
literal
l
)
{
impl
->
emplace
(
ins
,
std
::
move
(
l
));
return
std
::
prev
(
ins
);
}
instruction_ref
module
::
insert_parameter
(
instruction_ref
ins
,
std
::
string
name
,
shape
s
)
{
assert
(
get_parameter_shape
(
name
)
==
shape
{});
impl
->
insert
(
ins
,
{
builtin
::
param
{
std
::
move
(
name
),
impl
->
nparams
},
std
::
move
(
s
),
{}});
impl
->
nparams
++
;
return
std
::
prev
(
ins
);
}
instruction_ref
module
::
replace_return
(
std
::
vector
<
instruction_ref
>
args
)
instruction_ref
module
::
replace_return
(
std
::
vector
<
instruction_ref
>
args
)
{
{
auto
last
=
std
::
prev
(
this
->
end
());
auto
last
=
std
::
prev
(
this
->
end
());
...
...
src/program.cpp
View file @
a0edd061
...
@@ -159,6 +159,25 @@ instruction_ref program::validate() const
...
@@ -159,6 +159,25 @@ instruction_ref program::validate() const
return
mm
->
validate
();
return
mm
->
validate
();
}
}
target_assignments
program
::
get_target_assignments
(
const
std
::
vector
<
target
>&
targets
,
assignment_options
options
)
{
const
auto
m
=
options
.
metric
;
target_assignments
p
;
const
auto
*
mod
=
get_main_module
();
for
(
auto
it
:
iterator_for
(
*
mod
))
{
auto
t
=
std
::
max_element
(
targets
.
begin
(),
targets
.
end
(),
[
it
,
m
](
const
target
&
lhs
,
const
target
&
rhs
)
{
return
lhs
.
is_supported
(
it
,
m
)
<
rhs
.
is_supported
(
it
,
m
);
});
p
.
add_assignment
(
it
,
t
->
name
());
}
return
p
;
}
bool
program
::
is_compiled
()
const
{
return
not
this
->
impl
->
target_name
.
empty
();
}
bool
program
::
is_compiled
()
const
{
return
not
this
->
impl
->
target_name
.
empty
();
}
void
program
::
compile
(
const
target
&
t
,
compile_options
options
)
void
program
::
compile
(
const
target
&
t
,
compile_options
options
)
...
@@ -504,12 +523,14 @@ static void mod_from_val(module_ref mod,
...
@@ -504,12 +523,14 @@ static void mod_from_val(module_ref mod,
if
(
name
==
"@param"
)
if
(
name
==
"@param"
)
{
{
output
=
mod
->
add_parameter
(
fields
[
"parameter"
].
to
<
std
::
string
>
(),
output
=
mod
->
insert_parameter
(
mod
->
end
(),
migraphx
::
from_value
<
shape
>
(
node
.
at
(
"shape"
)));
fields
[
"parameter"
].
to
<
std
::
string
>
(),
migraphx
::
from_value
<
shape
>
(
node
.
at
(
"shape"
)));
}
}
else
if
(
name
==
"@literal"
)
else
if
(
name
==
"@literal"
)
{
{
output
=
mod
->
add_literal
(
migraphx
::
from_value
<
literal
>
(
node
.
at
(
"literal"
)));
output
=
mod
->
insert_literal
(
mod
->
end
(),
migraphx
::
from_value
<
literal
>
(
node
.
at
(
"literal"
)));
}
}
else
else
{
{
...
@@ -544,11 +565,11 @@ static void mod_from_val(module_ref mod,
...
@@ -544,11 +565,11 @@ static void mod_from_val(module_ref mod,
}
}
else
if
(
module_inputs
.
empty
())
else
if
(
module_inputs
.
empty
())
{
{
output
=
mod
->
add
_instruction
(
op
,
inputs
);
output
=
mod
->
insert
_instruction
(
mod
->
end
(),
op
,
inputs
);
}
}
else
else
{
{
output
=
mod
->
add
_instruction
(
op
,
inputs
,
module_inputs
);
output
=
mod
->
insert
_instruction
(
mod
->
end
(),
op
,
inputs
,
module_inputs
);
}
}
}
}
output
->
set_normalized
(
normalized
);
output
->
set_normalized
(
normalized
);
...
@@ -681,11 +702,13 @@ void program::perf_report(std::ostream& os,
...
@@ -681,11 +702,13 @@ void program::perf_report(std::ostream& os,
double
overhead_percent
=
overhead_time
*
100.0
/
total_time
;
double
overhead_percent
=
overhead_time
*
100.0
/
total_time
;
double
total_instruction_time
=
0.0
;
double
total_instruction_time
=
0.0
;
std
::
unordered_map
<
std
::
string
,
double
>
op_times
;
std
::
unordered_map
<
std
::
string
,
double
>
op_times
;
std
::
unordered_map
<
std
::
string
,
std
::
size_t
>
op_n
;
for
(
auto
&&
p
:
ins_vec
)
for
(
auto
&&
p
:
ins_vec
)
{
{
double
avg
=
common_average
(
p
.
second
);
double
avg
=
common_average
(
p
.
second
);
op_times
[
perf_group
(
p
.
first
->
get_operator
())]
+=
avg
;
op_times
[
perf_group
(
p
.
first
->
get_operator
())]
+=
avg
;
total_instruction_time
+=
avg
;
total_instruction_time
+=
avg
;
op_n
[
perf_group
(
p
.
first
->
get_operator
())]
++
;
}
}
double
calculate_overhead_time
=
total_time
-
total_instruction_time
;
double
calculate_overhead_time
=
total_time
-
total_instruction_time
;
double
calculate_overhead_percent
=
calculate_overhead_time
*
100.0
/
total_time
;
double
calculate_overhead_percent
=
calculate_overhead_time
*
100.0
/
total_time
;
...
@@ -706,18 +729,19 @@ void program::perf_report(std::ostream& os,
...
@@ -706,18 +729,19 @@ void program::perf_report(std::ostream& os,
os
<<
std
::
endl
;
os
<<
std
::
endl
;
os
<<
"Summary:"
<<
std
::
endl
;
os
<<
"Summary:"
<<
std
::
endl
;
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>
op_times_sorted
;
std
::
vector
<
std
::
tuple
<
double
,
std
::
size_t
,
std
::
string
>>
op_times_sorted
;
std
::
transform
(
op_times
.
begin
(),
std
::
transform
(
op_times
.
end
(),
op_times
.
begin
(),
op_times
.
end
(),
std
::
back_inserter
(
op_times_sorted
),
[
&
](
auto
p
)
{
std
::
back_inserter
(
op_times_sorted
),
auto
&&
name
=
p
.
first
;
[](
auto
p
)
{
return
std
::
make_pair
(
p
.
second
,
p
.
first
);
});
return
std
::
make_tuple
(
p
.
second
,
op_n
.
at
(
name
),
name
);
});
std
::
sort
(
op_times_sorted
.
begin
(),
op_times_sorted
.
end
(),
std
::
greater
<>
{});
std
::
sort
(
op_times_sorted
.
begin
(),
op_times_sorted
.
end
(),
std
::
greater
<>
{});
for
(
auto
&&
p
:
op_times_sorted
)
for
(
auto
&&
[
avg
,
nn
,
name
]
:
op_times_sorted
)
{
{
auto
&&
name
=
p
.
second
;
double
avg
=
p
.
first
;
double
percent
=
std
::
ceil
(
100.0
*
avg
/
total_instruction_time
);
double
percent
=
std
::
ceil
(
100.0
*
avg
/
total_instruction_time
);
os
<<
name
<<
": "
<<
avg
<<
"ms, "
<<
percent
<<
"%"
<<
std
::
endl
;
double
per_ins
=
avg
/
nn
;
os
<<
name
<<
": "
<<
avg
<<
"ms / "
<<
nn
<<
" = "
<<
per_ins
<<
"ms, "
<<
percent
<<
"%"
<<
std
::
endl
;
}
}
os
<<
std
::
endl
;
os
<<
std
::
endl
;
...
...
src/serialize.cpp
View file @
a0edd061
...
@@ -36,7 +36,7 @@ void raw_data_to_value(value& v, const RawData& rd)
...
@@ -36,7 +36,7 @@ void raw_data_to_value(value& v, const RawData& rd)
result
[
"shape"
]
=
migraphx
::
to_value
(
rd
.
get_shape
());
result
[
"shape"
]
=
migraphx
::
to_value
(
rd
.
get_shape
());
if
(
rd
.
get_shape
().
type
()
==
shape
::
tuple_type
)
if
(
rd
.
get_shape
().
type
()
==
shape
::
tuple_type
)
result
[
"sub"
]
=
migraphx
::
to_value
(
rd
.
get_sub_objects
());
result
[
"sub"
]
=
migraphx
::
to_value
(
rd
.
get_sub_objects
());
else
else
if
(
not
rd
.
empty
())
result
[
"data"
]
=
migraphx
::
value
::
binary
(
rd
.
data
(),
rd
.
get_shape
().
bytes
());
result
[
"data"
]
=
migraphx
::
value
::
binary
(
rd
.
data
(),
rd
.
get_shape
().
bytes
());
v
=
result
;
v
=
result
;
}
}
...
@@ -56,7 +56,7 @@ void migraphx_from_value(const value& v, argument& a)
...
@@ -56,7 +56,7 @@ void migraphx_from_value(const value& v, argument& a)
literal
l
=
migraphx
::
from_value
<
literal
>
(
v
);
literal
l
=
migraphx
::
from_value
<
literal
>
(
v
);
a
=
l
.
get_argument
();
a
=
l
.
get_argument
();
}
}
else
else
if
(
v
.
contains
(
"sub"
))
{
{
a
=
migraphx
::
from_value
<
std
::
vector
<
argument
>>
(
v
.
at
(
"sub"
));
a
=
migraphx
::
from_value
<
std
::
vector
<
argument
>>
(
v
.
at
(
"sub"
));
}
}
...
...
src/shape.cpp
View file @
a0edd061
...
@@ -26,6 +26,7 @@
...
@@ -26,6 +26,7 @@
#include <migraphx/stringutils.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/ranges.hpp>
#include <numeric>
#include <numeric>
#include <algorithm>
#include <algorithm>
#include <functional>
#include <functional>
...
@@ -65,13 +66,21 @@ struct shape_impl
...
@@ -65,13 +66,21 @@ struct shape_impl
std
::
is_sorted
(
m_strides
.
rbegin
(),
m_strides
.
rend
());
std
::
is_sorted
(
m_strides
.
rbegin
(),
m_strides
.
rend
());
}
}
shape_impl
(
shape
::
type_t
t
,
std
::
vector
<
shape
::
dynamic_dimension
>
dims
)
:
m_type
(
t
),
m_dyn_dims
(
std
::
move
(
dims
))
{
}
shape_impl
(
const
std
::
vector
<
shape
>&
subs
)
:
m_type
(
shape
::
tuple_type
),
m_shapes
(
subs
)
{}
shape_impl
(
const
std
::
vector
<
shape
>&
subs
)
:
m_type
(
shape
::
tuple_type
),
m_shapes
(
subs
)
{}
shape
::
type_t
m_type
;
shape
::
type_t
m_type
;
std
::
vector
<
std
::
size_t
>
m_lens
=
{};
std
::
vector
<
std
::
size_t
>
m_lens
=
{};
std
::
vector
<
std
::
size_t
>
m_strides
=
{};
std
::
vector
<
std
::
size_t
>
m_strides
=
{};
std
::
vector
<
shape
>
m_shapes
=
{};
std
::
vector
<
shape
>
m_shapes
=
{};
bool
m_standard
=
false
;
bool
m_standard
=
false
;
std
::
vector
<
shape
::
dynamic_dimension
>
m_dyn_dims
=
{};
void
calculate_strides
()
void
calculate_strides
()
{
{
m_strides
.
clear
();
m_strides
.
clear
();
...
@@ -87,6 +96,12 @@ struct shape_impl
...
@@ -87,6 +96,12 @@ struct shape_impl
std
::
size_t
element_space
()
const
std
::
size_t
element_space
()
const
{
{
if
(
not
m_dyn_dims
.
empty
())
{
auto
maxes
=
max_lens
();
return
std
::
accumulate
(
maxes
.
begin
(),
maxes
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<>
());
}
assert
(
m_lens
.
size
()
==
m_strides
.
size
());
assert
(
m_lens
.
size
()
==
m_strides
.
size
());
if
(
m_lens
.
empty
())
if
(
m_lens
.
empty
())
return
0
;
return
0
;
...
@@ -101,6 +116,11 @@ struct shape_impl
...
@@ -101,6 +116,11 @@ struct shape_impl
std
::
size_t
elements
()
const
std
::
size_t
elements
()
const
{
{
if
(
not
m_dyn_dims
.
empty
())
{
MIGRAPHX_THROW
(
"SHAPE: elements() called on dynamic shape"
);
}
assert
(
m_lens
.
size
()
==
m_strides
.
size
());
assert
(
m_lens
.
size
()
==
m_strides
.
size
());
if
(
m_lens
.
empty
())
if
(
m_lens
.
empty
())
return
0
;
return
0
;
...
@@ -108,6 +128,35 @@ struct shape_impl
...
@@ -108,6 +128,35 @@ struct shape_impl
m_lens
.
begin
(),
m_lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
m_lens
.
begin
(),
m_lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
}
}
std
::
vector
<
std
::
size_t
>
min_lens
()
const
{
std
::
vector
<
std
::
size_t
>
ret
(
m_dyn_dims
.
size
());
std
::
transform
(
m_dyn_dims
.
cbegin
(),
m_dyn_dims
.
cend
(),
ret
.
begin
(),
[](
shape
::
dynamic_dimension
x
)
{
return
x
.
min
;
});
return
ret
;
}
std
::
vector
<
std
::
size_t
>
max_lens
()
const
{
std
::
vector
<
std
::
size_t
>
ret
(
m_dyn_dims
.
size
());
std
::
transform
(
m_dyn_dims
.
cbegin
(),
m_dyn_dims
.
cend
(),
ret
.
begin
(),
[](
shape
::
dynamic_dimension
x
)
{
return
x
.
max
;
});
return
ret
;
}
std
::
vector
<
std
::
size_t
>
opt_lens
()
const
{
std
::
vector
<
std
::
size_t
>
ret
(
m_dyn_dims
.
size
());
std
::
transform
(
m_dyn_dims
.
cbegin
(),
m_dyn_dims
.
cend
(),
ret
.
begin
(),
[](
shape
::
dynamic_dimension
x
)
{
return
x
.
opt
;
});
return
ret
;
}
// Does the shape skip over elements?
// Does the shape skip over elements?
bool
skips
()
const
bool
skips
()
const
{
{
...
@@ -165,6 +214,16 @@ shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
...
@@ -165,6 +214,16 @@ shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
{
{
}
}
shape
::
shape
(
type_t
t
,
std
::
initializer_list
<
std
::
size_t
>
d
)
:
shape
::
shape
(
t
,
std
::
vector
<
std
::
size_t
>
{
d
.
begin
(),
d
.
end
()})
{
}
shape
::
shape
(
type_t
t
,
std
::
vector
<
shape
::
dynamic_dimension
>
dims
)
:
impl
(
std
::
make_shared
<
shape_impl
>
(
t
,
std
::
move
(
dims
)))
{
}
shape
::
shape
(
const
std
::
vector
<
shape
>&
subs
)
:
impl
(
std
::
make_shared
<
shape_impl
>
(
subs
))
{}
shape
::
shape
(
const
std
::
vector
<
shape
>&
subs
)
:
impl
(
std
::
make_shared
<
shape_impl
>
(
subs
))
{}
shape
::
shape
(
std
::
shared_ptr
<
shape_impl
>
pimpl
)
:
impl
(
std
::
move
(
pimpl
))
{}
shape
::
shape
(
std
::
shared_ptr
<
shape_impl
>
pimpl
)
:
impl
(
std
::
move
(
pimpl
))
{}
...
@@ -180,9 +239,13 @@ shape shape::from_permutation(type_t t,
...
@@ -180,9 +239,13 @@ shape shape::from_permutation(type_t t,
}
}
shape
::
type_t
shape
::
type
()
const
{
return
impl
->
m_type
;
}
shape
::
type_t
shape
::
type
()
const
{
return
impl
->
m_type
;
}
const
std
::
vector
<
std
::
size_t
>&
shape
::
lens
()
const
{
return
impl
->
m_lens
;
}
const
std
::
vector
<
std
::
size_t
>&
shape
::
lens
()
const
{
return
impl
->
m_lens
;
}
const
std
::
vector
<
std
::
size_t
>&
shape
::
strides
()
const
{
return
impl
->
m_strides
;
}
const
std
::
vector
<
std
::
size_t
>&
shape
::
strides
()
const
{
return
impl
->
m_strides
;
}
std
::
size_t
shape
::
elements
()
const
{
return
impl
->
elements
();
}
std
::
size_t
shape
::
elements
()
const
{
return
impl
->
elements
();
}
std
::
size_t
shape
::
bytes
()
const
std
::
size_t
shape
::
bytes
()
const
{
{
if
(
this
->
sub_shapes
().
empty
())
if
(
this
->
sub_shapes
().
empty
())
...
@@ -199,6 +262,7 @@ std::size_t shape::bytes() const
...
@@ -199,6 +262,7 @@ std::size_t shape::bytes() const
[
&
](
auto
x
,
auto
y
)
{
return
x
+
y
.
bytes
();
});
[
&
](
auto
x
,
auto
y
)
{
return
x
+
y
.
bytes
();
});
}
}
}
}
std
::
size_t
shape
::
type_size
()
const
std
::
size_t
shape
::
type_size
()
const
{
{
std
::
size_t
n
=
0
;
std
::
size_t
n
=
0
;
...
@@ -206,20 +270,35 @@ std::size_t shape::type_size() const
...
@@ -206,20 +270,35 @@ std::size_t shape::type_size() const
this
->
visit_type
([
&
](
auto
as
)
{
n
=
as
.
size
();
});
this
->
visit_type
([
&
](
auto
as
)
{
n
=
as
.
size
();
});
return
n
;
return
n
;
}
}
std
::
size_t
shape
::
index
(
std
::
initializer_list
<
std
::
size_t
>
l
)
const
std
::
size_t
shape
::
index
(
std
::
initializer_list
<
std
::
size_t
>
l
)
const
{
{
if
(
this
->
dynamic
())
{
MIGRAPHX_THROW
(
"SHAPE: index() called on dynamic shape"
);
}
assert
(
l
.
size
()
<=
this
->
lens
().
size
());
assert
(
l
.
size
()
<=
this
->
lens
().
size
());
assert
(
this
->
lens
().
size
()
==
this
->
strides
().
size
());
assert
(
this
->
lens
().
size
()
==
this
->
strides
().
size
());
return
std
::
inner_product
(
l
.
begin
(),
l
.
end
(),
this
->
strides
().
begin
(),
std
::
size_t
{
0
});
return
std
::
inner_product
(
l
.
begin
(),
l
.
end
(),
this
->
strides
().
begin
(),
std
::
size_t
{
0
});
}
}
std
::
size_t
shape
::
index
(
const
std
::
vector
<
std
::
size_t
>&
l
)
const
std
::
size_t
shape
::
index
(
const
std
::
vector
<
std
::
size_t
>&
l
)
const
{
{
if
(
this
->
dynamic
())
{
MIGRAPHX_THROW
(
"SHAPE: index() called on dynamic shape"
);
}
assert
(
l
.
size
()
<=
this
->
lens
().
size
());
assert
(
l
.
size
()
<=
this
->
lens
().
size
());
assert
(
this
->
lens
().
size
()
==
this
->
strides
().
size
());
assert
(
this
->
lens
().
size
()
==
this
->
strides
().
size
());
return
std
::
inner_product
(
l
.
begin
(),
l
.
end
(),
this
->
strides
().
begin
(),
std
::
size_t
{
0
});
return
std
::
inner_product
(
l
.
begin
(),
l
.
end
(),
this
->
strides
().
begin
(),
std
::
size_t
{
0
});
}
}
std
::
size_t
shape
::
index
(
std
::
size_t
i
)
const
std
::
size_t
shape
::
index
(
std
::
size_t
i
)
const
{
{
if
(
this
->
dynamic
())
{
MIGRAPHX_THROW
(
"SHAPE: index() called on dynamic shape"
);
}
assert
(
this
->
lens
().
size
()
==
this
->
strides
().
size
());
assert
(
this
->
lens
().
size
()
==
this
->
strides
().
size
());
if
(
this
->
standard
())
if
(
this
->
standard
())
return
i
;
return
i
;
...
@@ -267,12 +346,20 @@ void shape::multi_copy(std::size_t i, std::size_t* start, const std::size_t* end
...
@@ -267,12 +346,20 @@ void shape::multi_copy(std::size_t i, std::size_t* start, const std::size_t* end
bool
shape
::
packed
()
const
bool
shape
::
packed
()
const
{
{
if
(
this
->
dynamic
())
{
return
false
;
}
return
this
->
sub_shapes
().
empty
()
and
not
impl
->
skips
()
and
return
this
->
sub_shapes
().
empty
()
and
not
impl
->
skips
()
and
this
->
elements
()
==
this
->
element_space
();
this
->
elements
()
==
this
->
element_space
();
}
}
bool
shape
::
transposed
()
const
bool
shape
::
transposed
()
const
{
{
if
(
this
->
dynamic
())
{
return
false
;
}
if
(
this
->
broadcasted
())
if
(
this
->
broadcasted
())
{
{
// TODO: Use a filter_iterator instead
// TODO: Use a filter_iterator instead
...
@@ -292,6 +379,10 @@ bool shape::transposed() const
...
@@ -292,6 +379,10 @@ bool shape::transposed() const
bool
shape
::
broadcasted
()
const
bool
shape
::
broadcasted
()
const
{
{
if
(
this
->
dynamic
())
{
return
false
;
}
assert
(
this
->
lens
().
size
()
==
this
->
strides
().
size
());
assert
(
this
->
lens
().
size
()
==
this
->
strides
().
size
());
return
std
::
any_of
(
return
std
::
any_of
(
this
->
strides
().
begin
(),
this
->
strides
().
end
(),
[](
auto
x
)
{
return
x
==
0
;
});
this
->
strides
().
begin
(),
this
->
strides
().
end
(),
[](
auto
x
)
{
return
x
==
0
;
});
...
@@ -299,6 +390,10 @@ bool shape::broadcasted() const
...
@@ -299,6 +390,10 @@ bool shape::broadcasted() const
bool
shape
::
scalar
()
const
bool
shape
::
scalar
()
const
{
{
if
(
this
->
dynamic
())
{
return
false
;
}
assert
(
this
->
lens
().
size
()
==
this
->
strides
().
size
());
assert
(
this
->
lens
().
size
()
==
this
->
strides
().
size
());
// if any stride > 0, then accumulate will return false
// if any stride > 0, then accumulate will return false
return
this
->
sub_shapes
().
empty
()
and
return
this
->
sub_shapes
().
empty
()
and
...
@@ -317,6 +412,10 @@ shape shape::normalize_standard() const
...
@@ -317,6 +412,10 @@ shape shape::normalize_standard() const
shape
shape
::
with_lens
(
type_t
t
,
const
std
::
vector
<
std
::
size_t
>&
l
)
const
shape
shape
::
with_lens
(
type_t
t
,
const
std
::
vector
<
std
::
size_t
>&
l
)
const
{
{
if
(
this
->
dynamic
())
{
MIGRAPHX_THROW
(
"SHAPE: with_lens() called on dynamic shape"
);
}
assert
(
l
.
size
()
==
this
->
lens
().
size
());
assert
(
l
.
size
()
==
this
->
lens
().
size
());
auto
perm
=
find_permutation
(
*
this
);
auto
perm
=
find_permutation
(
*
this
);
return
shape
::
from_permutation
(
t
,
l
,
perm
);
return
shape
::
from_permutation
(
t
,
l
,
perm
);
...
@@ -324,6 +423,10 @@ shape shape::with_lens(type_t t, const std::vector<std::size_t>& l) const
...
@@ -324,6 +423,10 @@ shape shape::with_lens(type_t t, const std::vector<std::size_t>& l) const
shape
shape
::
with_lens
(
const
std
::
vector
<
std
::
size_t
>&
l
)
const
shape
shape
::
with_lens
(
const
std
::
vector
<
std
::
size_t
>&
l
)
const
{
{
if
(
this
->
dynamic
())
{
MIGRAPHX_THROW
(
"SHAPE: with_lens() called on dynamic shape"
);
}
return
this
->
with_lens
(
this
->
type
(),
l
);
return
this
->
with_lens
(
this
->
type
(),
l
);
}
}
...
@@ -338,20 +441,80 @@ std::size_t shape::element_space() const { return impl->element_space(); }
...
@@ -338,20 +441,80 @@ std::size_t shape::element_space() const { return impl->element_space(); }
std
::
string
shape
::
type_string
()
const
{
return
name
(
this
->
type
());
}
std
::
string
shape
::
type_string
()
const
{
return
name
(
this
->
type
());
}
bool
shape
::
dynamic
()
const
{
return
not
impl
->
m_dyn_dims
.
empty
();
}
const
std
::
vector
<
shape
::
dynamic_dimension
>&
shape
::
dyn_dims
()
const
{
return
impl
->
m_dyn_dims
;
}
std
::
vector
<
std
::
size_t
>
shape
::
min_lens
()
const
{
return
this
->
dynamic
()
?
impl
->
min_lens
()
:
this
->
lens
();
}
std
::
vector
<
std
::
size_t
>
shape
::
max_lens
()
const
{
return
this
->
dynamic
()
?
impl
->
max_lens
()
:
this
->
lens
();
}
std
::
vector
<
std
::
size_t
>
shape
::
opt_lens
()
const
{
return
this
->
dynamic
()
?
impl
->
opt_lens
()
:
this
->
lens
();
}
bool
shape
::
dynamic_dimension
::
is_fixed
()
const
{
return
this
->
min
==
this
->
max
;
}
bool
shape
::
dynamic_dimension
::
has_optimal
()
const
{
return
opt
!=
0
;
}
template
<
class
Self
,
class
F
>
auto
shape
::
dynamic_dimension
::
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
min
,
"min"
),
f
(
self
.
max
,
"max"
),
f
(
self
.
opt
,
"opt"
));
}
bool
operator
==
(
const
shape
::
dynamic_dimension
&
x
,
const
shape
::
dynamic_dimension
&
y
)
{
return
(
x
.
min
==
y
.
min
and
x
.
max
==
y
.
max
and
x
.
opt
==
y
.
opt
);
}
bool
operator
!=
(
const
shape
::
dynamic_dimension
&
x
,
const
shape
::
dynamic_dimension
&
y
)
{
return
!
(
x
==
y
);
}
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
shape
::
dynamic_dimension
&
x
)
{
os
<<
"["
<<
x
.
min
<<
", "
<<
x
.
max
<<
", "
<<
x
.
opt
<<
"]"
;
return
os
;
}
bool
operator
==
(
const
shape
&
x
,
const
shape
&
y
)
bool
operator
==
(
const
shape
&
x
,
const
shape
&
y
)
{
{
return
x
.
impl
==
y
.
impl
or
(
x
.
type
()
==
y
.
type
()
and
x
.
lens
()
==
y
.
lens
()
and
if
(
x
.
dynamic
()
and
y
.
dynamic
())
x
.
strides
()
==
y
.
strides
()
and
x
.
sub_shapes
()
==
y
.
sub_shapes
());
{
return
x
.
impl
==
y
.
impl
or
(
x
.
type
()
==
y
.
type
()
and
x
.
dyn_dims
()
==
y
.
dyn_dims
()
and
x
.
sub_shapes
()
==
y
.
sub_shapes
());
}
return
x
.
impl
==
y
.
impl
or
(
x
.
dynamic
()
==
y
.
dynamic
()
and
x
.
type
()
==
y
.
type
()
and
x
.
lens
()
==
y
.
lens
()
and
x
.
strides
()
==
y
.
strides
()
and
x
.
sub_shapes
()
==
y
.
sub_shapes
());
}
}
bool
operator
!=
(
const
shape
&
x
,
const
shape
&
y
)
{
return
!
(
x
==
y
);
}
bool
operator
!=
(
const
shape
&
x
,
const
shape
&
y
)
{
return
!
(
x
==
y
);
}
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
shape
&
x
)
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
shape
&
x
)
{
{
if
(
x
.
sub_shapes
().
empty
())
if
(
x
.
sub_shapes
().
empty
())
{
{
os
<<
x
.
type_string
()
<<
", "
;
if
(
x
.
dynamic
())
os
<<
"{"
<<
to_string_range
(
x
.
lens
())
<<
"}, "
;
{
os
<<
"{"
<<
to_string_range
(
x
.
strides
())
<<
"}"
;
os
<<
"dynamic, "
;
os
<<
x
.
type_string
()
<<
", "
;
os
<<
"{"
<<
to_string_range
(
x
.
dyn_dims
())
<<
"}"
;
}
else
{
os
<<
x
.
type_string
()
<<
", "
;
os
<<
"{"
<<
to_string_range
(
x
.
lens
())
<<
"}, "
;
os
<<
"{"
<<
to_string_range
(
x
.
strides
())
<<
"}"
;
}
}
}
else
else
{
{
...
@@ -375,12 +538,14 @@ const std::vector<shape>& shape::sub_shapes() const { return impl->m_shapes; }
...
@@ -375,12 +538,14 @@ const std::vector<shape>& shape::sub_shapes() const { return impl->m_shapes; }
void
migraphx_to_value
(
value
&
v
,
const
shape
&
s
)
void
migraphx_to_value
(
value
&
v
,
const
shape
&
s
)
{
{
value
result
;
value
result
;
result
[
"type"
]
=
migraphx
::
to_value
(
s
.
type_string
());
result
[
"type"
]
=
migraphx
::
to_value
(
s
.
type_string
());
result
[
"lens"
]
=
migraphx
::
to_value
(
s
.
lens
());
result
[
"lens"
]
=
migraphx
::
to_value
(
s
.
lens
());
result
[
"strides"
]
=
migraphx
::
to_value
(
s
.
strides
());
result
[
"strides"
]
=
migraphx
::
to_value
(
s
.
strides
());
result
[
"sub_shapes"
]
=
migraphx
::
to_value
(
s
.
sub_shapes
());
result
[
"sub_shapes"
]
=
migraphx
::
to_value
(
s
.
sub_shapes
());
v
=
result
;
result
[
"dynamic_dimensions"
]
=
migraphx
::
to_value
(
s
.
dyn_dims
());
v
=
result
;
}
}
void
migraphx_from_value
(
const
value
&
v
,
shape
&
s
)
void
migraphx_from_value
(
const
value
&
v
,
shape
&
s
)
{
{
auto
t
=
v
.
at
(
"type"
).
get_string
();
auto
t
=
v
.
at
(
"type"
).
get_string
();
...
@@ -390,9 +555,25 @@ void migraphx_from_value(const value& v, shape& s)
...
@@ -390,9 +555,25 @@ void migraphx_from_value(const value& v, shape& s)
}
}
else
else
{
{
s
=
shape
{
shape
::
parse_type
(
t
),
if
(
v
.
at
(
"dynamic_dimensions"
).
empty
())
v
.
at
(
"lens"
).
to_vector
<
std
::
size_t
>
(),
{
v
.
at
(
"strides"
).
to_vector
<
std
::
size_t
>
()};
s
=
shape
{
shape
::
parse_type
(
t
),
v
.
at
(
"lens"
).
to_vector
<
std
::
size_t
>
(),
v
.
at
(
"strides"
).
to_vector
<
std
::
size_t
>
()};
}
else
{
auto
v_dd
=
v
.
at
(
"dynamic_dimensions"
);
std
::
vector
<
shape
::
dynamic_dimension
>
dyn_dims
(
v
.
at
(
"dynamic_dimensions"
).
size
());
std
::
transform
(
v_dd
.
begin
(),
v_dd
.
end
(),
dyn_dims
.
begin
(),
[](
migraphx
::
value
x
)
{
auto
x_min
=
x
.
at
(
"min"
).
template
to
<
size_t
>();
auto
x_max
=
x
.
at
(
"max"
).
template
to
<
size_t
>();
auto
x_opt
=
x
.
at
(
"opt"
).
template
to
<
size_t
>();
return
shape
::
dynamic_dimension
{
x_min
,
x_max
,
x_opt
};
});
s
=
shape
{
shape
::
parse_type
(
t
),
dyn_dims
};
}
}
}
}
}
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment