Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
673ca71c
Commit
673ca71c
authored
Jan 31, 2023
by
Paul
Browse files
Merge branch 'develop' into conv-add
parents
4bfe1662
91cc7242
Changes
98
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
608 additions
and
119 deletions
+608
-119
.github/workflows/add-to-project.yaml
.github/workflows/add-to-project.yaml
+19
-0
.github/workflows/sync-onnxrt-main.yaml
.github/workflows/sync-onnxrt-main.yaml
+46
-0
Dockerfile
Dockerfile
+28
-6
Jenkinsfile
Jenkinsfile
+6
-0
src/CMakeLists.txt
src/CMakeLists.txt
+1
-0
src/include/migraphx/match/layernorm.hpp
src/include/migraphx/match/layernorm.hpp
+4
-3
src/include/migraphx/op/gather.hpp
src/include/migraphx/op/gather.hpp
+40
-15
src/include/migraphx/op/pad.hpp
src/include/migraphx/op/pad.hpp
+21
-10
src/include/migraphx/op/reshape.hpp
src/include/migraphx/op/reshape.hpp
+71
-7
src/include/migraphx/optimize_module.hpp
src/include/migraphx/optimize_module.hpp
+48
-0
src/include/migraphx/shape.hpp
src/include/migraphx/shape.hpp
+7
-0
src/module.cpp
src/module.cpp
+2
-1
src/onnx/onnx_parser.cpp
src/onnx/onnx_parser.cpp
+13
-3
src/onnx/parse_gemm.cpp
src/onnx/parse_gemm.cpp
+56
-39
src/onnx/parse_matmul.cpp
src/onnx/parse_matmul.cpp
+58
-34
src/onnx/parse_pad.cpp
src/onnx/parse_pad.cpp
+6
-0
src/onnx/parse_reshape.cpp
src/onnx/parse_reshape.cpp
+1
-1
src/onnx/parse_trilu.cpp
src/onnx/parse_trilu.cpp
+90
-0
src/optimize_module.cpp
src/optimize_module.cpp
+49
-0
src/shape.cpp
src/shape.cpp
+42
-0
No files found.
.github/workflows/add-to-project.yaml
0 → 100644
View file @
673ca71c
name
:
Add items to GH project
on
:
pull_request
:
types
:
-
opened
issues
:
types
:
-
opened
jobs
:
add-to-project
:
name
:
Add PRs and issues to MIGX project
runs-on
:
ubuntu-latest
steps
:
-
uses
:
actions/add-to-project@v0.4.0
with
:
project-url
:
https://github.com/orgs/ROCmSoftwarePlatform/projects/20
github-token
:
${{ secrets.TEST_PR_WORKFLOW }}
.github/workflows/sync-onnxrt-main.yaml
0 → 100644
View file @
673ca71c
name
:
Onnxruntime main weekly sync
on
:
schedule
:
-
cron
:
"
05
17
*
*
1"
jobs
:
runs-on
:
ubuntu-latest
sync
:
steps
:
-
uses
:
actions/checkout@v3
with
:
ref
:
develop
path
:
../
get_date
:
steps
:
-
run
:
echo "::set-output name=date::$(date +'%Y-%m-%d')"
update_file
:
needs
:
[
sync get_date
]
steps
:
-
run
:
git clone https://github.com/microsoft/onnxruntime.git && cd onnxruntime && git rev-parse HEAD >> ../test/onnx/.onnxrt-commit
Add_commit
:
needs
:
update_file
steps
:
-
name
:
Add & Commit
uses
:
EndBug/add-and-commit@v9.1.1
with
:
new_branch
:
onnxruntime-sync-${{ steps.date.outputs.date }}
add
:
../test/onnx/.onnxrt-commit
message
:
Update Onnxruntime commit to latest release
default_author
:
github_actions
push
:
true
PR
:
needs
:
Add_commit
steps
:
-
name
:
GitHub Action for creating Pull Requests
uses
:
devops-infra/action-pull-request@v0.5.3
with
:
github_token
:
${{ secrets.GITHUB_TOKEN }}
title
:
Sync Onnxruntime main
reviewer
:
pfultz2, causten
assignee
:
TedThemistokleous
label
:
automatic, onnxruntime
target_branch
:
develop
Dockerfile
View file @
673ca71c
...
...
@@ -5,6 +5,10 @@ ARG PREFIX=/usr/local
# Support multiarch
RUN
dpkg
--add-architecture
i386
# Install rocm key
RUN
apt-get update
&&
apt-get
install
-y
gnupg2
--no-install-recommends
curl
&&
\
curl
-sL
http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add -
# Add rocm repository
RUN
sh
-c
'echo deb [arch=amd64 trusted=yes] http://repo.radeon.com/rocm/apt/5.3/ ubuntu main > /etc/apt/sources.list.d/rocm.list'
...
...
@@ -32,10 +36,27 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
libnuma-dev
\
miopen-hip
\
rocblas
\
hipfft
\
rocthrust
\
rocrand
\
hipsparse
\
rccl
\
rccl-dev
\
rocm-smi-lib
\
rocm-dev
\
roctracer-dev
\
hipcub
\
hipblas
\
hipify-clang
\
half
\
libssl-dev
\
zlib1g-dev
&&
\
apt-get clean
&&
\
rm
-rf
/var/lib/apt/lists/
*
# add this for roctracer dependancies
RUN
pip3
install
CppHeaderParser
packaging
==
22.0
# Workaround broken rocm packages
RUN
ln
-s
/opt/rocm-
*
/opt/rocm
RUN
echo
"/opt/rocm/lib"
>
/etc/ld.so.conf.d/rocm.conf
...
...
@@ -72,18 +93,19 @@ RUN /download_models.sh && rm /download_models.sh
# Install latest ccache version
RUN
cget
-p
$PREFIX
install
facebook/zstd@v1.4.5
-X
subdir
-DCMAKE_DIR
=
build/cmake
RUN
cget
-p
$PREFIX
install
ccache@v4.1
-DENABLE_TESTING
=
OFF
RUN
cget
-p
/opt/cmake
install
kitware/cmake@v3.24.3
# Install newer cmake for onnx runtime
ARG
CMAKE_VERSION=3.24.2
RUN
cget
-p
/opt/cmake
install
-X
binary https://github.com/Kitware/CMake/releases/download/v
${
CMAKE_VERSION
}
/cmake-
${
CMAKE_VERSION
}
-Linux-x86_64
.tar.gz
COPY
./test/onnx/.onnxrt-commit /
ARG
ONNXRUNTIME_REPO=https://github.com/Microsoft/onnxruntime
ARG
ONNXRUNTIME_BRANCH=main
ARG
ONNXRUNTIME_COMMIT=24f1bd6156cf5968bbc76dfb0e801a9b9c56b9fc
ARG
ONNXRUNTIME_COMMIT
RUN
git clone
--single-branch
--branch
${
ONNXRUNTIME_BRANCH
}
--recursive
${
ONNXRUNTIME_REPO
}
onnxruntime
&&
\
cd
onnxruntime
&&
\
git checkout
${
ONNXRUNTIME_COMMIT
}
&&
\
/bin/sh dockerfiles/scripts/install_common_deps.sh
if
[
-z
"
$ONNXRUNTIME_COMMIT
"
]
;
then
git checkout
$(
cat
/.onnxrt-commit
)
;
else
git checkout
${
ONNXRUNTIME_COMMIT
}
;
fi
&&
\
/bin/sh /onnxruntime/dockerfiles/scripts/install_common_deps.sh
ADD
tools/build_and_test_onnxrt.sh /onnxruntime/build_and_test_onnxrt.sh
...
...
Jenkinsfile
View file @
673ca71c
...
...
@@ -15,11 +15,13 @@ def rocmtestnode(Map conf) {
def
compiler
=
bconf
.
get
(
"compiler"
,
"/opt/rocm/llvm/bin/clang++"
)
def
flags
=
bconf
.
get
(
"flags"
,
""
)
def
gpu_debug
=
bconf
.
get
(
"gpu_debug"
,
"0"
)
def
hiprtc_workarounds
=
bconf
.
get
(
"hiprtc_workarounds"
,
"0"
)
def
cmd
=
"""
ulimit -c unlimited
echo "leak:dnnl::impl::malloc" > suppressions.txt
export LSAN_OPTIONS="suppressions=\$(pwd)/suppressions.txt"
export MIGRAPHX_GPU_DEBUG=${gpu_debug}
export MIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS=${hiprtc_workarounds}
export CXX=${compiler}
export CXXFLAGS='-Werror'
env
...
...
@@ -110,6 +112,10 @@ rocmtest clang_debug: rocmnode('vega') { cmake_build ->
cmake_build
(
flags:
"-DCMAKE_BUILD_TYPE=release"
)
stash
includes:
'build/*.deb'
,
name:
'migraphx-package'
}
},
hiprtc_gpu_debug:
rocmnode
(
'vega'
)
{
cmake_build
->
stage
(
'HipRTC GPU Debug'
)
{
cmake_build
(
flags:
"-DCMAKE_BUILD_TYPE=release -DMIGRAPHX_USE_HIPRTC=On"
,
gpu_debug:
true
,
hiprtc_workarounds:
true
)
}
},
mlir_debug:
rocmnode
(
'vega'
)
{
cmake_build
->
stage
(
'MLIR Debug'
)
{
def
sanitizers
=
"undefined"
...
...
src/CMakeLists.txt
View file @
673ca71c
...
...
@@ -64,6 +64,7 @@ add_library(migraphx
normalize_ops.cpp
op_enums.cpp
operation.cpp
optimize_module.cpp
opt/memory_coloring.cpp
opt/memory_coloring_impl.cpp
pad_calc.cpp
...
...
src/include/migraphx/match/layernorm.hpp
View file @
673ca71c
...
...
@@ -48,10 +48,11 @@ struct layernorm_matcher
auto
layernorm_onnx
()
const
{
return
f
(
"div"
)(
arg
(
0
)(
x_minus_mean
()),
auto
add_eps
=
f
(
"add"
)(
either_arg
(
0
,
1
)(
variance
(),
is_constant
().
bind
(
"eps"
)));
return
f
(
"div"
)(
arg
(
0
)(
x_minus_mean
()),
arg
(
1
)(
skip_broadcasts
(
f
(
"sqrt"
)(
arg
(
0
)(
f
(
"add"
)(
either_arg
(
0
,
1
)(
variance
(),
is_constant
().
bind
(
"eps"
))))))));
arg
(
1
)(
skip_broadcasts
(
f
(
"sqrt"
)(
arg
(
0
)(
match
::
any_of
(
add_eps
,
variance
()))))));
}
auto
matcher
()
const
{
return
layernorm_onnx
();
}
...
...
src/include/migraphx/op/gather.hpp
View file @
673ca71c
...
...
@@ -26,6 +26,7 @@
#include <array>
#include <migraphx/check_shapes.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
...
...
@@ -61,35 +62,59 @@ struct gather
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
2
);
auto
lens
=
inputs
[
0
].
lens
();
auto
type
=
inputs
[
0
].
type
();
lens
.
erase
(
lens
.
begin
()
+
axis
);
if
(
not
inputs
[
1
].
scalar
())
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
2
);
shape
data
=
inputs
[
0
];
shape
indices
=
inputs
[
1
];
auto
type
=
data
.
type
();
// If index_dims is dynamic, convert the data to dynamic too.
if
(
indices
.
dynamic
())
{
auto
ind_lens
=
inputs
[
1
].
lens
();
lens
.
insert
(
lens
.
begin
()
+
axis
,
ind_lens
.
begin
(),
ind_lens
.
end
());
data
=
data
.
to_dynamic
();
}
// for scalar output
if
(
lens
.
empty
())
if
(
data
.
dynamic
())
{
return
{
type
};
auto
dims
=
data
.
dyn_dims
();
dims
.
erase
(
dims
.
begin
()
+
axis
);
if
(
not
indices
.
scalar
())
{
auto
index_dims
=
indices
.
to_dynamic
().
dyn_dims
();
dims
.
insert
(
dims
.
begin
()
+
axis
,
index_dims
.
begin
(),
index_dims
.
end
());
}
return
{
type
,
dims
};
}
else
{
// Both data and indices are static. indices may be scalar
auto
lens
=
data
.
lens
();
lens
.
erase
(
lens
.
begin
()
+
axis
);
return
{
type
,
lens
};
if
(
not
indices
.
scalar
())
{
auto
ind_lens
=
indices
.
lens
();
lens
.
insert
(
lens
.
begin
()
+
axis
,
ind_lens
.
begin
(),
ind_lens
.
end
());
}
// for scalar output
if
(
lens
.
empty
())
{
return
{
type
};
}
return
{
type
,
lens
};
}
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
out
put_shape
};
argument
result
{
dyn_out
.
com
put
ed
_shape
};
// negative axis means counting dimensions from back
auto
lens
=
args
[
0
].
get_shape
().
lens
();
std
::
size_t
axis_dim_size
=
lens
[
axis
];
// max dimension in axis
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
data
)
{
args
[
1
].
visit
([
&
](
auto
indices
)
{
if
(
out
put_shape
.
scalar
())
if
(
dyn_out
.
com
put
ed
_shape
.
scalar
())
{
auto
in_index
=
indices
.
front
();
in_index
=
(
in_index
<
0
)
?
in_index
+
axis_dim_size
:
in_index
;
...
...
src/include/migraphx/op/pad.hpp
View file @
673ca71c
...
...
@@ -59,18 +59,29 @@ struct pad
std
::
string
name
()
const
{
return
"pad"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
auto
&&
idims
=
inputs
.
front
().
lens
();
std
::
vector
<
std
::
size_t
>
rdims
(
idims
.
begin
(),
idims
.
end
());
std
::
size_t
num_dims
=
rdims
.
size
();
for
(
std
::
size_t
i
=
0
;
i
<
num_dims
;
i
++
)
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
);
const
auto
&
s0
=
inputs
.
front
();
if
(
s0
.
dynamic
())
{
rdims
[
i
]
+=
pads
[
i
]
+
pads
[
i
+
num_dims
];
auto
out_dyn_dims
=
s0
.
dyn_dims
();
for
(
std
::
size_t
i
=
0
;
i
<
s0
.
ndim
();
++
i
)
{
out_dyn_dims
[
i
]
+=
pads
[
i
]
+
pads
[
i
+
s0
.
ndim
()];
}
return
{
s0
.
type
(),
out_dyn_dims
};
}
else
{
auto
&&
idims
=
s0
.
lens
();
std
::
vector
<
std
::
size_t
>
rdims
(
idims
.
begin
(),
idims
.
end
());
std
::
size_t
num_dims
=
rdims
.
size
();
for
(
std
::
size_t
i
=
0
;
i
<
num_dims
;
i
++
)
{
rdims
[
i
]
+=
pads
[
i
]
+
pads
[
i
+
num_dims
];
}
shape
s
{
s0
.
type
(),
rdims
};
return
s
;
}
shape
s
{
inputs
.
front
().
type
(),
rdims
};
return
s
;
}
std
::
size_t
pad_ndims
()
const
...
...
src/include/migraphx/op/reshape.hpp
View file @
673ca71c
...
...
@@ -28,6 +28,7 @@
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/dyn_output.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
@@ -46,14 +47,60 @@ struct reshape
value
attributes
()
const
{
return
{{
"require_std_shape"
,
true
}};
}
std
::
string
name
()
const
{
return
"reshape"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
dyn_compute_shape
(
shape
s0
)
const
{
auto
dyn_dims
=
s0
.
dyn_dims
();
auto
num_not_fixed
=
std
::
count_if
(
dyn_dims
.
cbegin
(),
dyn_dims
.
cend
(),
[](
auto
dd
)
{
return
not
dd
.
is_fixed
();
});
if
(
num_not_fixed
!=
1
)
{
MIGRAPHX_THROW
(
"Reshape: Only supports one non-fixed dynamic_dimension"
);
}
// track number of fixed elements in input and output
std
::
size_t
num_dims_ele
=
1
;
std
::
size_t
num_dd_ele
=
1
;
for
(
std
::
size_t
i
=
0
;
i
<
dyn_dims
.
size
();
++
i
)
{
if
(
dyn_dims
[
i
].
is_fixed
())
{
num_dims_ele
*=
dims
[
i
];
num_dd_ele
*=
dyn_dims
[
i
].
min
;
}
else
{
if
(
dims
[
i
]
!=
0
and
dims
[
i
]
!=
-
1
)
{
MIGRAPHX_THROW
(
"Reshape: Non-fixed dynamic_dimension doesn't match with 0 or -1 "
"output dimension"
);
}
}
}
if
(
num_dims_ele
!=
num_dd_ele
)
{
MIGRAPHX_THROW
(
"Reshape: Number of fixed elements must match. Input: "
+
std
::
to_string
(
num_dd_ele
)
+
" Output: "
+
std
::
to_string
(
num_dims_ele
));
}
// construct output dynamic shape from dims attribute
std
::
vector
<
shape
::
dynamic_dimension
>
output_dyn_dims
(
dims
.
size
());
std
::
transform
(
dims
.
cbegin
(),
dims
.
cend
(),
dyn_dims
.
cbegin
(),
output_dyn_dims
.
begin
(),
[](
std
::
size_t
dim
,
auto
dyn_dim
)
{
if
(
not
dyn_dim
.
is_fixed
())
return
dyn_dim
;
return
shape
::
dynamic_dimension
{
dim
,
dim
};
});
return
{
s0
.
type
(),
output_dyn_dims
};
}
shape
static_compute_shape
(
std
::
vector
<
shape
>
inputs
,
std
::
size_t
n_neg_dims
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
).
standard
();
check_shapes
{
inputs
,
*
this
}.
standard
();
auto
&&
idims
=
inputs
.
front
().
lens
();
std
::
vector
<
std
::
size_t
>
rdims
(
dims
.
begin
(),
dims
.
end
());
auto
n_neg_dims
=
std
::
count
(
dims
.
begin
(),
dims
.
end
(),
-
1
);
if
(
n_neg_dims
>
1
)
MIGRAPHX_THROW
(
"Reshape: Dimensions for reshape can only have one -1 dim"
);
for
(
std
::
size_t
i
=
0
;
i
<
dims
.
size
();
i
++
)
{
...
...
@@ -86,9 +133,26 @@ struct reshape
return
s
;
}
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
);
auto
n_neg_dims
=
std
::
count
(
dims
.
begin
(),
dims
.
end
(),
-
1
);
if
(
n_neg_dims
>
1
)
MIGRAPHX_THROW
(
"Reshape: Dimensions for reshape can only have one -1 dim"
);
auto
s0
=
inputs
[
0
];
if
(
s0
.
dynamic
())
{
return
dyn_compute_shape
(
s0
);
}
else
{
return
static_compute_shape
(
inputs
,
n_neg_dims
);
}
}
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
return
args
[
0
].
reshape
(
out
put_shape
);
return
args
[
0
].
reshape
(
dyn_out
.
com
put
ed
_shape
);
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
...
...
src/include/migraphx/optimize_module.hpp
0 → 100644
View file @
673ca71c
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_OPTIMIZE_MODULE_HPP
#define MIGRAPHX_GUARD_RTGLIB_OPTIMIZE_MODULE_HPP
#include <string>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
module_pass_manager
;
/**
* Runs several passes in a loop
*/
struct
optimize_module
{
std
::
string
name
()
const
{
return
"optimize_module"
;
}
void
apply
(
module_pass_manager
&
mpm
)
const
;
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/shape.hpp
View file @
673ca71c
...
...
@@ -107,6 +107,13 @@ struct shape
friend
bool
operator
==
(
const
std
::
size_t
&
x
,
const
dynamic_dimension
&
y
);
friend
bool
operator
!=
(
const
dynamic_dimension
&
x
,
const
std
::
size_t
&
y
);
friend
bool
operator
!=
(
const
std
::
size_t
&
x
,
const
dynamic_dimension
&
y
);
// add and subtract fixed std::size_t dimension
dynamic_dimension
&
operator
+=
(
const
std
::
size_t
&
x
);
dynamic_dimension
&
operator
-=
(
const
std
::
size_t
&
x
);
friend
dynamic_dimension
operator
+
(
const
dynamic_dimension
&
x
,
const
std
::
size_t
&
y
);
friend
dynamic_dimension
operator
+
(
const
std
::
size_t
&
x
,
const
dynamic_dimension
&
y
);
friend
dynamic_dimension
operator
-
(
const
dynamic_dimension
&
x
,
const
std
::
size_t
&
y
);
};
static
const
std
::
vector
<
type_t
>&
types
();
...
...
src/module.cpp
View file @
673ca71c
...
...
@@ -822,7 +822,8 @@ static void print_make_op(std::ostream& os, const operation& op)
static
void
print_py_shape
(
std
::
ostream
&
os
,
const
migraphx
::
shape
&
s
)
{
os
<<
"migraphx.shape("
<<
s
.
type_string
()
<<
", lens="
<<
to_json_string
(
s
.
lens
());
os
<<
"migraphx.shape(type="
<<
to_json_string
(
s
.
type_string
())
<<
", lens="
<<
to_json_string
(
s
.
lens
());
if
(
not
s
.
standard
())
os
<<
", strides="
<<
to_json_string
(
s
.
strides
());
os
<<
")"
;
...
...
src/onnx/onnx_parser.cpp
View file @
673ca71c
...
...
@@ -110,9 +110,19 @@ instruction_ref onnx_parser::node_info::add_bias(const std::vector<instruction_r
{
if
(
args
.
size
()
==
3
)
{
auto
bias_bcast
=
mod
->
add_instruction
(
make_op
(
"broadcast"
,
{{
"axis"
,
axis
},
{
"out_lens"
,
curr_ins
->
get_shape
().
lens
()}}),
args
[
2
]);
instruction_ref
bias_bcast
;
// if curr_ins has a dynamic output shape use 2 input broadcast
if
(
curr_ins
->
get_shape
().
dynamic
())
{
bias_bcast
=
mod
->
add_instruction
(
make_op
(
"broadcast"
,
{{
"axis"
,
axis
}}),
args
[
2
],
curr_ins
);
}
else
{
bias_bcast
=
mod
->
add_instruction
(
make_op
(
"broadcast"
,
{{
"axis"
,
axis
},
{
"out_lens"
,
curr_ins
->
get_shape
().
lens
()}}),
args
[
2
]);
}
return
mod
->
add_instruction
(
make_op
(
"add"
),
curr_ins
,
bias_bcast
);
}
return
curr_ins
;
...
...
src/onnx/parse_gemm.cpp
View file @
673ca71c
...
...
@@ -39,10 +39,19 @@ struct parse_gemm : op_parser<parse_gemm>
onnx_parser
::
node_info
info
,
std
::
vector
<
instruction_ref
>
args
)
const
{
float
alpha
=
1.0
f
;
float
beta
=
1.0
f
;
bool
transa
=
false
;
bool
transb
=
false
;
auto
a_arg
=
args
[
0
];
auto
b_arg
=
args
[
1
];
if
(
a_arg
->
get_shape
().
ndim
()
!=
2
or
b_arg
->
get_shape
().
ndim
()
!=
2
)
{
MIGRAPHX_THROW
(
"PARSE_GEMM: A and B should be rank 2, A is rank "
+
std
::
to_string
(
a_arg
->
get_shape
().
ndim
())
+
", B is rank "
+
std
::
to_string
(
b_arg
->
get_shape
().
ndim
()));
}
float
alpha
=
1.0
f
;
float
beta
=
1.0
f
;
bool
trans_a
=
false
;
bool
trans_b
=
false
;
if
(
contains
(
info
.
attributes
,
"alpha"
))
{
alpha
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"alpha"
)).
at
<
float
>
();
...
...
@@ -53,65 +62,73 @@ struct parse_gemm : op_parser<parse_gemm>
}
if
(
contains
(
info
.
attributes
,
"transA"
))
{
transa
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"transA"
)).
at
<
bool
>
();
trans
_
a
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"transA"
)).
at
<
bool
>
();
}
if
(
contains
(
info
.
attributes
,
"transB"
))
{
transb
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"transB"
)).
at
<
bool
>
();
trans
_
b
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"transB"
)).
at
<
bool
>
();
}
std
::
vector
<
int64_t
>
perm
(
args
[
0
]
->
get_shape
().
lens
().
size
());
std
::
iota
(
perm
.
begin
(),
perm
.
end
(),
int64_t
{
0
});
// swap the last two elements
std
::
swap
(
*
perm
.
rbegin
(),
*
(
perm
.
rbegin
()
+
1
));
auto
l1
=
args
[
0
];
auto
dot_type
=
l1
->
get_shape
().
type
();
std
::
vector
<
int64_t
>
perm
=
{
1
,
0
};
auto
dot_type
=
a_arg
->
get_shape
().
type
();
if
(
alpha
!=
1.0
f
)
{
auto
alpha_literal
=
info
.
add_literal
(
alpha
);
l1
=
info
.
add_broadcastable_binary_op
(
"mul"
,
alpha_literal
,
l1
);
if
(
l1
->
get_shape
().
type
()
!=
dot_type
)
a_arg
=
info
.
add_broadcastable_binary_op
(
"mul"
,
alpha_literal
,
a_arg
);
if
(
a_arg
->
get_shape
().
type
()
!=
dot_type
)
{
l1
=
info
.
add_instruction
(
make_op
(
"convert"
,
{{
"target_type"
,
dot_type
}}),
l1
);
a_arg
=
info
.
add_instruction
(
make_op
(
"convert"
,
{{
"target_type"
,
dot_type
}}),
a_arg
);
}
}
l1
=
(
transa
)
?
info
.
add_instruction
(
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
l1
)
:
l1
;
auto
l2
=
(
transb
)
?
info
.
add_instruction
(
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
args
[
1
])
:
args
[
1
];
a_arg
=
(
trans_a
)
?
info
.
add_instruction
(
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
a_arg
)
:
a_arg
;
b_arg
=
(
trans_b
)
?
info
.
add_instruction
(
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
args
[
1
])
:
args
[
1
];
auto
ret
=
info
.
add_instruction
(
make_op
(
"dot"
),
l1
,
l2
);
auto
dot_ins
=
info
.
add_instruction
(
make_op
(
"dot"
),
a_arg
,
b_arg
);
if
(
args
.
size
()
==
3
)
{
if
(
not
float_equal
(
beta
,
0.0
f
)
&&
args
[
2
]
->
get_shape
().
elements
()
>
0
)
if
(
not
float_equal
(
beta
,
0.0
f
))
{
auto
out_lens
=
l1
->
get_shape
().
lens
();
out_lens
.
back
()
=
l2
->
get_shape
().
lens
().
back
();
auto
l3
=
args
[
2
];
auto
l3_lens
=
l3
->
get_shape
().
lens
();
if
(
not
std
::
equal
(
out_lens
.
begin
(),
out_lens
.
end
(),
l3_lens
.
begin
(),
l3_lens
.
end
()))
auto
c_arg
=
args
[
2
];
if
(
dot_ins
->
get_shape
().
dynamic
())
{
l3
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
out_lens
}}),
args
[
2
]);
c_arg
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
),
args
[
2
],
dot_ins
);
}
auto
beta_literal
=
info
.
add_literal
(
beta
);
auto
beta_l3
=
info
.
add_broadcastable_binary_op
(
"mul"
,
l3
,
beta_literal
);
if
(
beta_l3
->
get_shape
().
type
()
!=
dot_type
)
else
{
beta_l3
=
info
.
add_instruction
(
make_op
(
"convert"
,
{{
"target_type"
,
dot_type
}}),
beta_l3
);
auto
out_lens
=
a_arg
->
get_shape
().
lens
();
out_lens
.
back
()
=
b_arg
->
get_shape
().
lens
().
back
();
auto
c_lens
=
c_arg
->
get_shape
().
lens
();
if
(
not
std
::
equal
(
out_lens
.
begin
(),
out_lens
.
end
(),
c_lens
.
begin
(),
c_lens
.
end
()))
{
c_arg
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
out_lens
}}),
args
[
2
]);
}
}
return
info
.
add_instruction
(
make_op
(
"add"
),
ret
,
beta_l3
);
if
(
not
float_equal
(
beta
,
1.0
f
))
{
auto
beta_literal
=
info
.
add_literal
(
beta
);
c_arg
=
info
.
add_broadcastable_binary_op
(
"mul"
,
c_arg
,
beta_literal
);
if
(
c_arg
->
get_shape
().
type
()
!=
dot_type
)
{
c_arg
=
info
.
add_instruction
(
make_op
(
"convert"
,
{{
"target_type"
,
dot_type
}}),
c_arg
);
}
}
return
info
.
add_instruction
(
make_op
(
"add"
),
dot_ins
,
c_arg
);
}
}
return
ret
;
return
dot_ins
;
}
};
...
...
src/onnx/parse_matmul.cpp
View file @
673ca71c
...
...
@@ -43,55 +43,79 @@ struct parse_matmul : op_parser<parse_matmul>
const
onnx_parser
::
node_info
&
info
,
std
::
vector
<
instruction_ref
>
args
)
const
{
auto
l0
=
args
[
0
];
auto
l1
=
args
[
1
];
auto
l0_len
s
=
l
0
->
get_shape
()
.
lens
()
;
auto
l1_len
s
=
l
1
->
get_shape
()
.
lens
()
;
auto
a0
=
args
[
0
];
auto
a1
=
args
[
1
];
auto
s
0
=
a
0
->
get_shape
();
auto
s
1
=
a
1
->
get_shape
();
// args[0] is a vector, prepend 1 to the shape
instruction_ref
dot_res
;
bool
is_a_prepended
=
false
;
if
(
l0_lens
.
size
()
==
1
)
bool
is_b_appended
=
false
;
if
(
s0
.
ndim
()
==
1
)
{
is_a_prepended
=
true
;
l0_lens
.
insert
(
l0_lens
.
begin
(),
1
);
l0
=
info
.
add_instruction
(
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
0
}}}),
args
[
0
]);
a0
=
info
.
add_instruction
(
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
0
}}}),
args
[
0
]);
}
bool
is_b_appended
=
false
;
if
(
l1_lens
.
size
()
==
1
)
if
(
s1
.
ndim
()
==
1
)
{
is_b_appended
=
true
;
l1_lens
.
push_back
(
1
);
l1
=
info
.
add_instruction
(
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
1
}}}),
args
[
1
]);
a1
=
info
.
add_instruction
(
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
1
}}}),
args
[
1
]);
}
instruction_ref
bl0
=
l0
;
instruction_ref
bl1
=
l1
;
if
(
not
std
::
equal
(
l0_lens
.
rbegin
()
+
2
,
l0_lens
.
rend
(),
l1_lens
.
rbegin
()
+
2
,
l1_lens
.
rend
()))
if
(
s0
.
dynamic
()
or
s1
.
dynamic
())
{
auto
l0_it
=
l0_lens
.
begin
()
+
l0_lens
.
size
()
-
2
;
std
::
vector
<
std
::
size_t
>
l0_broadcasted_lens
(
l0_lens
.
begin
(),
l0_it
);
auto
l1_it
=
l1_lens
.
begin
()
+
l1_lens
.
size
()
-
2
;
std
::
vector
<
std
::
size_t
>
l1_broadcasted_lens
(
l1_lens
.
begin
(),
l1_it
);
auto
output_lens
=
compute_broadcasted_lens
(
l0_broadcasted_lens
,
l1_broadcasted_lens
);
l0_broadcasted_lens
=
output_lens
;
l0_broadcasted_lens
.
insert
(
l0_broadcasted_lens
.
end
(),
l0_it
,
l0_lens
.
end
());
l1_broadcasted_lens
=
output_lens
;
l1_broadcasted_lens
.
insert
(
l1_broadcasted_lens
.
end
(),
l1_it
,
l1_lens
.
end
());
if
(
l0_lens
!=
l0_broadcasted_lens
)
if
(
opd
.
op_name
==
"quant_dot"
)
{
MIGRAPHX_THROW
(
"PARSE_MATMUL: dynamic MatMulInteger not supported"
)
;
}
auto
s0_dds
=
a0
->
get_shape
().
to_dynamic
().
dyn_dims
(
);
auto
s1_dds
=
a1
->
get_shape
().
to_dynamic
().
dyn_dims
()
;
// TODO: handling this case requires a new multibroadcast mode
if
(
not
std
::
equal
(
s0_dds
.
rbegin
()
+
2
,
s0_dds
.
rend
(),
s1_dds
.
rbegin
()
+
2
,
s1_dds
.
rend
())
)
{
bl0
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
l0_broadcasted_lens
}}),
l0
);
MIGRAPHX_THROW
(
"PARSE_MATMUL: dynamic shape broadcasting not supported"
);
}
if
(
l1_lens
!=
l1_broadcasted_lens
)
dot_res
=
info
.
add_instruction
(
make_op
(
opd
.
op_name
),
a0
,
a1
);
}
else
{
auto
s0_lens
=
a0
->
get_shape
().
lens
();
auto
s1_lens
=
a1
->
get_shape
().
lens
();
instruction_ref
ba0
=
a0
;
instruction_ref
ba1
=
a1
;
// try broadcasting if dimensions other than last two do not match
if
(
not
std
::
equal
(
s0_lens
.
rbegin
()
+
2
,
s0_lens
.
rend
(),
s1_lens
.
rbegin
()
+
2
,
s1_lens
.
rend
()))
{
bl1
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
l1_broadcasted_lens
}}),
l1
);
auto
l0_it
=
s0_lens
.
begin
()
+
s0_lens
.
size
()
-
2
;
std
::
vector
<
std
::
size_t
>
l0_broadcasted_lens
(
s0_lens
.
begin
(),
l0_it
);
auto
l1_it
=
s1_lens
.
begin
()
+
s1_lens
.
size
()
-
2
;
std
::
vector
<
std
::
size_t
>
l1_broadcasted_lens
(
s1_lens
.
begin
(),
l1_it
);
auto
output_lens
=
compute_broadcasted_lens
(
l0_broadcasted_lens
,
l1_broadcasted_lens
);
l0_broadcasted_lens
=
output_lens
;
l0_broadcasted_lens
.
insert
(
l0_broadcasted_lens
.
end
(),
l0_it
,
s0_lens
.
end
());
l1_broadcasted_lens
=
output_lens
;
l1_broadcasted_lens
.
insert
(
l1_broadcasted_lens
.
end
(),
l1_it
,
s1_lens
.
end
());
if
(
s0_lens
!=
l0_broadcasted_lens
)
{
ba0
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
l0_broadcasted_lens
}}),
a0
);
}
if
(
s1_lens
!=
l1_broadcasted_lens
)
{
ba1
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
l1_broadcasted_lens
}}),
a1
);
}
}
dot_res
=
info
.
add_instruction
(
make_op
(
opd
.
op_name
),
ba0
,
ba1
);
}
instruction_ref
dot_res
=
info
.
add_instruction
(
make_op
(
opd
.
op_name
),
bl0
,
bl1
);
int64_t
num_axis
=
static_cast
<
int64_t
>
(
dot_res
->
get_shape
().
lens
().
size
());
// squeeze the appended or prepended dimensions
int64_t
num_axis
=
dot_res
->
get_shape
().
ndim
();
if
(
is_a_prepended
)
{
dot_res
=
info
.
add_instruction
(
make_op
(
"squeeze"
,
{{
"axes"
,
{
num_axis
-
2
}}}),
dot_res
);
...
...
src/onnx/parse_pad.cpp
View file @
673ca71c
...
...
@@ -147,7 +147,13 @@ struct parse_pad : op_parser<parse_pad>
{
auto
mode
=
info
.
attributes
.
at
(
"mode"
).
s
();
if
(
mode
==
"reflect"
)
{
if
(
args
.
front
()
->
get_shape
().
dynamic
())
{
MIGRAPHX_THROW
(
"PARSE_PAD: reflect padding with dynamic shape not supported"
);
}
return
reflect_pad
(
info
,
pads
,
args
.
front
());
}
if
(
mode
!=
"constant"
)
{
MIGRAPHX_THROW
(
...
...
src/onnx/parse_reshape.cpp
View file @
673ca71c
...
...
@@ -49,7 +49,7 @@ struct parse_reshape : op_parser<parse_reshape>
if
(
args
.
size
()
==
2
)
{
auto
s
=
args
[
1
]
->
eval
();
check_arg_empty
(
s
,
"Reshape:
dynamic shape
is not supported"
);
check_arg_empty
(
s
,
"Reshape:
non-constant shape input
is not supported"
);
s
.
visit
([
&
](
auto
v
)
{
copy
(
v
,
std
::
back_inserter
(
dims
));
});
}
...
...
src/onnx/parse_trilu.cpp
0 → 100644
View file @
673ca71c
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
onnx
{
struct
parse_trilu
:
op_parser
<
parse_trilu
>
{
std
::
vector
<
op_desc
>
operators
()
const
{
return
{{
"Trilu"
}};
}
instruction_ref
parse
(
const
op_desc
&
,
const
onnx_parser
&
,
const
onnx_parser
::
node_info
&
info
,
std
::
vector
<
instruction_ref
>
args
)
const
{
auto
input_shape
=
args
[
0
]
->
get_shape
();
assert
(
input_shape
.
ndim
()
>=
2
);
auto
input_lens
=
input_shape
.
lens
();
size_t
num_rows
=
*
(
input_lens
.
rbegin
()
+
1
);
size_t
num_cols
=
input_lens
.
back
();
int
k
=
0
;
bool
upper
=
true
;
if
(
args
.
size
()
>
1
)
{
auto
arg_k
=
args
[
1
]
->
eval
();
check_arg_empty
(
arg_k
,
"PARSE_TRILU: dynamic k not supported"
);
k
=
arg_k
.
at
<
int
>
();
}
if
(
k
<
0
)
MIGRAPHX_THROW
(
"PARSE_TRILU: negative k values not supported"
);
if
(
contains
(
info
.
attributes
,
"upper"
))
{
upper
=
static_cast
<
bool
>
(
info
.
attributes
.
at
(
"upper"
).
i
());
}
shape
::
type_t
output_type
=
args
[
0
]
->
get_shape
().
type
();
// when creating the mask, if upper == 1,
// the inner triangle will have values set to 0
std
::
vector
<
bool
>
mask_mat
(
num_rows
*
num_cols
,
upper
);
for
(
size_t
i
=
0
;
i
<
num_rows
;
i
++
)
{
for
(
size_t
j
=
0
;
j
<
std
::
min
(
k
,
static_cast
<
int
>
(
num_cols
));
j
++
)
{
mask_mat
[
i
*
num_cols
+
j
]
=
not
upper
;
}
k
++
;
}
auto
mask
=
info
.
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
output_type
,
{
num_rows
,
num_cols
}},
mask_mat
});
return
info
.
add_broadcastable_binary_op
(
"mul"
,
mask
,
args
[
0
]);
}
};
}
// namespace onnx
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/optimize_module.cpp
0 → 100644
View file @
673ca71c
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/optimize_module.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/propagate_constant.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
void
optimize_module
::
apply
(
module_pass_manager
&
mpm
)
const
{
for
(
int
i
=
0
;
i
<
2
;
i
++
)
{
mpm
.
run_pass
(
simplify_reshapes
{});
mpm
.
run_pass
(
simplify_algebra
{});
mpm
.
run_pass
(
eliminate_common_subexpression
{});
mpm
.
run_pass
(
dead_code_elimination
{});
mpm
.
run_pass
(
propagate_constant
{});
mpm
.
run_pass
(
dead_code_elimination
{});
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/shape.cpp
View file @
673ca71c
...
...
@@ -504,6 +504,31 @@ bool shape::dynamic_dimension::is_fixed() const { return this->min == this->max;
bool
shape
::
dynamic_dimension
::
has_optimal
()
const
{
return
opt
!=
0
;
}
shape
::
dynamic_dimension
&
shape
::
dynamic_dimension
::
operator
+=
(
const
std
::
size_t
&
x
)
{
this
->
min
+=
x
;
this
->
max
+=
x
;
if
(
this
->
opt
!=
0
)
{
this
->
opt
+=
x
;
};
return
*
this
;
}
shape
::
dynamic_dimension
&
shape
::
dynamic_dimension
::
operator
-=
(
const
std
::
size_t
&
x
)
{
assert
(
this
->
min
>=
x
);
assert
(
this
->
max
>=
x
);
this
->
min
-=
x
;
this
->
max
-=
x
;
if
(
this
->
opt
!=
0
)
{
assert
(
this
->
opt
>=
x
);
this
->
opt
-=
x
;
}
return
*
this
;
}
bool
operator
==
(
const
shape
::
dynamic_dimension
&
x
,
const
shape
::
dynamic_dimension
&
y
)
{
// don't check opt if both are fixed
...
...
@@ -529,6 +554,23 @@ bool operator==(const std::size_t& x, const shape::dynamic_dimension& y) { retur
bool
operator
!=
(
const
shape
::
dynamic_dimension
&
x
,
const
std
::
size_t
&
y
)
{
return
not
(
x
==
y
);
}
bool
operator
!=
(
const
std
::
size_t
&
x
,
const
shape
::
dynamic_dimension
&
y
)
{
return
not
(
x
==
y
);
}
shape
::
dynamic_dimension
operator
+
(
const
shape
::
dynamic_dimension
&
x
,
const
std
::
size_t
&
y
)
{
auto
dd
=
x
;
return
dd
+=
y
;
}
shape
::
dynamic_dimension
operator
+
(
const
std
::
size_t
&
x
,
const
shape
::
dynamic_dimension
&
y
)
{
return
y
+
x
;
}
shape
::
dynamic_dimension
operator
-
(
const
shape
::
dynamic_dimension
&
x
,
const
std
::
size_t
&
y
)
{
auto
dd
=
x
;
return
dd
-=
y
;
}
bool
operator
==
(
const
shape
&
x
,
const
shape
&
y
)
{
if
(
x
.
dynamic
()
and
y
.
dynamic
())
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment