Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
15fd8205
Commit
15fd8205
authored
May 10, 2022
by
Paul
Browse files
Add vectorization to reduction
parent
8a6ae079
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
41 additions
and
11 deletions
+41
-11
src/targets/gpu/include/migraphx/gpu/compile_gen.hpp
src/targets/gpu/include/migraphx/gpu/compile_gen.hpp
+3
-3
src/targets/gpu/jit/reduce.cpp
src/targets/gpu/jit/reduce.cpp
+19
-5
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
+2
-2
src/targets/gpu/kernels/include/migraphx/kernels/vec.hpp
src/targets/gpu/kernels/include/migraphx/kernels/vec.hpp
+14
-0
src/targets/gpu/kernels/include/migraphx/kernels/vectorize.hpp
...argets/gpu/kernels/include/migraphx/kernels/vectorize.hpp
+3
-1
No files found.
src/targets/gpu/include/migraphx/gpu/compile_gen.hpp
View file @
15fd8205
...
@@ -16,14 +16,14 @@ namespace gen {
...
@@ -16,14 +16,14 @@ namespace gen {
struct
vectorize
struct
vectorize
{
{
std
::
size_t
size
;
std
::
size_t
size
=
0
;
std
::
size_t
axis
;
std
::
size_t
axis
=
0
;
static
vectorize
elements
(
std
::
size_t
axis
,
const
std
::
vector
<
shape
>&
inputs
);
static
vectorize
elements
(
std
::
size_t
axis
,
const
std
::
vector
<
shape
>&
inputs
);
std
::
string
str
()
const
;
std
::
string
str
()
const
;
};
};
struct
preload
struct
preload
{
{
std
::
vector
<
bool
>
args
;
std
::
vector
<
bool
>
args
=
{}
;
static
preload
broadcasts
(
std
::
size_t
axis
,
const
std
::
vector
<
shape
>&
inputs
);
static
preload
broadcasts
(
std
::
size_t
axis
,
const
std
::
vector
<
shape
>&
inputs
);
bool
is_preloading
()
const
;
bool
is_preloading
()
const
;
std
::
string
str
()
const
;
std
::
string
str
()
const
;
...
...
src/targets/gpu/jit/reduce.cpp
View file @
15fd8205
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/ranges.hpp>
...
@@ -16,9 +17,12 @@ namespace migraphx {
...
@@ -16,9 +17,12 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
gpu
{
using
namespace
migraphx
::
gpu
::
gen
;
static
const
char
*
const
simple_reduce_kernel
=
R"__migraphx__(
static
const
char
*
const
simple_reduce_kernel
=
R"__migraphx__(
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/reduce.hpp>
#include <migraphx/kernels/reduce.hpp>
#include <migraphx/kernels/vectorize.hpp>
#include <args.hpp>
#include <args.hpp>
namespace migraphx {
namespace migraphx {
...
@@ -26,9 +30,10 @@ namespace migraphx {
...
@@ -26,9 +30,10 @@ namespace migraphx {
${preamble}
${preamble}
extern "C" {
extern "C" {
__global__ void kernel(void* input_p, void* output_p)
__global__ void
reduce_
kernel(void* input_p, void* output_p)
{
{
make_tensors()(input_p, output_p)([](auto input, auto output) {
transform_args(make_tensors(), ${transformers})(input_p, output_p)([](auto input, auto output) {
simple_reduce<reduce::${algo}>(${reduction}, ${init}, input, output, ${read}, ${write});
simple_reduce<reduce::${algo}>(${reduction}, ${init}, input, output, ${read}, ${write});
});
});
...
@@ -93,17 +98,24 @@ struct reduce_compiler : compiler<reduce_compiler>
...
@@ -93,17 +98,24 @@ struct reduce_compiler : compiler<reduce_compiler>
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
{
{
hip_compile_options
options
;
hip_compile_options
options
;
auto
reduce_elements
=
get_reduce_elements
(
inputs
);
auto
faxis
=
find_fast_axis
({
inputs
.
front
()});
vectorize
vec
{};
// Vectorize if the axis is a reduction axis
if
(
inputs
.
back
().
lens
()[
faxis
]
==
1
)
{
vec
=
vectorize
::
elements
(
faxis
,
inputs
);
}
auto
reduce_elements
=
get_reduce_elements
(
inputs
)
/
vec
.
size
;
auto
algo
=
v
.
get
(
"algo"
,
get_reduce_algo
(
inputs
));
auto
algo
=
v
.
get
(
"algo"
,
get_reduce_algo
(
inputs
));
if
(
algo
==
"block"
)
if
(
algo
==
"block"
)
{
{
auto
block_size
=
compute_block_size
(
reduce_elements
,
256
);
auto
block_size
=
compute_block_size
(
reduce_elements
,
256
);
options
.
set_launch_params
(
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
inputs
.
back
().
elements
()
*
block_size
,
256
),
block_size
);
v
,
compute_global_for
(
ctx
,
inputs
.
back
().
elements
()
*
block_size
/
vec
.
size
,
256
),
block_size
);
}
}
else
if
(
algo
==
"lane"
)
else
if
(
algo
==
"lane"
)
{
{
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
inputs
.
back
().
elements
(),
256
));
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
inputs
.
back
().
elements
()
/
vec
.
size
,
256
));
}
}
else
else
{
{
...
@@ -112,6 +124,7 @@ struct reduce_compiler : compiler<reduce_compiler>
...
@@ -112,6 +124,7 @@ struct reduce_compiler : compiler<reduce_compiler>
options
.
inputs
=
inputs
;
options
.
inputs
=
inputs
;
options
.
output
=
inputs
.
back
();
options
.
output
=
inputs
.
back
();
options
.
virtual_inputs
=
reduce_dims
(
inputs
);
options
.
virtual_inputs
=
reduce_dims
(
inputs
);
options
.
kernel_name
=
"reduce_kernel"
;
std
::
string
identity
=
"[](auto x) { return x; }"
;
std
::
string
identity
=
"[](auto x) { return x; }"
;
auto
src
=
interpolate_string
(
simple_reduce_kernel
,
auto
src
=
interpolate_string
(
simple_reduce_kernel
,
{{
"reduction"
,
v
.
at
(
"reduction"
).
to
<
std
::
string
>
()},
{{
"reduction"
,
v
.
at
(
"reduction"
).
to
<
std
::
string
>
()},
...
@@ -119,6 +132,7 @@ struct reduce_compiler : compiler<reduce_compiler>
...
@@ -119,6 +132,7 @@ struct reduce_compiler : compiler<reduce_compiler>
{
"read"
,
v
.
get
(
"read"
,
identity
)},
{
"read"
,
v
.
get
(
"read"
,
identity
)},
{
"write"
,
v
.
get
(
"write"
,
identity
)},
{
"write"
,
v
.
get
(
"write"
,
identity
)},
{
"algo"
,
algo
},
{
"algo"
,
algo
},
{
"transformers"
,
make_transformer_args
(
vec
)},
{
"preamble"
,
v
.
get
(
"preamble"
,
std
::
string
{})}});
{
"preamble"
,
v
.
get
(
"preamble"
,
std
::
string
{})}});
options
.
params
+=
"-Wno-float-equal"
;
options
.
params
+=
"-Wno-float-equal"
;
return
compile_hip_code_object
(
src
,
options
);
return
compile_hip_code_object
(
src
,
options
);
...
...
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
View file @
15fd8205
...
@@ -163,9 +163,9 @@ struct block
...
@@ -163,9 +163,9 @@ struct block
__device__
auto
reduce
(
Op
op
,
T
init
,
Read
read
)
const
__device__
auto
reduce
(
Op
op
,
T
init
,
Read
read
)
const
{
{
return
sliced
(
slicer
,
[
=
](
auto
x
,
auto
...
xs
)
{
return
sliced
(
slicer
,
[
=
](
auto
x
,
auto
...
xs
)
{
return
block_reduce
(
idx
,
op
,
init
,
x
.
get_shape
().
elements
(),
[
&
](
auto
j
)
{
return
vec_reduce
(
block_reduce
(
idx
,
op
,
init
,
x
.
get_shape
().
elements
(),
[
&
](
auto
j
)
{
return
read
(
x
[
j
],
xs
[
j
]...);
return
read
(
x
[
j
],
xs
[
j
]...);
});
})
,
op
)
;
});
});
}
}
...
...
src/targets/gpu/kernels/include/migraphx/kernels/vec.hpp
View file @
15fd8205
...
@@ -146,5 +146,19 @@ constexpr auto vec_packed_transform(Ts... xs)
...
@@ -146,5 +146,19 @@ constexpr auto vec_packed_transform(Ts... xs)
};
};
}
}
template
<
class
T
,
class
Op
>
constexpr
auto
vec_reduce
(
T
x
,
Op
op
)
{
if
constexpr
(
vec_size
<
T
>
()
<
2
)
return
x
;
else
{
vec_type
<
T
>
result
;
for
(
int
i
=
1
;
i
<
vec_size
<
T
>
();
i
++
)
result
=
op
(
result
[
i
-
1
],
result
[
i
]);
return
result
;
}
}
}
// namespace migraphx
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_VEC_HPP
#endif // MIGRAPHX_GUARD_KERNELS_VEC_HPP
src/targets/gpu/kernels/include/migraphx/kernels/vectorize.hpp
View file @
15fd8205
...
@@ -213,7 +213,9 @@ template <index_int N, index_int Axis, class T>
...
@@ -213,7 +213,9 @@ template <index_int N, index_int Axis, class T>
__device__
__host__
auto
vectorize_tensor
(
T
x
)
__device__
__host__
auto
vectorize_tensor
(
T
x
)
{
{
constexpr
auto
shape
=
get_shape_c
<
T
>
{};
constexpr
auto
shape
=
get_shape_c
<
T
>
{};
if
constexpr
(
shape
.
strides
[
Axis
]
==
0
)
if
constexpr
(
shape
.
lens
[
Axis
]
==
1
)
return
x
;
else
if
constexpr
(
shape
.
strides
[
Axis
]
==
0
)
return
tensor_step
<
N
>
(
x
,
_c
<
Axis
>
);
return
tensor_step
<
N
>
(
x
,
_c
<
Axis
>
);
else
else
return
as_vec
<
N
>
(
x
,
_c
<
Axis
>
);
return
as_vec
<
N
>
(
x
,
_c
<
Axis
>
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment