Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
0ffcccbc
Unverified
Commit
0ffcccbc
authored
Feb 03, 2023
by
Chris Austen
Committed by
GitHub
Feb 03, 2023
Browse files
Merge branch 'develop' into jit-reduce-reg
parents
4f12db9e
2b5c5f5e
Changes
32
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
923 additions
and
75 deletions
+923
-75
.clang-tidy
.clang-tidy
+2
-0
.github/workflows/ci.yaml
.github/workflows/ci.yaml
+4
-3
Dockerfile
Dockerfile
+1
-1
hip-clang.docker
hip-clang.docker
+1
-1
src/CMakeLists.txt
src/CMakeLists.txt
+1
-2
src/include/migraphx/half.hpp
src/include/migraphx/half.hpp
+2
-2
src/include/migraphx/instruction_ref.hpp
src/include/migraphx/instruction_ref.hpp
+2
-2
src/include/migraphx/memory_coloring.hpp
src/include/migraphx/memory_coloring.hpp
+1
-1
src/include/migraphx/op/gathernd.hpp
src/include/migraphx/op/gathernd.hpp
+88
-17
src/include/migraphx/op/scatternd_op.hpp
src/include/migraphx/op/scatternd_op.hpp
+66
-21
src/memory_coloring.cpp
src/memory_coloring.cpp
+405
-0
src/onnx/include/migraphx/onnx/onnx_parser.hpp
src/onnx/include/migraphx/onnx/onnx_parser.hpp
+2
-1
src/onnx/onnx_parser.cpp
src/onnx/onnx_parser.cpp
+13
-7
src/onnx/parse_if.cpp
src/onnx/parse_if.cpp
+20
-2
src/onnx/parse_loop.cpp
src/onnx/parse_loop.cpp
+1
-1
src/pass_manager.cpp
src/pass_manager.cpp
+11
-10
src/program.cpp
src/program.cpp
+2
-1
test/memory_coloring_test.cpp
test/memory_coloring_test.cpp
+22
-3
test/onnx/gathernd_dyn_test.onnx
test/onnx/gathernd_dyn_test.onnx
+0
-0
test/onnx/gen_onnx.py
test/onnx/gen_onnx.py
+279
-0
No files found.
.clang-tidy
View file @
0ffcccbc
...
@@ -9,6 +9,8 @@ CheckOptions:
...
@@ -9,6 +9,8 @@ CheckOptions:
value: risky
value: risky
- key: modernize-loop-convert.NamingStyle
- key: modernize-loop-convert.NamingStyle
value: lower_case
value: lower_case
- key: misc-const-correctness.AnalyzeValues
value: 'false'
- key: performance-unnecessary-copy-initialization.AllowedTypes
- key: performance-unnecessary-copy-initialization.AllowedTypes
value: 'shape'
value: 'shape'
- key: performance-unnecessary-value-param.AllowedTypes
- key: performance-unnecessary-value-param.AllowedTypes
...
...
.github/workflows/ci.yaml
View file @
0ffcccbc
...
@@ -32,7 +32,8 @@ jobs:
...
@@ -32,7 +32,8 @@ jobs:
# In this step, this action saves a list of existing images,
# In this step, this action saves a list of existing images,
# the cache is created without them in the post run.
# the cache is created without them in the post run.
# It also restores the cache if it exists.
# It also restores the cache if it exists.
-
uses
:
satackey/action-docker-layer-caching@v0.0.11
# name: Docker Layer Caching2
-
uses
:
jpribyl/action-docker-layer-caching@v0.1.1
# Ignore the failure of a step and avoid terminating the job.
# Ignore the failure of a step and avoid terminating the job.
continue-on-error
:
true
continue-on-error
:
true
...
@@ -81,7 +82,7 @@ jobs:
...
@@ -81,7 +82,7 @@ jobs:
# In this step, this action saves a list of existing images,
# In this step, this action saves a list of existing images,
# the cache is created without them in the post run.
# the cache is created without them in the post run.
# It also restores the cache if it exists.
# It also restores the cache if it exists.
-
uses
:
satackey
/action-docker-layer-caching@v0.
0
.1
1
-
uses
:
jpribyl
/action-docker-layer-caching@v0.
1
.1
# Ignore the failure of a step and avoid terminating the job.
# Ignore the failure of a step and avoid terminating the job.
continue-on-error
:
true
continue-on-error
:
true
...
@@ -126,7 +127,7 @@ jobs:
...
@@ -126,7 +127,7 @@ jobs:
# In this step, this action saves a list of existing images,
# In this step, this action saves a list of existing images,
# the cache is created without them in the post run.
# the cache is created without them in the post run.
# It also restores the cache if it exists.
# It also restores the cache if it exists.
-
uses
:
satackey
/action-docker-layer-caching@v0.
0
.1
1
-
uses
:
jpribyl
/action-docker-layer-caching@v0.
1
.1
# Ignore the failure of a step and avoid terminating the job.
# Ignore the failure of a step and avoid terminating the job.
continue-on-error
:
true
continue-on-error
:
true
...
...
Dockerfile
View file @
0ffcccbc
...
@@ -10,7 +10,7 @@ RUN apt-get update && apt-get install -y gnupg2 --no-install-recommends curl &&
...
@@ -10,7 +10,7 @@ RUN apt-get update && apt-get install -y gnupg2 --no-install-recommends curl &&
curl
-sL
http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add -
curl
-sL
http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add -
# Add rocm repository
# Add rocm repository
RUN
sh
-c
'echo deb [arch=amd64 trusted=yes] http://repo.radeon.com/rocm/apt/5.
3
/ ubuntu main > /etc/apt/sources.list.d/rocm.list'
RUN
sh
-c
'echo deb [arch=amd64 trusted=yes] http://repo.radeon.com/rocm/apt/5.
4.2
/ ubuntu main > /etc/apt/sources.list.d/rocm.list'
# Install dependencies
# Install dependencies
RUN
apt-get update
&&
DEBIAN_FRONTEND
=
noninteractive apt-get
install
-y
--allow-unauthenticated
\
RUN
apt-get update
&&
DEBIAN_FRONTEND
=
noninteractive apt-get
install
-y
--allow-unauthenticated
\
...
...
hip-clang.docker
View file @
0ffcccbc
...
@@ -6,7 +6,7 @@ ARG PREFIX=/usr/local
...
@@ -6,7 +6,7 @@ ARG PREFIX=/usr/local
RUN
dpkg
--add-architecture
i386
RUN
dpkg
--add-architecture
i386
# Add rocm repository
# Add rocm repository
RUN
sh
-c
'echo deb [arch=amd64 trusted=yes] http://repo.radeon.com/rocm/apt/5.
3
/ ubuntu main > /etc/apt/sources.list.d/rocm.list'
RUN
sh
-c
'echo deb [arch=amd64 trusted=yes] http://repo.radeon.com/rocm/apt/5.
4.2
/ ubuntu main > /etc/apt/sources.list.d/rocm.list'
# Install dependencies
# Install dependencies
RUN
apt-get update
&&
DEBIAN_FRONTEND
=
noninteractive apt-get
install
-y
--allow-unauthenticated
\
RUN
apt-get update
&&
DEBIAN_FRONTEND
=
noninteractive apt-get
install
-y
--allow-unauthenticated
\
...
...
src/CMakeLists.txt
100755 → 100644
View file @
0ffcccbc
...
@@ -58,6 +58,7 @@ add_library(migraphx
...
@@ -58,6 +58,7 @@ add_library(migraphx
layout_nhwc.cpp
layout_nhwc.cpp
load_save.cpp
load_save.cpp
make_op.cpp
make_op.cpp
memory_coloring.cpp
module.cpp
module.cpp
msgpack.cpp
msgpack.cpp
normalize_attributes.cpp
normalize_attributes.cpp
...
@@ -65,8 +66,6 @@ add_library(migraphx
...
@@ -65,8 +66,6 @@ add_library(migraphx
op_enums.cpp
op_enums.cpp
operation.cpp
operation.cpp
optimize_module.cpp
optimize_module.cpp
opt/memory_coloring.cpp
opt/memory_coloring_impl.cpp
pad_calc.cpp
pad_calc.cpp
pass_manager.cpp
pass_manager.cpp
permutation.cpp
permutation.cpp
...
...
src/include/migraphx/half.hpp
View file @
0ffcccbc
...
@@ -58,12 +58,12 @@ using deduce = typename detail::deduce<T>::type;
...
@@ -58,12 +58,12 @@ using deduce = typename detail::deduce<T>::type;
namespace
std
{
namespace
std
{
template
<
class
T
>
template
<
class
T
>
struct
common_type
<
migraphx
::
half
,
T
>
:
std
::
common_type
<
float
,
T
>
struct
common_type
<
migraphx
::
half
,
T
>
:
std
::
common_type
<
float
,
T
>
// NOLINT
{
{
};
};
template
<
class
T
>
template
<
class
T
>
struct
common_type
<
T
,
migraphx
::
half
>
:
std
::
common_type
<
float
,
T
>
struct
common_type
<
T
,
migraphx
::
half
>
:
std
::
common_type
<
float
,
T
>
// NOLINT
{
{
};
};
...
...
src/include/migraphx/instruction_ref.hpp
View file @
0ffcccbc
...
@@ -41,7 +41,7 @@ migraphx::instruction* as_address(const instruction_ref& ins) noexcept;
...
@@ -41,7 +41,7 @@ migraphx::instruction* as_address(const instruction_ref& ins) noexcept;
namespace
std
{
namespace
std
{
template
<
>
template
<
>
struct
hash
<
migraphx
::
instruction_ref
>
struct
hash
<
migraphx
::
instruction_ref
>
// NOLINT
{
{
using
argument_type
=
migraphx
::
instruction_ref
;
using
argument_type
=
migraphx
::
instruction_ref
;
using
result_type
=
std
::
size_t
;
using
result_type
=
std
::
size_t
;
...
@@ -52,7 +52,7 @@ struct hash<migraphx::instruction_ref>
...
@@ -52,7 +52,7 @@ struct hash<migraphx::instruction_ref>
};
};
template
<
>
template
<
>
struct
equal_to
<
migraphx
::
instruction_ref
>
struct
equal_to
<
migraphx
::
instruction_ref
>
// NOLINT
{
{
using
argument_type
=
migraphx
::
instruction_ref
;
using
argument_type
=
migraphx
::
instruction_ref
;
using
result_type
=
bool
;
using
result_type
=
bool
;
...
...
src/include/migraphx/memory_coloring.hpp
View file @
0ffcccbc
...
@@ -39,7 +39,7 @@ struct memory_coloring
...
@@ -39,7 +39,7 @@ struct memory_coloring
{
{
std
::
string
allocation_op
{};
std
::
string
allocation_op
{};
bool
verify
=
false
;
bool
verify
=
false
;
std
::
string
name
()
const
{
return
"memory
coloring"
;
}
std
::
string
name
()
const
{
return
"memory
_
coloring"
;
}
void
apply
(
module
&
m
)
const
;
void
apply
(
module
&
m
)
const
;
};
};
...
...
src/include/migraphx/op/gathernd.hpp
View file @
0ffcccbc
/*
/*
* The MIT License (MIT)
* The MIT License (MIT)
*
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* of this software and associated documentation files (the "Software"), to deal
...
@@ -25,6 +25,7 @@
...
@@ -25,6 +25,7 @@
#define MIGRAPHX_GUARD_OPERATORS_GATHERND_HPP
#define MIGRAPHX_GUARD_OPERATORS_GATHERND_HPP
#include <migraphx/check_shapes.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/argument.hpp>
...
@@ -47,33 +48,103 @@ struct gathernd
...
@@ -47,33 +48,103 @@ struct gathernd
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
2
);
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
2
);
auto
r
=
inputs
.
front
().
lens
().
size
();
auto
i_shape
=
inputs
.
back
();
auto
q
=
inputs
.
back
().
lens
().
size
();
auto
data_shape
=
inputs
.
front
();
auto
k
=
inputs
.
back
().
lens
().
back
();
auto
r
=
data_shape
.
ndim
();
auto
q
=
i_shape
.
ndim
();
size_t
k
;
if
(
i_shape
.
dynamic
())
{
// the rank of the output is a function of k, so it must be fixed.
if
(
not
i_shape
.
dyn_dims
().
back
().
is_fixed
())
{
MIGRAPHX_THROW
(
"GATHERND: last dimension of indices tensor must be fixed (min=max)"
);
}
k
=
i_shape
.
dyn_dims
().
back
().
min
;
}
else
k
=
i_shape
.
lens
().
back
();
// Begin input validation checks.
int
output_ndim
=
int
(
q
)
+
r
-
k
-
batch_dims
-
1
;
if
(
k
>
r
-
batch_dims
)
if
(
k
>
r
-
batch_dims
)
{
{
MIGRAPHX_THROW
(
"GATHERND: Indices of length "
+
std
::
to_string
(
k
)
+
MIGRAPHX_THROW
(
"GATHERND: Indices of length "
+
std
::
to_string
(
k
)
+
" cannot be used to access data of rank "
+
" cannot be used to access data of rank "
+
std
::
to_string
(
r
-
batch_dims
));
std
::
to_string
(
r
-
batch_dims
));
}
}
auto
indices_lens_iter
=
inputs
.
back
().
lens
().
begin
();
auto
output_lens_size
=
q
+
r
-
k
-
batch_dims
-
1
;
if
(
batch_dims
>=
q
or
batch_dims
>=
r
)
std
::
vector
<
std
::
size_t
>
output_lens
(
output_lens_size
);
{
std
::
copy
(
indices_lens_iter
,
indices_lens_iter
+
(
q
-
1
),
output_lens
.
begin
());
MIGRAPHX_THROW
(
"GATHERND: rank of an input cannot be less than batch_dims="
+
if
(
k
<
r
-
batch_dims
)
std
::
to_string
(
batch_dims
));
}
if
(
output_ndim
<
0
)
{
MIGRAPHX_THROW
(
"GATHERND: Indices too large for static data input: k="
+
std
::
to_string
(
k
));
}
if
(
migraphx
::
none_of
(
inputs
,
[](
auto
v
)
{
return
v
.
dynamic
();
}))
{
auto
indices_lens_iter
=
i_shape
.
lens
().
begin
();
// A rank 0 output is a scalar
if
(
output_ndim
==
0
)
return
shape
{
data_shape
.
type
(),
{
1
}};
// Part of the output shape comes from indices tensor, part from data tensor
std
::
vector
<
std
::
size_t
>
output_lens
(
output_ndim
);
std
::
copy
(
indices_lens_iter
,
indices_lens_iter
+
(
q
-
1
),
output_lens
.
begin
());
// fill the rest of output shape from data tensor
if
(
k
+
batch_dims
<
r
)
{
auto
data_lens
=
data_shape
.
lens
();
std
::
copy
(
data_lens
.
begin
()
+
batch_dims
+
k
,
data_lens
.
end
(),
output_lens
.
begin
()
+
q
-
1
);
}
shape
output_shape
{
data_shape
.
type
(),
output_lens
};
return
output_shape
;
}
else
{
{
auto
data_lens
=
inputs
.
front
().
lens
();
// If one or both inputs are dynamic shapes, the output is dynamic.
std
::
copy
(
// Make both inputs dynamic to simplify computations.
data_lens
.
begin
()
+
batch_dims
+
k
,
data_lens
.
end
(),
output_lens
.
begin
()
+
q
-
1
);
data_shape
=
data_shape
.
to_dynamic
();
i_shape
=
i_shape
.
to_dynamic
();
// A rank 0 output is a scalar
if
(
output_ndim
==
0
)
return
shape
(
data_shape
.
type
(),
{
shape
::
dynamic_dimension
({
1
,
1
,
0
})});
// Part of the output shape comes from indices tensor, part from data tensor
std
::
vector
<
shape
::
dynamic_dimension
>
output_dims
(
output_ndim
);
std
::
copy
(
i_shape
.
dyn_dims
().
begin
(),
i_shape
.
dyn_dims
().
begin
()
+
q
-
1
,
output_dims
.
begin
());
// fill the rest of output shape from data tensor
if
(
k
+
batch_dims
<
r
)
{
auto
data_dims
=
data_shape
.
dyn_dims
();
std
::
copy
(
data_dims
.
begin
()
+
batch_dims
+
k
,
data_dims
.
begin
()
+
r
,
output_dims
.
begin
()
+
q
-
1
);
}
shape
output_shape
(
data_shape
.
type
(),
output_dims
);
return
output_shape
;
}
}
shape
output_shape
{
inputs
.
front
().
type
(),
output_lens
};
return
output_shape
;
}
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
{
argument
result
{
out
put_shape
};
argument
result
{
dyn_out
.
com
put
ed
_shape
};
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
data
)
{
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
data
)
{
args
[
1
].
visit
([
&
](
auto
indices
)
{
args
[
1
].
visit
([
&
](
auto
indices
)
{
auto
indices_shape
=
indices
.
get_shape
();
auto
indices_shape
=
indices
.
get_shape
();
...
...
src/include/migraphx/op/scatternd_op.hpp
View file @
0ffcccbc
...
@@ -28,44 +28,89 @@
...
@@ -28,44 +28,89 @@
#include <migraphx/check_shapes.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/ranges.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
namespace
op
{
/**
* @brief
* N-dimensional Scatter operations. This struct is parent class to ops which differ in what formula
* is used to reduce (combine old and new values of) the scattered value. It was originally based
* on Onnx ScatterND operation (see
* https://github.com/onnx/onnx/blob/main/docs/Operators.md#ScatterND) and is also similar to Numpy
* numpy.add.at().
*
* @tparam Derived a template parameter in the CRTP inheritance idiom, represents one of the child
* operations.
*/
template
<
class
Derived
>
template
<
class
Derived
>
struct
scatternd_op
:
op_name
<
Derived
>
struct
scatternd_op
:
op_name
<
Derived
>
{
{
/** Validate input shapes and return the correct output shape. For Scatter ops, the output
* is the same shape as the data tensor (first input), but cast to a standard shape.
*
*/
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
3
);
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
3
);
auto
r
=
inputs
.
front
().
lens
().
size
();
auto
data_shape
=
inputs
.
front
();
auto
q
=
inputs
.
at
(
1
).
lens
().
size
();
auto
index_shape
=
inputs
.
at
(
1
);
auto
k
=
inputs
.
at
(
1
).
lens
().
back
();
auto
upd_shape
=
inputs
.
back
();
auto
ind_lens
=
inputs
.
at
(
1
).
lens
();
auto
upd_lens
=
inputs
.
back
().
lens
();
auto
r
=
data_shape
.
ndim
();
auto
data_lens
=
inputs
.
front
().
lens
();
auto
q
=
index_shape
.
ndim
();
size_t
k
;
if
(
index_shape
.
dynamic
())
{
// the rank of the output is a function of k, so k must be fixed.
if
(
not
index_shape
.
dyn_dims
().
back
().
is_fixed
())
{
MIGRAPHX_THROW
(
"GATHERND: last dimension of indices tensor must be fixed (min=max)"
);
}
k
=
index_shape
.
dyn_dims
().
back
().
min
;
}
else
k
=
index_shape
.
lens
().
back
();
// Checks on the sizes of input tensors
if
(
q
+
r
!=
upd_shape
.
ndim
()
+
k
+
1
)
MIGRAPHX_THROW
(
"ScatterND: ranks of inputs don't match. "
+
std
::
to_string
(
q
)
+
" + "
+
std
::
to_string
(
r
)
+
" - "
+
std
::
to_string
(
k
)
+
" - 1 != "
+
std
::
to_string
(
upd_shape
.
ndim
()));
if
(
k
>
r
)
if
(
k
>
r
)
MIGRAPHX_THROW
(
"ScatterND: index of size "
+
std
::
to_string
(
k
)
+
MIGRAPHX_THROW
(
"ScatterND: index of size "
+
std
::
to_string
(
k
)
+
" is too large for tensor of rank "
+
std
::
to_string
(
r
));
" is too large for tensor of rank "
+
std
::
to_string
(
r
));
if
(
not
(
std
::
equal
(
ind_lens
.
begin
(),
ind_lens
.
begin
()
+
q
-
1
,
upd_lens
.
begin
())
and
std
::
equal
(
data_lens
.
begin
()
+
k
,
data_lens
.
end
(),
upd_lens
.
begin
()
+
q
-
1
)))
// Convert all static shape dimensions to dynamic so they can be compared.
MIGRAPHX_THROW
(
"ScatterND: incorrect update shape. update.lens != indices.lens[0:q-1] "
// It's possible for some of the 3 inputs to be dynamic shapes and some static,
"++ data.lens[k:r-1]"
);
// but any dynamic dimension that's compared to a static dimension must be fixed.
auto
s
=
inputs
.
front
();
auto
ind_dims
=
index_shape
.
to_dynamic
().
dyn_dims
();
if
(
s
.
broadcasted
())
auto
upd_dims
=
upd_shape
.
to_dynamic
().
dyn_dims
();
auto
data_dims
=
data_shape
.
to_dynamic
().
dyn_dims
();
// Check that corresponding portions of tensor shapes match.
if
(
not
(
std
::
equal
(
ind_dims
.
begin
(),
ind_dims
.
begin
()
+
q
-
1
,
upd_dims
.
begin
())
and
std
::
equal
(
data_dims
.
begin
()
+
k
,
data_dims
.
end
(),
upd_dims
.
begin
()
+
q
-
1
)))
MIGRAPHX_THROW
(
"ScatterND: incorrect update shape. Update dimensions must match "
"indices and data."
);
if
(
data_shape
.
dynamic
())
return
data_shape
;
else
if
(
data_shape
.
broadcasted
())
{
{
return
{
s
.
type
(),
s
.
lens
()};
return
{
data_shape
.
type
(),
data_shape
.
lens
()};
}
}
else
else
{
{
return
s
.
with_lens
(
s
.
lens
());
return
data_shape
.
with_lens
(
data_shape
.
lens
());
}
}
}
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
{
argument
result
{
out
put_shape
};
argument
result
{
dyn_out
.
com
put
ed
_shape
};
auto
&
self
=
static_cast
<
const
Derived
&>
(
*
this
);
auto
&
self
=
static_cast
<
const
Derived
&>
(
*
this
);
visit_all
(
result
,
args
[
0
],
args
[
2
])([
&
](
auto
output
,
auto
data
,
auto
updates
)
{
visit_all
(
result
,
args
[
0
],
args
[
2
])([
&
](
auto
output
,
auto
data
,
auto
updates
)
{
std
::
copy
(
data
.
begin
(),
data
.
end
(),
output
.
begin
());
std
::
copy
(
data
.
begin
(),
data
.
end
(),
output
.
begin
());
...
@@ -74,8 +119,8 @@ struct scatternd_op : op_name<Derived>
...
@@ -74,8 +119,8 @@ struct scatternd_op : op_name<Derived>
auto
updates_std
=
shape
{
updates_shape
.
type
(),
updates_shape
.
lens
()};
auto
updates_std
=
shape
{
updates_shape
.
type
(),
updates_shape
.
lens
()};
auto
indices_shape
=
indices
.
get_shape
();
auto
indices_shape
=
indices
.
get_shape
();
auto
k
=
indices_shape
.
lens
().
back
();
auto
k
=
indices_shape
.
lens
().
back
();
auto
q
=
indices_shape
.
lens
().
size
();
auto
q
=
indices_shape
.
ndim
();
auto
r
=
out
put_shape
.
lens
().
size
();
auto
r
=
dyn_out
.
com
put
ed
_shape
.
ndim
();
par_for
(
updates_shape
.
elements
(),
[
&
](
const
auto
i
)
{
par_for
(
updates_shape
.
elements
(),
[
&
](
const
auto
i
)
{
auto
updates_idx
=
updates_std
.
multi
(
i
);
auto
updates_idx
=
updates_std
.
multi
(
i
);
std
::
vector
<
std
::
size_t
>
indices_idx
(
q
,
0
);
std
::
vector
<
std
::
size_t
>
indices_idx
(
q
,
0
);
...
@@ -89,7 +134,7 @@ struct scatternd_op : op_name<Derived>
...
@@ -89,7 +134,7 @@ struct scatternd_op : op_name<Derived>
std
::
copy
(
index_start
,
index_end
,
out_idx
.
begin
());
std
::
copy
(
index_start
,
index_end
,
out_idx
.
begin
());
std
::
copy
(
updates_idx
.
begin
()
+
q
-
1
,
updates_idx
.
end
(),
out_idx
.
begin
()
+
k
);
std
::
copy
(
updates_idx
.
begin
()
+
q
-
1
,
updates_idx
.
end
(),
out_idx
.
begin
()
+
k
);
self
.
reduction
()(
output
[
out
put_shape
.
index
(
out_idx
)],
updates
[
i
]);
self
.
reduction
()(
output
[
dyn_out
.
com
put
ed
_shape
.
index
(
out_idx
)],
updates
[
i
]);
});
});
});
});
});
});
...
...
src/memory_coloring.cpp
0 → 100644
View file @
0ffcccbc
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/memory_coloring.hpp>
#include <migraphx/module.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/algorithm.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
#include <unordered_set>
#include <unordered_map>
#include <map>
#include <set>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_DEBUG_MEMORY_COLORING
);
using
instruction_set
=
std
::
unordered_set
<
instruction_ref
>
;
using
instruction_set_map
=
std
::
unordered_map
<
instruction_ref
,
instruction_set
>
;
// This will do liveness analysis on the module, and it will call the
// function `f` with the instruction and the set of the other instructions
// that are live
template
<
class
F
>
void
liveness
(
const
module
&
m
,
F
f
)
{
auto
implicit_deps
=
m
.
calc_implicit_deps
();
instruction_set
live_set
;
auto
rp
=
reverse
(
m
);
for
(
auto
rins
:
iterator_for
(
rp
))
// NOLINT
{
// The base iterator is one ahead, so we need to use the previous iterator
auto
ins
=
std
::
prev
(
rins
.
base
());
// Add live variables
auto
add_live_variables
=
[
&
](
const
auto
&
inputs
)
{
for
(
auto
input
:
inputs
)
{
auto
i
=
instruction
::
get_output_alias
(
input
);
// Skip if variable comes from parent
if
(
not
m
.
has_instruction
(
i
))
continue
;
live_set
.
insert
(
i
);
}
};
add_live_variables
(
ins
->
inputs
());
add_live_variables
(
implicit_deps
[
ins
]);
// Remove last usage
auto
it
=
live_set
.
find
(
ins
);
if
(
it
!=
live_set
.
end
())
{
live_set
.
erase
(
it
);
f
(
ins
,
live_set
);
}
}
}
// This will build the conflict table or interference graph. This is
// essentially a map from one instruction to a set of instruction that are
// used together. Each instruction will be the allocation instruction.
instruction_set_map
build_conflict_table
(
const
module
&
m
,
std
::
string
allocation_op
)
{
instruction_set_map
conflict_table
;
liveness
(
m
,
[
&
](
auto
ins
,
auto
live_set
)
{
// Skip variables that aren't allocations
if
(
ins
->
name
()
!=
allocation_op
)
return
;
// Skip zero allocations
if
(
ins
->
get_shape
().
bytes
()
==
0
)
return
;
conflict_table
[
ins
];
for
(
auto
i
:
live_set
)
{
if
(
i
==
ins
)
continue
;
// Skip variables that aren't allocations
if
(
i
->
name
()
!=
allocation_op
)
continue
;
// Skip zero allocations
if
(
i
->
get_shape
().
bytes
()
==
0
)
continue
;
conflict_table
[
i
].
insert
(
ins
);
conflict_table
[
ins
].
insert
(
i
);
}
});
assert
(
std
::
all_of
(
conflict_table
.
begin
(),
conflict_table
.
end
(),
[](
auto
&&
pp
)
{
return
pp
.
second
.
count
(
pp
.
first
)
==
0
;
}));
return
conflict_table
;
}
// Check if intervals overlap
bool
is_overlap
(
std
::
pair
<
std
::
size_t
,
std
::
size_t
>
x
,
std
::
pair
<
std
::
size_t
,
std
::
size_t
>
y
)
{
return
std
::
max
(
x
.
first
,
y
.
first
)
<
std
::
min
(
x
.
second
,
y
.
second
);
}
struct
allocation_segment
{
using
segment
=
std
::
pair
<
std
::
size_t
,
std
::
size_t
>
;
std
::
unordered_map
<
instruction_ref
,
segment
>
ins2segment
;
const
segment
*
add_segment
(
instruction_ref
ins
,
segment
s
)
{
return
&
(
ins2segment
[
ins
]
=
s
);
}
const
segment
*
get_segment
(
instruction_ref
ins
)
const
{
auto
it
=
ins2segment
.
find
(
ins
);
if
(
it
==
ins2segment
.
end
())
return
nullptr
;
return
&
it
->
second
;
}
// Remove segment for an instruction
void
remove
(
instruction_ref
ins
)
{
auto
it
=
ins2segment
.
find
(
ins
);
if
(
it
!=
ins2segment
.
end
())
{
ins2segment
.
erase
(
it
);
}
}
std
::
size_t
max
()
{
std
::
size_t
n
=
0
;
for
(
auto
&&
pp
:
ins2segment
)
{
auto
seg
=
pp
.
second
;
n
=
std
::
max
(
n
,
seg
.
second
);
}
return
n
;
}
template
<
class
Iterator
>
static
bool
overlaps
(
Iterator
first
,
Iterator
last
,
const
segment
&
s
)
{
return
std
::
any_of
(
first
,
last
,
[
&
](
auto
&&
t
)
{
return
is_overlap
(
s
,
t
);
});
}
static
bool
overlaps
(
const
std
::
set
<
segment
>&
segments
,
const
segment
&
s
)
{
return
overlaps
(
segments
.
begin
(),
segments
.
end
(),
s
);
}
static
auto
find_gap
(
const
std
::
set
<
segment
>&
segments
,
std
::
size_t
n
)
{
std
::
size_t
max_end
=
0
;
return
std
::
adjacent_find
(
segments
.
begin
(),
segments
.
end
(),
[
&
](
segment
x
,
segment
y
)
{
if
(
x
.
second
<
max_end
)
return
false
;
max_end
=
x
.
second
;
if
(
is_overlap
(
x
,
y
))
return
false
;
assert
(
y
.
first
>=
x
.
second
);
auto
k
=
y
.
first
-
x
.
second
;
return
(
k
>=
n
);
});
}
static
std
::
size_t
max_type_size
(
const
shape
&
s
)
{
return
std
::
accumulate
(
s
.
sub_shapes
().
begin
(),
s
.
sub_shapes
().
end
(),
s
.
type_size
(),
[](
auto
size
,
const
auto
&
sub
)
{
return
std
::
max
(
size
,
max_type_size
(
sub
));
});
}
static
std
::
size_t
compute_alignment
(
instruction_ref
ins
)
{
auto
alignment
=
max_type_size
(
ins
->
get_shape
());
// A rough estimate for the total number of elements
auto
n
=
ins
->
get_shape
().
bytes
()
/
alignment
;
// Check for vectorized alignment
if
(
n
>
4
)
{
auto
d
=
n
%
4
;
if
(
d
==
0
)
alignment
*=
4
;
if
(
d
==
2
)
alignment
*=
2
;
}
return
alignment
;
}
static
segment
next_segment
(
std
::
set
<
segment
>&
segments
,
instruction_ref
ins
,
std
::
size_t
alignment
)
{
assert
(
ins
->
get_shape
().
bytes
()
>
0
);
// Compute alignment
auto
n
=
1
+
(
ins
->
get_shape
().
bytes
()
-
1
)
/
alignment
;
assert
(
n
>
0
);
auto
start
=
0
;
// Insert at end if it cant fit at the begining
if
(
segments
.
empty
()
or
segments
.
begin
()
->
first
<=
n
)
{
auto
it
=
find_gap
(
segments
,
n
);
if
(
it
==
segments
.
end
())
it
=
std
::
max_element
(
segments
.
begin
(),
segments
.
end
(),
[
&
](
segment
x
,
segment
y
)
{
return
x
.
second
<
y
.
second
;
});
if
(
it
!=
segments
.
end
())
start
=
it
->
second
;
}
auto
s
=
segment
{
start
,
start
+
n
};
assert
(
not
overlaps
(
segments
,
s
));
segments
.
insert
(
s
);
return
s
;
}
static
std
::
unordered_map
<
instruction_ref
,
int
>
create_allocation_index
(
const
module
&
m
,
const
instruction_set_map
&
conflict_table
)
{
std
::
unordered_map
<
instruction_ref
,
int
>
result
;
int
i
=
0
;
for
(
auto
ins
:
iterator_for
(
m
))
{
if
(
not
contains
(
conflict_table
,
ins
))
continue
;
result
[
ins
]
=
i
++
;
}
return
result
;
}
// Build the allocation_color class from the conflict_table
static
allocation_segment
build
(
const
module
&
m
,
const
instruction_set_map
&
conflict_table
,
std
::
size_t
alignment
)
{
allocation_segment
as
{};
std
::
vector
<
instruction_ref
>
conflict_queue
;
// Add all allocations to the conflict_queue
std
::
transform
(
conflict_table
.
begin
(),
conflict_table
.
end
(),
std
::
back_inserter
(
conflict_queue
),
[](
auto
&&
pp
)
{
return
pp
.
first
;
});
auto
alloc_index
=
create_allocation_index
(
m
,
conflict_table
);
// Sort the conflict queue so we process the allocation with the most
// number of adjacent allocations first
std
::
sort
(
conflict_queue
.
begin
(),
conflict_queue
.
end
(),
by
(
std
::
greater
<>
{},
[
&
](
auto
x
)
{
return
std
::
make_tuple
(
conflict_table
.
at
(
x
).
size
(),
x
->
get_shape
().
bytes
(),
alloc_index
.
at
(
x
));
}));
// Process the conflict_queue, we refer to the current allocation as
// the parent and the adjacent allocations as children
for
(
auto
parent
:
conflict_queue
)
{
// Sort children by size
std
::
vector
<
instruction_ref
>
children
(
conflict_table
.
at
(
parent
).
begin
(),
conflict_table
.
at
(
parent
).
end
());
std
::
sort
(
children
.
begin
(),
children
.
end
(),
by
(
std
::
less
<>
{},
[
&
](
auto
x
)
{
return
std
::
make_tuple
(
x
->
get_shape
().
bytes
(),
alloc_index
.
at
(
x
));
}));
assert
(
not
contains
(
children
,
parent
));
// This set is to track the segments already processed
std
::
set
<
segment
>
segments
;
// Add all segments for the children to the segments already processed
transform_if
(
children
.
begin
(),
children
.
end
(),
std
::
inserter
(
segments
,
segments
.
begin
()),
[
&
](
auto
child
)
{
return
as
.
get_segment
(
child
);
},
[
&
](
auto
child
)
{
return
*
as
.
get_segment
(
child
);
});
assert
(
as
.
get_segment
(
parent
)
==
nullptr
);
as
.
add_segment
(
parent
,
next_segment
(
segments
,
parent
,
alignment
));
}
// Reduce the number of segments
for
(
std
::
size_t
n
=
0
;
n
<
3
;
n
++
)
{
for
(
auto
parent
:
conflict_queue
)
{
auto
children
=
conflict_table
.
at
(
parent
);
// This set is to track the segments already processed
std
::
set
<
segment
>
segments
;
// Add all segments for the children to the segments already processed
transform_if
(
children
.
begin
(),
children
.
end
(),
std
::
inserter
(
segments
,
segments
.
begin
()),
[
&
](
auto
child
)
{
return
as
.
get_segment
(
child
);
},
[
&
](
auto
child
)
{
return
*
as
.
get_segment
(
child
);
});
// Get the segment for the parent
const
auto
*
parent_segment
=
as
.
get_segment
(
parent
);
assert
(
parent_segment
!=
nullptr
);
auto
s
=
next_segment
(
segments
,
parent
,
alignment
);
if
(
s
!=
*
parent_segment
and
s
.
second
<=
as
.
max
())
{
as
.
add_segment
(
parent
,
s
);
}
}
}
return
as
;
}
};
static
std
::
size_t
find_max_alignment
(
const
module
&
m
,
const
std
::
string
&
allocation_op
)
{
std
::
size_t
alignment
=
1
;
for
(
auto
ins
:
iterator_for
(
m
))
{
if
(
ins
->
name
()
!=
allocation_op
)
continue
;
alignment
=
std
::
max
(
allocation_segment
::
compute_alignment
(
ins
),
alignment
);
}
return
alignment
;
}
void
memory_coloring
::
apply
(
module
&
m
)
const
{
const
std
::
size_t
alignment
=
find_max_alignment
(
m
,
allocation_op
);
auto
conflict_table
=
build_conflict_table
(
m
,
allocation_op
);
auto
as
=
allocation_segment
::
build
(
m
,
conflict_table
,
alignment
);
// All allocations should have a segment
assert
(
std
::
all_of
(
conflict_table
.
begin
(),
conflict_table
.
end
(),
[
&
](
auto
&&
pp
)
{
return
as
.
get_segment
(
pp
.
first
);
}));
// Adjacent allocations should not have overlapping segments
assert
(
std
::
none_of
(
conflict_table
.
begin
(),
conflict_table
.
end
(),
[
&
](
auto
&&
pp
)
{
auto
*
x
=
as
.
get_segment
(
pp
.
first
);
return
std
::
any_of
(
pp
.
second
.
begin
(),
pp
.
second
.
end
(),
[
&
](
auto
ins
)
{
auto
*
y
=
as
.
get_segment
(
ins
);
assert
(
x
and
y
);
return
is_overlap
(
*
x
,
*
y
);
});
}));
// Print out segments
if
(
enabled
(
MIGRAPHX_DEBUG_MEMORY_COLORING
{}))
{
for
(
auto
&&
pp
:
conflict_table
)
{
std
::
cout
<<
"------- conflict -------"
<<
std
::
endl
;
auto
s1
=
as
.
ins2segment
.
at
(
pp
.
first
);
std
::
cout
<<
s1
.
first
<<
", "
<<
s1
.
second
<<
": "
;
m
.
debug_print
(
pp
.
first
);
for
(
auto
ins
:
pp
.
second
)
{
auto
s2
=
as
.
ins2segment
.
at
(
ins
);
std
::
cout
<<
s2
.
first
<<
", "
<<
s2
.
second
<<
": "
;
m
.
debug_print
(
ins
);
}
}
}
// Total memory
std
::
size_t
n
=
as
.
max
()
*
alignment
;
// Replace allocations
auto
mem
=
m
.
add_parameter
(
"scratch"
,
shape
{
shape
::
int8_type
,
{
n
}});
for
(
auto
&&
[
ins
,
seg
]
:
as
.
ins2segment
)
{
assert
(
ins
->
name
()
==
allocation_op
);
auto
s
=
ins
->
get_shape
();
std
::
size_t
offset
=
seg
.
first
*
alignment
;
assert
(
offset
<
n
);
m
.
replace_instruction
(
ins
,
op
::
load
{
s
,
offset
},
mem
);
}
// Replace zero allocation
for
(
auto
ins
:
iterator_for
(
m
))
{
if
(
ins
->
name
()
!=
allocation_op
)
continue
;
assert
(
ins
->
get_shape
().
bytes
()
==
0
);
m
.
replace_instruction
(
ins
,
op
::
load
{
ins
->
get_shape
(),
0
},
mem
);
}
// Remove scratch parameter if its not used
if
(
mem
->
outputs
().
empty
())
{
m
.
remove_instruction
(
mem
);
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/onnx/include/migraphx/onnx/onnx_parser.hpp
View file @
0ffcccbc
...
@@ -113,7 +113,8 @@ struct onnx_parser
...
@@ -113,7 +113,8 @@ struct onnx_parser
void
parse_from
(
std
::
istream
&
is
,
std
::
string
name
=
""
);
void
parse_from
(
std
::
istream
&
is
,
std
::
string
name
=
""
);
void
parse_from
(
const
void
*
data
,
std
::
size_t
size
);
void
parse_from
(
const
void
*
data
,
std
::
size_t
size
);
void
parse_graph
(
module
*
mod
,
const
onnx
::
GraphProto
&
graph
);
std
::
vector
<
instruction_ref
>
parse_graph
(
module
*
mod
,
const
onnx
::
GraphProto
&
graph
,
bool
inlining
=
false
);
literal
parse_value
(
const
onnx
::
AttributeProto
&
attr
)
const
;
literal
parse_value
(
const
onnx
::
AttributeProto
&
attr
)
const
;
literal
parse_tensor
(
const
onnx
::
TensorProto
&
t
)
const
;
literal
parse_tensor
(
const
onnx
::
TensorProto
&
t
)
const
;
shape
parse_type
(
const
onnx
::
TypeProto
&
t
,
const
std
::
vector
<
std
::
size_t
>&
input_dims
)
const
;
shape
parse_type
(
const
onnx
::
TypeProto
&
t
,
const
std
::
vector
<
std
::
size_t
>&
input_dims
)
const
;
...
...
src/onnx/onnx_parser.cpp
View file @
0ffcccbc
...
@@ -220,7 +220,7 @@ void onnx_parser::parse_from(std::istream& is, std::string name)
...
@@ -220,7 +220,7 @@ void onnx_parser::parse_from(std::istream& is, std::string name)
if
(
model
.
has_graph
())
if
(
model
.
has_graph
())
{
{
this
->
parse_graph
(
mm
,
model
.
graph
());
(
void
)
this
->
parse_graph
(
mm
,
model
.
graph
());
}
}
}
}
else
else
...
@@ -240,7 +240,7 @@ void onnx_parser::parse_from(const void* data, std::size_t size)
...
@@ -240,7 +240,7 @@ void onnx_parser::parse_from(const void* data, std::size_t size)
if
(
model
.
has_graph
())
if
(
model
.
has_graph
())
{
{
this
->
parse_graph
(
mm
,
model
.
graph
());
(
void
)
this
->
parse_graph
(
mm
,
model
.
graph
());
}
}
}
}
else
else
...
@@ -264,7 +264,8 @@ int64_t onnx_parser::get_opset_version(const onnx::ModelProto& model)
...
@@ -264,7 +264,8 @@ int64_t onnx_parser::get_opset_version(const onnx::ModelProto& model)
return
version
;
return
version
;
}
}
void
onnx_parser
::
parse_graph
(
module
*
mod
,
const
onnx
::
GraphProto
&
graph
)
std
::
vector
<
instruction_ref
>
onnx_parser
::
parse_graph
(
module
*
mod
,
const
onnx
::
GraphProto
&
graph
,
bool
inlining
)
{
{
std
::
unordered_map
<
std
::
string
,
instruction_ref
>
mod_insts
;
std
::
unordered_map
<
std
::
string
,
instruction_ref
>
mod_insts
;
for
(
auto
&&
f
:
graph
.
initializer
())
for
(
auto
&&
f
:
graph
.
initializer
())
...
@@ -372,11 +373,16 @@ void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
...
@@ -372,11 +373,16 @@ void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
std
::
back_inserter
(
output_ins
),
std
::
back_inserter
(
output_ins
),
[
&
](
const
auto
&
name
)
{
return
instructions
[
name
];
});
[
&
](
const
auto
&
name
)
{
return
instructions
[
name
];
});
// add the return instuction
if
(
not
inlining
)
mod
->
add_return
(
output_ins
);
{
// add the return instuction
mod
->
add_return
(
output_ins
);
// Remove instructions added in module (this is turned off for subgraph inlining)
erase_if
(
instructions
,
[
&
](
auto
&&
p
)
{
return
mod
->
has_instruction
(
p
.
second
);
});
}
// remove instructions added in this mod
return
output_ins
;
erase_if
(
instructions
,
[
&
](
auto
&&
p
)
{
return
mod
->
has_instruction
(
p
.
second
);
});
}
}
literal
onnx_parser
::
parse_value
(
const
onnx
::
AttributeProto
&
attr
)
const
literal
onnx_parser
::
parse_value
(
const
onnx
::
AttributeProto
&
attr
)
const
...
...
src/onnx/parse_if.cpp
View file @
0ffcccbc
...
@@ -51,6 +51,24 @@ struct parse_if : op_parser<parse_if>
...
@@ -51,6 +51,24 @@ struct parse_if : op_parser<parse_if>
" condition input can have only one element!"
);
" condition input can have only one element!"
);
}
}
// Fold instruction if condition is constant thus can be evaled
// prior to inference
if
(
args
.
front
()
->
can_eval
())
{
auto
cond_arg
=
args
.
front
()
->
eval
();
auto
*
mod
=
info
.
mod
;
// then branch
if
(
cond_arg
.
at
<
bool
>
())
{
return
parser
.
parse_graph
(
mod
,
then_graph
,
true
);
}
// else branch
else
{
return
parser
.
parse_graph
(
mod
,
else_graph
,
true
);
}
}
std
::
string
then_name
=
info
.
name
+
"_if"
;
std
::
string
then_name
=
info
.
name
+
"_if"
;
module_ref
then_mdl
=
parser
.
prog
.
create_module
(
then_name
);
module_ref
then_mdl
=
parser
.
prog
.
create_module
(
then_name
);
...
@@ -58,10 +76,10 @@ struct parse_if : op_parser<parse_if>
...
@@ -58,10 +76,10 @@ struct parse_if : op_parser<parse_if>
module_ref
else_mdl
=
parser
.
prog
.
create_module
(
else_name
);
module_ref
else_mdl
=
parser
.
prog
.
create_module
(
else_name
);
// parse the then sub_graph
// parse the then sub_graph
parser
.
parse_graph
(
then_mdl
,
then_graph
);
(
void
)
parser
.
parse_graph
(
then_mdl
,
then_graph
);
// parse_the else sub_graph
// parse_the else sub_graph
parser
.
parse_graph
(
else_mdl
,
else_graph
);
(
void
)
parser
.
parse_graph
(
else_mdl
,
else_graph
);
auto
then_out_shapes
=
then_mdl
->
get_output_shapes
();
auto
then_out_shapes
=
then_mdl
->
get_output_shapes
();
auto
else_out_shapes
=
else_mdl
->
get_output_shapes
();
auto
else_out_shapes
=
else_mdl
->
get_output_shapes
();
...
...
src/onnx/parse_loop.cpp
View file @
0ffcccbc
...
@@ -71,7 +71,7 @@ struct parse_loop : op_parser<parse_loop>
...
@@ -71,7 +71,7 @@ struct parse_loop : op_parser<parse_loop>
module_ref
sub_mod
=
parser
.
prog
.
create_module
(
mod_name
);
module_ref
sub_mod
=
parser
.
prog
.
create_module
(
mod_name
);
// parse the sub_graph
// parse the sub_graph
parser
.
parse_graph
(
sub_mod
,
sub_graph
);
(
void
)
parser
.
parse_graph
(
sub_mod
,
sub_graph
);
auto
ret
=
info
.
add_instruction
(
auto
ret
=
info
.
add_instruction
(
make_op
(
"loop"
,
{{
"max_iterations"
,
max_iterations
}}),
args
,
{
sub_mod
});
make_op
(
"loop"
,
{{
"max_iterations"
,
max_iterations
}}),
args
,
{
sub_mod
});
...
...
src/pass_manager.cpp
View file @
0ffcccbc
...
@@ -39,6 +39,7 @@ namespace migraphx {
...
@@ -39,6 +39,7 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_TRACE_PASSES
);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_TRACE_PASSES
);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_TIME_PASSES
);
void
validate_pass
(
module
&
mod
,
const
pass
&
p
,
tracer
trace
)
void
validate_pass
(
module
&
mod
,
const
pass
&
p
,
tracer
trace
)
{
{
...
@@ -94,19 +95,19 @@ struct module_pm : module_pass_manager
...
@@ -94,19 +95,19 @@ struct module_pm : module_pass_manager
virtual
void
run_pass
(
const
pass
&
p
)
override
virtual
void
run_pass
(
const
pass
&
p
)
override
{
{
assert
(
mod
);
assert
(
mod
);
timer
ts
{};
using
seconds
=
std
::
chrono
::
duration
<
double
>
;
trace
(
"Module: "
,
mod
->
name
(),
", Pass: "
,
p
.
name
());
const
double
t1
=
ts
.
record
<
seconds
>
();
assert
(
mod
->
validate
()
==
mod
->
end
());
assert
(
mod
->
validate
()
==
mod
->
end
());
p
.
apply
(
*
this
);
if
(
enabled
(
MIGRAPHX_TIME_PASSES
{}))
{
using
milliseconds
=
std
::
chrono
::
duration
<
double
,
std
::
milli
>
;
auto
ms
=
time
<
milliseconds
>
([
&
]
{
p
.
apply
(
*
this
);
});
std
::
cout
<<
p
.
name
()
<<
": "
<<
ms
<<
"ms
\n
"
;
}
else
{
p
.
apply
(
*
this
);
}
trace
(
*
mod
);
trace
(
*
mod
);
validate_pass
(
*
mod
,
p
,
*
t
);
validate_pass
(
*
mod
,
p
,
*
t
);
const
double
t2
=
ts
.
record
<
seconds
>
();
trace
(
"Pass: "
,
p
.
name
(),
" completed in (s): "
,
(
t2
-
t1
));
}
}
};
};
...
...
src/program.cpp
View file @
0ffcccbc
...
@@ -336,7 +336,8 @@ std::vector<argument> generic_eval(const module* mod,
...
@@ -336,7 +336,8 @@ std::vector<argument> generic_eval(const module* mod,
if
(
not
ins
->
get_shape
().
dynamic
()
and
param
.
get_shape
()
!=
ins
->
get_shape
())
if
(
not
ins
->
get_shape
().
dynamic
()
and
param
.
get_shape
()
!=
ins
->
get_shape
())
{
{
MIGRAPHX_THROW
(
"Incorrect shape {"
+
to_string
(
param
.
get_shape
())
+
MIGRAPHX_THROW
(
"Incorrect shape {"
+
to_string
(
param
.
get_shape
())
+
"} for parameter: "
+
param_name
);
"} for parameter: "
+
param_name
+
" should be: "
+
to_string
(
ins
->
get_shape
()));
}
}
return
param
;
return
param
;
}));
}));
...
...
test/memory_coloring_test.cpp
View file @
0ffcccbc
...
@@ -691,7 +691,7 @@ TEST_CASE(test38)
...
@@ -691,7 +691,7 @@ TEST_CASE(test38)
auto
p83
=
m
.
add_instruction
(
pass_op
{},
p78
,
p77
);
auto
p83
=
m
.
add_instruction
(
pass_op
{},
p78
,
p77
);
m
.
add_instruction
(
pass_op
{},
output
,
p83
,
p63
);
m
.
add_instruction
(
pass_op
{},
output
,
p83
,
p63
);
run_pass
(
m
);
run_pass
(
m
);
CHECK
(
m
.
get_parameter_shape
(
"scratch"
).
bytes
()
==
7225344
);
// Optimal solution is
6422528
CHECK
(
m
.
get_parameter_shape
(
"scratch"
).
bytes
()
==
6422528
);
CHECK
(
no_allocate
(
m
));
CHECK
(
no_allocate
(
m
));
}
}
...
@@ -729,7 +729,7 @@ TEST_CASE(test39)
...
@@ -729,7 +729,7 @@ TEST_CASE(test39)
run_pass
(
*
smod
);
run_pass
(
*
smod
);
}
}
CHECK
(
mm
->
get_parameter_shape
(
"scratch"
).
bytes
()
==
4
);
CHECK
(
mm
->
get_parameter_shape
(
"scratch"
).
bytes
()
==
1
);
CHECK
(
then_mod
->
get_parameter_shape
(
"scratch"
).
bytes
()
==
24
);
CHECK
(
then_mod
->
get_parameter_shape
(
"scratch"
).
bytes
()
==
24
);
CHECK
(
else_mod
->
get_parameter_shape
(
"scratch"
).
bytes
()
==
24
);
CHECK
(
else_mod
->
get_parameter_shape
(
"scratch"
).
bytes
()
==
24
);
CHECK
(
no_allocate
(
*
mm
));
CHECK
(
no_allocate
(
*
mm
));
...
@@ -3374,7 +3374,7 @@ TEST_CASE(rnn_dom)
...
@@ -3374,7 +3374,7 @@ TEST_CASE(rnn_dom)
m
.
add_instruction
(
pass_op
{},
moutput
,
mx250
,
mx249
,
mx248
);
m
.
add_instruction
(
pass_op
{},
moutput
,
mx250
,
mx249
,
mx248
);
run_pass
(
m
);
run_pass
(
m
);
CHECK
(
m
.
get_parameter_shape
(
"scratch"
).
bytes
()
==
1600
);
CHECK
(
m
.
get_parameter_shape
(
"scratch"
).
bytes
()
==
1824
);
// Optimal is
1600
CHECK
(
no_allocate
(
m
));
CHECK
(
no_allocate
(
m
));
CHECK
(
is_disjoint
({
mx0
,
mx8
}));
CHECK
(
is_disjoint
({
mx0
,
mx8
}));
CHECK
(
is_disjoint
({
mx0
,
mx8
}));
CHECK
(
is_disjoint
({
mx0
,
mx8
}));
...
@@ -3790,4 +3790,23 @@ TEST_CASE(literal_test)
...
@@ -3790,4 +3790,23 @@ TEST_CASE(literal_test)
CHECK
(
lit
==
result
);
CHECK
(
lit
==
result
);
}
}
TEST_CASE
(
test_tuple
)
{
migraphx
::
module
m
;
auto
s1
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
8
}};
auto
s2
=
migraphx
::
shape
{
migraphx
::
shape
::
half_type
,
{
10
}};
auto
s
=
migraphx
::
shape
{{
s1
,
s2
}};
auto
a1
=
add_alloc
(
m
,
s
);
auto
m1
=
m
.
add_instruction
(
pass_op
{},
a1
);
auto
a2
=
add_alloc
(
m
,
{
migraphx
::
shape
::
float_type
,
{
4
}});
m
.
add_instruction
(
pass_op
{},
a2
,
m1
);
run_pass
(
m
);
CHECK
(
m
.
get_parameter_shape
(
"scratch"
).
bytes
()
==
68
);
CHECK
(
no_allocate
(
m
));
CHECK
(
is_disjoint
({
a1
,
a2
}));
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/onnx/gathernd_dyn_test.onnx
0 → 100644
View file @
0ffcccbc
File added
test/onnx/gen_onnx.py
View file @
0ffcccbc
...
@@ -2132,6 +2132,19 @@ def gathernd_test():
...
@@ -2132,6 +2132,19 @@ def gathernd_test():
return
([
node
],
[
x
,
i
],
[
y
])
return
([
node
],
[
x
,
i
],
[
y
])
@
onnx_test
()
def
gathernd_dyn_test
():
x
=
helper
.
make_tensor_value_info
(
'data'
,
TensorProto
.
FLOAT
,
[
None
,
2
])
i
=
helper
.
make_tensor_value_info
(
'indices'
,
TensorProto
.
INT64
,
[
2
,
2
])
y
=
helper
.
make_tensor_value_info
(
'y'
,
TensorProto
.
FLOAT
,
[
2
])
node
=
onnx
.
helper
.
make_node
(
'GatherND'
,
inputs
=
[
'data'
,
'indices'
],
outputs
=
[
'y'
])
return
([
node
],
[
x
,
i
],
[
y
])
@
onnx_test
()
@
onnx_test
()
def
gathernd_batch_dims_test
():
def
gathernd_batch_dims_test
():
x
=
helper
.
make_tensor_value_info
(
'data'
,
TensorProto
.
FLOAT
,
[
2
,
2
,
2
])
x
=
helper
.
make_tensor_value_info
(
'data'
,
TensorProto
.
FLOAT
,
[
2
,
2
,
2
])
...
@@ -2498,6 +2511,58 @@ def if_else_test():
...
@@ -2498,6 +2511,58 @@ def if_else_test():
x
=
onnx
.
helper
.
make_tensor_value_info
(
'x'
,
onnx
.
TensorProto
.
FLOAT
,
[
2
,
3
])
x
=
onnx
.
helper
.
make_tensor_value_info
(
'x'
,
onnx
.
TensorProto
.
FLOAT
,
[
2
,
3
])
y
=
onnx
.
helper
.
make_tensor_value_info
(
'y'
,
onnx
.
TensorProto
.
FLOAT
,
[
2
,
3
])
y
=
onnx
.
helper
.
make_tensor_value_info
(
'y'
,
onnx
.
TensorProto
.
FLOAT
,
[
2
,
3
])
then_out
=
onnx
.
helper
.
make_tensor_value_info
(
'then_out'
,
onnx
.
TensorProto
.
FLOAT
,
[
2
,
3
])
else_out
=
onnx
.
helper
.
make_tensor_value_info
(
'else_out'
,
onnx
.
TensorProto
.
FLOAT
,
[
2
,
3
])
xt
=
np
.
ones
((
2
,
3
)).
astype
(
np
.
float
)
xt_tensor
=
helper
.
make_tensor
(
name
=
'xt'
,
data_type
=
TensorProto
.
FLOAT
,
dims
=
xt
.
shape
,
vals
=
xt
.
flatten
().
astype
(
np
.
float32
))
yt
=
np
.
random
.
randn
(
2
,
3
).
astype
(
np
.
float
)
yt_tensor
=
helper
.
make_tensor
(
name
=
'yt'
,
data_type
=
TensorProto
.
FLOAT
,
dims
=
yt
.
shape
,
vals
=
yt
.
flatten
().
astype
(
np
.
float32
))
then_add_node
=
onnx
.
helper
.
make_node
(
'Add'
,
inputs
=
[
'x'
,
'xt'
],
outputs
=
[
'then_out'
])
else_mul_node
=
onnx
.
helper
.
make_node
(
'Mul'
,
inputs
=
[
'y'
,
'yt'
],
outputs
=
[
'else_out'
])
then_body
=
onnx
.
helper
.
make_graph
([
then_add_node
],
'then_body'
,
[],
[
then_out
])
else_body
=
onnx
.
helper
.
make_graph
([
else_mul_node
],
'else_body'
,
[],
[
else_out
])
cond_tensor
=
onnx
.
helper
.
make_tensor_value_info
(
"cond"
,
onnx
.
TensorProto
.
BOOL
,
[
1
])
res
=
onnx
.
helper
.
make_tensor_value_info
(
'res'
,
TensorProto
.
FLOAT
,
[])
node
=
onnx
.
helper
.
make_node
(
'If'
,
inputs
=
[
'cond'
],
outputs
=
[
'res'
],
then_branch
=
then_body
,
else_branch
=
else_body
)
return
([
node
],
[
x
,
y
,
cond_tensor
],
[
res
],
[
xt_tensor
,
yt_tensor
])
@
onnx_test
()
def
if_else_test_inlined
():
x
=
onnx
.
helper
.
make_tensor_value_info
(
'x'
,
onnx
.
TensorProto
.
FLOAT
,
[
2
,
3
])
y
=
onnx
.
helper
.
make_tensor_value_info
(
'y'
,
onnx
.
TensorProto
.
FLOAT
,
[
2
,
3
])
then_out
=
onnx
.
helper
.
make_tensor_value_info
(
'then_out'
,
then_out
=
onnx
.
helper
.
make_tensor_value_info
(
'then_out'
,
onnx
.
TensorProto
.
FLOAT
,
onnx
.
TensorProto
.
FLOAT
,
[
2
,
3
])
[
2
,
3
])
...
@@ -2547,6 +2612,149 @@ def if_else_test():
...
@@ -2547,6 +2612,149 @@ def if_else_test():
return
([
node
],
[
x
,
y
],
[
res
],
[
cond_tensor
,
xt_tensor
,
yt_tensor
])
return
([
node
],
[
x
,
y
],
[
res
],
[
cond_tensor
,
xt_tensor
,
yt_tensor
])
@
onnx_test
()
def
if_then_else_multi_output_shapes_inlined_test
():
x
=
onnx
.
helper
.
make_tensor_value_info
(
'x'
,
onnx
.
TensorProto
.
FLOAT
,
[
2
,
3
,
1
])
y
=
onnx
.
helper
.
make_tensor_value_info
(
'y'
,
onnx
.
TensorProto
.
FLOAT
,
[
2
,
3
])
then_out
=
onnx
.
helper
.
make_tensor_value_info
(
'then_out'
,
onnx
.
TensorProto
.
FLOAT
,
[
2
,
3
,
1
])
then_out2
=
onnx
.
helper
.
make_tensor_value_info
(
'then_out2'
,
onnx
.
TensorProto
.
FLOAT
,
[
2
,
3
,
1
])
else_out
=
onnx
.
helper
.
make_tensor_value_info
(
'else_out'
,
onnx
.
TensorProto
.
FLOAT
,
[
2
,
3
])
else_out2
=
onnx
.
helper
.
make_tensor_value_info
(
'else_out2'
,
onnx
.
TensorProto
.
FLOAT
,
[
2
,
3
])
xt
=
np
.
ones
((
2
,
3
,
1
)).
astype
(
np
.
float
)
xt_tensor
=
helper
.
make_tensor
(
name
=
'xt'
,
data_type
=
TensorProto
.
FLOAT
,
dims
=
xt
.
shape
,
vals
=
xt
.
flatten
().
astype
(
np
.
float32
))
yt
=
np
.
random
.
randn
(
2
,
3
).
astype
(
np
.
float
)
yt_tensor
=
helper
.
make_tensor
(
name
=
'yt'
,
data_type
=
TensorProto
.
FLOAT
,
dims
=
yt
.
shape
,
vals
=
yt
.
flatten
().
astype
(
np
.
float32
))
then_add_node
=
onnx
.
helper
.
make_node
(
'Add'
,
inputs
=
[
'x'
,
'xt'
],
outputs
=
[
'then_out'
])
then_add_node2
=
onnx
.
helper
.
make_node
(
'Add'
,
inputs
=
[
'x'
,
'x'
],
outputs
=
[
'then_out2'
])
else_mul_node
=
onnx
.
helper
.
make_node
(
'Mul'
,
inputs
=
[
'y'
,
'yt'
],
outputs
=
[
'else_out'
])
else_sub_node
=
onnx
.
helper
.
make_node
(
'Sub'
,
inputs
=
[
'y'
,
'yt'
],
outputs
=
[
'else_out2'
])
then_body
=
onnx
.
helper
.
make_graph
([
then_add_node
,
then_add_node2
],
'then_body'
,
[],
[
then_out
,
then_out2
])
else_body
=
onnx
.
helper
.
make_graph
([
else_mul_node
,
else_sub_node
],
'else_body'
,
[],
[
else_out
,
else_out2
])
cond
=
np
.
array
([
1
]).
astype
(
np
.
bool
)
cond_tensor
=
helper
.
make_tensor
(
name
=
"cond"
,
data_type
=
TensorProto
.
BOOL
,
dims
=
cond
.
shape
,
vals
=
cond
.
astype
(
bool
))
res1
=
onnx
.
helper
.
make_tensor_value_info
(
'res1'
,
TensorProto
.
FLOAT
,
[])
res2
=
onnx
.
helper
.
make_tensor_value_info
(
'res2'
,
TensorProto
.
FLOAT
,
[])
node
=
onnx
.
helper
.
make_node
(
'If'
,
inputs
=
[
'cond'
],
outputs
=
[
'res1'
,
'res2'
],
then_branch
=
then_body
,
else_branch
=
else_body
)
return
([
node
],
[
x
,
y
],
[
res1
,
res2
],
[
cond_tensor
,
xt_tensor
,
yt_tensor
])
@
onnx_test
()
def
if_then_else_multi_output_shapes_test
():
x
=
onnx
.
helper
.
make_tensor_value_info
(
'x'
,
onnx
.
TensorProto
.
FLOAT
,
[
2
,
3
,
1
])
y
=
onnx
.
helper
.
make_tensor_value_info
(
'y'
,
onnx
.
TensorProto
.
FLOAT
,
[
2
,
3
,
1
])
then_out
=
onnx
.
helper
.
make_tensor_value_info
(
'then_out'
,
onnx
.
TensorProto
.
FLOAT
,
[
2
,
3
,
1
])
then_out2
=
onnx
.
helper
.
make_tensor_value_info
(
'then_out2'
,
onnx
.
TensorProto
.
FLOAT
,
[
2
,
3
,
1
])
else_out
=
onnx
.
helper
.
make_tensor_value_info
(
'else_out'
,
onnx
.
TensorProto
.
FLOAT
,
[
2
,
3
,
1
])
else_out2
=
onnx
.
helper
.
make_tensor_value_info
(
'else_out2'
,
onnx
.
TensorProto
.
FLOAT
,
[
2
,
3
,
1
])
xt
=
np
.
ones
((
2
,
3
,
1
)).
astype
(
np
.
float
)
xt_tensor
=
helper
.
make_tensor
(
name
=
'xt'
,
data_type
=
TensorProto
.
FLOAT
,
dims
=
xt
.
shape
,
vals
=
xt
.
flatten
().
astype
(
np
.
float32
))
yt
=
np
.
random
.
randn
(
2
,
3
,
1
).
astype
(
np
.
float
)
yt_tensor
=
helper
.
make_tensor
(
name
=
'yt'
,
data_type
=
TensorProto
.
FLOAT
,
dims
=
yt
.
shape
,
vals
=
yt
.
flatten
().
astype
(
np
.
float32
))
then_add_node
=
onnx
.
helper
.
make_node
(
'Add'
,
inputs
=
[
'x'
,
'xt'
],
outputs
=
[
'then_out'
])
then_add_node2
=
onnx
.
helper
.
make_node
(
'Add'
,
inputs
=
[
'x'
,
'x'
],
outputs
=
[
'then_out2'
])
else_mul_node
=
onnx
.
helper
.
make_node
(
'Mul'
,
inputs
=
[
'y'
,
'yt'
],
outputs
=
[
'else_out'
])
else_sub_node
=
onnx
.
helper
.
make_node
(
'Sub'
,
inputs
=
[
'y'
,
'yt'
],
outputs
=
[
'else_out2'
])
then_body
=
onnx
.
helper
.
make_graph
([
then_add_node
,
then_add_node2
],
'then_body'
,
[],
[
then_out
,
then_out2
])
else_body
=
onnx
.
helper
.
make_graph
([
else_mul_node
,
else_sub_node
],
'else_body'
,
[],
[
else_out
,
else_out2
])
cond_tensor
=
onnx
.
helper
.
make_tensor_value_info
(
"cond"
,
onnx
.
TensorProto
.
BOOL
,
[
1
])
res1
=
onnx
.
helper
.
make_tensor_value_info
(
'res1'
,
TensorProto
.
FLOAT
,
[])
res2
=
onnx
.
helper
.
make_tensor_value_info
(
'res2'
,
TensorProto
.
FLOAT
,
[])
node
=
onnx
.
helper
.
make_node
(
'If'
,
inputs
=
[
'cond'
],
outputs
=
[
'res1'
,
'res2'
],
then_branch
=
then_body
,
else_branch
=
else_body
)
return
([
node
],
[
x
,
y
,
cond_tensor
],
[
res1
,
res2
],
[
xt_tensor
,
yt_tensor
])
@
onnx_test
()
@
onnx_test
()
def
if_literal_test
():
def
if_literal_test
():
then_out
=
onnx
.
helper
.
make_tensor_value_info
(
'then_out'
,
then_out
=
onnx
.
helper
.
make_tensor_value_info
(
'then_out'
,
...
@@ -2807,6 +3015,59 @@ def if_then_test():
...
@@ -2807,6 +3015,59 @@ def if_then_test():
x
=
onnx
.
helper
.
make_tensor_value_info
(
'x'
,
onnx
.
TensorProto
.
FLOAT
,
[
2
,
3
])
x
=
onnx
.
helper
.
make_tensor_value_info
(
'x'
,
onnx
.
TensorProto
.
FLOAT
,
[
2
,
3
])
y
=
onnx
.
helper
.
make_tensor_value_info
(
'y'
,
onnx
.
TensorProto
.
FLOAT
,
[
2
,
3
])
y
=
onnx
.
helper
.
make_tensor_value_info
(
'y'
,
onnx
.
TensorProto
.
FLOAT
,
[
2
,
3
])
then_out
=
onnx
.
helper
.
make_tensor_value_info
(
'then_out'
,
onnx
.
TensorProto
.
FLOAT
,
[
2
,
3
])
else_out
=
onnx
.
helper
.
make_tensor_value_info
(
'else_out'
,
onnx
.
TensorProto
.
FLOAT
,
[
2
,
3
])
xt
=
np
.
ones
((
2
,
3
)).
astype
(
np
.
float
)
xt_tensor
=
helper
.
make_tensor
(
name
=
'xt'
,
data_type
=
TensorProto
.
FLOAT
,
dims
=
xt
.
shape
,
vals
=
xt
.
flatten
().
astype
(
np
.
float32
))
yt
=
np
.
random
.
randn
(
2
,
3
).
astype
(
np
.
float
)
yt_tensor
=
helper
.
make_tensor
(
name
=
'yt'
,
data_type
=
TensorProto
.
FLOAT
,
dims
=
yt
.
shape
,
vals
=
yt
.
flatten
().
astype
(
np
.
float32
))
then_add_node
=
onnx
.
helper
.
make_node
(
'Add'
,
inputs
=
[
'x'
,
'xt'
],
outputs
=
[
'then_out'
])
else_mul_node
=
onnx
.
helper
.
make_node
(
'Mul'
,
inputs
=
[
'y'
,
'yt'
],
outputs
=
[
'else_out'
])
then_body
=
onnx
.
helper
.
make_graph
([
then_add_node
],
'then_body'
,
[],
[
then_out
])
else_body
=
onnx
.
helper
.
make_graph
([
else_mul_node
],
'else_body'
,
[],
[
else_out
])
cond_tensor
=
onnx
.
helper
.
make_tensor_value_info
(
"cond"
,
onnx
.
TensorProto
.
BOOL
,
[
1
])
res
=
onnx
.
helper
.
make_tensor_value_info
(
'res'
,
TensorProto
.
FLOAT
,
[])
node
=
onnx
.
helper
.
make_node
(
'If'
,
inputs
=
[
'cond'
],
outputs
=
[
'res'
],
then_branch
=
then_body
,
else_branch
=
else_body
)
return
([
node
],
[
x
,
y
,
cond_tensor
],
[
res
],
[
xt_tensor
,
yt_tensor
])
@
onnx_test
()
def
if_then_test_inlined
():
x
=
onnx
.
helper
.
make_tensor_value_info
(
'x'
,
onnx
.
TensorProto
.
FLOAT
,
[
2
,
3
])
y
=
onnx
.
helper
.
make_tensor_value_info
(
'y'
,
onnx
.
TensorProto
.
FLOAT
,
[
2
,
3
])
then_out
=
onnx
.
helper
.
make_tensor_value_info
(
'then_out'
,
then_out
=
onnx
.
helper
.
make_tensor_value_info
(
'then_out'
,
onnx
.
TensorProto
.
FLOAT
,
onnx
.
TensorProto
.
FLOAT
,
[
2
,
3
])
[
2
,
3
])
...
@@ -5707,6 +5968,24 @@ def scatternd_test():
...
@@ -5707,6 +5968,24 @@ def scatternd_test():
return
([
node
],
[
data
,
indices
,
updates
],
[
output
])
return
([
node
],
[
data
,
indices
,
updates
],
[
output
])
@
onnx_test
()
def
scatternd_dyn_test
():
data
=
helper
.
make_tensor_value_info
(
'data'
,
TensorProto
.
FLOAT
,
[
None
,
2
,
2
])
indices
=
helper
.
make_tensor_value_info
(
'indices'
,
TensorProto
.
INT64
,
[
None
,
1
,
2
])
updates
=
helper
.
make_tensor_value_info
(
'updates'
,
TensorProto
.
FLOAT
,
[
None
,
1
,
2
])
output
=
helper
.
make_tensor_value_info
(
'output'
,
TensorProto
.
FLOAT
,
[
None
,
2
,
2
])
node
=
onnx
.
helper
.
make_node
(
'ScatterND'
,
inputs
=
[
'data'
,
'indices'
,
'updates'
],
outputs
=
[
'output'
])
return
([
node
],
[
data
,
indices
,
updates
],
[
output
])
@
onnx_test
()
@
onnx_test
()
def
selu_test
():
def
selu_test
():
x
=
helper
.
make_tensor_value_info
(
'x'
,
TensorProto
.
DOUBLE
,
[
2
,
3
])
x
=
helper
.
make_tensor_value_info
(
'x'
,
TensorProto
.
DOUBLE
,
[
2
,
3
])
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment