Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
a6fa5e4b
Unverified
Commit
a6fa5e4b
authored
Oct 23, 2023
by
Chris Austen
Committed by
GitHub
Oct 23, 2023
Browse files
Merge branch 'develop' into enable_navi_32_ci
parents
b7a7cd3c
7604ecf5
Changes
247
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
472 additions
and
178 deletions
+472
-178
src/targets/gpu/CMakeLists.txt
src/targets/gpu/CMakeLists.txt
+9
-1
src/targets/gpu/argmax.cpp
src/targets/gpu/argmax.cpp
+3
-2
src/targets/gpu/argmin.cpp
src/targets/gpu/argmin.cpp
+3
-2
src/targets/gpu/compile_hip.cpp
src/targets/gpu/compile_hip.cpp
+5
-7
src/targets/gpu/compile_hip_code_object.cpp
src/targets/gpu/compile_hip_code_object.cpp
+28
-20
src/targets/gpu/compile_ops.cpp
src/targets/gpu/compile_ops.cpp
+22
-7
src/targets/gpu/device/argmax.cpp
src/targets/gpu/device/argmax.cpp
+10
-3
src/targets/gpu/device/argmin.cpp
src/targets/gpu/device/argmin.cpp
+10
-3
src/targets/gpu/device/include/migraphx/gpu/device/launch.hpp
...targets/gpu/device/include/migraphx/gpu/device/launch.hpp
+8
-0
src/targets/gpu/device/targets.hpp.in
src/targets/gpu/device/targets.hpp.in
+5
-1
src/targets/gpu/driver/compile_op.cpp
src/targets/gpu/driver/compile_op.cpp
+2
-4
src/targets/gpu/driver/run_op.cpp
src/targets/gpu/driver/run_op.cpp
+2
-2
src/targets/gpu/fuse_ck.cpp
src/targets/gpu/fuse_ck.cpp
+54
-14
src/targets/gpu/fuse_mlir.cpp
src/targets/gpu/fuse_mlir.cpp
+99
-82
src/targets/gpu/hiprtc/main.cpp
src/targets/gpu/hiprtc/main.cpp
+1
-0
src/targets/gpu/include/migraphx/gpu/ck.hpp
src/targets/gpu/include/migraphx/gpu/ck.hpp
+165
-0
src/targets/gpu/include/migraphx/gpu/compile_hip.hpp
src/targets/gpu/include/migraphx/gpu/compile_hip.hpp
+1
-4
src/targets/gpu/include/migraphx/gpu/context.hpp
src/targets/gpu/include/migraphx/gpu/context.hpp
+3
-20
src/targets/gpu/include/migraphx/gpu/convolution.hpp
src/targets/gpu/include/migraphx/gpu/convolution.hpp
+3
-3
src/targets/gpu/include/migraphx/gpu/device/arg_op.hpp
src/targets/gpu/include/migraphx/gpu/device/arg_op.hpp
+39
-3
No files found.
src/targets/gpu/CMakeLists.txt
View file @
a6fa5e4b
...
...
@@ -48,10 +48,18 @@ else()
set
(
MIGRAPHX_USE_HIPRTC ON CACHE BOOL
"Use hipRTC APIs"
)
endif
()
include
(
Embed
)
file
(
GLOB KERNEL_FILES CONFIGURE_DEPENDS
${
CMAKE_CURRENT_SOURCE_DIR
}
/kernels/include/migraphx/kernels/*.hpp
)
message
(
STATUS
"KERNEL_FILES:
${
KERNEL_FILES
}
"
)
if
(
WIN32
)
# TODO: re-enable when CK is ported to Windows
list
(
REMOVE_ITEM KERNEL_FILES
${
CMAKE_CURRENT_SOURCE_DIR
}
/kernels/include/migraphx/kernels/ck_gemm.hpp
${
CMAKE_CURRENT_SOURCE_DIR
}
/kernels/include/migraphx/kernels/ck.hpp
)
endif
()
include
(
Embed
)
add_embed_library
(
migraphx_kernels
${
KERNEL_FILES
}
RELATIVE
${
CMAKE_CURRENT_SOURCE_DIR
}
/kernels/include/
)
configure_file
(
device/targets.hpp.in include/migraphx/gpu/device/targets.hpp
)
...
...
src/targets/gpu/argmax.cpp
View file @
a6fa5e4b
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
...
...
@@ -40,7 +40,8 @@ argument hip_argmax::compute(context& ctx, const shape&, const std::vector<argum
{
auto
n_dim
=
args
.
front
().
get_shape
().
lens
().
size
();
int64_t
tuned_axis
=
tune_axis
(
n_dim
,
op
.
axis
,
op
.
name
());
device
::
argmax
(
ctx
.
get_stream
().
get
(),
args
.
back
(),
args
.
front
(),
tuned_axis
);
device
::
argmax
(
ctx
.
get_stream
().
get
(),
args
.
back
(),
args
.
front
(),
tuned_axis
,
op
.
select_last_index
);
return
args
.
back
();
}
...
...
src/targets/gpu/argmin.cpp
View file @
a6fa5e4b
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
...
...
@@ -40,7 +40,8 @@ argument hip_argmin::compute(context& ctx, const shape&, const std::vector<argum
{
auto
n_dim
=
args
.
front
().
get_shape
().
lens
().
size
();
int64_t
tuned_axis
=
tune_axis
(
n_dim
,
op
.
axis
,
op
.
name
());
device
::
argmin
(
ctx
.
get_stream
().
get
(),
args
.
back
(),
args
.
front
(),
tuned_axis
);
device
::
argmin
(
ctx
.
get_stream
().
get
(),
args
.
back
(),
args
.
front
(),
tuned_axis
,
op
.
select_last_index
);
return
args
.
back
();
}
...
...
src/targets/gpu/compile_hip.cpp
View file @
a6fa5e4b
...
...
@@ -28,6 +28,7 @@
#include <migraphx/env.hpp>
#include <cassert>
#include <iostream>
#include <deque>
#ifdef MIGRAPHX_USE_HIPRTC
#include <hip/hiprtc.h>
...
...
@@ -92,7 +93,7 @@ struct hiprtc_program
{
struct
string_array
{
std
::
vector
<
std
::
string
>
strings
{};
std
::
deque
<
std
::
string
>
strings
{};
std
::
vector
<
const
char
*>
c_strs
{};
string_array
()
{}
...
...
@@ -209,7 +210,6 @@ std::vector<std::vector<char>> compile_hip_src_with_hiprtc(std::vector<hiprtc_sr
options
.
push_back
(
"-Wno-gnu-line-marker"
);
options
.
push_back
(
"-Wno-old-style-cast"
);
}
if
(
enabled
(
MIGRAPHX_GPU_DEBUG
{}))
options
.
push_back
(
"-DMIGRAPHX_DEBUG"
);
if
(
std
::
none_of
(
options
.
begin
(),
options
.
end
(),
[](
const
std
::
string
&
s
)
{
...
...
@@ -248,7 +248,7 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
{
if
(
src
.
path
.
extension
()
!=
".cpp"
)
continue
;
std
::
cout
<<
std
::
string
(
src
.
content
.
first
,
src
.
len
()
)
<<
std
::
endl
;
std
::
cout
<<
std
::
string
(
src
.
content
)
<<
std
::
endl
;
}
}
auto
p
=
dynamic_loader
::
path
(
&
compile_hip_src_with_hiprtc
);
...
...
@@ -338,7 +338,7 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
{
if
(
src
.
path
.
extension
()
!=
".cpp"
)
continue
;
std
::
cout
<<
std
::
string
(
src
.
content
.
first
,
src
.
len
()
)
<<
std
::
endl
;
std
::
cout
<<
std
::
string
(
src
.
content
)
<<
std
::
endl
;
}
}
...
...
@@ -359,9 +359,7 @@ bool hip_has_flags(const std::vector<std::string>& flags)
join_strings
(
flags
,
" "
)
+
" -x hip -c --offload-arch=gfx900 --cuda-device-only"
;
std
::
string
src
;
src_file
input
;
input
.
path
=
"main.cpp"
;
input
.
content
=
std
::
make_pair
(
src
.
data
(),
src
.
data
()
+
src
.
size
());
src_file
input
{
"main.cpp"
,
src
};
try
{
...
...
src/targets/gpu/compile_hip_code_object.cpp
View file @
a6fa5e4b
...
...
@@ -139,6 +139,12 @@ void hip_compile_options::set_launch_params(
global
=
compute_global
(
local
);
}
static
bool
hip_accept_non_uniform_wg
()
{
static
bool
non_uniform_wg
=
hip_has_flags
({
"-fno-offload-uniform-block"
});
return
non_uniform_wg
;
}
std
::
function
<
std
::
size_t
(
std
::
size_t
local
)
>
compute_global_for
(
context
&
ctx
,
std
::
size_t
n
,
std
::
size_t
over
)
{
...
...
@@ -146,13 +152,14 @@ compute_global_for(context& ctx, std::size_t n, std::size_t over)
std
::
size_t
max_global
=
ctx
.
get_current_device
().
get_cu_count
()
*
ctx
.
get_current_device
().
get_max_workitems_per_cu
();
return
[
n
,
over
,
max_global
](
std
::
size_t
local
)
{
// hip require global workitems multiple of local workitems. It may degrade performance.
// [TODO]: consider adding "fno-hip-uniform-block" flag when it becomes available.
// https://reviews.llvm.org/D155213
std
::
size_t
num_elements
=
((
n
+
local
-
1
)
/
local
)
*
local
;
std
::
size_t
groups
=
(
num_elements
+
local
-
1
)
/
local
;
std
::
size_t
max_blocks
=
max_global
/
local
;
std
::
size_t
nglobal
=
std
::
min
(
max_blocks
*
over
,
groups
)
*
local
;
std
::
size_t
num_elements
=
n
;
if
(
not
hip_accept_non_uniform_wg
())
{
num_elements
=
(
1
+
(
n
-
1
)
/
local
)
*
local
;
}
std
::
size_t
groups
=
1
+
(
num_elements
-
1
)
/
local
;
std
::
size_t
max_blocks
=
max_global
/
local
;
std
::
size_t
nglobal
=
std
::
min
(
max_blocks
*
over
,
groups
)
*
local
;
return
std
::
min
(
nglobal
,
num_elements
);
};
}
...
...
@@ -172,21 +179,22 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option
assert
(
options
.
inputs
.
size
()
==
options
.
virtual_inputs
.
size
()
or
options
.
virtual_inputs
.
empty
());
std
::
vector
<
src_file
>
srcs
=
options
.
additional_src_files
;
std
::
transform
(
migraphx_kernels
().
begin
(),
migraphx_kernels
().
end
(),
std
::
back_inserter
(
srcs
),
[](
auto
&&
p
)
{
auto
&&
name
=
p
.
first
;
auto
&&
c
=
p
.
second
;
auto
path
=
name
;
return
src_file
{
path
,
c
};
});
srcs
.
push_back
(
src_file
{
fs
::
path
{
"main.cpp"
},
std
::
make_pair
(
content
.
data
(),
content
.
data
()
+
content
.
size
())});
static
auto
kernels
{
::
migraphx_kernels
()};
std
::
transform
(
kernels
.
begin
(),
kernels
.
end
(),
std
::
back_inserter
(
srcs
),
[](
const
std
::
pair
<
std
::
string_view
,
std
::
string_view
>&
elem
)
{
return
src_file
{
elem
};
});
srcs
.
emplace_back
(
"main.cpp"
,
content
);
auto
args_hpp
=
generate_args_hpp
(
options
.
virtual_inputs
.
empty
()
?
options
.
inputs
:
options
.
virtual_inputs
);
srcs
.
push_back
(
src_file
{
fs
::
path
{
"args.hpp"
},
std
::
make_pair
(
args_hpp
.
data
(),
args_hpp
.
data
()
+
args_hpp
.
size
())});
srcs
.
emplace_back
(
"args.hpp"
,
args_hpp
);
if
(
options
.
global
%
options
.
local
!=
0
and
hip_accept_non_uniform_wg
())
options
.
params
+=
" -fno-offload-uniform-block"
;
else
assert
(
options
.
global
%
options
.
local
==
0
);
options
.
params
+=
" -DMIGRAPHX_NGLOBAL="
+
std
::
to_string
(
options
.
global
);
options
.
params
+=
" -DMIGRAPHX_NLOCAL="
+
std
::
to_string
(
options
.
local
);
options
.
params
+=
" "
+
join_strings
(
compiler_warnings
(),
" "
);
...
...
src/targets/gpu/compile_ops.cpp
View file @
a6fa5e4b
...
...
@@ -37,6 +37,7 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace
gpu
{
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_GPU_COMPILE_PARALLEL
);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_TRACE_BENCHMARKING
);
struct
precompile_op
{
...
...
@@ -179,15 +180,29 @@ struct compile_plan
MIGRAPHX_THROW
(
"Multiple kernels without config"
);
std
::
cout
<<
"Benchmarking "
<<
preop
.
name
()
<<
": "
<<
results
.
size
()
<<
" configs"
<<
std
::
endl
;
if
(
enabled
(
MIGRAPHX_TRACE_BENCHMARKING
{}))
std
::
cout
<<
"Problem: "
<<
config
->
problem
<<
std
::
endl
;
std
::
vector
<
double
>
times
;
times
.
reserve
(
results
.
size
());
std
::
transform
(
results
.
begin
(),
results
.
end
(),
std
::
back_inserter
(
times
),
[
&
](
const
auto
&
cr
)
{
if
(
not
cr
.
has_value
())
return
std
::
numeric_limits
<
double
>::
max
();
return
time_op
(
*
ctx
,
cr
->
replace
.
code_object
,
to_shapes
(
cr
->
ins
->
inputs
()),
20
)
.
first
;
});
std
::
transform
(
results
.
begin
(),
results
.
end
(),
config
->
solutions
.
begin
(),
std
::
back_inserter
(
times
),
[
&
](
const
auto
&
cr
,
const
auto
&
solution
)
{
if
(
enabled
(
MIGRAPHX_TRACE_BENCHMARKING
{}))
std
::
cout
<<
"Benchmarking solution: "
<<
solution
<<
std
::
endl
;
if
(
not
cr
.
has_value
())
{
if
(
enabled
(
MIGRAPHX_TRACE_BENCHMARKING
{}))
std
::
cout
<<
"No binary"
<<
std
::
endl
;
return
std
::
numeric_limits
<
double
>::
max
();
}
auto
t
=
time_op
(
*
ctx
,
cr
->
replace
.
code_object
,
to_shapes
(
cr
->
ins
->
inputs
()),
20
);
if
(
enabled
(
MIGRAPHX_TRACE_BENCHMARKING
{}))
std
::
cout
<<
t
<<
"ms"
<<
std
::
endl
;
return
t
;
});
auto
i
=
std
::
distance
(
times
.
begin
(),
std
::
min_element
(
times
.
begin
(),
times
.
end
()));
std
::
cout
<<
"Fastest solution: "
<<
config
->
solutions
.
at
(
i
)
<<
std
::
endl
;
pc
.
insert
(
preop
.
name
(),
config
->
problem
,
config
->
solutions
.
at
(
i
));
...
...
src/targets/gpu/device/argmax.cpp
View file @
a6fa5e4b
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
...
...
@@ -34,9 +34,16 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace
gpu
{
namespace
device
{
void
argmax
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
,
int64_t
axis
)
void
argmax
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
,
int64_t
axis
,
bool
select_last_index
)
{
arg_op
(
argmax_op
{},
stream
,
result
,
arg
,
axis
);
if
(
select_last_index
)
arg_op
(
argmax_op_last_index
{},
stream
,
result
,
arg
,
axis
);
else
arg_op
(
argmax_op_first_index
{},
stream
,
result
,
arg
,
axis
);
}
}
// namespace device
...
...
src/targets/gpu/device/argmin.cpp
View file @
a6fa5e4b
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
...
...
@@ -34,9 +34,16 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace
gpu
{
namespace
device
{
void
argmin
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
,
int64_t
axis
)
void
argmin
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
,
int64_t
axis
,
bool
select_last_index
)
{
arg_op
(
argmin_op
{},
stream
,
result
,
arg
,
axis
);
if
(
select_last_index
)
arg_op
(
argmin_op_last_index
{},
stream
,
result
,
arg
,
axis
);
else
arg_op
(
argmin_op_first_index
{},
stream
,
result
,
arg
,
axis
);
}
}
// namespace device
...
...
src/targets/gpu/device/include/migraphx/gpu/device/launch.hpp
View file @
a6fa5e4b
...
...
@@ -81,6 +81,14 @@ inline auto launch(hipStream_t stream, index_int global, index_int local)
using
f_type
=
decltype
(
f
);
dim3
nblocks
(
global
/
local
);
dim3
nthreads
(
local
);
/*
hipGetLastError() returns error for the first failed HIP call that happened previously.
MIGraphX calls into various backend libraries and failed HIP calls can also happen there.
Calling hipGetLastError() would reset error code to hipSuccess, so that inside MIGraphX
failed call to hipLaunchKernelGGL() can be captured.
*/
hipError_t
flush_call
=
hipGetLastError
();
(
void
)(
flush_call
);
// cppcheck-suppress UseDeviceLaunch
hipLaunchKernelGGL
((
launcher
<
f_type
>
),
nblocks
,
nthreads
,
0
,
stream
,
f
);
hipError_t
kernel_launch_status
=
hipGetLastError
();
...
...
src/targets/gpu/device/targets.hpp.in
View file @
a6fa5e4b
...
...
@@ -24,7 +24,7 @@
#ifndef MIGRAPHX_GUARD_DEVICE_TARGETS_CPP
#define MIGRAPHX_GUARD_DEVICE_TARGETS_CPP
#include <migraphx/config.hpp>
#include <migraphx/
gpu/device/
config.hpp>
#include <string>
#include <vector>
...
...
@@ -34,9 +34,13 @@ namespace gpu {
namespace device {
#define MIGRAPHX_GPU_TARGETS "@GPU_TARGETS@" // NOLINT
MIGRAPHX_DEVICE_EXPORT
const std::vector<std::string>& get_targets();
MIGRAPHX_DEVICE_EXPORT
std::string get_targets_as_string();
MIGRAPHX_DEVICE_EXPORT
std::string get_device_name();
} // namespace device
...
...
src/targets/gpu/driver/compile_op.cpp
View file @
a6fa5e4b
...
...
@@ -38,10 +38,8 @@ struct compile_op : action<compile_op>
context
ctx
;
auto
inputs
=
p
.
parse_shapes
(
v
.
at
(
"inputs"
));
auto
op
=
gpu
::
compile_op
(
v
.
at
(
"name"
).
to
<
std
::
string
>
(),
ctx
,
inputs
,
v
);
auto
[
host_time
,
device_time
]
=
time_op
(
ctx
,
op
,
inputs
,
p
.
get
(
v
,
"iterations"
,
100
));
std
::
cout
<<
op
<<
": "
<<
host_time
<<
"ms"
;
if
(
device_time
>
0
)
std
::
cout
<<
", "
<<
device_time
<<
"ms"
;
auto
t
=
time_op
(
ctx
,
op
,
inputs
,
p
.
get
(
v
,
"iterations"
,
100
));
std
::
cout
<<
op
<<
": "
<<
t
<<
"ms"
;
std
::
cout
<<
std
::
endl
;
}
};
...
...
src/targets/gpu/driver/run_op.cpp
View file @
a6fa5e4b
...
...
@@ -43,8 +43,8 @@ struct run_op : action<run_op>
auto
op
=
make_op
(
name
);
if
(
v
.
contains
(
"fields"
))
op
.
from_value
(
v
.
at
(
"fields"
));
auto
[
host_time
,
device_time
]
=
time_op
(
ctx
,
op
,
inputs
,
p
.
get
(
v
,
"iterations"
,
100
));
std
::
cout
<<
op
<<
": "
<<
host_time
<<
"ms"
<<
std
::
endl
;
auto
t
=
time_op
(
ctx
,
op
,
inputs
,
p
.
get
(
v
,
"iterations"
,
100
));
std
::
cout
<<
op
<<
": "
<<
t
<<
"ms"
<<
std
::
endl
;
}
};
...
...
src/targets/gpu/fuse_ck.cpp
View file @
a6fa5e4b
...
...
@@ -22,10 +22,11 @@
* THE SOFTWARE.
*/
#include <migraphx/gpu/fuse_ck.hpp>
#include <migraphx/gpu/gemm_softmax_gemm.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/gpu/device_name.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
@@ -55,7 +56,7 @@ struct ck_gemm
{
check_shapes
{
inputs
,
*
this
}.
same_ndims
();
if
(
inputs
.
size
()
<
2
)
MIGRAPHX_THROW
(
"
should have at least two inputs."
);
MIGRAPHX_THROW
(
name
()
+
":
should have at least two inputs."
);
auto
a
=
inputs
[
0
];
auto
b
=
inputs
[
1
];
for
(
const
auto
&
input
:
inputs
)
...
...
@@ -65,27 +66,35 @@ struct ck_gemm
return
r
;
return
r
.
with_type
(
mods
.
front
()
->
get_output_shapes
().
front
().
type
());
}
static
bool
is_ck_supported_type
(
shape
::
type_t
t
)
{
return
contains
({
shape
::
half_type
,
shape
::
int8_type
,
shape
::
int32_type
},
t
);
}
};
MIGRAPHX_REGISTER_OP
(
ck_gemm
);
namespace
{
bool
is_ck_supported_type
(
shape
::
type_t
t
)
struct
ck_gemm_softmax_gemm
:
gemm_softmax_gemm
{
return
contains
({
shape
::
half_type
,
shape
::
int8_type
,
shape
::
int32_type
},
t
);
}
std
::
string
name
()
const
{
return
"gpu::ck_gemm_softmax_gemm"
;
}
};
MIGRAPHX_REGISTER_OP
(
ck_gemm_softmax_gemm
);
namespace
{
MIGRAPHX_PRED_MATCHER
(
is_ck_gemm
,
instruction_ref
ins
)
{
if
(
ins
->
name
()
!=
"dot"
and
ins
->
name
()
!=
"quant_dot"
)
return
false
;
if
(
not
is_ck_supported_type
(
ins
->
get_shape
().
type
()))
if
(
not
ck_gemm
::
is_ck_supported_type
(
ins
->
get_shape
().
type
()))
return
false
;
auto
a
=
ins
->
inputs
().
front
()
->
get_shape
();
auto
b
=
ins
->
inputs
().
back
()
->
get_shape
();
auto
m
=
a
.
lens
()[
a
.
lens
().
size
()
-
2
];
auto
n
=
b
.
lens
().
back
();
auto
k
=
a
.
lens
().
back
();
auto
batch_size
=
std
::
accumulate
(
a
.
lens
().
rbegin
()
+
2
,
a
.
lens
().
rend
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
// Integer gemms must be divisible by 4 in ck
if
(
contains
({
shape
::
int8_type
,
shape
::
int32_type
},
ins
->
get_shape
().
type
()))
{
...
...
@@ -96,9 +105,17 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
if
(
k
%
4
!=
0
)
return
false
;
}
// Skipping GEMMs with a K dimension greater than 2048 is a course-grained strategy
// to avoid poor-performing GEMM kernels from CK
// To-do: Investigate a more precise strategy
auto
device_name
=
trim
(
split_string
(
get_device_name
(),
':'
).
front
());
if
(
device_name
==
"gfx940"
)
{
if
(
ins
->
get_shape
().
type
()
==
shape
::
half_type
)
{
if
(
batch_size
>=
64
)
return
m
<
2048
or
k
<=
64
or
n
<=
384
or
n
>=
2048
;
return
true
;
}
return
true
;
}
return
k
<=
2048
;
}
...
...
@@ -127,7 +144,15 @@ struct find_ck_gemm_pointwise
ins
->
get_shape
().
type
()
!=
gemm_ins
->
get_shape
().
type
())
return
;
if
(
std
::
any_of
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
[](
auto
input
)
{
return
not
is_ck_supported_type
(
input
->
get_shape
().
type
());
return
not
ck_gemm
::
is_ck_supported_type
(
input
->
get_shape
().
type
());
}))
return
;
if
(
std
::
any_of
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
[](
auto
input
)
{
return
not
input
->
inputs
().
empty
()
and
input
->
inputs
().
front
()
->
name
()
==
"capture"
;
}))
return
;
if
(
std
::
any_of
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
[](
auto
input
)
{
return
not
input
->
inputs
().
empty
()
and
input
->
inputs
().
front
()
->
name
()
==
"capture"
;
}))
return
;
assert
(
gemm_it
!=
inputs
.
end
());
...
...
@@ -152,7 +177,7 @@ struct find_ck_gemm_pointwise
struct
find_ck_gemm
{
auto
matcher
()
const
{
return
match
::
name
(
"dot"
)(
is_ck_gemm
().
bind
(
"gemm"
));
}
auto
matcher
()
const
{
return
match
::
name
(
"dot"
,
"quant_dot"
)(
is_ck_gemm
().
bind
(
"gemm"
));
}
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
{
...
...
@@ -161,11 +186,26 @@ struct find_ck_gemm
}
};
struct
find_ck_gemm_softmax_gemm
{
auto
matcher
()
const
{
return
match
::
name
(
"gpu::pre_gemm_softmax_gemm"
);
}
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
v
=
ins
->
get_operator
().
to_value
();
assert
(
v
.
contains
(
"scale"
));
auto
scale
=
v
.
at
(
"scale"
).
to
<
float
>
();
mpm
.
get_module
().
replace_instruction
(
ins
,
ck_gemm_softmax_gemm
{
migraphx
::
make_op
(
"dot"
),
scale
},
ins
->
inputs
());
}
};
}
// namespace
void
fuse_ck
::
apply
(
module_pass_manager
&
mpm
)
const
{
match
::
find_matches
(
mpm
,
find_ck_gemm_pointwise
{});
match
::
find_matches
(
mpm
,
find_ck_gemm_softmax_gemm
{},
find_ck_gemm_pointwise
{});
match
::
find_matches
(
mpm
,
find_ck_gemm
{});
}
...
...
src/targets/gpu/fuse_mlir.cpp
View file @
a6fa5e4b
...
...
@@ -36,24 +36,14 @@ struct module;
namespace
gpu
{
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_ENABLE_MLIR
);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_ENABLE_EXTRA_MLIR
);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_DISABLE_MLIR
);
bool
mlir_enabled
()
{
#ifdef MIGRAPHX_MLIR
const
bool
mlir_enabled
=
enabled
(
MIGRAPHX_ENABLE_MLIR
{});
if
(
mlir_enabled
)
{
return
true
;
}
else
{
std
::
cerr
<<
"WARNING: MIGraphX built with MLIR but it is not enabled. Please set the env "
"var MIGRAPHX_ENABLE_MLIR to use MLIR kernel generator."
<<
std
::
endl
;
return
false
;
}
const
bool
mlir_disabled
=
enabled
(
MIGRAPHX_DISABLE_MLIR
{});
return
not
mlir_disabled
;
#else
return
false
;
#endif
...
...
@@ -131,9 +121,16 @@ fuse_input_ops_and_gemm_based_op(module_ref mm, instruction_ref gemm_based_op)
for
(
instruction_ref
input
:
gemm_based_op
->
inputs
())
{
std
::
vector
<
operation
>
op_stream
;
while
(
contains
({
"slice"
,
"transpose"
,
"contiguous"
,
"reshape"
},
input
->
name
()))
while
(
contains
(
{
"slice"
,
"transpose"
,
"contiguous"
,
"reshape"
,
"squeeze"
,
"flatten"
,
"unsqueeze"
},
input
->
name
()))
{
op_stream
.
push_back
(
input
->
get_operator
());
operation
op
=
input
->
get_operator
();
if
(
contains
({
"squeeze"
,
"flatten"
,
"unsqueeze"
},
input
->
name
()))
{
op
=
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
input
->
get_shape
().
lens
()}});
}
op_stream
.
push_back
(
op
);
input
=
input
->
inputs
().
at
(
0
);
}
top_inputs
.
push_back
(
input
);
...
...
@@ -150,27 +147,72 @@ fuse_input_ops_and_gemm_based_op(module_ref mm, instruction_ref gemm_based_op)
return
{
new_gemm_based_op
,
top_inputs
};
}
MIGRAPHX_PRED_MATCHER
(
is_mlir_conv
,
instruction_ref
ins
)
enum
class
mlir_mode
{
if
(
ins
->
name
()
!=
"convolution"
and
ins
->
name
()
!=
"quant_convolution"
)
return
false
;
value
v
=
ins
->
get_operator
().
to_value
();
auto
group
=
v
.
at
(
"group"
).
to
<
int
>
();
if
(
group
!=
1
)
return
false
;
// Avoid MLIR assertion: Index < Length && "Invalid index!"
if
(
ins
->
get_shape
().
lens
().
size
()
!=
4
)
return
false
;
return
true
;
all
,
fast
,
int8
,
none
};
auto
is_mlir_dot
(
mlir_mode
mode
)
{
return
match
::
make_basic_pred_matcher
([
=
](
instruction_ref
ins
)
{
if
(
mode
==
mlir_mode
::
none
)
return
false
;
if
(
ins
->
name
()
!=
"dot"
and
ins
->
name
()
!=
"quant_dot"
)
return
false
;
if
(
mode
!=
mlir_mode
::
fast
)
return
true
;
auto
a
=
ins
->
inputs
().
front
()
->
get_shape
();
auto
b
=
ins
->
inputs
().
back
()
->
get_shape
();
// auto m = a.lens()[a.lens().size() - 2];
// auto n = b.lens().back();
auto
k
=
a
.
lens
().
back
();
// Skipping GEMMs with a K dimension greater than 2048 is a course-grained strategy
// to avoid poor-performing GEMM kernels from MLIR
// To-do: Investigate a more precise strategy
return
k
<=
2048
;
});
}
auto
is_mlir_conv
(
mlir_mode
mode
)
{
return
match
::
make_basic_pred_matcher
([
=
](
instruction_ref
ins
)
{
if
(
mode
==
mlir_mode
::
none
)
return
false
;
if
(
ins
->
name
()
!=
"convolution"
and
ins
->
name
()
!=
"quant_convolution"
)
return
false
;
value
v
=
ins
->
get_operator
().
to_value
();
auto
group
=
v
.
at
(
"group"
).
to
<
int
>
();
if
(
group
!=
1
)
return
false
;
// Avoid MLIR assertion: Index < Length && "Invalid index!"
if
(
ins
->
get_shape
().
lens
().
size
()
!=
4
)
return
false
;
if
(
ins
->
get_shape
().
type
()
==
shape
::
int8_type
)
return
true
;
if
(
mode
==
mlir_mode
::
int8
)
return
false
;
if
(
mode
==
mlir_mode
::
all
)
return
true
;
auto
w
=
ins
->
inputs
().
at
(
1
)
->
get_shape
();
if
(
w
.
lens
().
size
()
!=
4
)
return
true
;
if
(
w
.
lens
()[
2
]
!=
w
.
lens
()[
3
])
return
true
;
return
(
w
.
lens
()[
3
]
%
3
)
!=
0
;
});
}
struct
find_mlir_fused_ops
{
mlir_mode
conv_mode
=
mlir_mode
::
none
;
mlir_mode
dot_mode
=
mlir_mode
::
none
;
auto
matcher
()
const
{
auto
dot_or_conv
=
match
::
skip
(
match
::
name
(
"contiguous"
))(
match
::
any_of
(
match
::
name
(
"dot"
),
match
::
name
(
"quant_dot"
),
is_mlir_conv
())
.
bind
(
"gemm_based_op"
));
match
::
any_of
(
is_mlir_dot
(
dot_mode
),
is_mlir_conv
(
conv_mode
)).
bind
(
"gemm_based_op"
));
return
match
::
name
(
"pointwise"
)(
match
::
any_of
[
match
::
inputs
()](
dot_or_conv
.
bind
(
"x"
)));
}
...
...
@@ -302,8 +344,11 @@ struct find_mlir_fused_ops
}
};
template
<
auto
Matcher
>
struct
find_mlir_standalone_op
{
mlir_mode
mode
=
mlir_mode
::
none
;
auto
matcher
()
const
{
return
Matcher
(
mode
);
}
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
{
auto
conv_based_op
=
r
.
result
;
...
...
@@ -325,15 +370,8 @@ struct find_mlir_standalone_op
}
};
struct
find_mlir_standalone_convolution_op
:
find_mlir_standalone_op
{
auto
matcher
()
const
{
return
is_mlir_conv
;
}
};
struct
find_mlir_standalone_dot_op
:
find_mlir_standalone_op
{
auto
matcher
()
const
{
return
match
::
any_of
(
match
::
name
(
"dot"
),
match
::
name
(
"quant_dot"
));
}
};
using
find_mlir_standalone_convolution_op
=
find_mlir_standalone_op
<&
is_mlir_conv
>
;
using
find_mlir_standalone_dot_op
=
find_mlir_standalone_op
<&
is_mlir_dot
>
;
/**
* @brief Declares a new MIGraphX environment variable which forces to generate
...
...
@@ -347,44 +385,15 @@ struct find_mlir_standalone_dot_op : find_mlir_standalone_op
* intended to be primarily used by rocMLIR developers.
*/
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_MLIR_USE_SPECIFIC_OPS
);
bool
is_self_decide
()
{
return
string_value_of
(
MIGRAPHX_MLIR_USE_SPECIFIC_OPS
{},
""
).
empty
();
}
bool
is_requested
(
std
::
string_view
option
)
bool
is_requested
(
std
::
string_view
option
,
bool
fallback
=
false
)
{
assert
(
not
is_self_decide
());
auto
string_value
=
string_value_of
(
MIGRAPHX_MLIR_USE_SPECIFIC_OPS
{},
""
);
if
(
string_value
.
empty
())
return
fallback
;
const
auto
options
=
split_string
(
string_value
,
','
);
return
contains
(
options
,
option
);
}
bool
is_enabled
(
std
::
string_view
op_name
,
context
*
ctx
)
{
if
(
is_self_decide
())
{
if
(
op_name
==
"fused"
)
{
return
true
;
}
else
if
(
op_name
==
"convolution"
or
op_name
==
"quant_convolution"
)
{
if
(
ctx
==
nullptr
)
{
return
false
;
}
else
{
const
auto
&
device
=
ctx
->
get_current_device
();
const
std
::
string
navi_family
{
"gfx110"
};
return
starts_with
(
device
.
get_gfx_name
(),
navi_family
);
}
}
else
{
return
false
;
}
}
return
is_requested
(
op_name
);
}
}
// namespace
#endif // MIGRAPHX_MLIR
...
...
@@ -392,20 +401,28 @@ bool is_enabled(std::string_view op_name, context* ctx)
void
fuse_mlir
::
apply
(
module_pass_manager
&
mpm
)
const
{
#ifdef MIGRAPHX_MLIR
if
(
is_enabled
(
"fused"
,
this
->
ctx
))
{
match
::
find_matches
(
mpm
,
find_mlir_fused_ops
{});
}
const
auto
&
device_name
=
ctx
==
nullptr
?
""
:
ctx
->
get_current_device
().
get_gfx_name
();
const
bool
is_navi
=
starts_with
(
device_name
,
"gfx110"
);
if
(
is_enabled
(
"convolution"
,
this
->
ctx
))
{
match
::
find_matches
(
mpm
,
find_mlir_standalone_convolution_op
{});
}
auto
get_mode
=
[
&
](
std
::
string_view
option
,
mlir_mode
m1
,
mlir_mode
m2
=
mlir_mode
::
fast
)
{
if
(
is_requested
(
option
))
return
mlir_mode
::
all
;
if
(
is_navi
)
return
mlir_mode
::
all
;
return
std
::
max
(
m1
,
m2
);
};
if
(
is_enabled
(
"dot"
,
this
->
ctx
))
{
match
::
find_matches
(
mpm
,
find_mlir_standalone_dot_op
{});
}
mlir_mode
mode
=
(
enabled
(
MIGRAPHX_ENABLE_EXTRA_MLIR
{})
or
enable_extra
)
?
mlir_mode
::
fast
:
mlir_mode
::
none
;
match
::
find_matches
(
mpm
,
find_mlir_fused_ops
{.
conv_mode
=
get_mode
(
"fused"
,
mlir_mode
::
fast
),
.
dot_mode
=
get_mode
(
"fused"
,
mode
)});
match
::
find_matches
(
mpm
,
find_mlir_standalone_convolution_op
{
get_mode
(
"convolution"
,
mlir_mode
::
int8
)},
find_mlir_standalone_dot_op
{
get_mode
(
"dot"
,
mlir_mode
::
none
)});
#else
(
void
)
mpm
;
#endif
...
...
src/targets/gpu/hiprtc/main.cpp
View file @
a6fa5e4b
...
...
@@ -27,6 +27,7 @@
#include <migraphx/msgpack.hpp>
#include <migraphx/file_buffer.hpp>
#include <migraphx/ranges.hpp>
#include <array>
#include <iostream>
#include <cstring>
...
...
src/targets/gpu/include/migraphx/gpu/ck.hpp
0 → 100644
View file @
a6fa5e4b
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_GPU_CK_HPP
#define MIGRAPHX_GUARD_GPU_CK_HPP
#include <migraphx/compile_src.hpp>
#include <migraphx/env.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/stringutils.hpp>
#include <string_view>
#include "ck/host/device_gemm_multiple_d.hpp"
#include "ck/host/device_batched_gemm_softmax_gemm.hpp"
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
#ifndef _WIN32
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_ENABLE_CK
);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_LOG_CK_GEMM
);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_CK_DEBUG
);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_TUNE_CK
);
#endif
// NOLINTNEXTLINE
const
char
*
const
disable_warning_pragma
=
R"__migraphx__(
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
${content}
#pragma clang diagnostic pop
)__migraphx__"
;
template
<
class
P
>
std
::
string
ck_disable_warnings
(
P
p
)
{
return
interpolate_string
(
disable_warning_pragma
,
{{
"content"
,
std
::
string
{
p
.
data
(),
p
.
size
()}}});
}
static
std
::
unordered_map
<
std
::
string
,
std
::
string
>
create_ck_header_strings
()
{
std
::
unordered_map
<
std
::
string
,
std
::
string
>
result
;
auto
ck_headers
=
ck
::
host
::
GetHeaders
();
std
::
transform
(
ck_headers
.
begin
(),
ck_headers
.
end
(),
std
::
inserter
(
result
,
result
.
begin
()),
[
&
](
auto
&
p
)
{
return
std
::
pair
<
std
::
string
,
std
::
string
>
(
p
.
first
,
ck_disable_warnings
(
p
.
second
));
});
return
result
;
}
static
std
::
vector
<
src_file
>
create_ck_headers
()
{
static
const
auto
&
header_strings
=
create_ck_header_strings
();
std
::
vector
<
src_file
>
srcs
;
std
::
transform
(
header_strings
.
begin
(),
header_strings
.
end
(),
std
::
back_inserter
(
srcs
),
[
&
](
auto
&
p
)
{
return
src_file
{
p
};
});
return
srcs
;
}
static
inline
const
std
::
vector
<
src_file
>&
ck_headers
()
{
static
const
auto
&
headers
=
create_ck_headers
();
return
headers
;
}
inline
bool
transposed_matrix
(
const
shape
&
s
)
{
return
s
.
strides
().
back
()
!=
1
;
}
inline
ck
::
host
::
DataType
get_type
(
const
shape
&
s
)
{
if
(
s
.
type
()
==
shape
::
half_type
)
return
ck
::
host
::
DataType
::
Half
;
else
if
(
s
.
type
()
==
shape
::
float_type
)
return
ck
::
host
::
DataType
::
Float
;
else
if
(
s
.
type
()
==
shape
::
int8_type
)
return
ck
::
host
::
DataType
::
Int8
;
else
if
(
s
.
type
()
==
shape
::
int32_type
)
return
ck
::
host
::
DataType
::
Int32
;
MIGRAPHX_THROW
(
"Unsupported ck type"
);
}
inline
std
::
size_t
get_batch_count
(
const
shape
&
s
)
{
return
std
::
accumulate
(
s
.
lens
().
rbegin
()
+
2
,
s
.
lens
().
rend
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
}
inline
void
fold_batch_dims
(
shape
&
s
)
{
auto
lens
=
s
.
lens
();
if
(
lens
.
size
()
<=
2
)
return
;
auto
batch_count
=
get_batch_count
(
s
);
auto
m1
=
lens
.
at
(
lens
.
size
()
-
2
);
auto
m2
=
lens
.
at
(
lens
.
size
()
-
1
);
if
(
transposed_matrix
(
s
))
s
=
shape
{
s
.
type
(),
{
m1
,
m2
*
batch_count
}};
else
s
=
shape
{
s
.
type
(),
{
m1
*
batch_count
,
m2
}};
}
inline
void
remove_batch_dims
(
shape
&
s
)
{
auto
lens
=
s
.
lens
();
if
(
lens
.
size
()
<=
2
)
return
;
auto
m1
=
lens
.
at
(
lens
.
size
()
-
2
);
auto
m2
=
lens
.
at
(
lens
.
size
()
-
1
);
s
=
shape
{
s
.
type
(),
{
m1
,
m2
}};
}
inline
bool
standard_batch
(
const
shape
&
s
)
{
if
(
s
.
lens
().
size
()
<
3
)
return
true
;
std
::
vector
<
std
::
size_t
>
lens
(
s
.
lens
().
begin
(),
s
.
lens
().
end
()
-
2
);
std
::
vector
<
std
::
size_t
>
strides
(
s
.
strides
().
begin
(),
s
.
strides
().
end
()
-
2
);
auto
base
=
*
(
s
.
lens
().
end
()
-
2
)
*
*
(
s
.
lens
().
end
()
-
1
);
std
::
transform
(
strides
.
begin
(),
strides
.
end
(),
strides
.
begin
(),
[
&
](
auto
stride
)
{
return
stride
/
base
;
});
return
shape
{
s
.
type
(),
lens
,
strides
}.
standard
();
}
inline
bool
can_fold_batch
(
const
std
::
vector
<
shape
>&
inputs
)
{
const
auto
&
b_shape
=
inputs
[
1
];
if
(
std
::
any_of
(
inputs
.
begin
()
+
2
,
inputs
.
end
()
-
1
,
[](
auto
input
)
{
return
not
standard_batch
(
input
);
}))
return
false
;
const
auto
&
b_strides
=
b_shape
.
strides
();
return
std
::
all_of
(
b_strides
.
begin
(),
b_strides
.
end
()
-
2
,
[](
auto
stride
)
{
return
stride
==
0
;
});
}
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_GPU_CK_HPP
src/targets/gpu/include/migraphx/gpu/compile_hip.hpp
View file @
a6fa5e4b
...
...
@@ -45,10 +45,7 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS);
struct
hiprtc_src_file
{
hiprtc_src_file
()
=
default
;
hiprtc_src_file
(
const
src_file
&
s
)
:
path
(
s
.
path
.
string
()),
content
(
s
.
content
.
first
,
s
.
content
.
second
)
{
}
hiprtc_src_file
(
const
src_file
&
s
)
:
path
(
s
.
path
.
string
()),
content
(
s
.
content
)
{}
std
::
string
path
;
std
::
string
content
;
template
<
class
Self
,
class
F
>
...
...
src/targets/gpu/include/migraphx/gpu/context.hpp
View file @
a6fa5e4b
...
...
@@ -299,23 +299,6 @@ struct context
any_ptr
get_queue
()
{
return
get_stream
().
get
();
}
void
enable_perf_measurement
(
bool
b
=
true
)
{
if
(
b
)
{
start_event
=
create_event_for_timing
();
stop_event
=
create_event_for_timing
();
get_stream
().
record
(
start_event
.
get
());
get_stream
().
record
(
stop_event
.
get
());
}
else
{
start_event
=
nullptr
;
stop_event
=
nullptr
;
}
measure_perf
=
b
;
}
std
::
pair
<
hipEvent_t
,
hipEvent_t
>
get_perf_events
()
const
{
if
(
measure_perf
)
...
...
@@ -323,12 +306,12 @@ struct context
return
std
::
make_pair
(
nullptr
,
nullptr
);
}
float
get_elapsed_ms
(
)
const
static
float
get_elapsed_ms
(
hipEvent_t
start
,
hipEvent_t
stop
)
{
float
result
=
0
;
if
(
start
_event
!=
nullptr
and
stop
_event
!=
nullptr
)
if
(
start
!=
nullptr
and
stop
!=
nullptr
)
{
auto
status
=
hipEventElapsedTime
(
&
result
,
start
_event
.
get
(),
stop_event
.
get
()
);
auto
status
=
hipEventElapsedTime
(
&
result
,
start
,
stop
);
if
(
status
!=
hipSuccess
)
MIGRAPHX_THROW
(
"Failed hipEventElapsedTime: "
+
hip_error
(
status
));
}
...
...
src/targets/gpu/include/migraphx/gpu/convolution.hpp
View file @
a6fa5e4b
...
...
@@ -199,9 +199,9 @@ struct miopen_convolution
// MIOpen has APIs to pass pre-allocated buffers starting from rocm-5.6
preallocate
=
true
;
#endif
auto
x
=
preallocate
?
to_gpu
(
generate_argument
(
x_shape
))
:
inputs
[
0
];
auto
w
=
preallocate
?
to_gpu
(
generate_argument
(
w_shape
))
:
inputs
[
1
];
auto
y
=
preallocate
?
allocate_gpu
(
output_shape
)
:
inputs
[
2
];
auto
x
=
preallocate
?
to_gpu
(
generate_argument
(
x_shape
))
:
argument
{
inputs
[
0
]
}
;
auto
w
=
preallocate
?
to_gpu
(
generate_argument
(
w_shape
))
:
argument
{
inputs
[
1
]
}
;
auto
y
=
preallocate
?
allocate_gpu
(
output_shape
)
:
argument
{
inputs
[
2
]
}
;
auto
workspace
=
preallocate
?
allocate_gpu
(
workspace_shape
)
:
migraphx
::
argument
(
workspace_shape
);
...
...
src/targets/gpu/include/migraphx/gpu/device/arg_op.hpp
View file @
a6fa5e4b
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
...
...
@@ -55,7 +55,7 @@ MIGRAPHX_DEVICE_CONSTEXPR val_index<T> make_val_index(T v, int64_t i)
return
{
v
,
i
};
}
struct
argmax_op
struct
argmax_op
_first_index
{
template
<
class
T
>
MIGRAPHX_DEVICE_CONSTEXPR
val_index
<
T
>
operator
()(
val_index
<
T
>
x
,
val_index
<
T
>
y
)
const
...
...
@@ -73,7 +73,25 @@ struct argmax_op
MIGRAPHX_DEVICE_CONSTEXPR
auto
init
()
const
{
return
lowest
();
}
};
struct
argmin_op
struct
argmax_op_last_index
{
template
<
class
T
>
MIGRAPHX_DEVICE_CONSTEXPR
val_index
<
T
>
operator
()(
val_index
<
T
>
x
,
val_index
<
T
>
y
)
const
{
if
(
x
.
val
>
y
.
val
)
return
x
;
else
if
(
x
.
val
<
y
.
val
)
return
y
;
else
{
return
(
x
.
index
>
y
.
index
)
?
x
:
y
;
}
}
MIGRAPHX_DEVICE_CONSTEXPR
auto
init
()
const
{
return
lowest
();
}
};
struct
argmin_op_first_index
{
template
<
class
T
>
MIGRAPHX_DEVICE_CONSTEXPR
val_index
<
T
>
operator
()(
val_index
<
T
>
x
,
val_index
<
T
>
y
)
const
...
...
@@ -91,6 +109,24 @@ struct argmin_op
MIGRAPHX_DEVICE_CONSTEXPR
auto
init
()
const
{
return
highest
();
}
};
struct
argmin_op_last_index
{
template
<
class
T
>
MIGRAPHX_DEVICE_CONSTEXPR
val_index
<
T
>
operator
()(
val_index
<
T
>
x
,
val_index
<
T
>
y
)
const
{
if
(
x
.
val
<
y
.
val
)
return
x
;
else
if
(
x
.
val
>
y
.
val
)
return
y
;
else
{
return
(
x
.
index
>
y
.
index
)
?
x
:
y
;
}
}
MIGRAPHX_DEVICE_CONSTEXPR
auto
init
()
const
{
return
highest
();
}
};
template
<
class
Op
>
void
arg_op
(
Op
op
,
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
,
int64_t
axis
)
{
...
...
Prev
1
2
3
4
5
6
7
8
9
…
13
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment