Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
40fbef9b
Unverified
Commit
40fbef9b
authored
Aug 05, 2023
by
Ted Themistokleous
Committed by
GitHub
Aug 05, 2023
Browse files
Merge branch 'develop' into threaded_nms
parents
d164b151
aeb9f78c
Changes
440
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
681 additions
and
124 deletions
+681
-124
src/targets/fpga/subgraph.cpp
src/targets/fpga/subgraph.cpp
+1
-2
src/targets/fpga/vitis_ai_adapter.cpp
src/targets/fpga/vitis_ai_adapter.cpp
+1
-1
src/targets/gpu/CMakeLists.txt
src/targets/gpu/CMakeLists.txt
+35
-7
src/targets/gpu/compile_gen.cpp
src/targets/gpu/compile_gen.cpp
+13
-1
src/targets/gpu/compile_hip.cpp
src/targets/gpu/compile_hip.cpp
+10
-3
src/targets/gpu/compile_hip_code_object.cpp
src/targets/gpu/compile_hip_code_object.cpp
+10
-6
src/targets/gpu/compile_miopen.cpp
src/targets/gpu/compile_miopen.cpp
+1
-1
src/targets/gpu/compile_ops.cpp
src/targets/gpu/compile_ops.cpp
+181
-12
src/targets/gpu/compiler.cpp
src/targets/gpu/compiler.cpp
+24
-12
src/targets/gpu/device/include/migraphx/gpu/device/visit.hpp
src/targets/gpu/device/include/migraphx/gpu/device/visit.hpp
+10
-10
src/targets/gpu/device/multinomial.cpp
src/targets/gpu/device/multinomial.cpp
+12
-11
src/targets/gpu/device/scatter.cpp
src/targets/gpu/device/scatter.cpp
+16
-12
src/targets/gpu/device_name.cpp
src/targets/gpu/device_name.cpp
+3
-1
src/targets/gpu/driver/CMakeLists.txt
src/targets/gpu/driver/CMakeLists.txt
+1
-1
src/targets/gpu/driver/compile_op.cpp
src/targets/gpu/driver/compile_op.cpp
+1
-1
src/targets/gpu/driver/run_op.cpp
src/targets/gpu/driver/run_op.cpp
+1
-1
src/targets/gpu/fuse_ck.cpp
src/targets/gpu/fuse_ck.cpp
+175
-0
src/targets/gpu/fuse_mlir.cpp
src/targets/gpu/fuse_mlir.cpp
+180
-35
src/targets/gpu/fuse_ops.cpp
src/targets/gpu/fuse_ops.cpp
+2
-1
src/targets/gpu/gemm_impl.cpp
src/targets/gpu/gemm_impl.cpp
+4
-6
No files found.
src/targets/fpga/subgraph.cpp
View file @
40fbef9b
...
...
@@ -113,8 +113,7 @@ void subgraph::apply(module_pass_manager& mpm) const
// TODO(varunsh): this code may be replaceable by code in the fuse_pointwise pass
// assuming all FPGA instructions are in one contiguous range
pm
->
insert_instructions
(
pm
->
end
(),
first
,
last
,
{});
pm
->
insert_instructions
(
pm
->
end
(),
first
,
std
::
next
(
last
),
{});
migraphx
::
instruction_ref
placeholder_ins
;
for
(
auto
it
:
iterator_for
(
mod
))
{
...
...
src/targets/fpga/vitis_ai_adapter.cpp
View file @
40fbef9b
...
...
@@ -33,7 +33,7 @@ migraphx::shape x_model::get_shape() const { return shape; };
void
x_model
::
set_shape
(
migraphx
::
shape
s
)
{
shape
=
s
;
}
x_model
create_xmodel
(
const
migraphx
::
module_ref
mod
)
x_model
create_xmodel
(
migraphx
::
const_
module_ref
mod
)
{
std
::
cout
<<
"Calling an external function: create_xmodel!
\n
"
;
x_model
xmodel
;
...
...
src/targets/gpu/CMakeLists.txt
View file @
40fbef9b
...
...
@@ -33,6 +33,11 @@ if(NOT TARGET MIOpen)
message
(
SEND_ERROR
"Cant find miopen"
)
endif
()
if
(
NOT WIN32
)
# TODO: re-enable when CK is ported to Windows
find_package
(
composable_kernel 1.0.0 REQUIRED COMPONENTS jit_library
)
endif
()
if
(
BUILD_DEV
)
set
(
MIGRAPHX_USE_HIPRTC OFF CACHE BOOL
"Use hipRTC APIs"
)
else
()
...
...
@@ -40,12 +45,12 @@ else()
endif
()
include
(
Embed
)
file
(
GLOB KERNEL_FILES
${
CONFIGURE_DEPENDS
}
file
(
GLOB KERNEL_FILES CONFIGURE_DEPENDS
${
CMAKE_CURRENT_SOURCE_DIR
}
/kernels/include/migraphx/kernels/*.hpp
)
message
(
STATUS
"KERNEL_FILES:
${
KERNEL_FILES
}
"
)
add_embed_library
(
migraphx_kernels
${
KERNEL_FILES
}
)
add_embed_library
(
migraphx_kernels
${
KERNEL_FILES
}
RELATIVE
${
CMAKE_CURRENT_SOURCE_DIR
}
/kernels/include/
)
file
(
GLOB DEVICE_GPU_SRCS
${
CONFIGURE_DEPENDS
}
${
CMAKE_CURRENT_SOURCE_DIR
}
/device/*.cpp
)
file
(
GLOB DEVICE_GPU_SRCS CONFIGURE_DEPENDS
${
CMAKE_CURRENT_SOURCE_DIR
}
/device/*.cpp
)
add_library
(
migraphx_device
${
DEVICE_GPU_SRCS
}
)
add_library
(
compile_for_gpu INTERFACE
)
...
...
@@ -65,6 +70,8 @@ target_link_libraries(migraphx_device PUBLIC migraphx)
target_link_libraries
(
migraphx_device PRIVATE compile_for_gpu
)
target_include_directories
(
migraphx_device PUBLIC $<BUILD_INTERFACE:
${
CMAKE_CURRENT_SOURCE_DIR
}
/include>
)
target_include_directories
(
migraphx_device PRIVATE $<BUILD_INTERFACE:
${
CMAKE_CURRENT_SOURCE_DIR
}
/device/include>
)
target_compile_options
(
migraphx_device PRIVATE -Wno-ignored-attributes
)
migraphx_generate_export_header
(
migraphx_device DIRECTORY migraphx/gpu/device
)
add_library
(
kernel_file_check EXCLUDE_FROM_ALL
)
...
...
@@ -80,7 +87,13 @@ target_link_libraries(kernel_file_check compile_for_gpu)
rocm_clang_tidy_check
(
kernel_file_check
)
file
(
GLOB JIT_GPU_SRCS
${
CONFIGURE_DEPENDS
}
${
CMAKE_CURRENT_SOURCE_DIR
}
/jit/*.cpp
)
file
(
GLOB JIT_GPU_SRCS CONFIGURE_DEPENDS
${
CMAKE_CURRENT_SOURCE_DIR
}
/jit/*.cpp
)
if
(
WIN32
)
# TODO: re-enable when CK is ported to Windows
list
(
REMOVE_ITEM JIT_GPU_SRCS
${
CMAKE_CURRENT_SOURCE_DIR
}
/jit/ck_gemm.cpp
)
endif
()
add_library
(
migraphx_gpu
abs.cpp
analyze_streams.cpp
...
...
@@ -95,6 +108,7 @@ add_library(migraphx_gpu
compile_miopen.cpp
compiler.cpp
device_name.cpp
fuse_ck.cpp
fuse_mlir.cpp
fuse_ops.cpp
gather.cpp
...
...
@@ -123,11 +137,14 @@ add_library(migraphx_gpu
schedule_model.cpp
sync_device.cpp
target.cpp
time_op.cpp
topk.cpp
write_literals.cpp
${
JIT_GPU_SRCS
}
)
set_target_properties
(
migraphx_gpu PROPERTIES EXPORT_NAME gpu
)
migraphx_generate_export_header
(
migraphx_gpu
)
function
(
register_migraphx_gpu_ops PREFIX
)
foreach
(
OP
${
ARGN
}
)
...
...
@@ -169,7 +186,7 @@ register_op(migraphx_gpu
OPERATORS gpu::rocblas_gemm<op::dot> gpu::rocblas_gemm<op::quant_dot>
INCLUDES migraphx/gpu/context.hpp
)
register_op
(
migraphx_gpu HEADER migraphx/gpu/convolution.hpp
OPERATORS gpu::miopen_convolution<op::convolution> gpu::miopen_convolution<op::
de
convolution> gpu::miopen_convolution<op::quant_convolution>
OPERATORS gpu::miopen_convolution<op::convolution> gpu::miopen_convolution<op::convolution
_backwards
> gpu::miopen_convolution<op::quant_convolution>
INCLUDES migraphx/gpu/context.hpp
)
rocm_set_soversion
(
migraphx_gpu
${
MIGRAPHX_SO_VERSION
}
)
rocm_clang_tidy_check
(
migraphx_gpu
)
...
...
@@ -181,7 +198,9 @@ if(MIGRAPHX_ENABLE_MLIR)
find_package
(
rocMLIR 1.0.0 CONFIG REQUIRED
)
message
(
STATUS
"Build with rocMLIR::rockCompiler
${
rocMLIR_VERSION
}
"
)
target_compile_definitions
(
migraphx_gpu PRIVATE
"-DMIGRAPHX_MLIR"
)
target_link_libraries
(
migraphx_gpu PUBLIC rocMLIR::rockCompiler
)
# Make this private to avoid multiple inclusions of LLVM symbols.
# TODO: Fix rocMLIR's library to hide LLVM internals.
target_link_libraries
(
migraphx_gpu PRIVATE rocMLIR::rockCompiler
)
endif
()
if
(
MIGRAPHX_USE_HIPRTC
)
...
...
@@ -227,7 +246,12 @@ check_library_exists(MIOpen "miopenFindSolutions" "${MIOPEN_LOCATION}" HAS_FIND_
set
(
MIGRAPHX_USE_FIND_2_API
"
${
HAS_FIND_2_API
}
"
CACHE BOOL
""
)
if
(
MIGRAPHX_USE_FIND_2_API
)
target_compile_definitions
(
migraphx_gpu PUBLIC -DMIGRAPHX_HAS_FIND_2_API
)
check_library_exists
(
MIOpen
"miopenSetFindOptionPreallocatedTensor"
"
${
MIOPEN_LOCATION
}
"
HAS_PREALLOCATION_API
)
if
(
HAS_PREALLOCATION_API
)
target_compile_definitions
(
migraphx_gpu PUBLIC -DMIGRAPHX_HAS_FIND_2_API -DMIGRAPHX_PREALLOCATE_MIOPEN_BUFFERS
)
else
()
target_compile_definitions
(
migraphx_gpu PUBLIC -DMIGRAPHX_HAS_FIND_2_API
)
endif
()
message
(
STATUS
"MIGraphx is using Find-2.0 API of MIOpen"
)
else
()
message
(
STATUS
"MIGraphx is using legacy Find API in MIOpen"
)
...
...
@@ -242,6 +266,10 @@ endif()
target_link_libraries
(
migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas
)
target_link_libraries
(
migraphx_gpu PRIVATE migraphx_device migraphx_kernels
)
if
(
NOT WIN32
)
# TODO: re-enable when CK is ported to Windows
target_link_libraries
(
migraphx_gpu PRIVATE composable_kernel::jit_library
)
endif
()
add_subdirectory
(
driver
)
add_subdirectory
(
hiprtc
)
...
...
src/targets/gpu/compile_gen.cpp
View file @
40fbef9b
...
...
@@ -29,6 +29,7 @@
#include <migraphx/module.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/rewrite_quantization.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/instruction.hpp>
...
...
@@ -171,7 +172,8 @@ std::string make_transformer_args(std::vector<std::string> transformers)
void
generate_pointwise
(
cpp_generator
&
gg
,
const
module
&
pm
,
const
std
::
string
&
name
)
{
module
m
=
pm
;
run_passes
(
m
,
{
eliminate_common_subexpression
{},
dead_code_elimination
{}});
run_passes
(
m
,
{
rewrite_quantization
{},
eliminate_common_subexpression
{},
dead_code_elimination
{}});
cpp_generator
g
;
g
.
fmap
([](
const
std
::
string
&
fname
)
{
return
"migraphx::"
+
fname
;
});
g
.
add_point_op
(
"where"
,
"${function:where}(${0}, ${1}, ${2})"
);
...
...
@@ -280,6 +282,14 @@ std::string generate_reduce(const module& m, const std::string& name)
not
input
->
get_shape
().
broadcasted
();
});
auto
inner_names
=
names
;
for
(
auto
input
:
ins
->
inputs
())
{
if
(
input
->
name
()
!=
"@param"
)
continue
;
if
(
contains
(
tensors
,
input
))
continue
;
inner_names
[
input
]
+=
"[out_idx]"
;
}
for
(
auto
input
:
tensors
)
inner_names
[
input
]
+=
"_lambda_param"
;
auto
call_function
=
...
...
@@ -308,6 +318,8 @@ std::string generate_reduce(const module& m, const std::string& name)
});
f
.
set_attributes
({
"__device__"
,
"__attribute__((const))"
}).
set_generic_types
(
m
).
set_name
(
name
);
f
.
add_generic_param
(
"r"
);
f
.
add_generic_param
(
"out_idx"
);
f
.
unused_param
(
"out_idx"
);
g
.
create_function
(
f
);
return
g
.
str
();
}
...
...
src/targets/gpu/compile_hip.cpp
View file @
40fbef9b
...
...
@@ -56,9 +56,6 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DUMP_SRC);
#ifdef MIGRAPHX_USE_HIPRTC
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_TRACE_HIPRTC
);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS
);
std
::
string
hiprtc_error
(
hiprtcResult
err
,
const
std
::
string
&
msg
)
{
return
"hiprtc: "
+
(
hiprtcGetErrorString
(
err
)
+
(
": "
+
msg
));
...
...
@@ -194,6 +191,7 @@ std::vector<std::vector<char>> compile_hip_src_with_hiprtc(std::vector<hiprtc_sr
options
.
push_back
(
"-DMIGRAPHX_HAS_DPP=0"
);
options
.
push_back
(
"-DMIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS=1"
);
options
.
push_back
(
"-Wno-reserved-identifier"
);
options
.
push_back
(
"-Wno-unused-parameter"
);
options
.
push_back
(
"-Wno-gnu-line-marker"
);
options
.
push_back
(
"-Wno-old-style-cast"
);
}
...
...
@@ -216,6 +214,15 @@ std::vector<std::vector<char>>
compile_hip_src
(
const
std
::
vector
<
src_file
>&
srcs
,
std
::
string
params
,
const
std
::
string
&
arch
)
{
std
::
vector
<
hiprtc_src_file
>
hsrcs
{
srcs
.
begin
(),
srcs
.
end
()};
if
(
enabled
(
MIGRAPHX_GPU_DUMP_SRC
{}))
{
for
(
const
auto
&
src
:
srcs
)
{
if
(
src
.
path
.
extension
()
!=
".cpp"
)
continue
;
std
::
cout
<<
std
::
string
(
src
.
content
.
first
,
src
.
len
())
<<
std
::
endl
;
}
}
auto
p
=
dynamic_loader
::
path
(
&
compile_hip_src_with_hiprtc
);
auto
driver
=
p
.
parent_path
().
parent_path
()
/
"bin"
/
"migraphx-hiprtc-driver"
;
...
...
src/targets/gpu/compile_hip_code_object.cpp
View file @
40fbef9b
...
...
@@ -135,10 +135,14 @@ compute_global_for(context& ctx, std::size_t n, std::size_t over)
std
::
size_t
max_global
=
ctx
.
get_current_device
().
get_cu_count
()
*
ctx
.
get_current_device
().
get_max_workitems_per_cu
();
return
[
n
,
over
,
max_global
](
std
::
size_t
local
)
{
std
::
size_t
groups
=
(
n
+
local
-
1
)
/
local
;
std
::
size_t
max_blocks
=
max_global
/
local
;
std
::
size_t
nglobal
=
std
::
min
(
max_blocks
*
over
,
groups
)
*
local
;
return
std
::
min
(
nglobal
,
n
);
// hip require global workitems multiple of local workitems. It may degrade performance.
// [TODO]: consider adding "fno-hip-uniform-block" flag when it becomes available.
// https://reviews.llvm.org/D155213
std
::
size_t
num_elements
=
((
n
+
local
-
1
)
/
local
)
*
local
;
std
::
size_t
groups
=
(
num_elements
+
local
-
1
)
/
local
;
std
::
size_t
max_blocks
=
max_global
/
local
;
std
::
size_t
nglobal
=
std
::
min
(
max_blocks
*
over
,
groups
)
*
local
;
return
std
::
min
(
nglobal
,
num_elements
);
};
}
...
...
@@ -156,14 +160,14 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option
assert
(
not
options
.
inputs
.
empty
());
assert
(
options
.
inputs
.
size
()
==
options
.
virtual_inputs
.
size
()
or
options
.
virtual_inputs
.
empty
());
std
::
vector
<
src_file
>
srcs
;
std
::
vector
<
src_file
>
srcs
=
options
.
additional_src_files
;
std
::
transform
(
migraphx_kernels
().
begin
(),
migraphx_kernels
().
end
(),
std
::
back_inserter
(
srcs
),
[](
auto
&&
p
)
{
auto
&&
name
=
p
.
first
;
auto
&&
c
=
p
.
second
;
auto
path
=
fs
::
path
{
"migraphx"
}
/
"kernels"
/
name
;
auto
path
=
name
;
return
src_file
{
path
,
c
};
});
srcs
.
push_back
(
src_file
{
fs
::
path
{
"main.cpp"
},
...
...
src/targets/gpu/compile_miopen.cpp
View file @
40fbef9b
...
...
@@ -79,7 +79,7 @@ void compile_miopen::apply(module& m) const
std
::
size_t
ws
=
0
;
try
{
// for the regular convolution and
de
convolution, this try would always succeed
// for the regular convolution and convolution
_backwards
, this try would always succeed
ws
=
compile
(
op
,
ins
,
int8_x4_format
);
}
catch
(
migraphx
::
exception
&
)
...
...
src/targets/gpu/compile_ops.cpp
View file @
40fbef9b
...
...
@@ -30,6 +30,7 @@
#include <migraphx/register_op.hpp>
#include <migraphx/op/identity.hpp>
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/gpu/time_op.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
@@ -76,33 +77,201 @@ struct compiled_result
instruction_ref
ins
;
};
struct
problem_cache
{
bool
has
(
const
std
::
string
&
name
,
const
value
&
problem
)
const
{
return
contains
(
cache
,
create_key
(
name
,
problem
));
}
void
insert
(
const
std
::
string
&
name
,
const
value
&
problem
,
const
value
&
solution
)
{
assert
(
not
solution
.
is_null
());
cache
[
create_key
(
name
,
problem
)]
=
solution
;
}
void
mark
(
const
std
::
string
&
name
,
const
value
&
problem
)
{
cache
.
insert
(
std
::
make_pair
(
create_key
(
name
,
problem
),
value
{}));
}
optional
<
value
>
get
(
const
std
::
string
&
name
,
const
value
&
problem
)
const
{
auto
it
=
cache
.
find
(
create_key
(
name
,
problem
));
if
(
it
==
cache
.
end
())
return
nullopt
;
return
it
->
second
;
}
static
value
create_key
(
const
std
::
string
&
name
,
const
value
&
problem
)
{
return
{{
"name"
,
name
},
{
"problem"
,
problem
}};
}
std
::
unordered_map
<
value
,
value
>
cache
;
};
struct
compile_plan
{
context
*
ctx
;
operation
preop
;
instruction_ref
ins
;
optional
<
tuning_config
>
config
=
nullopt
;
std
::
vector
<
optional
<
compiled_result
>>
results
=
{};
void
update_config
(
bool
exhaustive
)
{
config
=
get_tuning_config
(
*
ctx
,
ins
,
preop
,
exhaustive
);
}
template
<
class
Vector
>
void
insert_compiles
(
Vector
&
compiles
,
const
value
&
solution
,
std
::
size_t
i
)
{
compiles
.
emplace_back
([
=
]
{
try
{
results
[
i
]
=
compiled_result
{
compile
(
*
ctx
,
ins
,
preop
,
solution
),
ins
};
}
catch
(...)
{
results
[
i
]
=
nullopt
;
}
});
}
template
<
class
Vector
>
void
add_compiles
(
Vector
&
compiles
,
problem_cache
&
pc
)
{
if
(
config
.
has_value
())
{
const
auto
&
problem
=
config
->
problem
;
if
(
auto
sol
=
pc
.
get
(
preop
.
name
(),
problem
))
{
auto
solution
=
sol
.
value
();
// No solution yet until benchmarked so skip for now
if
(
solution
.
is_null
())
return
;
results
.
resize
(
1
);
insert_compiles
(
compiles
,
solution
,
0
);
}
else
{
pc
.
mark
(
preop
.
name
(),
problem
);
const
auto
&
solutions
=
config
->
solutions
;
results
.
resize
(
solutions
.
size
());
for
(
auto
i
:
range
(
solutions
.
size
()))
{
auto
solution
=
solutions
[
i
];
insert_compiles
(
compiles
,
solution
,
i
);
}
}
}
else
{
results
.
resize
(
1
);
insert_compiles
(
compiles
,
value
{},
0
);
}
}
const
compiled_result
&
benchmark
(
problem_cache
&
pc
)
const
{
if
(
results
.
empty
())
MIGRAPHX_THROW
(
"No configs to tune"
);
if
(
results
.
size
()
==
1
)
{
if
(
not
results
.
front
().
has_value
())
MIGRAPHX_THROW
(
"No configs to tune"
);
return
*
results
.
front
();
}
if
(
not
config
)
MIGRAPHX_THROW
(
"Multiple kernels without config"
);
std
::
cout
<<
"Benchmarking "
<<
preop
.
name
()
<<
": "
<<
results
.
size
()
<<
" configs"
<<
std
::
endl
;
std
::
vector
<
double
>
times
;
times
.
reserve
(
results
.
size
());
std
::
transform
(
results
.
begin
(),
results
.
end
(),
std
::
back_inserter
(
times
),
[
&
](
const
auto
&
cr
)
{
if
(
not
cr
.
has_value
())
return
std
::
numeric_limits
<
double
>::
max
();
return
time_op
(
*
ctx
,
cr
->
replace
.
code_object
,
to_shapes
(
cr
->
ins
->
inputs
()),
20
)
.
first
;
});
auto
i
=
std
::
distance
(
times
.
begin
(),
std
::
min_element
(
times
.
begin
(),
times
.
end
()));
std
::
cout
<<
"Fastest solution: "
<<
config
->
solutions
.
at
(
i
)
<<
std
::
endl
;
pc
.
insert
(
preop
.
name
(),
config
->
problem
,
config
->
solutions
.
at
(
i
));
if
(
not
results
[
i
].
has_value
())
MIGRAPHX_THROW
(
"No valid tuned compilation."
);
return
*
results
[
i
];
}
void
replace
(
module
&
m
,
problem_cache
&
pc
)
const
{
const
auto
&
cr
=
benchmark
(
pc
);
cr
.
replace
.
replace
(
m
,
cr
.
ins
);
}
};
template
<
class
F
>
void
par_compile
(
std
::
size_t
n
,
F
f
)
{
if
(
n
==
0
)
return
;
par_for
(
n
,
n
/
value_of
(
MIGRAPHX_GPU_COMPILE_PARALLEL
{},
n
),
f
);
auto
d
=
value_of
(
MIGRAPHX_GPU_COMPILE_PARALLEL
{});
if
(
d
==
0
)
d
=
n
;
par_for
(
n
,
n
/
d
,
f
);
}
void
compile_
ops
::
apply
(
module
&
m
)
const
struct
compile_
manager
{
std
::
vector
<
std
::
function
<
compiled_result
()
>>
compiles
;
problem_cache
pc
;
std
::
vector
<
compile_plan
>
cps
;
bool
exhaustive
=
false
;
template
<
class
...
Ts
>
void
add_plan
(
Ts
&&
...
xs
)
{
cps
.
push_back
({
std
::
forward
<
Ts
>
(
xs
)...});
}
void
update_configs
()
{
par_compile
(
cps
.
size
(),
[
&
](
auto
i
)
{
cps
[
i
].
update_config
(
exhaustive
);
});
}
void
compile
(
module
&
m
)
{
std
::
vector
<
std
::
function
<
void
()
>>
compiles
;
for
(
auto
&
cp
:
cps
)
{
cp
.
add_compiles
(
compiles
,
pc
);
}
par_compile
(
compiles
.
size
(),
[
&
](
auto
i
)
{
compiles
[
i
]();
});
// Replace and/or benchmark
for
(
const
auto
&
cp
:
cps
)
{
if
(
cp
.
results
.
empty
())
continue
;
cp
.
replace
(
m
,
pc
);
}
// Remove compile_plan already executed
cps
.
erase
(
std
::
remove_if
(
cps
.
begin
(),
cps
.
end
(),
[](
const
auto
&
cp
)
{
return
not
cp
.
results
.
empty
();
}),
cps
.
end
());
}
};
void
compile_ops
::
apply
(
module
&
m
)
const
{
compile_manager
cm
;
cm
.
exhaustive
=
exhaustive_tune
;
// Find all precompile opes
for
(
auto
ins
:
iterator_for
(
m
))
{
if
(
ins
->
name
()
!=
"gpu::precompile_op"
)
continue
;
operation
preop
=
any_cast
<
precompile_op
>
(
ins
->
get_operator
()).
op
;
compiles
.
emplace_back
([
=
]()
->
compiled_result
{
return
{
compile
(
*
ctx
,
ins
,
preop
),
ins
};
});
}
std
::
vector
<
compiled_result
>
results
(
compiles
.
size
());
par_compile
(
compiles
.
size
(),
[
&
](
auto
i
)
{
results
[
i
]
=
compiles
[
i
]();
});
for
(
const
auto
&
cr
:
results
)
{
cr
.
replace
(
m
,
cr
.
ins
);
cm
.
add_plan
(
ctx
,
preop
,
ins
);
}
cm
.
update_configs
();
cm
.
compile
(
m
);
// Compile already tuned configs
cm
.
compile
(
m
);
assert
(
cm
.
cps
.
empty
());
}
}
// namespace gpu
...
...
src/targets/gpu/compiler.cpp
View file @
40fbef9b
...
...
@@ -28,33 +28,45 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
auto
&
compiler_map
()
namespace
{
struct
compiler_handle
{
static
std
::
unordered_map
<
std
::
string
,
compiler_compile
>
m
;
// NOLINT
return
m
;
}
compiler_compile
compile
;
compiler_compile_op
compile_op
;
compiler_tuning_config
get_tuning_config
;
};
}
// namespace
auto
&
compiler_
op_
map
()
auto
&
compiler_map
()
{
static
std
::
unordered_map
<
std
::
string
,
compiler_
compile_op
>
m
;
// NOLINT
static
std
::
unordered_map
<
std
::
string
,
compiler_
handle
>
m
;
// NOLINT
return
m
;
}
void
register_compiler
(
const
std
::
string
&
name
,
compiler_compile
c
,
compiler_compile_op
cop
)
void
register_compiler
(
const
std
::
string
&
name
,
compiler_compile
c
,
compiler_compile_op
cop
,
compiler_tuning_config
ctg
)
{
compiler_map
()[
name
]
=
std
::
move
(
c
);
compiler_op_map
()[
name
]
=
std
::
move
(
cop
);
compiler_map
()[
name
]
=
{
std
::
move
(
c
),
std
::
move
(
cop
),
std
::
move
(
ctg
)};
}
bool
has_compiler_for
(
const
std
::
string
&
name
)
{
return
compiler_map
().
count
(
name
)
>
0
;
}
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
)
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
,
const
value
&
solution
)
{
return
compiler_map
().
at
(
op
.
name
())(
ctx
,
ins
,
op
);
return
compiler_map
().
at
(
op
.
name
())
.
compile
(
ctx
,
ins
,
op
,
solution
);
}
operation
compile_op
(
const
std
::
string
&
name
,
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
{
return
compiler_op_map
().
at
(
name
)(
ctx
,
inputs
,
v
);
return
compiler_map
().
at
(
name
).
compile_op
(
ctx
,
inputs
,
v
);
}
optional
<
tuning_config
>
get_tuning_config
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
,
bool
exhaustive
)
{
return
compiler_map
().
at
(
op
.
name
()).
get_tuning_config
(
ctx
,
ins
,
op
,
exhaustive
);
}
}
// namespace gpu
...
...
src/targets/gpu/device/include/migraphx/gpu/device/visit.hpp
View file @
40fbef9b
...
...
@@ -94,6 +94,10 @@ template <>
struct
is_hip_type
<
std
::
uint8_t
>
:
std
::
true_type
{
};
template
<
>
struct
is_hip_type
<
std
::
int32_t
>
:
std
::
true_type
{
};
template
<
class
T
,
class
V
,
MIGRAPHX_REQUIRES
(
is_hip_type
<
typename
T
::
type
>{})
>
void
hip_visitor_invoke
(
T
as
,
V
&&
v
)
...
...
@@ -120,12 +124,10 @@ void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs)
if
(
not
std
::
all_of
(
types
.
begin
(),
types
.
end
(),
[
&
](
migraphx
::
shape
::
type_t
t
)
{
return
t
==
s
.
type
();
}))
MIGRAPHX_THROW
(
"Types must be the same"
);
std
::
initializer_list
<
index_int
>
ranks
=
{
static_cast
<
index_int
>
(
get_shape
(
xs
).
lens
().
size
())...};
if
(
not
std
::
all_of
(
ranks
.
begin
(),
ranks
.
end
(),
[
&
](
index_int
r
)
{
return
r
==
s
.
lens
().
size
();
}))
std
::
initializer_list
<
index_int
>
ranks
=
{
static_cast
<
index_int
>
(
get_shape
(
xs
).
ndim
())...};
if
(
not
std
::
all_of
(
ranks
.
begin
(),
ranks
.
end
(),
[
&
](
index_int
r
)
{
return
r
==
s
.
ndim
();
}))
MIGRAPHX_THROW
(
"Ranks must be the same"
);
visit_tensor_size
(
s
.
lens
().
size
(),
[
&
](
auto
ndim
)
{
visit_tensor_size
(
s
.
ndim
(),
[
&
](
auto
ndim
)
{
s
.
visit_type
(
hip_visitor
([
&
](
auto
as
)
{
v
(
f
(
xs
,
ndim
,
as
)...);
}));
});
}
...
...
@@ -133,12 +135,10 @@ void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs)
template
<
class
V
,
class
F
,
class
...
Ts
>
void
hip_visit_views_impl
(
const
shape
&
s
,
F
f
,
V
&&
v
,
Ts
&&
...
xs
)
{
std
::
initializer_list
<
index_int
>
ranks
=
{
static_cast
<
index_int
>
(
get_shape
(
xs
).
lens
().
size
())...};
if
(
not
std
::
all_of
(
ranks
.
begin
(),
ranks
.
end
(),
[
&
](
index_int
r
)
{
return
r
==
s
.
lens
().
size
();
}))
std
::
initializer_list
<
index_int
>
ranks
=
{
static_cast
<
index_int
>
(
get_shape
(
xs
).
ndim
())...};
if
(
not
std
::
all_of
(
ranks
.
begin
(),
ranks
.
end
(),
[
&
](
index_int
r
)
{
return
r
==
s
.
ndim
();
}))
MIGRAPHX_THROW
(
"Ranks must be the same"
);
visit_tensor_size
(
s
.
lens
().
size
(),
[
&
](
auto
ndim
)
{
v
(
f
(
xs
,
ndim
)...);
});
visit_tensor_size
(
s
.
ndim
(),
[
&
](
auto
ndim
)
{
v
(
f
(
xs
,
ndim
)...);
});
}
template
<
class
F
>
...
...
src/targets/gpu/device/multinomial.cpp
View file @
40fbef9b
...
...
@@ -67,18 +67,19 @@ void multinomial(hipStream_t stream,
size_t
class_size
=
arg0
.
get_shape
().
lens
().
back
();
size_t
sample_size
=
result
.
get_shape
().
lens
().
back
();
hip_visit_all
(
arg0
,
arg1
)([
&
](
auto
cdf
,
auto
dist
)
{
result
.
visit
([
&
](
auto
out
)
{
hip_visit_views
(
out
)([
&
](
auto
output
)
{
gs_launch
(
stream
,
batch_size
*
sample_size
)([
=
](
auto
i
)
__device__
{
auto
idx
=
output
.
get_shape
().
multi
(
i
);
auto
cdf_begin
=
cdf
.
begin
()
+
(
idx
.
front
()
*
class_size
);
auto
cdf_end
=
cdf_begin
+
class_size
;
auto
sample_iter
=
upper_bound
(
cdf_begin
,
cdf_end
,
dist
[
i
]
*
*
(
std
::
prev
(
cdf_end
)));
output
[
i
]
=
std
::
distance
(
cdf_begin
,
sample_iter
);
visit_all
(
arg0
,
arg1
)([
&
](
auto
cdf_host
,
auto
dist_host
)
{
result
.
visit
([
&
](
auto
output_host
)
{
hip_visit_views
(
cdf_host
,
dist_host
,
output_host
)(
[
&
](
auto
cdf
,
auto
dist
,
auto
output
)
{
gs_launch
(
stream
,
batch_size
*
sample_size
)([
=
](
auto
i
)
__device__
{
auto
idx
=
output
.
get_shape
().
multi
(
i
);
auto
cdf_begin
=
cdf
.
begin
()
+
(
idx
.
front
()
*
class_size
);
auto
cdf_end
=
cdf_begin
+
class_size
;
auto
*
sample_iter
=
upper_bound
(
cdf_begin
,
cdf_end
,
dist
[
i
]
*
*
(
std
::
prev
(
cdf_end
)));
output
[
i
]
=
std
::
distance
(
cdf_begin
,
sample_iter
);
});
});
});
});
});
}
...
...
src/targets/gpu/device/scatter.cpp
View file @
40fbef9b
...
...
@@ -37,22 +37,26 @@ argument scatter(
hipStream_t
stream
,
argument
result
,
argument
arg0
,
argument
arg1
,
argument
arg2
,
int64_t
axis
)
{
auto
ds
=
arg0
.
get_shape
();
auto
inds
=
arg1
.
get_shape
();
auto
s1
=
arg1
.
get_shape
();
auto
axis_dim_size
=
ds
.
lens
()[
axis
];
hip_visit_all
(
result
,
arg0
,
inds
)([
&
](
auto
output
,
auto
data
,
auto
s1
)
{
hip_visit_all
(
result
,
arg0
,
arg2
)([
&
](
auto
output
,
auto
data
,
auto
update
)
{
auto
*
output_ptr
=
device_cast
(
output
.
data
());
const
auto
*
data_ptr
=
device_cast
(
data
.
data
());
gs_launch
(
stream
,
ds
.
elements
())([
=
](
auto
i
)
__device__
{
output_ptr
[
i
]
=
data_ptr
[
i
];
});
hip_visit_all
(
arg1
,
arg2
)([
&
](
auto
indices
,
auto
update
)
{
const
auto
*
upd_ptr
=
device_cast
(
update
.
data
());
const
auto
*
indices_ptr
=
device_cast
(
indices
.
data
());
gs_launch
(
stream
,
inds
.
elements
())([
=
](
auto
i
)
__device__
{
auto
out_idx
=
s1
.
multi
(
i
);
auto
index
=
indices_ptr
[
i
];
index
=
index
<
0
?
index
+
axis_dim_size
:
index
;
out_idx
[
axis
]
=
index
;
output
[
out_idx
]
=
upd_ptr
[
i
];
});
hip_visit_all
(
arg1
)([
&
](
auto
indices
)
{
if
constexpr
(
indices
.
get_shape
().
lens
.
size
()
==
output
.
get_shape
().
lens
.
size
())
{
const
auto
*
upd_ptr
=
device_cast
(
update
.
data
());
const
auto
*
indices_ptr
=
device_cast
(
indices
.
data
());
gs_launch
(
stream
,
s1
.
elements
())([
=
](
auto
i
)
__device__
{
auto
out_idx
=
indices
.
get_shape
().
multi
(
i
);
auto
index
=
indices_ptr
[
i
];
index
=
index
<
0
?
index
+
axis_dim_size
:
index
;
out_idx
[
axis
]
=
index
;
output
[
out_idx
]
=
upd_ptr
[
i
];
});
}
});
});
...
...
src/targets/gpu/device_name.cpp
View file @
40fbef9b
...
...
@@ -43,6 +43,8 @@ auto get_arch_name(rank<1>, const HipDeviceProp& props) -> decltype(std::string(
return
std
::
string
(
props
.
gcnArchName
);
}
std
::
string
get_arch_name
(
const
hipDeviceProp_t
&
props
)
{
return
get_arch_name
(
rank
<
1
>
{},
props
);
}
int
get_device_id
()
{
int
device
;
...
...
@@ -58,7 +60,7 @@ std::string get_device_name()
auto
status
=
hipGetDeviceProperties
(
&
props
,
get_device_id
());
if
(
status
!=
hipSuccess
)
MIGRAPHX_THROW
(
"Failed to get device properties"
);
return
get_arch_name
(
rank
<
1
>
{},
props
);
return
get_arch_name
(
props
);
}
}
// namespace gpu
...
...
src/targets/gpu/driver/CMakeLists.txt
View file @
40fbef9b
...
...
@@ -22,7 +22,7 @@
# THE SOFTWARE.
#####################################################################################
file
(
GLOB GPU_DRIVER_SRCS
${
CONFIGURE_DEPENDS
}
${
CMAKE_CURRENT_SOURCE_DIR
}
/*.cpp
)
file
(
GLOB GPU_DRIVER_SRCS CONFIGURE_DEPENDS
${
CMAKE_CURRENT_SOURCE_DIR
}
/*.cpp
)
add_executable
(
gpu-driver
${
GPU_DRIVER_SRCS
}
)
...
...
src/targets/gpu/driver/compile_op.cpp
View file @
40fbef9b
...
...
@@ -22,7 +22,7 @@
* THE SOFTWARE.
*/
#include <migraphx/gpu/driver/action.hpp>
#include <migraphx/gpu/
driver/perf
.hpp>
#include <migraphx/gpu/
time_op
.hpp>
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/gpu/context.hpp>
...
...
src/targets/gpu/driver/run_op.cpp
View file @
40fbef9b
...
...
@@ -22,7 +22,7 @@
* THE SOFTWARE.
*/
#include <migraphx/gpu/driver/action.hpp>
#include <migraphx/gpu/
driver/perf
.hpp>
#include <migraphx/gpu/
time_op
.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/make_op.hpp>
...
...
src/targets/gpu/fuse_ck.cpp
0 → 100644
View file @
40fbef9b
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gpu/fuse_ck.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
module
;
namespace
gpu
{
struct
ck_gemm
{
operation
op
=
make_op
(
"dot"
);
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
op
,
"op"
));
}
std
::
string
name
()
const
{
return
"gpu::ck_gemm"
;
}
void
check_gemm_shape
(
const
shape
&
s
)
const
{
if
(
not
contains
(
range
(
s
.
strides
().
rbegin
(),
s
.
strides
().
rbegin
()
+
3
),
1
))
MIGRAPHX_THROW
(
"Invalid shape for ck_gemm"
);
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
,
const
std
::
vector
<
module_ref
>&
mods
)
const
{
check_shapes
{
inputs
,
*
this
}.
same_ndims
();
if
(
inputs
.
size
()
<
2
)
MIGRAPHX_THROW
(
"should have at least two inputs."
);
auto
a
=
inputs
[
0
];
auto
b
=
inputs
[
1
];
for
(
const
auto
&
input
:
inputs
)
check_gemm_shape
(
input
);
auto
r
=
op
.
compute_shape
({
a
,
b
});
if
(
mods
.
empty
())
return
r
;
return
r
.
with_type
(
mods
.
front
()
->
get_output_shapes
().
front
().
type
());
}
};
MIGRAPHX_REGISTER_OP
(
ck_gemm
);
namespace
{
bool
is_ck_supported_type
(
shape
::
type_t
t
)
{
return
contains
({
shape
::
half_type
,
shape
::
int8_type
,
shape
::
int32_type
},
t
);
}
MIGRAPHX_PRED_MATCHER
(
is_ck_gemm
,
instruction_ref
ins
)
{
if
(
ins
->
name
()
!=
"dot"
and
ins
->
name
()
!=
"quant_dot"
)
return
false
;
if
(
not
is_ck_supported_type
(
ins
->
get_shape
().
type
()))
return
false
;
auto
a
=
ins
->
inputs
().
front
()
->
get_shape
();
auto
b
=
ins
->
inputs
().
back
()
->
get_shape
();
auto
m
=
a
.
lens
()[
a
.
lens
().
size
()
-
2
];
auto
n
=
b
.
lens
().
back
();
auto
k
=
a
.
lens
().
back
();
// Integer gemms must be divisible by 4 in ck
if
(
contains
({
shape
::
int8_type
,
shape
::
int32_type
},
ins
->
get_shape
().
type
()))
{
if
(
m
%
4
!=
0
)
return
false
;
if
(
n
%
4
!=
0
)
return
false
;
if
(
k
%
4
!=
0
)
return
false
;
}
// Skipping GEMMs with a K dimension greater than 2048 is a course-grained strategy
// to avoid poor-performing GEMM kernels from CK
// To-do: Investigate a more precise strategy
return
k
<=
2048
;
}
struct
find_ck_gemm_pointwise
{
// Find a gemm followed by a pointwise operation.
auto
matcher
()
const
{
auto
gemm
=
match
::
skip
(
match
::
name
(
"contiguous"
))(
match
::
name
(
"dot"
,
"quant_dot"
)(
is_ck_gemm
().
bind
(
"gemm"
)));
return
match
::
name
(
"pointwise"
)(
match
::
any_of
[
match
::
inputs
()](
gemm
.
bind
(
"x"
)));
}
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
gemm_ins
=
r
.
instructions
[
"gemm"
];
auto
x_ins
=
r
.
instructions
[
"x"
];
// input after contiguous
auto
*
pm
=
ins
->
module_inputs
().
front
();
auto
names
=
pm
->
get_parameter_names
();
std
::
sort
(
names
.
begin
(),
names
.
end
());
auto
inputs
=
ins
->
inputs
();
auto
gemm_it
=
std
::
find
(
inputs
.
begin
(),
inputs
.
end
(),
x_ins
);
auto
gemm_idx
=
gemm_it
-
inputs
.
begin
();
if
(
gemm_ins
->
get_shape
().
type
()
!=
shape
::
int32_type
and
ins
->
get_shape
().
type
()
!=
gemm_ins
->
get_shape
().
type
())
return
;
if
(
std
::
any_of
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
[](
auto
input
)
{
return
not
is_ck_supported_type
(
input
->
get_shape
().
type
());
}))
return
;
assert
(
gemm_it
!=
inputs
.
end
());
if
(
gemm_idx
!=
0
)
{
auto
first_param
=
pm
->
get_parameter
(
names
[
0
]);
auto
gemm_param
=
pm
->
get_parameter
(
names
[
gemm_idx
]);
auto
new_gemm_param
=
pm
->
add_parameter
(
names
[
0
]
+
"_0"
,
gemm_param
->
get_shape
());
auto
new_first_param
=
pm
->
add_parameter
(
names
[
gemm_idx
]
+
"_0"
,
first_param
->
get_shape
());
pm
->
replace_instruction
(
gemm_param
,
new_gemm_param
);
pm
->
replace_instruction
(
first_param
,
new_first_param
);
pm
->
remove_instruction
(
first_param
);
pm
->
remove_instruction
(
gemm_param
);
}
inputs
.
erase
(
gemm_it
);
inputs
.
insert
(
inputs
.
begin
(),
gemm_ins
->
inputs
().
begin
(),
gemm_ins
->
inputs
().
end
());
mpm
.
get_module
().
replace_instruction
(
ins
,
ck_gemm
{
gemm_ins
->
get_operator
()},
inputs
,
{
pm
});
}
};
struct
find_ck_gemm
{
auto
matcher
()
const
{
return
match
::
name
(
"dot"
)(
is_ck_gemm
().
bind
(
"gemm"
));
}
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
mpm
.
get_module
().
replace_instruction
(
ins
,
ck_gemm
{
ins
->
get_operator
()},
ins
->
inputs
());
}
};
}
// namespace
void
fuse_ck
::
apply
(
module_pass_manager
&
mpm
)
const
{
match
::
find_matches
(
mpm
,
find_ck_gemm_pointwise
{});
match
::
find_matches
(
mpm
,
find_ck_gemm
{});
}
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/fuse_mlir.cpp
View file @
40fbef9b
...
...
@@ -38,6 +38,27 @@ namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_ENABLE_MLIR
);
bool
mlir_enabled
()
{
#ifdef MIGRAPHX_MLIR
const
bool
mlir_enabled
=
enabled
(
MIGRAPHX_ENABLE_MLIR
{});
if
(
mlir_enabled
)
{
return
true
;
}
else
{
std
::
cerr
<<
"WARNING: MIGraphX built with MLIR but it is not enabled. Please set the env "
"var MIGRAPHX_ENABLE_MLIR to use MLIR kernel generator."
<<
std
::
endl
;
return
false
;
}
#else
return
false
;
#endif
}
#ifdef MIGRAPHX_MLIR
struct
mlir_op
...
...
@@ -58,8 +79,41 @@ struct mlir_op
MIGRAPHX_THROW
(
"should have one submodule."
);
if
(
inputs
.
size
()
<
2
)
MIGRAPHX_THROW
(
"should have at least two inputs."
);
auto
n
=
inputs
.
size
();
return
op
.
compute_shape
({
inputs
[
n
-
2
],
inputs
[
n
-
1
]});
module_ref
mod
=
mods
[
0
];
auto
type
=
mod
->
get_output_shapes
().
front
().
type
();
std
::
unordered_map
<
instruction_ref
,
shape
>
ins_shapes
;
size_t
param_cnt
=
0
;
std
::
vector
<
std
::
string
>
names
=
mod
->
get_parameter_names
();
std
::
sort
(
names
.
begin
(),
names
.
end
());
for
(
std
::
string
param_name
:
names
)
{
ins_shapes
[
mod
->
get_parameter
(
param_name
)]
=
inputs
[
param_cnt
++
];
}
for
(
auto
ins
:
iterator_for
(
*
mod
))
{
if
(
ins
->
name
()
==
"@param"
)
{
continue
;
}
if
(
ins
->
name
()
==
"@literal"
)
{
ins_shapes
[
ins
]
=
ins
->
get_shape
();
continue
;
}
if
(
ins
->
name
()
==
"@return"
)
{
return
ins_shapes
[
ins
->
inputs
().
at
(
0
)].
with_type
(
type
);
}
std
::
vector
<
shape
>
input_shapes
;
input_shapes
.
resize
(
ins
->
inputs
().
size
());
std
::
transform
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
input_shapes
.
begin
(),
[
&
](
auto
in
)
{
return
ins_shapes
[
in
];
});
ins_shapes
[
ins
]
=
ins
->
get_operator
().
compute_shape
(
input_shapes
);
}
MIGRAPHX_THROW
(
"No return found in the submodule"
);
}
};
MIGRAPHX_REGISTER_OP
(
mlir_op
);
...
...
@@ -68,7 +122,7 @@ namespace {
MIGRAPHX_PRED_MATCHER
(
is_mlir_conv
,
instruction_ref
ins
)
{
if
(
ins
->
name
()
!=
"convolution"
)
if
(
ins
->
name
()
!=
"convolution"
and
ins
->
name
()
!=
"quant_convolution"
)
return
false
;
value
v
=
ins
->
get_operator
().
to_value
();
auto
group
=
v
.
at
(
"group"
).
to
<
int
>
();
...
...
@@ -85,10 +139,121 @@ struct find_mlir_op
auto
matcher
()
const
{
auto
dot_or_conv
=
match
::
skip
(
match
::
name
(
"contiguous"
))(
match
::
any_of
(
match
::
name
(
"dot"
),
is_mlir_conv
()).
bind
(
"gemm_based_op"
));
match
::
any_of
(
match
::
name
(
"dot"
),
match
::
name
(
"quant_dot"
),
is_mlir_conv
())
.
bind
(
"gemm_based_op"
));
return
match
::
name
(
"pointwise"
)(
match
::
any_of
[
match
::
inputs
()](
dot_or_conv
.
bind
(
"x"
)));
}
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
create_param_map_with_literals
(
module_ref
mm
,
const
module
*
pm
,
const
shape
&
shape
)
const
{
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
ins_map
;
for
(
auto
ins
:
iterator_for
(
*
pm
))
{
if
(
ins
->
name
()
!=
"@literal"
)
{
continue
;
}
literal
r
=
ins
->
get_literal
();
instruction_ref
literal
=
mm
->
add_literal
(
r
);
instruction_ref
mbcast
=
mm
->
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
shape
.
lens
()}}),
literal
);
ins_map
[
ins
]
=
mbcast
;
}
return
ins_map
;
}
std
::
tuple
<
instruction_ref
,
std
::
vector
<
instruction_ref
>>
fuse_input_ops_and_gemm_based_op
(
module_ref
mm
,
instruction_ref
gemm_based_op
)
const
{
std
::
vector
<
instruction_ref
>
top_inputs
;
std
::
vector
<
instruction_ref
>
imm_inputs
;
size_t
input_cnt
=
0
;
for
(
instruction_ref
input
:
gemm_based_op
->
inputs
())
{
std
::
vector
<
operation
>
op_stream
;
while
(
contains
({
"slice"
,
"transpose"
,
"contiguous"
,
"reshape"
},
input
->
name
()))
{
op_stream
.
push_back
(
input
->
get_operator
());
input
=
input
->
inputs
().
at
(
0
);
}
top_inputs
.
push_back
(
input
);
instruction_ref
prev_input
=
mm
->
add_parameter
(
"y"
+
std
::
to_string
(
input_cnt
++
),
input
->
get_shape
());
for
(
const
auto
&
op
:
reverse
(
op_stream
))
{
prev_input
=
mm
->
add_instruction
(
op
,
{
prev_input
});
}
imm_inputs
.
push_back
(
prev_input
);
}
instruction_ref
new_gemm_based_op
=
mm
->
add_instruction
(
gemm_based_op
->
get_operator
(),
imm_inputs
);
return
{
new_gemm_based_op
,
top_inputs
};
}
// Whitelist supported fusion options, including imposing type constraints
// for cases where MLIR only supports an operation (usually a pointwise function)
// on particular types.
bool
is_pointwise_op_supported_by_mlir
(
const
instruction
&
i
)
const
{
using
type_t
=
shape
::
type_t
;
const
auto
&
name
=
i
.
name
();
const
auto
result_type
=
i
.
get_shape
().
type
();
const
std
::
initializer_list
<
type_t
>
allowed_types
=
{
type_t
::
float_type
,
type_t
::
half_type
,
type_t
::
int8_type
,
type_t
::
int32_type
,
type_t
::
bool_type
};
// Preliminary type check.
if
(
not
contains
(
allowed_types
,
result_type
))
{
return
false
;
}
const
std
::
initializer_list
<
std
::
string
>
any_type_ops
=
{
"@literal"
,
"@param"
,
"@return"
};
const
std
::
initializer_list
<
std
::
string
>
no_bool_ops
=
{
"convolution"
,
"quant_convolution"
,
"dot"
,
"quant_dot"
,
"add"
,
"clip"
,
"relu"
,
"sub"
,
"mul"
,
"div"
,
"pow"
,
"where"
,
"quantizelinear"
,
"dequantizelinear"
,
"abs"
,
"neg"
};
const
std
::
initializer_list
<
std
::
string
>
fp_only_ops
=
{
"ceil"
,
"erf"
,
"exp"
,
"floor"
,
"log"
,
"recip"
,
"rsqrt"
,
"sigmoid"
"softmax"
,
"tanh"
};
bool
is_float
=
contains
({
type_t
::
float_type
,
type_t
::
half_type
},
result_type
);
if
(
contains
(
any_type_ops
,
name
))
return
true
;
if
(
result_type
!=
type_t
::
bool_type
&&
contains
(
no_bool_ops
,
name
))
return
true
;
if
(
is_float
&&
contains
(
fp_only_ops
,
name
))
return
true
;
// Only conversions between floating types are known to be unambigiously
// supported.
if
(
is_float
&&
name
==
"convert"
)
{
return
std
::
all_of
(
i
.
inputs
().
begin
(),
i
.
inputs
().
end
(),
[](
const
auto
&
arg
)
{
return
contains
({
type_t
::
float_type
,
type_t
::
half_type
},
arg
->
get_shape
().
type
());
});
}
return
false
;
}
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
...
...
@@ -96,35 +261,25 @@ struct find_mlir_op
auto
x_ins
=
r
.
instructions
[
"x"
];
// input after contiguous
auto
*
pm
=
ins
->
module_inputs
().
front
();
auto
names
=
pm
->
get_parameter_names
();
// Whitelist pointwise operators
if
(
std
::
any_of
(
pm
->
begin
(),
pm
->
end
(),
[](
const
auto
&
i
)
{
return
not
contains
(
{
"@literal"
,
"@param"
,
"@return"
,
"convolution"
,
"dot"
,
"add"
,
"relu"
},
i
.
name
());
}))
return
;
// Only fuse with fp32/fp16
if
(
std
::
any_of
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
[
&
](
auto
i
)
{
return
not
contains
({
shape
::
type_t
::
float_type
,
shape
::
type_t
::
half_type
},
i
->
get_shape
().
type
());
// Whitelist pointwise operators.
if
(
std
::
any_of
(
pm
->
begin
(),
pm
->
end
(),
[
&
](
const
auto
&
i
)
{
return
not
is_pointwise_op_supported_by_mlir
(
i
);
}))
return
;
std
::
sort
(
names
.
begin
(),
names
.
end
());
module_ref
mm
=
mpm
.
create_module
(
"mlir_"
+
pm
->
name
());
mm
->
set_bypass
();
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
param_map
;
auto
x
=
mm
->
add_parameter
(
"x"
+
std
::
to_string
(
names
.
size
()),
gemm_based_op
->
inputs
().
at
(
0
)
->
get_shape
());
auto
w
=
mm
->
add_parameter
(
"x"
+
std
::
to_string
(
names
.
size
()
+
1
),
gemm_based_op
->
inputs
().
at
(
1
)
->
get_shape
());
auto
conv
=
mm
->
add_instruction
(
gemm_based_op
->
get_operator
(),
{
x
,
w
});
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
param_map
=
create_param_map_with_literals
(
mm
,
pm
,
gemm_based_op
->
get_shape
());
auto
[
anchor_op
,
top_inputs
]
=
fuse_input_ops_and_gemm_based_op
(
mm
,
gemm_based_op
);
std
::
transform
(
names
.
begin
(),
names
.
end
(),
ins
->
inputs
().
begin
(),
std
::
inserter
(
param_map
,
param_map
.
end
()),
[
&
](
auto
name
,
auto
input
)
{
[
&
,
&
anchor_op
=
anchor_op
](
auto
name
,
auto
input
)
{
if
(
input
==
x_ins
)
return
std
::
make_pair
(
pm
->
get_parameter
(
name
),
conv
);
return
std
::
make_pair
(
pm
->
get_parameter
(
name
),
anchor_op
);
return
std
::
make_pair
(
pm
->
get_parameter
(
name
),
mm
->
add_parameter
(
name
,
input
->
get_shape
()));
});
...
...
@@ -135,7 +290,7 @@ struct find_mlir_op
ins
->
inputs
().
end
(),
std
::
back_inserter
(
inputs
),
[
&
](
auto
input
)
{
return
input
!=
gemm_based_op
;
});
inputs
.
insert
(
inputs
.
end
(),
gemm_based_op
->
inputs
()
.
begin
(),
gemm_based_op
->
inputs
()
.
end
());
inputs
.
insert
(
inputs
.
end
(),
top_
inputs
.
begin
(),
top_
inputs
.
end
());
mpm
.
get_module
().
replace_instruction
(
ins
,
mlir_op
{
gemm_based_op
->
get_operator
()},
inputs
,
{
mm
});
}
...
...
@@ -148,17 +303,7 @@ struct find_mlir_op
void
fuse_mlir
::
apply
(
module_pass_manager
&
mpm
)
const
{
#ifdef MIGRAPHX_MLIR
const
bool
mlir_enabled
=
enabled
(
MIGRAPHX_ENABLE_MLIR
{});
if
(
mlir_enabled
)
{
match
::
find_matches
(
mpm
,
find_mlir_op
{});
}
else
{
std
::
cerr
<<
"WARNING: MIGraphX built with MLIR but it is not enabled. Please set the env "
"var MIGRAPHX_ENABLE_MLIR to use MLIR kernel generator."
<<
std
::
endl
;
}
match
::
find_matches
(
mpm
,
find_mlir_op
{});
#else
(
void
)
mpm
;
#endif
...
...
src/targets/gpu/fuse_ops.cpp
View file @
40fbef9b
...
...
@@ -165,7 +165,8 @@ struct fusion
const
std
::
unordered_set
<
std
::
string
>&
get_supported_archs
()
{
static
std
::
unordered_set
<
std
::
string
>
supported_archs
{
"gfx900"
,
"gfx906"
,
"gfx908"
,
"gfx1030"
};
static
std
::
unordered_set
<
std
::
string
>
supported_archs
{
"gfx900"
,
"gfx906"
,
"gfx908"
,
"gfx1030"
,
"gfx940"
};
return
supported_archs
;
}
...
...
src/targets/gpu/gemm_impl.cpp
View file @
40fbef9b
...
...
@@ -140,12 +140,10 @@ void gemm_impl(context& ctx,
compute_type
=
rocblas_datatype_f32_r
;
}
#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38
rocblas_gemm_flags
flag
=
int8_x4_format
?
rocblas_gemm_flags_pack_int8x4
:
rocblas_gemm_flags_none
;
#else
(
void
)
int8_x4_format
;
int
flag
=
0
;
rocblas_gemm_flags
flag
=
rocblas_gemm_flags_none
;
#if ROCBLAS_VERSION_MAJOR < 3
if
(
int8_x4_format
)
flag
=
rocblas_gemm_flags_pack_int8x4
;
#endif
auto
a_lens
=
args
[
0
].
get_shape
().
lens
();
...
...
Prev
1
…
8
9
10
11
12
13
14
15
16
…
22
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment