Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
4ea39116
Commit
4ea39116
authored
Nov 10, 2023
by
Khalique Ahmed
Browse files
manual merge
parents
20128cae
d8011adf
Changes
315
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
282 additions
and
223 deletions
+282
-223
src/quantization.cpp
src/quantization.cpp
+5
-4
src/rewrite_quantization.cpp
src/rewrite_quantization.cpp
+19
-4
src/simplify_dyn_ops.cpp
src/simplify_dyn_ops.cpp
+46
-2
src/targets/cpu/include/migraphx/cpu/dnnl.hpp
src/targets/cpu/include/migraphx/cpu/dnnl.hpp
+15
-2
src/targets/cpu/include/migraphx/cpu/fuse_ops.hpp
src/targets/cpu/include/migraphx/cpu/fuse_ops.hpp
+2
-4
src/targets/cpu/include/migraphx/cpu/pointwise.hpp
src/targets/cpu/include/migraphx/cpu/pointwise.hpp
+1
-0
src/targets/gpu/CMakeLists.txt
src/targets/gpu/CMakeLists.txt
+31
-21
src/targets/gpu/argmax.cpp
src/targets/gpu/argmax.cpp
+3
-2
src/targets/gpu/argmin.cpp
src/targets/gpu/argmin.cpp
+3
-2
src/targets/gpu/compile_hip.cpp
src/targets/gpu/compile_hip.cpp
+14
-13
src/targets/gpu/compile_hip_code_object.cpp
src/targets/gpu/compile_hip_code_object.cpp
+28
-20
src/targets/gpu/compile_miopen.cpp
src/targets/gpu/compile_miopen.cpp
+4
-15
src/targets/gpu/compile_ops.cpp
src/targets/gpu/compile_ops.cpp
+28
-10
src/targets/gpu/device/argmax.cpp
src/targets/gpu/device/argmax.cpp
+10
-3
src/targets/gpu/device/argmin.cpp
src/targets/gpu/device/argmin.cpp
+10
-3
src/targets/gpu/device/int8_gemm_pack.cpp
src/targets/gpu/device/int8_gemm_pack.cpp
+0
-97
src/targets/gpu/device/targets.hpp.in
src/targets/gpu/device/targets.hpp.in
+5
-1
src/targets/gpu/driver/compile_op.cpp
src/targets/gpu/driver/compile_op.cpp
+2
-4
src/targets/gpu/driver/run_op.cpp
src/targets/gpu/driver/run_op.cpp
+2
-2
src/targets/gpu/fuse_ck.cpp
src/targets/gpu/fuse_ck.cpp
+54
-14
No files found.
src/quantization.cpp
View file @
4ea39116
...
...
@@ -70,6 +70,10 @@ void quantize_int8(program& prog,
MIGRAPHX_THROW
(
"QUANTIZE_INT8: only support DOT and CONVOLUTION operation"
);
}
// Run optimize_module() before converting to int8 to const eval and fold in FP32 to
// avoid loss of precision.
run_passes
(
prog
,
{
optimize_module
{}});
std
::
shared_ptr
<
std
::
vector
<
std
::
pair
<
float
,
float
>>>
int8_quant_params
=
std
::
make_shared
<
std
::
vector
<
std
::
pair
<
float
,
float
>>>
();
std
::
shared_ptr
<
std
::
vector
<
float
>>
max_abs_vals
=
std
::
make_shared
<
std
::
vector
<
float
>>
();
...
...
@@ -143,11 +147,8 @@ void quantize_int8(program& prog,
run_passes
(
prog
,
{
quantize_int8_pass
{
ins_names
,
*
int8_quant_params
},
eliminate_common_subexpression
{},
dead_code_elimination
{},
simplify_reshapes
{},
dead_code_elimination
{},
simplify_qdq
{},
optimize_module
{},
dead_code_elimination
{}});
}
...
...
src/rewrite_quantization.cpp
View file @
4ea39116
...
...
@@ -33,6 +33,8 @@
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_ENABLE_CK_WORKAROUNDS
);
void
apply_quantizelinear
(
module
&
m
,
instruction_ref
ins
)
{
assert
(
ins
->
name
()
==
"quantizelinear"
);
...
...
@@ -45,7 +47,7 @@ void apply_quantizelinear(module& m, instruction_ref ins)
ins
,
make_op
(
"convert"
,
{{
"target_type"
,
y_scale
->
get_shape
().
type
()}}),
x
);
}
auto
div
=
m
.
insert_instruction
(
ins
,
make_op
(
"div"
),
x
,
y_scale
);
auto
add_zero_point
=
m
.
insert_instruction
(
ins
,
make_op
(
"
round
"
),
div
);
auto
add_zero_point
=
m
.
insert_instruction
(
ins
,
make_op
(
"
nearbyint
"
),
div
);
if
(
ins
->
inputs
().
size
()
==
3
)
{
...
...
@@ -62,9 +64,22 @@ void apply_quantizelinear(module& m, instruction_ref ins)
max_quant
=
qt
.
max
();
min_quant
=
qt
.
min
();
});
auto
s
=
add_zero_point
->
get_shape
();
auto
min_arg
=
m
.
add_literal
(
literal
{
shape
{
s
.
type
()},
{
min_quant
}});
auto
max_arg
=
m
.
add_literal
(
literal
{
shape
{
s
.
type
()},
{
max_quant
}});
auto
s
=
add_zero_point
->
get_shape
();
instruction_ref
min_arg
;
instruction_ref
max_arg
;
if
(
enabled
(
MIGRAPHX_ENABLE_CK_WORKAROUNDS
{}))
{
std
::
vector
<
int
>
min_data
(
s
.
elements
(),
min_quant
);
std
::
vector
<
int
>
max_data
(
s
.
elements
(),
max_quant
);
min_arg
=
m
.
add_literal
(
literal
(
s
,
min_data
));
max_arg
=
m
.
add_literal
(
literal
(
s
,
max_data
));
}
else
{
min_arg
=
m
.
add_literal
(
literal
{
shape
{
s
.
type
()},
{
min_quant
}});
max_arg
=
m
.
add_literal
(
literal
{
shape
{
s
.
type
()},
{
max_quant
}});
}
auto
saturate
=
insert_common_op
(
m
,
ins
,
make_op
(
"clip"
),
{
add_zero_point
,
min_arg
,
max_arg
});
m
.
replace_instruction
(
ins
,
make_op
(
"convert"
,
{{
"target_type"
,
ins
->
get_shape
().
type
()}}),
saturate
);
...
...
src/simplify_dyn_ops.cpp
View file @
4ea39116
...
...
@@ -24,6 +24,7 @@
#include <migraphx/simplify_dyn_ops.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/literal.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
@@ -131,10 +132,53 @@ struct find_const_4in_slice
}
};
/**
* Simplify dimensions_of to a literal when the input arugment has a static shape
* or the dynamic dimensions from `start` to `end` are fixed.
*/
struct
find_static_dimensions_of
{
auto
matcher
()
const
{
return
match
::
name
(
"dimensions_of"
)();
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
mr
)
const
{
auto
ins
=
mr
.
result
;
auto
input
=
ins
->
inputs
().
at
(
0
);
auto
dimensions_of_value
=
ins
->
get_operator
().
to_value
();
auto
start
=
dimensions_of_value
.
at
(
"start"
).
to
<
std
::
size_t
>
();
auto
end
=
dimensions_of_value
.
at
(
"end"
).
to
<
std
::
size_t
>
();
if
(
input
->
get_shape
().
dynamic
())
{
// check if dynamic dimensions from start to end are fixed
auto
dds
=
input
->
get_shape
().
dyn_dims
();
if
(
std
::
any_of
(
dds
.
begin
()
+
start
,
dds
.
begin
()
+
end
,
[](
auto
dd
)
{
return
not
dd
.
is_fixed
();
}))
{
return
;
}
}
std
::
size_t
output_ndim
=
end
-
start
;
std
::
vector
<
int64_t
>
vec_shape
(
output_ndim
);
migraphx
::
shape
s
(
migraphx
::
shape
::
int64_type
,
{
output_ndim
});
std
::
vector
<
std
::
size_t
>
input_lens
=
input
->
get_shape
().
to_static
(
1
).
lens
();
std
::
transform
(
input_lens
.
begin
()
+
start
,
input_lens
.
begin
()
+
end
,
vec_shape
.
begin
(),
[](
auto
i
)
{
return
int64_t
(
i
);
});
migraphx
::
shape
output_shape
{
migraphx
::
shape
::
int64_type
,
{
end
-
start
}};
auto
lit_ins
=
m
.
add_literal
(
migraphx
::
literal
{
output_shape
,
vec_shape
});
m
.
replace_instruction
(
ins
,
lit_ins
);
}
};
void
simplify_dyn_ops
::
apply
(
module
&
m
)
const
{
match
::
find_matches
(
m
,
find_static_2in_broadcasts
{},
find_const_3in_slice
{},
find_const_4in_slice
{});
match
::
find_matches
(
m
,
find_static_2in_broadcasts
{},
find_static_dimensions_of
{},
find_const_3in_slice
{},
find_const_4in_slice
{});
}
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/targets/cpu/include/migraphx/cpu/dnnl.hpp
View file @
4ea39116
...
...
@@ -91,6 +91,19 @@ struct post_op : reflect_equality<post_op>, reflect_stream<post_op>
}
};
template
<
class
F
>
struct
execute_wrapper
{
F
f
;
argument
operator
()(
context
&
,
const
std
::
vector
<
argument
>&
args
)
const
{
return
f
(
args
);
}
};
template
<
class
F
>
execute_wrapper
<
F
>
make_execute_wrapper
(
F
f
)
{
return
{
std
::
move
(
f
)};
}
template
<
class
Derived
,
class
Primitive
>
struct
dnnl_op
:
auto_register_op
<
Derived
>
{
...
...
@@ -308,7 +321,7 @@ struct dnnl_op : auto_register_op<Derived>
#ifndef NDEBUG
auto
prim_attr
=
get_primitive_attr
(
md
);
#endif
execute
=
[
=
](
context
&
,
const
std
::
vector
<
argument
>&
args
)
{
execute
=
make_execute_wrapper
([
=
](
const
std
::
vector
<
argument
>&
args
)
{
#ifndef NDEBUG
// Check that the memory descriptors have not changed
auto
debug_args
=
args
;
...
...
@@ -379,7 +392,7 @@ struct dnnl_op : auto_register_op<Derived>
m
[
arg_lookup
[
i
]]
=
to_dnnl_memory
(
md
.
at
(
arg_lookup
[
i
]),
args
[
i
]);
prim
.
execute
(
get_dnnl_context
().
stream
,
m
);
return
args
.
back
();
};
}
)
;
}
std
::
vector
<
shape
>
trim_post_op_inputs
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
...
...
src/targets/cpu/include/migraphx/cpu/fuse_ops.hpp
View file @
4ea39116
...
...
@@ -24,7 +24,7 @@
#ifndef MIGRAPHX_GUARD_CPU_FUSE_OPS_HPP
#define MIGRAPHX_GUARD_CPU_FUSE_OPS_HPP
#include <migraphx/c
onfig
.hpp>
#include <migraphx/c
pu/context
.hpp>
#include <string>
namespace
migraphx
{
...
...
@@ -34,9 +34,7 @@ struct module;
namespace
cpu
{
struct
context
;
struct
fuse_ops
struct
MIGRAPHX_CPU_EXPORT
fuse_ops
{
context
*
ctx
=
nullptr
;
std
::
string
name
()
const
{
return
"cpu::fuse_ops"
;
}
...
...
src/targets/cpu/include/migraphx/cpu/pointwise.hpp
View file @
4ea39116
...
...
@@ -24,6 +24,7 @@
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_CPU_POINTWISE_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_CPU_POINTWISE_HPP
#include <array>
#include <migraphx/config.hpp>
#include <migraphx/context.hpp>
#include <migraphx/check_shapes.hpp>
...
...
src/targets/gpu/CMakeLists.txt
View file @
4ea39116
# ####################################################################################
# The MIT License (MIT)
#
# Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
...
...
@@ -37,8 +37,7 @@ if(NOT TARGET MIOpen)
message
(
SEND_ERROR
"Cant find miopen"
)
endif
()
if
(
NOT WIN32
)
# TODO: re-enable when CK is ported to Windows
if
(
MIGRAPHX_USE_COMPOSABLEKERNEL
)
find_package
(
composable_kernel 1.0.0 REQUIRED COMPONENTS jit_library
)
endif
()
...
...
@@ -48,10 +47,18 @@ else()
set
(
MIGRAPHX_USE_HIPRTC ON CACHE BOOL
"Use hipRTC APIs"
)
endif
()
include
(
Embed
)
file
(
GLOB KERNEL_FILES CONFIGURE_DEPENDS
${
CMAKE_CURRENT_SOURCE_DIR
}
/kernels/include/migraphx/kernels/*.hpp
)
message
(
STATUS
"KERNEL_FILES:
${
KERNEL_FILES
}
"
)
if
(
NOT MIGRAPHX_USE_COMPOSABLEKERNEL
)
list
(
REMOVE_ITEM KERNEL_FILES
${
CMAKE_CURRENT_SOURCE_DIR
}
/kernels/include/migraphx/kernels/ck_gemm.hpp
${
CMAKE_CURRENT_SOURCE_DIR
}
/kernels/include/migraphx/kernels/ck_gemm_softmax_gemm.hpp
${
CMAKE_CURRENT_SOURCE_DIR
}
/kernels/include/migraphx/kernels/ck.hpp
)
endif
()
include
(
Embed
)
add_embed_library
(
migraphx_kernels
${
KERNEL_FILES
}
RELATIVE
${
CMAKE_CURRENT_SOURCE_DIR
}
/kernels/include/
)
configure_file
(
device/targets.hpp.in include/migraphx/gpu/device/targets.hpp
)
...
...
@@ -95,9 +102,10 @@ rocm_clang_tidy_check(kernel_file_check)
file
(
GLOB JIT_GPU_SRCS CONFIGURE_DEPENDS
${
CMAKE_CURRENT_SOURCE_DIR
}
/jit/*.cpp
)
if
(
WIN32
)
# TODO: re-enable when CK is ported to Windows
list
(
REMOVE_ITEM JIT_GPU_SRCS
${
CMAKE_CURRENT_SOURCE_DIR
}
/jit/ck_gemm.cpp
)
if
(
NOT MIGRAPHX_USE_COMPOSABLEKERNEL
)
list
(
REMOVE_ITEM JIT_GPU_SRCS
${
CMAKE_CURRENT_SOURCE_DIR
}
/jit/ck_gemm.cpp
${
CMAKE_CURRENT_SOURCE_DIR
}
/jit/ck_gemm_softmax_gemm.cpp
)
endif
()
add_library
(
migraphx_gpu
...
...
@@ -120,8 +128,6 @@ add_library(migraphx_gpu
gather.cpp
gemm_impl.cpp
hip.cpp
int8_conv_pack.cpp
int8_gemm_pack.cpp
kernel.cpp
lowering.cpp
logsoftmax.cpp
...
...
@@ -132,7 +138,6 @@ add_library(migraphx_gpu
no_device.cpp
nonzero.cpp
pack_args.cpp
pack_int8_args.cpp
prefuse_ops.cpp
pad.cpp
perfdb.cpp
...
...
@@ -176,7 +181,6 @@ register_migraphx_gpu_ops(hip_
register_migraphx_gpu_ops
(
miopen_
abs
contiguous
int8_conv_pack
lrn
pooling
)
...
...
@@ -184,10 +188,6 @@ register_op(migraphx_gpu
HEADER migraphx/gpu/rnn_variable_seq_lens.hpp
OPERATORS gpu::hip_rnn_var_sl_shift_sequence gpu::hip_rnn_var_sl_shift_output gpu::hip_rnn_var_sl_last_output
INCLUDES migraphx/gpu/context.hpp
)
register_op
(
migraphx_gpu
HEADER migraphx/gpu/int8_gemm_pack.hpp
OPERATORS gpu::hip_int8_gemm_pack_a gpu::hip_int8_gemm_pack_b
INCLUDES migraphx/gpu/context.hpp
)
register_op
(
migraphx_gpu
HEADER migraphx/gpu/gemm.hpp
OPERATORS gpu::rocblas_gemm<op::dot> gpu::rocblas_gemm<op::quant_dot>
...
...
@@ -231,24 +231,28 @@ else()
string
(
REGEX REPLACE
" /[^ ]+
\\
.(a|so) "
" "
HIP_COMPILER_FLAGS
"
${
HIP_COMPILER_FLAGS
}
"
)
endforeach
()
message
(
STATUS
"Hip compiler flags:
${
HIP_COMPILER_FLAGS
}
"
)
message
(
STATUS
"Hip compiler flags:
\"
${
HIP_COMPILER_FLAGS
}
\"
"
)
target_compile_definitions
(
migraphx_gpu PRIVATE
"
-DMIGRAPHX_HIP_COMPILER=
${
CMAKE_CXX_COMPILER
}
"
"
-DMIGRAPHX_HIP_COMPILER_FLAGS=
${
HIP_COMPILER_FLAGS
}
"
-DMIGRAPHX_HIP_COMPILER=
"
${
CMAKE_CXX_COMPILER
}
"
-DMIGRAPHX_HIP_COMPILER_FLAGS=
"
${
HIP_COMPILER_FLAGS
}
"
)
if
(
DEFINED CMAKE_CXX_COMPILER_LAUNCHER
)
execute_process
(
COMMAND which
${
CMAKE_CXX_COMPILER_LAUNCHER
}
OUTPUT_VARIABLE MIGRAPHX_HIP_COMPILER_LAUNCHER
)
string
(
STRIP
"
${
MIGRAPHX_HIP_COMPILER_LAUNCHER
}
"
MIGRAPHX_HIP_COMPILER_LAUNCHER
)
target_compile_definitions
(
migraphx_gpu PRIVATE
"
-DMIGRAPHX_HIP_COMPILER_LAUNCHER=
${
MIGRAPHX_HIP_COMPILER_LAUNCHER
}
"
)
target_compile_definitions
(
migraphx_gpu PRIVATE -DMIGRAPHX_HIP_COMPILER_LAUNCHER=
"
${
MIGRAPHX_HIP_COMPILER_LAUNCHER
}
"
)
endif
()
endif
()
# Check miopen find mode api
include
(
CheckLibraryExists
)
get_target_property
(
MIOPEN_LOCATION MIOpen LOCATION
)
get_target_property
(
ROCBLAS_LOCATION roc::rocblas LOCATION
)
check_library_exists
(
MIOpen
"miopenHiddenSetConvolutionFindMode"
"
${
MIOPEN_LOCATION
}
"
HAS_FIND_MODE_API
)
check_library_exists
(
MIOpen
"miopenFindSolutions"
"
${
MIOPEN_LOCATION
}
"
HAS_FIND_2_API
)
# Beta API for automated GEMM tuning
check_library_exists
(
roc::rocblas
"rocblas_gemm_ex_get_solutions"
"
${
ROCBLAS_LOCATION
}
"
HAS_ROCBLAS_TUNING_BETA_FEATURE_API
)
set
(
MIGRAPHX_USE_FIND_2_API
"
${
HAS_FIND_2_API
}
"
CACHE BOOL
""
)
...
...
@@ -271,10 +275,16 @@ else()
message
(
STATUS
"MIOpen does not have find mode api"
)
endif
()
if
(
HAS_ROCBLAS_TUNING_BETA_FEATURE_API
)
target_compile_definitions
(
migraphx_gpu PUBLIC -DMIGRAPHX_USE_ROCBLAS_TUNING_API -DROCBLAS_BETA_FEATURES_API -DROCBLAS_NO_DEPRECATED_WARNINGS
)
message
(
STATUS
"MIGraphx is using Beta API of rocBLAS"
)
else
()
message
(
STATUS
"rocBLAS does not have User Tuning Beta API"
)
endif
()
target_link_libraries
(
migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas
)
target_link_libraries
(
migraphx_gpu PRIVATE migraphx_device migraphx_kernels
)
if
(
NOT WIN32
)
# TODO: re-enable when CK is ported to Windows
if
(
MIGRAPHX_USE_COMPOSABLEKERNEL
)
target_link_libraries
(
migraphx_gpu PRIVATE composable_kernel::jit_library
)
endif
()
...
...
src/targets/gpu/argmax.cpp
View file @
4ea39116
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
...
...
@@ -40,7 +40,8 @@ argument hip_argmax::compute(context& ctx, const shape&, const std::vector<argum
{
auto
n_dim
=
args
.
front
().
get_shape
().
lens
().
size
();
int64_t
tuned_axis
=
tune_axis
(
n_dim
,
op
.
axis
,
op
.
name
());
device
::
argmax
(
ctx
.
get_stream
().
get
(),
args
.
back
(),
args
.
front
(),
tuned_axis
);
device
::
argmax
(
ctx
.
get_stream
().
get
(),
args
.
back
(),
args
.
front
(),
tuned_axis
,
op
.
select_last_index
);
return
args
.
back
();
}
...
...
src/targets/gpu/argmin.cpp
View file @
4ea39116
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
...
...
@@ -40,7 +40,8 @@ argument hip_argmin::compute(context& ctx, const shape&, const std::vector<argum
{
auto
n_dim
=
args
.
front
().
get_shape
().
lens
().
size
();
int64_t
tuned_axis
=
tune_axis
(
n_dim
,
op
.
axis
,
op
.
name
());
device
::
argmin
(
ctx
.
get_stream
().
get
(),
args
.
back
(),
args
.
front
(),
tuned_axis
);
device
::
argmin
(
ctx
.
get_stream
().
get
(),
args
.
back
(),
args
.
front
(),
tuned_axis
,
op
.
select_last_index
);
return
args
.
back
();
}
...
...
src/targets/gpu/compile_hip.cpp
View file @
4ea39116
...
...
@@ -248,7 +248,7 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
{
if
(
src
.
path
.
extension
()
!=
".cpp"
)
continue
;
std
::
cout
<<
std
::
string
(
src
.
content
.
first
,
src
.
len
()
)
<<
std
::
endl
;
std
::
cout
<<
std
::
string
(
src
.
content
)
<<
std
::
endl
;
}
}
auto
p
=
dynamic_loader
::
path
(
&
compile_hip_src_with_hiprtc
);
...
...
@@ -284,16 +284,20 @@ std::vector<std::vector<char>> compile_hip_src_with_hiprtc(std::vector<hiprtc_sr
bool
is_hip_clang_compiler
()
{
static
const
auto
result
=
ends_with
(
MIGRAPHX_STRINGIZE
(
MIGRAPHX_HIP_COMPILER
),
"clang++"
)
;
static
const
auto
result
=
fs
::
path
{
MIGRAPHX_HIP_COMPILER
}.
stem
()
==
"clang++"
;
return
result
;
}
#ifdef MIGRAPHX_HIP_COMPILER_LAUNCHER
bool
has_compiler_launcher
()
{
static
const
auto
result
=
fs
::
exists
(
MIGRAPHX_
STRINGIZE
(
MIGRAPHX_
HIP_COMPILER_LAUNCHER
)
)
;
static
const
auto
result
=
fs
::
exists
(
MIGRAPHX_HIP_COMPILER_LAUNCHER
);
return
result
;
}
#endif
src_compiler
assemble
(
src_compiler
compiler
)
{
compiler
.
out_ext
=
".S"
;
...
...
@@ -306,8 +310,7 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
{
assert
(
not
srcs
.
empty
());
if
(
not
is_hip_clang_compiler
())
MIGRAPHX_THROW
(
"Unknown hip compiler: "
+
std
::
string
(
MIGRAPHX_STRINGIZE
(
MIGRAPHX_HIP_COMPILER
)));
MIGRAPHX_THROW
(
"Unknown hip compiler: "
MIGRAPHX_HIP_COMPILER
);
if
(
params
.
find
(
"-std="
)
==
std
::
string
::
npos
)
params
+=
" --std=c++17"
;
...
...
@@ -323,14 +326,14 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
params
+=
" -DMIGRAPHX_DEBUG"
;
params
+=
" -Wno-unused-command-line-argument -Wno-cuda-compat "
;
params
+=
MIGRAPHX_STRINGIZE
(
MIGRAPHX_HIP_COMPILER_FLAGS
)
;
params
+=
MIGRAPHX_HIP_COMPILER_FLAGS
;
src_compiler
compiler
;
compiler
.
flags
=
params
;
compiler
.
compiler
=
MIGRAPHX_STRINGIZE
(
MIGRAPHX_HIP_COMPILER
)
;
compiler
.
compiler
=
MIGRAPHX_HIP_COMPILER
;
#ifdef MIGRAPHX_HIP_COMPILER_LAUNCHER
if
(
has_compiler_launcher
())
compiler
.
launcher
=
MIGRAPHX_STRINGIZE
(
MIGRAPHX_HIP_COMPILER_LAUNCHER
)
;
compiler
.
launcher
=
MIGRAPHX_HIP_COMPILER_LAUNCHER
;
#endif
if
(
enabled
(
MIGRAPHX_GPU_DUMP_SRC
{}))
{
...
...
@@ -338,7 +341,7 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
{
if
(
src
.
path
.
extension
()
!=
".cpp"
)
continue
;
std
::
cout
<<
std
::
string
(
src
.
content
.
first
,
src
.
len
()
)
<<
std
::
endl
;
std
::
cout
<<
std
::
string
(
src
.
content
)
<<
std
::
endl
;
}
}
...
...
@@ -354,14 +357,12 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
bool
hip_has_flags
(
const
std
::
vector
<
std
::
string
>&
flags
)
{
src_compiler
compiler
;
compiler
.
compiler
=
MIGRAPHX_STRINGIZE
(
MIGRAPHX_HIP_COMPILER
)
;
compiler
.
compiler
=
MIGRAPHX_HIP_COMPILER
;
compiler
.
flags
=
join_strings
(
flags
,
" "
)
+
" -x hip -c --offload-arch=gfx900 --cuda-device-only"
;
std
::
string
src
;
src_file
input
;
input
.
path
=
"main.cpp"
;
input
.
content
=
std
::
make_pair
(
src
.
data
(),
src
.
data
()
+
src
.
size
());
src_file
input
{
"main.cpp"
,
src
};
try
{
...
...
src/targets/gpu/compile_hip_code_object.cpp
View file @
4ea39116
...
...
@@ -139,6 +139,12 @@ void hip_compile_options::set_launch_params(
global
=
compute_global
(
local
);
}
static
bool
hip_accept_non_uniform_wg
()
{
static
bool
non_uniform_wg
=
hip_has_flags
({
"-fno-offload-uniform-block"
});
return
non_uniform_wg
;
}
std
::
function
<
std
::
size_t
(
std
::
size_t
local
)
>
compute_global_for
(
context
&
ctx
,
std
::
size_t
n
,
std
::
size_t
over
)
{
...
...
@@ -146,13 +152,14 @@ compute_global_for(context& ctx, std::size_t n, std::size_t over)
std
::
size_t
max_global
=
ctx
.
get_current_device
().
get_cu_count
()
*
ctx
.
get_current_device
().
get_max_workitems_per_cu
();
return
[
n
,
over
,
max_global
](
std
::
size_t
local
)
{
// hip require global workitems multiple of local workitems. It may degrade performance.
// [TODO]: consider adding "fno-hip-uniform-block" flag when it becomes available.
// https://reviews.llvm.org/D155213
std
::
size_t
num_elements
=
((
n
+
local
-
1
)
/
local
)
*
local
;
std
::
size_t
groups
=
(
num_elements
+
local
-
1
)
/
local
;
std
::
size_t
max_blocks
=
max_global
/
local
;
std
::
size_t
nglobal
=
std
::
min
(
max_blocks
*
over
,
groups
)
*
local
;
std
::
size_t
num_elements
=
n
;
if
(
not
hip_accept_non_uniform_wg
())
{
num_elements
=
(
1
+
(
n
-
1
)
/
local
)
*
local
;
}
std
::
size_t
groups
=
1
+
(
num_elements
-
1
)
/
local
;
std
::
size_t
max_blocks
=
max_global
/
local
;
std
::
size_t
nglobal
=
std
::
min
(
max_blocks
*
over
,
groups
)
*
local
;
return
std
::
min
(
nglobal
,
num_elements
);
};
}
...
...
@@ -172,21 +179,22 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option
assert
(
options
.
inputs
.
size
()
==
options
.
virtual_inputs
.
size
()
or
options
.
virtual_inputs
.
empty
());
std
::
vector
<
src_file
>
srcs
=
options
.
additional_src_files
;
std
::
transform
(
migraphx_kernels
().
begin
(),
migraphx_kernels
().
end
(),
std
::
back_inserter
(
srcs
),
[](
auto
&&
p
)
{
auto
&&
name
=
p
.
first
;
auto
&&
c
=
p
.
second
;
auto
path
=
name
;
return
src_file
{
path
,
c
};
});
srcs
.
push_back
(
src_file
{
fs
::
path
{
"main.cpp"
},
std
::
make_pair
(
content
.
data
(),
content
.
data
()
+
content
.
size
())});
static
auto
kernels
{
::
migraphx_kernels
()};
std
::
transform
(
kernels
.
begin
(),
kernels
.
end
(),
std
::
back_inserter
(
srcs
),
[](
const
std
::
pair
<
std
::
string_view
,
std
::
string_view
>&
elem
)
{
return
src_file
{
elem
};
});
srcs
.
emplace_back
(
"main.cpp"
,
content
);
auto
args_hpp
=
generate_args_hpp
(
options
.
virtual_inputs
.
empty
()
?
options
.
inputs
:
options
.
virtual_inputs
);
srcs
.
push_back
(
src_file
{
fs
::
path
{
"args.hpp"
},
std
::
make_pair
(
args_hpp
.
data
(),
args_hpp
.
data
()
+
args_hpp
.
size
())});
srcs
.
emplace_back
(
"args.hpp"
,
args_hpp
);
if
(
options
.
global
%
options
.
local
!=
0
and
hip_accept_non_uniform_wg
())
options
.
params
+=
" -fno-offload-uniform-block"
;
else
assert
(
options
.
global
%
options
.
local
==
0
);
options
.
params
+=
" -DMIGRAPHX_NGLOBAL="
+
std
::
to_string
(
options
.
global
);
options
.
params
+=
" -DMIGRAPHX_NLOCAL="
+
std
::
to_string
(
options
.
local
);
options
.
params
+=
" "
+
join_strings
(
compiler_warnings
(),
" "
);
...
...
src/targets/gpu/compile_miopen.cpp
View file @
4ea39116
...
...
@@ -60,9 +60,8 @@ struct miopen_op
};
MIGRAPHX_REGISTER_OP
(
miopen_op
);
std
::
size_t
compile_miopen
::
compile
(
operation
&
op
,
instruction_ref
ins
,
bool
format
)
const
std
::
size_t
compile_miopen
::
compile
(
operation
&
op
,
instruction_ref
ins
)
const
{
op
.
from_value
({{
"int8_x4_format"
,
format
}});
auto
v
=
op
.
compile
(
*
ctx
,
ins
->
get_shape
(),
to_shapes
(
ins
->
inputs
()));
return
v
.
get
<
std
::
size_t
>
(
"workspace"
,
0
);
}
...
...
@@ -70,25 +69,15 @@ std::size_t compile_miopen::compile(operation& op, instruction_ref ins, bool for
void
compile_miopen
::
apply
(
module
&
m
)
const
{
assert
(
ctx
);
const
bool
int8_x4_format
=
get_int8_x4_format
(
any_cast
<
migraphx
::
gpu
::
context
>
(
*
ctx
));
for
(
auto
ins
:
iterator_for
(
m
))
{
if
(
ins
->
name
()
!=
"gpu::miopen_op"
)
continue
;
auto
op
=
any_cast
<
miopen_op
>
(
ins
->
get_operator
()).
op
;
std
::
size_t
ws
=
0
;
try
{
// for the regular convolution and convolution_backwards, this try would always succeed
ws
=
compile
(
op
,
ins
,
int8_x4_format
);
}
catch
(
migraphx
::
exception
&
)
{
// In case no solver supports the default format, retry using the other format.
ws
=
compile
(
op
,
ins
,
not
int8_x4_format
);
}
auto
inputs
=
ins
->
inputs
();
auto
alloc
=
m
.
insert_instruction
(
ws
=
compile
(
op
,
ins
);
auto
inputs
=
ins
->
inputs
();
auto
alloc
=
m
.
insert_instruction
(
ins
,
make_op
(
"allocate"
,
{{
"shape"
,
to_value
(
shape
{
shape
::
int8_type
,
{
ws
}})}}));
inputs
.
insert
(
std
::
prev
(
inputs
.
end
()),
alloc
);
...
...
src/targets/gpu/compile_ops.cpp
View file @
4ea39116
...
...
@@ -37,6 +37,7 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace
gpu
{
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_GPU_COMPILE_PARALLEL
);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_TRACE_BENCHMARKING
);
struct
precompile_op
{
...
...
@@ -167,6 +168,7 @@ struct compile_plan
}
const
compiled_result
&
benchmark
(
problem_cache
&
pc
)
const
{
const
auto
trace_level
=
value_of
(
MIGRAPHX_TRACE_BENCHMARKING
{});
if
(
results
.
empty
())
MIGRAPHX_THROW
(
"No configs to tune"
);
if
(
results
.
size
()
==
1
)
...
...
@@ -177,19 +179,35 @@ struct compile_plan
}
if
(
not
config
)
MIGRAPHX_THROW
(
"Multiple kernels without config"
);
std
::
cout
<<
"Benchmarking "
<<
preop
.
name
()
<<
": "
<<
results
.
size
()
<<
" configs"
<<
std
::
endl
;
if
(
trace_level
>
0
)
std
::
cout
<<
"Benchmarking "
<<
preop
.
name
()
<<
": "
<<
results
.
size
()
<<
" configs"
<<
std
::
endl
;
if
(
trace_level
>
1
)
std
::
cout
<<
"Problem: "
<<
config
->
problem
<<
std
::
endl
;
std
::
vector
<
double
>
times
;
times
.
reserve
(
results
.
size
());
std
::
transform
(
results
.
begin
(),
results
.
end
(),
std
::
back_inserter
(
times
),
[
&
](
const
auto
&
cr
)
{
if
(
not
cr
.
has_value
())
return
std
::
numeric_limits
<
double
>::
max
();
return
time_op
(
*
ctx
,
cr
->
replace
.
code_object
,
to_shapes
(
cr
->
ins
->
inputs
()),
20
)
.
first
;
});
std
::
transform
(
results
.
begin
(),
results
.
end
(),
config
->
solutions
.
begin
(),
std
::
back_inserter
(
times
),
[
&
](
const
auto
&
cr
,
const
auto
&
solution
)
{
if
(
trace_level
>
1
)
std
::
cout
<<
"Benchmarking solution: "
<<
solution
<<
std
::
endl
;
if
(
not
cr
.
has_value
())
{
if
(
trace_level
>
1
)
std
::
cout
<<
"No binary"
<<
std
::
endl
;
return
std
::
numeric_limits
<
double
>::
max
();
}
auto
t
=
time_op
(
*
ctx
,
cr
->
replace
.
code_object
,
to_shapes
(
cr
->
ins
->
inputs
()),
20
);
if
(
trace_level
>
1
)
std
::
cout
<<
t
<<
"ms"
<<
std
::
endl
;
return
t
;
});
auto
i
=
std
::
distance
(
times
.
begin
(),
std
::
min_element
(
times
.
begin
(),
times
.
end
()));
std
::
cout
<<
"Fastest solution: "
<<
config
->
solutions
.
at
(
i
)
<<
std
::
endl
;
if
(
trace_level
>
0
)
std
::
cout
<<
"Fastest solution: "
<<
config
->
solutions
.
at
(
i
)
<<
std
::
endl
;
pc
.
insert
(
preop
.
name
(),
config
->
problem
,
config
->
solutions
.
at
(
i
));
if
(
not
results
[
i
].
has_value
())
MIGRAPHX_THROW
(
"No valid tuned compilation."
);
...
...
src/targets/gpu/device/argmax.cpp
View file @
4ea39116
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
...
...
@@ -34,9 +34,16 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace
gpu
{
namespace
device
{
void
argmax
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
,
int64_t
axis
)
void
argmax
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
,
int64_t
axis
,
bool
select_last_index
)
{
arg_op
(
argmax_op
{},
stream
,
result
,
arg
,
axis
);
if
(
select_last_index
)
arg_op
(
argmax_op_last_index
{},
stream
,
result
,
arg
,
axis
);
else
arg_op
(
argmax_op_first_index
{},
stream
,
result
,
arg
,
axis
);
}
}
// namespace device
...
...
src/targets/gpu/device/argmin.cpp
View file @
4ea39116
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
...
...
@@ -34,9 +34,16 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace
gpu
{
namespace
device
{
void
argmin
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
,
int64_t
axis
)
void
argmin
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
,
int64_t
axis
,
bool
select_last_index
)
{
arg_op
(
argmin_op
{},
stream
,
result
,
arg
,
axis
);
if
(
select_last_index
)
arg_op
(
argmin_op_last_index
{},
stream
,
result
,
arg
,
axis
);
else
arg_op
(
argmin_op_first_index
{},
stream
,
result
,
arg
,
axis
);
}
}
// namespace device
...
...
src/targets/gpu/device/int8_gemm_pack.cpp
deleted
100644 → 0
View file @
20128cae
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/gpu/device/int8_gemm_pack.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp>
#include <migraphx/gpu/device/tensor.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
device
{
void
int8_gemm_pack_a
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
)
{
auto
comp_shape
=
arg
.
get_shape
();
auto
out_lens
=
comp_shape
.
lens
();
auto
dim_0
=
out_lens
.
size
()
-
2
;
auto
dim_1
=
out_lens
.
size
()
-
1
;
std
::
size_t
lda
=
comp_shape
.
strides
()[
dim_0
];
std
::
size_t
m_size
=
out_lens
[
dim_0
]
*
out_lens
[
dim_1
];
visit_all
(
result
,
arg
)([
&
](
auto
output
,
auto
input
)
{
std
::
size_t
nelements
=
comp_shape
.
elements
();
auto
*
out_ptr
=
device_cast
(
output
.
data
());
auto
*
in_ptr
=
device_cast
(
input
.
data
());
visit_tensor_size
(
out_lens
.
size
(),
[
&
](
auto
out_dim
)
{
hip_tensor_descriptor
<
out_dim
>
desc
(
comp_shape
);
gs_launch
(
stream
,
nelements
,
256
)([
=
](
auto
ii
)
__device__
{
const
size_t
nb
=
4
;
auto
idx
=
desc
.
multi
(
ii
);
std
::
size_t
i_m
=
idx
[
dim_1
];
std
::
size_t
i_k
=
idx
[
dim_0
];
std
::
size_t
offset
=
ii
/
m_size
*
m_size
;
out_ptr
[
i_k
%
nb
+
(
i_m
+
(
i_k
/
nb
)
*
lda
)
*
nb
+
offset
]
=
in_ptr
[
i_m
+
i_k
*
lda
+
offset
];
});
});
});
}
void
int8_gemm_pack_b
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
)
{
auto
trans_shape
=
arg
.
get_shape
();
auto
out_lens
=
trans_shape
.
lens
();
auto
dim_0
=
trans_shape
.
lens
().
size
()
-
2
;
auto
dim_1
=
trans_shape
.
lens
().
size
()
-
1
;
std
::
size_t
ldb
=
trans_shape
.
strides
()[
dim_1
];
auto
wrap_lens
=
out_lens
;
std
::
swap
(
wrap_lens
[
dim_0
],
wrap_lens
[
dim_1
]);
shape
comp_shape
{
trans_shape
.
type
(),
wrap_lens
};
std
::
size_t
m_size
=
out_lens
[
dim_0
]
*
out_lens
[
dim_1
];
visit_all
(
result
,
arg
)([
&
](
auto
output
,
auto
input
)
{
std
::
size_t
nelements
=
comp_shape
.
elements
();
auto
*
out_ptr
=
device_cast
(
output
.
data
());
auto
*
in_ptr
=
device_cast
(
input
.
data
());
visit_tensor_size
(
out_lens
.
size
(),
[
&
](
auto
out_dim
)
{
hip_tensor_descriptor
<
out_dim
>
desc
(
comp_shape
);
gs_launch
(
stream
,
nelements
,
256
)([
=
](
auto
ii
)
__device__
{
const
size_t
nb
=
4
;
auto
idx
=
desc
.
multi
(
ii
);
std
::
size_t
i_n
=
idx
[
dim_1
];
std
::
size_t
i_k
=
idx
[
dim_0
];
std
::
size_t
offset
=
ii
/
m_size
*
m_size
;
out_ptr
[
i_k
%
nb
+
(
i_n
+
(
i_k
/
nb
)
*
ldb
)
*
nb
+
offset
]
=
in_ptr
[
i_n
+
i_k
*
ldb
+
offset
];
});
});
});
}
}
// namespace device
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/device/targets.hpp.in
View file @
4ea39116
...
...
@@ -24,7 +24,7 @@
#ifndef MIGRAPHX_GUARD_DEVICE_TARGETS_CPP
#define MIGRAPHX_GUARD_DEVICE_TARGETS_CPP
#include <migraphx/config.hpp>
#include <migraphx/
gpu/device/
config.hpp>
#include <string>
#include <vector>
...
...
@@ -34,9 +34,13 @@ namespace gpu {
namespace device {
#define MIGRAPHX_GPU_TARGETS "@GPU_TARGETS@" // NOLINT
MIGRAPHX_DEVICE_EXPORT
const std::vector<std::string>& get_targets();
MIGRAPHX_DEVICE_EXPORT
std::string get_targets_as_string();
MIGRAPHX_DEVICE_EXPORT
std::string get_device_name();
} // namespace device
...
...
src/targets/gpu/driver/compile_op.cpp
View file @
4ea39116
...
...
@@ -38,10 +38,8 @@ struct compile_op : action<compile_op>
context
ctx
;
auto
inputs
=
p
.
parse_shapes
(
v
.
at
(
"inputs"
));
auto
op
=
gpu
::
compile_op
(
v
.
at
(
"name"
).
to
<
std
::
string
>
(),
ctx
,
inputs
,
v
);
auto
[
host_time
,
device_time
]
=
time_op
(
ctx
,
op
,
inputs
,
p
.
get
(
v
,
"iterations"
,
100
));
std
::
cout
<<
op
<<
": "
<<
host_time
<<
"ms"
;
if
(
device_time
>
0
)
std
::
cout
<<
", "
<<
device_time
<<
"ms"
;
auto
t
=
time_op
(
ctx
,
op
,
inputs
,
p
.
get
(
v
,
"iterations"
,
100
));
std
::
cout
<<
op
<<
": "
<<
t
<<
"ms"
;
std
::
cout
<<
std
::
endl
;
}
};
...
...
src/targets/gpu/driver/run_op.cpp
View file @
4ea39116
...
...
@@ -43,8 +43,8 @@ struct run_op : action<run_op>
auto
op
=
make_op
(
name
);
if
(
v
.
contains
(
"fields"
))
op
.
from_value
(
v
.
at
(
"fields"
));
auto
[
host_time
,
device_time
]
=
time_op
(
ctx
,
op
,
inputs
,
p
.
get
(
v
,
"iterations"
,
100
));
std
::
cout
<<
op
<<
": "
<<
host_time
<<
"ms"
<<
std
::
endl
;
auto
t
=
time_op
(
ctx
,
op
,
inputs
,
p
.
get
(
v
,
"iterations"
,
100
));
std
::
cout
<<
op
<<
": "
<<
t
<<
"ms"
<<
std
::
endl
;
}
};
...
...
src/targets/gpu/fuse_ck.cpp
View file @
4ea39116
...
...
@@ -22,10 +22,11 @@
* THE SOFTWARE.
*/
#include <migraphx/gpu/fuse_ck.hpp>
#include <migraphx/gpu/gemm_softmax_gemm.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/gpu/device_name.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
@@ -55,7 +56,7 @@ struct ck_gemm
{
check_shapes
{
inputs
,
*
this
}.
same_ndims
();
if
(
inputs
.
size
()
<
2
)
MIGRAPHX_THROW
(
"
should have at least two inputs."
);
MIGRAPHX_THROW
(
name
()
+
":
should have at least two inputs."
);
auto
a
=
inputs
[
0
];
auto
b
=
inputs
[
1
];
for
(
const
auto
&
input
:
inputs
)
...
...
@@ -65,27 +66,35 @@ struct ck_gemm
return
r
;
return
r
.
with_type
(
mods
.
front
()
->
get_output_shapes
().
front
().
type
());
}
static
bool
is_ck_supported_type
(
shape
::
type_t
t
)
{
return
contains
({
shape
::
half_type
,
shape
::
int8_type
,
shape
::
int32_type
},
t
);
}
};
MIGRAPHX_REGISTER_OP
(
ck_gemm
);
namespace
{
bool
is_ck_supported_type
(
shape
::
type_t
t
)
struct
ck_gemm_softmax_gemm
:
gemm_softmax_gemm
{
return
contains
({
shape
::
half_type
,
shape
::
int8_type
,
shape
::
int32_type
},
t
);
}
std
::
string
name
()
const
{
return
"gpu::ck_gemm_softmax_gemm"
;
}
};
MIGRAPHX_REGISTER_OP
(
ck_gemm_softmax_gemm
);
namespace
{
MIGRAPHX_PRED_MATCHER
(
is_ck_gemm
,
instruction_ref
ins
)
{
if
(
ins
->
name
()
!=
"dot"
and
ins
->
name
()
!=
"quant_dot"
)
return
false
;
if
(
not
is_ck_supported_type
(
ins
->
get_shape
().
type
()))
if
(
not
ck_gemm
::
is_ck_supported_type
(
ins
->
get_shape
().
type
()))
return
false
;
auto
a
=
ins
->
inputs
().
front
()
->
get_shape
();
auto
b
=
ins
->
inputs
().
back
()
->
get_shape
();
auto
m
=
a
.
lens
()[
a
.
lens
().
size
()
-
2
];
auto
n
=
b
.
lens
().
back
();
auto
k
=
a
.
lens
().
back
();
auto
batch_size
=
std
::
accumulate
(
a
.
lens
().
rbegin
()
+
2
,
a
.
lens
().
rend
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
// Integer gemms must be divisible by 4 in ck
if
(
contains
({
shape
::
int8_type
,
shape
::
int32_type
},
ins
->
get_shape
().
type
()))
{
...
...
@@ -96,9 +105,17 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
if
(
k
%
4
!=
0
)
return
false
;
}
// Skipping GEMMs with a K dimension greater than 2048 is a course-grained strategy
// to avoid poor-performing GEMM kernels from CK
// To-do: Investigate a more precise strategy
auto
device_name
=
trim
(
split_string
(
get_device_name
(),
':'
).
front
());
if
(
device_name
==
"gfx940"
)
{
if
(
ins
->
get_shape
().
type
()
==
shape
::
half_type
)
{
if
(
batch_size
>=
64
)
return
m
<
2048
or
k
<=
64
or
n
<=
384
or
n
>=
2048
;
return
true
;
}
return
true
;
}
return
k
<=
2048
;
}
...
...
@@ -127,7 +144,15 @@ struct find_ck_gemm_pointwise
ins
->
get_shape
().
type
()
!=
gemm_ins
->
get_shape
().
type
())
return
;
if
(
std
::
any_of
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
[](
auto
input
)
{
return
not
is_ck_supported_type
(
input
->
get_shape
().
type
());
return
not
ck_gemm
::
is_ck_supported_type
(
input
->
get_shape
().
type
());
}))
return
;
if
(
std
::
any_of
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
[](
auto
input
)
{
return
not
input
->
inputs
().
empty
()
and
input
->
inputs
().
front
()
->
name
()
==
"capture"
;
}))
return
;
if
(
std
::
any_of
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
[](
auto
input
)
{
return
not
input
->
inputs
().
empty
()
and
input
->
inputs
().
front
()
->
name
()
==
"capture"
;
}))
return
;
assert
(
gemm_it
!=
inputs
.
end
());
...
...
@@ -152,7 +177,7 @@ struct find_ck_gemm_pointwise
struct
find_ck_gemm
{
auto
matcher
()
const
{
return
match
::
name
(
"dot"
)(
is_ck_gemm
().
bind
(
"gemm"
));
}
auto
matcher
()
const
{
return
match
::
name
(
"dot"
,
"quant_dot"
)(
is_ck_gemm
().
bind
(
"gemm"
));
}
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
{
...
...
@@ -161,11 +186,26 @@ struct find_ck_gemm
}
};
struct
find_ck_gemm_softmax_gemm
{
auto
matcher
()
const
{
return
match
::
name
(
"gpu::pre_gemm_softmax_gemm"
);
}
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
v
=
ins
->
get_operator
().
to_value
();
assert
(
v
.
contains
(
"scale"
));
auto
scale
=
v
.
at
(
"scale"
).
to
<
float
>
();
mpm
.
get_module
().
replace_instruction
(
ins
,
ck_gemm_softmax_gemm
{
migraphx
::
make_op
(
"dot"
),
scale
},
ins
->
inputs
());
}
};
}
// namespace
void
fuse_ck
::
apply
(
module_pass_manager
&
mpm
)
const
{
match
::
find_matches
(
mpm
,
find_ck_gemm_pointwise
{});
match
::
find_matches
(
mpm
,
find_ck_gemm_softmax_gemm
{},
find_ck_gemm_pointwise
{});
match
::
find_matches
(
mpm
,
find_ck_gemm
{});
}
...
...
Prev
1
2
3
4
5
6
7
8
9
10
…
16
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment