Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
2dc6894c
"ts/webui/src/static/style/logDrawer.scss" did not exist on "fc7ddcd0c83febfbbae76bc5065e1e9d6cd8f8c3"
Commit
2dc6894c
authored
Nov 13, 2023
by
Paul
Browse files
Add subwave reductions
parent
df869fd8
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
143 additions
and
44 deletions
+143
-44
src/targets/gpu/include/migraphx/gpu/context.hpp
src/targets/gpu/include/migraphx/gpu/context.hpp
+2
-0
src/targets/gpu/jit/reduce.cpp
src/targets/gpu/jit/reduce.cpp
+49
-17
src/targets/gpu/kernels/include/migraphx/kernels/index.hpp
src/targets/gpu/kernels/include/migraphx/kernels/index.hpp
+20
-1
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
+72
-26
No files found.
src/targets/gpu/include/migraphx/gpu/context.hpp
View file @
2dc6894c
...
...
@@ -182,6 +182,8 @@ struct hip_device
std
::
size_t
get_max_workitems_per_block
()
const
{
return
device_props
.
maxThreadsPerBlock
;
}
std
::
size_t
get_wavefront_size
()
const
{
return
device_props
.
warpSize
;
}
private:
std
::
size_t
device_id
=
0
;
std
::
size_t
current_stream
=
0
;
...
...
src/targets/gpu/jit/reduce.cpp
View file @
2dc6894c
...
...
@@ -97,9 +97,10 @@ static shape get_output_shape(const shape& s, const std::vector<T>& axes)
}
template
<
class
ReduceLens
>
static
std
::
string
get_reduce_algo
(
const
std
::
vector
<
shape
>&
inputs
,
ReduceLens
rlens
)
static
std
::
string
get_reduce_algo
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
ReduceLens
rlens
)
{
const
auto
init
=
std
::
numeric_limits
<
std
::
size_t
>::
max
();
auto
relements
=
std
::
accumulate
(
rlens
.
begin
(),
rlens
.
end
(),
1
,
std
::
multiplies
<>
{});
// The minimum stride
auto
min_stride
=
std
::
inner_product
(
rlens
.
begin
(),
...
...
@@ -110,13 +111,24 @@ static std::string get_reduce_algo(const std::vector<shape>& inputs, ReduceLens
[](
auto
len
,
auto
stride
)
{
return
len
==
1
?
init
:
stride
;
});
if
(
min_stride
>
2
)
return
"lane"
;
if
(
relements
<=
ctx
.
get_current_device
().
get_wavefront_size
())
return
"wave"
;
return
"block"
;
}
static
std
::
string
get_reduce_algo
(
const
std
::
vector
<
shape
>&
inputs
)
static
std
::
string
get_reduce_algo
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
)
{
auto
rlens
=
get_reduce_lens
(
inputs
.
front
().
lens
(),
inputs
.
back
().
lens
());
return
get_reduce_algo
(
inputs
,
rlens
);
return
get_reduce_algo
(
ctx
,
inputs
,
rlens
);
}
static
std
::
size_t
compute_subwave_size
(
context
&
ctx
,
std
::
size_t
n
)
{
std
::
size_t
max_wavefront_size
=
ctx
.
get_current_device
().
get_wavefront_size
();
std
::
size_t
wavefront_size
=
1
;
while
(
wavefront_size
<
n
and
wavefront_size
<
max_wavefront_size
)
wavefront_size
*=
2
;
return
wavefront_size
;
}
struct
simple_reduce_compiler
:
compiler
<
simple_reduce_compiler
>
...
...
@@ -145,18 +157,28 @@ struct simple_reduce_compiler : compiler<simple_reduce_compiler>
auto
faxis
=
find_fast_axis
({
options
.
virtual_inputs
.
front
()});
vectorize
vec
{};
auto
nelements
=
options
.
virtual_inputs
.
back
().
elements
();
auto
algo
=
v
.
get
(
"algo"
,
get_reduce_algo
(
options
.
virtual_inputs
));
if
(
algo
==
"block"
)
auto
algo
=
v
.
get
(
"algo"
,
get_reduce_algo
(
ctx
,
options
.
virtual_inputs
));
if
(
algo
==
"block"
or
algo
==
"wave"
)
{
// Vectorize if the axis is a reduction axis
if
(
options
.
virtual_inputs
.
back
().
lens
()[
faxis
]
==
1
)
vec
=
vectorize
::
elements
(
ctx
,
faxis
,
options
.
virtual_inputs
);
auto
relements
=
get_reduce_elements
(
options
.
virtual_inputs
)
/
vec
.
size
;
auto
block_size
=
compute_block_size
(
relements
,
256
);
if
(
relements
>=
block_size
*
256
)
algo
=
"block_large"
;
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
nelements
*
block_size
,
256
),
block_size
);
if
(
algo
==
"block"
)
{
auto
block_size
=
compute_block_size
(
relements
,
256
);
if
(
relements
>=
block_size
*
256
)
algo
=
"block_large"
;
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
nelements
*
block_size
,
256
),
block_size
);
}
else
{
auto
subwave_size
=
compute_subwave_size
(
ctx
,
relements
);
algo
=
"subwave<"
+
std
::
to_string
(
subwave_size
)
+
">"
;
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
nelements
*
subwave_size
,
256
),
ctx
.
get_current_device
().
get_wavefront_size
());
}
}
else
if
(
algo
==
"lane"
)
{
...
...
@@ -241,18 +263,28 @@ struct fused_reduce_compiler : compiler<fused_reduce_compiler>
auto
faxis
=
find_fast_axis
({
options
.
virtual_inputs
.
front
()});
vectorize
vec
{};
auto
nelements
=
reduce_output_shape
.
elements
();
auto
algo
=
v
.
get
(
"algo"
,
get_reduce_algo
(
options
.
virtual_inputs
,
reduction_shape
.
lens
()));
if
(
algo
==
"block"
)
auto
algo
=
v
.
get
(
"algo"
,
get_reduce_algo
(
ctx
,
options
.
virtual_inputs
,
reduction_shape
.
lens
()));
if
(
algo
==
"block"
or
algo
==
"wave"
)
{
// Vectorize if the axis is a reduction axis
if
(
reduce_output_shape
.
lens
()[
faxis
]
==
1
)
vec
=
vectorize
::
elements
(
ctx
,
faxis
,
options
.
virtual_inputs
);
auto
relements
=
reduction_shape
.
elements
()
/
vec
.
size
;
auto
block_size
=
compute_block_size
(
relements
,
256
);
if
(
relements
>=
block_size
*
256
)
algo
=
"block_large"
;
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
nelements
*
block_size
,
256
),
block_size
);
if
(
algo
==
"block"
)
{
auto
block_size
=
compute_block_size
(
relements
,
256
);
if
(
relements
>=
block_size
*
256
)
algo
=
"block_large"
;
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
nelements
*
block_size
,
256
),
block_size
);
}
else
{
auto
subwave_size
=
compute_subwave_size
(
ctx
,
relements
);
algo
=
"subwave<"
+
std
::
to_string
(
subwave_size
)
+
">"
;
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
nelements
*
subwave_size
,
256
),
ctx
.
get_current_device
().
get_wavefront_size
());
}
}
else
if
(
algo
==
"lane"
)
{
...
...
src/targets/gpu/kernels/include/migraphx/kernels/index.hpp
View file @
2dc6894c
...
...
@@ -135,6 +135,13 @@ struct index
constexpr
auto
ngroup
()
const
{
return
nglobal
()
/
max_nlocal
();
}
template
<
unsigned
int
SubWaveSize
>
constexpr
index_constant
<
SubWaveSize
>
nlocal_subwave
()
const
{
return
{};
}
template
<
unsigned
int
SubWaveSize
>
constexpr
auto
local_subwave
()
const
{
return
local
%
nlocal_subwave
<
SubWaveSize
>
();
}
template
<
unsigned
int
SubWaveSize
>
constexpr
auto
nwave
()
const
{
return
max_nlocal
()
/
nlocal_subwave
<
SubWaveSize
>
();
}
constexpr
index_constant
<
__AMDGCN_WAVEFRONT_SIZE
>
nlocal_wave
()
const
{
return
{};
}
constexpr
auto
local_wave
()
const
{
return
local
%
nlocal_wave
();
}
constexpr
auto
nwave
()
const
{
return
max_nlocal
()
/
nlocal_wave
();
}
...
...
@@ -164,6 +171,12 @@ struct index
return
max_stride_iterations
(
n
,
nlocal_wave
());
}
template
<
unsigned
int
SubWaveSize
,
class
N
>
constexpr
auto
max_local_subwave_stride_iterations
(
N
n
)
const
{
return
max_stride_iterations
(
n
,
nlocal_subwave
<
SubWaveSize
>
());
}
template
<
class
F
,
class
I
,
class
D
>
static
constexpr
auto
invoke_loop
(
F
f
,
I
i
,
D
d
)
->
decltype
(
f
(
i
,
d
))
{
...
...
@@ -254,10 +267,16 @@ struct index
for_stride
<
false
>
(
group
,
n
,
ngroup
(),
f
);
}
template
<
unsigned
int
SubWaveSize
,
class
F
,
class
N
>
__device__
void
local_subwave_stride
(
N
n
,
F
f
)
const
{
for_stride
<
true
>
(
local_subwave
<
SubWaveSize
>
(),
n
,
nlocal_subwave
<
SubWaveSize
>
(),
f
);
}
template
<
class
F
,
class
N
>
__device__
void
local_wave_stride
(
N
n
,
F
f
)
const
{
for_stride
<
fals
e
>
(
local_wave
(),
n
,
nlocal_wave
(),
f
);
for_stride
<
tru
e
>
(
local_wave
(),
n
,
nlocal_wave
(),
f
);
}
};
...
...
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
View file @
2dc6894c
...
...
@@ -31,30 +31,66 @@
namespace
migraphx
{
constexpr
bool
is_power_of_2
(
unsigned
int
x
)
{
return
x
>
0
&&
!
(
x
&
(
x
-
1
));
}
#if MIGRAPHX_HAS_DPP
template
<
class
T
,
class
Op
>
template
<
unsigned
int
SubWaveSize
,
class
T
,
class
Op
>
__device__
void
dpp_reduce
(
T
&
in
,
Op
op
)
{
static_assert
(
SubWaveSize
<=
__AMDGCN_WAVEFRONT_SIZE
,
"Too large subwave size"
);
static_assert
(
is_power_of_2
(
SubWaveSize
),
"SubWaveSize is not a power of 2"
);
T
out
{};
out
=
dpp_mov
<
dpp_row_shr
(
1
)
>
(
in
);
in
=
op
(
in
,
out
);
out
=
dpp_mov
<
dpp_row_shr
(
2
)
>
(
in
);
in
=
op
(
in
,
out
);
out
=
dpp_mov
<
dpp_row_shr
(
4
),
0xf
,
0xe
>
(
in
);
in
=
op
(
in
,
out
);
out
=
dpp_mov
<
dpp_row_shr
(
8
),
0xf
,
0xc
>
(
in
);
in
=
op
(
in
,
out
);
if
constexpr
(
SubWaveSize
>
1
)
{
out
=
dpp_mov
<
dpp_row_shr
(
1
)
>
(
in
);
in
=
op
(
in
,
out
);
}
if
constexpr
(
SubWaveSize
>
2
)
{
out
=
dpp_mov
<
dpp_row_shr
(
2
)
>
(
in
);
in
=
op
(
in
,
out
);
}
if
constexpr
(
SubWaveSize
>
4
)
{
out
=
dpp_mov
<
dpp_row_shr
(
4
),
0xf
,
0xe
>
(
in
);
in
=
op
(
in
,
out
);
}
if
constexpr
(
SubWaveSize
>
8
)
{
out
=
dpp_mov
<
dpp_row_shr
(
8
),
0xf
,
0xc
>
(
in
);
in
=
op
(
in
,
out
);
}
#if __AMDGCN_WAVEFRONT_SIZE == 32
out
=
dpp_swizzle
<
dpp_row_bcast
(
15
)
>
(
in
);
in
=
op
(
in
,
out
);
if
constexpr
(
SubWaveSize
>
16
)
{
out
=
dpp_swizzle
<
dpp_row_bcast
(
15
)
>
(
in
);
in
=
op
(
in
,
out
);
}
#else
out
=
dpp_mov
<
dpp_row_bcast
(
15
),
0xa
>
(
in
);
in
=
op
(
in
,
out
);
out
=
dpp_mov
<
dpp_row_bcast
(
31
),
0xc
>
(
in
);
in
=
op
(
in
,
out
);
if
constexpr
(
SubWaveSize
>
16
)
{
out
=
dpp_mov
<
dpp_row_bcast
(
15
),
0xa
>
(
in
);
in
=
op
(
in
,
out
);
}
if
constexpr
(
SubWaveSize
>
32
)
{
out
=
dpp_mov
<
dpp_row_bcast
(
31
),
0xc
>
(
in
);
in
=
op
(
in
,
out
);
}
#endif
}
template
<
class
T
,
class
Op
>
__device__
void
dpp_reduce
(
T
&
in
,
Op
op
)
{
dpp_reduce
<
__AMDGCN_WAVEFRONT_SIZE
>
(
in
,
op
);
}
#if defined(MIGRAPHX_USE_CLANG_TIDY) || defined(CPPCHECK)
// NOLINTNEXTLINE
#define MIGRAPHX_DPP_REDUCE_ASM(x, ins) x = 1
...
...
@@ -98,17 +134,24 @@ MIGRAPHX_DPP_REDUCE(op::product, v_mul, _u)
MIGRAPHX_DPP_REDUCE
(
op
::
max
,
v_max
,
_i
)
MIGRAPHX_DPP_REDUCE
(
op
::
min
,
v_min
,
_i
)
template
<
class
Op
,
class
T
,
class
Index
,
class
F
>
__device__
auto
wave_reduce
(
index
idx
,
Op
op
,
T
init
,
Index
n
,
F
f
)
template
<
unsigned
int
SubWaveSize
,
class
Op
,
class
T
,
class
Index
,
class
F
>
__device__
auto
subwave_reduce
(
index
idx
,
Op
op
,
T
init
,
Index
n
,
F
f
)
{
MIGRAPHX_ASSERT
(
idx
.
max_nlocal
()
==
idx
.
nlocal
());
using
type
=
decltype
(
index
::
invoke_loop
(
f
,
0
,
_c
<
0
>
));
type
x
=
init
;
idx
.
local_wave_stride
(
n
,
[
&
](
auto
i
,
auto
d
)
{
x
=
op
(
x
,
index
::
invoke_loop
(
f
,
i
,
d
));
});
dpp_reduce
(
x
,
op
);
idx
.
local_
sub
wave_stride
<
SubWaveSize
>
(
n
,
[
&
](
auto
i
,
auto
d
)
{
x
=
op
(
x
,
index
::
invoke_loop
(
f
,
i
,
d
));
});
dpp_reduce
<
SubWaveSize
>
(
x
,
op
);
return
x
;
}
template
<
class
Op
,
class
T
,
class
Index
,
class
F
>
__device__
auto
wave_reduce
(
index
idx
,
Op
op
,
T
init
,
Index
n
,
F
f
)
{
return
subwave_reduce
<
__AMDGCN_WAVEFRONT_SIZE
>
(
idx
,
op
,
init
,
n
,
f
);
}
template
<
class
Op
,
class
T
,
class
Index
,
class
F
>
__device__
auto
block_reduce
(
index
idx
,
Op
op
,
T
init
,
Index
n
,
F
f
)
{
...
...
@@ -486,7 +529,8 @@ struct block_large
}
};
struct
wave
template
<
unsigned
int
SubWaveSize
>
struct
subwave
{
template
<
class
Slicer
>
struct
reducer
:
reducer_base
<
reducer
<
Slicer
>>
...
...
@@ -515,7 +559,7 @@ struct wave
template
<
class
Op
,
class
T
,
class
Read
,
class
N
,
class
...
Ts
>
__device__
auto
reduce_impl
(
Op
op
,
T
init
,
Read
read
,
N
n
,
Ts
&&
...
xs
)
const
{
return
wave_reduce
(
idx
,
op
,
init
,
n
,
[
&
](
auto
j
,
auto
d
)
{
return
sub
wave_reduce
<
SubWaveSize
>
(
idx
,
op
,
init
,
n
,
[
&
](
auto
j
,
auto
d
)
{
return
vec_reduce
(
read
(
xs
(
j
,
d
)...),
op
);
});
}
...
...
@@ -523,7 +567,7 @@ struct wave
template
<
class
F
>
__device__
void
outer
(
F
f
)
const
{
if
(
idx
.
local_wave
()
==
0
)
if
(
idx
.
local_
sub
wave
<
SubWaveSize
>
()
==
0
)
f
();
}
...
...
@@ -536,9 +580,9 @@ struct wave
template
<
class
R
,
class
F
,
class
N
,
class
...
Ts
>
__device__
auto
inner_impl
(
F
f
,
N
n
,
Ts
&&
...
xs
)
const
{
using
max_iterations
=
decltype
(
idx
.
max_local_wave_stride_iterations
(
n
));
using
max_iterations
=
decltype
(
idx
.
max_local_
sub
wave_stride_iterations
<
SubWaveSize
>
(
n
));
inner_storage
<
R
,
max_iterations
{},
N
>
storage
;
idx
.
local_wave_stride
(
n
,
[
&
](
auto
j
,
auto
d
)
{
storage
(
j
,
d
)
=
f
(
xs
(
j
,
d
)...);
});
idx
.
local_
sub
wave_stride
<
SubWaveSize
>
(
n
,
[
&
](
auto
j
,
auto
d
)
{
storage
(
j
,
d
)
=
f
(
xs
(
j
,
d
)...);
});
return
storage
;
}
};
...
...
@@ -554,13 +598,15 @@ struct wave
{
auto
idx
=
make_index
();
constexpr
auto
nelements
=
get_shape_c
<
Output
>
{}.
elements
();
idx
.
global_stride
(
nelements
*
idx
.
nlocal_wave
(),
[
&
](
auto
i
)
{
const
auto
out_idx
=
get_shape_c
<
Output
>
{}.
multi
(
i
/
idx
.
nlocal_wave
());
idx
.
global_stride
(
nelements
*
idx
.
nlocal_
sub
wave
<
SubWaveSize
>
(),
[
&
](
auto
i
)
{
const
auto
out_idx
=
get_shape_c
<
Output
>
{}.
multi
(
i
/
idx
.
nlocal_
sub
wave
<
SubWaveSize
>
());
f
(
out_idx
,
make
(
idx
,
[
&
](
auto
input
)
{
return
reduce_slice
<
Output
>
(
input
,
out_idx
);
}));
});
}
};
using
wave
=
subwave
<
__AMDGCN_WAVEFRONT_SIZE
>
;
struct
lane
{
template
<
class
Slicer
>
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment