Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
308db690
Commit
308db690
authored
Nov 13, 2023
by
Paul
Browse files
Format
parent
2dc6894c
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
37 additions
and
27 deletions
+37
-27
src/targets/gpu/jit/reduce.cpp
src/targets/gpu/jit/reduce.cpp
+14
-11
src/targets/gpu/kernels/include/migraphx/kernels/index.hpp
src/targets/gpu/kernels/include/migraphx/kernels/index.hpp
+15
-6
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
+8
-10
No files found.
src/targets/gpu/jit/reduce.cpp
View file @
308db690
...
@@ -111,7 +111,7 @@ static std::string get_reduce_algo(context& ctx, const std::vector<shape>& input
...
@@ -111,7 +111,7 @@ static std::string get_reduce_algo(context& ctx, const std::vector<shape>& input
[](
auto
len
,
auto
stride
)
{
return
len
==
1
?
init
:
stride
;
});
[](
auto
len
,
auto
stride
)
{
return
len
==
1
?
init
:
stride
;
});
if
(
min_stride
>
2
)
if
(
min_stride
>
2
)
return
"lane"
;
return
"lane"
;
if
(
relements
<=
ctx
.
get_current_device
().
get_wavefront_size
())
if
(
relements
<=
ctx
.
get_current_device
().
get_wavefront_size
())
return
"wave"
;
return
"wave"
;
return
"block"
;
return
"block"
;
}
}
...
@@ -176,8 +176,9 @@ struct simple_reduce_compiler : compiler<simple_reduce_compiler>
...
@@ -176,8 +176,9 @@ struct simple_reduce_compiler : compiler<simple_reduce_compiler>
{
{
auto
subwave_size
=
compute_subwave_size
(
ctx
,
relements
);
auto
subwave_size
=
compute_subwave_size
(
ctx
,
relements
);
algo
=
"subwave<"
+
std
::
to_string
(
subwave_size
)
+
">"
;
algo
=
"subwave<"
+
std
::
to_string
(
subwave_size
)
+
">"
;
options
.
set_launch_params
(
options
.
set_launch_params
(
v
,
v
,
compute_global_for
(
ctx
,
nelements
*
subwave_size
,
256
),
ctx
.
get_current_device
().
get_wavefront_size
());
compute_global_for
(
ctx
,
nelements
*
subwave_size
,
256
),
ctx
.
get_current_device
().
get_wavefront_size
());
}
}
}
}
else
if
(
algo
==
"lane"
)
else
if
(
algo
==
"lane"
)
...
@@ -263,14 +264,15 @@ struct fused_reduce_compiler : compiler<fused_reduce_compiler>
...
@@ -263,14 +264,15 @@ struct fused_reduce_compiler : compiler<fused_reduce_compiler>
auto
faxis
=
find_fast_axis
({
options
.
virtual_inputs
.
front
()});
auto
faxis
=
find_fast_axis
({
options
.
virtual_inputs
.
front
()});
vectorize
vec
{};
vectorize
vec
{};
auto
nelements
=
reduce_output_shape
.
elements
();
auto
nelements
=
reduce_output_shape
.
elements
();
auto
algo
=
v
.
get
(
"algo"
,
get_reduce_algo
(
ctx
,
options
.
virtual_inputs
,
reduction_shape
.
lens
()));
auto
algo
=
v
.
get
(
"algo"
,
get_reduce_algo
(
ctx
,
options
.
virtual_inputs
,
reduction_shape
.
lens
()));
if
(
algo
==
"block"
or
algo
==
"wave"
)
if
(
algo
==
"block"
or
algo
==
"wave"
)
{
{
// Vectorize if the axis is a reduction axis
// Vectorize if the axis is a reduction axis
if
(
reduce_output_shape
.
lens
()[
faxis
]
==
1
)
if
(
reduce_output_shape
.
lens
()[
faxis
]
==
1
)
vec
=
vectorize
::
elements
(
ctx
,
faxis
,
options
.
virtual_inputs
);
vec
=
vectorize
::
elements
(
ctx
,
faxis
,
options
.
virtual_inputs
);
auto
relements
=
reduction_shape
.
elements
()
/
vec
.
size
;
auto
relements
=
reduction_shape
.
elements
()
/
vec
.
size
;
if
(
algo
==
"block"
)
if
(
algo
==
"block"
)
{
{
auto
block_size
=
compute_block_size
(
relements
,
256
);
auto
block_size
=
compute_block_size
(
relements
,
256
);
if
(
relements
>=
block_size
*
256
)
if
(
relements
>=
block_size
*
256
)
...
@@ -282,8 +284,9 @@ struct fused_reduce_compiler : compiler<fused_reduce_compiler>
...
@@ -282,8 +284,9 @@ struct fused_reduce_compiler : compiler<fused_reduce_compiler>
{
{
auto
subwave_size
=
compute_subwave_size
(
ctx
,
relements
);
auto
subwave_size
=
compute_subwave_size
(
ctx
,
relements
);
algo
=
"subwave<"
+
std
::
to_string
(
subwave_size
)
+
">"
;
algo
=
"subwave<"
+
std
::
to_string
(
subwave_size
)
+
">"
;
options
.
set_launch_params
(
options
.
set_launch_params
(
v
,
v
,
compute_global_for
(
ctx
,
nelements
*
subwave_size
,
256
),
ctx
.
get_current_device
().
get_wavefront_size
());
compute_global_for
(
ctx
,
nelements
*
subwave_size
,
256
),
ctx
.
get_current_device
().
get_wavefront_size
());
}
}
}
}
else
if
(
algo
==
"lane"
)
else
if
(
algo
==
"lane"
)
...
...
src/targets/gpu/kernels/include/migraphx/kernels/index.hpp
View file @
308db690
...
@@ -135,12 +135,21 @@ struct index
...
@@ -135,12 +135,21 @@ struct index
constexpr
auto
ngroup
()
const
{
return
nglobal
()
/
max_nlocal
();
}
constexpr
auto
ngroup
()
const
{
return
nglobal
()
/
max_nlocal
();
}
template
<
unsigned
int
SubWaveSize
>
template
<
unsigned
int
SubWaveSize
>
constexpr
index_constant
<
SubWaveSize
>
nlocal_subwave
()
const
{
return
{};
}
constexpr
index_constant
<
SubWaveSize
>
nlocal_subwave
()
const
template
<
unsigned
int
SubWaveSize
>
{
constexpr
auto
local_subwave
()
const
{
return
local
%
nlocal_subwave
<
SubWaveSize
>
();
}
return
{};
template
<
unsigned
int
SubWaveSize
>
}
constexpr
auto
nwave
()
const
{
return
max_nlocal
()
/
nlocal_subwave
<
SubWaveSize
>
();
}
template
<
unsigned
int
SubWaveSize
>
constexpr
auto
local_subwave
()
const
{
return
local
%
nlocal_subwave
<
SubWaveSize
>
();
}
template
<
unsigned
int
SubWaveSize
>
constexpr
auto
nwave
()
const
{
return
max_nlocal
()
/
nlocal_subwave
<
SubWaveSize
>
();
}
constexpr
index_constant
<
__AMDGCN_WAVEFRONT_SIZE
>
nlocal_wave
()
const
{
return
{};
}
constexpr
index_constant
<
__AMDGCN_WAVEFRONT_SIZE
>
nlocal_wave
()
const
{
return
{};
}
constexpr
auto
local_wave
()
const
{
return
local
%
nlocal_wave
();
}
constexpr
auto
local_wave
()
const
{
return
local
%
nlocal_wave
();
}
...
...
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
View file @
308db690
...
@@ -31,11 +31,7 @@
...
@@ -31,11 +31,7 @@
namespace
migraphx
{
namespace
migraphx
{
constexpr
bool
is_power_of_2
(
unsigned
int
x
)
{
return
x
>
0
&&
!
(
x
&
(
x
-
1
));
}
constexpr
bool
is_power_of_2
(
unsigned
int
x
)
{
return
x
>
0
&&
!
(
x
&
(
x
-
1
));
}
#if MIGRAPHX_HAS_DPP
#if MIGRAPHX_HAS_DPP
...
@@ -134,14 +130,14 @@ MIGRAPHX_DPP_REDUCE(op::product, v_mul, _u)
...
@@ -134,14 +130,14 @@ MIGRAPHX_DPP_REDUCE(op::product, v_mul, _u)
MIGRAPHX_DPP_REDUCE
(
op
::
max
,
v_max
,
_i
)
MIGRAPHX_DPP_REDUCE
(
op
::
max
,
v_max
,
_i
)
MIGRAPHX_DPP_REDUCE
(
op
::
min
,
v_min
,
_i
)
MIGRAPHX_DPP_REDUCE
(
op
::
min
,
v_min
,
_i
)
template
<
unsigned
int
SubWaveSize
,
class
Op
,
class
T
,
class
Index
,
class
F
>
template
<
unsigned
int
SubWaveSize
,
class
Op
,
class
T
,
class
Index
,
class
F
>
__device__
auto
subwave_reduce
(
index
idx
,
Op
op
,
T
init
,
Index
n
,
F
f
)
__device__
auto
subwave_reduce
(
index
idx
,
Op
op
,
T
init
,
Index
n
,
F
f
)
{
{
MIGRAPHX_ASSERT
(
idx
.
max_nlocal
()
==
idx
.
nlocal
());
MIGRAPHX_ASSERT
(
idx
.
max_nlocal
()
==
idx
.
nlocal
());
using
type
=
decltype
(
index
::
invoke_loop
(
f
,
0
,
_c
<
0
>
));
using
type
=
decltype
(
index
::
invoke_loop
(
f
,
0
,
_c
<
0
>
));
type
x
=
init
;
type
x
=
init
;
idx
.
local_subwave_stride
<
SubWaveSize
>
(
n
,
[
&
](
auto
i
,
auto
d
)
{
x
=
op
(
x
,
index
::
invoke_loop
(
f
,
i
,
d
));
});
idx
.
local_subwave_stride
<
SubWaveSize
>
(
n
,
[
&
](
auto
i
,
auto
d
)
{
x
=
op
(
x
,
index
::
invoke_loop
(
f
,
i
,
d
));
});
dpp_reduce
<
SubWaveSize
>
(
x
,
op
);
dpp_reduce
<
SubWaveSize
>
(
x
,
op
);
return
x
;
return
x
;
}
}
...
@@ -529,7 +525,7 @@ struct block_large
...
@@ -529,7 +525,7 @@ struct block_large
}
}
};
};
template
<
unsigned
int
SubWaveSize
>
template
<
unsigned
int
SubWaveSize
>
struct
subwave
struct
subwave
{
{
template
<
class
Slicer
>
template
<
class
Slicer
>
...
@@ -580,9 +576,11 @@ struct subwave
...
@@ -580,9 +576,11 @@ struct subwave
template
<
class
R
,
class
F
,
class
N
,
class
...
Ts
>
template
<
class
R
,
class
F
,
class
N
,
class
...
Ts
>
__device__
auto
inner_impl
(
F
f
,
N
n
,
Ts
&&
...
xs
)
const
__device__
auto
inner_impl
(
F
f
,
N
n
,
Ts
&&
...
xs
)
const
{
{
using
max_iterations
=
decltype
(
idx
.
max_local_subwave_stride_iterations
<
SubWaveSize
>
(
n
));
using
max_iterations
=
decltype
(
idx
.
max_local_subwave_stride_iterations
<
SubWaveSize
>
(
n
));
inner_storage
<
R
,
max_iterations
{},
N
>
storage
;
inner_storage
<
R
,
max_iterations
{},
N
>
storage
;
idx
.
local_subwave_stride
<
SubWaveSize
>
(
n
,
[
&
](
auto
j
,
auto
d
)
{
storage
(
j
,
d
)
=
f
(
xs
(
j
,
d
)...);
});
idx
.
local_subwave_stride
<
SubWaveSize
>
(
n
,
[
&
](
auto
j
,
auto
d
)
{
storage
(
j
,
d
)
=
f
(
xs
(
j
,
d
)...);
});
return
storage
;
return
storage
;
}
}
};
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment