Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
67c92b83
Commit
67c92b83
authored
Dec 08, 2023
by
Paul
Browse files
Fx subwave implementation
parent
5be87179
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
25 additions
and
13 deletions
+25
-13
src/targets/gpu/jit/reduce.cpp
src/targets/gpu/jit/reduce.cpp
+1
-1
src/targets/gpu/kernels/include/migraphx/kernels/dpp.hpp
src/targets/gpu/kernels/include/migraphx/kernels/dpp.hpp
+9
-0
src/targets/gpu/kernels/include/migraphx/kernels/index.hpp
src/targets/gpu/kernels/include/migraphx/kernels/index.hpp
+4
-0
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
+11
-12
No files found.
src/targets/gpu/jit/reduce.cpp
View file @
67c92b83
...
...
@@ -126,7 +126,7 @@ static std::size_t compute_subwave_size(context& ctx, std::size_t n)
{
std
::
size_t
max_wavefront_size
=
ctx
.
get_current_device
().
get_wavefront_size
();
std
::
size_t
wavefront_size
=
1
;
while
(
wavefront_size
<
n
and
wavefront_size
<
max_wavefront_size
)
while
(
wavefront_size
<
=
n
and
wavefront_size
<
max_wavefront_size
)
wavefront_size
*=
2
;
return
wavefront_size
;
}
...
...
src/targets/gpu/kernels/include/migraphx/kernels/dpp.hpp
View file @
67c92b83
...
...
@@ -30,6 +30,8 @@
namespace
migraphx
{
constexpr
bool
is_power_of_2
(
unsigned
int
x
)
{
return
x
>
0
&&
!
(
x
&
(
x
-
1
));
}
#ifndef MIGRAPHX_HAS_DPP
#define MIGRAPHX_HAS_DPP 1
#endif
...
...
@@ -86,6 +88,13 @@ __device__ T dpp_swizzle(T& x)
return
dpp_op
(
x
,
[](
auto
i
)
{
return
__hip_ds_swizzle
(
i
,
Mask
);
});
}
template
<
unsigned
int
SrcLane
,
unsigned
int
Width
,
class
T
>
__device__
T
dpp_readlane
(
T
&
x
)
{
static_assert
(
is_power_of_2
(
Width
),
"Width must be a power of 2"
);
return
dpp_op
(
x
,
[](
auto
i
)
{
return
__shfl
(
i
,
SrcLane
,
Width
);
});
}
#endif // MIGRAPHX_HAS_DPP
}
// namespace migraphx
...
...
src/targets/gpu/kernels/include/migraphx/kernels/index.hpp
View file @
67c92b83
...
...
@@ -143,6 +143,10 @@ struct index
template
<
unsigned
int
SubWaveSize
>
constexpr
auto
local_subwave
()
const
{
#ifdef MIGRAPHX_HAS_CONST_LOCAL
if
constexpr
(
decltype
(
nlocal
()){}
==
SubWaveSize
)
return
local
;
#endif
return
local
%
nlocal_subwave
<
SubWaveSize
>
();
}
template
<
unsigned
int
SubWaveSize
>
...
...
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
View file @
67c92b83
...
...
@@ -32,8 +32,6 @@
namespace
migraphx
{
constexpr
bool
is_power_of_2
(
unsigned
int
x
)
{
return
x
>
0
&&
!
(
x
&
(
x
-
1
));
}
#if MIGRAPHX_HAS_DPP
template
<
unsigned
int
SubWaveSize
,
class
T
,
class
Op
>
...
...
@@ -41,42 +39,41 @@ __device__ void dpp_reduce(T& in, Op op)
{
static_assert
(
SubWaveSize
<=
__AMDGCN_WAVEFRONT_SIZE
,
"Too large subwave size"
);
static_assert
(
is_power_of_2
(
SubWaveSize
),
"SubWaveSize is not a power of 2"
);
T
out
{};
if
constexpr
(
SubWaveSize
>
1
)
{
out
=
dpp_mov
<
dpp_row_shr
(
1
)
>
(
in
);
auto
out
=
dpp_mov
<
dpp_row_shr
(
1
)
>
(
in
);
in
=
op
(
in
,
out
);
}
if
constexpr
(
SubWaveSize
>
2
)
{
out
=
dpp_mov
<
dpp_row_shr
(
2
)
>
(
in
);
auto
out
=
dpp_mov
<
dpp_row_shr
(
2
)
>
(
in
);
in
=
op
(
in
,
out
);
}
if
constexpr
(
SubWaveSize
>
4
)
{
out
=
dpp_mov
<
dpp_row_shr
(
4
),
0xf
,
0xe
>
(
in
);
auto
out
=
dpp_mov
<
dpp_row_shr
(
4
),
0xf
,
0xe
>
(
in
);
in
=
op
(
in
,
out
);
}
if
constexpr
(
SubWaveSize
>
8
)
{
out
=
dpp_mov
<
dpp_row_shr
(
8
),
0xf
,
0xc
>
(
in
);
auto
out
=
dpp_mov
<
dpp_row_shr
(
8
),
0xf
,
0xc
>
(
in
);
in
=
op
(
in
,
out
);
}
#if __AMDGCN_WAVEFRONT_SIZE == 32
if
constexpr
(
SubWaveSize
>
16
)
{
out
=
dpp_swizzle
<
0x1e0
>
(
in
);
auto
out
=
dpp_swizzle
<
0x1e0
>
(
in
);
in
=
op
(
in
,
out
);
}
#else
if
constexpr
(
SubWaveSize
>
16
)
{
out
=
dpp_mov
<
dpp_row_bcast
(
15
),
0xa
>
(
in
);
auto
out
=
dpp_mov
<
dpp_row_bcast
(
15
),
0xa
>
(
in
);
in
=
op
(
in
,
out
);
}
if
constexpr
(
SubWaveSize
>
32
)
{
out
=
dpp_mov
<
dpp_row_bcast
(
31
),
0xc
>
(
in
);
auto
out
=
dpp_mov
<
dpp_row_bcast
(
31
),
0xc
>
(
in
);
in
=
op
(
in
,
out
);
}
#endif
...
...
@@ -173,9 +170,11 @@ __device__ auto subwave_reduce(index idx, Op op, T init, Index n, F f)
using
type
=
decltype
(
index
::
invoke_loop
(
f
,
0
,
_c
<
0
>
));
type
x
=
init
;
idx
.
local_subwave_stride
<
SubWaveSize
>
(
n
,
[
&
](
auto
i
,
auto
d
)
{
x
=
op
(
x
,
index
::
invoke_loop
(
f
,
i
,
d
));
});
n
,
[
&
](
auto
i
,
auto
d
)
{
x
=
op
(
x
,
index
::
invoke_loop
(
f
,
i
,
d
));
});
dpp_reduce
<
SubWaveSize
>
(
x
,
op
);
return
x
;
return
dpp_readlane
<
SubWaveSize
-
1
,
SubWaveSize
>
(
x
)
;
}
template
<
class
Op
,
class
T
,
class
Index
,
class
F
>
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment