Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
0ffcccbc
Unverified
Commit
0ffcccbc
authored
Feb 03, 2023
by
Chris Austen
Committed by
GitHub
Feb 03, 2023
Browse files
Merge branch 'develop' into jit-reduce-reg
parents
4f12db9e
2b5c5f5e
Changes
32
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
923 additions
and
75 deletions
+923
-75
.clang-tidy
.clang-tidy
+2
-0
.github/workflows/ci.yaml
.github/workflows/ci.yaml
+4
-3
Dockerfile
Dockerfile
+1
-1
hip-clang.docker
hip-clang.docker
+1
-1
src/CMakeLists.txt
src/CMakeLists.txt
+1
-2
src/include/migraphx/half.hpp
src/include/migraphx/half.hpp
+2
-2
src/include/migraphx/instruction_ref.hpp
src/include/migraphx/instruction_ref.hpp
+2
-2
src/include/migraphx/memory_coloring.hpp
src/include/migraphx/memory_coloring.hpp
+1
-1
src/include/migraphx/op/gathernd.hpp
src/include/migraphx/op/gathernd.hpp
+88
-17
src/include/migraphx/op/scatternd_op.hpp
src/include/migraphx/op/scatternd_op.hpp
+66
-21
src/memory_coloring.cpp
src/memory_coloring.cpp
+405
-0
src/onnx/include/migraphx/onnx/onnx_parser.hpp
src/onnx/include/migraphx/onnx/onnx_parser.hpp
+2
-1
src/onnx/onnx_parser.cpp
src/onnx/onnx_parser.cpp
+13
-7
src/onnx/parse_if.cpp
src/onnx/parse_if.cpp
+20
-2
src/onnx/parse_loop.cpp
src/onnx/parse_loop.cpp
+1
-1
src/pass_manager.cpp
src/pass_manager.cpp
+11
-10
src/program.cpp
src/program.cpp
+2
-1
test/memory_coloring_test.cpp
test/memory_coloring_test.cpp
+22
-3
test/onnx/gathernd_dyn_test.onnx
test/onnx/gathernd_dyn_test.onnx
+0
-0
test/onnx/gen_onnx.py
test/onnx/gen_onnx.py
+279
-0
No files found.
.clang-tidy
View file @
0ffcccbc
...
...
@@ -9,6 +9,8 @@ CheckOptions:
value: risky
- key: modernize-loop-convert.NamingStyle
value: lower_case
- key: misc-const-correctness.AnalyzeValues
value: 'false'
- key: performance-unnecessary-copy-initialization.AllowedTypes
value: 'shape'
- key: performance-unnecessary-value-param.AllowedTypes
...
...
.github/workflows/ci.yaml
View file @
0ffcccbc
...
...
@@ -32,7 +32,8 @@ jobs:
# In this step, this action saves a list of existing images,
# the cache is created without them in the post run.
# It also restores the cache if it exists.
-
uses
:
satackey/action-docker-layer-caching@v0.0.11
# name: Docker Layer Caching2
-
uses
:
jpribyl/action-docker-layer-caching@v0.1.1
# Ignore the failure of a step and avoid terminating the job.
continue-on-error
:
true
...
...
@@ -81,7 +82,7 @@ jobs:
# In this step, this action saves a list of existing images,
# the cache is created without them in the post run.
# It also restores the cache if it exists.
-
uses
:
satackey
/action-docker-layer-caching@v0.
0
.1
1
-
uses
:
jpribyl
/action-docker-layer-caching@v0.
1
.1
# Ignore the failure of a step and avoid terminating the job.
continue-on-error
:
true
...
...
@@ -126,7 +127,7 @@ jobs:
# In this step, this action saves a list of existing images,
# the cache is created without them in the post run.
# It also restores the cache if it exists.
-
uses
:
satackey
/action-docker-layer-caching@v0.
0
.1
1
-
uses
:
jpribyl
/action-docker-layer-caching@v0.
1
.1
# Ignore the failure of a step and avoid terminating the job.
continue-on-error
:
true
...
...
Dockerfile
View file @
0ffcccbc
...
...
@@ -10,7 +10,7 @@ RUN apt-get update && apt-get install -y gnupg2 --no-install-recommends curl &&
curl
-sL
http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add -
# Add rocm repository
RUN
sh
-c
'echo deb [arch=amd64 trusted=yes] http://repo.radeon.com/rocm/apt/5.
3
/ ubuntu main > /etc/apt/sources.list.d/rocm.list'
RUN
sh
-c
'echo deb [arch=amd64 trusted=yes] http://repo.radeon.com/rocm/apt/5.
4.2
/ ubuntu main > /etc/apt/sources.list.d/rocm.list'
# Install dependencies
RUN
apt-get update
&&
DEBIAN_FRONTEND
=
noninteractive apt-get
install
-y
--allow-unauthenticated
\
...
...
hip-clang.docker
View file @
0ffcccbc
...
...
@@ -6,7 +6,7 @@ ARG PREFIX=/usr/local
RUN
dpkg
--add-architecture
i386
# Add rocm repository
RUN
sh
-c
'echo deb [arch=amd64 trusted=yes] http://repo.radeon.com/rocm/apt/5.
3
/ ubuntu main > /etc/apt/sources.list.d/rocm.list'
RUN
sh
-c
'echo deb [arch=amd64 trusted=yes] http://repo.radeon.com/rocm/apt/5.
4.2
/ ubuntu main > /etc/apt/sources.list.d/rocm.list'
# Install dependencies
RUN
apt-get update
&&
DEBIAN_FRONTEND
=
noninteractive apt-get
install
-y
--allow-unauthenticated
\
...
...
src/CMakeLists.txt
100755 → 100644
View file @
0ffcccbc
...
...
@@ -58,6 +58,7 @@ add_library(migraphx
layout_nhwc.cpp
load_save.cpp
make_op.cpp
memory_coloring.cpp
module.cpp
msgpack.cpp
normalize_attributes.cpp
...
...
@@ -65,8 +66,6 @@ add_library(migraphx
op_enums.cpp
operation.cpp
optimize_module.cpp
opt/memory_coloring.cpp
opt/memory_coloring_impl.cpp
pad_calc.cpp
pass_manager.cpp
permutation.cpp
...
...
src/include/migraphx/half.hpp
View file @
0ffcccbc
...
...
@@ -58,12 +58,12 @@ using deduce = typename detail::deduce<T>::type;
namespace
std
{
template
<
class
T
>
struct
common_type
<
migraphx
::
half
,
T
>
:
std
::
common_type
<
float
,
T
>
struct
common_type
<
migraphx
::
half
,
T
>
:
std
::
common_type
<
float
,
T
>
// NOLINT
{
};
template
<
class
T
>
struct
common_type
<
T
,
migraphx
::
half
>
:
std
::
common_type
<
float
,
T
>
struct
common_type
<
T
,
migraphx
::
half
>
:
std
::
common_type
<
float
,
T
>
// NOLINT
{
};
...
...
src/include/migraphx/instruction_ref.hpp
View file @
0ffcccbc
...
...
@@ -41,7 +41,7 @@ migraphx::instruction* as_address(const instruction_ref& ins) noexcept;
namespace
std
{
template
<
>
struct
hash
<
migraphx
::
instruction_ref
>
struct
hash
<
migraphx
::
instruction_ref
>
// NOLINT
{
using
argument_type
=
migraphx
::
instruction_ref
;
using
result_type
=
std
::
size_t
;
...
...
@@ -52,7 +52,7 @@ struct hash<migraphx::instruction_ref>
};
template
<
>
struct
equal_to
<
migraphx
::
instruction_ref
>
struct
equal_to
<
migraphx
::
instruction_ref
>
// NOLINT
{
using
argument_type
=
migraphx
::
instruction_ref
;
using
result_type
=
bool
;
...
...
src/include/migraphx/memory_coloring.hpp
View file @
0ffcccbc
...
...
@@ -39,7 +39,7 @@ struct memory_coloring
{
std
::
string
allocation_op
{};
bool
verify
=
false
;
std
::
string
name
()
const
{
return
"memory
coloring"
;
}
std
::
string
name
()
const
{
return
"memory
_
coloring"
;
}
void
apply
(
module
&
m
)
const
;
};
...
...
src/include/migraphx/op/gathernd.hpp
View file @
0ffcccbc
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
...
...
@@ -25,6 +25,7 @@
#define MIGRAPHX_GUARD_OPERATORS_GATHERND_HPP
#include <migraphx/check_shapes.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/argument.hpp>
...
...
@@ -47,33 +48,103 @@ struct gathernd
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
2
);
auto
r
=
inputs
.
front
().
lens
().
size
();
auto
q
=
inputs
.
back
().
lens
().
size
();
auto
k
=
inputs
.
back
().
lens
().
back
();
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
2
);
auto
i_shape
=
inputs
.
back
();
auto
data_shape
=
inputs
.
front
();
auto
r
=
data_shape
.
ndim
();
auto
q
=
i_shape
.
ndim
();
size_t
k
;
if
(
i_shape
.
dynamic
())
{
// the rank of the output is a function of k, so it must be fixed.
if
(
not
i_shape
.
dyn_dims
().
back
().
is_fixed
())
{
MIGRAPHX_THROW
(
"GATHERND: last dimension of indices tensor must be fixed (min=max)"
);
}
k
=
i_shape
.
dyn_dims
().
back
().
min
;
}
else
k
=
i_shape
.
lens
().
back
();
// Begin input validation checks.
int
output_ndim
=
int
(
q
)
+
r
-
k
-
batch_dims
-
1
;
if
(
k
>
r
-
batch_dims
)
{
MIGRAPHX_THROW
(
"GATHERND: Indices of length "
+
std
::
to_string
(
k
)
+
" cannot be used to access data of rank "
+
std
::
to_string
(
r
-
batch_dims
));
}
auto
indices_lens_iter
=
inputs
.
back
().
lens
().
begin
();
auto
output_lens_size
=
q
+
r
-
k
-
batch_dims
-
1
;
std
::
vector
<
std
::
size_t
>
output_lens
(
output_lens_size
);
std
::
copy
(
indices_lens_iter
,
indices_lens_iter
+
(
q
-
1
),
output_lens
.
begin
());
if
(
k
<
r
-
batch_dims
)
if
(
batch_dims
>=
q
or
batch_dims
>=
r
)
{
MIGRAPHX_THROW
(
"GATHERND: rank of an input cannot be less than batch_dims="
+
std
::
to_string
(
batch_dims
));
}
if
(
output_ndim
<
0
)
{
MIGRAPHX_THROW
(
"GATHERND: Indices too large for static data input: k="
+
std
::
to_string
(
k
));
}
if
(
migraphx
::
none_of
(
inputs
,
[](
auto
v
)
{
return
v
.
dynamic
();
}))
{
auto
indices_lens_iter
=
i_shape
.
lens
().
begin
();
// A rank 0 output is a scalar
if
(
output_ndim
==
0
)
return
shape
{
data_shape
.
type
(),
{
1
}};
// Part of the output shape comes from indices tensor, part from data tensor
std
::
vector
<
std
::
size_t
>
output_lens
(
output_ndim
);
std
::
copy
(
indices_lens_iter
,
indices_lens_iter
+
(
q
-
1
),
output_lens
.
begin
());
// fill the rest of output shape from data tensor
if
(
k
+
batch_dims
<
r
)
{
auto
data_lens
=
data_shape
.
lens
();
std
::
copy
(
data_lens
.
begin
()
+
batch_dims
+
k
,
data_lens
.
end
(),
output_lens
.
begin
()
+
q
-
1
);
}
shape
output_shape
{
data_shape
.
type
(),
output_lens
};
return
output_shape
;
}
else
{
auto
data_lens
=
inputs
.
front
().
lens
();
std
::
copy
(
data_lens
.
begin
()
+
batch_dims
+
k
,
data_lens
.
end
(),
output_lens
.
begin
()
+
q
-
1
);
// If one or both inputs are dynamic shapes, the output is dynamic.
// Make both inputs dynamic to simplify computations.
data_shape
=
data_shape
.
to_dynamic
();
i_shape
=
i_shape
.
to_dynamic
();
// A rank 0 output is a scalar
if
(
output_ndim
==
0
)
return
shape
(
data_shape
.
type
(),
{
shape
::
dynamic_dimension
({
1
,
1
,
0
})});
// Part of the output shape comes from indices tensor, part from data tensor
std
::
vector
<
shape
::
dynamic_dimension
>
output_dims
(
output_ndim
);
std
::
copy
(
i_shape
.
dyn_dims
().
begin
(),
i_shape
.
dyn_dims
().
begin
()
+
q
-
1
,
output_dims
.
begin
());
// fill the rest of output shape from data tensor
if
(
k
+
batch_dims
<
r
)
{
auto
data_dims
=
data_shape
.
dyn_dims
();
std
::
copy
(
data_dims
.
begin
()
+
batch_dims
+
k
,
data_dims
.
begin
()
+
r
,
output_dims
.
begin
()
+
q
-
1
);
}
shape
output_shape
(
data_shape
.
type
(),
output_dims
);
return
output_shape
;
}
shape
output_shape
{
inputs
.
front
().
type
(),
output_lens
};
return
output_shape
;
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
out
put_shape
};
argument
result
{
dyn_out
.
com
put
ed
_shape
};
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
data
)
{
args
[
1
].
visit
([
&
](
auto
indices
)
{
auto
indices_shape
=
indices
.
get_shape
();
...
...
src/include/migraphx/op/scatternd_op.hpp
View file @
0ffcccbc
...
...
@@ -28,44 +28,89 @@
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/ranges.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
/**
* @brief
* N-dimensional Scatter operations. This struct is parent class to ops which differ in what formula
* is used to reduce (combine old and new values of) the scattered value. It was originally based
* on Onnx ScatterND operation (see
* https://github.com/onnx/onnx/blob/main/docs/Operators.md#ScatterND) and is also similar to Numpy
* numpy.add.at().
*
* @tparam Derived a template parameter in the CRTP inheritance idiom, represents one of the child
* operations.
*/
template
<
class
Derived
>
struct
scatternd_op
:
op_name
<
Derived
>
{
/** Validate input shapes and return the correct output shape. For Scatter ops, the output
* is the same shape as the data tensor (first input), but cast to a standard shape.
*
*/
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
3
);
auto
r
=
inputs
.
front
().
lens
().
size
();
auto
q
=
inputs
.
at
(
1
).
lens
().
size
();
auto
k
=
inputs
.
at
(
1
).
lens
().
back
();
auto
ind_lens
=
inputs
.
at
(
1
).
lens
();
auto
upd_lens
=
inputs
.
back
().
lens
();
auto
data_lens
=
inputs
.
front
().
lens
();
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
3
);
auto
data_shape
=
inputs
.
front
();
auto
index_shape
=
inputs
.
at
(
1
);
auto
upd_shape
=
inputs
.
back
();
auto
r
=
data_shape
.
ndim
();
auto
q
=
index_shape
.
ndim
();
size_t
k
;
if
(
index_shape
.
dynamic
())
{
// the rank of the output is a function of k, so k must be fixed.
if
(
not
index_shape
.
dyn_dims
().
back
().
is_fixed
())
{
MIGRAPHX_THROW
(
"GATHERND: last dimension of indices tensor must be fixed (min=max)"
);
}
k
=
index_shape
.
dyn_dims
().
back
().
min
;
}
else
k
=
index_shape
.
lens
().
back
();
// Checks on the sizes of input tensors
if
(
q
+
r
!=
upd_shape
.
ndim
()
+
k
+
1
)
MIGRAPHX_THROW
(
"ScatterND: ranks of inputs don't match. "
+
std
::
to_string
(
q
)
+
" + "
+
std
::
to_string
(
r
)
+
" - "
+
std
::
to_string
(
k
)
+
" - 1 != "
+
std
::
to_string
(
upd_shape
.
ndim
()));
if
(
k
>
r
)
MIGRAPHX_THROW
(
"ScatterND: index of size "
+
std
::
to_string
(
k
)
+
" is too large for tensor of rank "
+
std
::
to_string
(
r
));
if
(
not
(
std
::
equal
(
ind_lens
.
begin
(),
ind_lens
.
begin
()
+
q
-
1
,
upd_lens
.
begin
())
and
std
::
equal
(
data_lens
.
begin
()
+
k
,
data_lens
.
end
(),
upd_lens
.
begin
()
+
q
-
1
)))
MIGRAPHX_THROW
(
"ScatterND: incorrect update shape. update.lens != indices.lens[0:q-1] "
"++ data.lens[k:r-1]"
);
auto
s
=
inputs
.
front
();
if
(
s
.
broadcasted
())
// Convert all static shape dimensions to dynamic so they can be compared.
// It's possible for some of the 3 inputs to be dynamic shapes and some static,
// but any dynamic dimension that's compared to a static dimension must be fixed.
auto
ind_dims
=
index_shape
.
to_dynamic
().
dyn_dims
();
auto
upd_dims
=
upd_shape
.
to_dynamic
().
dyn_dims
();
auto
data_dims
=
data_shape
.
to_dynamic
().
dyn_dims
();
// Check that corresponding portions of tensor shapes match.
if
(
not
(
std
::
equal
(
ind_dims
.
begin
(),
ind_dims
.
begin
()
+
q
-
1
,
upd_dims
.
begin
())
and
std
::
equal
(
data_dims
.
begin
()
+
k
,
data_dims
.
end
(),
upd_dims
.
begin
()
+
q
-
1
)))
MIGRAPHX_THROW
(
"ScatterND: incorrect update shape. Update dimensions must match "
"indices and data."
);
if
(
data_shape
.
dynamic
())
return
data_shape
;
else
if
(
data_shape
.
broadcasted
())
{
return
{
s
.
type
(),
s
.
lens
()};
return
{
data_shape
.
type
(),
data_shape
.
lens
()};
}
else
{
return
s
.
with_lens
(
s
.
lens
());
return
data_shape
.
with_lens
(
data_shape
.
lens
());
}
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
out
put_shape
};
argument
result
{
dyn_out
.
com
put
ed
_shape
};
auto
&
self
=
static_cast
<
const
Derived
&>
(
*
this
);
visit_all
(
result
,
args
[
0
],
args
[
2
])([
&
](
auto
output
,
auto
data
,
auto
updates
)
{
std
::
copy
(
data
.
begin
(),
data
.
end
(),
output
.
begin
());
...
...
@@ -74,8 +119,8 @@ struct scatternd_op : op_name<Derived>
auto
updates_std
=
shape
{
updates_shape
.
type
(),
updates_shape
.
lens
()};
auto
indices_shape
=
indices
.
get_shape
();
auto
k
=
indices_shape
.
lens
().
back
();
auto
q
=
indices_shape
.
lens
().
size
();
auto
r
=
out
put_shape
.
lens
().
size
();
auto
q
=
indices_shape
.
ndim
();
auto
r
=
dyn_out
.
com
put
ed
_shape
.
ndim
();
par_for
(
updates_shape
.
elements
(),
[
&
](
const
auto
i
)
{
auto
updates_idx
=
updates_std
.
multi
(
i
);
std
::
vector
<
std
::
size_t
>
indices_idx
(
q
,
0
);
...
...
@@ -89,7 +134,7 @@ struct scatternd_op : op_name<Derived>
std
::
copy
(
index_start
,
index_end
,
out_idx
.
begin
());
std
::
copy
(
updates_idx
.
begin
()
+
q
-
1
,
updates_idx
.
end
(),
out_idx
.
begin
()
+
k
);
self
.
reduction
()(
output
[
out
put_shape
.
index
(
out_idx
)],
updates
[
i
]);
self
.
reduction
()(
output
[
dyn_out
.
com
put
ed
_shape
.
index
(
out_idx
)],
updates
[
i
]);
});
});
});
...
...
src/memory_coloring.cpp
0 → 100644
View file @
0ffcccbc
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/memory_coloring.hpp>
#include <migraphx/module.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/algorithm.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
#include <unordered_set>
#include <unordered_map>
#include <map>
#include <set>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_DEBUG_MEMORY_COLORING
);
using
instruction_set
=
std
::
unordered_set
<
instruction_ref
>
;
using
instruction_set_map
=
std
::
unordered_map
<
instruction_ref
,
instruction_set
>
;
// This will do liveness analysis on the module, and it will call the
// function `f` with the instruction and the set of the other instructions
// that are live
template
<
class
F
>
void
liveness
(
const
module
&
m
,
F
f
)
{
auto
implicit_deps
=
m
.
calc_implicit_deps
();
instruction_set
live_set
;
auto
rp
=
reverse
(
m
);
for
(
auto
rins
:
iterator_for
(
rp
))
// NOLINT
{
// The base iterator is one ahead, so we need to use the previous iterator
auto
ins
=
std
::
prev
(
rins
.
base
());
// Add live variables
auto
add_live_variables
=
[
&
](
const
auto
&
inputs
)
{
for
(
auto
input
:
inputs
)
{
auto
i
=
instruction
::
get_output_alias
(
input
);
// Skip if variable comes from parent
if
(
not
m
.
has_instruction
(
i
))
continue
;
live_set
.
insert
(
i
);
}
};
add_live_variables
(
ins
->
inputs
());
add_live_variables
(
implicit_deps
[
ins
]);
// Remove last usage
auto
it
=
live_set
.
find
(
ins
);
if
(
it
!=
live_set
.
end
())
{
live_set
.
erase
(
it
);
f
(
ins
,
live_set
);
}
}
}
// This will build the conflict table or interference graph. This is
// essentially a map from one instruction to a set of instruction that are
// used together. Each instruction will be the allocation instruction.
instruction_set_map
build_conflict_table
(
const
module
&
m
,
std
::
string
allocation_op
)
{
instruction_set_map
conflict_table
;
liveness
(
m
,
[
&
](
auto
ins
,
auto
live_set
)
{
// Skip variables that aren't allocations
if
(
ins
->
name
()
!=
allocation_op
)
return
;
// Skip zero allocations
if
(
ins
->
get_shape
().
bytes
()
==
0
)
return
;
conflict_table
[
ins
];
for
(
auto
i
:
live_set
)
{
if
(
i
==
ins
)
continue
;
// Skip variables that aren't allocations
if
(
i
->
name
()
!=
allocation_op
)
continue
;
// Skip zero allocations
if
(
i
->
get_shape
().
bytes
()
==
0
)
continue
;
conflict_table
[
i
].
insert
(
ins
);
conflict_table
[
ins
].
insert
(
i
);
}
});
assert
(
std
::
all_of
(
conflict_table
.
begin
(),
conflict_table
.
end
(),
[](
auto
&&
pp
)
{
return
pp
.
second
.
count
(
pp
.
first
)
==
0
;
}));
return
conflict_table
;
}
// Check if intervals overlap
bool
is_overlap
(
std
::
pair
<
std
::
size_t
,
std
::
size_t
>
x
,
std
::
pair
<
std
::
size_t
,
std
::
size_t
>
y
)
{
return
std
::
max
(
x
.
first
,
y
.
first
)
<
std
::
min
(
x
.
second
,
y
.
second
);
}
struct
allocation_segment
{
using
segment
=
std
::
pair
<
std
::
size_t
,
std
::
size_t
>
;
std
::
unordered_map
<
instruction_ref
,
segment
>
ins2segment
;
const
segment
*
add_segment
(
instruction_ref
ins
,
segment
s
)
{
return
&
(
ins2segment
[
ins
]
=
s
);
}
const
segment
*
get_segment
(
instruction_ref
ins
)
const
{
auto
it
=
ins2segment
.
find
(
ins
);
if
(
it
==
ins2segment
.
end
())
return
nullptr
;
return
&
it
->
second
;
}
// Remove segment for an instruction
void
remove
(
instruction_ref
ins
)
{
auto
it
=
ins2segment
.
find
(
ins
);
if
(
it
!=
ins2segment
.
end
())
{
ins2segment
.
erase
(
it
);
}
}
std
::
size_t
max
()
{
std
::
size_t
n
=
0
;
for
(
auto
&&
pp
:
ins2segment
)
{
auto
seg
=
pp
.
second
;
n
=
std
::
max
(
n
,
seg
.
second
);
}
return
n
;
}
template
<
class
Iterator
>
static
bool
overlaps
(
Iterator
first
,
Iterator
last
,
const
segment
&
s
)
{
return
std
::
any_of
(
first
,
last
,
[
&
](
auto
&&
t
)
{
return
is_overlap
(
s
,
t
);
});
}
static
bool
overlaps
(
const
std
::
set
<
segment
>&
segments
,
const
segment
&
s
)
{
return
overlaps
(
segments
.
begin
(),
segments
.
end
(),
s
);
}
static
auto
find_gap
(
const
std
::
set
<
segment
>&
segments
,
std
::
size_t
n
)
{
std
::
size_t
max_end
=
0
;
return
std
::
adjacent_find
(
segments
.
begin
(),
segments
.
end
(),
[
&
](
segment
x
,
segment
y
)
{
if
(
x
.
second
<
max_end
)
return
false
;
max_end
=
x
.
second
;
if
(
is_overlap
(
x
,
y
))
return
false
;
assert
(
y
.
first
>=
x
.
second
);
auto
k
=
y
.
first
-
x
.
second
;
return
(
k
>=
n
);
});
}
static
std
::
size_t
max_type_size
(
const
shape
&
s
)
{
return
std
::
accumulate
(
s
.
sub_shapes
().
begin
(),
s
.
sub_shapes
().
end
(),
s
.
type_size
(),
[](
auto
size
,
const
auto
&
sub
)
{
return
std
::
max
(
size
,
max_type_size
(
sub
));
});
}
static
std
::
size_t
compute_alignment
(
instruction_ref
ins
)
{
auto
alignment
=
max_type_size
(
ins
->
get_shape
());
// A rough estimate for the total number of elements
auto
n
=
ins
->
get_shape
().
bytes
()
/
alignment
;
// Check for vectorized alignment
if
(
n
>
4
)
{
auto
d
=
n
%
4
;
if
(
d
==
0
)
alignment
*=
4
;
if
(
d
==
2
)
alignment
*=
2
;
}
return
alignment
;
}
static
segment
next_segment
(
std
::
set
<
segment
>&
segments
,
instruction_ref
ins
,
std
::
size_t
alignment
)
{
assert
(
ins
->
get_shape
().
bytes
()
>
0
);
// Compute alignment
auto
n
=
1
+
(
ins
->
get_shape
().
bytes
()
-
1
)
/
alignment
;
assert
(
n
>
0
);
auto
start
=
0
;
// Insert at end if it cant fit at the begining
if
(
segments
.
empty
()
or
segments
.
begin
()
->
first
<=
n
)
{
auto
it
=
find_gap
(
segments
,
n
);
if
(
it
==
segments
.
end
())
it
=
std
::
max_element
(
segments
.
begin
(),
segments
.
end
(),
[
&
](
segment
x
,
segment
y
)
{
return
x
.
second
<
y
.
second
;
});
if
(
it
!=
segments
.
end
())
start
=
it
->
second
;
}
auto
s
=
segment
{
start
,
start
+
n
};
assert
(
not
overlaps
(
segments
,
s
));
segments
.
insert
(
s
);
return
s
;
}
static
std
::
unordered_map
<
instruction_ref
,
int
>
create_allocation_index
(
const
module
&
m
,
const
instruction_set_map
&
conflict_table
)
{
std
::
unordered_map
<
instruction_ref
,
int
>
result
;
int
i
=
0
;
for
(
auto
ins
:
iterator_for
(
m
))
{
if
(
not
contains
(
conflict_table
,
ins
))
continue
;
result
[
ins
]
=
i
++
;
}
return
result
;
}
// Build the allocation_color class from the conflict_table
static
allocation_segment
build
(
const
module
&
m
,
const
instruction_set_map
&
conflict_table
,
std
::
size_t
alignment
)
{
allocation_segment
as
{};
std
::
vector
<
instruction_ref
>
conflict_queue
;
// Add all allocations to the conflict_queue
std
::
transform
(
conflict_table
.
begin
(),
conflict_table
.
end
(),
std
::
back_inserter
(
conflict_queue
),
[](
auto
&&
pp
)
{
return
pp
.
first
;
});
auto
alloc_index
=
create_allocation_index
(
m
,
conflict_table
);
// Sort the conflict queue so we process the allocation with the most
// number of adjacent allocations first
std
::
sort
(
conflict_queue
.
begin
(),
conflict_queue
.
end
(),
by
(
std
::
greater
<>
{},
[
&
](
auto
x
)
{
return
std
::
make_tuple
(
conflict_table
.
at
(
x
).
size
(),
x
->
get_shape
().
bytes
(),
alloc_index
.
at
(
x
));
}));
// Process the conflict_queue, we refer to the current allocation as
// the parent and the adjacent allocations as children
for
(
auto
parent
:
conflict_queue
)
{
// Sort children by size
std
::
vector
<
instruction_ref
>
children
(
conflict_table
.
at
(
parent
).
begin
(),
conflict_table
.
at
(
parent
).
end
());
std
::
sort
(
children
.
begin
(),
children
.
end
(),
by
(
std
::
less
<>
{},
[
&
](
auto
x
)
{
return
std
::
make_tuple
(
x
->
get_shape
().
bytes
(),
alloc_index
.
at
(
x
));
}));
assert
(
not
contains
(
children
,
parent
));
// This set is to track the segments already processed
std
::
set
<
segment
>
segments
;
// Add all segments for the children to the segments already processed
transform_if
(
children
.
begin
(),
children
.
end
(),
std
::
inserter
(
segments
,
segments
.
begin
()),
[
&
](
auto
child
)
{
return
as
.
get_segment
(
child
);
},
[
&
](
auto
child
)
{
return
*
as
.
get_segment
(
child
);
});
assert
(
as
.
get_segment
(
parent
)
==
nullptr
);
as
.
add_segment
(
parent
,
next_segment
(
segments
,
parent
,
alignment
));
}
// Reduce the number of segments
for
(
std
::
size_t
n
=
0
;
n
<
3
;
n
++
)
{
for
(
auto
parent
:
conflict_queue
)
{
auto
children
=
conflict_table
.
at
(
parent
);
// This set is to track the segments already processed
std
::
set
<
segment
>
segments
;
// Add all segments for the children to the segments already processed
transform_if
(
children
.
begin
(),
children
.
end
(),
std
::
inserter
(
segments
,
segments
.
begin
()),
[
&
](
auto
child
)
{
return
as
.
get_segment
(
child
);
},
[
&
](
auto
child
)
{
return
*
as
.
get_segment
(
child
);
});
// Get the segment for the parent
const
auto
*
parent_segment
=
as
.
get_segment
(
parent
);
assert
(
parent_segment
!=
nullptr
);
auto
s
=
next_segment
(
segments
,
parent
,
alignment
);
if
(
s
!=
*
parent_segment
and
s
.
second
<=
as
.
max
())
{
as
.
add_segment
(
parent
,
s
);
}
}
}
return
as
;
}
};
static
std
::
size_t
find_max_alignment
(
const
module
&
m
,
const
std
::
string
&
allocation_op
)
{
std
::
size_t
alignment
=
1
;
for
(
auto
ins
:
iterator_for
(
m
))
{
if
(
ins
->
name
()
!=
allocation_op
)
continue
;
alignment
=
std
::
max
(
allocation_segment
::
compute_alignment
(
ins
),
alignment
);
}
return
alignment
;
}
void
memory_coloring
::
apply
(
module
&
m
)
const
{
const
std
::
size_t
alignment
=
find_max_alignment
(
m
,
allocation_op
);
auto
conflict_table
=
build_conflict_table
(
m
,
allocation_op
);
auto
as
=
allocation_segment
::
build
(
m
,
conflict_table
,
alignment
);
// All allocations should have a segment
assert
(
std
::
all_of
(
conflict_table
.
begin
(),
conflict_table
.
end
(),
[
&
](
auto
&&
pp
)
{
return
as
.
get_segment
(
pp
.
first
);
}));
// Adjacent allocations should not have overlapping segments
assert
(
std
::
none_of
(
conflict_table
.
begin
(),
conflict_table
.
end
(),
[
&
](
auto
&&
pp
)
{
auto
*
x
=
as
.
get_segment
(
pp
.
first
);
return
std
::
any_of
(
pp
.
second
.
begin
(),
pp
.
second
.
end
(),
[
&
](
auto
ins
)
{
auto
*
y
=
as
.
get_segment
(
ins
);
assert
(
x
and
y
);
return
is_overlap
(
*
x
,
*
y
);
});
}));
// Print out segments
if
(
enabled
(
MIGRAPHX_DEBUG_MEMORY_COLORING
{}))
{
for
(
auto
&&
pp
:
conflict_table
)
{
std
::
cout
<<
"------- conflict -------"
<<
std
::
endl
;
auto
s1
=
as
.
ins2segment
.
at
(
pp
.
first
);
std
::
cout
<<
s1
.
first
<<
", "
<<
s1
.
second
<<
": "
;
m
.
debug_print
(
pp
.
first
);
for
(
auto
ins
:
pp
.
second
)
{
auto
s2
=
as
.
ins2segment
.
at
(
ins
);
std
::
cout
<<
s2
.
first
<<
", "
<<
s2
.
second
<<
": "
;
m
.
debug_print
(
ins
);
}
}
}
// Total memory
std
::
size_t
n
=
as
.
max
()
*
alignment
;
// Replace allocations
auto
mem
=
m
.
add_parameter
(
"scratch"
,
shape
{
shape
::
int8_type
,
{
n
}});
for
(
auto
&&
[
ins
,
seg
]
:
as
.
ins2segment
)
{
assert
(
ins
->
name
()
==
allocation_op
);
auto
s
=
ins
->
get_shape
();
std
::
size_t
offset
=
seg
.
first
*
alignment
;
assert
(
offset
<
n
);
m
.
replace_instruction
(
ins
,
op
::
load
{
s
,
offset
},
mem
);
}
// Replace zero allocation
for
(
auto
ins
:
iterator_for
(
m
))
{
if
(
ins
->
name
()
!=
allocation_op
)
continue
;
assert
(
ins
->
get_shape
().
bytes
()
==
0
);
m
.
replace_instruction
(
ins
,
op
::
load
{
ins
->
get_shape
(),
0
},
mem
);
}
// Remove scratch parameter if its not used
if
(
mem
->
outputs
().
empty
())
{
m
.
remove_instruction
(
mem
);
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/onnx/include/migraphx/onnx/onnx_parser.hpp
View file @
0ffcccbc
...
...
@@ -113,7 +113,8 @@ struct onnx_parser
void
parse_from
(
std
::
istream
&
is
,
std
::
string
name
=
""
);
void
parse_from
(
const
void
*
data
,
std
::
size_t
size
);
void
parse_graph
(
module
*
mod
,
const
onnx
::
GraphProto
&
graph
);
std
::
vector
<
instruction_ref
>
parse_graph
(
module
*
mod
,
const
onnx
::
GraphProto
&
graph
,
bool
inlining
=
false
);
literal
parse_value
(
const
onnx
::
AttributeProto
&
attr
)
const
;
literal
parse_tensor
(
const
onnx
::
TensorProto
&
t
)
const
;
shape
parse_type
(
const
onnx
::
TypeProto
&
t
,
const
std
::
vector
<
std
::
size_t
>&
input_dims
)
const
;
...
...
src/onnx/onnx_parser.cpp
View file @
0ffcccbc
...
...
@@ -220,7 +220,7 @@ void onnx_parser::parse_from(std::istream& is, std::string name)
if
(
model
.
has_graph
())
{
this
->
parse_graph
(
mm
,
model
.
graph
());
(
void
)
this
->
parse_graph
(
mm
,
model
.
graph
());
}
}
else
...
...
@@ -240,7 +240,7 @@ void onnx_parser::parse_from(const void* data, std::size_t size)
if
(
model
.
has_graph
())
{
this
->
parse_graph
(
mm
,
model
.
graph
());
(
void
)
this
->
parse_graph
(
mm
,
model
.
graph
());
}
}
else
...
...
@@ -264,7 +264,8 @@ int64_t onnx_parser::get_opset_version(const onnx::ModelProto& model)
return
version
;
}
void
onnx_parser
::
parse_graph
(
module
*
mod
,
const
onnx
::
GraphProto
&
graph
)
std
::
vector
<
instruction_ref
>
onnx_parser
::
parse_graph
(
module
*
mod
,
const
onnx
::
GraphProto
&
graph
,
bool
inlining
)
{
std
::
unordered_map
<
std
::
string
,
instruction_ref
>
mod_insts
;
for
(
auto
&&
f
:
graph
.
initializer
())
...
...
@@ -372,11 +373,16 @@ void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
std
::
back_inserter
(
output_ins
),
[
&
](
const
auto
&
name
)
{
return
instructions
[
name
];
});
// add the return instuction
mod
->
add_return
(
output_ins
);
if
(
not
inlining
)
{
// add the return instuction
mod
->
add_return
(
output_ins
);
// Remove instructions added in module (this is turned off for subgraph inlining)
erase_if
(
instructions
,
[
&
](
auto
&&
p
)
{
return
mod
->
has_instruction
(
p
.
second
);
});
}
// remove instructions added in this mod
erase_if
(
instructions
,
[
&
](
auto
&&
p
)
{
return
mod
->
has_instruction
(
p
.
second
);
});
return
output_ins
;
}
literal
onnx_parser
::
parse_value
(
const
onnx
::
AttributeProto
&
attr
)
const
...
...
src/onnx/parse_if.cpp
View file @
0ffcccbc
...
...
@@ -51,6 +51,24 @@ struct parse_if : op_parser<parse_if>
" condition input can have only one element!"
);
}
// Fold instruction if condition is constant thus can be evaled
// prior to inference
if
(
args
.
front
()
->
can_eval
())
{
auto
cond_arg
=
args
.
front
()
->
eval
();
auto
*
mod
=
info
.
mod
;
// then branch
if
(
cond_arg
.
at
<
bool
>
())
{
return
parser
.
parse_graph
(
mod
,
then_graph
,
true
);
}
// else branch
else
{
return
parser
.
parse_graph
(
mod
,
else_graph
,
true
);
}
}
std
::
string
then_name
=
info
.
name
+
"_if"
;
module_ref
then_mdl
=
parser
.
prog
.
create_module
(
then_name
);
...
...
@@ -58,10 +76,10 @@ struct parse_if : op_parser<parse_if>
module_ref
else_mdl
=
parser
.
prog
.
create_module
(
else_name
);
// parse the then sub_graph
parser
.
parse_graph
(
then_mdl
,
then_graph
);
(
void
)
parser
.
parse_graph
(
then_mdl
,
then_graph
);
// parse_the else sub_graph
parser
.
parse_graph
(
else_mdl
,
else_graph
);
(
void
)
parser
.
parse_graph
(
else_mdl
,
else_graph
);
auto
then_out_shapes
=
then_mdl
->
get_output_shapes
();
auto
else_out_shapes
=
else_mdl
->
get_output_shapes
();
...
...
src/onnx/parse_loop.cpp
View file @
0ffcccbc
...
...
@@ -71,7 +71,7 @@ struct parse_loop : op_parser<parse_loop>
module_ref
sub_mod
=
parser
.
prog
.
create_module
(
mod_name
);
// parse the sub_graph
parser
.
parse_graph
(
sub_mod
,
sub_graph
);
(
void
)
parser
.
parse_graph
(
sub_mod
,
sub_graph
);
auto
ret
=
info
.
add_instruction
(
make_op
(
"loop"
,
{{
"max_iterations"
,
max_iterations
}}),
args
,
{
sub_mod
});
...
...
src/pass_manager.cpp
View file @
0ffcccbc
...
...
@@ -39,6 +39,7 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_TRACE_PASSES
);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_TIME_PASSES
);
void
validate_pass
(
module
&
mod
,
const
pass
&
p
,
tracer
trace
)
{
...
...
@@ -94,19 +95,19 @@ struct module_pm : module_pass_manager
virtual
void
run_pass
(
const
pass
&
p
)
override
{
assert
(
mod
);
timer
ts
{};
using
seconds
=
std
::
chrono
::
duration
<
double
>
;
trace
(
"Module: "
,
mod
->
name
(),
", Pass: "
,
p
.
name
());
const
double
t1
=
ts
.
record
<
seconds
>
();
assert
(
mod
->
validate
()
==
mod
->
end
());
p
.
apply
(
*
this
);
if
(
enabled
(
MIGRAPHX_TIME_PASSES
{}))
{
using
milliseconds
=
std
::
chrono
::
duration
<
double
,
std
::
milli
>
;
auto
ms
=
time
<
milliseconds
>
([
&
]
{
p
.
apply
(
*
this
);
});
std
::
cout
<<
p
.
name
()
<<
": "
<<
ms
<<
"ms
\n
"
;
}
else
{
p
.
apply
(
*
this
);
}
trace
(
*
mod
);
validate_pass
(
*
mod
,
p
,
*
t
);
const
double
t2
=
ts
.
record
<
seconds
>
();
trace
(
"Pass: "
,
p
.
name
(),
" completed in (s): "
,
(
t2
-
t1
));
}
};
...
...
src/program.cpp
View file @
0ffcccbc
...
...
@@ -336,7 +336,8 @@ std::vector<argument> generic_eval(const module* mod,
if
(
not
ins
->
get_shape
().
dynamic
()
and
param
.
get_shape
()
!=
ins
->
get_shape
())
{
MIGRAPHX_THROW
(
"Incorrect shape {"
+
to_string
(
param
.
get_shape
())
+
"} for parameter: "
+
param_name
);
"} for parameter: "
+
param_name
+
" should be: "
+
to_string
(
ins
->
get_shape
()));
}
return
param
;
}));
...
...
test/memory_coloring_test.cpp
View file @
0ffcccbc
...
...
@@ -691,7 +691,7 @@ TEST_CASE(test38)
auto
p83
=
m
.
add_instruction
(
pass_op
{},
p78
,
p77
);
m
.
add_instruction
(
pass_op
{},
output
,
p83
,
p63
);
run_pass
(
m
);
CHECK
(
m
.
get_parameter_shape
(
"scratch"
).
bytes
()
==
7225344
);
// Optimal solution is
6422528
CHECK
(
m
.
get_parameter_shape
(
"scratch"
).
bytes
()
==
6422528
);
CHECK
(
no_allocate
(
m
));
}
...
...
@@ -729,7 +729,7 @@ TEST_CASE(test39)
run_pass
(
*
smod
);
}
CHECK
(
mm
->
get_parameter_shape
(
"scratch"
).
bytes
()
==
4
);
CHECK
(
mm
->
get_parameter_shape
(
"scratch"
).
bytes
()
==
1
);
CHECK
(
then_mod
->
get_parameter_shape
(
"scratch"
).
bytes
()
==
24
);
CHECK
(
else_mod
->
get_parameter_shape
(
"scratch"
).
bytes
()
==
24
);
CHECK
(
no_allocate
(
*
mm
));
...
...
@@ -3374,7 +3374,7 @@ TEST_CASE(rnn_dom)
m
.
add_instruction
(
pass_op
{},
moutput
,
mx250
,
mx249
,
mx248
);
run_pass
(
m
);
CHECK
(
m
.
get_parameter_shape
(
"scratch"
).
bytes
()
==
1600
);
CHECK
(
m
.
get_parameter_shape
(
"scratch"
).
bytes
()
==
1824
);
// Optimal is
1600
CHECK
(
no_allocate
(
m
));
CHECK
(
is_disjoint
({
mx0
,
mx8
}));
CHECK
(
is_disjoint
({
mx0
,
mx8
}));
...
...
@@ -3790,4 +3790,23 @@ TEST_CASE(literal_test)
CHECK
(
lit
==
result
);
}
TEST_CASE
(
test_tuple
)
{
migraphx
::
module
m
;
auto
s1
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
8
}};
auto
s2
=
migraphx
::
shape
{
migraphx
::
shape
::
half_type
,
{
10
}};
auto
s
=
migraphx
::
shape
{{
s1
,
s2
}};
auto
a1
=
add_alloc
(
m
,
s
);
auto
m1
=
m
.
add_instruction
(
pass_op
{},
a1
);
auto
a2
=
add_alloc
(
m
,
{
migraphx
::
shape
::
float_type
,
{
4
}});
m
.
add_instruction
(
pass_op
{},
a2
,
m1
);
run_pass
(
m
);
CHECK
(
m
.
get_parameter_shape
(
"scratch"
).
bytes
()
==
68
);
CHECK
(
no_allocate
(
m
));
CHECK
(
is_disjoint
({
a1
,
a2
}));
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/onnx/gathernd_dyn_test.onnx
0 → 100644
View file @
0ffcccbc
File added
test/onnx/gen_onnx.py
View file @
0ffcccbc
...
...
@@ -2132,6 +2132,19 @@ def gathernd_test():
return
([
node
],
[
x
,
i
],
[
y
])
@
onnx_test
()
def
gathernd_dyn_test
():
x
=
helper
.
make_tensor_value_info
(
'data'
,
TensorProto
.
FLOAT
,
[
None
,
2
])
i
=
helper
.
make_tensor_value_info
(
'indices'
,
TensorProto
.
INT64
,
[
2
,
2
])
y
=
helper
.
make_tensor_value_info
(
'y'
,
TensorProto
.
FLOAT
,
[
2
])
node
=
onnx
.
helper
.
make_node
(
'GatherND'
,
inputs
=
[
'data'
,
'indices'
],
outputs
=
[
'y'
])
return
([
node
],
[
x
,
i
],
[
y
])
@
onnx_test
()
def
gathernd_batch_dims_test
():
x
=
helper
.
make_tensor_value_info
(
'data'
,
TensorProto
.
FLOAT
,
[
2
,
2
,
2
])
...
...
@@ -2498,6 +2511,58 @@ def if_else_test():
x
=
onnx
.
helper
.
make_tensor_value_info
(
'x'
,
onnx
.
TensorProto
.
FLOAT
,
[
2
,
3
])
y
=
onnx
.
helper
.
make_tensor_value_info
(
'y'
,
onnx
.
TensorProto
.
FLOAT
,
[
2
,
3
])
then_out
=
onnx
.
helper
.
make_tensor_value_info
(
'then_out'
,
onnx
.
TensorProto
.
FLOAT
,
[
2
,
3
])
else_out
=
onnx
.
helper
.
make_tensor_value_info
(
'else_out'
,
onnx
.
TensorProto
.
FLOAT
,
[
2
,
3
])
xt
=
np
.
ones
((
2
,
3
)).
astype
(
np
.
float
)
xt_tensor
=
helper
.
make_tensor
(
name
=
'xt'
,
data_type
=
TensorProto
.
FLOAT
,
dims
=
xt
.
shape
,
vals
=
xt
.
flatten
().
astype
(
np
.
float32
))
yt
=
np
.
random
.
randn
(
2
,
3
).
astype
(
np
.
float
)
yt_tensor
=
helper
.
make_tensor
(
name
=
'yt'
,
data_type
=
TensorProto
.
FLOAT
,
dims
=
yt
.
shape
,
vals
=
yt
.
flatten
().
astype
(
np
.
float32
))
then_add_node
=
onnx
.
helper
.
make_node
(
'Add'
,
inputs
=
[
'x'
,
'xt'
],
outputs
=
[
'then_out'
])
else_mul_node
=
onnx
.
helper
.
make_node
(
'Mul'
,
inputs
=
[
'y'
,
'yt'
],
outputs
=
[
'else_out'
])
then_body
=
onnx
.
helper
.
make_graph
([
then_add_node
],
'then_body'
,
[],
[
then_out
])
else_body
=
onnx
.
helper
.
make_graph
([
else_mul_node
],
'else_body'
,
[],
[
else_out
])
cond_tensor
=
onnx
.
helper
.
make_tensor_value_info
(
"cond"
,
onnx
.
TensorProto
.
BOOL
,
[
1
])
res
=
onnx
.
helper
.
make_tensor_value_info
(
'res'
,
TensorProto
.
FLOAT
,
[])
node
=
onnx
.
helper
.
make_node
(
'If'
,
inputs
=
[
'cond'
],
outputs
=
[
'res'
],
then_branch
=
then_body
,
else_branch
=
else_body
)
return
([
node
],
[
x
,
y
,
cond_tensor
],
[
res
],
[
xt_tensor
,
yt_tensor
])
@
onnx_test
()
def
if_else_test_inlined
():
x
=
onnx
.
helper
.
make_tensor_value_info
(
'x'
,
onnx
.
TensorProto
.
FLOAT
,
[
2
,
3
])
y
=
onnx
.
helper
.
make_tensor_value_info
(
'y'
,
onnx
.
TensorProto
.
FLOAT
,
[
2
,
3
])
then_out
=
onnx
.
helper
.
make_tensor_value_info
(
'then_out'
,
onnx
.
TensorProto
.
FLOAT
,
[
2
,
3
])
...
...
@@ -2547,6 +2612,149 @@ def if_else_test():
return
([
node
],
[
x
,
y
],
[
res
],
[
cond_tensor
,
xt_tensor
,
yt_tensor
])
@
onnx_test
()
def
if_then_else_multi_output_shapes_inlined_test
():
x
=
onnx
.
helper
.
make_tensor_value_info
(
'x'
,
onnx
.
TensorProto
.
FLOAT
,
[
2
,
3
,
1
])
y
=
onnx
.
helper
.
make_tensor_value_info
(
'y'
,
onnx
.
TensorProto
.
FLOAT
,
[
2
,
3
])
then_out
=
onnx
.
helper
.
make_tensor_value_info
(
'then_out'
,
onnx
.
TensorProto
.
FLOAT
,
[
2
,
3
,
1
])
then_out2
=
onnx
.
helper
.
make_tensor_value_info
(
'then_out2'
,
onnx
.
TensorProto
.
FLOAT
,
[
2
,
3
,
1
])
else_out
=
onnx
.
helper
.
make_tensor_value_info
(
'else_out'
,
onnx
.
TensorProto
.
FLOAT
,
[
2
,
3
])
else_out2
=
onnx
.
helper
.
make_tensor_value_info
(
'else_out2'
,
onnx
.
TensorProto
.
FLOAT
,
[
2
,
3
])
xt
=
np
.
ones
((
2
,
3
,
1
)).
astype
(
np
.
float
)
xt_tensor
=
helper
.
make_tensor
(
name
=
'xt'
,
data_type
=
TensorProto
.
FLOAT
,
dims
=
xt
.
shape
,
vals
=
xt
.
flatten
().
astype
(
np
.
float32
))
yt
=
np
.
random
.
randn
(
2
,
3
).
astype
(
np
.
float
)
yt_tensor
=
helper
.
make_tensor
(
name
=
'yt'
,
data_type
=
TensorProto
.
FLOAT
,
dims
=
yt
.
shape
,
vals
=
yt
.
flatten
().
astype
(
np
.
float32
))
then_add_node
=
onnx
.
helper
.
make_node
(
'Add'
,
inputs
=
[
'x'
,
'xt'
],
outputs
=
[
'then_out'
])
then_add_node2
=
onnx
.
helper
.
make_node
(
'Add'
,
inputs
=
[
'x'
,
'x'
],
outputs
=
[
'then_out2'
])
else_mul_node
=
onnx
.
helper
.
make_node
(
'Mul'
,
inputs
=
[
'y'
,
'yt'
],
outputs
=
[
'else_out'
])
else_sub_node
=
onnx
.
helper
.
make_node
(
'Sub'
,
inputs
=
[
'y'
,
'yt'
],
outputs
=
[
'else_out2'
])
then_body
=
onnx
.
helper
.
make_graph
([
then_add_node
,
then_add_node2
],
'then_body'
,
[],
[
then_out
,
then_out2
])
else_body
=
onnx
.
helper
.
make_graph
([
else_mul_node
,
else_sub_node
],
'else_body'
,
[],
[
else_out
,
else_out2
])
cond
=
np
.
array
([
1
]).
astype
(
np
.
bool
)
cond_tensor
=
helper
.
make_tensor
(
name
=
"cond"
,
data_type
=
TensorProto
.
BOOL
,
dims
=
cond
.
shape
,
vals
=
cond
.
astype
(
bool
))
res1
=
onnx
.
helper
.
make_tensor_value_info
(
'res1'
,
TensorProto
.
FLOAT
,
[])
res2
=
onnx
.
helper
.
make_tensor_value_info
(
'res2'
,
TensorProto
.
FLOAT
,
[])
node
=
onnx
.
helper
.
make_node
(
'If'
,
inputs
=
[
'cond'
],
outputs
=
[
'res1'
,
'res2'
],
then_branch
=
then_body
,
else_branch
=
else_body
)
return
([
node
],
[
x
,
y
],
[
res1
,
res2
],
[
cond_tensor
,
xt_tensor
,
yt_tensor
])
@
onnx_test
()
def
if_then_else_multi_output_shapes_test
():
x
=
onnx
.
helper
.
make_tensor_value_info
(
'x'
,
onnx
.
TensorProto
.
FLOAT
,
[
2
,
3
,
1
])
y
=
onnx
.
helper
.
make_tensor_value_info
(
'y'
,
onnx
.
TensorProto
.
FLOAT
,
[
2
,
3
,
1
])
then_out
=
onnx
.
helper
.
make_tensor_value_info
(
'then_out'
,
onnx
.
TensorProto
.
FLOAT
,
[
2
,
3
,
1
])
then_out2
=
onnx
.
helper
.
make_tensor_value_info
(
'then_out2'
,
onnx
.
TensorProto
.
FLOAT
,
[
2
,
3
,
1
])
else_out
=
onnx
.
helper
.
make_tensor_value_info
(
'else_out'
,
onnx
.
TensorProto
.
FLOAT
,
[
2
,
3
,
1
])
else_out2
=
onnx
.
helper
.
make_tensor_value_info
(
'else_out2'
,
onnx
.
TensorProto
.
FLOAT
,
[
2
,
3
,
1
])
xt
=
np
.
ones
((
2
,
3
,
1
)).
astype
(
np
.
float
)
xt_tensor
=
helper
.
make_tensor
(
name
=
'xt'
,
data_type
=
TensorProto
.
FLOAT
,
dims
=
xt
.
shape
,
vals
=
xt
.
flatten
().
astype
(
np
.
float32
))
yt
=
np
.
random
.
randn
(
2
,
3
,
1
).
astype
(
np
.
float
)
yt_tensor
=
helper
.
make_tensor
(
name
=
'yt'
,
data_type
=
TensorProto
.
FLOAT
,
dims
=
yt
.
shape
,
vals
=
yt
.
flatten
().
astype
(
np
.
float32
))
then_add_node
=
onnx
.
helper
.
make_node
(
'Add'
,
inputs
=
[
'x'
,
'xt'
],
outputs
=
[
'then_out'
])
then_add_node2
=
onnx
.
helper
.
make_node
(
'Add'
,
inputs
=
[
'x'
,
'x'
],
outputs
=
[
'then_out2'
])
else_mul_node
=
onnx
.
helper
.
make_node
(
'Mul'
,
inputs
=
[
'y'
,
'yt'
],
outputs
=
[
'else_out'
])
else_sub_node
=
onnx
.
helper
.
make_node
(
'Sub'
,
inputs
=
[
'y'
,
'yt'
],
outputs
=
[
'else_out2'
])
then_body
=
onnx
.
helper
.
make_graph
([
then_add_node
,
then_add_node2
],
'then_body'
,
[],
[
then_out
,
then_out2
])
else_body
=
onnx
.
helper
.
make_graph
([
else_mul_node
,
else_sub_node
],
'else_body'
,
[],
[
else_out
,
else_out2
])
cond_tensor
=
onnx
.
helper
.
make_tensor_value_info
(
"cond"
,
onnx
.
TensorProto
.
BOOL
,
[
1
])
res1
=
onnx
.
helper
.
make_tensor_value_info
(
'res1'
,
TensorProto
.
FLOAT
,
[])
res2
=
onnx
.
helper
.
make_tensor_value_info
(
'res2'
,
TensorProto
.
FLOAT
,
[])
node
=
onnx
.
helper
.
make_node
(
'If'
,
inputs
=
[
'cond'
],
outputs
=
[
'res1'
,
'res2'
],
then_branch
=
then_body
,
else_branch
=
else_body
)
return
([
node
],
[
x
,
y
,
cond_tensor
],
[
res1
,
res2
],
[
xt_tensor
,
yt_tensor
])
@
onnx_test
()
def
if_literal_test
():
then_out
=
onnx
.
helper
.
make_tensor_value_info
(
'then_out'
,
...
...
@@ -2807,6 +3015,59 @@ def if_then_test():
x
=
onnx
.
helper
.
make_tensor_value_info
(
'x'
,
onnx
.
TensorProto
.
FLOAT
,
[
2
,
3
])
y
=
onnx
.
helper
.
make_tensor_value_info
(
'y'
,
onnx
.
TensorProto
.
FLOAT
,
[
2
,
3
])
then_out
=
onnx
.
helper
.
make_tensor_value_info
(
'then_out'
,
onnx
.
TensorProto
.
FLOAT
,
[
2
,
3
])
else_out
=
onnx
.
helper
.
make_tensor_value_info
(
'else_out'
,
onnx
.
TensorProto
.
FLOAT
,
[
2
,
3
])
xt
=
np
.
ones
((
2
,
3
)).
astype
(
np
.
float
)
xt_tensor
=
helper
.
make_tensor
(
name
=
'xt'
,
data_type
=
TensorProto
.
FLOAT
,
dims
=
xt
.
shape
,
vals
=
xt
.
flatten
().
astype
(
np
.
float32
))
yt
=
np
.
random
.
randn
(
2
,
3
).
astype
(
np
.
float
)
yt_tensor
=
helper
.
make_tensor
(
name
=
'yt'
,
data_type
=
TensorProto
.
FLOAT
,
dims
=
yt
.
shape
,
vals
=
yt
.
flatten
().
astype
(
np
.
float32
))
then_add_node
=
onnx
.
helper
.
make_node
(
'Add'
,
inputs
=
[
'x'
,
'xt'
],
outputs
=
[
'then_out'
])
else_mul_node
=
onnx
.
helper
.
make_node
(
'Mul'
,
inputs
=
[
'y'
,
'yt'
],
outputs
=
[
'else_out'
])
then_body
=
onnx
.
helper
.
make_graph
([
then_add_node
],
'then_body'
,
[],
[
then_out
])
else_body
=
onnx
.
helper
.
make_graph
([
else_mul_node
],
'else_body'
,
[],
[
else_out
])
cond_tensor
=
onnx
.
helper
.
make_tensor_value_info
(
"cond"
,
onnx
.
TensorProto
.
BOOL
,
[
1
])
res
=
onnx
.
helper
.
make_tensor_value_info
(
'res'
,
TensorProto
.
FLOAT
,
[])
node
=
onnx
.
helper
.
make_node
(
'If'
,
inputs
=
[
'cond'
],
outputs
=
[
'res'
],
then_branch
=
then_body
,
else_branch
=
else_body
)
return
([
node
],
[
x
,
y
,
cond_tensor
],
[
res
],
[
xt_tensor
,
yt_tensor
])
@
onnx_test
()
def
if_then_test_inlined
():
x
=
onnx
.
helper
.
make_tensor_value_info
(
'x'
,
onnx
.
TensorProto
.
FLOAT
,
[
2
,
3
])
y
=
onnx
.
helper
.
make_tensor_value_info
(
'y'
,
onnx
.
TensorProto
.
FLOAT
,
[
2
,
3
])
then_out
=
onnx
.
helper
.
make_tensor_value_info
(
'then_out'
,
onnx
.
TensorProto
.
FLOAT
,
[
2
,
3
])
...
...
@@ -5707,6 +5968,24 @@ def scatternd_test():
return
([
node
],
[
data
,
indices
,
updates
],
[
output
])
@
onnx_test
()
def
scatternd_dyn_test
():
data
=
helper
.
make_tensor_value_info
(
'data'
,
TensorProto
.
FLOAT
,
[
None
,
2
,
2
])
indices
=
helper
.
make_tensor_value_info
(
'indices'
,
TensorProto
.
INT64
,
[
None
,
1
,
2
])
updates
=
helper
.
make_tensor_value_info
(
'updates'
,
TensorProto
.
FLOAT
,
[
None
,
1
,
2
])
output
=
helper
.
make_tensor_value_info
(
'output'
,
TensorProto
.
FLOAT
,
[
None
,
2
,
2
])
node
=
onnx
.
helper
.
make_node
(
'ScatterND'
,
inputs
=
[
'data'
,
'indices'
,
'updates'
],
outputs
=
[
'output'
])
return
([
node
],
[
data
,
indices
,
updates
],
[
output
])
@
onnx_test
()
def
selu_test
():
x
=
helper
.
make_tensor_value_info
(
'x'
,
TensorProto
.
DOUBLE
,
[
2
,
3
])
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment