Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
658110e1
Commit
658110e1
authored
Nov 12, 2023
by
Paul
Browse files
Add wave reduction
parent
d8011adf
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
107 additions
and
0 deletions
+107
-0
src/targets/gpu/kernels/include/migraphx/kernels/index.hpp
src/targets/gpu/kernels/include/migraphx/kernels/index.hpp
+18
-0
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
+89
-0
No files found.
src/targets/gpu/kernels/include/migraphx/kernels/index.hpp
View file @
658110e1
...
@@ -134,6 +134,12 @@ struct index
...
@@ -134,6 +134,12 @@ struct index
#endif
#endif
constexpr
auto
ngroup
()
const
{
return
nglobal
()
/
max_nlocal
();
}
constexpr
auto
ngroup
()
const
{
return
nglobal
()
/
max_nlocal
();
}
constexpr
index_constant
<
__AMDGCN_WAVEFRONT_SIZE
>
nlocal_wave
()
const
{
return
{};
}
constexpr
auto
local_wave
()
const
{
return
local
%
nlocal_wave
();
}
constexpr
auto
nwave
()
const
{
return
max_nlocal
()
/
nlocal_wave
();
}
constexpr
auto
wave
()
const
{
return
local
/
nlocal_wave
();
}
template
<
class
N
,
class
Stride
>
template
<
class
N
,
class
Stride
>
static
constexpr
auto
max_stride_iterations
(
N
n
,
Stride
stride
)
static
constexpr
auto
max_stride_iterations
(
N
n
,
Stride
stride
)
{
{
...
@@ -152,6 +158,12 @@ struct index
...
@@ -152,6 +158,12 @@ struct index
return
max_stride_iterations
(
n
,
nlocal
());
return
max_stride_iterations
(
n
,
nlocal
());
}
}
template
<
class
N
>
constexpr
auto
max_local_wave_stride_iterations
(
N
n
)
const
{
return
max_stride_iterations
(
n
,
nlocal_wave
());
}
template
<
class
F
,
class
I
,
class
D
>
template
<
class
F
,
class
I
,
class
D
>
static
constexpr
auto
invoke_loop
(
F
f
,
I
i
,
D
d
)
->
decltype
(
f
(
i
,
d
))
static
constexpr
auto
invoke_loop
(
F
f
,
I
i
,
D
d
)
->
decltype
(
f
(
i
,
d
))
{
{
...
@@ -241,6 +253,12 @@ struct index
...
@@ -241,6 +253,12 @@ struct index
{
{
for_stride
<
false
>
(
group
,
n
,
ngroup
(),
f
);
for_stride
<
false
>
(
group
,
n
,
ngroup
(),
f
);
}
}
template
<
class
F
,
class
N
>
__device__
void
local_wave_stride
(
N
n
,
F
f
)
const
{
for_stride
<
false
>
(
local_wave
(),
n
,
nlocal_wave
(),
f
);
}
};
};
#ifdef MIGRAPHX_NLOCAL
#ifdef MIGRAPHX_NLOCAL
...
...
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
View file @
658110e1
...
@@ -95,10 +95,24 @@ MIGRAPHX_DPP_REDUCE(op::product, v_mul, _u)
...
@@ -95,10 +95,24 @@ MIGRAPHX_DPP_REDUCE(op::product, v_mul, _u)
MIGRAPHX_DPP_REDUCE
(
op
::
max
,
v_max
,
_i
)
MIGRAPHX_DPP_REDUCE
(
op
::
max
,
v_max
,
_i
)
MIGRAPHX_DPP_REDUCE
(
op
::
min
,
v_min
,
_i
)
MIGRAPHX_DPP_REDUCE
(
op
::
min
,
v_min
,
_i
)
template
<
class
Op
,
class
T
,
class
Index
,
class
F
>
__device__
auto
wave_reduce
(
index
idx
,
Op
op
,
T
init
,
Index
n
,
F
f
)
{
MIGRAPHX_ASSERT
(
idx
.
max_nlocal
()
==
idx
.
nlocal
());
using
type
=
decltype
(
index
::
invoke_loop
(
f
,
0
,
_c
<
0
>
));
type
x
=
init
;
idx
.
local_wave_stride
(
n
,
[
&
](
auto
i
,
auto
d
)
{
x
=
op
(
x
,
index
::
invoke_loop
(
f
,
i
,
d
));
});
dpp_reduce
(
x
,
op
);
return
x
;
}
template
<
class
Op
,
class
T
,
class
Index
,
class
F
>
template
<
class
Op
,
class
T
,
class
Index
,
class
F
>
__device__
auto
block_reduce
(
index
idx
,
Op
op
,
T
init
,
Index
n
,
F
f
)
__device__
auto
block_reduce
(
index
idx
,
Op
op
,
T
init
,
Index
n
,
F
f
)
{
{
MIGRAPHX_ASSERT
(
idx
.
max_nlocal
()
==
idx
.
nlocal
());
MIGRAPHX_ASSERT
(
idx
.
max_nlocal
()
==
idx
.
nlocal
());
if
(
idx
.
max_nlocal
()
==
idx
.
nlocal_wave
())
return
wave_reduce
(
idx
,
op
,
init
,
n
,
f
);
#if __AMDGCN_WAVEFRONT_SIZE == 32
#if __AMDGCN_WAVEFRONT_SIZE == 32
constexpr
index_int
lanes_per_thread
=
16
;
constexpr
index_int
lanes_per_thread
=
16
;
#else
#else
...
@@ -470,6 +484,81 @@ struct block_large
...
@@ -470,6 +484,81 @@ struct block_large
}
}
};
};
struct
wave
{
template
<
class
Slicer
>
struct
reducer
:
reducer_base
<
reducer
<
Slicer
>>
{
index
idx
;
Slicer
slice
;
template
<
class
T
,
index_int
N
,
class
Size
>
struct
inner_storage
:
inner_storage_tag
{
using
type
=
T
;
array
<
T
,
N
>
arr
;
constexpr
Size
rsize
()
const
{
return
{};
}
template
<
class
U
,
class
V
>
constexpr
auto
&
operator
()(
U
,
V
d
)
const
{
return
arr
[
d
];
}
template
<
class
U
,
class
V
>
constexpr
auto
&
operator
()(
U
,
V
d
)
{
return
arr
[
d
];
}
};
template
<
class
Op
,
class
T
,
class
Read
,
class
N
,
class
...
Ts
>
__device__
auto
reduce_impl
(
Op
op
,
T
init
,
Read
read
,
N
n
,
Ts
&&
...
xs
)
const
{
return
wave_reduce
(
idx
,
op
,
init
,
n
,
[
&
](
auto
j
,
auto
d
)
{
return
vec_reduce
(
read
(
xs
(
j
,
d
)...),
op
);
});
}
template
<
class
F
>
__device__
void
outer
(
F
f
)
const
{
if
(
idx
.
local
==
0
)
f
();
}
template
<
class
F
,
class
N
,
class
...
Ts
>
__device__
void
inner_void_impl
(
F
f
,
N
n
,
Ts
&&
...
xs
)
const
{
idx
.
local_wave_stride
(
n
,
[
&
](
auto
j
,
auto
d
)
{
f
(
xs
(
j
,
d
)...);
});
}
template
<
class
R
,
class
F
,
class
N
,
class
...
Ts
>
__device__
auto
inner_impl
(
F
f
,
N
n
,
Ts
&&
...
xs
)
const
{
using
max_iterations
=
decltype
(
idx
.
max_local_wave_stride_iterations
(
n
));
inner_storage
<
R
,
max_iterations
{},
N
>
storage
;
idx
.
local_wave_stride
(
n
,
[
&
](
auto
j
,
auto
d
)
{
storage
(
j
,
d
)
=
f
(
xs
(
j
,
d
)...);
});
return
storage
;
}
};
template
<
class
Slicer
>
static
__device__
auto
make
(
index
idx
,
Slicer
slicer
)
{
return
reducer
<
Slicer
>
{{},
idx
,
slicer
};
}
template
<
class
Output
,
class
F
>
static
__device__
void
run
(
F
f
)
{
auto
idx
=
make_index
();
constexpr
auto
nelements
=
get_shape_c
<
Output
>
{}.
elements
();
idx
.
global_stride
(
nelements
*
idx
.
nlocal_wave
(),
[
&
](
auto
i
)
{
const
auto
out_idx
=
get_shape_c
<
Output
>
{}.
multi
(
i
/
idx
.
nlocal_wave
());
f
(
out_idx
,
make
(
idx
,
[
&
](
auto
input
)
{
return
reduce_slice
<
Output
>
(
input
,
out_idx
);
}));
});
}
};
struct
lane
struct
lane
{
{
template
<
class
Slicer
>
template
<
class
Slicer
>
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment