Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
40fbef9b
"vscode:/vscode.git/clone" did not exist on "240cbda06c309bf4ba51ef808c8b075b7ae3d818"
Unverified
Commit
40fbef9b
authored
Aug 05, 2023
by
Ted Themistokleous
Committed by
GitHub
Aug 05, 2023
Browse files
Merge branch 'develop' into threaded_nms
parents
d164b151
aeb9f78c
Changes
440
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
681 additions
and
124 deletions
+681
-124
src/targets/fpga/subgraph.cpp
src/targets/fpga/subgraph.cpp
+1
-2
src/targets/fpga/vitis_ai_adapter.cpp
src/targets/fpga/vitis_ai_adapter.cpp
+1
-1
src/targets/gpu/CMakeLists.txt
src/targets/gpu/CMakeLists.txt
+35
-7
src/targets/gpu/compile_gen.cpp
src/targets/gpu/compile_gen.cpp
+13
-1
src/targets/gpu/compile_hip.cpp
src/targets/gpu/compile_hip.cpp
+10
-3
src/targets/gpu/compile_hip_code_object.cpp
src/targets/gpu/compile_hip_code_object.cpp
+10
-6
src/targets/gpu/compile_miopen.cpp
src/targets/gpu/compile_miopen.cpp
+1
-1
src/targets/gpu/compile_ops.cpp
src/targets/gpu/compile_ops.cpp
+181
-12
src/targets/gpu/compiler.cpp
src/targets/gpu/compiler.cpp
+24
-12
src/targets/gpu/device/include/migraphx/gpu/device/visit.hpp
src/targets/gpu/device/include/migraphx/gpu/device/visit.hpp
+10
-10
src/targets/gpu/device/multinomial.cpp
src/targets/gpu/device/multinomial.cpp
+12
-11
src/targets/gpu/device/scatter.cpp
src/targets/gpu/device/scatter.cpp
+16
-12
src/targets/gpu/device_name.cpp
src/targets/gpu/device_name.cpp
+3
-1
src/targets/gpu/driver/CMakeLists.txt
src/targets/gpu/driver/CMakeLists.txt
+1
-1
src/targets/gpu/driver/compile_op.cpp
src/targets/gpu/driver/compile_op.cpp
+1
-1
src/targets/gpu/driver/run_op.cpp
src/targets/gpu/driver/run_op.cpp
+1
-1
src/targets/gpu/fuse_ck.cpp
src/targets/gpu/fuse_ck.cpp
+175
-0
src/targets/gpu/fuse_mlir.cpp
src/targets/gpu/fuse_mlir.cpp
+180
-35
src/targets/gpu/fuse_ops.cpp
src/targets/gpu/fuse_ops.cpp
+2
-1
src/targets/gpu/gemm_impl.cpp
src/targets/gpu/gemm_impl.cpp
+4
-6
No files found.
src/targets/fpga/subgraph.cpp
View file @
40fbef9b
...
@@ -113,8 +113,7 @@ void subgraph::apply(module_pass_manager& mpm) const
...
@@ -113,8 +113,7 @@ void subgraph::apply(module_pass_manager& mpm) const
// TODO(varunsh): this code may be replaceable by code in the fuse_pointwise pass
// TODO(varunsh): this code may be replaceable by code in the fuse_pointwise pass
// assuming all FPGA instructions are in one contiguous range
// assuming all FPGA instructions are in one contiguous range
pm
->
insert_instructions
(
pm
->
end
(),
first
,
last
,
{});
pm
->
insert_instructions
(
pm
->
end
(),
first
,
std
::
next
(
last
),
{});
migraphx
::
instruction_ref
placeholder_ins
;
migraphx
::
instruction_ref
placeholder_ins
;
for
(
auto
it
:
iterator_for
(
mod
))
for
(
auto
it
:
iterator_for
(
mod
))
{
{
...
...
src/targets/fpga/vitis_ai_adapter.cpp
View file @
40fbef9b
...
@@ -33,7 +33,7 @@ migraphx::shape x_model::get_shape() const { return shape; };
...
@@ -33,7 +33,7 @@ migraphx::shape x_model::get_shape() const { return shape; };
void
x_model
::
set_shape
(
migraphx
::
shape
s
)
{
shape
=
s
;
}
void
x_model
::
set_shape
(
migraphx
::
shape
s
)
{
shape
=
s
;
}
x_model
create_xmodel
(
const
migraphx
::
module_ref
mod
)
x_model
create_xmodel
(
migraphx
::
const_
module_ref
mod
)
{
{
std
::
cout
<<
"Calling an external function: create_xmodel!
\n
"
;
std
::
cout
<<
"Calling an external function: create_xmodel!
\n
"
;
x_model
xmodel
;
x_model
xmodel
;
...
...
src/targets/gpu/CMakeLists.txt
View file @
40fbef9b
...
@@ -33,6 +33,11 @@ if(NOT TARGET MIOpen)
...
@@ -33,6 +33,11 @@ if(NOT TARGET MIOpen)
message
(
SEND_ERROR
"Cant find miopen"
)
message
(
SEND_ERROR
"Cant find miopen"
)
endif
()
endif
()
if
(
NOT WIN32
)
# TODO: re-enable when CK is ported to Windows
find_package
(
composable_kernel 1.0.0 REQUIRED COMPONENTS jit_library
)
endif
()
if
(
BUILD_DEV
)
if
(
BUILD_DEV
)
set
(
MIGRAPHX_USE_HIPRTC OFF CACHE BOOL
"Use hipRTC APIs"
)
set
(
MIGRAPHX_USE_HIPRTC OFF CACHE BOOL
"Use hipRTC APIs"
)
else
()
else
()
...
@@ -40,12 +45,12 @@ else()
...
@@ -40,12 +45,12 @@ else()
endif
()
endif
()
include
(
Embed
)
include
(
Embed
)
file
(
GLOB KERNEL_FILES
${
CONFIGURE_DEPENDS
}
file
(
GLOB KERNEL_FILES CONFIGURE_DEPENDS
${
CMAKE_CURRENT_SOURCE_DIR
}
/kernels/include/migraphx/kernels/*.hpp
)
${
CMAKE_CURRENT_SOURCE_DIR
}
/kernels/include/migraphx/kernels/*.hpp
)
message
(
STATUS
"KERNEL_FILES:
${
KERNEL_FILES
}
"
)
message
(
STATUS
"KERNEL_FILES:
${
KERNEL_FILES
}
"
)
add_embed_library
(
migraphx_kernels
${
KERNEL_FILES
}
)
add_embed_library
(
migraphx_kernels
${
KERNEL_FILES
}
RELATIVE
${
CMAKE_CURRENT_SOURCE_DIR
}
/kernels/include/
)
file
(
GLOB DEVICE_GPU_SRCS
${
CONFIGURE_DEPENDS
}
${
CMAKE_CURRENT_SOURCE_DIR
}
/device/*.cpp
)
file
(
GLOB DEVICE_GPU_SRCS CONFIGURE_DEPENDS
${
CMAKE_CURRENT_SOURCE_DIR
}
/device/*.cpp
)
add_library
(
migraphx_device
${
DEVICE_GPU_SRCS
}
)
add_library
(
migraphx_device
${
DEVICE_GPU_SRCS
}
)
add_library
(
compile_for_gpu INTERFACE
)
add_library
(
compile_for_gpu INTERFACE
)
...
@@ -65,6 +70,8 @@ target_link_libraries(migraphx_device PUBLIC migraphx)
...
@@ -65,6 +70,8 @@ target_link_libraries(migraphx_device PUBLIC migraphx)
target_link_libraries
(
migraphx_device PRIVATE compile_for_gpu
)
target_link_libraries
(
migraphx_device PRIVATE compile_for_gpu
)
target_include_directories
(
migraphx_device PUBLIC $<BUILD_INTERFACE:
${
CMAKE_CURRENT_SOURCE_DIR
}
/include>
)
target_include_directories
(
migraphx_device PUBLIC $<BUILD_INTERFACE:
${
CMAKE_CURRENT_SOURCE_DIR
}
/include>
)
target_include_directories
(
migraphx_device PRIVATE $<BUILD_INTERFACE:
${
CMAKE_CURRENT_SOURCE_DIR
}
/device/include>
)
target_include_directories
(
migraphx_device PRIVATE $<BUILD_INTERFACE:
${
CMAKE_CURRENT_SOURCE_DIR
}
/device/include>
)
target_compile_options
(
migraphx_device PRIVATE -Wno-ignored-attributes
)
migraphx_generate_export_header
(
migraphx_device DIRECTORY migraphx/gpu/device
)
add_library
(
kernel_file_check EXCLUDE_FROM_ALL
)
add_library
(
kernel_file_check EXCLUDE_FROM_ALL
)
...
@@ -80,7 +87,13 @@ target_link_libraries(kernel_file_check compile_for_gpu)
...
@@ -80,7 +87,13 @@ target_link_libraries(kernel_file_check compile_for_gpu)
rocm_clang_tidy_check
(
kernel_file_check
)
rocm_clang_tidy_check
(
kernel_file_check
)
file
(
GLOB JIT_GPU_SRCS
${
CONFIGURE_DEPENDS
}
${
CMAKE_CURRENT_SOURCE_DIR
}
/jit/*.cpp
)
file
(
GLOB JIT_GPU_SRCS CONFIGURE_DEPENDS
${
CMAKE_CURRENT_SOURCE_DIR
}
/jit/*.cpp
)
if
(
WIN32
)
# TODO: re-enable when CK is ported to Windows
list
(
REMOVE_ITEM JIT_GPU_SRCS
${
CMAKE_CURRENT_SOURCE_DIR
}
/jit/ck_gemm.cpp
)
endif
()
add_library
(
migraphx_gpu
add_library
(
migraphx_gpu
abs.cpp
abs.cpp
analyze_streams.cpp
analyze_streams.cpp
...
@@ -95,6 +108,7 @@ add_library(migraphx_gpu
...
@@ -95,6 +108,7 @@ add_library(migraphx_gpu
compile_miopen.cpp
compile_miopen.cpp
compiler.cpp
compiler.cpp
device_name.cpp
device_name.cpp
fuse_ck.cpp
fuse_mlir.cpp
fuse_mlir.cpp
fuse_ops.cpp
fuse_ops.cpp
gather.cpp
gather.cpp
...
@@ -123,11 +137,14 @@ add_library(migraphx_gpu
...
@@ -123,11 +137,14 @@ add_library(migraphx_gpu
schedule_model.cpp
schedule_model.cpp
sync_device.cpp
sync_device.cpp
target.cpp
target.cpp
time_op.cpp
topk.cpp
topk.cpp
write_literals.cpp
write_literals.cpp
${
JIT_GPU_SRCS
}
${
JIT_GPU_SRCS
}
)
)
set_target_properties
(
migraphx_gpu PROPERTIES EXPORT_NAME gpu
)
set_target_properties
(
migraphx_gpu PROPERTIES EXPORT_NAME gpu
)
migraphx_generate_export_header
(
migraphx_gpu
)
function
(
register_migraphx_gpu_ops PREFIX
)
function
(
register_migraphx_gpu_ops PREFIX
)
foreach
(
OP
${
ARGN
}
)
foreach
(
OP
${
ARGN
}
)
...
@@ -169,7 +186,7 @@ register_op(migraphx_gpu
...
@@ -169,7 +186,7 @@ register_op(migraphx_gpu
OPERATORS gpu::rocblas_gemm<op::dot> gpu::rocblas_gemm<op::quant_dot>
OPERATORS gpu::rocblas_gemm<op::dot> gpu::rocblas_gemm<op::quant_dot>
INCLUDES migraphx/gpu/context.hpp
)
INCLUDES migraphx/gpu/context.hpp
)
register_op
(
migraphx_gpu HEADER migraphx/gpu/convolution.hpp
register_op
(
migraphx_gpu HEADER migraphx/gpu/convolution.hpp
OPERATORS gpu::miopen_convolution<op::convolution> gpu::miopen_convolution<op::
de
convolution> gpu::miopen_convolution<op::quant_convolution>
OPERATORS gpu::miopen_convolution<op::convolution> gpu::miopen_convolution<op::convolution
_backwards
> gpu::miopen_convolution<op::quant_convolution>
INCLUDES migraphx/gpu/context.hpp
)
INCLUDES migraphx/gpu/context.hpp
)
rocm_set_soversion
(
migraphx_gpu
${
MIGRAPHX_SO_VERSION
}
)
rocm_set_soversion
(
migraphx_gpu
${
MIGRAPHX_SO_VERSION
}
)
rocm_clang_tidy_check
(
migraphx_gpu
)
rocm_clang_tidy_check
(
migraphx_gpu
)
...
@@ -181,7 +198,9 @@ if(MIGRAPHX_ENABLE_MLIR)
...
@@ -181,7 +198,9 @@ if(MIGRAPHX_ENABLE_MLIR)
find_package
(
rocMLIR 1.0.0 CONFIG REQUIRED
)
find_package
(
rocMLIR 1.0.0 CONFIG REQUIRED
)
message
(
STATUS
"Build with rocMLIR::rockCompiler
${
rocMLIR_VERSION
}
"
)
message
(
STATUS
"Build with rocMLIR::rockCompiler
${
rocMLIR_VERSION
}
"
)
target_compile_definitions
(
migraphx_gpu PRIVATE
"-DMIGRAPHX_MLIR"
)
target_compile_definitions
(
migraphx_gpu PRIVATE
"-DMIGRAPHX_MLIR"
)
target_link_libraries
(
migraphx_gpu PUBLIC rocMLIR::rockCompiler
)
# Make this private to avoid multiple inclusions of LLVM symbols.
# TODO: Fix rocMLIR's library to hide LLVM internals.
target_link_libraries
(
migraphx_gpu PRIVATE rocMLIR::rockCompiler
)
endif
()
endif
()
if
(
MIGRAPHX_USE_HIPRTC
)
if
(
MIGRAPHX_USE_HIPRTC
)
...
@@ -227,7 +246,12 @@ check_library_exists(MIOpen "miopenFindSolutions" "${MIOPEN_LOCATION}" HAS_FIND_
...
@@ -227,7 +246,12 @@ check_library_exists(MIOpen "miopenFindSolutions" "${MIOPEN_LOCATION}" HAS_FIND_
set
(
MIGRAPHX_USE_FIND_2_API
"
${
HAS_FIND_2_API
}
"
CACHE BOOL
""
)
set
(
MIGRAPHX_USE_FIND_2_API
"
${
HAS_FIND_2_API
}
"
CACHE BOOL
""
)
if
(
MIGRAPHX_USE_FIND_2_API
)
if
(
MIGRAPHX_USE_FIND_2_API
)
target_compile_definitions
(
migraphx_gpu PUBLIC -DMIGRAPHX_HAS_FIND_2_API
)
check_library_exists
(
MIOpen
"miopenSetFindOptionPreallocatedTensor"
"
${
MIOPEN_LOCATION
}
"
HAS_PREALLOCATION_API
)
if
(
HAS_PREALLOCATION_API
)
target_compile_definitions
(
migraphx_gpu PUBLIC -DMIGRAPHX_HAS_FIND_2_API -DMIGRAPHX_PREALLOCATE_MIOPEN_BUFFERS
)
else
()
target_compile_definitions
(
migraphx_gpu PUBLIC -DMIGRAPHX_HAS_FIND_2_API
)
endif
()
message
(
STATUS
"MIGraphx is using Find-2.0 API of MIOpen"
)
message
(
STATUS
"MIGraphx is using Find-2.0 API of MIOpen"
)
else
()
else
()
message
(
STATUS
"MIGraphx is using legacy Find API in MIOpen"
)
message
(
STATUS
"MIGraphx is using legacy Find API in MIOpen"
)
...
@@ -242,6 +266,10 @@ endif()
...
@@ -242,6 +266,10 @@ endif()
target_link_libraries
(
migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas
)
target_link_libraries
(
migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas
)
target_link_libraries
(
migraphx_gpu PRIVATE migraphx_device migraphx_kernels
)
target_link_libraries
(
migraphx_gpu PRIVATE migraphx_device migraphx_kernels
)
if
(
NOT WIN32
)
# TODO: re-enable when CK is ported to Windows
target_link_libraries
(
migraphx_gpu PRIVATE composable_kernel::jit_library
)
endif
()
add_subdirectory
(
driver
)
add_subdirectory
(
driver
)
add_subdirectory
(
hiprtc
)
add_subdirectory
(
hiprtc
)
...
...
src/targets/gpu/compile_gen.cpp
View file @
40fbef9b
...
@@ -29,6 +29,7 @@
...
@@ -29,6 +29,7 @@
#include <migraphx/module.hpp>
#include <migraphx/module.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/rewrite_quantization.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
...
@@ -171,7 +172,8 @@ std::string make_transformer_args(std::vector<std::string> transformers)
...
@@ -171,7 +172,8 @@ std::string make_transformer_args(std::vector<std::string> transformers)
void
generate_pointwise
(
cpp_generator
&
gg
,
const
module
&
pm
,
const
std
::
string
&
name
)
void
generate_pointwise
(
cpp_generator
&
gg
,
const
module
&
pm
,
const
std
::
string
&
name
)
{
{
module
m
=
pm
;
module
m
=
pm
;
run_passes
(
m
,
{
eliminate_common_subexpression
{},
dead_code_elimination
{}});
run_passes
(
m
,
{
rewrite_quantization
{},
eliminate_common_subexpression
{},
dead_code_elimination
{}});
cpp_generator
g
;
cpp_generator
g
;
g
.
fmap
([](
const
std
::
string
&
fname
)
{
return
"migraphx::"
+
fname
;
});
g
.
fmap
([](
const
std
::
string
&
fname
)
{
return
"migraphx::"
+
fname
;
});
g
.
add_point_op
(
"where"
,
"${function:where}(${0}, ${1}, ${2})"
);
g
.
add_point_op
(
"where"
,
"${function:where}(${0}, ${1}, ${2})"
);
...
@@ -280,6 +282,14 @@ std::string generate_reduce(const module& m, const std::string& name)
...
@@ -280,6 +282,14 @@ std::string generate_reduce(const module& m, const std::string& name)
not
input
->
get_shape
().
broadcasted
();
not
input
->
get_shape
().
broadcasted
();
});
});
auto
inner_names
=
names
;
auto
inner_names
=
names
;
for
(
auto
input
:
ins
->
inputs
())
{
if
(
input
->
name
()
!=
"@param"
)
continue
;
if
(
contains
(
tensors
,
input
))
continue
;
inner_names
[
input
]
+=
"[out_idx]"
;
}
for
(
auto
input
:
tensors
)
for
(
auto
input
:
tensors
)
inner_names
[
input
]
+=
"_lambda_param"
;
inner_names
[
input
]
+=
"_lambda_param"
;
auto
call_function
=
auto
call_function
=
...
@@ -308,6 +318,8 @@ std::string generate_reduce(const module& m, const std::string& name)
...
@@ -308,6 +318,8 @@ std::string generate_reduce(const module& m, const std::string& name)
});
});
f
.
set_attributes
({
"__device__"
,
"__attribute__((const))"
}).
set_generic_types
(
m
).
set_name
(
name
);
f
.
set_attributes
({
"__device__"
,
"__attribute__((const))"
}).
set_generic_types
(
m
).
set_name
(
name
);
f
.
add_generic_param
(
"r"
);
f
.
add_generic_param
(
"r"
);
f
.
add_generic_param
(
"out_idx"
);
f
.
unused_param
(
"out_idx"
);
g
.
create_function
(
f
);
g
.
create_function
(
f
);
return
g
.
str
();
return
g
.
str
();
}
}
...
...
src/targets/gpu/compile_hip.cpp
View file @
40fbef9b
...
@@ -56,9 +56,6 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DUMP_SRC);
...
@@ -56,9 +56,6 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DUMP_SRC);
#ifdef MIGRAPHX_USE_HIPRTC
#ifdef MIGRAPHX_USE_HIPRTC
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_TRACE_HIPRTC
);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS
);
std
::
string
hiprtc_error
(
hiprtcResult
err
,
const
std
::
string
&
msg
)
std
::
string
hiprtc_error
(
hiprtcResult
err
,
const
std
::
string
&
msg
)
{
{
return
"hiprtc: "
+
(
hiprtcGetErrorString
(
err
)
+
(
": "
+
msg
));
return
"hiprtc: "
+
(
hiprtcGetErrorString
(
err
)
+
(
": "
+
msg
));
...
@@ -194,6 +191,7 @@ std::vector<std::vector<char>> compile_hip_src_with_hiprtc(std::vector<hiprtc_sr
...
@@ -194,6 +191,7 @@ std::vector<std::vector<char>> compile_hip_src_with_hiprtc(std::vector<hiprtc_sr
options
.
push_back
(
"-DMIGRAPHX_HAS_DPP=0"
);
options
.
push_back
(
"-DMIGRAPHX_HAS_DPP=0"
);
options
.
push_back
(
"-DMIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS=1"
);
options
.
push_back
(
"-DMIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS=1"
);
options
.
push_back
(
"-Wno-reserved-identifier"
);
options
.
push_back
(
"-Wno-reserved-identifier"
);
options
.
push_back
(
"-Wno-unused-parameter"
);
options
.
push_back
(
"-Wno-gnu-line-marker"
);
options
.
push_back
(
"-Wno-gnu-line-marker"
);
options
.
push_back
(
"-Wno-old-style-cast"
);
options
.
push_back
(
"-Wno-old-style-cast"
);
}
}
...
@@ -216,6 +214,15 @@ std::vector<std::vector<char>>
...
@@ -216,6 +214,15 @@ std::vector<std::vector<char>>
compile_hip_src
(
const
std
::
vector
<
src_file
>&
srcs
,
std
::
string
params
,
const
std
::
string
&
arch
)
compile_hip_src
(
const
std
::
vector
<
src_file
>&
srcs
,
std
::
string
params
,
const
std
::
string
&
arch
)
{
{
std
::
vector
<
hiprtc_src_file
>
hsrcs
{
srcs
.
begin
(),
srcs
.
end
()};
std
::
vector
<
hiprtc_src_file
>
hsrcs
{
srcs
.
begin
(),
srcs
.
end
()};
if
(
enabled
(
MIGRAPHX_GPU_DUMP_SRC
{}))
{
for
(
const
auto
&
src
:
srcs
)
{
if
(
src
.
path
.
extension
()
!=
".cpp"
)
continue
;
std
::
cout
<<
std
::
string
(
src
.
content
.
first
,
src
.
len
())
<<
std
::
endl
;
}
}
auto
p
=
dynamic_loader
::
path
(
&
compile_hip_src_with_hiprtc
);
auto
p
=
dynamic_loader
::
path
(
&
compile_hip_src_with_hiprtc
);
auto
driver
=
p
.
parent_path
().
parent_path
()
/
"bin"
/
"migraphx-hiprtc-driver"
;
auto
driver
=
p
.
parent_path
().
parent_path
()
/
"bin"
/
"migraphx-hiprtc-driver"
;
...
...
src/targets/gpu/compile_hip_code_object.cpp
View file @
40fbef9b
...
@@ -135,10 +135,14 @@ compute_global_for(context& ctx, std::size_t n, std::size_t over)
...
@@ -135,10 +135,14 @@ compute_global_for(context& ctx, std::size_t n, std::size_t over)
std
::
size_t
max_global
=
ctx
.
get_current_device
().
get_cu_count
()
*
std
::
size_t
max_global
=
ctx
.
get_current_device
().
get_cu_count
()
*
ctx
.
get_current_device
().
get_max_workitems_per_cu
();
ctx
.
get_current_device
().
get_max_workitems_per_cu
();
return
[
n
,
over
,
max_global
](
std
::
size_t
local
)
{
return
[
n
,
over
,
max_global
](
std
::
size_t
local
)
{
std
::
size_t
groups
=
(
n
+
local
-
1
)
/
local
;
// hip require global workitems multiple of local workitems. It may degrade performance.
std
::
size_t
max_blocks
=
max_global
/
local
;
// [TODO]: consider adding "fno-hip-uniform-block" flag when it becomes available.
std
::
size_t
nglobal
=
std
::
min
(
max_blocks
*
over
,
groups
)
*
local
;
// https://reviews.llvm.org/D155213
return
std
::
min
(
nglobal
,
n
);
std
::
size_t
num_elements
=
((
n
+
local
-
1
)
/
local
)
*
local
;
std
::
size_t
groups
=
(
num_elements
+
local
-
1
)
/
local
;
std
::
size_t
max_blocks
=
max_global
/
local
;
std
::
size_t
nglobal
=
std
::
min
(
max_blocks
*
over
,
groups
)
*
local
;
return
std
::
min
(
nglobal
,
num_elements
);
};
};
}
}
...
@@ -156,14 +160,14 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option
...
@@ -156,14 +160,14 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option
assert
(
not
options
.
inputs
.
empty
());
assert
(
not
options
.
inputs
.
empty
());
assert
(
options
.
inputs
.
size
()
==
options
.
virtual_inputs
.
size
()
or
assert
(
options
.
inputs
.
size
()
==
options
.
virtual_inputs
.
size
()
or
options
.
virtual_inputs
.
empty
());
options
.
virtual_inputs
.
empty
());
std
::
vector
<
src_file
>
srcs
;
std
::
vector
<
src_file
>
srcs
=
options
.
additional_src_files
;
std
::
transform
(
migraphx_kernels
().
begin
(),
std
::
transform
(
migraphx_kernels
().
begin
(),
migraphx_kernels
().
end
(),
migraphx_kernels
().
end
(),
std
::
back_inserter
(
srcs
),
std
::
back_inserter
(
srcs
),
[](
auto
&&
p
)
{
[](
auto
&&
p
)
{
auto
&&
name
=
p
.
first
;
auto
&&
name
=
p
.
first
;
auto
&&
c
=
p
.
second
;
auto
&&
c
=
p
.
second
;
auto
path
=
fs
::
path
{
"migraphx"
}
/
"kernels"
/
name
;
auto
path
=
name
;
return
src_file
{
path
,
c
};
return
src_file
{
path
,
c
};
});
});
srcs
.
push_back
(
src_file
{
fs
::
path
{
"main.cpp"
},
srcs
.
push_back
(
src_file
{
fs
::
path
{
"main.cpp"
},
...
...
src/targets/gpu/compile_miopen.cpp
View file @
40fbef9b
...
@@ -79,7 +79,7 @@ void compile_miopen::apply(module& m) const
...
@@ -79,7 +79,7 @@ void compile_miopen::apply(module& m) const
std
::
size_t
ws
=
0
;
std
::
size_t
ws
=
0
;
try
try
{
{
// for the regular convolution and
de
convolution, this try would always succeed
// for the regular convolution and convolution
_backwards
, this try would always succeed
ws
=
compile
(
op
,
ins
,
int8_x4_format
);
ws
=
compile
(
op
,
ins
,
int8_x4_format
);
}
}
catch
(
migraphx
::
exception
&
)
catch
(
migraphx
::
exception
&
)
...
...
src/targets/gpu/compile_ops.cpp
View file @
40fbef9b
...
@@ -30,6 +30,7 @@
...
@@ -30,6 +30,7 @@
#include <migraphx/register_op.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/op/identity.hpp>
#include <migraphx/op/identity.hpp>
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/gpu/time_op.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -76,33 +77,201 @@ struct compiled_result
...
@@ -76,33 +77,201 @@ struct compiled_result
instruction_ref
ins
;
instruction_ref
ins
;
};
};
struct
problem_cache
{
bool
has
(
const
std
::
string
&
name
,
const
value
&
problem
)
const
{
return
contains
(
cache
,
create_key
(
name
,
problem
));
}
void
insert
(
const
std
::
string
&
name
,
const
value
&
problem
,
const
value
&
solution
)
{
assert
(
not
solution
.
is_null
());
cache
[
create_key
(
name
,
problem
)]
=
solution
;
}
void
mark
(
const
std
::
string
&
name
,
const
value
&
problem
)
{
cache
.
insert
(
std
::
make_pair
(
create_key
(
name
,
problem
),
value
{}));
}
optional
<
value
>
get
(
const
std
::
string
&
name
,
const
value
&
problem
)
const
{
auto
it
=
cache
.
find
(
create_key
(
name
,
problem
));
if
(
it
==
cache
.
end
())
return
nullopt
;
return
it
->
second
;
}
static
value
create_key
(
const
std
::
string
&
name
,
const
value
&
problem
)
{
return
{{
"name"
,
name
},
{
"problem"
,
problem
}};
}
std
::
unordered_map
<
value
,
value
>
cache
;
};
struct
compile_plan
{
context
*
ctx
;
operation
preop
;
instruction_ref
ins
;
optional
<
tuning_config
>
config
=
nullopt
;
std
::
vector
<
optional
<
compiled_result
>>
results
=
{};
void
update_config
(
bool
exhaustive
)
{
config
=
get_tuning_config
(
*
ctx
,
ins
,
preop
,
exhaustive
);
}
template
<
class
Vector
>
void
insert_compiles
(
Vector
&
compiles
,
const
value
&
solution
,
std
::
size_t
i
)
{
compiles
.
emplace_back
([
=
]
{
try
{
results
[
i
]
=
compiled_result
{
compile
(
*
ctx
,
ins
,
preop
,
solution
),
ins
};
}
catch
(...)
{
results
[
i
]
=
nullopt
;
}
});
}
template
<
class
Vector
>
void
add_compiles
(
Vector
&
compiles
,
problem_cache
&
pc
)
{
if
(
config
.
has_value
())
{
const
auto
&
problem
=
config
->
problem
;
if
(
auto
sol
=
pc
.
get
(
preop
.
name
(),
problem
))
{
auto
solution
=
sol
.
value
();
// No solution yet until benchmarked so skip for now
if
(
solution
.
is_null
())
return
;
results
.
resize
(
1
);
insert_compiles
(
compiles
,
solution
,
0
);
}
else
{
pc
.
mark
(
preop
.
name
(),
problem
);
const
auto
&
solutions
=
config
->
solutions
;
results
.
resize
(
solutions
.
size
());
for
(
auto
i
:
range
(
solutions
.
size
()))
{
auto
solution
=
solutions
[
i
];
insert_compiles
(
compiles
,
solution
,
i
);
}
}
}
else
{
results
.
resize
(
1
);
insert_compiles
(
compiles
,
value
{},
0
);
}
}
const
compiled_result
&
benchmark
(
problem_cache
&
pc
)
const
{
if
(
results
.
empty
())
MIGRAPHX_THROW
(
"No configs to tune"
);
if
(
results
.
size
()
==
1
)
{
if
(
not
results
.
front
().
has_value
())
MIGRAPHX_THROW
(
"No configs to tune"
);
return
*
results
.
front
();
}
if
(
not
config
)
MIGRAPHX_THROW
(
"Multiple kernels without config"
);
std
::
cout
<<
"Benchmarking "
<<
preop
.
name
()
<<
": "
<<
results
.
size
()
<<
" configs"
<<
std
::
endl
;
std
::
vector
<
double
>
times
;
times
.
reserve
(
results
.
size
());
std
::
transform
(
results
.
begin
(),
results
.
end
(),
std
::
back_inserter
(
times
),
[
&
](
const
auto
&
cr
)
{
if
(
not
cr
.
has_value
())
return
std
::
numeric_limits
<
double
>::
max
();
return
time_op
(
*
ctx
,
cr
->
replace
.
code_object
,
to_shapes
(
cr
->
ins
->
inputs
()),
20
)
.
first
;
});
auto
i
=
std
::
distance
(
times
.
begin
(),
std
::
min_element
(
times
.
begin
(),
times
.
end
()));
std
::
cout
<<
"Fastest solution: "
<<
config
->
solutions
.
at
(
i
)
<<
std
::
endl
;
pc
.
insert
(
preop
.
name
(),
config
->
problem
,
config
->
solutions
.
at
(
i
));
if
(
not
results
[
i
].
has_value
())
MIGRAPHX_THROW
(
"No valid tuned compilation."
);
return
*
results
[
i
];
}
void
replace
(
module
&
m
,
problem_cache
&
pc
)
const
{
const
auto
&
cr
=
benchmark
(
pc
);
cr
.
replace
.
replace
(
m
,
cr
.
ins
);
}
};
template
<
class
F
>
template
<
class
F
>
void
par_compile
(
std
::
size_t
n
,
F
f
)
void
par_compile
(
std
::
size_t
n
,
F
f
)
{
{
if
(
n
==
0
)
if
(
n
==
0
)
return
;
return
;
par_for
(
n
,
n
/
value_of
(
MIGRAPHX_GPU_COMPILE_PARALLEL
{},
n
),
f
);
auto
d
=
value_of
(
MIGRAPHX_GPU_COMPILE_PARALLEL
{});
if
(
d
==
0
)
d
=
n
;
par_for
(
n
,
n
/
d
,
f
);
}
}
void
compile_
ops
::
apply
(
module
&
m
)
const
struct
compile_
manager
{
{
std
::
vector
<
std
::
function
<
compiled_result
()
>>
compiles
;
problem_cache
pc
;
std
::
vector
<
compile_plan
>
cps
;
bool
exhaustive
=
false
;
template
<
class
...
Ts
>
void
add_plan
(
Ts
&&
...
xs
)
{
cps
.
push_back
({
std
::
forward
<
Ts
>
(
xs
)...});
}
void
update_configs
()
{
par_compile
(
cps
.
size
(),
[
&
](
auto
i
)
{
cps
[
i
].
update_config
(
exhaustive
);
});
}
void
compile
(
module
&
m
)
{
std
::
vector
<
std
::
function
<
void
()
>>
compiles
;
for
(
auto
&
cp
:
cps
)
{
cp
.
add_compiles
(
compiles
,
pc
);
}
par_compile
(
compiles
.
size
(),
[
&
](
auto
i
)
{
compiles
[
i
]();
});
// Replace and/or benchmark
for
(
const
auto
&
cp
:
cps
)
{
if
(
cp
.
results
.
empty
())
continue
;
cp
.
replace
(
m
,
pc
);
}
// Remove compile_plan already executed
cps
.
erase
(
std
::
remove_if
(
cps
.
begin
(),
cps
.
end
(),
[](
const
auto
&
cp
)
{
return
not
cp
.
results
.
empty
();
}),
cps
.
end
());
}
};
void
compile_ops
::
apply
(
module
&
m
)
const
{
compile_manager
cm
;
cm
.
exhaustive
=
exhaustive_tune
;
// Find all precompile opes
for
(
auto
ins
:
iterator_for
(
m
))
for
(
auto
ins
:
iterator_for
(
m
))
{
{
if
(
ins
->
name
()
!=
"gpu::precompile_op"
)
if
(
ins
->
name
()
!=
"gpu::precompile_op"
)
continue
;
continue
;
operation
preop
=
any_cast
<
precompile_op
>
(
ins
->
get_operator
()).
op
;
operation
preop
=
any_cast
<
precompile_op
>
(
ins
->
get_operator
()).
op
;
compiles
.
emplace_back
([
=
]()
->
compiled_result
{
cm
.
add_plan
(
ctx
,
preop
,
ins
);
return
{
compile
(
*
ctx
,
ins
,
preop
),
ins
};
});
}
std
::
vector
<
compiled_result
>
results
(
compiles
.
size
());
par_compile
(
compiles
.
size
(),
[
&
](
auto
i
)
{
results
[
i
]
=
compiles
[
i
]();
});
for
(
const
auto
&
cr
:
results
)
{
cr
.
replace
(
m
,
cr
.
ins
);
}
}
cm
.
update_configs
();
cm
.
compile
(
m
);
// Compile already tuned configs
cm
.
compile
(
m
);
assert
(
cm
.
cps
.
empty
());
}
}
}
// namespace gpu
}
// namespace gpu
...
...
src/targets/gpu/compiler.cpp
View file @
40fbef9b
...
@@ -28,33 +28,45 @@ namespace migraphx {
...
@@ -28,33 +28,45 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
gpu
{
auto
&
compiler_map
()
namespace
{
struct
compiler_handle
{
{
static
std
::
unordered_map
<
std
::
string
,
compiler_compile
>
m
;
// NOLINT
compiler_compile
compile
;
return
m
;
compiler_compile_op
compile_op
;
}
compiler_tuning_config
get_tuning_config
;
};
}
// namespace
auto
&
compiler_
op_
map
()
auto
&
compiler_map
()
{
{
static
std
::
unordered_map
<
std
::
string
,
compiler_
compile_op
>
m
;
// NOLINT
static
std
::
unordered_map
<
std
::
string
,
compiler_
handle
>
m
;
// NOLINT
return
m
;
return
m
;
}
}
void
register_compiler
(
const
std
::
string
&
name
,
compiler_compile
c
,
compiler_compile_op
cop
)
void
register_compiler
(
const
std
::
string
&
name
,
compiler_compile
c
,
compiler_compile_op
cop
,
compiler_tuning_config
ctg
)
{
{
compiler_map
()[
name
]
=
std
::
move
(
c
);
compiler_map
()[
name
]
=
{
std
::
move
(
c
),
std
::
move
(
cop
),
std
::
move
(
ctg
)};
compiler_op_map
()[
name
]
=
std
::
move
(
cop
);
}
}
bool
has_compiler_for
(
const
std
::
string
&
name
)
{
return
compiler_map
().
count
(
name
)
>
0
;
}
bool
has_compiler_for
(
const
std
::
string
&
name
)
{
return
compiler_map
().
count
(
name
)
>
0
;
}
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
)
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
,
const
value
&
solution
)
{
{
return
compiler_map
().
at
(
op
.
name
())(
ctx
,
ins
,
op
);
return
compiler_map
().
at
(
op
.
name
())
.
compile
(
ctx
,
ins
,
op
,
solution
);
}
}
operation
operation
compile_op
(
const
std
::
string
&
name
,
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
compile_op
(
const
std
::
string
&
name
,
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
{
{
return
compiler_op_map
().
at
(
name
)(
ctx
,
inputs
,
v
);
return
compiler_map
().
at
(
name
).
compile_op
(
ctx
,
inputs
,
v
);
}
optional
<
tuning_config
>
get_tuning_config
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
,
bool
exhaustive
)
{
return
compiler_map
().
at
(
op
.
name
()).
get_tuning_config
(
ctx
,
ins
,
op
,
exhaustive
);
}
}
}
// namespace gpu
}
// namespace gpu
...
...
src/targets/gpu/device/include/migraphx/gpu/device/visit.hpp
View file @
40fbef9b
...
@@ -94,6 +94,10 @@ template <>
...
@@ -94,6 +94,10 @@ template <>
struct
is_hip_type
<
std
::
uint8_t
>
:
std
::
true_type
struct
is_hip_type
<
std
::
uint8_t
>
:
std
::
true_type
{
{
};
};
template
<
>
struct
is_hip_type
<
std
::
int32_t
>
:
std
::
true_type
{
};
template
<
class
T
,
class
V
,
MIGRAPHX_REQUIRES
(
is_hip_type
<
typename
T
::
type
>{})
>
template
<
class
T
,
class
V
,
MIGRAPHX_REQUIRES
(
is_hip_type
<
typename
T
::
type
>{})
>
void
hip_visitor_invoke
(
T
as
,
V
&&
v
)
void
hip_visitor_invoke
(
T
as
,
V
&&
v
)
...
@@ -120,12 +124,10 @@ void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs)
...
@@ -120,12 +124,10 @@ void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs)
if
(
not
std
::
all_of
(
if
(
not
std
::
all_of
(
types
.
begin
(),
types
.
end
(),
[
&
](
migraphx
::
shape
::
type_t
t
)
{
return
t
==
s
.
type
();
}))
types
.
begin
(),
types
.
end
(),
[
&
](
migraphx
::
shape
::
type_t
t
)
{
return
t
==
s
.
type
();
}))
MIGRAPHX_THROW
(
"Types must be the same"
);
MIGRAPHX_THROW
(
"Types must be the same"
);
std
::
initializer_list
<
index_int
>
ranks
=
{
std
::
initializer_list
<
index_int
>
ranks
=
{
static_cast
<
index_int
>
(
get_shape
(
xs
).
ndim
())...};
static_cast
<
index_int
>
(
get_shape
(
xs
).
lens
().
size
())...};
if
(
not
std
::
all_of
(
ranks
.
begin
(),
ranks
.
end
(),
[
&
](
index_int
r
)
{
return
r
==
s
.
ndim
();
}))
if
(
not
std
::
all_of
(
ranks
.
begin
(),
ranks
.
end
(),
[
&
](
index_int
r
)
{
return
r
==
s
.
lens
().
size
();
}))
MIGRAPHX_THROW
(
"Ranks must be the same"
);
MIGRAPHX_THROW
(
"Ranks must be the same"
);
visit_tensor_size
(
s
.
lens
().
size
(),
[
&
](
auto
ndim
)
{
visit_tensor_size
(
s
.
ndim
(),
[
&
](
auto
ndim
)
{
s
.
visit_type
(
hip_visitor
([
&
](
auto
as
)
{
v
(
f
(
xs
,
ndim
,
as
)...);
}));
s
.
visit_type
(
hip_visitor
([
&
](
auto
as
)
{
v
(
f
(
xs
,
ndim
,
as
)...);
}));
});
});
}
}
...
@@ -133,12 +135,10 @@ void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs)
...
@@ -133,12 +135,10 @@ void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs)
template
<
class
V
,
class
F
,
class
...
Ts
>
template
<
class
V
,
class
F
,
class
...
Ts
>
void
hip_visit_views_impl
(
const
shape
&
s
,
F
f
,
V
&&
v
,
Ts
&&
...
xs
)
void
hip_visit_views_impl
(
const
shape
&
s
,
F
f
,
V
&&
v
,
Ts
&&
...
xs
)
{
{
std
::
initializer_list
<
index_int
>
ranks
=
{
std
::
initializer_list
<
index_int
>
ranks
=
{
static_cast
<
index_int
>
(
get_shape
(
xs
).
ndim
())...};
static_cast
<
index_int
>
(
get_shape
(
xs
).
lens
().
size
())...};
if
(
not
std
::
all_of
(
ranks
.
begin
(),
ranks
.
end
(),
[
&
](
index_int
r
)
{
return
r
==
s
.
ndim
();
}))
if
(
not
std
::
all_of
(
ranks
.
begin
(),
ranks
.
end
(),
[
&
](
index_int
r
)
{
return
r
==
s
.
lens
().
size
();
}))
MIGRAPHX_THROW
(
"Ranks must be the same"
);
MIGRAPHX_THROW
(
"Ranks must be the same"
);
visit_tensor_size
(
s
.
lens
().
size
(),
[
&
](
auto
ndim
)
{
v
(
f
(
xs
,
ndim
)...);
});
visit_tensor_size
(
s
.
ndim
(),
[
&
](
auto
ndim
)
{
v
(
f
(
xs
,
ndim
)...);
});
}
}
template
<
class
F
>
template
<
class
F
>
...
...
src/targets/gpu/device/multinomial.cpp
View file @
40fbef9b
...
@@ -67,18 +67,19 @@ void multinomial(hipStream_t stream,
...
@@ -67,18 +67,19 @@ void multinomial(hipStream_t stream,
size_t
class_size
=
arg0
.
get_shape
().
lens
().
back
();
size_t
class_size
=
arg0
.
get_shape
().
lens
().
back
();
size_t
sample_size
=
result
.
get_shape
().
lens
().
back
();
size_t
sample_size
=
result
.
get_shape
().
lens
().
back
();
hip_visit_all
(
arg0
,
arg1
)([
&
](
auto
cdf
,
auto
dist
)
{
visit_all
(
arg0
,
arg1
)([
&
](
auto
cdf_host
,
auto
dist_host
)
{
result
.
visit
([
&
](
auto
out
)
{
result
.
visit
([
&
](
auto
output_host
)
{
hip_visit_views
(
out
)([
&
](
auto
output
)
{
hip_visit_views
(
cdf_host
,
dist_host
,
output_host
)(
gs_launch
(
stream
,
batch_size
*
sample_size
)([
=
](
auto
i
)
__device__
{
[
&
](
auto
cdf
,
auto
dist
,
auto
output
)
{
auto
idx
=
output
.
get_shape
().
multi
(
i
);
gs_launch
(
stream
,
batch_size
*
sample_size
)([
=
](
auto
i
)
__device__
{
auto
cdf_begin
=
cdf
.
begin
()
+
(
idx
.
front
()
*
class_size
);
auto
idx
=
output
.
get_shape
().
multi
(
i
);
auto
cdf_end
=
cdf_begin
+
class_size
;
auto
cdf_begin
=
cdf
.
begin
()
+
(
idx
.
front
()
*
class_size
);
auto
sample_iter
=
auto
cdf_end
=
cdf_begin
+
class_size
;
upper_bound
(
cdf_begin
,
cdf_end
,
dist
[
i
]
*
*
(
std
::
prev
(
cdf_end
)));
auto
*
sample_iter
=
output
[
i
]
=
std
::
distance
(
cdf_begin
,
sample_iter
);
upper_bound
(
cdf_begin
,
cdf_end
,
dist
[
i
]
*
*
(
std
::
prev
(
cdf_end
)));
output
[
i
]
=
std
::
distance
(
cdf_begin
,
sample_iter
);
});
});
});
});
});
});
});
});
}
}
...
...
src/targets/gpu/device/scatter.cpp
View file @
40fbef9b
...
@@ -37,22 +37,26 @@ argument scatter(
...
@@ -37,22 +37,26 @@ argument scatter(
hipStream_t
stream
,
argument
result
,
argument
arg0
,
argument
arg1
,
argument
arg2
,
int64_t
axis
)
hipStream_t
stream
,
argument
result
,
argument
arg0
,
argument
arg1
,
argument
arg2
,
int64_t
axis
)
{
{
auto
ds
=
arg0
.
get_shape
();
auto
ds
=
arg0
.
get_shape
();
auto
inds
=
arg1
.
get_shape
();
auto
s1
=
arg1
.
get_shape
();
auto
axis_dim_size
=
ds
.
lens
()[
axis
];
auto
axis_dim_size
=
ds
.
lens
()[
axis
];
hip_visit_all
(
result
,
arg0
,
inds
)([
&
](
auto
output
,
auto
data
,
auto
s1
)
{
hip_visit_all
(
result
,
arg0
,
arg2
)([
&
](
auto
output
,
auto
data
,
auto
update
)
{
auto
*
output_ptr
=
device_cast
(
output
.
data
());
auto
*
output_ptr
=
device_cast
(
output
.
data
());
const
auto
*
data_ptr
=
device_cast
(
data
.
data
());
const
auto
*
data_ptr
=
device_cast
(
data
.
data
());
gs_launch
(
stream
,
ds
.
elements
())([
=
](
auto
i
)
__device__
{
output_ptr
[
i
]
=
data_ptr
[
i
];
});
gs_launch
(
stream
,
ds
.
elements
())([
=
](
auto
i
)
__device__
{
output_ptr
[
i
]
=
data_ptr
[
i
];
});
hip_visit_all
(
arg1
,
arg2
)([
&
](
auto
indices
,
auto
update
)
{
const
auto
*
upd_ptr
=
device_cast
(
update
.
data
());
hip_visit_all
(
arg1
)([
&
](
auto
indices
)
{
const
auto
*
indices_ptr
=
device_cast
(
indices
.
data
());
if
constexpr
(
indices
.
get_shape
().
lens
.
size
()
==
output
.
get_shape
().
lens
.
size
())
gs_launch
(
stream
,
inds
.
elements
())([
=
](
auto
i
)
__device__
{
{
auto
out_idx
=
s1
.
multi
(
i
);
const
auto
*
upd_ptr
=
device_cast
(
update
.
data
());
auto
index
=
indices_ptr
[
i
];
const
auto
*
indices_ptr
=
device_cast
(
indices
.
data
());
index
=
index
<
0
?
index
+
axis_dim_size
:
index
;
gs_launch
(
stream
,
s1
.
elements
())([
=
](
auto
i
)
__device__
{
out_idx
[
axis
]
=
index
;
auto
out_idx
=
indices
.
get_shape
().
multi
(
i
);
output
[
out_idx
]
=
upd_ptr
[
i
];
auto
index
=
indices_ptr
[
i
];
});
index
=
index
<
0
?
index
+
axis_dim_size
:
index
;
out_idx
[
axis
]
=
index
;
output
[
out_idx
]
=
upd_ptr
[
i
];
});
}
});
});
});
});
...
...
src/targets/gpu/device_name.cpp
View file @
40fbef9b
...
@@ -43,6 +43,8 @@ auto get_arch_name(rank<1>, const HipDeviceProp& props) -> decltype(std::string(
...
@@ -43,6 +43,8 @@ auto get_arch_name(rank<1>, const HipDeviceProp& props) -> decltype(std::string(
return
std
::
string
(
props
.
gcnArchName
);
return
std
::
string
(
props
.
gcnArchName
);
}
}
std
::
string
get_arch_name
(
const
hipDeviceProp_t
&
props
)
{
return
get_arch_name
(
rank
<
1
>
{},
props
);
}
int
get_device_id
()
int
get_device_id
()
{
{
int
device
;
int
device
;
...
@@ -58,7 +60,7 @@ std::string get_device_name()
...
@@ -58,7 +60,7 @@ std::string get_device_name()
auto
status
=
hipGetDeviceProperties
(
&
props
,
get_device_id
());
auto
status
=
hipGetDeviceProperties
(
&
props
,
get_device_id
());
if
(
status
!=
hipSuccess
)
if
(
status
!=
hipSuccess
)
MIGRAPHX_THROW
(
"Failed to get device properties"
);
MIGRAPHX_THROW
(
"Failed to get device properties"
);
return
get_arch_name
(
rank
<
1
>
{},
props
);
return
get_arch_name
(
props
);
}
}
}
// namespace gpu
}
// namespace gpu
...
...
src/targets/gpu/driver/CMakeLists.txt
View file @
40fbef9b
...
@@ -22,7 +22,7 @@
...
@@ -22,7 +22,7 @@
# THE SOFTWARE.
# THE SOFTWARE.
#####################################################################################
#####################################################################################
file
(
GLOB GPU_DRIVER_SRCS
${
CONFIGURE_DEPENDS
}
${
CMAKE_CURRENT_SOURCE_DIR
}
/*.cpp
)
file
(
GLOB GPU_DRIVER_SRCS CONFIGURE_DEPENDS
${
CMAKE_CURRENT_SOURCE_DIR
}
/*.cpp
)
add_executable
(
gpu-driver
add_executable
(
gpu-driver
${
GPU_DRIVER_SRCS
}
${
GPU_DRIVER_SRCS
}
)
)
...
...
src/targets/gpu/driver/compile_op.cpp
View file @
40fbef9b
...
@@ -22,7 +22,7 @@
...
@@ -22,7 +22,7 @@
* THE SOFTWARE.
* THE SOFTWARE.
*/
*/
#include <migraphx/gpu/driver/action.hpp>
#include <migraphx/gpu/driver/action.hpp>
#include <migraphx/gpu/
driver/perf
.hpp>
#include <migraphx/gpu/
time_op
.hpp>
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/context.hpp>
...
...
src/targets/gpu/driver/run_op.cpp
View file @
40fbef9b
...
@@ -22,7 +22,7 @@
...
@@ -22,7 +22,7 @@
* THE SOFTWARE.
* THE SOFTWARE.
*/
*/
#include <migraphx/gpu/driver/action.hpp>
#include <migraphx/gpu/driver/action.hpp>
#include <migraphx/gpu/
driver/perf
.hpp>
#include <migraphx/gpu/
time_op
.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
...
...
src/targets/gpu/fuse_ck.cpp
0 → 100644
View file @
40fbef9b
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gpu/fuse_ck.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
module
;
namespace
gpu
{
struct
ck_gemm
{
operation
op
=
make_op
(
"dot"
);
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
op
,
"op"
));
}
std
::
string
name
()
const
{
return
"gpu::ck_gemm"
;
}
void
check_gemm_shape
(
const
shape
&
s
)
const
{
if
(
not
contains
(
range
(
s
.
strides
().
rbegin
(),
s
.
strides
().
rbegin
()
+
3
),
1
))
MIGRAPHX_THROW
(
"Invalid shape for ck_gemm"
);
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
,
const
std
::
vector
<
module_ref
>&
mods
)
const
{
check_shapes
{
inputs
,
*
this
}.
same_ndims
();
if
(
inputs
.
size
()
<
2
)
MIGRAPHX_THROW
(
"should have at least two inputs."
);
auto
a
=
inputs
[
0
];
auto
b
=
inputs
[
1
];
for
(
const
auto
&
input
:
inputs
)
check_gemm_shape
(
input
);
auto
r
=
op
.
compute_shape
({
a
,
b
});
if
(
mods
.
empty
())
return
r
;
return
r
.
with_type
(
mods
.
front
()
->
get_output_shapes
().
front
().
type
());
}
};
MIGRAPHX_REGISTER_OP
(
ck_gemm
);
namespace
{
bool
is_ck_supported_type
(
shape
::
type_t
t
)
{
return
contains
({
shape
::
half_type
,
shape
::
int8_type
,
shape
::
int32_type
},
t
);
}
MIGRAPHX_PRED_MATCHER
(
is_ck_gemm
,
instruction_ref
ins
)
{
if
(
ins
->
name
()
!=
"dot"
and
ins
->
name
()
!=
"quant_dot"
)
return
false
;
if
(
not
is_ck_supported_type
(
ins
->
get_shape
().
type
()))
return
false
;
auto
a
=
ins
->
inputs
().
front
()
->
get_shape
();
auto
b
=
ins
->
inputs
().
back
()
->
get_shape
();
auto
m
=
a
.
lens
()[
a
.
lens
().
size
()
-
2
];
auto
n
=
b
.
lens
().
back
();
auto
k
=
a
.
lens
().
back
();
// Integer gemms must be divisible by 4 in ck
if
(
contains
({
shape
::
int8_type
,
shape
::
int32_type
},
ins
->
get_shape
().
type
()))
{
if
(
m
%
4
!=
0
)
return
false
;
if
(
n
%
4
!=
0
)
return
false
;
if
(
k
%
4
!=
0
)
return
false
;
}
// Skipping GEMMs with a K dimension greater than 2048 is a course-grained strategy
// to avoid poor-performing GEMM kernels from CK
// To-do: Investigate a more precise strategy
return
k
<=
2048
;
}
struct
find_ck_gemm_pointwise
{
// Find a gemm followed by a pointwise operation.
auto
matcher
()
const
{
auto
gemm
=
match
::
skip
(
match
::
name
(
"contiguous"
))(
match
::
name
(
"dot"
,
"quant_dot"
)(
is_ck_gemm
().
bind
(
"gemm"
)));
return
match
::
name
(
"pointwise"
)(
match
::
any_of
[
match
::
inputs
()](
gemm
.
bind
(
"x"
)));
}
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
gemm_ins
=
r
.
instructions
[
"gemm"
];
auto
x_ins
=
r
.
instructions
[
"x"
];
// input after contiguous
auto
*
pm
=
ins
->
module_inputs
().
front
();
auto
names
=
pm
->
get_parameter_names
();
std
::
sort
(
names
.
begin
(),
names
.
end
());
auto
inputs
=
ins
->
inputs
();
auto
gemm_it
=
std
::
find
(
inputs
.
begin
(),
inputs
.
end
(),
x_ins
);
auto
gemm_idx
=
gemm_it
-
inputs
.
begin
();
if
(
gemm_ins
->
get_shape
().
type
()
!=
shape
::
int32_type
and
ins
->
get_shape
().
type
()
!=
gemm_ins
->
get_shape
().
type
())
return
;
if
(
std
::
any_of
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
[](
auto
input
)
{
return
not
is_ck_supported_type
(
input
->
get_shape
().
type
());
}))
return
;
assert
(
gemm_it
!=
inputs
.
end
());
if
(
gemm_idx
!=
0
)
{
auto
first_param
=
pm
->
get_parameter
(
names
[
0
]);
auto
gemm_param
=
pm
->
get_parameter
(
names
[
gemm_idx
]);
auto
new_gemm_param
=
pm
->
add_parameter
(
names
[
0
]
+
"_0"
,
gemm_param
->
get_shape
());
auto
new_first_param
=
pm
->
add_parameter
(
names
[
gemm_idx
]
+
"_0"
,
first_param
->
get_shape
());
pm
->
replace_instruction
(
gemm_param
,
new_gemm_param
);
pm
->
replace_instruction
(
first_param
,
new_first_param
);
pm
->
remove_instruction
(
first_param
);
pm
->
remove_instruction
(
gemm_param
);
}
inputs
.
erase
(
gemm_it
);
inputs
.
insert
(
inputs
.
begin
(),
gemm_ins
->
inputs
().
begin
(),
gemm_ins
->
inputs
().
end
());
mpm
.
get_module
().
replace_instruction
(
ins
,
ck_gemm
{
gemm_ins
->
get_operator
()},
inputs
,
{
pm
});
}
};
struct
find_ck_gemm
{
auto
matcher
()
const
{
return
match
::
name
(
"dot"
)(
is_ck_gemm
().
bind
(
"gemm"
));
}
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
mpm
.
get_module
().
replace_instruction
(
ins
,
ck_gemm
{
ins
->
get_operator
()},
ins
->
inputs
());
}
};
}
// namespace
void
fuse_ck
::
apply
(
module_pass_manager
&
mpm
)
const
{
match
::
find_matches
(
mpm
,
find_ck_gemm_pointwise
{});
match
::
find_matches
(
mpm
,
find_ck_gemm
{});
}
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/fuse_mlir.cpp
View file @
40fbef9b
...
@@ -38,6 +38,27 @@ namespace gpu {
...
@@ -38,6 +38,27 @@ namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_ENABLE_MLIR
);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_ENABLE_MLIR
);
bool
mlir_enabled
()
{
#ifdef MIGRAPHX_MLIR
const
bool
mlir_enabled
=
enabled
(
MIGRAPHX_ENABLE_MLIR
{});
if
(
mlir_enabled
)
{
return
true
;
}
else
{
std
::
cerr
<<
"WARNING: MIGraphX built with MLIR but it is not enabled. Please set the env "
"var MIGRAPHX_ENABLE_MLIR to use MLIR kernel generator."
<<
std
::
endl
;
return
false
;
}
#else
return
false
;
#endif
}
#ifdef MIGRAPHX_MLIR
#ifdef MIGRAPHX_MLIR
struct
mlir_op
struct
mlir_op
...
@@ -58,8 +79,41 @@ struct mlir_op
...
@@ -58,8 +79,41 @@ struct mlir_op
MIGRAPHX_THROW
(
"should have one submodule."
);
MIGRAPHX_THROW
(
"should have one submodule."
);
if
(
inputs
.
size
()
<
2
)
if
(
inputs
.
size
()
<
2
)
MIGRAPHX_THROW
(
"should have at least two inputs."
);
MIGRAPHX_THROW
(
"should have at least two inputs."
);
auto
n
=
inputs
.
size
();
return
op
.
compute_shape
({
inputs
[
n
-
2
],
inputs
[
n
-
1
]});
module_ref
mod
=
mods
[
0
];
auto
type
=
mod
->
get_output_shapes
().
front
().
type
();
std
::
unordered_map
<
instruction_ref
,
shape
>
ins_shapes
;
size_t
param_cnt
=
0
;
std
::
vector
<
std
::
string
>
names
=
mod
->
get_parameter_names
();
std
::
sort
(
names
.
begin
(),
names
.
end
());
for
(
std
::
string
param_name
:
names
)
{
ins_shapes
[
mod
->
get_parameter
(
param_name
)]
=
inputs
[
param_cnt
++
];
}
for
(
auto
ins
:
iterator_for
(
*
mod
))
{
if
(
ins
->
name
()
==
"@param"
)
{
continue
;
}
if
(
ins
->
name
()
==
"@literal"
)
{
ins_shapes
[
ins
]
=
ins
->
get_shape
();
continue
;
}
if
(
ins
->
name
()
==
"@return"
)
{
return
ins_shapes
[
ins
->
inputs
().
at
(
0
)].
with_type
(
type
);
}
std
::
vector
<
shape
>
input_shapes
;
input_shapes
.
resize
(
ins
->
inputs
().
size
());
std
::
transform
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
input_shapes
.
begin
(),
[
&
](
auto
in
)
{
return
ins_shapes
[
in
];
});
ins_shapes
[
ins
]
=
ins
->
get_operator
().
compute_shape
(
input_shapes
);
}
MIGRAPHX_THROW
(
"No return found in the submodule"
);
}
}
};
};
MIGRAPHX_REGISTER_OP
(
mlir_op
);
MIGRAPHX_REGISTER_OP
(
mlir_op
);
...
@@ -68,7 +122,7 @@ namespace {
...
@@ -68,7 +122,7 @@ namespace {
MIGRAPHX_PRED_MATCHER
(
is_mlir_conv
,
instruction_ref
ins
)
MIGRAPHX_PRED_MATCHER
(
is_mlir_conv
,
instruction_ref
ins
)
{
{
if
(
ins
->
name
()
!=
"convolution"
)
if
(
ins
->
name
()
!=
"convolution"
and
ins
->
name
()
!=
"quant_convolution"
)
return
false
;
return
false
;
value
v
=
ins
->
get_operator
().
to_value
();
value
v
=
ins
->
get_operator
().
to_value
();
auto
group
=
v
.
at
(
"group"
).
to
<
int
>
();
auto
group
=
v
.
at
(
"group"
).
to
<
int
>
();
...
@@ -85,10 +139,121 @@ struct find_mlir_op
...
@@ -85,10 +139,121 @@ struct find_mlir_op
auto
matcher
()
const
auto
matcher
()
const
{
{
auto
dot_or_conv
=
match
::
skip
(
match
::
name
(
"contiguous"
))(
auto
dot_or_conv
=
match
::
skip
(
match
::
name
(
"contiguous"
))(
match
::
any_of
(
match
::
name
(
"dot"
),
is_mlir_conv
()).
bind
(
"gemm_based_op"
));
match
::
any_of
(
match
::
name
(
"dot"
),
match
::
name
(
"quant_dot"
),
is_mlir_conv
())
.
bind
(
"gemm_based_op"
));
return
match
::
name
(
"pointwise"
)(
match
::
any_of
[
match
::
inputs
()](
dot_or_conv
.
bind
(
"x"
)));
return
match
::
name
(
"pointwise"
)(
match
::
any_of
[
match
::
inputs
()](
dot_or_conv
.
bind
(
"x"
)));
}
}
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
create_param_map_with_literals
(
module_ref
mm
,
const
module
*
pm
,
const
shape
&
shape
)
const
{
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
ins_map
;
for
(
auto
ins
:
iterator_for
(
*
pm
))
{
if
(
ins
->
name
()
!=
"@literal"
)
{
continue
;
}
literal
r
=
ins
->
get_literal
();
instruction_ref
literal
=
mm
->
add_literal
(
r
);
instruction_ref
mbcast
=
mm
->
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
shape
.
lens
()}}),
literal
);
ins_map
[
ins
]
=
mbcast
;
}
return
ins_map
;
}
std
::
tuple
<
instruction_ref
,
std
::
vector
<
instruction_ref
>>
fuse_input_ops_and_gemm_based_op
(
module_ref
mm
,
instruction_ref
gemm_based_op
)
const
{
std
::
vector
<
instruction_ref
>
top_inputs
;
std
::
vector
<
instruction_ref
>
imm_inputs
;
size_t
input_cnt
=
0
;
for
(
instruction_ref
input
:
gemm_based_op
->
inputs
())
{
std
::
vector
<
operation
>
op_stream
;
while
(
contains
({
"slice"
,
"transpose"
,
"contiguous"
,
"reshape"
},
input
->
name
()))
{
op_stream
.
push_back
(
input
->
get_operator
());
input
=
input
->
inputs
().
at
(
0
);
}
top_inputs
.
push_back
(
input
);
instruction_ref
prev_input
=
mm
->
add_parameter
(
"y"
+
std
::
to_string
(
input_cnt
++
),
input
->
get_shape
());
for
(
const
auto
&
op
:
reverse
(
op_stream
))
{
prev_input
=
mm
->
add_instruction
(
op
,
{
prev_input
});
}
imm_inputs
.
push_back
(
prev_input
);
}
instruction_ref
new_gemm_based_op
=
mm
->
add_instruction
(
gemm_based_op
->
get_operator
(),
imm_inputs
);
return
{
new_gemm_based_op
,
top_inputs
};
}
// Whitelist supported fusion options, including imposing type constraints
// for cases where MLIR only supports an operation (usually a pointwise function)
// on particular types.
bool
is_pointwise_op_supported_by_mlir
(
const
instruction
&
i
)
const
{
using
type_t
=
shape
::
type_t
;
const
auto
&
name
=
i
.
name
();
const
auto
result_type
=
i
.
get_shape
().
type
();
const
std
::
initializer_list
<
type_t
>
allowed_types
=
{
type_t
::
float_type
,
type_t
::
half_type
,
type_t
::
int8_type
,
type_t
::
int32_type
,
type_t
::
bool_type
};
// Preliminary type check.
if
(
not
contains
(
allowed_types
,
result_type
))
{
return
false
;
}
const
std
::
initializer_list
<
std
::
string
>
any_type_ops
=
{
"@literal"
,
"@param"
,
"@return"
};
const
std
::
initializer_list
<
std
::
string
>
no_bool_ops
=
{
"convolution"
,
"quant_convolution"
,
"dot"
,
"quant_dot"
,
"add"
,
"clip"
,
"relu"
,
"sub"
,
"mul"
,
"div"
,
"pow"
,
"where"
,
"quantizelinear"
,
"dequantizelinear"
,
"abs"
,
"neg"
};
const
std
::
initializer_list
<
std
::
string
>
fp_only_ops
=
{
"ceil"
,
"erf"
,
"exp"
,
"floor"
,
"log"
,
"recip"
,
"rsqrt"
,
"sigmoid"
"softmax"
,
"tanh"
};
bool
is_float
=
contains
({
type_t
::
float_type
,
type_t
::
half_type
},
result_type
);
if
(
contains
(
any_type_ops
,
name
))
return
true
;
if
(
result_type
!=
type_t
::
bool_type
&&
contains
(
no_bool_ops
,
name
))
return
true
;
if
(
is_float
&&
contains
(
fp_only_ops
,
name
))
return
true
;
// Only conversions between floating types are known to be unambigiously
// supported.
if
(
is_float
&&
name
==
"convert"
)
{
return
std
::
all_of
(
i
.
inputs
().
begin
(),
i
.
inputs
().
end
(),
[](
const
auto
&
arg
)
{
return
contains
({
type_t
::
float_type
,
type_t
::
half_type
},
arg
->
get_shape
().
type
());
});
}
return
false
;
}
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
{
{
auto
ins
=
r
.
result
;
auto
ins
=
r
.
result
;
...
@@ -96,35 +261,25 @@ struct find_mlir_op
...
@@ -96,35 +261,25 @@ struct find_mlir_op
auto
x_ins
=
r
.
instructions
[
"x"
];
// input after contiguous
auto
x_ins
=
r
.
instructions
[
"x"
];
// input after contiguous
auto
*
pm
=
ins
->
module_inputs
().
front
();
auto
*
pm
=
ins
->
module_inputs
().
front
();
auto
names
=
pm
->
get_parameter_names
();
auto
names
=
pm
->
get_parameter_names
();
// Whitelist pointwise operators
// Whitelist pointwise operators.
if
(
std
::
any_of
(
pm
->
begin
(),
pm
->
end
(),
[](
const
auto
&
i
)
{
if
(
std
::
any_of
(
pm
->
begin
(),
pm
->
end
(),
[
&
](
const
auto
&
i
)
{
return
not
contains
(
return
not
is_pointwise_op_supported_by_mlir
(
i
);
{
"@literal"
,
"@param"
,
"@return"
,
"convolution"
,
"dot"
,
"add"
,
"relu"
},
i
.
name
());
}))
return
;
// Only fuse with fp32/fp16
if
(
std
::
any_of
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
[
&
](
auto
i
)
{
return
not
contains
({
shape
::
type_t
::
float_type
,
shape
::
type_t
::
half_type
},
i
->
get_shape
().
type
());
}))
}))
return
;
return
;
std
::
sort
(
names
.
begin
(),
names
.
end
());
std
::
sort
(
names
.
begin
(),
names
.
end
());
module_ref
mm
=
mpm
.
create_module
(
"mlir_"
+
pm
->
name
());
module_ref
mm
=
mpm
.
create_module
(
"mlir_"
+
pm
->
name
());
mm
->
set_bypass
();
mm
->
set_bypass
();
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
param_map
;
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
param_map
=
auto
x
=
mm
->
add_parameter
(
"x"
+
std
::
to_string
(
names
.
size
()),
create_param_map_with_literals
(
mm
,
pm
,
gemm_based_op
->
get_shape
());
gemm_based_op
->
inputs
().
at
(
0
)
->
get_shape
());
auto
[
anchor_op
,
top_inputs
]
=
fuse_input_ops_and_gemm_based_op
(
mm
,
gemm_based_op
);
auto
w
=
mm
->
add_parameter
(
"x"
+
std
::
to_string
(
names
.
size
()
+
1
),
gemm_based_op
->
inputs
().
at
(
1
)
->
get_shape
());
auto
conv
=
mm
->
add_instruction
(
gemm_based_op
->
get_operator
(),
{
x
,
w
});
std
::
transform
(
names
.
begin
(),
std
::
transform
(
names
.
begin
(),
names
.
end
(),
names
.
end
(),
ins
->
inputs
().
begin
(),
ins
->
inputs
().
begin
(),
std
::
inserter
(
param_map
,
param_map
.
end
()),
std
::
inserter
(
param_map
,
param_map
.
end
()),
[
&
](
auto
name
,
auto
input
)
{
[
&
,
&
anchor_op
=
anchor_op
](
auto
name
,
auto
input
)
{
if
(
input
==
x_ins
)
if
(
input
==
x_ins
)
return
std
::
make_pair
(
pm
->
get_parameter
(
name
),
conv
);
return
std
::
make_pair
(
pm
->
get_parameter
(
name
),
anchor_op
);
return
std
::
make_pair
(
pm
->
get_parameter
(
name
),
return
std
::
make_pair
(
pm
->
get_parameter
(
name
),
mm
->
add_parameter
(
name
,
input
->
get_shape
()));
mm
->
add_parameter
(
name
,
input
->
get_shape
()));
});
});
...
@@ -135,7 +290,7 @@ struct find_mlir_op
...
@@ -135,7 +290,7 @@ struct find_mlir_op
ins
->
inputs
().
end
(),
ins
->
inputs
().
end
(),
std
::
back_inserter
(
inputs
),
std
::
back_inserter
(
inputs
),
[
&
](
auto
input
)
{
return
input
!=
gemm_based_op
;
});
[
&
](
auto
input
)
{
return
input
!=
gemm_based_op
;
});
inputs
.
insert
(
inputs
.
end
(),
gemm_based_op
->
inputs
()
.
begin
(),
gemm_based_op
->
inputs
()
.
end
());
inputs
.
insert
(
inputs
.
end
(),
top_
inputs
.
begin
(),
top_
inputs
.
end
());
mpm
.
get_module
().
replace_instruction
(
mpm
.
get_module
().
replace_instruction
(
ins
,
mlir_op
{
gemm_based_op
->
get_operator
()},
inputs
,
{
mm
});
ins
,
mlir_op
{
gemm_based_op
->
get_operator
()},
inputs
,
{
mm
});
}
}
...
@@ -148,17 +303,7 @@ struct find_mlir_op
...
@@ -148,17 +303,7 @@ struct find_mlir_op
void
fuse_mlir
::
apply
(
module_pass_manager
&
mpm
)
const
void
fuse_mlir
::
apply
(
module_pass_manager
&
mpm
)
const
{
{
#ifdef MIGRAPHX_MLIR
#ifdef MIGRAPHX_MLIR
const
bool
mlir_enabled
=
enabled
(
MIGRAPHX_ENABLE_MLIR
{});
match
::
find_matches
(
mpm
,
find_mlir_op
{});
if
(
mlir_enabled
)
{
match
::
find_matches
(
mpm
,
find_mlir_op
{});
}
else
{
std
::
cerr
<<
"WARNING: MIGraphX built with MLIR but it is not enabled. Please set the env "
"var MIGRAPHX_ENABLE_MLIR to use MLIR kernel generator."
<<
std
::
endl
;
}
#else
#else
(
void
)
mpm
;
(
void
)
mpm
;
#endif
#endif
...
...
src/targets/gpu/fuse_ops.cpp
View file @
40fbef9b
...
@@ -165,7 +165,8 @@ struct fusion
...
@@ -165,7 +165,8 @@ struct fusion
const
std
::
unordered_set
<
std
::
string
>&
get_supported_archs
()
const
std
::
unordered_set
<
std
::
string
>&
get_supported_archs
()
{
{
static
std
::
unordered_set
<
std
::
string
>
supported_archs
{
"gfx900"
,
"gfx906"
,
"gfx908"
,
"gfx1030"
};
static
std
::
unordered_set
<
std
::
string
>
supported_archs
{
"gfx900"
,
"gfx906"
,
"gfx908"
,
"gfx1030"
,
"gfx940"
};
return
supported_archs
;
return
supported_archs
;
}
}
...
...
src/targets/gpu/gemm_impl.cpp
View file @
40fbef9b
...
@@ -140,12 +140,10 @@ void gemm_impl(context& ctx,
...
@@ -140,12 +140,10 @@ void gemm_impl(context& ctx,
compute_type
=
rocblas_datatype_f32_r
;
compute_type
=
rocblas_datatype_f32_r
;
}
}
#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38
rocblas_gemm_flags
flag
=
rocblas_gemm_flags_none
;
rocblas_gemm_flags
flag
=
#if ROCBLAS_VERSION_MAJOR < 3
int8_x4_format
?
rocblas_gemm_flags_pack_int8x4
:
rocblas_gemm_flags_none
;
if
(
int8_x4_format
)
#else
flag
=
rocblas_gemm_flags_pack_int8x4
;
(
void
)
int8_x4_format
;
int
flag
=
0
;
#endif
#endif
auto
a_lens
=
args
[
0
].
get_shape
().
lens
();
auto
a_lens
=
args
[
0
].
get_shape
().
lens
();
...
...
Prev
1
…
8
9
10
11
12
13
14
15
16
…
22
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment