Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
9c91c08d
"vscode:/vscode.git/clone" did not exist on "1371159f0bfbd45cf8a323059e6d9bafdfe745cd"
Unverified
Commit
9c91c08d
authored
Jul 07, 2023
by
Chris Austen
Committed by
GitHub
Jul 07, 2023
Browse files
Merge branch 'develop' into enable_navi_32_ci
parents
a56bb11d
c1b8c975
Changes
124
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
583 additions
and
188 deletions
+583
-188
src/include/migraphx/op/convert.hpp
src/include/migraphx/op/convert.hpp
+1
-1
src/include/migraphx/op/multibroadcast.hpp
src/include/migraphx/op/multibroadcast.hpp
+14
-10
src/include/migraphx/op/reshape.hpp
src/include/migraphx/op/reshape.hpp
+117
-5
src/include/migraphx/operation.hpp
src/include/migraphx/operation.hpp
+5
-3
src/include/migraphx/permutation.hpp
src/include/migraphx/permutation.hpp
+2
-2
src/include/migraphx/program.hpp
src/include/migraphx/program.hpp
+3
-0
src/include/migraphx/raw_data.hpp
src/include/migraphx/raw_data.hpp
+1
-0
src/include/migraphx/replace_allocate.hpp
src/include/migraphx/replace_allocate.hpp
+2
-2
src/include/migraphx/shape.hpp
src/include/migraphx/shape.hpp
+23
-1
src/include/migraphx/source_location.hpp
src/include/migraphx/source_location.hpp
+73
-0
src/include/migraphx/target.hpp
src/include/migraphx/target.hpp
+5
-0
src/include/migraphx/value.hpp
src/include/migraphx/value.hpp
+16
-2
src/instruction.cpp
src/instruction.cpp
+2
-0
src/module.cpp
src/module.cpp
+6
-3
src/onnx/onnx_parser.cpp
src/onnx/onnx_parser.cpp
+82
-12
src/onnx/parse_instancenorm.cpp
src/onnx/parse_instancenorm.cpp
+81
-26
src/onnx/parse_where.cpp
src/onnx/parse_where.cpp
+1
-0
src/pass_manager.cpp
src/pass_manager.cpp
+10
-2
src/program.cpp
src/program.cpp
+138
-118
src/promote_literals.cpp
src/promote_literals.cpp
+1
-1
No files found.
src/include/migraphx/op/convert.hpp
View file @
9c91c08d
...
@@ -66,7 +66,7 @@ struct convert : unary<convert>
...
@@ -66,7 +66,7 @@ struct convert : unary<convert>
auto
type
=
target_type
;
auto
type
=
target_type
;
return
[
type
](
auto
x
)
{
return
[
type
](
auto
x
)
{
auto
y
=
x
;
auto
y
=
x
;
shape
::
visit
(
type
,
[
&
](
auto
as
)
{
y
=
std
::
min
(
std
::
max
(
as
(
x
),
as
.
min
()),
as
.
max
()
);
});
shape
::
visit
(
type
,
[
&
](
auto
as
)
{
y
=
as
(
x
);
});
return
y
;
return
y
;
};
};
}
}
...
...
src/include/migraphx/op/multibroadcast.hpp
View file @
9c91c08d
/*
/*
* The MIT License (MIT)
* The MIT License (MIT)
*
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* of this software and associated documentation files (the "Software"), to deal
...
@@ -36,9 +36,9 @@ namespace op {
...
@@ -36,9 +36,9 @@ namespace op {
/**
/**
* Broadcast multiple dimensions between two tensors.
* Broadcast multiple dimensions between two tensors.
* Two versions of this operator:
one
input and
two
inputs.
* Two versions of this operator:
1
input and
2+
inputs.
* One input version uses output_lens attribute and broadcasts to it.
* One input version uses output_lens attribute and broadcasts to it.
*
Two
inputs version broadcasts
both
input
s
to the common shape at evaluation time.
*
2+
inputs version broadcasts
first
input to the common shape at evaluation time.
*/
*/
struct
multibroadcast
struct
multibroadcast
{
{
...
@@ -57,12 +57,12 @@ struct multibroadcast
...
@@ -57,12 +57,12 @@ struct multibroadcast
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
,
2
);
check_shapes
{
inputs
,
*
this
,
true
}.
has
_at_least
(
1
);
auto
t
=
inputs
.
at
(
0
).
type
();
auto
t
=
inputs
.
at
(
0
).
type
();
auto
s0
=
inputs
.
at
(
0
);
auto
s0
=
inputs
.
at
(
0
);
if
(
s0
.
max_lens
().
empty
()
)
if
(
s0
.
ndim
()
<
1
)
{
{
MIGRAPHX_THROW
(
"MULTIBROADCAST: input dimensions should be > 0"
);
MIGRAPHX_THROW
(
"MULTIBROADCAST: input dimensions should be > 0"
);
}
}
...
@@ -81,6 +81,9 @@ struct multibroadcast
...
@@ -81,6 +81,9 @@ struct multibroadcast
if
(
inputs
.
size
()
==
1
)
if
(
inputs
.
size
()
==
1
)
{
{
if
(
s0
.
dynamic
())
MIGRAPHX_THROW
(
"MULTIBROADCAST: Single dynamic input shape not supported. Use two inputs."
);
if
(
s0
.
lens
().
size
()
>
output_lens
.
size
())
if
(
s0
.
lens
().
size
()
>
output_lens
.
size
())
{
{
MIGRAPHX_THROW
(
"MULTIBROADCAST: input dimensions should <= output size"
);
MIGRAPHX_THROW
(
"MULTIBROADCAST: input dimensions should <= output size"
);
...
@@ -102,19 +105,20 @@ struct multibroadcast
...
@@ -102,19 +105,20 @@ struct multibroadcast
}
}
else
else
{
{
//
two
inputs
//
2+
inputs
auto
s1
=
inputs
.
at
(
1
);
if
(
std
::
any_of
(
if
(
s0
.
dynamic
()
or
s1
.
dynamic
())
inputs
.
cbegin
(),
inputs
.
cend
(),
[](
auto
input
)
{
return
input
.
dynamic
()
;
})
)
{
{
if
(
not
output_dyn_dims
.
empty
())
if
(
not
output_dyn_dims
.
empty
())
{
{
return
{
t
,
output_dyn_dims
};
return
{
t
,
output_dyn_dims
};
}
}
return
{
t
,
compute_
broadcasted
_dyn_dims
(
s0
,
s1
)};
return
{
t
,
compute_
common
_dyn_dims
(
inputs
)};
}
}
else
else
{
{
auto
bcast_lens
=
compute_broadcasted_lens
(
s0
.
lens
(),
s1
.
lens
());
// output_lens will not be set for 2+ input version
auto
bcast_lens
=
compute_common_lens
(
inputs
);
auto
offset
=
bcast_lens
.
size
()
-
s0
.
lens
().
size
();
auto
offset
=
bcast_lens
.
size
()
-
s0
.
lens
().
size
();
auto
bcast_strides
=
make_bcast_strides
(
bcast_lens
,
offset
);
auto
bcast_strides
=
make_bcast_strides
(
bcast_lens
,
offset
);
return
{
t
,
std
::
move
(
bcast_lens
),
std
::
move
(
bcast_strides
)};
return
{
t
,
std
::
move
(
bcast_lens
),
std
::
move
(
bcast_strides
)};
...
...
src/include/migraphx/op/reshape.hpp
View file @
9c91c08d
...
@@ -29,6 +29,7 @@
...
@@ -29,6 +29,7 @@
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/value.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/optional.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -96,9 +97,115 @@ struct reshape
...
@@ -96,9 +97,115 @@ struct reshape
return
{
s0
.
type
(),
output_dyn_dims
};
return
{
s0
.
type
(),
output_dyn_dims
};
}
}
template
<
class
Iterator
>
static
auto
compute_end_dim
(
Iterator
start
,
Iterator
last
,
std
::
size_t
dim
)
{
std
::
size_t
x
=
1
;
auto
it
=
std
::
find_if
(
start
,
last
,
[
&
](
auto
i
)
{
x
*=
i
;
return
x
>=
dim
;
});
if
(
x
!=
dim
)
return
start
;
return
it
;
}
template
<
class
DimIterator
,
class
StrideIterator
>
static
auto
can_strides_merge
(
DimIterator
dim_start
,
DimIterator
dim_last
,
StrideIterator
stride_start
,
StrideIterator
stride_last
)
{
assert
(
std
::
distance
(
dim_start
,
dim_last
)
==
std
::
distance
(
stride_start
,
stride_last
));
auto
cstride
=
*
std
::
prev
(
stride_last
);
return
std
::
equal
(
std
::
make_reverse_iterator
(
dim_last
),
std
::
make_reverse_iterator
(
dim_start
+
1
),
std
::
make_reverse_iterator
(
stride_last
-
1
),
std
::
make_reverse_iterator
(
stride_start
),
[
&
](
auto
dim
,
auto
stride
)
{
cstride
*=
dim
;
return
stride
==
cstride
;
});
}
// This will reshape the dimesions of the input shape to use the lens of
// `rdims`. If this can't be done without changing memory layout then it
// will return nullopt
static
optional
<
shape
>
reshape_dims
(
const
shape
&
input
,
const
std
::
vector
<
std
::
size_t
>&
rdims
)
{
if
(
input
.
standard
())
return
shape
{
input
.
type
(),
rdims
};
const
auto
&
idims
=
input
.
lens
();
const
auto
&
istrides
=
input
.
strides
();
std
::
vector
<
std
::
size_t
>
rstrides
;
std
::
size_t
i
=
0
;
std
::
size_t
r
=
0
;
while
(
i
<
idims
.
size
()
and
r
<
rdims
.
size
())
{
auto
idim
=
idims
[
i
];
auto
rdim
=
rdims
[
r
];
if
(
rdim
==
idim
)
{
rstrides
.
push_back
(
istrides
[
i
]);
}
// squeeze
else
if
(
rdim
>
idim
)
{
auto
start
=
idims
.
begin
()
+
i
;
auto
it
=
compute_end_dim
(
start
,
idims
.
end
(),
rdim
);
if
(
it
==
start
)
return
nullopt
;
auto
n
=
it
-
start
;
assert
((
i
+
n
)
<=
istrides
.
size
());
if
(
not
can_strides_merge
(
start
,
it
+
1
,
istrides
.
begin
()
+
i
,
istrides
.
begin
()
+
i
+
n
+
1
))
return
nullopt
;
i
+=
n
;
rstrides
.
push_back
(
istrides
[
i
]);
}
// unsqueeze
else
// if(rdim < idim)
{
auto
start
=
rdims
.
begin
()
+
i
;
auto
it
=
compute_end_dim
(
start
,
rdims
.
end
(),
idim
);
if
(
it
==
start
)
return
nullopt
;
auto
n
=
it
-
start
;
assert
((
r
+
n
)
<=
rdims
.
size
());
auto
stride
=
istrides
[
i
]
*
idim
;
std
::
for_each
(
start
,
it
+
1
,
[
&
](
auto
dim
)
{
stride
/=
dim
;
rstrides
.
push_back
(
stride
);
});
r
+=
n
;
}
i
++
;
r
++
;
}
// Handle trailing 1s
if
(
rstrides
.
size
()
<
rdims
.
size
()
and
not
rstrides
.
empty
())
{
auto
stride
=
rstrides
.
back
();
for
(
auto
d
:
range
(
rdims
.
begin
()
+
rstrides
.
size
(),
rdims
.
end
()))
{
if
(
d
!=
1
)
return
nullopt
;
rstrides
.
push_back
(
stride
);
}
}
if
(
rdims
.
size
()
!=
rstrides
.
size
())
return
nullopt
;
return
shape
{
input
.
type
(),
rdims
,
rstrides
};
}
shape
static_compute_shape
(
std
::
vector
<
shape
>
inputs
,
std
::
size_t
n_neg_dims
)
const
shape
static_compute_shape
(
std
::
vector
<
shape
>
inputs
,
std
::
size_t
n_neg_dims
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
standard
(
);
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
auto
&&
idims
=
inputs
.
front
().
lens
();
auto
&&
idims
=
inputs
.
front
().
lens
();
std
::
vector
<
std
::
size_t
>
rdims
(
dims
.
begin
(),
dims
.
end
());
std
::
vector
<
std
::
size_t
>
rdims
(
dims
.
begin
(),
dims
.
end
());
...
@@ -125,12 +232,17 @@ struct reshape
...
@@ -125,12 +232,17 @@ struct reshape
}
}
}
}
shape
s
{
inputs
.
front
().
type
(),
rdims
};
auto
s
=
reshape_dims
(
inputs
.
front
(),
rdims
);
if
(
s
.
elements
()
!=
inputs
.
front
().
elements
())
if
(
not
s
.
has_value
())
MIGRAPHX_THROW
(
"Reshape on axis that is not packed."
);
if
(
s
->
elements
()
!=
inputs
.
front
().
elements
())
MIGRAPHX_THROW
(
"Reshape: Wrong number of elements for reshape: reshape has "
+
MIGRAPHX_THROW
(
"Reshape: Wrong number of elements for reshape: reshape has "
+
std
::
to_string
(
s
.
elements
())
+
" elements whereas the input has "
+
std
::
to_string
(
s
->
elements
())
+
" elements whereas the input has "
+
std
::
to_string
(
inputs
.
front
().
elements
()));
std
::
to_string
(
inputs
.
front
().
elements
()));
return
s
;
assert
(
s
->
bytes
()
==
inputs
.
front
().
bytes
());
return
*
s
;
}
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
...
...
src/include/migraphx/operation.hpp
View file @
9c91c08d
...
@@ -261,11 +261,13 @@ auto compute_op(rank<1>,
...
@@ -261,11 +261,13 @@ auto compute_op(rank<1>,
template
<
class
T
,
class
F
>
template
<
class
T
,
class
F
>
argument
compute_op
(
rank
<
0
>
,
argument
compute_op
(
rank
<
0
>
,
const
T
&
x
,
const
T
&
x
,
const
shape
&
,
const
shape
&
output
,
const
std
::
vector
<
argument
>&
,
const
std
::
vector
<
argument
>&
inputs
,
const
std
::
vector
<
module_ref
>&
,
const
std
::
vector
<
module_ref
>&
module_args
,
F
)
F
)
{
{
if
(
module_args
.
empty
())
return
compute_op
(
x
,
output
,
inputs
);
std
::
string
name
=
x
.
name
();
std
::
string
name
=
x
.
name
();
MIGRAPHX_THROW
(
"Not computable: "
+
name
);
MIGRAPHX_THROW
(
"Not computable: "
+
name
);
}
}
...
...
src/include/migraphx/permutation.hpp
View file @
9c91c08d
...
@@ -56,12 +56,12 @@ inline std::vector<int64_t> sort_permutation(const Vector& data, Op op)
...
@@ -56,12 +56,12 @@ inline std::vector<int64_t> sort_permutation(const Vector& data, Op op)
}
}
/*!
/*!
* Returns the permutation
needed to apply to the shape
to undo the
current
permutation
* Returns the
inverse
permutation
that could be applied
to undo the
inputted
permutation
*/
*/
std
::
vector
<
int64_t
>
invert_permutation
(
const
std
::
vector
<
int64_t
>&
permutation
);
std
::
vector
<
int64_t
>
invert_permutation
(
const
std
::
vector
<
int64_t
>&
permutation
);
/*!
/*!
* Finds the permutation
most likely from a transpose operator that has been applied to the shape.
* Finds the permutation
that would make the shape not transposed (refering to shape.transposed())
*/
*/
std
::
vector
<
int64_t
>
find_permutation
(
const
shape
&
s
);
std
::
vector
<
int64_t
>
find_permutation
(
const
shape
&
s
);
std
::
vector
<
int64_t
>
find_permutation
(
const
std
::
vector
<
shape
>&
shapes
);
std
::
vector
<
int64_t
>
find_permutation
(
const
std
::
vector
<
shape
>&
shapes
);
...
...
src/include/migraphx/program.hpp
View file @
9c91c08d
...
@@ -79,6 +79,9 @@ struct program
...
@@ -79,6 +79,9 @@ struct program
std
::
vector
<
argument
>
eval
(
parameter_map
params
,
std
::
vector
<
argument
>
eval
(
parameter_map
params
,
execution_environment
exec_env
=
execution_environment
{})
const
;
execution_environment
exec_env
=
execution_environment
{})
const
;
void
finish
()
const
;
std
::
size_t
size
()
const
;
std
::
size_t
size
()
const
;
std
::
vector
<
shape
>
get_output_shapes
()
const
;
std
::
vector
<
shape
>
get_output_shapes
()
const
;
...
...
src/include/migraphx/raw_data.hpp
View file @
9c91c08d
...
@@ -187,6 +187,7 @@ struct raw_data : raw_data_base
...
@@ -187,6 +187,7 @@ struct raw_data : raw_data_base
std
::
string
to_string
()
const
std
::
string
to_string
()
const
{
{
std
::
stringstream
ss
;
std
::
stringstream
ss
;
ss
.
precision
(
std
::
numeric_limits
<
double
>::
max_digits10
);
ss
<<
static_cast
<
const
Derived
&>
(
*
this
);
ss
<<
static_cast
<
const
Derived
&>
(
*
this
);
return
ss
.
str
();
return
ss
.
str
();
}
}
...
...
src/include/migraphx/replace_allocate.hpp
View file @
9c91c08d
...
@@ -30,7 +30,7 @@
...
@@ -30,7 +30,7 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
module
;
struct
module
_pass_manager
;
/**
/**
* Replace `allocate` instructions with target allocations or output parameters.
* Replace `allocate` instructions with target allocations or output parameters.
...
@@ -40,7 +40,7 @@ struct replace_allocate
...
@@ -40,7 +40,7 @@ struct replace_allocate
allocation_model
model
;
allocation_model
model
;
bool
offload_copy
=
false
;
bool
offload_copy
=
false
;
std
::
string
name
()
const
{
return
"replace_allocate"
;
}
std
::
string
name
()
const
{
return
"replace_allocate"
;
}
void
apply
(
module
&
m
)
const
;
void
apply
(
module
_pass_manager
&
mp
m
)
const
;
};
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/shape.hpp
View file @
9c91c08d
...
@@ -156,14 +156,34 @@ struct shape
...
@@ -156,14 +156,34 @@ struct shape
shape
(
const
std
::
vector
<
shape
>&
subs
);
shape
(
const
std
::
vector
<
shape
>&
subs
);
/**
* Creates an output shape with dimensions equal to the input lengths and strides determined
* by the permutation argument such that find_permutation() of the output shape returns the
* inputted permuation.
*
* 2D example:
* parameters:
* l = [2, 3], perm = [1, 0]
* therefore:
* "original" shape = {lens = [3, 2], strides = [2, 1]}
* output_shape = {lens = [2, 3], strides = [1, 2]
*
* 3D example:
* parameters:
* l = [2, 3, 4], perm = [1, 2, 0]
* therefore:
* "original" shape = {lens = [3, 4, 2], strides = [8, 2, 1]}
* output_shape = {lens = [2, 3, 4], strides = [1, 8, 2]}
*/
static
shape
static
shape
from_permutation
(
type_t
t
,
const
std
::
vector
<
std
::
size_t
>&
l
,
const
std
::
vector
<
int64_t
>&
perm
);
from_permutation
(
type_t
t
,
const
std
::
vector
<
std
::
size_t
>&
l
,
const
std
::
vector
<
int64_t
>&
perm
);
type_t
type
()
const
;
type_t
type
()
const
;
const
std
::
vector
<
std
::
size_t
>&
lens
()
const
;
const
std
::
vector
<
std
::
size_t
>&
lens
()
const
;
const
std
::
vector
<
std
::
size_t
>&
strides
()
const
;
const
std
::
vector
<
std
::
size_t
>&
strides
()
const
;
/*!
/*!
* The number of dimensions in the shape.
* The number of dimensions in the shape
, either static or dynamic
.
* Same as the number of indices required to get a data value.
* Same as the number of indices required to get a data value.
*/
*/
std
::
size_t
ndim
()
const
;
std
::
size_t
ndim
()
const
;
...
@@ -279,6 +299,8 @@ struct shape
...
@@ -279,6 +299,8 @@ struct shape
type
min
()
const
{
return
std
::
numeric_limits
<
type
>::
lowest
();
}
type
min
()
const
{
return
std
::
numeric_limits
<
type
>::
lowest
();
}
type
nan
()
const
{
return
std
::
numeric_limits
<
type
>::
quiet_NaN
();
}
template
<
class
U
>
template
<
class
U
>
type
operator
()(
U
u
)
const
type
operator
()(
U
u
)
const
{
{
...
...
src/include/migraphx/source_location.hpp
0 → 100644
View file @
9c91c08d
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHX_SOURCE_LOCATION_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_SOURCE_LOCATION_HPP
#include <migraphx/config.hpp>
#if defined(CPPCHECK)
#define MIGRAPHX_HAS_SOURCE_LOCATION 1
#define MIGRAPHX_HAS_SOURCE_LOCATION_TS 1
#elif defined(__has_include)
#if __has_include(<source_location>) && __cplusplus >= 202003L
#define MIGRAPHX_HAS_SOURCE_LOCATION 1
#else
#define MIGRAPHX_HAS_SOURCE_LOCATION 0
#endif
#if __has_include(<experimental/source_location>) && __cplusplus >= 201103L
#define MIGRAPHX_HAS_SOURCE_LOCATION_TS 1
#else
#define MIGRAPHX_HAS_SOURCE_LOCATION_TS 0
#endif
#else
#define MIGRAPHX_HAS_SOURCE_LOCATION 0
#define MIGRAPHX_HAS_SOURCE_LOCATION_TS 0
#endif
#if MIGRAPHX_HAS_SOURCE_LOCATION
#include <source_location>
#elif MIGRAPHX_HAS_SOURCE_LOCATION_TS
#include <experimental/source_location>
#endif
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
#if MIGRAPHX_HAS_SOURCE_LOCATION
using
source_location
=
std
::
source_location
;
#elif MIGRAPHX_HAS_SOURCE_LOCATION_TS
using
source_location
=
std
::
experimental
::
source_location
;
#else
struct
source_location
{
static
constexpr
source_location
current
()
noexcept
{
return
source_location
{};
}
constexpr
std
::
uint_least32_t
line
()
const
noexcept
{
return
0
;
}
constexpr
std
::
uint_least32_t
column
()
const
noexcept
{
return
0
;
}
constexpr
const
char
*
file_name
()
const
noexcept
{
return
""
;
}
constexpr
const
char
*
function_name
()
const
noexcept
{
return
""
;
}
};
#endif
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_SOURCE_LOCATION_HPP
src/include/migraphx/target.hpp
View file @
9c91c08d
...
@@ -45,6 +45,8 @@
...
@@ -45,6 +45,8 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
value
;
#ifdef DOXYGEN
#ifdef DOXYGEN
/// An interface for a compilation target
/// An interface for a compilation target
...
@@ -467,6 +469,9 @@ inline const ValueType& any_cast(const target& x)
...
@@ -467,6 +469,9 @@ inline const ValueType& any_cast(const target& x)
#endif
#endif
void
migraphx_to_value
(
value
&
v
,
const
target
&
t
);
void
migraphx_from_value
(
const
value
&
v
,
target
&
t
);
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
...
...
src/include/migraphx/value.hpp
View file @
9c91c08d
...
@@ -32,6 +32,7 @@
...
@@ -32,6 +32,7 @@
#include <algorithm>
#include <algorithm>
#include <cassert>
#include <cassert>
#include <memory>
#include <memory>
#include <cstdint>
#include <sstream>
#include <sstream>
#include <type_traits>
#include <type_traits>
#include <tuple>
#include <tuple>
...
@@ -392,8 +393,8 @@ struct value
...
@@ -392,8 +393,8 @@ struct value
return; \
return; \
}
}
MIGRAPHX_VISIT_VALUE_TYPES
(
MIGRAPHX_VALUE_GENERATE_CASE_VALUE
)
MIGRAPHX_VISIT_VALUE_TYPES
(
MIGRAPHX_VALUE_GENERATE_CASE_VALUE
)
MIGRAPHX_VALUE_GENERATE_CASE
(
array
,
)
MIGRAPHX_VALUE_GENERATE_CASE
_VALUE
(
array
,
)
MIGRAPHX_VALUE_GENERATE_CASE
(
object
,
)
MIGRAPHX_VALUE_GENERATE_CASE
_VALUE
(
object
,
)
}
}
MIGRAPHX_THROW
(
"Unknown type"
);
MIGRAPHX_THROW
(
"Unknown type"
);
}
}
...
@@ -461,6 +462,8 @@ struct value
...
@@ -461,6 +462,8 @@ struct value
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
value
&
d
);
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
value
&
d
);
std
::
size_t
hash
()
const
;
void
debug_print
(
bool
show_type
=
false
)
const
;
void
debug_print
(
bool
show_type
=
false
)
const
;
type_t
get_type
()
const
;
type_t
get_type
()
const
;
...
@@ -481,4 +484,15 @@ struct value
...
@@ -481,4 +484,15 @@ struct value
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
namespace
std
{
template
<
>
struct
hash
<
migraphx
::
value
>
{
using
argument_type
=
migraphx
::
value
;
using
result_type
=
std
::
size_t
;
result_type
operator
()(
const
migraphx
::
value
&
x
)
const
{
return
x
.
hash
();
}
};
}
// namespace std
#endif
#endif
src/instruction.cpp
View file @
9c91c08d
...
@@ -473,7 +473,9 @@ operation instruction::normalized_operator() const
...
@@ -473,7 +473,9 @@ operation instruction::normalized_operator() const
return
o
;
return
o
;
}
}
std
::
size_t
instruction
::
get_target_id
()
const
{
return
target_id
;
}
std
::
size_t
instruction
::
get_target_id
()
const
{
return
target_id
;
}
void
instruction
::
set_target_id
(
std
::
size_t
tid
)
{
this
->
target_id
=
tid
;
}
void
instruction
::
set_target_id
(
std
::
size_t
tid
)
{
this
->
target_id
=
tid
;
}
std
::
vector
<
shape
>
to_shapes
(
const
std
::
vector
<
instruction_ref
>&
args
)
std
::
vector
<
shape
>
to_shapes
(
const
std
::
vector
<
instruction_ref
>&
args
)
{
{
std
::
vector
<
shape
>
shapes
(
args
.
size
());
std
::
vector
<
shape
>
shapes
(
args
.
size
());
...
...
src/module.cpp
View file @
9c91c08d
...
@@ -326,6 +326,8 @@ instruction_ref module::replace_instruction(instruction_ref ins, instruction_ref
...
@@ -326,6 +326,8 @@ instruction_ref module::replace_instruction(instruction_ref ins, instruction_ref
if
(
ins
==
std
::
prev
(
this
->
end
()))
if
(
ins
==
std
::
prev
(
this
->
end
()))
{
{
// "rep" instruction could be used earlier in the program and moving it at the end
// may cause invalid program, therefore make an identity operation in this case.
return
replace_instruction
(
ins
,
make_op
(
"identity"
),
rep
);
return
replace_instruction
(
ins
,
make_op
(
"identity"
),
rep
);
}
}
...
@@ -650,8 +652,9 @@ instruction_ref module::find_dangling_reference() const
...
@@ -650,8 +652,9 @@ instruction_ref module::find_dangling_reference() const
return
end
();
return
end
();
}
}
void
module
::
finalize
(
context
&
c
tx
)
void
module
::
finalize
(
std
::
vector
<
context
>
&
c
ontexts
)
{
{
assert
(
not
contexts
.
empty
());
const
bool
trace
=
enabled
(
MIGRAPHX_TRACE_FINALIZE
{});
const
bool
trace
=
enabled
(
MIGRAPHX_TRACE_FINALIZE
{});
for
(
auto
ins
:
iterator_for
(
*
this
))
for
(
auto
ins
:
iterator_for
(
*
this
))
{
{
...
@@ -660,10 +663,10 @@ void module::finalize(context& ctx)
...
@@ -660,10 +663,10 @@ void module::finalize(context& ctx)
std
::
cout
<<
"Finalize: "
;
std
::
cout
<<
"Finalize: "
;
this
->
debug_print
(
ins
);
this
->
debug_print
(
ins
);
}
}
ins
->
finalize
(
c
tx
);
ins
->
finalize
(
c
ontexts
[
ins
->
get_target_id
()]
);
for
(
const
auto
&
smod
:
ins
->
module_inputs
())
for
(
const
auto
&
smod
:
ins
->
module_inputs
())
{
{
smod
->
finalize
(
c
tx
);
smod
->
finalize
(
c
ontexts
);
}
}
}
}
...
...
src/onnx/onnx_parser.cpp
View file @
9c91c08d
...
@@ -38,6 +38,9 @@
...
@@ -38,6 +38,9 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
onnx
{
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_TRACE_ONNX_PARSER
)
static
shape
shape_from_dyn_dims
(
shape
::
type_t
shape_type
,
static
shape
shape_from_dyn_dims
(
shape
::
type_t
shape_type
,
const
std
::
vector
<
shape
::
dynamic_dimension
>&
dyn_dims
)
const
std
::
vector
<
shape
::
dynamic_dimension
>&
dyn_dims
)
...
@@ -53,8 +56,6 @@ static shape shape_from_dyn_dims(shape::type_t shape_type,
...
@@ -53,8 +56,6 @@ static shape shape_from_dyn_dims(shape::type_t shape_type,
return
{
shape_type
,
dyn_dims
};
return
{
shape_type
,
dyn_dims
};
}
}
namespace
onnx
{
static
onnx_parser
::
attribute_map
get_attributes
(
const
onnx
::
NodeProto
&
node
)
static
onnx_parser
::
attribute_map
get_attributes
(
const
onnx
::
NodeProto
&
node
)
{
{
std
::
unordered_map
<
std
::
string
,
onnx
::
AttributeProto
>
result
;
std
::
unordered_map
<
std
::
string
,
onnx
::
AttributeProto
>
result
;
...
@@ -149,6 +150,25 @@ instruction_ref onnx_parser::node_info::add_broadcastable_binary_op(const std::s
...
@@ -149,6 +150,25 @@ instruction_ref onnx_parser::node_info::add_broadcastable_binary_op(const std::s
return
this
->
add_common_op
(
op_name
,
arg0
,
arg1
);
return
this
->
add_common_op
(
op_name
,
arg0
,
arg1
);
}
}
/**
* @brief A wrapper for insert_common_args(), which constructs an argument list
* and inserts multibroadcast and convert ops to match inputs to a common shape and type
* as required. The requested operation is placed after the added multibroadcast and convert ops,
* if any, so that their results are transparent to the programmer.
*
* Use add_common_op() to match input sizes when inputs may be
* either static or dynamic.
*
* @param op_name string; Name of operation (op) to add; valid names are the same as
* for make_op()
*
* @param inputs vector of instruction_ref. List of instructions for the new
* operator. Multibroadcast and convert operations, if needed, are deduced from these too.
*
* @return instruction_ref Returns an instruction_ref which is the result of the requested
* operation.
*
*/
instruction_ref
onnx_parser
::
node_info
::
add_common_op
(
const
std
::
string
&
op_name
,
instruction_ref
onnx_parser
::
node_info
::
add_common_op
(
const
std
::
string
&
op_name
,
std
::
vector
<
instruction_ref
>
inputs
)
const
std
::
vector
<
instruction_ref
>
inputs
)
const
{
{
...
@@ -278,16 +298,48 @@ int64_t onnx_parser::get_opset_version(const onnx::ModelProto& model)
...
@@ -278,16 +298,48 @@ int64_t onnx_parser::get_opset_version(const onnx::ModelProto& model)
return
version
;
return
version
;
}
}
std
::
vector
<
instruction_ref
>
void
print_added_instructions
(
module
*
mod
,
onnx_parser
::
parse_graph
(
module
*
mod
,
const
onnx
::
GraphProto
&
graph
,
bool
inlining
)
const
std
::
vector
<
instruction_ref
>&
args
,
const
std
::
vector
<
instruction_ref
>&
result
)
{
// Print instructions added by the parser not in args
std
::
vector
<
instruction_ref
>
added_instructions
;
fix
([
&
](
auto
self
,
auto
r
)
{
for
(
auto
ins
:
r
)
{
if
(
contains
(
args
,
ins
))
continue
;
if
(
contains
(
added_instructions
,
ins
))
continue
;
self
(
ins
->
inputs
());
added_instructions
.
push_back
(
ins
);
}
})(
result
);
mod
->
debug_print
(
added_instructions
);
}
std
::
unordered_map
<
std
::
string
,
instruction_ref
>
parse_intializer
(
const
onnx_parser
&
parser
,
module
*
mod
,
const
onnx
::
GraphProto
&
graph
)
{
{
std
::
unordered_map
<
std
::
string
,
instruction_ref
>
mod_insts
;
std
::
unordered_map
<
std
::
string
,
instruction_ref
>
mod_insts
;
for
(
auto
&&
f
:
graph
.
initializer
())
for
(
auto
&&
f
:
graph
.
initializer
())
{
{
if
(
enabled
(
MIGRAPHX_TRACE_ONNX_PARSER
{}))
std
::
cout
<<
"initializer: "
<<
f
.
name
()
<<
std
::
endl
;
// backup instructions in parent mod
// backup instructions in parent mod
mod_insts
[
f
.
name
()]
=
mod
->
add_literal
(
parse_tensor
(
f
));
mod_insts
[
f
.
name
()]
=
mod
->
add_literal
(
parser
.
parse_tensor
(
f
));
if
(
enabled
(
MIGRAPHX_TRACE_ONNX_PARSER
{}))
mod
->
debug_print
(
mod_insts
[
f
.
name
()]);
}
}
return
mod_insts
;
}
std
::
unordered_map
<
std
::
string
,
instruction_ref
>
parse_inputs
(
const
onnx_parser
&
parser
,
module
*
mod
,
const
onnx
::
GraphProto
&
graph
,
std
::
unordered_map
<
std
::
string
,
instruction_ref
>
mod_insts
)
{
for
(
auto
&&
input
:
graph
.
input
())
for
(
auto
&&
input
:
graph
.
input
())
{
{
const
std
::
string
&
name
=
input
.
name
();
const
std
::
string
&
name
=
input
.
name
();
...
@@ -298,7 +350,7 @@ onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph, bool inlini
...
@@ -298,7 +350,7 @@ onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph, bool inlini
// scenario that a nested subgraph contains a parameter with the
// scenario that a nested subgraph contains a parameter with the
// name existed in its parent graph.
// name existed in its parent graph.
// In the current implementation, MIGraphX throws an exception for that.
// In the current implementation, MIGraphX throws an exception for that.
if
(
contains
(
instructions
,
name
))
if
(
contains
(
parser
.
instructions
,
name
))
{
{
MIGRAPHX_THROW
(
"module
\"
"
+
mod
->
name
()
+
"
\"
has parameter name
\"
"
+
name
+
MIGRAPHX_THROW
(
"module
\"
"
+
mod
->
name
()
+
"
\"
has parameter name
\"
"
+
name
+
"
\"
existing in parent graph!"
);
"
\"
existing in parent graph!"
);
...
@@ -306,28 +358,41 @@ onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph, bool inlini
...
@@ -306,28 +358,41 @@ onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph, bool inlini
shape
s
;
shape
s
;
std
::
vector
<
std
::
size_t
>
dims
;
std
::
vector
<
std
::
size_t
>
dims
;
if
(
map_input_dims
.
count
(
name
)
>
0
)
if
(
parser
.
map_input_dims
.
count
(
name
)
>
0
)
{
{
dims
=
map_input_dims
.
at
(
name
);
dims
=
parser
.
map_input_dims
.
at
(
name
);
s
=
parse_type
(
input
.
type
(),
dims
);
s
=
parser
.
parse_type
(
input
.
type
(),
dims
);
}
}
else
if
(
map_dyn_input_dims
.
count
(
name
)
>
0
)
else
if
(
parser
.
map_dyn_input_dims
.
count
(
name
)
>
0
)
{
{
shape
::
type_t
shape_type
=
get_type
(
input
.
type
().
tensor_type
().
elem_type
());
shape
::
type_t
shape_type
=
get_type
(
input
.
type
().
tensor_type
().
elem_type
());
s
=
shape_from_dyn_dims
(
shape_type
,
map_dyn_input_dims
.
at
(
name
));
s
=
shape_from_dyn_dims
(
shape_type
,
parser
.
map_dyn_input_dims
.
at
(
name
));
}
}
else
else
{
{
s
=
parse_type
(
input
.
type
(),
dims
);
s
=
parser
.
parse_type
(
input
.
type
(),
dims
);
}
}
mod_insts
[
name
]
=
mod
->
add_parameter
(
name
,
s
);
mod_insts
[
name
]
=
mod
->
add_parameter
(
name
,
s
);
}
}
}
}
return
mod_insts
;
}
std
::
vector
<
instruction_ref
>
onnx_parser
::
parse_graph
(
module
*
mod
,
const
onnx
::
GraphProto
&
graph
,
bool
inlining
)
{
std
::
unordered_map
<
std
::
string
,
instruction_ref
>
mod_insts
=
parse_intializer
(
*
this
,
mod
,
graph
);
mod_insts
=
parse_inputs
(
*
this
,
mod
,
graph
,
mod_insts
);
std
::
copy
(
mod_insts
.
begin
(),
mod_insts
.
end
(),
std
::
inserter
(
instructions
,
instructions
.
end
()));
std
::
copy
(
mod_insts
.
begin
(),
mod_insts
.
end
(),
std
::
inserter
(
instructions
,
instructions
.
end
()));
for
(
auto
&&
node
:
graph
.
node
())
for
(
auto
&&
node
:
graph
.
node
())
{
{
if
(
enabled
(
MIGRAPHX_TRACE_ONNX_PARSER
{}))
std
::
cout
<<
"operator: "
<<
node
.
op_type
()
<<
std
::
endl
;
std
::
vector
<
instruction_ref
>
args
;
std
::
vector
<
instruction_ref
>
args
;
for
(
auto
&&
input
:
node
.
input
())
for
(
auto
&&
input
:
node
.
input
())
{
{
...
@@ -365,6 +430,11 @@ onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph, bool inlini
...
@@ -365,6 +430,11 @@ onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph, bool inlini
result
.
begin
(),
result
.
begin
(),
std
::
inserter
(
instructions
,
instructions
.
end
()),
std
::
inserter
(
instructions
,
instructions
.
end
()),
[](
auto
&&
x
,
auto
&&
y
)
{
return
std
::
make_pair
(
x
,
y
);
});
[](
auto
&&
x
,
auto
&&
y
)
{
return
std
::
make_pair
(
x
,
y
);
});
if
(
enabled
(
MIGRAPHX_TRACE_ONNX_PARSER
{}))
{
print_added_instructions
(
mod
,
args
,
result
);
}
}
}
// Find instructions corresponding to the output
// Find instructions corresponding to the output
...
...
src/onnx/parse_instancenorm.cpp
View file @
9c91c08d
/*
/*
* The MIT License (MIT)
* The MIT License (MIT)
*
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* of this software and associated documentation files (the "Software"), to deal
...
@@ -21,10 +21,14 @@
...
@@ -21,10 +21,14 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
* THE SOFTWARE.
*/
*/
#include <iterator>
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/env.hpp>
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_DISABLE_FP16_INSTANCENORM_CONVERT
);
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -39,54 +43,105 @@ struct parse_instancenorm : op_parser<parse_instancenorm>
...
@@ -39,54 +43,105 @@ struct parse_instancenorm : op_parser<parse_instancenorm>
instruction_ref
parse
(
const
op_desc
&
opd
,
instruction_ref
parse
(
const
op_desc
&
opd
,
const
onnx_parser
&
parser
,
const
onnx_parser
&
parser
,
onnx_parser
::
node_info
info
,
onnx_parser
::
node_info
info
,
std
::
vector
<
instruction_ref
>
args
)
const
std
::
vector
<
instruction_ref
>
o
args
)
const
{
{
// y = scale * ( x - mean ) / sqrt ( variance + epsilon ) + bias
// y = scale * ( x - mean ) / sqrt ( variance + epsilon ) + bias
// mean = reduce_mean({D1, D2, ... Dk}, x)
// mean = reduce_mean({D1, D2, ... Dk}, x)
// variance = reduce_mean({D1, D2, ... Dk}, (x - mean)^2)
// variance = reduce_mean({D1, D2, ... Dk}, (x - mean)^2)
// Convert fp16 to fp32 to workaround for FP16 accuracy issues with reduce_mean/variance.
bool
convert_fp16
=
true
;
if
(
enabled
(
MIGRAPHX_DISABLE_FP16_INSTANCENORM_CONVERT
{}))
{
convert_fp16
=
false
;
}
float
epsilon
=
1e-5
f
;
float
epsilon
=
1e-5
f
;
if
(
contains
(
info
.
attributes
,
"epsilon"
))
if
(
contains
(
info
.
attributes
,
"epsilon"
))
{
{
epsilon
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"epsilon"
)).
at
<
float
>
();
epsilon
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"epsilon"
)).
at
<
float
>
();
}
}
auto
dtype
=
oargs
[
0
]
->
get_shape
().
type
();
auto
literal_dtype
=
dtype
;
std
::
vector
<
instruction_ref
>
args
;
// cppcheck-suppress knownConditionTrueFalse
if
(
dtype
==
shape
::
half_type
and
convert_fp16
)
{
std
::
transform
(
oargs
.
begin
(),
oargs
.
end
(),
std
::
back_inserter
(
args
),
[
&
](
const
auto
i
)
{
return
info
.
add_instruction
(
make_op
(
"convert"
,
{{
"target_type"
,
shape
::
float_type
}}),
i
);
});
literal_dtype
=
shape
::
float_type
;
}
else
{
args
=
oargs
;
}
auto
x
=
args
[
0
];
auto
x
=
args
[
0
];
auto
scale
=
args
[
1
];
auto
scale
=
args
[
1
];
auto
bias
=
args
[
2
];
auto
bias
=
args
[
2
];
auto
dims
=
x
->
get_shape
().
lens
();
auto
dims
=
x
->
get_shape
().
lens
();
auto
dtype
=
x
->
get_shape
().
type
();
if
(
not
contains
(
valid_types
,
dtype
))
if
(
not
contains
(
valid_types
,
dtype
))
MIGRAPHX_THROW
(
opd
.
op_name
+
": invalid output type: "
+
std
::
to_string
(
dtype
)
+
MIGRAPHX_THROW
(
opd
.
op_name
+
": invalid output type: "
+
std
::
to_string
(
dtype
)
+
". Valid types are 1 (float), 10 (half), and 11 (double)."
);
". Valid types are 1 (float), 10 (half), and 11 (double)."
);
auto
ndims
=
dims
.
size
();
bool
dyn_input
=
x
->
get_shape
().
dynamic
();
auto
ndims
=
x
->
get_shape
().
ndim
();
assert
(
ndims
>=
2
);
assert
(
ndims
>=
2
);
auto
kdims
=
ndims
-
2
;
auto
kdims
=
ndims
-
2
;
std
::
vector
<
int64_t
>
axes
(
kdims
);
std
::
vector
<
int64_t
>
axes
(
kdims
);
std
::
iota
(
axes
.
begin
(),
axes
.
end
(),
2
);
std
::
iota
(
axes
.
begin
(),
axes
.
end
(),
2
);
auto
mean
=
info
.
add_instruction
(
make_op
(
"reduce_mean"
,
{{
"axes"
,
axes
}}),
x
);
auto
mean
=
info
.
add_instruction
(
make_op
(
"reduce_mean"
,
{{
"axes"
,
axes
}}),
x
);
auto
mean_bcast
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
dims
}}),
mean
);
// Use add_common_op() to insert multibroadcast/convert instructions where needed when
auto
l0
=
info
.
add_instruction
(
make_op
(
"sqdiff"
),
x
,
mean_bcast
);
// inputs may be either static or dynamic.
auto
variance
=
info
.
add_instruction
(
make_op
(
"reduce_mean"
,
{{
"axes"
,
axes
}}),
l0
);
auto
l1
=
info
.
add_common_op
(
"sub"
,
x
,
mean
);
auto
l1
=
info
.
add_instruction
(
make_op
(
"sub"
),
x
,
mean_bcast
);
// for the fp16, if not converting to fp32 then divide `x` and `mean` by `sqrt(n)` and take
auto
epsilon_literal
=
info
.
add_literal
(
literal
{
shape
{
dtype
},
{
epsilon
}});
// reduce_sum to calculate variance i.e.
auto
epsilon_bcast
=
// var = reduce_sum((x/s_n - mean/s_n)^2) where s_n = sqrt(n)
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
dims
}}),
epsilon_literal
);
std
::
string
reduce_op_name
=
auto
variance_bcast
=
(
dtype
==
shape
::
half_type
and
not
convert_fp16
)
?
"reduce_sum"
:
"reduce_mean"
;
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
dims
}}),
variance
);
if
(
dtype
==
shape
::
half_type
and
not
convert_fp16
)
auto
l2
=
info
.
add_instruction
(
make_op
(
"add"
),
variance_bcast
,
epsilon_bcast
);
{
double
n
=
std
::
accumulate
(
dims
.
begin
()
+
2
,
dims
.
end
(),
1
,
[
&
](
const
auto
&
i
,
const
auto
&
j
)
{
return
i
*
j
;
});
n
=
1.0
/
std
::
sqrt
(
n
);
auto
n_literal
=
info
.
add_literal
(
literal
{
dtype
,
{
n
}});
x
=
info
.
add_common_op
(
"mul"
,
{
x
,
n_literal
});
}
auto
l0
=
info
.
add_common_op
(
"sqdiff"
,
x
,
mean
);
auto
variance
=
info
.
add_instruction
(
make_op
(
reduce_op_name
,
{{
"axes"
,
axes
}}),
l0
);
auto
epsilon_literal
=
info
.
add_literal
(
literal
{
shape
{
literal_dtype
},
{
epsilon
}});
auto
l2
=
info
.
add_common_op
(
"add"
,
variance
,
epsilon_literal
);
auto
l3
=
info
.
add_instruction
(
make_op
(
"rsqrt"
),
l2
);
auto
l3
=
info
.
add_instruction
(
make_op
(
"rsqrt"
),
l2
);
auto
l4
=
info
.
add_instruction
(
make_op
(
"mul"
),
l1
,
l3
);
auto
l4
=
info
.
add_common_op
(
"mul"
,
l1
,
l3
);
auto
scale_bcast
=
info
.
add_instruction
(
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
dims
}}),
scale
);
// add_common_op() doesn't apply the plain broadcast op, so we add that op explicitly for
;
// both scale and bias.
auto
bias_bcast
=
instruction_ref
scale_bcast
;
info
.
add_instruction
(
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
dims
}}),
bias
);
instruction_ref
bias_bcast
;
auto
l5
=
info
.
add_instruction
(
make_op
(
"mul"
),
l4
,
scale_bcast
);
if
(
dyn_input
)
return
info
.
add_instruction
(
make_op
(
"add"
),
l5
,
bias_bcast
);
{
scale_bcast
=
info
.
add_instruction
(
make_op
(
"broadcast"
,
{{
"axis"
,
1
}}),
scale
,
x
);
bias_bcast
=
info
.
add_instruction
(
make_op
(
"broadcast"
,
{{
"axis"
,
1
}}),
bias
,
x
);
}
else
{
scale_bcast
=
info
.
add_instruction
(
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
dims
}}),
scale
);
bias_bcast
=
info
.
add_instruction
(
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
dims
}}),
bias
);
}
auto
l5
=
info
.
add_instruction
(
make_op
(
"mul"
),
l4
,
scale_bcast
);
auto
ret
=
info
.
add_instruction
(
make_op
(
"add"
),
l5
,
bias_bcast
);
if
(
dtype
==
shape
::
half_type
and
convert_fp16
)
{
return
info
.
add_instruction
(
make_op
(
"convert"
,
{{
"target_type"
,
shape
::
half_type
}}),
ret
);
}
return
ret
;
}
}
};
};
...
...
src/onnx/parse_where.cpp
View file @
9c91c08d
...
@@ -56,6 +56,7 @@ struct parse_where : op_parser<parse_where>
...
@@ -56,6 +56,7 @@ struct parse_where : op_parser<parse_where>
auto
lens
=
auto
lens
=
compute_broadcasted_lens
(
args
[
0
]
->
get_shape
().
lens
(),
args
[
1
]
->
get_shape
().
lens
());
compute_broadcasted_lens
(
args
[
0
]
->
get_shape
().
lens
(),
args
[
1
]
->
get_shape
().
lens
());
lens
=
compute_broadcasted_lens
(
lens
,
args
[
2
]
->
get_shape
().
lens
());
lens
=
compute_broadcasted_lens
(
lens
,
args
[
2
]
->
get_shape
().
lens
());
if
(
args
[
0
]
->
get_shape
().
lens
()
!=
lens
)
if
(
args
[
0
]
->
get_shape
().
lens
()
!=
lens
)
{
{
args
[
0
]
=
args
[
0
]
=
...
...
src/pass_manager.cpp
View file @
9c91c08d
...
@@ -68,12 +68,18 @@ void run_pass(program& prog, const pass& p, tracer trace)
...
@@ -68,12 +68,18 @@ void run_pass(program& prog, const pass& p, tracer trace)
struct
module_pm
:
module_pass_manager
struct
module_pm
:
module_pass_manager
{
{
module
*
mod
=
nullptr
;
module
*
mod
=
nullptr
;
module
*
root_mod
=
nullptr
;
tracer
*
t
=
nullptr
;
tracer
*
t
=
nullptr
;
module
*
common_parent
=
nullptr
;
module
*
common_parent
=
nullptr
;
program
*
prog
=
nullptr
;
program
*
prog
=
nullptr
;
module_pm
(
module
*
pmod
=
nullptr
,
tracer
*
pt
=
nullptr
)
:
mod
(
pmod
),
t
(
pt
)
{}
module_pm
(
module
*
pmod
=
nullptr
,
tracer
*
pt
=
nullptr
)
:
mod
(
pmod
),
t
(
pt
)
{}
module_pm
(
module
*
pmod
=
nullptr
,
module
*
rmod
=
nullptr
,
tracer
*
pt
=
nullptr
)
:
mod
(
pmod
),
root_mod
(
rmod
),
t
(
pt
)
{
}
template
<
class
...
Ts
>
template
<
class
...
Ts
>
void
trace
(
Ts
&&
...
xs
)
const
void
trace
(
Ts
&&
...
xs
)
const
{
{
...
@@ -97,6 +103,8 @@ struct module_pm : module_pass_manager
...
@@ -97,6 +103,8 @@ struct module_pm : module_pass_manager
virtual
module
*
get_root_module
()
override
virtual
module
*
get_root_module
()
override
{
{
if
(
root_mod
!=
nullptr
)
return
root_mod
;
assert
(
prog
);
assert
(
prog
);
return
prog
->
get_main_module
();
return
prog
->
get_main_module
();
}
}
...
@@ -140,7 +148,7 @@ void run_passes(program& prog, module_ref root_mod, const std::vector<pass>& pas
...
@@ -140,7 +148,7 @@ void run_passes(program& prog, module_ref root_mod, const std::vector<pass>& pas
continue
;
continue
;
if
(
not
visited
.
insert
(
mod
).
second
)
if
(
not
visited
.
insert
(
mod
).
second
)
continue
;
continue
;
module_pm
mpm
{
mod
,
&
trace
};
module_pm
mpm
{
mod
,
root_mod
,
&
trace
};
mpm
.
prog
=
&
prog
;
mpm
.
prog
=
&
prog
;
auto
parents
=
range
(
tree
.
equal_range
(
mod
));
auto
parents
=
range
(
tree
.
equal_range
(
mod
));
auto
nparents
=
distance
(
parents
);
auto
nparents
=
distance
(
parents
);
...
@@ -164,7 +172,7 @@ void run_passes(module& mod, const std::vector<pass>& passes, tracer trace)
...
@@ -164,7 +172,7 @@ void run_passes(module& mod, const std::vector<pass>& passes, tracer trace)
trace
=
tracer
{
std
::
cout
};
trace
=
tracer
{
std
::
cout
};
for
(
const
auto
&
p
:
passes
)
for
(
const
auto
&
p
:
passes
)
{
{
module_pm
{
&
mod
,
&
trace
}.
run_pass
(
p
);
module_pm
{
&
mod
,
&
mod
,
&
trace
}.
run_pass
(
p
);
}
}
}
}
...
...
src/program.cpp
View file @
9c91c08d
...
@@ -70,9 +70,8 @@ struct program_impl
...
@@ -70,9 +70,8 @@ struct program_impl
{
{
// A map is used to keep references to modules of the program
// A map is used to keep references to modules of the program
std
::
unordered_map
<
std
::
string
,
module
>
modules
;
std
::
unordered_map
<
std
::
string
,
module
>
modules
;
context
ctx
;
std
::
string
target_name
;
std
::
vector
<
context
>
contexts
;
std
::
vector
<
context
>
contexts
;
std
::
vector
<
target
>
targets
;
};
};
program
::
program
()
:
impl
(
std
::
make_unique
<
program_impl
>
())
{
this
->
create_module
(
"main"
);
}
program
::
program
()
:
impl
(
std
::
make_unique
<
program_impl
>
())
{
this
->
create_module
(
"main"
);
}
...
@@ -96,14 +95,8 @@ void program::assign(const program& p)
...
@@ -96,14 +95,8 @@ void program::assign(const program& p)
{
{
impl
=
std
::
make_unique
<
program_impl
>
();
impl
=
std
::
make_unique
<
program_impl
>
();
}
}
else
if
(
not
impl
->
modules
.
empty
())
{
impl
->
modules
.
clear
();
}
impl
->
ctx
=
p
.
impl
->
ctx
;
*
impl
=
*
p
.
impl
;
impl
->
target_name
=
p
.
impl
->
target_name
;
impl
->
modules
=
p
.
impl
->
modules
;
// build a map from old ins to new ins
// build a map from old ins to new ins
// Build a map from old module to new module
// Build a map from old module to new module
...
@@ -166,7 +159,11 @@ std::vector<shape> program::get_output_shapes() const
...
@@ -166,7 +159,11 @@ std::vector<shape> program::get_output_shapes() const
return
mm
->
get_output_shapes
();
return
mm
->
get_output_shapes
();
}
}
context
&
program
::
get_context
()
const
{
return
impl
->
ctx
;
}
context
&
program
::
get_context
()
const
{
assert
(
impl
->
contexts
.
size
()
==
1
);
return
impl
->
contexts
.
front
();
}
instruction_ref
program
::
validate
()
const
instruction_ref
program
::
validate
()
const
{
{
...
@@ -217,7 +214,7 @@ target_assignments program::get_target_assignments(const std::vector<target>& ta
...
@@ -217,7 +214,7 @@ target_assignments program::get_target_assignments(const std::vector<target>& ta
return
p
;
return
p
;
}
}
bool
program
::
is_compiled
()
const
{
return
not
this
->
impl
->
target_name
.
empty
();
}
bool
program
::
is_compiled
()
const
{
return
not
this
->
impl
->
contexts
.
empty
();
}
void
program
::
compile
(
const
std
::
vector
<
target
>&
targets
,
std
::
vector
<
compile_options
>
compile_opts
)
void
program
::
compile
(
const
std
::
vector
<
target
>&
targets
,
std
::
vector
<
compile_options
>
compile_opts
)
{
{
...
@@ -299,24 +296,24 @@ void program::compile(const std::vector<target>& targets, std::vector<compile_op
...
@@ -299,24 +296,24 @@ void program::compile(const std::vector<target>& targets, std::vector<compile_op
MIGRAPHX_THROW
(
"Dangling reference in module "
+
current_mod
->
name
()
+
MIGRAPHX_THROW
(
"Dangling reference in module "
+
current_mod
->
name
()
+
" from instruction "
+
std
::
to_string
(
index
));
" from instruction "
+
std
::
to_string
(
index
));
}
}
current_mod
->
finalize
(
this
->
impl
->
contexts
[
root_target_id
]);
}
}
}
}
this
->
finalize
();
}
}
void
program
::
compile
(
const
target
&
t
,
compile_options
options
)
void
program
::
compile
(
const
target
&
t
,
compile_options
options
)
{
{
// todo: combine with multi-target compile method
// todo: combine with multi-target compile method
assert
(
not
this
->
is_compiled
());
assert
(
not
this
->
is_compiled
());
this
->
impl
->
target
_name
=
t
.
name
()
;
this
->
impl
->
target
s
=
{
t
}
;
this
->
impl
->
c
tx
=
t
.
get_context
();
this
->
impl
->
c
ontexts
=
{
t
.
get_context
()
}
;
if
(
enabled
(
MIGRAPHX_TRACE_COMPILE
{}))
if
(
enabled
(
MIGRAPHX_TRACE_COMPILE
{}))
options
.
trace
=
tracer
{
std
::
cout
};
options
.
trace
=
tracer
{
std
::
cout
};
options
.
trace
(
*
this
);
options
.
trace
(
*
this
);
options
.
trace
();
options
.
trace
();
auto
&&
passes
=
t
.
get_passes
(
this
->
impl
->
c
tx
,
options
);
auto
&&
passes
=
t
.
get_passes
(
this
->
impl
->
c
ontexts
.
front
()
,
options
);
run_passes
(
*
this
,
passes
,
options
.
trace
);
run_passes
(
*
this
,
passes
,
options
.
trace
);
auto
mods
=
this
->
get_modules
();
auto
mods
=
this
->
get_modules
();
// Validate and finalize
// Validate and finalize
...
@@ -335,14 +332,14 @@ void program::compile(const target& t, compile_options options)
...
@@ -335,14 +332,14 @@ void program::compile(const target& t, compile_options options)
MIGRAPHX_THROW
(
"Dangling reference in module "
+
mod
->
name
()
+
" from instruction "
+
MIGRAPHX_THROW
(
"Dangling reference in module "
+
mod
->
name
()
+
" from instruction "
+
std
::
to_string
(
index
));
std
::
to_string
(
index
));
}
}
mod
->
finalize
(
this
->
impl
->
c
tx
);
mod
->
finalize
(
this
->
impl
->
c
ontexts
);
}
}
}
}
void
program
::
finalize
()
void
program
::
finalize
()
{
{
auto
*
mm
=
this
->
get_main_module
();
auto
*
mm
=
this
->
get_main_module
();
mm
->
finalize
(
this
->
impl
->
c
tx
);
mm
->
finalize
(
this
->
impl
->
c
ontexts
);
}
}
template
<
class
T
>
template
<
class
T
>
...
@@ -359,6 +356,31 @@ std::string classify(T x)
...
@@ -359,6 +356,31 @@ std::string classify(T x)
}
}
}
}
void
print_statistics
(
std
::
ostream
&
os
,
const
argument
&
a
)
{
a
.
visit
(
[
&
](
auto
t
)
{
os
<<
"Min value: "
<<
*
std
::
min_element
(
t
.
begin
(),
t
.
end
())
<<
", "
;
os
<<
"Max value: "
<<
*
std
::
max_element
(
t
.
begin
(),
t
.
end
())
<<
", "
;
double
num_elements
=
t
.
size
();
auto
mean
=
std
::
accumulate
(
t
.
begin
(),
t
.
end
(),
0.0
)
/
num_elements
;
auto
stddev
=
std
::
sqrt
(
std
::
accumulate
(
t
.
begin
(),
t
.
end
(),
0.0
,
[
&
](
auto
r
,
auto
v
)
{
return
r
+
std
::
pow
((
v
-
mean
),
2.0
);
})
/
num_elements
);
os
<<
"Mean: "
<<
mean
<<
", "
;
os
<<
"StdDev: "
<<
stddev
<<
"
\n
"
;
},
[
&
](
const
auto
&
xs
)
{
for
(
const
auto
&
x
:
xs
)
{
print_statistics
(
os
,
x
);
}
});
}
std
::
unordered_set
<
std
::
string
>
classify_argument
(
const
argument
&
a
)
std
::
unordered_set
<
std
::
string
>
classify_argument
(
const
argument
&
a
)
{
{
std
::
unordered_set
<
std
::
string
>
result
;
std
::
unordered_set
<
std
::
string
>
result
;
...
@@ -404,16 +426,15 @@ void preview_argument(std::ostream& os, const argument& a)
...
@@ -404,16 +426,15 @@ void preview_argument(std::ostream& os, const argument& a)
template
<
class
F
>
template
<
class
F
>
std
::
vector
<
argument
>
generic_eval
(
const
module
*
mod
,
std
::
vector
<
argument
>
generic_eval
(
const
module
*
mod
,
context
&
ctx
,
std
::
vector
<
context
>
&
ctx
,
std
::
unordered_map
<
std
::
string
,
argument
>
params
,
std
::
unordered_map
<
std
::
string
,
argument
>
params
,
std
::
unordered_map
<
instruction_ref
,
argument
>
results
,
std
::
unordered_map
<
instruction_ref
,
argument
>
results
,
F
make_
trace
)
F
trace
)
{
{
assert
(
mod
->
validate
()
==
mod
->
end
());
assert
(
mod
->
validate
()
==
mod
->
end
());
results
.
reserve
(
mod
->
size
()
*
2
);
results
.
reserve
(
mod
->
size
()
*
2
);
std
::
vector
<
argument
>
values
;
std
::
vector
<
argument
>
values
;
values
.
reserve
(
16
);
values
.
reserve
(
16
);
auto
trace
=
make_trace
(
mod
);
for
(
auto
ins
:
iterator_for
(
*
mod
))
for
(
auto
ins
:
iterator_for
(
*
mod
))
{
{
assert
(
results
.
find
(
ins
)
==
results
.
end
());
assert
(
results
.
find
(
ins
)
==
results
.
end
());
...
@@ -469,14 +490,19 @@ std::vector<argument> generic_eval(const module* mod,
...
@@ -469,14 +490,19 @@ std::vector<argument> generic_eval(const module* mod,
const
auto
&
mod_args
=
ins
->
module_inputs
();
const
auto
&
mod_args
=
ins
->
module_inputs
();
auto
module_eval
=
[
&
](
module_ref
smod
,
auto
module_eval
=
[
&
](
module_ref
smod
,
const
std
::
unordered_map
<
std
::
string
,
argument
>&
inputs
)
{
const
std
::
unordered_map
<
std
::
string
,
argument
>&
inputs
)
{
auto
ssctx
=
ctx
;
return
generic_eval
(
smod
,
ctx
,
inputs
,
results
,
trace
);
return
generic_eval
(
smod
,
ssctx
,
inputs
,
results
,
make_trace
);
};
};
results
.
emplace
(
ins
,
trace
(
ins
,
[
&
]
{
results
.
emplace
(
return
ins
->
normalized_operator
().
compute
(
ins
,
trace
(
ins
,
[
&
]
{
ctx
,
ins
->
get_shape
(),
values
,
mod_args
,
module_eval
);
auto
op
=
ins
->
normalized_operator
();
}));
if
(
op
.
is_context_free
())
return
op
.
compute
(
ins
->
get_shape
(),
values
,
mod_args
,
module_eval
);
if
(
ins
->
get_target_id
()
>=
ctx
.
size
())
MIGRAPHX_THROW
(
"No context available for "
+
op
.
name
());
return
op
.
compute
(
ctx
[
ins
->
get_target_id
()],
ins
->
get_shape
(),
values
,
mod_args
,
module_eval
);
}));
}
}
assert
(
results
.
find
(
ins
)
!=
results
.
end
());
assert
(
results
.
find
(
ins
)
!=
results
.
end
());
if
(
not
ins
->
get_shape
().
any_of_dynamic
())
if
(
not
ins
->
get_shape
().
any_of_dynamic
())
...
@@ -489,44 +515,25 @@ std::vector<argument> generic_eval(const module* mod,
...
@@ -489,44 +515,25 @@ std::vector<argument> generic_eval(const module* mod,
template
<
class
F
>
template
<
class
F
>
std
::
vector
<
argument
>
generic_eval
(
const
program
&
p
,
std
::
vector
<
argument
>
generic_eval
(
const
program
&
p
,
context
&
ctx
,
std
::
vector
<
context
>
&
ctx
,
std
::
unordered_map
<
std
::
string
,
argument
>
params
,
std
::
unordered_map
<
std
::
string
,
argument
>
params
,
F
make_
trace
)
F
trace
)
{
{
const
module
*
mm
=
p
.
get_main_module
();
const
module
*
mm
=
p
.
get_main_module
();
return
generic_eval
(
mm
,
ctx
,
params
,
{},
make_
trace
);
return
generic_eval
(
mm
,
ctx
,
params
,
{},
trace
);
}
}
std
::
vector
<
argument
>
program
::
eval
(
parameter_map
params
,
execution_environment
exec_env
)
const
std
::
vector
<
argument
>
program
::
eval
(
parameter_map
params
,
execution_environment
exec_env
)
const
{
{
auto
&
ctx
=
this
->
impl
->
ctx
;
auto
&
contexts
=
this
->
impl
->
contexts
;
#ifndef NDEBUG
auto
with_check_context
=
[
&
](
auto
f
)
{
return
[
=
,
&
ctx
](
auto
&&
)
{
auto
sctx
=
std
::
make_shared
<
context
>
(
ctx
);
auto
check_context
=
[
=
,
&
ctx
](
auto
g
)
{
assert
(
is_shared
(
ctx
,
*
sctx
));
auto
x
=
g
();
*
sctx
=
ctx
;
return
x
;
};
return
[
=
](
auto
&&
...
xs
)
{
return
f
(
xs
...,
check_context
);
};
};
};
#else
auto
with_check_context
=
[](
auto
f
)
{
return
[
=
](
auto
&&
)
{
return
[
=
](
auto
&&
...
xs
)
{
return
f
(
xs
...,
[](
auto
g
)
{
return
g
();
});
};
};
};
#endif
auto
trace_level
=
value_of
(
MIGRAPHX_TRACE_EVAL
{});
auto
trace_level
=
value_of
(
MIGRAPHX_TRACE_EVAL
{});
std
::
vector
<
argument
>
ret
;
std
::
vector
<
argument
>
ret
;
if
(
exec_env
.
async
)
if
(
exec_env
.
async
)
{
{
ctx
.
wait_for
(
exec_env
.
queue
);
assert
(
contexts
.
size
()
==
1
);
contexts
.
front
().
wait_for
(
exec_env
.
queue
);
}
}
if
(
trace_level
>
0
)
if
(
trace_level
>
0
)
...
@@ -538,68 +545,79 @@ std::vector<argument> program::eval(parameter_map params, execution_environment
...
@@ -538,68 +545,79 @@ std::vector<argument> program::eval(parameter_map params, execution_environment
instruction
::
print
(
ss
,
x
,
ins_names
);
instruction
::
print
(
ss
,
x
,
ins_names
);
ins_out
[
x
]
=
ss
.
str
();
ins_out
[
x
]
=
ss
.
str
();
});
});
ret
=
generic_eval
(
*
this
,
contexts
,
std
::
move
(
params
),
[
&
](
instruction_ref
ins
,
auto
f
)
{
ret
=
generic_eval
(
*
this
,
auto
&
ctx
=
contexts
[
ins
->
get_target_id
()];
ctx
,
ctx
.
finish
();
std
::
move
(
params
),
std
::
cout
<<
"Run instruction: "
<<
ins_out
.
at
(
ins
)
<<
std
::
endl
;
with_check_context
([
&
](
auto
&
ins
,
auto
f
,
auto
&&
check_context
)
{
timer
t
{};
ctx
.
finish
();
auto
result
=
f
();
std
::
cout
<<
"Run instruction: "
<<
ins_out
.
at
(
ins
)
<<
std
::
endl
;
double
t1
=
t
.
record
<
milliseconds
>
();
timer
t
{};
ctx
.
finish
();
auto
result
=
check_context
(
f
);
double
t2
=
t
.
record
<
milliseconds
>
();
double
t1
=
t
.
record
<
milliseconds
>
();
std
::
cout
<<
"Time: "
<<
t1
<<
"ms, "
<<
t2
<<
"ms"
<<
std
::
endl
;
ctx
.
finish
();
if
(
trace_level
>
1
and
ins
->
name
().
front
()
!=
'@'
and
ins
->
name
()
!=
"load"
and
double
t2
=
t
.
record
<
milliseconds
>
();
not
result
.
empty
())
std
::
cout
<<
"Time: "
<<
t1
<<
"ms, "
<<
t2
<<
"ms"
<<
std
::
endl
;
{
if
(
trace_level
>
1
and
ins
->
name
().
front
()
!=
'@'
and
migraphx
::
argument
buffer
;
ins
->
name
()
!=
"load"
and
not
result
.
empty
())
try
{
{
target
tgt
=
make_target
(
this
->
impl
->
target_name
);
const
target
&
tgt
=
this
->
impl
->
targets
.
at
(
ins
->
get_target_id
());
auto
buffer
=
tgt
.
copy_from
(
result
);
buffer
=
tgt
.
copy_from
(
result
);
if
(
trace_level
==
2
)
}
{
catch
(
const
migraphx
::
exception
&
)
std
::
cout
<<
"Output has "
{
<<
to_string_range
(
classify_argument
(
buffer
))
// instruction was run on host then no need to copy buffer from target
<<
std
::
endl
;
buffer
=
result
;
std
::
cout
<<
"Output: "
;
}
preview_argument
(
std
::
cout
,
buffer
);
catch
(...)
std
::
cout
<<
std
::
endl
;
{
}
MIGRAPHX_THROW
(
"MIGraphX program execution with MIGRAPHX_TRACE_EVAL failed.
\n
"
);
else
}
{
if
(
trace_level
==
2
)
std
::
cout
<<
"Output: "
<<
buffer
<<
std
::
endl
;
{
}
std
::
cout
<<
"Output has "
<<
to_string_range
(
classify_argument
(
buffer
))
}
<<
std
::
endl
;
return
result
;
std
::
cout
<<
"Output: "
;
}));
preview_argument
(
std
::
cout
,
buffer
);
std
::
cout
<<
std
::
endl
;
print_statistics
(
std
::
cout
,
buffer
);
}
else
{
std
::
cout
<<
"Output: "
<<
buffer
<<
std
::
endl
;
}
}
return
result
;
});
}
}
else
else
{
{
ret
=
generic_eval
(
*
this
,
ret
=
generic_eval
(
*
this
,
contexts
,
std
::
move
(
params
),
[
&
](
auto
&&
,
auto
f
)
{
return
f
();
});
ctx
,
std
::
move
(
params
),
with_check_context
([
&
](
auto
&
,
auto
f
,
auto
&&
check_context
)
{
return
check_context
(
f
);
}));
}
}
if
(
exec_env
.
async
)
if
(
exec_env
.
async
)
{
{
ctx
.
finish_on
(
exec_env
.
queue
);
assert
(
contexts
.
size
()
==
1
);
contexts
.
front
().
finish_on
(
exec_env
.
queue
);
}
}
return
ret
;
return
ret
;
}
}
const
int
program_file_version
=
5
;
void
program
::
finish
()
const
{
for
(
const
auto
&
ctx
:
this
->
impl
->
contexts
)
ctx
.
finish
();
}
const
int
program_file_version
=
6
;
value
program
::
to_value
()
const
value
program
::
to_value
()
const
{
{
value
result
;
value
result
;
result
[
"version"
]
=
program_file_version
;
result
[
"version"
]
=
program_file_version
;
result
[
"target"
]
=
this
->
impl
->
target_name
;
result
[
"targets"
]
=
migraphx
::
to_value
(
this
->
impl
->
targets
);
if
(
not
this
->
impl
->
target_name
.
empty
())
result
[
"contexts"
]
=
migraphx
::
to_value
(
this
->
impl
->
contexts
);
result
[
"context"
]
=
this
->
impl
->
ctx
.
to_value
();
value
module_vals
=
value
::
object
{};
value
module_vals
=
value
::
object
{};
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
names
;
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
names
;
...
@@ -728,12 +746,12 @@ void program::from_value(const value& v)
...
@@ -728,12 +746,12 @@ void program::from_value(const value& v)
MIGRAPHX_THROW
(
"Warning: Program version mismatch"
);
MIGRAPHX_THROW
(
"Warning: Program version mismatch"
);
}
}
this
->
impl
->
target_name
=
v
.
at
(
"target"
).
to
<
std
::
string
>
();
migraphx
::
from_value
(
v
.
at
(
"targets"
),
this
->
impl
->
targets
);
if
(
not
this
->
impl
->
target_name
.
empty
())
for
(
auto
i
:
range
(
this
->
impl
->
targets
.
size
()))
{
{
target
t
=
make_target
(
this
->
impl
->
target_name
);
this
->
impl
->
contexts
.
push_back
(
this
->
impl
->
targets
[
i
].
get_context
());
this
->
impl
->
ctx
=
t
.
get_context
();
this
->
impl
->
contexts
.
back
().
from_value
(
v
.
at
(
"contexts"
)[
i
]);
this
->
impl
->
ctx
.
from_value
(
v
.
at
(
"context"
));
}
}
auto
module_vals
=
v
.
at
(
"modules"
);
auto
module_vals
=
v
.
at
(
"modules"
);
...
@@ -754,7 +772,9 @@ void program::from_value(const value& v)
...
@@ -754,7 +772,9 @@ void program::from_value(const value& v)
auto
*
mm
=
get_main_module
();
auto
*
mm
=
get_main_module
();
mod_from_val
(
mm
,
module_vals
,
map_insts
,
map_mods
);
mod_from_val
(
mm
,
module_vals
,
map_insts
,
map_mods
);
this
->
finalize
();
// Finalize a compiled model
if
(
not
this
->
impl
->
contexts
.
empty
())
this
->
finalize
();
}
}
double
common_average
(
const
std
::
vector
<
double
>&
v
)
double
common_average
(
const
std
::
vector
<
double
>&
v
)
...
@@ -774,19 +794,19 @@ std::string perf_group(const operation& op)
...
@@ -774,19 +794,19 @@ std::string perf_group(const operation& op)
void
program
::
mark
(
const
parameter_map
&
params
,
marker
&&
m
)
void
program
::
mark
(
const
parameter_map
&
params
,
marker
&&
m
)
{
{
auto
&
ctx
=
this
->
impl
->
c
tx
;
auto
&
ctx
=
this
->
impl
->
c
ontexts
;
// Run once by itself
// Run once by itself
eval
(
params
);
eval
(
params
);
ctx
.
finish
();
this
->
finish
();
// Start marking
// Start marking
m
.
mark_start
(
*
this
);
m
.
mark_start
(
*
this
);
generic_eval
(
*
this
,
ctx
,
params
,
always
(
[
&
](
auto
ins
,
auto
f
)
{
generic_eval
(
*
this
,
ctx
,
params
,
[
&
](
auto
ins
,
auto
f
)
{
argument
result
;
argument
result
;
m
.
mark_start
(
ins
);
m
.
mark_start
(
ins
);
result
=
f
();
result
=
f
();
m
.
mark_stop
(
ins
);
m
.
mark_stop
(
ins
);
return
result
;
return
result
;
})
)
;
});
m
.
mark_stop
(
*
this
);
m
.
mark_stop
(
*
this
);
}
}
...
@@ -795,10 +815,10 @@ void program::perf_report(std::ostream& os,
...
@@ -795,10 +815,10 @@ void program::perf_report(std::ostream& os,
parameter_map
params
,
parameter_map
params
,
std
::
size_t
batch
)
const
std
::
size_t
batch
)
const
{
{
auto
&
ctx
=
this
->
impl
->
c
tx
;
auto
&
ctx
=
this
->
impl
->
c
ontexts
;
// Run once by itself
// Run once by itself
eval
(
params
);
eval
(
params
);
ctx
.
finish
();
this
->
finish
();
// Run and time entire program
// Run and time entire program
std
::
vector
<
double
>
total_vec
;
std
::
vector
<
double
>
total_vec
;
total_vec
.
reserve
(
n
);
total_vec
.
reserve
(
n
);
...
@@ -806,28 +826,28 @@ void program::perf_report(std::ostream& os,
...
@@ -806,28 +826,28 @@ void program::perf_report(std::ostream& os,
{
{
total_vec
.
push_back
(
time
<
milliseconds
>
([
&
]
{
total_vec
.
push_back
(
time
<
milliseconds
>
([
&
]
{
eval
(
params
);
eval
(
params
);
ctx
.
finish
();
this
->
finish
();
}));
}));
}
}
std
::
sort
(
total_vec
.
begin
(),
total_vec
.
end
());
std
::
sort
(
total_vec
.
begin
(),
total_vec
.
end
());
std
::
unordered_map
<
instruction_ref
,
std
::
vector
<
double
>>
ins_vec
;
std
::
unordered_map
<
instruction_ref
,
std
::
vector
<
double
>>
ins_vec
;
// Fill the map
// Fill the map
generic_eval
(
*
this
,
ctx
,
params
,
always
(
[
&
](
auto
ins
,
auto
)
{
generic_eval
(
*
this
,
ctx
,
params
,
[
&
](
auto
ins
,
auto
)
{
ins_vec
[
ins
].
reserve
(
n
);
ins_vec
[
ins
].
reserve
(
n
);
return
argument
{
ins
->
get_shape
(),
nullptr
};
return
argument
{
ins
->
get_shape
(),
nullptr
};
})
)
;
});
// Run and time each instruction
// Run and time each instruction
for
(
std
::
size_t
i
=
0
;
i
<
n
;
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
n
;
i
++
)
{
{
generic_eval
(
*
this
,
ctx
,
params
,
always
(
[
&
](
auto
ins
,
auto
f
)
{
generic_eval
(
*
this
,
ctx
,
params
,
[
&
](
auto
ins
,
auto
f
)
{
argument
result
;
argument
result
;
ins_vec
[
ins
].
push_back
(
time
<
milliseconds
>
([
&
]
{
ins_vec
[
ins
].
push_back
(
time
<
milliseconds
>
([
&
]
{
result
=
f
();
result
=
f
();
ctx
.
finish
();
this
->
impl
->
contexts
[
ins
->
get_target_id
()]
.
finish
();
}));
}));
return
result
;
return
result
;
})
)
;
});
}
}
for
(
auto
&&
p
:
ins_vec
)
for
(
auto
&&
p
:
ins_vec
)
std
::
sort
(
p
.
second
.
begin
(),
p
.
second
.
end
());
std
::
sort
(
p
.
second
.
begin
(),
p
.
second
.
end
());
...
@@ -995,10 +1015,10 @@ void program::print_cpp(std::ostream& os) const
...
@@ -995,10 +1015,10 @@ void program::print_cpp(std::ostream& os) const
void
program
::
dry_run
(
std
::
unordered_map
<
std
::
string
,
argument
>
params
)
const
void
program
::
dry_run
(
std
::
unordered_map
<
std
::
string
,
argument
>
params
)
const
{
{
auto
&
ctx
=
this
->
impl
->
c
tx
;
auto
&
ctx
=
this
->
impl
->
c
ontexts
;
generic_eval
(
*
this
,
ctx
,
std
::
move
(
params
),
always
(
[](
auto
ins
,
auto
&&
...)
{
generic_eval
(
*
this
,
ctx
,
std
::
move
(
params
),
[](
auto
ins
,
auto
&&
...)
{
return
argument
{
ins
->
get_shape
(),
nullptr
};
return
argument
{
ins
->
get_shape
(),
nullptr
};
})
)
;
});
}
}
void
program
::
annotate
(
std
::
ostream
&
os
,
const
std
::
function
<
void
(
instruction_ref
)
>&
a
)
const
void
program
::
annotate
(
std
::
ostream
&
os
,
const
std
::
function
<
void
(
instruction_ref
)
>&
a
)
const
...
...
src/promote_literals.cpp
View file @
9c91c08d
...
@@ -34,7 +34,7 @@ void promote_literals::apply(module_pass_manager& mpm) const
...
@@ -34,7 +34,7 @@ void promote_literals::apply(module_pass_manager& mpm) const
{
{
module
&
m
=
mpm
.
get_module
();
module
&
m
=
mpm
.
get_module
();
module_ref
root_module
=
mpm
.
get_root_module
();
module_ref
root_module
=
mpm
.
get_root_module
();
if
(
m
.
name
()
==
"main"
)
if
(
m
==
*
root_module
)
return
;
return
;
for
(
auto
ins
:
iterator_for
(
m
))
for
(
auto
ins
:
iterator_for
(
m
))
...
...
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment