Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
673ca71c
Commit
673ca71c
authored
Jan 31, 2023
by
Paul
Browse files
Merge branch 'develop' into conv-add
parents
4bfe1662
91cc7242
Changes
98
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
434 additions
and
158 deletions
+434
-158
src/simplify_algebra.cpp
src/simplify_algebra.cpp
+14
-2
src/simplify_reshapes.cpp
src/simplify_reshapes.cpp
+11
-4
src/targets/gpu/CMakeLists.txt
src/targets/gpu/CMakeLists.txt
+55
-53
src/targets/gpu/compile_hip.cpp
src/targets/gpu/compile_hip.cpp
+30
-15
src/targets/gpu/compile_hip_code_object.cpp
src/targets/gpu/compile_hip_code_object.cpp
+1
-1
src/targets/gpu/device/include/migraphx/gpu/device/reduce.hpp
...targets/gpu/device/include/migraphx/gpu/device/reduce.hpp
+3
-5
src/targets/gpu/fuse_ops.cpp
src/targets/gpu/fuse_ops.cpp
+53
-28
src/targets/gpu/jit/gather.cpp
src/targets/gpu/jit/gather.cpp
+89
-0
src/targets/gpu/jit/reduce.cpp
src/targets/gpu/jit/reduce.cpp
+13
-4
src/targets/gpu/kernels/include/migraphx/kernels/array.hpp
src/targets/gpu/kernels/include/migraphx/kernels/array.hpp
+1
-1
src/targets/gpu/kernels/include/migraphx/kernels/dpp.hpp
src/targets/gpu/kernels/include/migraphx/kernels/dpp.hpp
+1
-1
src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp
...rgets/gpu/kernels/include/migraphx/kernels/functional.hpp
+8
-0
src/targets/gpu/kernels/include/migraphx/kernels/gather.hpp
src/targets/gpu/kernels/include/migraphx/kernels/gather.hpp
+64
-0
src/targets/gpu/kernels/include/migraphx/kernels/gathernd.hpp
...targets/gpu/kernels/include/migraphx/kernels/gathernd.hpp
+13
-19
src/targets/gpu/kernels/include/migraphx/kernels/hip.hpp
src/targets/gpu/kernels/include/migraphx/kernels/hip.hpp
+8
-1
src/targets/gpu/kernels/include/migraphx/kernels/math.hpp
src/targets/gpu/kernels/include/migraphx/kernels/math.hpp
+20
-13
src/targets/gpu/kernels/include/migraphx/kernels/ops.hpp
src/targets/gpu/kernels/include/migraphx/kernels/ops.hpp
+10
-0
src/targets/gpu/kernels/include/migraphx/kernels/shape.hpp
src/targets/gpu/kernels/include/migraphx/kernels/shape.hpp
+1
-8
src/targets/gpu/kernels/include/migraphx/kernels/types.hpp
src/targets/gpu/kernels/include/migraphx/kernels/types.hpp
+39
-2
src/targets/gpu/lowering.cpp
src/targets/gpu/lowering.cpp
+0
-1
No files found.
src/simplify_algebra.cpp
View file @
673ca71c
...
...
@@ -1092,11 +1092,23 @@ struct find_split_reshape
return
;
}
// Only want to apply this optimization if each split output is followed by
// a contiguous op and a reshape
if
(
std
::
any_of
(
split_outputs
.
begin
(),
split_outputs
.
end
(),
[](
auto
i
)
{
if
(
i
->
outputs
().
size
()
==
1
)
{
auto
cont
=
i
->
outputs
().
front
();
return
cont
->
outputs
().
size
()
!=
1
;
}
return
false
;
}))
{
return
;
}
std
::
vector
<
instruction_ref
>
vec_rsp
(
split_outputs
.
size
());
std
::
transform
(
split_outputs
.
begin
(),
split_outputs
.
end
(),
vec_rsp
.
begin
(),
[](
auto
i
)
{
assert
(
i
->
outputs
().
size
()
==
1
);
auto
cont
=
i
->
outputs
().
front
();
assert
(
cont
->
outputs
().
size
()
==
1
);
return
cont
->
outputs
().
front
();
});
...
...
src/simplify_reshapes.cpp
View file @
673ca71c
...
...
@@ -763,16 +763,23 @@ struct find_transpose_slice
// Compute axis before transpose to use for unsqueeze
auto
perm
=
ins
->
get_operator
().
to_value
()[
"permutation"
].
to_vector
<
int64_t
>
();
auto
preaxis
=
std
::
find
(
perm
.
begin
(),
perm
.
end
(),
axis
)
-
perm
.
begin
();
// Make unsqeeze
// Make unsqueeze
std
::
vector
<
int64_t
>
steps
(
sdistance
.
size
());
std
::
transform
(
slice
.
axes
.
begin
(),
slice
.
axes
.
end
(),
sdistance
.
begin
(),
steps
.
begin
(),
[
&
](
const
auto
ax
,
const
auto
sdis
)
{
return
ins
->
get_shape
().
lens
().
at
(
ax
)
/
sdis
;
});
auto
unsqueeze
=
m
.
insert_instruction
(
ins
,
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
preaxis
}},
{
"steps"
,
s
distance
}}),
ins
->
inputs
());
ins
,
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
preaxis
}},
{
"steps"
,
s
teps
}}),
ins
->
inputs
());
// Make transpose
std
::
transform
(
perm
.
begin
(),
perm
.
end
(),
perm
.
begin
(),
[
&
](
auto
i
)
{
if
(
i
>
preaxis
)
if
(
i
>
=
preaxis
)
return
i
+
1
;
return
i
;
});
perm
.
insert
(
perm
.
begin
(),
preaxis
+
1
);
perm
.
insert
(
perm
.
begin
(),
preaxis
);
auto
transpose
=
m
.
insert_instruction
(
ins
,
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
unsqueeze
);
// Slice and squeeze
...
...
src/targets/gpu/CMakeLists.txt
View file @
673ca71c
#####################################################################################
#
####################################################################################
# The MIT License (MIT)
#
# Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
...
...
@@ -20,7 +20,7 @@
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#####################################################################################
#
####################################################################################
list
(
APPEND CMAKE_PREFIX_PATH /opt/rocm /opt/rocm/hip /opt/rocm/hcc
)
find_package
(
miopen
)
...
...
@@ -33,6 +33,8 @@ if(NOT TARGET MIOpen)
message
(
SEND_ERROR
"Cant find miopen"
)
endif
()
set
(
MIGRAPHX_USE_HIPRTC OFF CACHE BOOL
"Use hipRTC APIs"
)
include
(
Embed
)
file
(
GLOB KERNEL_FILES
${
CONFIGURE_DEPENDS
}
${
CMAKE_CURRENT_SOURCE_DIR
}
/kernels/include/migraphx/kernels/*.hpp
)
...
...
@@ -46,9 +48,10 @@ add_library(compile_for_gpu INTERFACE)
target_compile_options
(
compile_for_gpu INTERFACE -std=c++17 -fno-gpu-rdc -Wno-cuda-compat -Wno-unused-command-line-argument -Xclang -fallow-half-arguments-and-returns
)
target_link_libraries
(
compile_for_gpu INTERFACE hip::device -fno-gpu-rdc -Wno-invalid-command-line-argument -Wno-unused-command-line-argument -Wno-option-ignored
)
check_cxx_compiler_flag
(
"--cuda-host-only -fhip-lambda-host-device -x hip"
HAS_HIP_LAMBDA_HOST_DEVICE
)
if
(
HAS_HIP_LAMBDA_HOST_DEVICE
)
message
(
STATUS
"Enable -fhip-lambda-host-device"
)
target_compile_options
(
compile_for_gpu INTERFACE -fhip-lambda-host-device
)
message
(
STATUS
"Enable -fhip-lambda-host-device"
)
target_compile_options
(
compile_for_gpu INTERFACE -fhip-lambda-host-device
)
endif
()
set_target_properties
(
migraphx_device PROPERTIES EXPORT_NAME device
)
...
...
@@ -60,11 +63,13 @@ target_include_directories(migraphx_device PUBLIC $<BUILD_INTERFACE:${CMAKE_CURR
target_include_directories
(
migraphx_device PRIVATE $<BUILD_INTERFACE:
${
CMAKE_CURRENT_SOURCE_DIR
}
/device/include>
)
add_library
(
kernel_file_check EXCLUDE_FROM_ALL
)
foreach
(
KERNEL_FILE
${
KERNEL_FILES
}
)
get_filename_component
(
KERNEL_BASE_FILE
${
KERNEL_FILE
}
NAME_WE
)
file
(
WRITE
${
CMAKE_CURRENT_BINARY_DIR
}
/kernels/include/migraphx/kernels/
${
KERNEL_BASE_FILE
}
.cpp
"#include <migraphx/kernels/
${
KERNEL_BASE_FILE
}
.hpp>
\n
"
)
target_sources
(
kernel_file_check PRIVATE
${
CMAKE_CURRENT_BINARY_DIR
}
/kernels/include/migraphx/kernels/
${
KERNEL_BASE_FILE
}
.cpp
)
endforeach
()
target_compile_definitions
(
kernel_file_check PRIVATE -DMIGRAPHX_NLOCAL=256
)
target_include_directories
(
kernel_file_check PRIVATE $<BUILD_INTERFACE:
${
CMAKE_CURRENT_SOURCE_DIR
}
/kernels/include/>
)
target_link_libraries
(
kernel_file_check compile_for_gpu
)
...
...
@@ -125,6 +130,7 @@ function(register_migraphx_gpu_ops PREFIX)
register_op
(
migraphx_gpu HEADER migraphx/gpu/
${
OP
}
.hpp OPERATORS gpu::
${
PREFIX
}${
OP
}
INCLUDES migraphx/gpu/context.hpp
)
endforeach
()
endfunction
()
register_migraphx_gpu_ops
(
hip_
argmax
argmin
...
...
@@ -146,47 +152,41 @@ register_migraphx_gpu_ops(miopen_
lrn
pooling
)
register_op
(
migraphx_gpu
HEADER migraphx/gpu/rnn_variable_seq_lens.hpp
register_op
(
migraphx_gpu
HEADER migraphx/gpu/rnn_variable_seq_lens.hpp
OPERATORS gpu::hip_rnn_var_sl_shift_sequence gpu::hip_rnn_var_sl_shift_output gpu::hip_rnn_var_sl_last_output
INCLUDES migraphx/gpu/context.hpp
)
register_op
(
migraphx_gpu
HEADER migraphx/gpu/int8_gemm_pack.hpp
register_op
(
migraphx_gpu
HEADER migraphx/gpu/int8_gemm_pack.hpp
OPERATORS gpu::hip_int8_gemm_pack_a gpu::hip_int8_gemm_pack_b
INCLUDES migraphx/gpu/context.hpp
)
register_op
(
migraphx_gpu
HEADER migraphx/gpu/gemm.hpp
register_op
(
migraphx_gpu
HEADER migraphx/gpu/gemm.hpp
OPERATORS gpu::rocblas_gemm<op::dot> gpu::rocblas_gemm<op::quant_dot>
INCLUDES migraphx/gpu/context.hpp
)
register_op
(
migraphx_gpu HEADER migraphx/gpu/convolution.hpp
register_op
(
migraphx_gpu HEADER migraphx/gpu/convolution.hpp
OPERATORS gpu::miopen_convolution<op::convolution> gpu::miopen_convolution<op::deconvolution> gpu::miopen_convolution<op::quant_convolution>
INCLUDES migraphx/gpu/context.hpp
)
rocm_set_soversion
(
migraphx_gpu
${
MIGRAPHX_SO_VERSION
}
)
rocm_clang_tidy_check
(
migraphx_gpu
)
# look for offload bundler
get_filename_component
(
CMAKE_CXX_COMPILER_PATH
"
${
CMAKE_CXX_COMPILER
}
"
PATH
)
if
(
CMAKE_CXX_COMPILER MATCHES
".*clang
\\
+
\\
+$"
)
find_program
(
MIGRAPHX_OFFLOADBUNDLER_BIN clang-offload-bundler
HINTS
${
CMAKE_CXX_COMPILER_PATH
}
PATH_SUFFIXES bin
PATHS /opt/rocm/llvm
)
else
()
if
(
NOT CMAKE_CXX_COMPILER MATCHES
".*clang
\\
+
\\
+$"
)
find_program
(
MIGRAPHX_EXTRACT_KERNEL extractkernel
PATH_SUFFIXES bin
HINTS
${
CMAKE_CXX_COMPILER_PATH
}
PATHS
/opt/rocm/hip
/opt/rocm/hcc
/opt/rocm
/opt/rocm/hip
/opt/rocm/hcc
/opt/rocm
)
endif
()
message
(
STATUS
"clang-offload-bundler:
${
MIGRAPHX_OFFLOADBUNDLER_BIN
}
"
)
message
(
STATUS
"extractkernel:
${
MIGRAPHX_EXTRACT_KERNEL
}
"
)
set
(
MIGRAPHX_ENABLE_MLIR OFF CACHE BOOL
""
)
if
(
MIGRAPHX_ENABLE_MLIR
)
# Find package rocMLIR
find_package
(
rocMLIR 1.0.0 CONFIG REQUIRED
)
...
...
@@ -195,36 +195,39 @@ if(MIGRAPHX_ENABLE_MLIR)
target_link_libraries
(
migraphx_gpu PUBLIC rocMLIR::rockCompiler
)
endif
()
set
(
MIGRAPHX_USE_HIPRTC OFF CACHE BOOL
""
)
if
(
MIGRAPHX_USE_HIPRTC
)
target_compile_definitions
(
migraphx_gpu PRIVATE -DMIGRAPHX_USE_HIPRTC=1
)
message
(
STATUS
"MIGraphX is using hipRTC"
)
target_compile_definitions
(
migraphx_gpu PRIVATE -DMIGRAPHX_USE_HIPRTC=1
)
else
()
# Get flags needed to compile hip
include
(
TargetFlags
)
target_flags
(
HIP_COMPILER_FLAGS hip::device
)
# Remove cuda arch flags
string
(
REGEX REPLACE --cuda-gpu-arch=[a-z0-9]+
""
HIP_COMPILER_FLAGS
"
${
HIP_COMPILER_FLAGS
}
"
)
string
(
REGEX REPLACE --offload-arch=[a-z0-9:+-]+
""
HIP_COMPILER_FLAGS
"
${
HIP_COMPILER_FLAGS
}
"
)
# Skip library paths since hip will incorrectly treat it as a source file
string
(
APPEND HIP_COMPILER_FLAGS
" "
)
foreach
(
_unused RANGE 2
)
string
(
REGEX REPLACE
" /[^ ]+
\\
.(a|so) "
" "
HIP_COMPILER_FLAGS
"
${
HIP_COMPILER_FLAGS
}
"
)
endforeach
()
message
(
STATUS
"MIGraphX is using HIP Clang"
)
message
(
STATUS
"Hip compiler flags:
${
HIP_COMPILER_FLAGS
}
"
)
target_compile_definitions
(
migraphx_gpu PRIVATE
"-DMIGRAPHX_HIP_COMPILER=
${
CMAKE_CXX_COMPILER
}
"
"-DMIGRAPHX_HIP_COMPILER_FLAGS=
${
HIP_COMPILER_FLAGS
}
"
"-DMIGRAPHX_OFFLOADBUNDLER_BIN=
${
MIGRAPHX_OFFLOADBUNDLER_BIN
}
"
"-DMIGRAPHX_EXTRACT_KERNEL=
${
MIGRAPHX_EXTRACT_KERNEL
}
"
"-DMIGRAPHX_USE_HIPRTC=0"
)
if
(
DEFINED CMAKE_CXX_COMPILER_LAUNCHER
)
execute_process
(
COMMAND which
${
CMAKE_CXX_COMPILER_LAUNCHER
}
OUTPUT_VARIABLE MIGRAPHX_HIP_COMPILER_LAUNCHER
)
string
(
STRIP
"
${
MIGRAPHX_HIP_COMPILER_LAUNCHER
}
"
MIGRAPHX_HIP_COMPILER_LAUNCHER
)
target_compile_definitions
(
migraphx_gpu PRIVATE
"-DMIGRAPHX_HIP_COMPILER_LAUNCHER=
${
MIGRAPHX_HIP_COMPILER_LAUNCHER
}
"
)
endif
()
# Get flags needed to compile hip
include
(
TargetFlags
)
target_flags
(
HIP_COMPILER_FLAGS hip::device
)
# Remove cuda arch flags
string
(
REGEX REPLACE --cuda-gpu-arch=[a-z0-9]+
""
HIP_COMPILER_FLAGS
"
${
HIP_COMPILER_FLAGS
}
"
)
string
(
REGEX REPLACE --offload-arch=[a-z0-9:+-]+
""
HIP_COMPILER_FLAGS
"
${
HIP_COMPILER_FLAGS
}
"
)
# Skip library paths since hip will incorrectly treat it as a source file
string
(
APPEND HIP_COMPILER_FLAGS
" "
)
foreach
(
_unused RANGE 2
)
string
(
REGEX REPLACE
" /[^ ]+
\\
.(a|so) "
" "
HIP_COMPILER_FLAGS
"
${
HIP_COMPILER_FLAGS
}
"
)
endforeach
()
message
(
STATUS
"Hip compiler flags:
${
HIP_COMPILER_FLAGS
}
"
)
target_compile_definitions
(
migraphx_gpu PRIVATE
"-DMIGRAPHX_HIP_COMPILER=
${
CMAKE_CXX_COMPILER
}
"
"-DMIGRAPHX_HIP_COMPILER_FLAGS=
${
HIP_COMPILER_FLAGS
}
"
"-DMIGRAPHX_EXTRACT_KERNEL=
${
MIGRAPHX_EXTRACT_KERNEL
}
"
)
if
(
DEFINED CMAKE_CXX_COMPILER_LAUNCHER
)
execute_process
(
COMMAND which
${
CMAKE_CXX_COMPILER_LAUNCHER
}
OUTPUT_VARIABLE MIGRAPHX_HIP_COMPILER_LAUNCHER
)
string
(
STRIP
"
${
MIGRAPHX_HIP_COMPILER_LAUNCHER
}
"
MIGRAPHX_HIP_COMPILER_LAUNCHER
)
target_compile_definitions
(
migraphx_gpu PRIVATE
"-DMIGRAPHX_HIP_COMPILER_LAUNCHER=
${
MIGRAPHX_HIP_COMPILER_LAUNCHER
}
"
)
endif
()
endif
()
# Check miopen find mode api
...
...
@@ -236,7 +239,7 @@ check_library_exists(MIOpen "miopenFindSolutions" "${MIOPEN_LOCATION}" HAS_FIND_
# TODO: Set default to HAS_FIND_2_API
set
(
MIGRAPHX_USE_FIND_2_API OFF CACHE BOOL
""
)
if
(
MIGRAPHX_USE_FIND_2_API
)
if
(
MIGRAPHX_USE_FIND_2_API
)
target_compile_definitions
(
migraphx_gpu PUBLIC -DMIGRAPHX_HAS_FIND_2_API
)
message
(
STATUS
"MIGraphx is using Find-2.0 API of MIOpen"
)
else
()
...
...
@@ -258,8 +261,7 @@ target_link_libraries(migraphx_gpu PRIVATE migraphx_device migraphx_kernels)
add_subdirectory
(
driver
)
rocm_install_targets
(
TARGETS migraphx_gpu migraphx_device compile_for_gpu
INCLUDE
TARGETS migraphx_gpu migraphx_device compile_for_gpu
INCLUDE
${
CMAKE_CURRENT_SOURCE_DIR
}
/include
)
src/targets/gpu/compile_hip.cpp
View file @
673ca71c
...
...
@@ -29,10 +29,9 @@
#include <cassert>
#include <iostream>
#if MIGRAPHX_USE_HIPRTC
#if
def
MIGRAPHX_USE_HIPRTC
#include <hip/hiprtc.h>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/env.hpp>
#else
#include <migraphx/compile_src.hpp>
#include <migraphx/process.hpp>
...
...
@@ -48,9 +47,10 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_OPTIMIZE);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_GPU_DUMP_ASM
);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_GPU_DUMP_SRC
);
#if MIGRAPHX_USE_HIPRTC
#if
def
MIGRAPHX_USE_HIPRTC
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_TRACE_HIPRTC
)
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_TRACE_HIPRTC
);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS
);
std
::
string
hiprtc_error
(
hiprtcResult
err
,
const
std
::
string
&
msg
)
{
...
...
@@ -143,25 +143,29 @@ struct hiprtc_program
options
.
end
(),
std
::
back_inserter
(
c_options
),
[](
const
std
::
string
&
s
)
{
return
s
.
c_str
();
});
auto
result
=
hiprtcCompileProgram
(
prog
.
get
(),
c_options
.
size
(),
c_options
.
data
());
std
::
cerr
<<
log
()
<<
std
::
endl
;
auto
result
=
hiprtcCompileProgram
(
prog
.
get
(),
c_options
.
size
(),
c_options
.
data
());
auto
prog_log
=
log
();
if
(
not
prog_log
.
empty
())
{
std
::
cerr
<<
prog_log
<<
std
::
endl
;
}
if
(
result
!=
HIPRTC_SUCCESS
)
MIGRAPHX_HIPRTC_THROW
(
result
,
"Compilation failed."
);
}
std
::
string
log
()
std
::
string
log
()
const
{
std
::
size_t
n
=
0
;
MIGRAPHX_HIPRTC
(
hiprtcGetProgramLogSize
(
prog
.
get
(),
&
n
));
if
(
n
<
2
)
if
(
n
==
0
)
return
{};
std
::
vector
<
char
>
buffer
(
n
);
std
::
string
buffer
(
n
,
'\0'
);
MIGRAPHX_HIPRTC
(
hiprtcGetProgramLog
(
prog
.
get
(),
buffer
.
data
()));
assert
(
buffer
.
back
()
=
=
0
);
return
{
buffer
.
begin
(),
buffer
.
end
()
-
1
}
;
assert
(
buffer
.
back
()
!
=
0
);
return
buffer
;
}
std
::
vector
<
char
>
get_code_obj
()
std
::
vector
<
char
>
get_code_obj
()
const
{
std
::
size_t
n
=
0
;
MIGRAPHX_HIPRTC
(
hiprtcGetCodeSize
(
prog
.
get
(),
&
n
));
...
...
@@ -176,6 +180,17 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
{
hiprtc_program
prog
(
srcs
);
auto
options
=
split_string
(
params
,
' '
);
options
.
push_back
(
"-DMIGRAPHX_USE_HIPRTC=1"
);
// remove following three compilation flags for HIPRTC once fixes from hipRTC are available in
if
(
enabled
(
MIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS
{}))
{
options
.
push_back
(
"-DMIGRAPHX_HAS_DPP=0"
);
options
.
push_back
(
"-DMIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS=1"
);
options
.
push_back
(
"-Wno-reserved-identifier"
);
options
.
push_back
(
"-Wno-gnu-line-marker"
);
options
.
push_back
(
"-Wno-old-style-cast"
);
}
if
(
enabled
(
MIGRAPHX_GPU_DEBUG
{}))
options
.
push_back
(
"-DMIGRAPHX_DEBUG"
);
if
(
std
::
none_of
(
options
.
begin
(),
options
.
end
(),
[](
const
std
::
string
&
s
)
{
...
...
@@ -183,7 +198,7 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
}))
options
.
push_back
(
"-std=c++17"
);
options
.
push_back
(
"-fno-gpu-rdc"
);
options
.
push_back
(
"
-O"
+
string_value_of
(
MIGRAPHX_GPU_OPTIMIZE
{},
"3"
));
options
.
push_back
(
"-O"
+
string_value_of
(
MIGRAPHX_GPU_OPTIMIZE
{},
"3"
));
options
.
push_back
(
"-Wno-cuda-compat"
);
options
.
push_back
(
"--offload-arch="
+
arch
);
prog
.
compile
(
options
);
...
...
@@ -292,6 +307,8 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
return
{
compiler
.
compile
(
srcs
)};
}
#endif // MIGRAPHX_USE_HIPRTC
std
::
string
enum_params
(
std
::
size_t
count
,
std
::
string
param
)
{
std
::
vector
<
std
::
string
>
items
(
count
);
...
...
@@ -299,8 +316,6 @@ std::string enum_params(std::size_t count, std::string param)
return
join_strings
(
items
,
","
);
}
#endif // MIGRAPHX_USE_HIPRTC
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/compile_hip_code_object.cpp
View file @
673ca71c
...
...
@@ -29,7 +29,6 @@
#include <migraphx/context.hpp>
#include <migraphx_kernels.hpp>
#include <migraphx/stringutils.hpp>
#include <hip/hip_runtime_api.h>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
@@ -80,6 +79,7 @@ std::string generate_args_hpp(const std::vector<shape>& inputs)
#include <migraphx/kernels/args.hpp>
#include <migraphx/kernels/tensor_view.hpp>
#include <migraphx/kernels/types.hpp>
namespace migraphx {
...
...
src/targets/gpu/device/include/migraphx/gpu/device/reduce.hpp
View file @
673ca71c
...
...
@@ -36,6 +36,7 @@ namespace gpu {
namespace
device
{
#ifdef MIGRAPHX_NO_DPP
template
<
index_int
N
,
class
Op
,
class
T
,
...
...
@@ -62,6 +63,7 @@ __device__ auto block_reduce(index idx, Op op, T init, ForStride fs, F f)
}
return
buffer
[
0
];
}
#else
constexpr
unsigned
int
dpp_row_shr
(
unsigned
int
x
)
{
return
0x110u
|
x
;
}
...
...
@@ -96,11 +98,7 @@ __device__ T dpp_mov(T& x)
input
.
data
=
x
;
for
(
index_int
i
=
0
;
i
<
n
;
i
++
)
{
#if defined(__HCC__)
output
.
reg
[
i
]
=
__llvm_amdgcn_move_dpp
(
input
.
reg
[
i
],
DppCtrl
,
RowMask
,
BankMask
,
BoundCtrl
);
#else
output
.
reg
[
i
]
=
__hip_move_dpp
(
input
.
reg
[
i
],
DppCtrl
,
RowMask
,
BankMask
,
BoundCtrl
);
#endif
}
return
output
.
data
;
}
...
...
@@ -310,4 +308,4 @@ void reduce(hipStream_t stream,
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
#endif
// MIGRAPHX_NO_DPP
src/targets/gpu/fuse_ops.cpp
View file @
673ca71c
...
...
@@ -553,11 +553,13 @@ struct find_gemm_pointwise
{
auto
matcher
()
const
{
return
precompile_name
(
"pointwise"
)(
auto
gemm_op
=
match
::
name
(
"gpu::gemm"
)(
match
::
nargs
(
3
),
match
::
used_once
()).
bind
(
"gemm"
);
auto
binary_op
=
match
::
all_of
(
match
::
nargs
(
3
),
match
::
either_arg
(
0
,
1
)(
match
::
any_of
(
match
::
standard_shape
(),
match
::
is_constant
()).
bind
(
"c"
),
match
::
name
(
"gpu::gemm"
)(
match
::
nargs
(
3
),
match
::
used_once
()).
bind
(
"gemm"
)));
match
::
any_of
(
match
::
standard_shape
(),
match
::
is_constant
()).
bind
(
"c"
),
gemm_op
));
auto
unary_op
=
match
::
all_of
(
match
::
nargs
(
2
),
match
::
arg
(
0
)(
gemm_op
));
return
precompile_name
(
"pointwise"
)(
match
::
any_of
(
binary_op
,
unary_op
));
}
// TODO: Move to matcher.hpp
...
...
@@ -589,61 +591,84 @@ struct find_gemm_pointwise
return
match
::
name
(
"@return"
)(
match
::
args
(
match
::
any_of
(
add
,
mul_add
,
add_mul
)));
}
static
auto
match_mul
(
const
std
::
string
&
input
)
{
auto
mul
=
match_mul_const
(
match_param
(
input
),
"alpha"
);
return
match
::
name
(
"@return"
)(
match
::
args
(
mul
));
}
static
float
get_float
(
instruction_ref
ins
)
{
return
ins
->
get_literal
().
at
<
float
>
();
}
template
<
class
Gemm
>
static
bool
update_gemm
(
Gemm
&
gemm
,
module_ref
pm
,
unsigned
input
)
{
auto
names
=
pm
->
get_parameter_names
();
if
(
names
.
size
()
!=
2
)
return
false
;
std
::
sort
(
names
.
begin
(),
names
.
end
());
unsigned
output
=
input
==
0
?
1
:
0
;
auto
mr
=
match
::
match_instruction
(
*
pm
,
std
::
prev
(
pm
->
end
()),
match_add
(
names
[
input
],
names
[
output
]));
if
(
mr
.
result
==
pm
->
end
())
return
false
;
if
(
contains
(
mr
.
instructions
,
"alpha_mul"
))
if
(
names
.
size
()
==
1
)
{
auto
mr
=
match
::
match_instruction
(
*
pm
,
std
::
prev
(
pm
->
end
()),
match_mul
(
names
[
input
]));
if
(
mr
.
result
==
pm
->
end
())
return
false
;
gemm
.
alpha
*=
get_float
(
mr
.
instructions
[
"alpha"
]);
else
if
(
contains
(
mr
.
instructions
,
"beta_mul"
))
gemm
.
beta
*=
get_float
(
mr
.
instructions
[
"beta"
]);
else
if
(
contains
(
mr
.
instructions
,
"gamma_mul"
)
)
return
true
;
}
else
if
(
names
.
size
()
==
2
)
{
gemm
.
alpha
*=
get_float
(
mr
.
instructions
[
"gamma"
]);
gemm
.
beta
*=
get_float
(
mr
.
instructions
[
"gamma"
]);
unsigned
output
=
input
==
0
?
1
:
0
;
auto
mr
=
match
::
match_instruction
(
*
pm
,
std
::
prev
(
pm
->
end
()),
match_add
(
names
[
input
],
names
[
output
]));
if
(
mr
.
result
==
pm
->
end
())
return
false
;
if
(
contains
(
mr
.
instructions
,
"alpha_mul"
))
gemm
.
alpha
*=
get_float
(
mr
.
instructions
[
"alpha"
]);
else
if
(
contains
(
mr
.
instructions
,
"beta_mul"
))
gemm
.
beta
*=
get_float
(
mr
.
instructions
[
"beta"
]);
else
if
(
contains
(
mr
.
instructions
,
"gamma_mul"
))
{
gemm
.
alpha
*=
get_float
(
mr
.
instructions
[
"gamma"
]);
gemm
.
beta
*=
get_float
(
mr
.
instructions
[
"gamma"
]);
}
return
true
;
}
else
{
return
false
;
}
return
true
;
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
gemm_ins
=
r
.
instructions
[
"gemm"
];
auto
c_ins
=
r
.
instructions
[
"c"
];
auto
gemm
=
any_cast
<
rocblas_gemm
<
op
::
dot
>>
(
gemm_ins
->
get_operator
());
// Already fused gemm
if
(
not
float_equal
(
gemm
.
beta
,
0
))
return
;
gemm
.
beta
=
1
;
if
(
ins
->
inputs
().
size
()
==
3
)
gemm
.
beta
=
1
;
if
(
not
update_gemm
(
gemm
,
ins
->
module_inputs
().
front
(),
ins
->
inputs
().
front
()
==
gemm_ins
?
0
:
1
))
return
;
// const-fold input if not standard shape since rocblas can't handle it
if
(
not
c_ins
->
get_shape
().
standard
())
{
auto
c
=
make_op
(
"contiguous"
);
auto
l
=
c
.
compute
(
c
.
compute_shape
({
c_ins
->
get_shape
()}),
{
c_ins
->
eval
()});
c_ins
=
m
.
add_literal
(
l
.
get_shape
(),
l
.
data
());
}
auto
inputs
=
gemm_ins
->
inputs
();
inputs
.
pop_back
();
inputs
.
push_back
(
c_ins
);
if
(
ins
->
inputs
().
size
()
==
3
)
{
auto
c_ins
=
r
.
instructions
[
"c"
];
// const-fold input if not standard shape since rocblas can't handle it
if
(
not
c_ins
->
get_shape
().
standard
())
{
auto
c
=
make_op
(
"contiguous"
);
auto
l
=
c
.
compute
(
c
.
compute_shape
({
c_ins
->
get_shape
()}),
{
c_ins
->
eval
()});
c_ins
=
m
.
add_literal
(
l
.
get_shape
(),
l
.
data
());
}
inputs
.
push_back
(
c_ins
);
}
inputs
.
push_back
(
ins
->
inputs
().
back
());
m
.
replace_instruction
(
ins
,
gemm
,
inputs
);
...
...
src/targets/gpu/jit/gather.cpp
0 → 100644
View file @
673ca71c
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
// NOLINTNEXTLINE
static
const
char
*
const
gather_kernel
=
R"__migraphx__(
#include <migraphx/kernels/gather.hpp>
#include <migraphx/kernels/ops.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp>
#include <args.hpp>
namespace migraphx {
extern "C" {
__global__ void gather_kernel(void* in_data, void* in_indices, void* output)
{
make_tensors()(in_data, in_indices, output)([](auto&&... xs) {
gather<${axis}>(xs...);
});
}
}
} // namespace migraphx
)__migraphx__"
;
struct
gather_compiler
:
compiler
<
gather_compiler
>
{
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"gather"
};
}
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
{
hip_compile_options
options
;
const
auto
&
out_s
=
inputs
.
back
();
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
out_s
.
elements
()));
options
.
inputs
=
inputs
;
options
.
output
=
out_s
;
options
.
kernel_name
=
"gather_kernel"
;
options
.
virtual_inputs
=
inputs
;
auto
axis
=
v
.
at
(
"axis"
).
to
<
std
::
string
>
();
auto
src
=
interpolate_string
(
gather_kernel
,
{{
"axis"
,
axis
}});
return
compile_hip_code_object
(
src
,
options
);
}
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
)
const
{
return
replace
(
compile_op
(
ctx
,
to_shapes
(
ins
->
inputs
()),
op
.
to_value
()));
}
};
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/jit/reduce.cpp
View file @
673ca71c
...
...
@@ -156,16 +156,25 @@ struct reduce_compiler : compiler<reduce_compiler>
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
)
const
{
value
v
=
value
::
object
{};
auto
reduce_elements
=
get_reduce_elements
(
ins
->
inputs
());
value
v
=
value
::
object
{};
if
(
op
.
name
()
==
"reduce_sum"
)
{
v
[
"reduction"
]
=
"op::sum{}"
;
}
else
if
(
op
.
name
()
==
"reduce_mean"
)
{
v
[
"reduction"
]
=
"op::sum{}"
;
v
[
"write"
]
=
"op::mean{"
+
std
::
to_string
(
reduce_elements
)
+
"}"
;
auto
reduce_elements
=
get_reduce_elements
(
ins
->
inputs
());
auto
reduce_type
=
ins
->
inputs
().
front
()
->
get_shape
().
type
();
v
[
"reduction"
]
=
"op::sum{}"
;
std
::
string
mean
=
"op::mean{"
+
std
::
to_string
(
reduce_elements
)
+
"}"
;
// Use float accumulator when reduction size is too large for half
if
(
reduce_type
==
shape
::
half_type
and
reduce_elements
>
16384
)
v
[
"read"
]
=
"compose("
+
mean
+
", op::convert_to<float>{})"
;
else
if
(
contains
({
shape
::
float_type
,
shape
::
half_type
,
shape
::
double_type
},
reduce_type
))
v
[
"read"
]
=
mean
;
else
v
[
"write"
]
=
mean
;
}
else
if
(
op
.
name
()
==
"reduce_max"
)
{
...
...
src/targets/gpu/kernels/include/migraphx/kernels/array.hpp
View file @
673ca71c
...
...
@@ -105,7 +105,7 @@ constexpr auto array_for_each(T& x, Ts&... xs)
}
else
{
using
vec_type
=
std
::
remove_reference_t
<
decltype
(
array2vec
(
x
))
>
;
using
vec_type
=
remove_reference_t
<
decltype
(
array2vec
(
x
))
>
;
f
(
array2vec
(
x
),
__builtin_convertvector
(
array2vec
(
xs
),
vec_type
)...);
}
}
...
...
src/targets/gpu/kernels/include/migraphx/kernels/dpp.hpp
View file @
673ca71c
...
...
@@ -72,7 +72,7 @@ __device__ T dpp_mov(T& x)
}
return
output
.
data
;
}
#endif
#endif
// MIGRAPHX_HAS_DPP
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_DPP_HPP
src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp
View file @
673ca71c
...
...
@@ -187,6 +187,14 @@ constexpr auto fold(F f)
return
[
=
](
auto
&&
...
xs
)
{
return
fold_impl
(
f
,
static_cast
<
decltype
(
xs
)
&&>
(
xs
)...);
};
}
template
<
class
...
Fs
>
constexpr
auto
compose
(
Fs
...
fs
)
{
return
fold
([](
auto
f
,
auto
g
)
{
return
[
=
](
auto
&&
...
xs
)
{
return
f
(
g
(
static_cast
<
decltype
(
xs
)
>
(
xs
)...));
};
})(
fs
...);
}
template
<
class
...
Ts
>
constexpr
auto
pack
(
Ts
...
xs
)
{
...
...
src/targets/gpu/kernels/include/migraphx/kernels/gather.hpp
0 → 100644
View file @
673ca71c
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_KERNELS_GATHER_HPP
#define MIGRAPHX_GUARD_KERNELS_GATHER_HPP
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/shape.hpp>
#include <migraphx/kernels/algorithm.hpp>
#include <migraphx/kernels/tensor_view.hpp>
namespace
migraphx
{
template
<
int
Axis
,
class
Input
,
class
Indices
>
constexpr
auto
gather_shape
(
Input
input
,
Indices
indices
)
{
auto
lengths
=
input
.
lens
;
lengths
[
Axis
]
=
indices
.
elements
();
return
make_shape
(
lengths
,
input
.
strides
);
}
template
<
int
Axis
,
class
Input
,
class
Indices
,
class
Output
>
__device__
void
gather
(
Input
input
,
Indices
indices
,
Output
output
)
{
auto
ind
=
make_index
();
auto
axis_dim_size
=
input
.
get_shape
().
lens
[
Axis
];
constexpr
auto
out_comp
=
gather_shape
<
Axis
>
(
get_shape_c
<
Input
>
{},
get_shape_c
<
Indices
>
{});
ind
.
global_stride
(
output
.
get_shape
().
elements
(),
[
&
](
auto
i
)
{
auto
idx
=
out_comp
.
multi
(
i
);
auto
in_index
=
indices
[
idx
[
Axis
]];
auto
new_in_index
=
(
in_index
<
0
)
?
in_index
+
axis_dim_size
:
in_index
;
idx
[
Axis
]
=
new_in_index
;
output
[
i
]
=
input
[
idx
];
});
}
}
// namespace migraphx
#endif
src/targets/gpu/kernels/include/migraphx/kernels/gathernd.hpp
View file @
673ca71c
...
...
@@ -26,7 +26,7 @@
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp>
#include <migraphx/kernels/ops.hpp>
namespace
migraphx
{
template
<
class
T
>
...
...
@@ -53,23 +53,17 @@ __device__ void gathernd(const T& data_t, const U& indices_t, const V& output_t,
auto
indices_shape_lens
=
indices_shape
.
lens
;
auto
data_shape_lens
=
data_shape
.
lens
;
auto
num_slice_dims
=
indices_shape_lens
.
back
();
std
::
size_t
num_slices
=
accumulate
(
indices_shape_lens
.
begin
(),
indices_shape_lens
.
end
()
-
1
,
1
,
std
::
multiplies
<
std
::
size_t
>
());
std
::
size_t
slice_size
=
accumulate
(
data_shape_lens
.
begin
()
+
num_slice_dims
+
batch_dims
,
std
::
size_t
num_slices
=
accumulate
(
indices_shape_lens
.
begin
(),
indices_shape_lens
.
end
()
-
1
,
1
,
op
::
product
{});
std
::
size_t
slice_size
=
accumulate
(
data_shape_lens
.
begin
()
+
num_slice_dims
+
batch_dims
,
data_shape_lens
.
end
(),
1
,
std
::
multiplies
<
std
::
size_t
>
());
const
std
::
size_t
num_batches
=
accumulate
(
data_shape_lens
.
begin
(),
data_shape_lens
.
begin
()
+
batch_dims
,
1
,
std
::
multiplies
<
std
::
size_t
>
());
const
std
::
size_t
data_batch_stride
=
accumulate
(
data_shape_lens
.
begin
()
+
batch_dims
,
data_shape_lens
.
end
(),
1
,
std
::
multiplies
<
std
::
size_t
>
());
const
auto
num_slices_per_batch
=
num_slices
/
num_batches
;
op
::
product
{});
const
std
::
size_t
num_batches
=
accumulate
(
data_shape_lens
.
begin
(),
data_shape_lens
.
begin
()
+
batch_dims
,
1
,
op
::
product
{});
const
std
::
size_t
data_batch_stride
=
accumulate
(
data_shape_lens
.
begin
()
+
batch_dims
,
data_shape_lens
.
end
(),
1
,
op
::
product
{});
const
auto
num_slices_per_batch
=
num_slices
/
num_batches
;
ind
.
global_stride
(
output_shape
.
elements
(),
[
&
](
auto
i
)
{
const
auto
*
indices_ptr
=
indices_t
.
data
();
...
...
@@ -83,15 +77,15 @@ __device__ void gathernd(const T& data_t, const U& indices_t, const V& output_t,
int64_t
index
=
slice_indices
[
idx
];
const
std
::
size_t
input_dim_idx
=
batch_dims
+
idx
;
const
auto
input_dim
=
data_shape_lens
[
input_dim_idx
];
assert
(
index
>=
-
static_cast
<
int64_t
>
(
input_dim
)
and
index
<
static_cast
<
int64_t
>
(
input_dim
));
MIGRAPHX_ASSERT
(
index
>=
-
static_cast
<
int64_t
>
(
input_dim
)
and
index
<
static_cast
<
int64_t
>
(
input_dim
));
if
(
index
<
0
)
index
+=
input_dim
;
std
::
size_t
size_from_slice_dims
=
accumulate
(
data_shape_lens
.
begin
()
+
batch_dims
+
idx
+
1
,
data_shape_lens
.
begin
()
+
batch_dims
+
num_slice_dims
,
slice_size
,
std
::
multiplies
<
std
::
size_t
>
()
);
op
::
product
{}
);
relative_slice_offset
+=
index
*
size_from_slice_dims
;
}
...
...
src/targets/gpu/kernels/include/migraphx/kernels/hip.hpp
View file @
673ca71c
...
...
@@ -24,11 +24,18 @@
#ifndef MIGRAPHX_GUARD_KERNELS_HIP_HPP
#define MIGRAPHX_GUARD_KERNELS_HIP_HPP
#ifndef MIGRAPHX_USE_HIPRTC
// Workaround macro redefinition issue with clang tidy
#if defined(__HIP_PLATFORM_HCC__) && defined(MIGRAPHX_USE_CLANG_TIDY)
#undef __HIP_PLATFORM_HCC__ // NOLINT
#endif
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <hip/math_functions.h>
#include <hip/hip_math_constants.h>
#elif defined(MIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS)
#include <hip/hip_common.h>
#include <hip/hip_math_constants.h>
#endif
#endif // MIGRAPHX_GUARD_KERNELS_HIP_HPP
src/targets/gpu/kernels/include/migraphx/kernels/math.hpp
View file @
673ca71c
...
...
@@ -28,8 +28,7 @@
#include <migraphx/kernels/vec.hpp>
#include <migraphx/kernels/functional.hpp>
#include <migraphx/kernels/type_traits.hpp>
#include <hip/hip_fp16.h>
#include <hip/math_functions.h>
#include <migraphx/kernels/hip.hpp>
namespace
migraphx
{
...
...
@@ -132,9 +131,14 @@ MIGRAPHX_DEVICE_MATH_FOR(float, fmod, ::fmodf)
// Builtin half functions
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
abs
,
::
__habs
)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
ceil
,
::
hceil
)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
cos
,
::
hcos
)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
exp
,
::
hexp
)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
floor
,
::
hfloor
)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
isnan
,
::
__hisnan
)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
log
,
::
hlog
)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
rsqrt
,
::
hrsqrt
)
// MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, sin, ::hsin)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
sqrt
,
::
hsqrt
)
// Use float to compute half overload
...
...
@@ -144,16 +148,11 @@ MIGRAPHX_DEVICE_MATH_HALF(asin, ::asin)
MIGRAPHX_DEVICE_MATH_HALF
(
asinh
,
::
asinh
)
MIGRAPHX_DEVICE_MATH_HALF
(
atan
,
::
atan
)
MIGRAPHX_DEVICE_MATH_HALF
(
atanh
,
::
atanh
)
MIGRAPHX_DEVICE_MATH_HALF
(
ceil
,
::
ceil
)
MIGRAPHX_DEVICE_MATH_HALF
(
cos
,
::
cos
)
MIGRAPHX_DEVICE_MATH_HALF
(
cosh
,
::
cosh
)
MIGRAPHX_DEVICE_MATH_HALF
(
erf
,
::
erf
)
MIGRAPHX_DEVICE_MATH_HALF
(
floor
,
::
floor
)
MIGRAPHX_DEVICE_MATH_HALF
(
isnan
,
::
isnan
)
MIGRAPHX_DEVICE_MATH_HALF
(
pow
,
::
pow
)
MIGRAPHX_DEVICE_MATH_HALF
(
remainder
,
::
remainder
)
MIGRAPHX_DEVICE_MATH_HALF
(
round
,
::
round
)
MIGRAPHX_DEVICE_MATH_HALF
(
sin
,
::
sin
)
MIGRAPHX_DEVICE_MATH_HALF
(
sinh
,
::
sinh
)
MIGRAPHX_DEVICE_MATH_HALF
(
tan
,
::
tan
)
MIGRAPHX_DEVICE_MATH_HALF
(
tanh
,
::
tanh
)
...
...
@@ -166,19 +165,19 @@ MIGRAPHX_DEVICE_MATH_HALF(fmod, ::fmod)
// at this time are: exp2, exp10, log2, log10, isinf
MIGRAPHX_DEVICE_MATH_HALF2
(
abs
,
::
__habs2
)
MIGRAPHX_DEVICE_MATH_HALF2
(
ceil
,
::
h2ceil
)
MIGRAPHX_DEVICE_MATH_HALF2
(
floor
,
::
h2floor
)
MIGRAPHX_DEVICE_MATH_HALF2
(
sin
,
::
h2sin
)
MIGRAPHX_DEVICE_MATH_HALF2
(
cos
,
::
h2cos
)
MIGRAPHX_DEVICE_MATH_HALF2
(
exp
,
::
h2exp
)
MIGRAPHX_DEVICE_MATH_HALF2
(
exp2
,
::
h2exp2
)
MIGRAPHX_DEVICE_MATH_HALF2
(
exp10
,
::
h2exp10
)
MIGRAPHX_DEVICE_MATH_HALF2
(
log2
,
::
h2log2
)
MIGRAPHX_DEVICE_MATH_HALF2
(
exp2
,
::
h2exp2
)
MIGRAPHX_DEVICE_MATH_HALF2
(
floor
,
::
h2floor
)
MIGRAPHX_DEVICE_MATH_HALF2
(
isinf
,
::
__hisinf2
)
MIGRAPHX_DEVICE_MATH_HALF2
(
isnan
,
::
__hisnan2
)
MIGRAPHX_DEVICE_MATH_HALF2
(
log
,
::
h2log
)
MIGRAPHX_DEVICE_MATH_HALF2
(
log10
,
::
h2log10
)
MIGRAPHX_DEVICE_MATH_HALF2
(
log2
,
::
h2log2
)
MIGRAPHX_DEVICE_MATH_HALF2
(
rsqrt
,
::
h2rsqrt
)
// MIGRAPHX_DEVICE_MATH_HALF2(sin, ::h2sin)
MIGRAPHX_DEVICE_MATH_HALF2
(
sqrt
,
::
h2sqrt
)
MIGRAPHX_DEVICE_MATH_HALF2
(
isinf
,
::
__hisinf2
)
MIGRAPHX_DEVICE_MATH_HALF2
(
isnan
,
::
__hisnan2
)
template
<
class
T
,
class
U
>
constexpr
auto
where
(
bool
cond
,
const
T
&
a
,
const
U
&
b
)
...
...
@@ -218,6 +217,14 @@ constexpr auto min(const T& a, const U& b)
return
min
<
common_type_t
<
T
,
U
>>
(
a
,
b
);
}
// Sin for half is broken on hip, so use cos instead
template
<
class
T
,
MIGRAPHX_REQUIRES
(
is_same
<
vec_type
<
T
>,
half
>
{})
>
constexpr
T
sin
(
T
x
)
{
constexpr
const
T
shift
=
HIP_PIO2_F
;
return
migraphx
::
cos
(
shift
-
x
);
}
MIGRAPHX_DEVICE_MATH_VEC
(
abs
)
MIGRAPHX_DEVICE_MATH_VEC
(
acos
)
MIGRAPHX_DEVICE_MATH_VEC
(
acosh
)
...
...
src/targets/gpu/kernels/include/migraphx/kernels/ops.hpp
View file @
673ca71c
...
...
@@ -56,6 +56,16 @@ struct id
}
};
template
<
class
T
>
struct
convert_to
{
template
<
class
U
>
MIGRAPHX_DEVICE_CONSTEXPR
auto
operator
()(
U
x
)
const
{
return
convert
<
T
>
(
x
);
}
};
struct
mean
{
index_int
item_num
=
1
;
...
...
src/targets/gpu/kernels/include/migraphx/kernels/shape.hpp
View file @
673ca71c
...
...
@@ -76,14 +76,6 @@ struct shape
constexpr
index_int
index
(
index_array
x
)
const
{
return
x
.
dot
(
strides
);
}
constexpr
index_int
index
(
std
::
initializer_list
<
index_int
>
x
)
const
{
index_int
idx
=
0
;
for
(
index_int
i
=
0
;
i
<
x
.
size
();
i
++
)
idx
+=
*
(
x
.
begin
()
+
i
)
*
strides
[
i
];
return
idx
;
}
constexpr
index_int
index
(
index_int
i
)
const
{
if
(
this
->
standard
())
...
...
@@ -128,6 +120,7 @@ struct shape
result
[
0
]
=
tidx
;
return
result
;
}
/// Convert multi-index into a single index
constexpr
index_int
single
(
index_array
idx
)
const
{
...
...
src/targets/gpu/kernels/include/migraphx/kernels/types.hpp
View file @
673ca71c
...
...
@@ -28,8 +28,45 @@
namespace
migraphx
{
using
index_int
=
std
::
uint32_t
;
using
diff_int
=
std
::
int32_t
;
#if defined(MIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS) and defined(MIGRAPHX_USE_HIPRTC)
using
int8_t
=
signed
char
;
using
uint8_t
=
unsigned
char
;
using
int16_t
=
signed
short
;
using
uint16_t
=
unsigned
short
;
using
int32_t
=
signed
int
;
using
uint32_t
=
unsigned
int
;
using
int64_t
=
signed
long
long
;
using
uint64_t
=
unsigned
long
long
;
#elif defined(MIGRAPHX_USE_HIPRTC)
using
int8_t
=
__hip_int8_t
;
using
uint8_t
=
__hip_uint8_t
;
using
int16_t
=
__hip_int16_t
;
using
uint16_t
=
__hip_uint16_t
;
using
int32_t
=
__hip_int32_t
;
using
uint32_t
=
__hip_uint32_t
;
using
int64_t
=
__hip_int64_t
;
using
uint64_t
=
__hip_uint64_t
;
#else
using
int8_t
=
std
::
int8_t
;
using
uint8_t
=
std
::
uint8_t
;
using
int16_t
=
std
::
int16_t
;
using
uint16_t
=
std
::
uint16_t
;
using
int32_t
=
std
::
int32_t
;
using
uint32_t
=
std
::
uint32_t
;
using
int64_t
=
std
::
int64_t
;
using
uint64_t
=
std
::
uint64_t
;
#endif // MIGRAPHX_USE_HIPRTC
using
index_int
=
uint32_t
;
using
diff_int
=
int32_t
;
static_assert
(
sizeof
(
int8_t
)
==
1
,
"int8_t must be 1 bytes"
);
static_assert
(
sizeof
(
uint8_t
)
==
1
,
"uint8_t must be 1 bytes"
);
static_assert
(
sizeof
(
int16_t
)
==
2
,
"int16_t must be 2 bytes"
);
static_assert
(
sizeof
(
uint16_t
)
==
2
,
"uint16_t must be 2 bytes"
);
static_assert
(
sizeof
(
int32_t
)
==
4
,
"int32_t must be 4 bytes"
);
static_assert
(
sizeof
(
uint32_t
)
==
4
,
"uint32_t must be 4 bytes"
);
static_assert
(
sizeof
(
int64_t
)
==
8
,
"int64_t must be 8 bytes"
);
static_assert
(
sizeof
(
uint64_t
)
==
8
,
"uint64_t must be 8 bytes"
);
#define MIGRAPHX_DEVICE_CONSTEXPR constexpr __device__ __host__ // NOLINT
...
...
src/targets/gpu/lowering.cpp
View file @
673ca71c
...
...
@@ -90,7 +90,6 @@ struct miopen_apply
add_extend_op
(
"argmax"
);
add_extend_op
(
"argmin"
);
add_extend_op
(
"gather"
);
add_extend_op
(
"logsoftmax"
);
add_extend_op
(
"lrn"
);
add_extend_op
(
"multinomial"
);
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment