Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
25e8cf0b
"src/vscode:/vscode.git/clone" did not exist on "65ef35cd978b2389d41e2d324f06784e01e2af07"
Unverified
Commit
25e8cf0b
authored
Jan 27, 2023
by
Ted Themistokleous
Committed by
GitHub
Jan 27, 2023
Browse files
Merge branch 'develop' into test_onnx_zoo
parents
a313a68e
635502be
Changes
138
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
441 additions
and
148 deletions
+441
-148
.github/workflows/add-to-project.yaml
.github/workflows/add-to-project.yaml
+19
-0
.github/workflows/sync-onnxrt-main.yaml
.github/workflows/sync-onnxrt-main.yaml
+45
-0
Dockerfile
Dockerfile
+31
-5
Jenkinsfile
Jenkinsfile
+18
-8
examples/nlp/python_bert_squad/requirements_bertsquad.txt
examples/nlp/python_bert_squad/requirements_bertsquad.txt
+1
-1
src/common.cpp
src/common.cpp
+1
-2
src/dead_code_elimination.cpp
src/dead_code_elimination.cpp
+2
-2
src/driver/main.cpp
src/driver/main.cpp
+8
-2
src/file_buffer.cpp
src/file_buffer.cpp
+16
-8
src/include/migraphx/check_shapes.hpp
src/include/migraphx/check_shapes.hpp
+1
-1
src/include/migraphx/file_buffer.hpp
src/include/migraphx/file_buffer.hpp
+1
-1
src/include/migraphx/instruction.hpp
src/include/migraphx/instruction.hpp
+2
-0
src/include/migraphx/literal.hpp
src/include/migraphx/literal.hpp
+6
-15
src/include/migraphx/module.hpp
src/include/migraphx/module.hpp
+6
-0
src/include/migraphx/op/argmax.hpp
src/include/migraphx/op/argmax.hpp
+19
-11
src/include/migraphx/op/dot.hpp
src/include/migraphx/op/dot.hpp
+51
-22
src/include/migraphx/op/flatten.hpp
src/include/migraphx/op/flatten.hpp
+39
-9
src/include/migraphx/op/pad.hpp
src/include/migraphx/op/pad.hpp
+21
-10
src/include/migraphx/op/pooling.hpp
src/include/migraphx/op/pooling.hpp
+115
-36
src/include/migraphx/op/reduce_op.hpp
src/include/migraphx/op/reduce_op.hpp
+39
-15
No files found.
.github/workflows/add-to-project.yaml
0 → 100644
View file @
25e8cf0b
name
:
Add items to GH project
on
:
pull_request
:
types
:
-
opened
issues
:
types
:
-
opened
jobs
:
add-to-project
:
name
:
Add PRs and issues to MIGX project
runs-on
:
ubuntu-latest
steps
:
-
uses
:
actions/add-to-project@v0.4.0
with
:
project-url
:
https://github.com/orgs/ROCmSoftwarePlatform/projects/20
github-token
:
${{ secrets.TEST_PR_WORKFLOW }}
.github/workflows/sync-onnxrt-main.yaml
0 → 100644
View file @
25e8cf0b
name
:
Onnxruntime main weekly sync
on
:
schedule
:
-
cron
:
"
05
09
*
*
5"
jobs
:
sync
:
steps
:
-
uses
:
actions/checkout@v3
with
:
ref
:
develop
path
:
../
get_date
:
steps
:
-
run
:
echo "::set-output name=date::$(date +'%Y-%m-%d')"
update_file
:
needs
:
[
sync get_date
]
steps
:
-
run
:
git clone https://github.com/microsoft/onnxruntime.git && cd onnxruntime && git rev-parse HEAD >> ../test/onnx/.onnxrt-commit
Add_commit
:
needs
:
update_file
steps
:
-
name
:
Add & Commit
uses
:
EndBug/add-and-commit@v9.1.1
with
:
new_branch
:
onnxruntime-sync-${{ steps.date.outputs.date }}
add
:
../test/onnx/.onnxrt-commit
message
:
Update Onnxruntime commit to latest release
default_author
:
github_actions
push
:
true
PR
:
needs
:
Add_commit
steps
:
-
name
:
GitHub Action for creating Pull Requests
uses
:
devops-infra/action-pull-request@v0.5.3
with
:
github_token
:
${{ secrets.GITHUB_TOKEN }}
title
:
Sync Onnxruntime main
reviewer
:
pfultz2, causten
assignee
:
TedThemistokleous
label
:
automatic, onnxruntime
target_branch
:
develop
Dockerfile
View file @
25e8cf0b
...
...
@@ -5,6 +5,10 @@ ARG PREFIX=/usr/local
# Support multiarch
RUN
dpkg
--add-architecture
i386
# Install rocm key
RUN
apt-get update
&&
apt-get
install
-y
gnupg2
--no-install-recommends
curl
&&
\
curl
-sL
http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add -
# Add rocm repository
RUN
sh
-c
'echo deb [arch=amd64 trusted=yes] http://repo.radeon.com/rocm/apt/5.3/ ubuntu main > /etc/apt/sources.list.d/rocm.list'
...
...
@@ -32,10 +36,27 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
libnuma-dev
\
miopen-hip
\
rocblas
\
hipfft
\
rocthrust
\
rocrand
\
hipsparse
\
rccl
\
rccl-dev
\
rocm-smi-lib
\
rocm-dev
\
roctracer-dev
\
hipcub
\
hipblas
\
hipify-clang
\
half
\
libssl-dev
\
zlib1g-dev
&&
\
apt-get clean
&&
\
rm
-rf
/var/lib/apt/lists/
*
# add this for roctracer dependancies
RUN
pip3
install
CppHeaderParser
packaging
==
22.0
# Workaround broken rocm packages
RUN
ln
-s
/opt/rocm-
*
/opt/rocm
RUN
echo
"/opt/rocm/lib"
>
/etc/ld.so.conf.d/rocm.conf
...
...
@@ -72,22 +93,27 @@ RUN /download_models.sh && rm /download_models.sh
# Install latest ccache version
RUN
cget
-p
$PREFIX
install
facebook/zstd@v1.4.5
-X
subdir
-DCMAKE_DIR
=
build/cmake
RUN
cget
-p
$PREFIX
install
ccache@v4.1
-DENABLE_TESTING
=
OFF
RUN
cget
-p
/opt/cmake
install
kitware/cmake@v3.24.3
# Install newer cmake for onnx runtime
ARG
CMAKE_VERSION=3.24.2
RUN
cget
-p
/opt/cmake
install
-X
binary https://github.com/Kitware/CMake/releases/download/v
${
CMAKE_VERSION
}
/cmake-
${
CMAKE_VERSION
}
-Linux-x86_64
.tar.gz
RUN
export
ONNXRT_COMMIT
=
$(
cat test
/onnx/.onnxrt-commit
)
ARG
ONNXRUNTIME_REPO=https://github.com/Microsoft/onnxruntime
ARG
ONNXRUNTIME_BRANCH=main
ARG
ONNXRUNTIME_COMMIT=24f1bd6156cf5968bbc76dfb0e801a9b9c56b9fc
ARG
ONNXRUNTIME_COMMIT=$ONNXRT_COMMIT
# Let us know which commit where're using for CI
RUN
echo
"Onnxruntime Commit:"
&&
echo
$ONNXRUNTIME_COMMIT
RUN
git clone
--single-branch
--branch
${
ONNXRUNTIME_BRANCH
}
--recursive
${
ONNXRUNTIME_REPO
}
onnxruntime
&&
\
cd
onnxruntime
&&
\
git checkout
${
ONNXRUNTIME_COMMIT
}
&&
\
/bin/sh dockerfiles/scripts/install_common_deps.sh
ADD
tools/build_and_test_onnxrt.sh /onnxruntime/build_and_test_onnxrt.sh
RUN
cget
-p
/usr/local
install
ROCmSoftwarePlatform/
llvm-project-mlir@c0723a7e50043d973cb73ae51dc30d36679ee7e5
-DBUILD_MIXR_TARGET
=
On
RUN
cget
-p
/usr/local
install
ROCmSoftwarePlatform/
rocMLIR@78b706fe9879587ab98b6614ae539265374a3fae
-DBUILD_MIXR_TARGET
=
On
-DLLVM_ENABLE_ZSTD
=
Off
-DLLVM_ENABLE_THREADS
=
Off
ENV
MIOPEN_FIND_DB_PATH=/tmp/miopen/find-db
ENV
MIOPEN_USER_DB_PATH=/tmp/miopen/user-db
...
...
Jenkinsfile
View file @
25e8cf0b
...
...
@@ -11,16 +11,22 @@ def rocmtestnode(Map conf) {
def
image
=
'migraphxlib'
env
.
CCACHE_COMPRESSLEVEL
=
7
env
.
CCACHE_DIR
=
ccache
def
cmake_build
=
{
compiler
,
flags
->
def
cmake_build
=
{
bconf
->
def
compiler
=
bconf
.
get
(
"compiler"
,
"/opt/rocm/llvm/bin/clang++"
)
def
flags
=
bconf
.
get
(
"flags"
,
""
)
def
gpu_debug
=
bconf
.
get
(
"gpu_debug"
,
"0"
)
def
cmd
=
"""
env
ulimit -c unlimited
echo "leak:dnnl::impl::malloc" > suppressions.txt
export LSAN_OPTIONS="suppressions=\$(pwd)/suppressions.txt"
export MIGRAPHX_GPU_DEBUG=${gpu_debug}
export CXX=${compiler}
export CXXFLAGS='-Werror'
env
rm -rf build
mkdir build
cd build
CXX=${compiler} CXXFLAGS='-Werror'
cmake -DCMAKE_C_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_COMPILER_LAUNCHER=ccache ${flags} ..
cmake -DCMAKE_C_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_COMPILER_LAUNCHER=ccache ${flags} ..
make -j\$(nproc) generate all doc package check VERBOSE=1
"""
echo
cmd
...
...
@@ -93,28 +99,32 @@ rocmtest clang_debug: rocmnode('vega') { cmake_build ->
stage
(
'Hip Clang Debug'
)
{
def
sanitizers
=
"undefined"
def
debug_flags
=
"-g -O2 -fsanitize=${sanitizers} -fno-sanitize-recover=${sanitizers}"
cmake_build
(
"/opt/rocm/llvm/bin/clang++"
,
"-DCMAKE_BUILD_TYPE=debug -DMIGRAPHX_ENABLE_PYTHON=Off -DCMAKE_CXX_FLAGS_DEBUG='${debug_flags}' -DCMAKE_C_FLAGS_DEBUG='${debug_flags}'"
)
cmake_build
(
flags:
"-DCMAKE_BUILD_TYPE=debug -DMIGRAPHX_ENABLE_PYTHON=Off -DCMAKE_CXX_FLAGS_DEBUG='${debug_flags}' -DCMAKE_C_FLAGS_DEBUG='${debug_flags}'"
)
}
},
clang_gpu_debug:
rocmnode
(
'vega'
)
{
cmake_build
->
stage
(
'Hip Clang GPU Debug'
)
{
cmake_build
(
flags:
"-DCMAKE_BUILD_TYPE=release"
,
gpu_debug:
true
)
}
},
clang_release:
rocmnode
(
'vega'
)
{
cmake_build
->
stage
(
'Hip Clang Release'
)
{
cmake_build
(
"/opt/rocm/llvm/bin/clang++"
,
"-DCMAKE_BUILD_TYPE=release"
)
cmake_build
(
flags:
"-DCMAKE_BUILD_TYPE=release"
)
stash
includes:
'build/*.deb'
,
name:
'migraphx-package'
}
},
mlir_debug:
rocmnode
(
'vega'
)
{
cmake_build
->
stage
(
'MLIR Debug'
)
{
def
sanitizers
=
"undefined"
def
debug_flags
=
"-g -O2 -fsanitize=${sanitizers} -fno-sanitize-recover=${sanitizers}"
cmake_build
(
"/opt/rocm/llvm/bin/clang++"
,
"-DCMAKE_BUILD_TYPE=debug -DMIGRAPHX_ENABLE_PYTHON=Off -DMIGRAPHX_ENABLE_MLIR=On -DCMAKE_CXX_FLAGS_DEBUG='${debug_flags}' -DCMAKE_C_FLAGS_DEBUG='${debug_flags}'"
)
cmake_build
(
flags:
"-DCMAKE_BUILD_TYPE=debug -DMIGRAPHX_ENABLE_PYTHON=Off -DMIGRAPHX_ENABLE_MLIR=On -DCMAKE_CXX_FLAGS_DEBUG='${debug_flags}' -DCMAKE_C_FLAGS_DEBUG='${debug_flags}'"
)
}
},
clang_asan:
rocmnode
(
'nogpu'
)
{
cmake_build
->
stage
(
'Clang ASAN'
)
{
def
sanitizers
=
"undefined,address"
def
debug_flags
=
"-g -O2 -fno-omit-frame-pointer -fsanitize=${sanitizers} -fno-sanitize-recover=${sanitizers}"
cmake_build
(
"/opt/rocm/llvm/bin/clang++"
,
"-DCMAKE_BUILD_TYPE=debug -DMIGRAPHX_ENABLE_PYTHON=Off -DMIGRAPHX_ENABLE_GPU=Off -DMIGRAPHX_ENABLE_CPU=On -DCMAKE_CXX_FLAGS_DEBUG='${debug_flags}' -DCMAKE_C_FLAGS_DEBUG='${debug_flags}'"
)
cmake_build
(
flags:
"-DCMAKE_BUILD_TYPE=debug -DMIGRAPHX_ENABLE_PYTHON=Off -DMIGRAPHX_ENABLE_GPU=Off -DMIGRAPHX_ENABLE_CPU=On -DCMAKE_CXX_FLAGS_DEBUG='${debug_flags}' -DCMAKE_C_FLAGS_DEBUG='${debug_flags}'"
)
}
}
//, clang_release_navi: rocmnode('navi21') { cmake_build ->
// stage('HIP Clang Release Navi') {
// cmake_build(
"/opt/rocm/llvm/bin/clang++",
"-DCMAKE_BUILD_TYPE=release")
// cmake_build(
flags:
"-DCMAKE_BUILD_TYPE=release")
// }
//}
...
...
examples/nlp/python_bert_squad/requirements_bertsquad.txt
View file @
25e8cf0b
...
...
@@ -21,6 +21,6 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#####################################################################################
tensorflow==2.
7.2
tensorflow==2.
9.3
onnxruntime
tokenizers
\ No newline at end of file
src/common.cpp
View file @
25e8cf0b
...
...
@@ -77,7 +77,6 @@ std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, sha
}
auto
offset
=
s1
.
ndim
()
-
s0
.
ndim
();
std
::
vector
<
shape
::
dynamic_dimension
>
out_dims
(
s1
.
dyn_dims
());
shape
::
dynamic_dimension
one_dyn_dim
{
1
,
1
,
0
};
std
::
transform
(
s0
.
dyn_dims
().
cbegin
(),
s0
.
dyn_dims
().
cend
(),
...
...
@@ -88,7 +87,7 @@ std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, sha
{
return
a
;
}
else
if
(
a
==
one_dyn_dim
or
b
==
one_dyn_dim
)
else
if
(
a
==
1
or
b
==
1
)
{
// setting opt to 0, may need to be changed
return
shape
::
dynamic_dimension
{
std
::
max
(
a
.
min
,
b
.
min
),
std
::
max
(
a
.
max
,
b
.
max
),
0
};
...
...
src/dead_code_elimination.cpp
View file @
25e8cf0b
...
...
@@ -51,8 +51,8 @@ void dead_code_elimination::apply(module& m) const
// Skip instruction with empty shape as output unless its [dynamic, builtin, undefined,
// identity, allocate]
if
((
not
i
->
get_shape
().
dynamic
()
and
i
->
get_shape
().
elements
()
==
0
)
and
i
->
name
().
front
()
!
=
'@'
and
not
contains
({
"undefined"
,
"identity"
,
"allocate"
},
i
->
name
()
))
not
(
i
->
name
().
front
()
=
=
'@'
)
and
not
contains
({
"identity"
,
"allocate"
},
i
->
name
())
and
not
i
->
is_undefined
(
))
continue
;
assert
(
std
::
distance
(
m
.
begin
(),
i
)
<=
std
::
distance
(
m
.
begin
(),
last
));
std
::
unordered_set
<
instruction_ref
>
visited
;
...
...
src/driver/main.cpp
View file @
25e8cf0b
...
...
@@ -109,8 +109,12 @@ struct loader
ap
(
brief
,
{
"--brief"
},
ap
.
help
(
"Make the output brief."
),
ap
.
set_value
(
true
));
ap
(
output_type
,
{
"--cpp"
},
ap
.
help
(
"Print out the program as
cpp
program."
),
ap
.
help
(
"Print out the program as
C++
program."
),
ap
.
set_value
(
"cpp"
));
ap
(
output_type
,
{
"--python"
,
"--py"
},
ap
.
help
(
"Print out the program as python program."
),
ap
.
set_value
(
"py"
));
ap
(
output_type
,
{
"--json"
},
ap
.
help
(
"Print out program as json."
),
ap
.
set_value
(
"json"
));
ap
(
output_type
,
{
"--text"
},
...
...
@@ -259,7 +263,9 @@ struct loader
type
=
"binary"
;
}
if
(
type
==
"cpp"
)
if
(
type
==
"py"
)
p
.
print_py
(
*
os
);
else
if
(
type
==
"cpp"
)
p
.
print_cpp
(
*
os
);
else
if
(
type
==
"graphviz"
)
p
.
print_graph
(
*
os
,
brief
);
...
...
src/file_buffer.cpp
View file @
25e8cf0b
...
...
@@ -30,23 +30,31 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
template
<
class
T
>
T
generic_read_file
(
const
std
::
string
&
filename
)
T
generic_read_file
(
const
std
::
string
&
filename
,
size_t
offset
=
0
,
size_t
nbytes
=
0
)
{
std
::
ifstream
is
(
filename
,
std
::
ios
::
binary
|
std
::
ios
::
ate
);
std
::
streamsize
size
=
is
.
tellg
();
if
(
size
<
1
)
if
(
nbytes
==
0
)
{
// if there is a non-zero offset and nbytes is not set,
// calculate size of remaining bytes to read
nbytes
=
is
.
tellg
();
if
(
offset
>
nbytes
)
MIGRAPHX_THROW
(
"offset is larger than file size"
);
nbytes
-=
offset
;
}
if
(
nbytes
<
1
)
MIGRAPHX_THROW
(
"Invalid size for: "
+
filename
);
is
.
seekg
(
0
,
std
::
ios
::
beg
);
is
.
seekg
(
offset
,
std
::
ios
::
beg
);
T
buffer
(
size
,
0
);
if
(
not
is
.
read
(
&
buffer
[
0
],
size
))
T
buffer
(
nbytes
,
0
);
if
(
not
is
.
read
(
&
buffer
[
0
],
nbytes
))
MIGRAPHX_THROW
(
"Error reading file: "
+
filename
);
return
buffer
;
}
std
::
vector
<
char
>
read_buffer
(
const
std
::
string
&
filename
)
std
::
vector
<
char
>
read_buffer
(
const
std
::
string
&
filename
,
size_t
offset
,
size_t
nbytes
)
{
return
generic_read_file
<
std
::
vector
<
char
>>
(
filename
);
return
generic_read_file
<
std
::
vector
<
char
>>
(
filename
,
offset
,
nbytes
);
}
std
::
string
read_string
(
const
std
::
string
&
filename
)
...
...
src/include/migraphx/check_shapes.hpp
View file @
25e8cf0b
...
...
@@ -198,7 +198,7 @@ struct check_shapes
*/
const
check_shapes
&
same_ndims
()
const
{
if
(
not
this
->
same
([](
const
shape
&
s
)
{
return
s
.
max_lens
().
size
();
}))
if
(
not
this
->
same
([](
const
shape
&
s
)
{
return
s
.
ndim
();
}))
MIGRAPHX_THROW
(
prefix
()
+
"Number of dimensions do not match"
);
return
*
this
;
}
...
...
src/include/migraphx/file_buffer.hpp
View file @
25e8cf0b
...
...
@@ -31,7 +31,7 @@
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
std
::
vector
<
char
>
read_buffer
(
const
std
::
string
&
filename
);
std
::
vector
<
char
>
read_buffer
(
const
std
::
string
&
filename
,
size_t
offset
=
0
,
size_t
nbytes
=
0
);
std
::
string
read_string
(
const
std
::
string
&
filename
);
void
write_buffer
(
const
std
::
string
&
filename
,
const
char
*
buffer
,
std
::
size_t
size
);
...
...
src/include/migraphx/instruction.hpp
View file @
25e8cf0b
...
...
@@ -121,6 +121,8 @@ struct instruction
bool
can_eval
()
const
;
bool
is_undefined
()
const
;
argument
eval
(
bool
check_eval
=
true
)
const
;
void
finalize
(
context
&
ctx
);
...
...
src/include/migraphx/literal.hpp
View file @
25e8cf0b
...
...
@@ -80,6 +80,7 @@ struct literal : raw_data<literal>
fill
(
start
,
end
);
}
// Directly copies buffer of x
template
<
class
T
,
MIGRAPHX_REQUIRES
(
sizeof
(
T
)
==
1
)>
literal
(
const
shape
&
s
,
T
*
x
)
:
buffer
(
make_shared_array
<
char
>
(
s
.
bytes
())),
m_shape
(
s
)
{
...
...
@@ -107,25 +108,15 @@ struct literal : raw_data<literal>
std
::
shared_ptr
<
char
>
buffer
;
shape
m_shape
;
// Keeps the same data ordering as the given container
template
<
class
Iterator
>
void
fill
(
Iterator
start
,
Iterator
end
)
{
assert
(
std
::
distance
(
start
,
end
)
==
m_shape
.
elements
());
if
(
m_shape
.
standard
())
{
m_shape
.
visit_type
([
&
](
auto
as
)
{
std
::
copy
(
start
,
end
,
as
.
from
(
buffer
.
get
()));
});
}
else
{
auto
it
=
start
;
m_shape
.
visit_type
([
&
](
auto
as
)
{
auto
output
=
make_view
(
m_shape
,
as
.
from
(
buffer
.
get
()));
shape_for_each
(
output
.
get_shape
(),
[
&
](
const
auto
&
idx
)
{
output
(
idx
.
begin
(),
idx
.
end
())
=
*
it
;
// NOLINT(bugprone-signed-char-misuse)
it
++
;
std
::
copy
(
start
,
end
,
output
.
begin
());
});
});
}
}
};
...
...
src/include/migraphx/module.hpp
View file @
25e8cf0b
...
...
@@ -205,6 +205,12 @@ struct module
void
print_graph
(
std
::
ostream
&
os
,
bool
brief
=
false
)
const
;
void
print_py
(
std
::
ostream
&
os
)
const
;
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
print_py
(
std
::
ostream
&
os
,
const
std
::
string
&
mname
,
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
names
)
const
;
void
print_cpp
(
std
::
ostream
&
os
)
const
;
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
print_cpp
(
std
::
ostream
&
os
,
...
...
src/include/migraphx/op/argmax.hpp
View file @
25e8cf0b
...
...
@@ -30,6 +30,7 @@
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/dyn_output.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
@@ -56,13 +57,21 @@ struct argmax
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
auto
lens
=
inputs
[
0
].
lens
();
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
);
const
auto
&
s0
=
inputs
[
0
];
if
(
s0
.
dynamic
())
{
auto
dyn_dims
=
s0
.
dyn_dims
();
dyn_dims
[
axis
]
=
{
1
,
1
,
0
};
return
{
shape
::
int64_type
,
dyn_dims
};
}
else
{
auto
lens
=
s0
.
lens
();
lens
[
axis
]
=
1
;
return
{
shape
::
int64_type
,
lens
};
}
}
template
<
class
T
>
int64_t
calc_argmax
(
T
&
input
,
std
::
vector
<
std
::
size_t
>&
indices
,
size_t
item_num
)
const
...
...
@@ -79,19 +88,18 @@ struct argmax
max_index
=
i
;
}
}
return
max_index
;
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
out
put_shape
};
argument
result
{
dyn_out
.
com
put
ed
_shape
};
auto
batch_item_num
=
args
.
front
().
get_shape
().
lens
()[
axis
];
result
.
visit
([
&
](
auto
output
)
{
args
[
0
].
visit
([
&
](
auto
input
)
{
par_for
(
out
put_shape
.
elements
(),
[
&
](
auto
i
)
{
auto
data_idx
=
out
put_shape
.
multi
(
i
);
par_for
(
dyn_out
.
com
put
ed
_shape
.
elements
(),
[
&
](
auto
i
)
{
auto
data_idx
=
dyn_out
.
com
put
ed
_shape
.
multi
(
i
);
output
[
i
]
=
this
->
calc_argmax
(
input
,
data_idx
,
batch_item_num
);
});
});
...
...
src/include/migraphx/op/dot.hpp
View file @
25e8cf0b
...
...
@@ -28,6 +28,7 @@
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <migraphx/gemm.hpp>
#include <migraphx/dyn_output.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
@@ -38,41 +39,69 @@ struct dot
std
::
string
name
()
const
{
return
"dot"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
same_type
().
has
(
2
);
check_shapes
{
inputs
,
*
this
,
true
}.
same_type
().
same_ndims
().
has
(
2
);
const
shape
&
a
=
inputs
.
at
(
0
);
const
shape
&
b
=
inputs
.
at
(
1
);
auto
t
=
a
.
type
();
if
(
not
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[](
auto
s
)
{
return
s
.
lens
().
size
()
>=
2
;
}))
if
(
not
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[](
auto
s
)
{
return
s
.
ndim
()
>=
2
;
}))
{
MIGRAPHX_THROW
(
"DOT: dot only accept 2 or more dim
s operands
"
);
MIGRAPHX_THROW
(
"DOT: dot only accept
s operands with
2 or more dim
ensions
"
);
}
// only handle the case that the batch size of a and b are the same
if
(
a
.
dynamic
()
or
b
.
dynamic
())
{
auto
s0
=
a
.
to_dynamic
();
auto
s1
=
b
.
to_dynamic
();
if
(
not
std
::
equal
(
s0
.
dyn_dims
().
rbegin
()
+
2
,
s0
.
dyn_dims
().
rend
(),
s1
.
dyn_dims
().
rbegin
()
+
2
,
s1
.
dyn_dims
().
rend
()))
{
MIGRAPHX_THROW
(
"DOT: dynamic outer dimensions of A and B mismatch: {"
+
to_string_range
(
s0
.
dyn_dims
())
+
"} x {"
+
to_string_range
(
s1
.
dyn_dims
())
+
"}"
);
}
std
::
size_t
dim_0
=
s0
.
ndim
()
-
2
;
std
::
size_t
dim_1
=
s0
.
ndim
()
-
1
;
if
(
s0
.
dyn_dims
()[
dim_1
]
!=
s1
.
dyn_dims
()[
dim_0
])
{
MIGRAPHX_THROW
(
"DOT: dynamic inner dimensions do not match: {"
+
to_string_range
(
s0
.
dyn_dims
())
+
"} x {"
+
to_string_range
(
s1
.
dyn_dims
())
+
"}"
);
}
auto
out_dyn_dims
=
s0
.
dyn_dims
();
out_dyn_dims
[
dim_1
]
=
s1
.
dyn_dims
()[
dim_1
];
return
{
t
,
out_dyn_dims
};
}
else
{
// only handle the case that all the dimensions except the last two are the same
if
(
not
std
::
equal
(
a
.
lens
().
rbegin
()
+
2
,
a
.
lens
().
rend
(),
b
.
lens
().
rbegin
()
+
2
,
b
.
lens
().
rend
()))
{
MIGRAPHX_THROW
(
"DOT: batch size of A and B mismatch: {"
+
to_string_range
(
a
.
lens
())
+
"} x {"
+
to_string_range
(
b
.
lens
())
+
"}"
);
MIGRAPHX_THROW
(
"DOT: static outer dimensions of A and B mismatch: {"
+
to_string_range
(
a
.
lens
())
+
"} x {"
+
to_string_range
(
b
.
lens
())
+
"}"
);
}
std
::
size_t
dim_0
=
a
.
lens
().
size
()
-
2
;
std
::
size_t
dim_1
=
a
.
lens
().
size
()
-
1
;
std
::
size_t
dim_0
=
a
.
ndim
()
-
2
;
std
::
size_t
dim_1
=
a
.
ndim
()
-
1
;
if
(
a
.
lens
()[
dim_1
]
!=
b
.
lens
()[
dim_0
])
{
MIGRAPHX_THROW
(
"DOT: inner dimensions do not match: {"
+
to_string_range
(
a
.
lens
())
+
"} x {"
+
to_string_range
(
b
.
lens
())
+
"}"
);
MIGRAPHX_THROW
(
"DOT: static inner dimensions do not match: {"
+
to_string_range
(
a
.
lens
())
+
"} x {"
+
to_string_range
(
b
.
lens
())
+
"}"
);
}
auto
out_lens
=
a
.
lens
();
out_lens
[
dim_1
]
=
b
.
lens
()[
dim_1
];
return
{
t
,
out_lens
};
}
}
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
=
argument
{
out
put_shape
};
argument
result
=
argument
{
dyn_out
.
com
put
ed
_shape
};
visit_all
(
result
,
args
[
0
],
args
[
1
])(
[
&
](
auto
cmat
,
auto
amat
,
auto
bmat
)
{
gemm
(
cmat
,
amat
,
bmat
,
1.0
f
,
0.0
f
);
});
return
result
;
...
...
src/include/migraphx/op/flatten.hpp
View file @
25e8cf0b
...
...
@@ -55,17 +55,47 @@ struct flatten
std
::
string
name
()
const
{
return
"flatten"
;
}
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
).
standard
();
auto
&&
lens
=
inputs
.
front
().
lens
();
auto
x
=
std
::
accumulate
(
lens
.
begin
(),
lens
.
begin
()
+
axis
,
std
::
size_t
{
1
},
std
::
multiplies
<>
{});
auto
y
=
std
::
accumulate
(
lens
.
begin
()
+
axis
,
lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<>
{});
return
{
inputs
.
at
(
0
).
type
(),
{
x
,
y
}};
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
);
auto
s
=
inputs
[
0
];
if
(
s
.
dynamic
())
{
auto
min_lens
=
s
.
min_lens
();
auto
max_lens
=
s
.
max_lens
();
auto
opt_lens
=
s
.
opt_lens
();
// If any of the opt values is 0, output opt will be 0
shape
::
dynamic_dimension
x
=
{
std
::
accumulate
(
min_lens
.
begin
(),
min_lens
.
begin
()
+
axis
,
std
::
size_t
{
1
},
std
::
multiplies
<>
{}),
std
::
accumulate
(
max_lens
.
begin
(),
max_lens
.
begin
()
+
axis
,
std
::
size_t
{
1
},
std
::
multiplies
<>
{}),
std
::
accumulate
(
opt_lens
.
begin
(),
opt_lens
.
begin
()
+
axis
,
std
::
size_t
{
1
},
std
::
multiplies
<>
{})};
shape
::
dynamic_dimension
y
=
{
std
::
accumulate
(
min_lens
.
begin
()
+
axis
,
min_lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<>
{}),
std
::
accumulate
(
max_lens
.
begin
()
+
axis
,
max_lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<>
{}),
std
::
accumulate
(
opt_lens
.
begin
()
+
axis
,
opt_lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<>
{}),
};
return
{
s
.
type
(),
{
x
,
y
}};
}
else
{
check_shapes
{
inputs
,
*
this
}.
standard
();
auto
&&
lens
=
s
.
lens
();
auto
x
=
std
::
accumulate
(
lens
.
begin
(),
lens
.
begin
()
+
axis
,
std
::
size_t
{
1
},
std
::
multiplies
<>
{});
auto
y
=
std
::
accumulate
(
lens
.
begin
()
+
axis
,
lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<>
{});
return
{
s
.
type
(),
{
x
,
y
}};
}
}
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
return
args
[
0
].
reshape
(
out
put_shape
);
return
args
[
0
].
reshape
(
dyn_out
.
com
put
ed
_shape
);
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
...
...
src/include/migraphx/op/pad.hpp
View file @
25e8cf0b
...
...
@@ -59,19 +59,30 @@ struct pad
std
::
string
name
()
const
{
return
"pad"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
auto
&&
idims
=
inputs
.
front
().
lens
();
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
);
const
auto
&
s0
=
inputs
.
front
();
if
(
s0
.
dynamic
())
{
auto
out_dyn_dims
=
s0
.
dyn_dims
();
for
(
std
::
size_t
i
=
0
;
i
<
s0
.
ndim
();
++
i
)
{
out_dyn_dims
[
i
]
+=
pads
[
i
]
+
pads
[
i
+
s0
.
ndim
()];
}
return
{
s0
.
type
(),
out_dyn_dims
};
}
else
{
auto
&&
idims
=
s0
.
lens
();
std
::
vector
<
std
::
size_t
>
rdims
(
idims
.
begin
(),
idims
.
end
());
std
::
size_t
num_dims
=
rdims
.
size
();
for
(
std
::
size_t
i
=
0
;
i
<
num_dims
;
i
++
)
{
rdims
[
i
]
+=
pads
[
i
]
+
pads
[
i
+
num_dims
];
}
shape
s
{
inputs
.
front
().
type
(),
rdims
};
shape
s
{
s0
.
type
(),
rdims
};
return
s
;
}
}
std
::
size_t
pad_ndims
()
const
{
...
...
src/include/migraphx/op/pooling.hpp
View file @
25e8cf0b
...
...
@@ -31,7 +31,7 @@
#include <migraphx/argument.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/
int_divide
.hpp>
#include <migraphx/
dyn_output
.hpp>
#include <cmath>
#include <utility>
...
...
@@ -49,6 +49,9 @@ struct pooling
bool
ceil_mode
=
false
;
int
lp_order
=
2
;
// Global pooling with dynamic shape input
bool
dyn_global
=
false
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
...
...
@@ -57,7 +60,8 @@ struct pooling
f
(
self
.
stride
,
"stride"
),
f
(
self
.
lengths
,
"lengths"
),
f
(
self
.
ceil_mode
,
"ceil_mode"
),
f
(
self
.
lp_order
,
"lp_order"
));
f
(
self
.
lp_order
,
"lp_order"
),
f
(
self
.
dyn_global
,
"dyn_global"
));
}
std
::
string
name
()
const
{
return
"pooling"
;
}
...
...
@@ -65,51 +69,111 @@ struct pooling
void
check_attribute_size
()
const
{
if
((
padding
.
size
()
!=
stride
.
size
()
and
(
padding
.
size
()
/
2
)
!=
stride
.
size
())
or
stride
.
size
()
!=
lengths
.
size
())
(
not
dyn_global
and
stride
.
size
()
!=
lengths
.
size
())
)
{
MIGRAPHX_THROW
(
"POOLING: inconsistent attribute sizes"
);
}
}
size_t
kdims
()
const
{
check_attribute_size
();
return
stride
.
size
();
}
value
attributes
()
const
{
return
{{
"normalize_padding"
,
"padding"
}};
}
std
::
vector
<
std
::
size_t
>
calc_spatial_dim_out
(
const
std
::
vector
<
std
::
size_t
>&
input_lens
,
std
::
size_t
kdims
)
const
{
std
::
vector
<
std
::
size_t
>
output_lens
{};
for
(
size_t
i
=
0
;
i
<
kdims
;
++
i
)
{
if
(
input_lens
[
i
+
2
]
==
0
)
{
// handle opt = 0
output_lens
.
push_back
(
0
);
}
else
{
std
::
size_t
padding_factor
=
2
*
padding
[
i
];
if
(
padding
.
size
()
==
2
*
kdims
)
padding_factor
=
padding
[
i
]
+
padding
[
i
+
kdims
];
assert
(
input_lens
[
i
+
2
]
+
padding_factor
>=
lengths
[
i
]);
std
::
size_t
dim_size
=
input_lens
[
i
+
2
]
+
padding_factor
-
lengths
[
i
];
std
::
size_t
len
=
(
ceil_mode
)
?
dim_size
/
stride
[
i
]
+
static_cast
<
std
::
size_t
>
((
dim_size
%
stride
[
i
]
!=
0
))
// ceil uint divide
:
dim_size
/
stride
[
i
];
// floor divide
output_lens
.
push_back
(
len
+
1
);
}
}
return
output_lens
;
}
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
);
check_attribute_size
();
const
shape
&
input
=
inputs
.
at
(
0
);
auto
input_lens
=
input
.
lens
();
size_t
kdims
=
input_lens
.
size
()
-
2
;
auto
input_size
=
inputs
[
0
].
lens
().
size
();
auto
padding_size
=
padding
.
size
();
if
(
input_size
!=
padding_size
/
2
+
2
and
input_size
!=
padding_size
+
2
)
size_t
kdims
=
input
.
ndim
()
-
2
;
if
(
input
.
ndim
()
!=
padding_size
/
2
+
2
and
input
.
ndim
()
!=
padding_size
+
2
)
{
MIGRAPHX_THROW
(
"POOLING: input and attribute size mismatch!"
);
}
std
::
vector
<
std
::
size_t
>
output_lens
(
input_lens
.
begin
(),
input_lens
.
begin
()
+
2
);
for
(
size_t
i
=
0
;
i
<
kdims
;
i
++
)
if
(
input
.
dynamic
())
{
std
::
ptrdiff_t
dim_size
;
auto
padding_factor
=
2
*
padding
[
i
];
if
(
padding_size
==
2
*
kdims
)
padding_factor
=
padding
[
i
]
+
padding
[
i
+
kdims
];
dim_size
=
input_lens
[
i
+
2
]
+
padding_factor
-
lengths
[
i
];
assert
(
dim_size
>=
0
);
std
::
size_t
len
=
(
ceil_mode
)
?
ceil_divide
<
std
::
ptrdiff_t
>
(
dim_size
,
stride
[
i
])
:
floor_divide
<
std
::
ptrdiff_t
>
(
dim_size
,
stride
[
i
]);
output_lens
.
push_back
(
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
1
,
len
+
1
)));
auto
input_dyn_dims
=
input
.
dyn_dims
();
std
::
vector
<
shape
::
dynamic_dimension
>
output_dyn_dims
(
input_dyn_dims
.
begin
(),
input_dyn_dims
.
begin
()
+
2
);
if
(
dyn_global
)
{
for
(
size_t
i
=
0
;
i
<
kdims
;
++
i
)
{
output_dyn_dims
.
push_back
(
shape
::
dynamic_dimension
{
1
,
1
,
1
});
}
return
inputs
[
0
].
with_lens
(
output_lens
);
return
{
input
.
type
(),
output_dyn_dims
};
}
else
{
auto
min_spatial_dims
=
calc_spatial_dim_out
(
input
.
min_lens
(),
kdims
);
auto
max_spatial_dims
=
calc_spatial_dim_out
(
input
.
max_lens
(),
kdims
);
auto
opt_spatial_dims
=
calc_spatial_dim_out
(
input
.
opt_lens
(),
kdims
);
for
(
size_t
i
=
0
;
i
<
kdims
;
++
i
)
{
output_dyn_dims
.
push_back
(
shape
::
dynamic_dimension
{
min_spatial_dims
[
i
],
max_spatial_dims
[
i
],
opt_spatial_dims
[
i
]});
}
return
{
input
.
type
(),
output_dyn_dims
};
}
}
else
{
auto
input_lens
=
input
.
lens
();
size_t
kdims
()
const
std
::
vector
<
std
::
size_t
>
output_lens
(
input_lens
.
begin
(),
input_lens
.
begin
()
+
2
);
// Used for when normalize_compute_shape() is called again at model eval time
// for an originally dynamic shape. Since kernel shape is not used with dyn_global.
if
(
dyn_global
)
{
check_attribute_size
();
return
stride
.
size
();
for
(
size_t
i
=
0
;
i
<
kdims
;
++
i
)
{
output_lens
.
push_back
(
1
);
}
return
{
input
.
type
(),
output_lens
};
}
else
{
auto
output_spatial_lens
=
calc_spatial_dim_out
(
input_lens
,
kdims
);
output_lens
.
insert
(
output_lens
.
end
(),
output_spatial_lens
.
begin
(),
output_spatial_lens
.
end
());
return
inputs
[
0
].
with_lens
(
output_lens
);
}
}
}
struct
lpnorm_pool
...
...
@@ -158,7 +222,11 @@ struct pooling
};
template
<
class
Type
,
class
Out
,
class
In
,
class
Op
>
void
calc_pooling
(
const
shape
&
output_shape
,
Out
&
output
,
const
In
&
input
,
Op
op
)
const
void
calc_pooling
(
const
shape
&
output_shape
,
Out
&
output
,
const
In
&
input
,
const
std
::
vector
<
std
::
size_t
>&
kernel_dims
,
Op
op
)
const
{
auto
in_s
=
input
.
get_shape
();
auto
in_lens
=
in_s
.
lens
();
...
...
@@ -172,7 +240,7 @@ struct pooling
auto
d_2
=
dim
-
2
;
int
start
=
static_cast
<
int
>
(
idx_o
[
dim
]
*
stride
[
d_2
])
-
static_cast
<
int
>
(
padding
[
d_2
]);
int
end
=
std
::
min
(
start
+
length
s
[
d_2
],
in_lens
[
dim
]);
int
end
=
std
::
min
(
start
+
kernel_dim
s
[
d_2
],
in_lens
[
dim
]);
start
=
std
::
max
(
start
,
0
);
win_start
.
push_back
(
start
);
win_size
.
push_back
(
end
-
start
);
...
...
@@ -198,21 +266,32 @@ struct pooling
});
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
dyn_out
.
computed_shape
};
auto
input_lens
=
args
[
0
].
get_shape
().
lens
();
std
::
vector
<
std
::
size_t
>
kernel_dims
;
if
(
dyn_global
)
{
argument
result
{
output_shape
};
kernel_dims
.
insert
(
kernel_dims
.
end
(),
input_lens
.
begin
()
+
2
,
input_lens
.
end
());
}
else
{
kernel_dims
=
this
->
lengths
;
}
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
input
)
{
using
type
=
typename
decltype
(
output
)
::
value_type
;
switch
(
mode
)
{
case
migraphx
::
op
::
pooling_mode
::
average
:
calc_pooling
<
type
>
(
out
put_shape
,
output
,
input
,
avg_pool
{});
calc_pooling
<
type
>
(
dyn_out
.
com
put
ed
_shape
,
output
,
input
,
kernel_dims
,
avg_pool
{});
break
;
case
migraphx
::
op
::
pooling_mode
::
max
:
calc_pooling
<
type
>
(
out
put_shape
,
output
,
input
,
max_pool
{});
calc_pooling
<
type
>
(
dyn_out
.
com
put
ed
_shape
,
output
,
input
,
kernel_dims
,
max_pool
{});
break
;
case
migraphx
::
op
::
pooling_mode
::
lpnorm
:
calc_pooling
<
type
>
(
output_shape
,
output
,
input
,
lpnorm_pool
{
lp_order
});
calc_pooling
<
type
>
(
dyn_out
.
computed_shape
,
output
,
input
,
kernel_dims
,
lpnorm_pool
{
lp_order
});
break
;
}
});
...
...
src/include/migraphx/op/reduce_op.hpp
View file @
25e8cf0b
...
...
@@ -26,6 +26,7 @@
#include <migraphx/op/name.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/tensor_view.hpp>
#include <migraphx/shape_for_each.hpp>
...
...
@@ -105,26 +106,49 @@ struct reduce_op : op_name<Derived>
return
tuned_axes
;
}
/**
* @brief returns a shape in which the axis or axes named
* for reduction by this op are set, to size 1.
*
* @param inputs list of input shapes
* @return shape
*/
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
);
auto
s
=
inputs
.
at
(
0
);
if
(
s
.
dynamic
())
{
auto
output_dyn_dims
=
s
.
dyn_dims
();
auto
tuned_axes
=
tune_axes
(
output_dyn_dims
.
size
());
for
(
const
auto
&
axis
:
tuned_axes
)
{
// At the time of writing, there's no functional difference between
// optimum of 0 (no opt) or 1.
output_dyn_dims
[
axis
]
=
{
1
,
1
,
0
};
}
return
shape
{
s
.
type
(),
output_dyn_dims
};
}
else
{
auto
lens
=
s
.
lens
();
auto
tuned_axes
=
tune_axes
(
lens
.
size
());
for
(
auto
axis
:
tuned_axes
)
for
(
const
auto
&
axis
:
tuned_axes
)
{
lens
[
axis
]
=
1
;
}
return
inputs
[
0
].
with_lens
(
lens
);
}
}
template
<
class
T
>
void
tune_dims
(
const
std
::
vector
<
int64_t
>&
tuned_axes
,
const
std
::
vector
<
T
>&
in_lens
,
std
::
vector
<
T
>&
out_lens
)
const
{
for
(
auto
axis
:
tuned_axes
)
for
(
const
auto
&
axis
:
tuned_axes
)
{
out_lens
[
axis
]
=
in_lens
[
axis
];
}
...
...
@@ -151,17 +175,17 @@ struct reduce_op : op_name<Derived>
static_cast
<
const
Derived
&>
(
*
this
).
output
(
batch_shape
)(
val
);
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
out
put_shape
};
argument
result
{
dyn_out
.
com
put
ed
_shape
};
auto
arg_lens
=
args
.
front
().
get_shape
().
lens
();
auto
tuned_axes
=
tune_axes
(
arg_lens
.
size
());
std
::
vector
<
std
::
size_t
>
batch_lens
(
out
put_shape
.
lens
().
size
(),
1
);
std
::
vector
<
std
::
size_t
>
batch_lens
(
dyn_out
.
com
put
ed
_shape
.
lens
().
size
(),
1
);
tune_dims
(
tuned_axes
,
arg_lens
,
batch_lens
);
shape
batch_shape
{
out
put_shape
.
type
(),
batch_lens
};
shape
batch_shape
{
dyn_out
.
com
put
ed
_shape
.
type
(),
batch_lens
};
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
input
)
{
par_for
(
out
put_shape
.
elements
(),
[
&
](
auto
i
)
{
auto
out_idx
=
out
put_shape
.
multi
(
i
);
par_for
(
dyn_out
.
com
put
ed
_shape
.
elements
(),
[
&
](
auto
i
)
{
auto
out_idx
=
dyn_out
.
com
put
ed
_shape
.
multi
(
i
);
this
->
reduce
(
input
,
batch_shape
,
tuned_axes
,
out_idx
,
output
);
});
});
...
...
Prev
1
2
3
4
5
…
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment