Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
a5c87ec5
Commit
a5c87ec5
authored
Jan 17, 2023
by
Paul
Browse files
Merge branch 'develop' into jit-reduce-reg
parents
dae94657
3f49f8eb
Changes
57
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
515 additions
and
102 deletions
+515
-102
.github/workflows/add-to-project.yaml
.github/workflows/add-to-project.yaml
+19
-0
Dockerfile
Dockerfile
+1
-1
Jenkinsfile
Jenkinsfile
+18
-8
src/include/migraphx/op/pad.hpp
src/include/migraphx/op/pad.hpp
+21
-10
src/include/migraphx/op/reduce_op.hpp
src/include/migraphx/op/reduce_op.hpp
+39
-15
src/include/migraphx/op/reshape.hpp
src/include/migraphx/op/reshape.hpp
+71
-7
src/include/migraphx/shape.hpp
src/include/migraphx/shape.hpp
+7
-0
src/onnx/onnx_parser.cpp
src/onnx/onnx_parser.cpp
+13
-3
src/onnx/parse_matmul.cpp
src/onnx/parse_matmul.cpp
+58
-34
src/onnx/parse_pad.cpp
src/onnx/parse_pad.cpp
+6
-0
src/onnx/parse_reduce_op.cpp
src/onnx/parse_reduce_op.cpp
+1
-2
src/onnx/parse_reshape.cpp
src/onnx/parse_reshape.cpp
+1
-1
src/shape.cpp
src/shape.cpp
+42
-0
src/simplify_algebra.cpp
src/simplify_algebra.cpp
+14
-2
src/simplify_reshapes.cpp
src/simplify_reshapes.cpp
+11
-4
src/targets/gpu/jit/gather.cpp
src/targets/gpu/jit/gather.cpp
+89
-0
src/targets/gpu/jit/reduce.cpp
src/targets/gpu/jit/reduce.cpp
+13
-4
src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp
...rgets/gpu/kernels/include/migraphx/kernels/functional.hpp
+8
-0
src/targets/gpu/kernels/include/migraphx/kernels/gather.hpp
src/targets/gpu/kernels/include/migraphx/kernels/gather.hpp
+64
-0
src/targets/gpu/kernels/include/migraphx/kernels/math.hpp
src/targets/gpu/kernels/include/migraphx/kernels/math.hpp
+19
-11
No files found.
.github/workflows/add-to-project.yaml
0 → 100644
View file @
a5c87ec5
name
:
Add items to GH project
on
:
pull_request
:
types
:
-
opened
issues
:
types
:
-
opened
jobs
:
add-to-project
:
name
:
Add PRs and issues to MIGX project
runs-on
:
ubuntu-latest
steps
:
-
uses
:
actions/add-to-project@v0.4.0
with
:
project-url
:
https://github.com/orgs/ROCmSoftwarePlatform/projects/20
github-token
:
${{ secrets.TEST_PR_WORKFLOW }}
Dockerfile
View file @
a5c87ec5
...
@@ -87,7 +87,7 @@ RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXR
...
@@ -87,7 +87,7 @@ RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXR
ADD
tools/build_and_test_onnxrt.sh /onnxruntime/build_and_test_onnxrt.sh
ADD
tools/build_and_test_onnxrt.sh /onnxruntime/build_and_test_onnxrt.sh
RUN
cget
-p
/usr/local
install
ROCmSoftwarePlatform/rocMLIR@
0f38fb33f518b53b94b541feb9b079668c5518e8
-DBUILD_MIXR_TARGET
=
On
-DLLVM_ENABLE_ZSTD
=
Off
-DLLVM_ENABLE_THREADS
=
Off
RUN
cget
-p
/usr/local
install
ROCmSoftwarePlatform/rocMLIR@
78b706fe9879587ab98b6614ae539265374a3fae
-DBUILD_MIXR_TARGET
=
On
-DLLVM_ENABLE_ZSTD
=
Off
-DLLVM_ENABLE_THREADS
=
Off
ENV
MIOPEN_FIND_DB_PATH=/tmp/miopen/find-db
ENV
MIOPEN_FIND_DB_PATH=/tmp/miopen/find-db
ENV
MIOPEN_USER_DB_PATH=/tmp/miopen/user-db
ENV
MIOPEN_USER_DB_PATH=/tmp/miopen/user-db
...
...
Jenkinsfile
View file @
a5c87ec5
...
@@ -11,16 +11,22 @@ def rocmtestnode(Map conf) {
...
@@ -11,16 +11,22 @@ def rocmtestnode(Map conf) {
def
image
=
'migraphxlib'
def
image
=
'migraphxlib'
env
.
CCACHE_COMPRESSLEVEL
=
7
env
.
CCACHE_COMPRESSLEVEL
=
7
env
.
CCACHE_DIR
=
ccache
env
.
CCACHE_DIR
=
ccache
def
cmake_build
=
{
compiler
,
flags
->
def
cmake_build
=
{
bconf
->
def
compiler
=
bconf
.
get
(
"compiler"
,
"/opt/rocm/llvm/bin/clang++"
)
def
flags
=
bconf
.
get
(
"flags"
,
""
)
def
gpu_debug
=
bconf
.
get
(
"gpu_debug"
,
"0"
)
def
cmd
=
"""
def
cmd
=
"""
env
ulimit -c unlimited
ulimit -c unlimited
echo "leak:dnnl::impl::malloc" > suppressions.txt
echo "leak:dnnl::impl::malloc" > suppressions.txt
export LSAN_OPTIONS="suppressions=\$(pwd)/suppressions.txt"
export LSAN_OPTIONS="suppressions=\$(pwd)/suppressions.txt"
export MIGRAPHX_GPU_DEBUG=${gpu_debug}
export CXX=${compiler}
export CXXFLAGS='-Werror'
env
rm -rf build
rm -rf build
mkdir build
mkdir build
cd build
cd build
CXX=${compiler} CXXFLAGS='-Werror'
cmake -DCMAKE_C_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_COMPILER_LAUNCHER=ccache ${flags} ..
cmake -DCMAKE_C_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_COMPILER_LAUNCHER=ccache ${flags} ..
make -j\$(nproc) generate all doc package check VERBOSE=1
make -j\$(nproc) generate all doc package check VERBOSE=1
"""
"""
echo
cmd
echo
cmd
...
@@ -93,28 +99,32 @@ rocmtest clang_debug: rocmnode('vega') { cmake_build ->
...
@@ -93,28 +99,32 @@ rocmtest clang_debug: rocmnode('vega') { cmake_build ->
stage
(
'Hip Clang Debug'
)
{
stage
(
'Hip Clang Debug'
)
{
def
sanitizers
=
"undefined"
def
sanitizers
=
"undefined"
def
debug_flags
=
"-g -O2 -fsanitize=${sanitizers} -fno-sanitize-recover=${sanitizers}"
def
debug_flags
=
"-g -O2 -fsanitize=${sanitizers} -fno-sanitize-recover=${sanitizers}"
cmake_build
(
"/opt/rocm/llvm/bin/clang++"
,
"-DCMAKE_BUILD_TYPE=debug -DMIGRAPHX_ENABLE_PYTHON=Off -DCMAKE_CXX_FLAGS_DEBUG='${debug_flags}' -DCMAKE_C_FLAGS_DEBUG='${debug_flags}'"
)
cmake_build
(
flags:
"-DCMAKE_BUILD_TYPE=debug -DMIGRAPHX_ENABLE_PYTHON=Off -DCMAKE_CXX_FLAGS_DEBUG='${debug_flags}' -DCMAKE_C_FLAGS_DEBUG='${debug_flags}'"
)
}
},
clang_gpu_debug:
rocmnode
(
'vega'
)
{
cmake_build
->
stage
(
'Hip Clang GPU Debug'
)
{
cmake_build
(
flags:
"-DCMAKE_BUILD_TYPE=release"
,
gpu_debug:
true
)
}
}
},
clang_release:
rocmnode
(
'vega'
)
{
cmake_build
->
},
clang_release:
rocmnode
(
'vega'
)
{
cmake_build
->
stage
(
'Hip Clang Release'
)
{
stage
(
'Hip Clang Release'
)
{
cmake_build
(
"/opt/rocm/llvm/bin/clang++"
,
"-DCMAKE_BUILD_TYPE=release"
)
cmake_build
(
flags:
"-DCMAKE_BUILD_TYPE=release"
)
stash
includes:
'build/*.deb'
,
name:
'migraphx-package'
stash
includes:
'build/*.deb'
,
name:
'migraphx-package'
}
}
},
mlir_debug:
rocmnode
(
'vega'
)
{
cmake_build
->
},
mlir_debug:
rocmnode
(
'vega'
)
{
cmake_build
->
stage
(
'MLIR Debug'
)
{
stage
(
'MLIR Debug'
)
{
def
sanitizers
=
"undefined"
def
sanitizers
=
"undefined"
def
debug_flags
=
"-g -O2 -fsanitize=${sanitizers} -fno-sanitize-recover=${sanitizers}"
def
debug_flags
=
"-g -O2 -fsanitize=${sanitizers} -fno-sanitize-recover=${sanitizers}"
cmake_build
(
"/opt/rocm/llvm/bin/clang++"
,
"-DCMAKE_BUILD_TYPE=debug -DMIGRAPHX_ENABLE_PYTHON=Off -DMIGRAPHX_ENABLE_MLIR=On -DCMAKE_CXX_FLAGS_DEBUG='${debug_flags}' -DCMAKE_C_FLAGS_DEBUG='${debug_flags}'"
)
cmake_build
(
flags:
"-DCMAKE_BUILD_TYPE=debug -DMIGRAPHX_ENABLE_PYTHON=Off -DMIGRAPHX_ENABLE_MLIR=On -DCMAKE_CXX_FLAGS_DEBUG='${debug_flags}' -DCMAKE_C_FLAGS_DEBUG='${debug_flags}'"
)
}
}
},
clang_asan:
rocmnode
(
'nogpu'
)
{
cmake_build
->
},
clang_asan:
rocmnode
(
'nogpu'
)
{
cmake_build
->
stage
(
'Clang ASAN'
)
{
stage
(
'Clang ASAN'
)
{
def
sanitizers
=
"undefined,address"
def
sanitizers
=
"undefined,address"
def
debug_flags
=
"-g -O2 -fno-omit-frame-pointer -fsanitize=${sanitizers} -fno-sanitize-recover=${sanitizers}"
def
debug_flags
=
"-g -O2 -fno-omit-frame-pointer -fsanitize=${sanitizers} -fno-sanitize-recover=${sanitizers}"
cmake_build
(
"/opt/rocm/llvm/bin/clang++"
,
"-DCMAKE_BUILD_TYPE=debug -DMIGRAPHX_ENABLE_PYTHON=Off -DMIGRAPHX_ENABLE_GPU=Off -DMIGRAPHX_ENABLE_CPU=On -DCMAKE_CXX_FLAGS_DEBUG='${debug_flags}' -DCMAKE_C_FLAGS_DEBUG='${debug_flags}'"
)
cmake_build
(
flags:
"-DCMAKE_BUILD_TYPE=debug -DMIGRAPHX_ENABLE_PYTHON=Off -DMIGRAPHX_ENABLE_GPU=Off -DMIGRAPHX_ENABLE_CPU=On -DCMAKE_CXX_FLAGS_DEBUG='${debug_flags}' -DCMAKE_C_FLAGS_DEBUG='${debug_flags}'"
)
}
}
}
//, clang_release_navi: rocmnode('navi21') { cmake_build ->
}
//, clang_release_navi: rocmnode('navi21') { cmake_build ->
// stage('HIP Clang Release Navi') {
// stage('HIP Clang Release Navi') {
// cmake_build(
"/opt/rocm/llvm/bin/clang++",
"-DCMAKE_BUILD_TYPE=release")
// cmake_build(
flags:
"-DCMAKE_BUILD_TYPE=release")
// }
// }
//}
//}
...
...
src/include/migraphx/op/pad.hpp
View file @
a5c87ec5
...
@@ -59,18 +59,29 @@ struct pad
...
@@ -59,18 +59,29 @@ struct pad
std
::
string
name
()
const
{
return
"pad"
;
}
std
::
string
name
()
const
{
return
"pad"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
);
auto
&&
idims
=
inputs
.
front
().
lens
();
const
auto
&
s0
=
inputs
.
front
();
std
::
vector
<
std
::
size_t
>
rdims
(
idims
.
begin
(),
idims
.
end
());
if
(
s0
.
dynamic
())
std
::
size_t
num_dims
=
rdims
.
size
();
for
(
std
::
size_t
i
=
0
;
i
<
num_dims
;
i
++
)
{
{
rdims
[
i
]
+=
pads
[
i
]
+
pads
[
i
+
num_dims
];
auto
out_dyn_dims
=
s0
.
dyn_dims
();
for
(
std
::
size_t
i
=
0
;
i
<
s0
.
ndim
();
++
i
)
{
out_dyn_dims
[
i
]
+=
pads
[
i
]
+
pads
[
i
+
s0
.
ndim
()];
}
return
{
s0
.
type
(),
out_dyn_dims
};
}
else
{
auto
&&
idims
=
s0
.
lens
();
std
::
vector
<
std
::
size_t
>
rdims
(
idims
.
begin
(),
idims
.
end
());
std
::
size_t
num_dims
=
rdims
.
size
();
for
(
std
::
size_t
i
=
0
;
i
<
num_dims
;
i
++
)
{
rdims
[
i
]
+=
pads
[
i
]
+
pads
[
i
+
num_dims
];
}
shape
s
{
s0
.
type
(),
rdims
};
return
s
;
}
}
shape
s
{
inputs
.
front
().
type
(),
rdims
};
return
s
;
}
}
std
::
size_t
pad_ndims
()
const
std
::
size_t
pad_ndims
()
const
...
...
src/include/migraphx/op/reduce_op.hpp
View file @
a5c87ec5
...
@@ -26,6 +26,7 @@
...
@@ -26,6 +26,7 @@
#include <migraphx/op/name.hpp>
#include <migraphx/op/name.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/tensor_view.hpp>
#include <migraphx/tensor_view.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/shape_for_each.hpp>
...
@@ -105,18 +106,41 @@ struct reduce_op : op_name<Derived>
...
@@ -105,18 +106,41 @@ struct reduce_op : op_name<Derived>
return
tuned_axes
;
return
tuned_axes
;
}
}
/**
* @brief returns a shape in which the axis or axes named
* for reduction by this op are set, to size 1.
*
* @param inputs list of input shapes
* @return shape
*/
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
);
auto
s
=
inputs
.
at
(
0
);
auto
s
=
inputs
.
at
(
0
);
auto
lens
=
s
.
lens
();
if
(
s
.
dynamic
())
auto
tuned_axes
=
tune_axes
(
lens
.
size
());
for
(
auto
axis
:
tuned_axes
)
{
{
lens
[
axis
]
=
1
;
auto
output_dyn_dims
=
s
.
dyn_dims
();
auto
tuned_axes
=
tune_axes
(
output_dyn_dims
.
size
());
for
(
const
auto
&
axis
:
tuned_axes
)
{
// At the time of writing, there's no functional difference between
// optimum of 0 (no opt) or 1.
output_dyn_dims
[
axis
]
=
{
1
,
1
,
0
};
}
return
shape
{
s
.
type
(),
output_dyn_dims
};
}
else
{
auto
lens
=
s
.
lens
();
auto
tuned_axes
=
tune_axes
(
lens
.
size
());
for
(
const
auto
&
axis
:
tuned_axes
)
{
lens
[
axis
]
=
1
;
}
return
inputs
[
0
].
with_lens
(
lens
);
}
}
return
inputs
[
0
].
with_lens
(
lens
);
}
}
template
<
class
T
>
template
<
class
T
>
...
@@ -124,7 +148,7 @@ struct reduce_op : op_name<Derived>
...
@@ -124,7 +148,7 @@ struct reduce_op : op_name<Derived>
const
std
::
vector
<
T
>&
in_lens
,
const
std
::
vector
<
T
>&
in_lens
,
std
::
vector
<
T
>&
out_lens
)
const
std
::
vector
<
T
>&
out_lens
)
const
{
{
for
(
auto
axis
:
tuned_axes
)
for
(
const
auto
&
axis
:
tuned_axes
)
{
{
out_lens
[
axis
]
=
in_lens
[
axis
];
out_lens
[
axis
]
=
in_lens
[
axis
];
}
}
...
@@ -151,17 +175,17 @@ struct reduce_op : op_name<Derived>
...
@@ -151,17 +175,17 @@ struct reduce_op : op_name<Derived>
static_cast
<
const
Derived
&>
(
*
this
).
output
(
batch_shape
)(
val
);
static_cast
<
const
Derived
&>
(
*
this
).
output
(
batch_shape
)(
val
);
}
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
{
argument
result
{
out
put_shape
};
argument
result
{
dyn_out
.
com
put
ed
_shape
};
auto
arg_lens
=
args
.
front
().
get_shape
().
lens
();
auto
arg_lens
=
args
.
front
().
get_shape
().
lens
();
auto
tuned_axes
=
tune_axes
(
arg_lens
.
size
());
auto
tuned_axes
=
tune_axes
(
arg_lens
.
size
());
std
::
vector
<
std
::
size_t
>
batch_lens
(
out
put_shape
.
lens
().
size
(),
1
);
std
::
vector
<
std
::
size_t
>
batch_lens
(
dyn_out
.
com
put
ed
_shape
.
lens
().
size
(),
1
);
tune_dims
(
tuned_axes
,
arg_lens
,
batch_lens
);
tune_dims
(
tuned_axes
,
arg_lens
,
batch_lens
);
shape
batch_shape
{
out
put_shape
.
type
(),
batch_lens
};
shape
batch_shape
{
dyn_out
.
com
put
ed
_shape
.
type
(),
batch_lens
};
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
input
)
{
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
input
)
{
par_for
(
out
put_shape
.
elements
(),
[
&
](
auto
i
)
{
par_for
(
dyn_out
.
com
put
ed
_shape
.
elements
(),
[
&
](
auto
i
)
{
auto
out_idx
=
out
put_shape
.
multi
(
i
);
auto
out_idx
=
dyn_out
.
com
put
ed
_shape
.
multi
(
i
);
this
->
reduce
(
input
,
batch_shape
,
tuned_axes
,
out_idx
,
output
);
this
->
reduce
(
input
,
batch_shape
,
tuned_axes
,
out_idx
,
output
);
});
});
});
});
...
...
src/include/migraphx/op/reshape.hpp
View file @
a5c87ec5
...
@@ -28,6 +28,7 @@
...
@@ -28,6 +28,7 @@
#include <migraphx/argument.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/value.hpp>
#include <migraphx/dyn_output.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -46,14 +47,60 @@ struct reshape
...
@@ -46,14 +47,60 @@ struct reshape
value
attributes
()
const
{
return
{{
"require_std_shape"
,
true
}};
}
value
attributes
()
const
{
return
{{
"require_std_shape"
,
true
}};
}
std
::
string
name
()
const
{
return
"reshape"
;
}
std
::
string
name
()
const
{
return
"reshape"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
dyn_compute_shape
(
shape
s0
)
const
{
auto
dyn_dims
=
s0
.
dyn_dims
();
auto
num_not_fixed
=
std
::
count_if
(
dyn_dims
.
cbegin
(),
dyn_dims
.
cend
(),
[](
auto
dd
)
{
return
not
dd
.
is_fixed
();
});
if
(
num_not_fixed
!=
1
)
{
MIGRAPHX_THROW
(
"Reshape: Only supports one non-fixed dynamic_dimension"
);
}
// track number of fixed elements in input and output
std
::
size_t
num_dims_ele
=
1
;
std
::
size_t
num_dd_ele
=
1
;
for
(
std
::
size_t
i
=
0
;
i
<
dyn_dims
.
size
();
++
i
)
{
if
(
dyn_dims
[
i
].
is_fixed
())
{
num_dims_ele
*=
dims
[
i
];
num_dd_ele
*=
dyn_dims
[
i
].
min
;
}
else
{
if
(
dims
[
i
]
!=
0
and
dims
[
i
]
!=
-
1
)
{
MIGRAPHX_THROW
(
"Reshape: Non-fixed dynamic_dimension doesn't match with 0 or -1 "
"output dimension"
);
}
}
}
if
(
num_dims_ele
!=
num_dd_ele
)
{
MIGRAPHX_THROW
(
"Reshape: Number of fixed elements must match. Input: "
+
std
::
to_string
(
num_dd_ele
)
+
" Output: "
+
std
::
to_string
(
num_dims_ele
));
}
// construct output dynamic shape from dims attribute
std
::
vector
<
shape
::
dynamic_dimension
>
output_dyn_dims
(
dims
.
size
());
std
::
transform
(
dims
.
cbegin
(),
dims
.
cend
(),
dyn_dims
.
cbegin
(),
output_dyn_dims
.
begin
(),
[](
std
::
size_t
dim
,
auto
dyn_dim
)
{
if
(
not
dyn_dim
.
is_fixed
())
return
dyn_dim
;
return
shape
::
dynamic_dimension
{
dim
,
dim
};
});
return
{
s0
.
type
(),
output_dyn_dims
};
}
shape
static_compute_shape
(
std
::
vector
<
shape
>
inputs
,
std
::
size_t
n_neg_dims
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
).
standard
();
check_shapes
{
inputs
,
*
this
}.
standard
();
auto
&&
idims
=
inputs
.
front
().
lens
();
auto
&&
idims
=
inputs
.
front
().
lens
();
std
::
vector
<
std
::
size_t
>
rdims
(
dims
.
begin
(),
dims
.
end
());
std
::
vector
<
std
::
size_t
>
rdims
(
dims
.
begin
(),
dims
.
end
());
auto
n_neg_dims
=
std
::
count
(
dims
.
begin
(),
dims
.
end
(),
-
1
);
if
(
n_neg_dims
>
1
)
MIGRAPHX_THROW
(
"Reshape: Dimensions for reshape can only have one -1 dim"
);
for
(
std
::
size_t
i
=
0
;
i
<
dims
.
size
();
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
dims
.
size
();
i
++
)
{
{
...
@@ -86,9 +133,26 @@ struct reshape
...
@@ -86,9 +133,26 @@ struct reshape
return
s
;
return
s
;
}
}
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
);
auto
n_neg_dims
=
std
::
count
(
dims
.
begin
(),
dims
.
end
(),
-
1
);
if
(
n_neg_dims
>
1
)
MIGRAPHX_THROW
(
"Reshape: Dimensions for reshape can only have one -1 dim"
);
auto
s0
=
inputs
[
0
];
if
(
s0
.
dynamic
())
{
return
dyn_compute_shape
(
s0
);
}
else
{
return
static_compute_shape
(
inputs
,
n_neg_dims
);
}
}
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
{
return
args
[
0
].
reshape
(
out
put_shape
);
return
args
[
0
].
reshape
(
dyn_out
.
com
put
ed
_shape
);
}
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
...
...
src/include/migraphx/shape.hpp
View file @
a5c87ec5
...
@@ -107,6 +107,13 @@ struct shape
...
@@ -107,6 +107,13 @@ struct shape
friend
bool
operator
==
(
const
std
::
size_t
&
x
,
const
dynamic_dimension
&
y
);
friend
bool
operator
==
(
const
std
::
size_t
&
x
,
const
dynamic_dimension
&
y
);
friend
bool
operator
!=
(
const
dynamic_dimension
&
x
,
const
std
::
size_t
&
y
);
friend
bool
operator
!=
(
const
dynamic_dimension
&
x
,
const
std
::
size_t
&
y
);
friend
bool
operator
!=
(
const
std
::
size_t
&
x
,
const
dynamic_dimension
&
y
);
friend
bool
operator
!=
(
const
std
::
size_t
&
x
,
const
dynamic_dimension
&
y
);
// add and subtract fixed std::size_t dimension
dynamic_dimension
&
operator
+=
(
const
std
::
size_t
&
x
);
dynamic_dimension
&
operator
-=
(
const
std
::
size_t
&
x
);
friend
dynamic_dimension
operator
+
(
const
dynamic_dimension
&
x
,
const
std
::
size_t
&
y
);
friend
dynamic_dimension
operator
+
(
const
std
::
size_t
&
x
,
const
dynamic_dimension
&
y
);
friend
dynamic_dimension
operator
-
(
const
dynamic_dimension
&
x
,
const
std
::
size_t
&
y
);
};
};
static
const
std
::
vector
<
type_t
>&
types
();
static
const
std
::
vector
<
type_t
>&
types
();
...
...
src/onnx/onnx_parser.cpp
View file @
a5c87ec5
...
@@ -110,9 +110,19 @@ instruction_ref onnx_parser::node_info::add_bias(const std::vector<instruction_r
...
@@ -110,9 +110,19 @@ instruction_ref onnx_parser::node_info::add_bias(const std::vector<instruction_r
{
{
if
(
args
.
size
()
==
3
)
if
(
args
.
size
()
==
3
)
{
{
auto
bias_bcast
=
mod
->
add_instruction
(
instruction_ref
bias_bcast
;
make_op
(
"broadcast"
,
{{
"axis"
,
axis
},
{
"out_lens"
,
curr_ins
->
get_shape
().
lens
()}}),
// if curr_ins has a dynamic output shape use 2 input broadcast
args
[
2
]);
if
(
curr_ins
->
get_shape
().
dynamic
())
{
bias_bcast
=
mod
->
add_instruction
(
make_op
(
"broadcast"
,
{{
"axis"
,
axis
}}),
args
[
2
],
curr_ins
);
}
else
{
bias_bcast
=
mod
->
add_instruction
(
make_op
(
"broadcast"
,
{{
"axis"
,
axis
},
{
"out_lens"
,
curr_ins
->
get_shape
().
lens
()}}),
args
[
2
]);
}
return
mod
->
add_instruction
(
make_op
(
"add"
),
curr_ins
,
bias_bcast
);
return
mod
->
add_instruction
(
make_op
(
"add"
),
curr_ins
,
bias_bcast
);
}
}
return
curr_ins
;
return
curr_ins
;
...
...
src/onnx/parse_matmul.cpp
View file @
a5c87ec5
...
@@ -43,55 +43,79 @@ struct parse_matmul : op_parser<parse_matmul>
...
@@ -43,55 +43,79 @@ struct parse_matmul : op_parser<parse_matmul>
const
onnx_parser
::
node_info
&
info
,
const
onnx_parser
::
node_info
&
info
,
std
::
vector
<
instruction_ref
>
args
)
const
std
::
vector
<
instruction_ref
>
args
)
const
{
{
auto
l0
=
args
[
0
];
auto
a0
=
args
[
0
];
auto
l1
=
args
[
1
];
auto
a1
=
args
[
1
];
auto
l0_len
s
=
l
0
->
get_shape
()
.
lens
()
;
auto
s
0
=
a
0
->
get_shape
();
auto
l1_len
s
=
l
1
->
get_shape
()
.
lens
()
;
auto
s
1
=
a
1
->
get_shape
();
// args[0] is a vector, prepend 1 to the shape
instruction_ref
dot_res
;
bool
is_a_prepended
=
false
;
bool
is_a_prepended
=
false
;
if
(
l0_lens
.
size
()
==
1
)
bool
is_b_appended
=
false
;
if
(
s0
.
ndim
()
==
1
)
{
{
is_a_prepended
=
true
;
is_a_prepended
=
true
;
l0_lens
.
insert
(
l0_lens
.
begin
(),
1
);
a0
=
info
.
add_instruction
(
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
0
}}}),
args
[
0
]);
l0
=
info
.
add_instruction
(
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
0
}}}),
args
[
0
]);
}
}
if
(
s1
.
ndim
()
==
1
)
bool
is_b_appended
=
false
;
if
(
l1_lens
.
size
()
==
1
)
{
{
is_b_appended
=
true
;
is_b_appended
=
true
;
l1_lens
.
push_back
(
1
);
a1
=
info
.
add_instruction
(
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
1
}}}),
args
[
1
]);
l1
=
info
.
add_instruction
(
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
1
}}}),
args
[
1
]);
}
}
instruction_ref
bl0
=
l0
;
if
(
s0
.
dynamic
()
or
s1
.
dynamic
())
instruction_ref
bl1
=
l1
;
if
(
not
std
::
equal
(
l0_lens
.
rbegin
()
+
2
,
l0_lens
.
rend
(),
l1_lens
.
rbegin
()
+
2
,
l1_lens
.
rend
()))
{
{
auto
l0_it
=
l0_lens
.
begin
()
+
l0_lens
.
size
()
-
2
;
if
(
opd
.
op_name
==
"quant_dot"
)
std
::
vector
<
std
::
size_t
>
l0_broadcasted_lens
(
l0_lens
.
begin
(),
l0_it
);
{
auto
l1_it
=
l1_lens
.
begin
()
+
l1_lens
.
size
()
-
2
;
MIGRAPHX_THROW
(
"PARSE_MATMUL: dynamic MatMulInteger not supported"
)
;
std
::
vector
<
std
::
size_t
>
l1_broadcasted_lens
(
l1_lens
.
begin
(),
l1_it
);
}
auto
output_lens
=
compute_broadcasted_lens
(
l0_broadcasted_lens
,
l1_broadcasted_lens
);
auto
s0_dds
=
a0
->
get_shape
().
to_dynamic
().
dyn_dims
(
);
l0_broadcasted_lens
=
output_lens
;
auto
s1_dds
=
a1
->
get_shape
().
to_dynamic
().
dyn_dims
()
;
l0_broadcasted_lens
.
insert
(
l0_broadcasted_lens
.
end
(),
l0_it
,
l0_lens
.
end
());
l1_broadcasted_lens
=
output_lens
;
// TODO: handling this case requires a new multibroadcast mode
l1_broadcasted_lens
.
insert
(
l1_broadcasted_lens
.
end
(),
l1_it
,
l1_lens
.
end
());
if
(
not
std
::
equal
(
if
(
l0_lens
!=
l0_broadcasted_lens
)
s0_dds
.
rbegin
()
+
2
,
s0_dds
.
rend
(),
s1_dds
.
rbegin
()
+
2
,
s1_dds
.
rend
())
)
{
{
bl0
=
info
.
add_instruction
(
MIGRAPHX_THROW
(
"PARSE_MATMUL: dynamic shape broadcasting not supported"
);
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
l0_broadcasted_lens
}}),
l0
);
}
}
if
(
l1_lens
!=
l1_broadcasted_lens
)
dot_res
=
info
.
add_instruction
(
make_op
(
opd
.
op_name
),
a0
,
a1
);
}
else
{
auto
s0_lens
=
a0
->
get_shape
().
lens
();
auto
s1_lens
=
a1
->
get_shape
().
lens
();
instruction_ref
ba0
=
a0
;
instruction_ref
ba1
=
a1
;
// try broadcasting if dimensions other than last two do not match
if
(
not
std
::
equal
(
s0_lens
.
rbegin
()
+
2
,
s0_lens
.
rend
(),
s1_lens
.
rbegin
()
+
2
,
s1_lens
.
rend
()))
{
{
bl1
=
info
.
add_instruction
(
auto
l0_it
=
s0_lens
.
begin
()
+
s0_lens
.
size
()
-
2
;
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
l1_broadcasted_lens
}}),
l1
);
std
::
vector
<
std
::
size_t
>
l0_broadcasted_lens
(
s0_lens
.
begin
(),
l0_it
);
auto
l1_it
=
s1_lens
.
begin
()
+
s1_lens
.
size
()
-
2
;
std
::
vector
<
std
::
size_t
>
l1_broadcasted_lens
(
s1_lens
.
begin
(),
l1_it
);
auto
output_lens
=
compute_broadcasted_lens
(
l0_broadcasted_lens
,
l1_broadcasted_lens
);
l0_broadcasted_lens
=
output_lens
;
l0_broadcasted_lens
.
insert
(
l0_broadcasted_lens
.
end
(),
l0_it
,
s0_lens
.
end
());
l1_broadcasted_lens
=
output_lens
;
l1_broadcasted_lens
.
insert
(
l1_broadcasted_lens
.
end
(),
l1_it
,
s1_lens
.
end
());
if
(
s0_lens
!=
l0_broadcasted_lens
)
{
ba0
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
l0_broadcasted_lens
}}),
a0
);
}
if
(
s1_lens
!=
l1_broadcasted_lens
)
{
ba1
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
l1_broadcasted_lens
}}),
a1
);
}
}
}
dot_res
=
info
.
add_instruction
(
make_op
(
opd
.
op_name
),
ba0
,
ba1
);
}
}
instruction_ref
dot_res
=
info
.
add_instruction
(
make_op
(
opd
.
op_name
),
bl0
,
bl1
);
int64_t
num_axis
=
static_cast
<
int64_t
>
(
dot_res
->
get_shape
().
lens
().
size
());
// squeeze the appended or prepended dimensions
int64_t
num_axis
=
dot_res
->
get_shape
().
ndim
();
if
(
is_a_prepended
)
if
(
is_a_prepended
)
{
{
dot_res
=
info
.
add_instruction
(
make_op
(
"squeeze"
,
{{
"axes"
,
{
num_axis
-
2
}}}),
dot_res
);
dot_res
=
info
.
add_instruction
(
make_op
(
"squeeze"
,
{{
"axes"
,
{
num_axis
-
2
}}}),
dot_res
);
...
...
src/onnx/parse_pad.cpp
View file @
a5c87ec5
...
@@ -147,7 +147,13 @@ struct parse_pad : op_parser<parse_pad>
...
@@ -147,7 +147,13 @@ struct parse_pad : op_parser<parse_pad>
{
{
auto
mode
=
info
.
attributes
.
at
(
"mode"
).
s
();
auto
mode
=
info
.
attributes
.
at
(
"mode"
).
s
();
if
(
mode
==
"reflect"
)
if
(
mode
==
"reflect"
)
{
if
(
args
.
front
()
->
get_shape
().
dynamic
())
{
MIGRAPHX_THROW
(
"PARSE_PAD: reflect padding with dynamic shape not supported"
);
}
return
reflect_pad
(
info
,
pads
,
args
.
front
());
return
reflect_pad
(
info
,
pads
,
args
.
front
());
}
if
(
mode
!=
"constant"
)
if
(
mode
!=
"constant"
)
{
{
MIGRAPHX_THROW
(
MIGRAPHX_THROW
(
...
...
src/onnx/parse_reduce_op.cpp
View file @
a5c87ec5
...
@@ -68,8 +68,7 @@ instruction_ref parse_reduce_oper(const std::string& op_name,
...
@@ -68,8 +68,7 @@ instruction_ref parse_reduce_oper(const std::string& op_name,
}
}
else
else
{
{
std
::
size_t
n_dim
=
args
.
front
()
->
get_shape
().
lens
().
size
();
axes
.
resize
(
args
.
front
()
->
get_shape
().
ndim
());
axes
.
resize
(
n_dim
);
std
::
iota
(
axes
.
begin
(),
axes
.
end
(),
0
);
std
::
iota
(
axes
.
begin
(),
axes
.
end
(),
0
);
}
}
}
}
...
...
src/onnx/parse_reshape.cpp
View file @
a5c87ec5
...
@@ -49,7 +49,7 @@ struct parse_reshape : op_parser<parse_reshape>
...
@@ -49,7 +49,7 @@ struct parse_reshape : op_parser<parse_reshape>
if
(
args
.
size
()
==
2
)
if
(
args
.
size
()
==
2
)
{
{
auto
s
=
args
[
1
]
->
eval
();
auto
s
=
args
[
1
]
->
eval
();
check_arg_empty
(
s
,
"Reshape:
dynamic shape
is not supported"
);
check_arg_empty
(
s
,
"Reshape:
non-constant shape input
is not supported"
);
s
.
visit
([
&
](
auto
v
)
{
copy
(
v
,
std
::
back_inserter
(
dims
));
});
s
.
visit
([
&
](
auto
v
)
{
copy
(
v
,
std
::
back_inserter
(
dims
));
});
}
}
...
...
src/shape.cpp
View file @
a5c87ec5
...
@@ -504,6 +504,31 @@ bool shape::dynamic_dimension::is_fixed() const { return this->min == this->max;
...
@@ -504,6 +504,31 @@ bool shape::dynamic_dimension::is_fixed() const { return this->min == this->max;
bool
shape
::
dynamic_dimension
::
has_optimal
()
const
{
return
opt
!=
0
;
}
bool
shape
::
dynamic_dimension
::
has_optimal
()
const
{
return
opt
!=
0
;
}
shape
::
dynamic_dimension
&
shape
::
dynamic_dimension
::
operator
+=
(
const
std
::
size_t
&
x
)
{
this
->
min
+=
x
;
this
->
max
+=
x
;
if
(
this
->
opt
!=
0
)
{
this
->
opt
+=
x
;
};
return
*
this
;
}
shape
::
dynamic_dimension
&
shape
::
dynamic_dimension
::
operator
-=
(
const
std
::
size_t
&
x
)
{
assert
(
this
->
min
>=
x
);
assert
(
this
->
max
>=
x
);
this
->
min
-=
x
;
this
->
max
-=
x
;
if
(
this
->
opt
!=
0
)
{
assert
(
this
->
opt
>=
x
);
this
->
opt
-=
x
;
}
return
*
this
;
}
bool
operator
==
(
const
shape
::
dynamic_dimension
&
x
,
const
shape
::
dynamic_dimension
&
y
)
bool
operator
==
(
const
shape
::
dynamic_dimension
&
x
,
const
shape
::
dynamic_dimension
&
y
)
{
{
// don't check opt if both are fixed
// don't check opt if both are fixed
...
@@ -529,6 +554,23 @@ bool operator==(const std::size_t& x, const shape::dynamic_dimension& y) { retur
...
@@ -529,6 +554,23 @@ bool operator==(const std::size_t& x, const shape::dynamic_dimension& y) { retur
bool
operator
!=
(
const
shape
::
dynamic_dimension
&
x
,
const
std
::
size_t
&
y
)
{
return
not
(
x
==
y
);
}
bool
operator
!=
(
const
shape
::
dynamic_dimension
&
x
,
const
std
::
size_t
&
y
)
{
return
not
(
x
==
y
);
}
bool
operator
!=
(
const
std
::
size_t
&
x
,
const
shape
::
dynamic_dimension
&
y
)
{
return
not
(
x
==
y
);
}
bool
operator
!=
(
const
std
::
size_t
&
x
,
const
shape
::
dynamic_dimension
&
y
)
{
return
not
(
x
==
y
);
}
shape
::
dynamic_dimension
operator
+
(
const
shape
::
dynamic_dimension
&
x
,
const
std
::
size_t
&
y
)
{
auto
dd
=
x
;
return
dd
+=
y
;
}
shape
::
dynamic_dimension
operator
+
(
const
std
::
size_t
&
x
,
const
shape
::
dynamic_dimension
&
y
)
{
return
y
+
x
;
}
shape
::
dynamic_dimension
operator
-
(
const
shape
::
dynamic_dimension
&
x
,
const
std
::
size_t
&
y
)
{
auto
dd
=
x
;
return
dd
-=
y
;
}
bool
operator
==
(
const
shape
&
x
,
const
shape
&
y
)
bool
operator
==
(
const
shape
&
x
,
const
shape
&
y
)
{
{
if
(
x
.
dynamic
()
and
y
.
dynamic
())
if
(
x
.
dynamic
()
and
y
.
dynamic
())
...
...
src/simplify_algebra.cpp
View file @
a5c87ec5
...
@@ -1065,11 +1065,23 @@ struct find_split_reshape
...
@@ -1065,11 +1065,23 @@ struct find_split_reshape
return
;
return
;
}
}
// Only want to apply this optimization if each split output is followed by
// a contiguous op and a reshape
if
(
std
::
any_of
(
split_outputs
.
begin
(),
split_outputs
.
end
(),
[](
auto
i
)
{
if
(
i
->
outputs
().
size
()
==
1
)
{
auto
cont
=
i
->
outputs
().
front
();
return
cont
->
outputs
().
size
()
!=
1
;
}
return
false
;
}))
{
return
;
}
std
::
vector
<
instruction_ref
>
vec_rsp
(
split_outputs
.
size
());
std
::
vector
<
instruction_ref
>
vec_rsp
(
split_outputs
.
size
());
std
::
transform
(
split_outputs
.
begin
(),
split_outputs
.
end
(),
vec_rsp
.
begin
(),
[](
auto
i
)
{
std
::
transform
(
split_outputs
.
begin
(),
split_outputs
.
end
(),
vec_rsp
.
begin
(),
[](
auto
i
)
{
assert
(
i
->
outputs
().
size
()
==
1
);
auto
cont
=
i
->
outputs
().
front
();
auto
cont
=
i
->
outputs
().
front
();
assert
(
cont
->
outputs
().
size
()
==
1
);
return
cont
->
outputs
().
front
();
return
cont
->
outputs
().
front
();
});
});
...
...
src/simplify_reshapes.cpp
View file @
a5c87ec5
...
@@ -763,16 +763,23 @@ struct find_transpose_slice
...
@@ -763,16 +763,23 @@ struct find_transpose_slice
// Compute axis before transpose to use for unsqueeze
// Compute axis before transpose to use for unsqueeze
auto
perm
=
ins
->
get_operator
().
to_value
()[
"permutation"
].
to_vector
<
int64_t
>
();
auto
perm
=
ins
->
get_operator
().
to_value
()[
"permutation"
].
to_vector
<
int64_t
>
();
auto
preaxis
=
std
::
find
(
perm
.
begin
(),
perm
.
end
(),
axis
)
-
perm
.
begin
();
auto
preaxis
=
std
::
find
(
perm
.
begin
(),
perm
.
end
(),
axis
)
-
perm
.
begin
();
// Make unsqeeze
// Make unsqueeze
std
::
vector
<
int64_t
>
steps
(
sdistance
.
size
());
std
::
transform
(
slice
.
axes
.
begin
(),
slice
.
axes
.
end
(),
sdistance
.
begin
(),
steps
.
begin
(),
[
&
](
const
auto
ax
,
const
auto
sdis
)
{
return
ins
->
get_shape
().
lens
().
at
(
ax
)
/
sdis
;
});
auto
unsqueeze
=
m
.
insert_instruction
(
auto
unsqueeze
=
m
.
insert_instruction
(
ins
,
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
preaxis
}},
{
"steps"
,
s
distance
}}),
ins
->
inputs
());
ins
,
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
preaxis
}},
{
"steps"
,
s
teps
}}),
ins
->
inputs
());
// Make transpose
// Make transpose
std
::
transform
(
perm
.
begin
(),
perm
.
end
(),
perm
.
begin
(),
[
&
](
auto
i
)
{
std
::
transform
(
perm
.
begin
(),
perm
.
end
(),
perm
.
begin
(),
[
&
](
auto
i
)
{
if
(
i
>
preaxis
)
if
(
i
>
=
preaxis
)
return
i
+
1
;
return
i
+
1
;
return
i
;
return
i
;
});
});
perm
.
insert
(
perm
.
begin
(),
preaxis
+
1
);
perm
.
insert
(
perm
.
begin
(),
preaxis
);
auto
transpose
=
auto
transpose
=
m
.
insert_instruction
(
ins
,
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
unsqueeze
);
m
.
insert_instruction
(
ins
,
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
unsqueeze
);
// Slice and squeeze
// Slice and squeeze
...
...
src/targets/gpu/jit/gather.cpp
0 → 100644
View file @
a5c87ec5
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
// NOLINTNEXTLINE
static
const
char
*
const
gather_kernel
=
R"__migraphx__(
#include <migraphx/kernels/gather.hpp>
#include <migraphx/kernels/ops.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp>
#include <args.hpp>
namespace migraphx {
extern "C" {
__global__ void gather_kernel(void* in_data, void* in_indices, void* output)
{
make_tensors()(in_data, in_indices, output)([](auto&&... xs) {
gather<${axis}>(xs...);
});
}
}
} // namespace migraphx
)__migraphx__"
;
struct
gather_compiler
:
compiler
<
gather_compiler
>
{
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"gather"
};
}
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
{
hip_compile_options
options
;
const
auto
&
out_s
=
inputs
.
back
();
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
out_s
.
elements
()));
options
.
inputs
=
inputs
;
options
.
output
=
out_s
;
options
.
kernel_name
=
"gather_kernel"
;
options
.
virtual_inputs
=
inputs
;
auto
axis
=
v
.
at
(
"axis"
).
to
<
std
::
string
>
();
auto
src
=
interpolate_string
(
gather_kernel
,
{{
"axis"
,
axis
}});
return
compile_hip_code_object
(
src
,
options
);
}
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
)
const
{
return
replace
(
compile_op
(
ctx
,
to_shapes
(
ins
->
inputs
()),
op
.
to_value
()));
}
};
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/jit/reduce.cpp
View file @
a5c87ec5
...
@@ -156,16 +156,25 @@ struct reduce_compiler : compiler<reduce_compiler>
...
@@ -156,16 +156,25 @@ struct reduce_compiler : compiler<reduce_compiler>
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
)
const
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
)
const
{
{
value
v
=
value
::
object
{};
value
v
=
value
::
object
{};
auto
reduce_elements
=
get_reduce_elements
(
ins
->
inputs
());
if
(
op
.
name
()
==
"reduce_sum"
)
if
(
op
.
name
()
==
"reduce_sum"
)
{
{
v
[
"reduction"
]
=
"op::sum{}"
;
v
[
"reduction"
]
=
"op::sum{}"
;
}
}
else
if
(
op
.
name
()
==
"reduce_mean"
)
else
if
(
op
.
name
()
==
"reduce_mean"
)
{
{
v
[
"reduction"
]
=
"op::sum{}"
;
auto
reduce_elements
=
get_reduce_elements
(
ins
->
inputs
());
v
[
"write"
]
=
"op::mean{"
+
std
::
to_string
(
reduce_elements
)
+
"}"
;
auto
reduce_type
=
ins
->
inputs
().
front
()
->
get_shape
().
type
();
v
[
"reduction"
]
=
"op::sum{}"
;
std
::
string
mean
=
"op::mean{"
+
std
::
to_string
(
reduce_elements
)
+
"}"
;
// Use float accumulator when reduction size is too large for half
if
(
reduce_type
==
shape
::
half_type
and
reduce_elements
>
16384
)
v
[
"read"
]
=
"compose("
+
mean
+
", op::convert_to<float>{})"
;
else
if
(
contains
({
shape
::
float_type
,
shape
::
half_type
,
shape
::
double_type
},
reduce_type
))
v
[
"read"
]
=
mean
;
else
v
[
"write"
]
=
mean
;
}
}
else
if
(
op
.
name
()
==
"reduce_max"
)
else
if
(
op
.
name
()
==
"reduce_max"
)
{
{
...
...
src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp
View file @
a5c87ec5
...
@@ -187,6 +187,14 @@ constexpr auto fold(F f)
...
@@ -187,6 +187,14 @@ constexpr auto fold(F f)
return
[
=
](
auto
&&
...
xs
)
{
return
fold_impl
(
f
,
static_cast
<
decltype
(
xs
)
&&>
(
xs
)...);
};
return
[
=
](
auto
&&
...
xs
)
{
return
fold_impl
(
f
,
static_cast
<
decltype
(
xs
)
&&>
(
xs
)...);
};
}
}
template
<
class
...
Fs
>
constexpr
auto
compose
(
Fs
...
fs
)
{
return
fold
([](
auto
f
,
auto
g
)
{
return
[
=
](
auto
&&
...
xs
)
{
return
f
(
g
(
static_cast
<
decltype
(
xs
)
>
(
xs
)...));
};
})(
fs
...);
}
template
<
class
...
Ts
>
template
<
class
...
Ts
>
constexpr
auto
pack
(
Ts
...
xs
)
constexpr
auto
pack
(
Ts
...
xs
)
{
{
...
...
src/targets/gpu/kernels/include/migraphx/kernels/gather.hpp
0 → 100644
View file @
a5c87ec5
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_KERNELS_GATHER_HPP
#define MIGRAPHX_GUARD_KERNELS_GATHER_HPP
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/shape.hpp>
#include <migraphx/kernels/algorithm.hpp>
#include <migraphx/kernels/tensor_view.hpp>
namespace
migraphx
{
template
<
int
Axis
,
class
Input
,
class
Indices
>
constexpr
auto
gather_shape
(
Input
input
,
Indices
indices
)
{
auto
lengths
=
input
.
lens
;
lengths
[
Axis
]
=
indices
.
elements
();
return
make_shape
(
lengths
,
input
.
strides
);
}
template
<
int
Axis
,
class
Input
,
class
Indices
,
class
Output
>
__device__
void
gather
(
Input
input
,
Indices
indices
,
Output
output
)
{
auto
ind
=
make_index
();
auto
axis_dim_size
=
input
.
get_shape
().
lens
[
Axis
];
constexpr
auto
out_comp
=
gather_shape
<
Axis
>
(
get_shape_c
<
Input
>
{},
get_shape_c
<
Indices
>
{});
ind
.
global_stride
(
output
.
get_shape
().
elements
(),
[
&
](
auto
i
)
{
auto
idx
=
out_comp
.
multi
(
i
);
auto
in_index
=
indices
[
idx
[
Axis
]];
auto
new_in_index
=
(
in_index
<
0
)
?
in_index
+
axis_dim_size
:
in_index
;
idx
[
Axis
]
=
new_in_index
;
output
[
i
]
=
input
[
idx
];
});
}
}
// namespace migraphx
#endif
src/targets/gpu/kernels/include/migraphx/kernels/math.hpp
View file @
a5c87ec5
...
@@ -132,9 +132,14 @@ MIGRAPHX_DEVICE_MATH_FOR(float, fmod, ::fmodf)
...
@@ -132,9 +132,14 @@ MIGRAPHX_DEVICE_MATH_FOR(float, fmod, ::fmodf)
// Builtin half functions
// Builtin half functions
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
abs
,
::
__habs
)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
abs
,
::
__habs
)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
ceil
,
::
hceil
)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
cos
,
::
hcos
)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
exp
,
::
hexp
)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
exp
,
::
hexp
)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
floor
,
::
hfloor
)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
isnan
,
::
__hisnan
)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
log
,
::
hlog
)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
log
,
::
hlog
)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
rsqrt
,
::
hrsqrt
)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
rsqrt
,
::
hrsqrt
)
// MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, sin, ::hsin)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
sqrt
,
::
hsqrt
)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
sqrt
,
::
hsqrt
)
// Use float to compute half overload
// Use float to compute half overload
...
@@ -144,16 +149,11 @@ MIGRAPHX_DEVICE_MATH_HALF(asin, ::asin)
...
@@ -144,16 +149,11 @@ MIGRAPHX_DEVICE_MATH_HALF(asin, ::asin)
MIGRAPHX_DEVICE_MATH_HALF
(
asinh
,
::
asinh
)
MIGRAPHX_DEVICE_MATH_HALF
(
asinh
,
::
asinh
)
MIGRAPHX_DEVICE_MATH_HALF
(
atan
,
::
atan
)
MIGRAPHX_DEVICE_MATH_HALF
(
atan
,
::
atan
)
MIGRAPHX_DEVICE_MATH_HALF
(
atanh
,
::
atanh
)
MIGRAPHX_DEVICE_MATH_HALF
(
atanh
,
::
atanh
)
MIGRAPHX_DEVICE_MATH_HALF
(
ceil
,
::
ceil
)
MIGRAPHX_DEVICE_MATH_HALF
(
cos
,
::
cos
)
MIGRAPHX_DEVICE_MATH_HALF
(
cosh
,
::
cosh
)
MIGRAPHX_DEVICE_MATH_HALF
(
cosh
,
::
cosh
)
MIGRAPHX_DEVICE_MATH_HALF
(
erf
,
::
erf
)
MIGRAPHX_DEVICE_MATH_HALF
(
erf
,
::
erf
)
MIGRAPHX_DEVICE_MATH_HALF
(
floor
,
::
floor
)
MIGRAPHX_DEVICE_MATH_HALF
(
isnan
,
::
isnan
)
MIGRAPHX_DEVICE_MATH_HALF
(
pow
,
::
pow
)
MIGRAPHX_DEVICE_MATH_HALF
(
pow
,
::
pow
)
MIGRAPHX_DEVICE_MATH_HALF
(
remainder
,
::
remainder
)
MIGRAPHX_DEVICE_MATH_HALF
(
remainder
,
::
remainder
)
MIGRAPHX_DEVICE_MATH_HALF
(
round
,
::
round
)
MIGRAPHX_DEVICE_MATH_HALF
(
round
,
::
round
)
MIGRAPHX_DEVICE_MATH_HALF
(
sin
,
::
sin
)
MIGRAPHX_DEVICE_MATH_HALF
(
sinh
,
::
sinh
)
MIGRAPHX_DEVICE_MATH_HALF
(
sinh
,
::
sinh
)
MIGRAPHX_DEVICE_MATH_HALF
(
tan
,
::
tan
)
MIGRAPHX_DEVICE_MATH_HALF
(
tan
,
::
tan
)
MIGRAPHX_DEVICE_MATH_HALF
(
tanh
,
::
tanh
)
MIGRAPHX_DEVICE_MATH_HALF
(
tanh
,
::
tanh
)
...
@@ -166,19 +166,19 @@ MIGRAPHX_DEVICE_MATH_HALF(fmod, ::fmod)
...
@@ -166,19 +166,19 @@ MIGRAPHX_DEVICE_MATH_HALF(fmod, ::fmod)
// at this time are: exp2, exp10, log2, log10, isinf
// at this time are: exp2, exp10, log2, log10, isinf
MIGRAPHX_DEVICE_MATH_HALF2
(
abs
,
::
__habs2
)
MIGRAPHX_DEVICE_MATH_HALF2
(
abs
,
::
__habs2
)
MIGRAPHX_DEVICE_MATH_HALF2
(
ceil
,
::
h2ceil
)
MIGRAPHX_DEVICE_MATH_HALF2
(
ceil
,
::
h2ceil
)
MIGRAPHX_DEVICE_MATH_HALF2
(
floor
,
::
h2floor
)
MIGRAPHX_DEVICE_MATH_HALF2
(
sin
,
::
h2sin
)
MIGRAPHX_DEVICE_MATH_HALF2
(
cos
,
::
h2cos
)
MIGRAPHX_DEVICE_MATH_HALF2
(
cos
,
::
h2cos
)
MIGRAPHX_DEVICE_MATH_HALF2
(
exp
,
::
h2exp
)
MIGRAPHX_DEVICE_MATH_HALF2
(
exp
,
::
h2exp
)
MIGRAPHX_DEVICE_MATH_HALF2
(
exp2
,
::
h2exp2
)
MIGRAPHX_DEVICE_MATH_HALF2
(
exp10
,
::
h2exp10
)
MIGRAPHX_DEVICE_MATH_HALF2
(
exp10
,
::
h2exp10
)
MIGRAPHX_DEVICE_MATH_HALF2
(
log2
,
::
h2log2
)
MIGRAPHX_DEVICE_MATH_HALF2
(
exp2
,
::
h2exp2
)
MIGRAPHX_DEVICE_MATH_HALF2
(
floor
,
::
h2floor
)
MIGRAPHX_DEVICE_MATH_HALF2
(
isinf
,
::
__hisinf2
)
MIGRAPHX_DEVICE_MATH_HALF2
(
isnan
,
::
__hisnan2
)
MIGRAPHX_DEVICE_MATH_HALF2
(
log
,
::
h2log
)
MIGRAPHX_DEVICE_MATH_HALF2
(
log
,
::
h2log
)
MIGRAPHX_DEVICE_MATH_HALF2
(
log10
,
::
h2log10
)
MIGRAPHX_DEVICE_MATH_HALF2
(
log10
,
::
h2log10
)
MIGRAPHX_DEVICE_MATH_HALF2
(
log2
,
::
h2log2
)
MIGRAPHX_DEVICE_MATH_HALF2
(
rsqrt
,
::
h2rsqrt
)
MIGRAPHX_DEVICE_MATH_HALF2
(
rsqrt
,
::
h2rsqrt
)
// MIGRAPHX_DEVICE_MATH_HALF2(sin, ::h2sin)
MIGRAPHX_DEVICE_MATH_HALF2
(
sqrt
,
::
h2sqrt
)
MIGRAPHX_DEVICE_MATH_HALF2
(
sqrt
,
::
h2sqrt
)
MIGRAPHX_DEVICE_MATH_HALF2
(
isinf
,
::
__hisinf2
)
MIGRAPHX_DEVICE_MATH_HALF2
(
isnan
,
::
__hisnan2
)
template
<
class
T
,
class
U
>
template
<
class
T
,
class
U
>
constexpr
auto
where
(
bool
cond
,
const
T
&
a
,
const
U
&
b
)
constexpr
auto
where
(
bool
cond
,
const
T
&
a
,
const
U
&
b
)
...
@@ -218,6 +218,14 @@ constexpr auto min(const T& a, const U& b)
...
@@ -218,6 +218,14 @@ constexpr auto min(const T& a, const U& b)
return
min
<
common_type_t
<
T
,
U
>>
(
a
,
b
);
return
min
<
common_type_t
<
T
,
U
>>
(
a
,
b
);
}
}
// Sin for half is broken on hip, so use cos instead
template
<
class
T
,
MIGRAPHX_REQUIRES
(
is_same
<
vec_type
<
T
>,
half
>
{})
>
constexpr
T
sin
(
T
x
)
{
constexpr
const
T
shift
=
M_PI_2
;
return
migraphx
::
cos
(
shift
-
x
);
}
MIGRAPHX_DEVICE_MATH_VEC
(
abs
)
MIGRAPHX_DEVICE_MATH_VEC
(
abs
)
MIGRAPHX_DEVICE_MATH_VEC
(
acos
)
MIGRAPHX_DEVICE_MATH_VEC
(
acos
)
MIGRAPHX_DEVICE_MATH_VEC
(
acosh
)
MIGRAPHX_DEVICE_MATH_VEC
(
acosh
)
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment