Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
efa5dcce
Commit
efa5dcce
authored
May 03, 2022
by
Paul
Browse files
Add softmax kernel
parent
4a5a23a4
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
148 additions
and
1 deletion
+148
-1
src/targets/gpu/jit/softmax.cpp
src/targets/gpu/jit/softmax.cpp
+76
-0
src/targets/gpu/kernels/include/migraphx/kernels/array.hpp
src/targets/gpu/kernels/include/migraphx/kernels/array.hpp
+9
-0
src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp
...rgets/gpu/kernels/include/migraphx/kernels/functional.hpp
+1
-1
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
+37
-0
src/targets/gpu/kernels/include/migraphx/kernels/softmax.hpp
src/targets/gpu/kernels/include/migraphx/kernels/softmax.hpp
+24
-0
src/targets/gpu/kernels/include/migraphx/kernels/vec.hpp
src/targets/gpu/kernels/include/migraphx/kernels/vec.hpp
+1
-0
No files found.
src/targets/gpu/jit/softmax.cpp
0 → 100644
View file @
efa5dcce
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
static
const
char
*
const
softmax_kernel
=
R"__migraphx__(
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/softmax.hpp>
#include <args.hpp>
namespace migraphx {
extern "C" {
__global__ void softmax_kernel(void* input_p, void* output_p)
{
make_tensors()(input_p, output_p)([](auto input, auto output) {
softmax<${axis}>(input, output);
});
}
}
} // namespace migraphx
)__migraphx__"
;
constexpr
std
::
size_t
compute_block_size
(
std
::
size_t
n
,
std
::
size_t
max_block_size
=
1024
)
{
size_t
block_size
=
128
;
while
(
block_size
<=
max_block_size
and
block_size
<=
n
)
block_size
*=
2
;
return
block_size
/
2
;
}
struct
softmax_compiler
:
compiler
<
softmax_compiler
>
{
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"softmax"
};
}
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
{
auto
axis
=
v
.
at
(
"axis"
).
to
<
int64_t
>
();
auto
block_size
=
compute_block_size
(
inputs
[
0
].
lens
()[
axis
],
256
);
hip_compile_options
options
;
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
inputs
.
back
().
elements
(),
block_size
),
256
);
options
.
output
=
inputs
.
back
();
options
.
inputs
=
inputs
;
options
.
kernel_name
=
"softmax_kernel"
;
auto
src
=
interpolate_string
(
softmax_kernel
,
{{
"axis"
,
to_string
(
axis
)}});
return
compile_hip_code_object
(
src
,
options
);
}
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
)
const
{
return
replace
(
compile_op
(
ctx
,
to_shapes
(
ins
->
inputs
()),
op
.
to_value
()));
}
};
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/kernels/include/migraphx/kernels/array.hpp
View file @
efa5dcce
...
...
@@ -4,6 +4,7 @@
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/type_traits.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/functional.hpp>
#include <migraphx/kernels/debug.hpp>
namespace
migraphx
{
...
...
@@ -190,6 +191,14 @@ constexpr auto transform(integral_const_array<T, Xs...>, F f)
return
integral_const_array
<
T
,
f
(
Xs
)...
>
{};
}
template
<
class
T
,
T
...
Xs
,
class
F
>
constexpr
auto
transform_i
(
integral_const_array
<
T
,
Xs
...
>
,
F
f
)
{
return
sequence_c
<
sizeof
...(
Xs
)
>
([
=
](
auto
...
is
)
{
return
integral_const_array
<
T
,
f
(
Xs
,
is
)...
>
{};
});
}
template
<
class
T
,
T
...
Xs
,
class
U
,
U
...
Ys
,
class
F
>
constexpr
auto
transform
(
integral_const_array
<
T
,
Xs
...
>
,
integral_const_array
<
U
,
Ys
...
>
,
F
f
)
{
...
...
src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp
View file @
efa5dcce
#ifndef MIGRAPHX_GUARD_KERNELS_FUNCTIONAL_HPP
#define MIGRAPHX_GUARD_KERNELS_FUNCTIONAL_HPP
#include <migraphx/kernels/
array
.hpp>
#include <migraphx/kernels/
integral_constant
.hpp>
namespace
migraphx
{
...
...
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
View file @
efa5dcce
...
...
@@ -152,6 +152,21 @@ constexpr auto sliced(Slicer slicer, F f)
};
}
template
<
class
Input
,
index_int
Axis
>
constexpr
auto
compute_reduce_axis
()
{
constexpr
auto
lens
=
transform_i
(
get_shape_c
<
Input
>
{}.
lens
,
[](
index_int
x
,
index_int
i
)
->
index_int
{
if
(
i
==
Axis
)
return
1
;
return
x
;
});
return
make_shape
(
lens
,
get_shape_c
<
Input
>
{}.
strides
);
}
template
<
class
Input
,
index_int
Axis
>
using
with_axis
=
decltype
(
compute_reduce_axis
<
Input
,
Axis
>
());
struct
block
{
template
<
class
Slicer
>
...
...
@@ -175,6 +190,17 @@ struct block
if
(
idx
.
local
==
0
)
f
();
}
template
<
class
T
,
class
...
Ts
>
__device__
auto
inner
(
T
x
,
Ts
...
xs
)
const
{
return
[
=
](
auto
f
)
{
// TODO: Assert same elements
idx
.
local_stride
(
x
.
elements
(),
[
&
](
auto
j
)
{
f
(
x
[
j
],
xs
[
j
]...);
});
};
}
};
template
<
class
Slicer
>
...
...
@@ -221,6 +247,17 @@ struct lane
{
f
();
}
template
<
class
T
,
class
...
Ts
>
__device__
auto
inner
(
T
x
,
Ts
...
xs
)
const
{
return
[
=
](
auto
f
)
{
for
(
index_int
j
=
0
;
j
<
x
.
get_shape
().
elements
();
j
++
)
{
f
(
x
[
j
],
xs
[
j
]...);
}
};
}
};
template
<
class
Slicer
>
...
...
src/targets/gpu/kernels/include/migraphx/kernels/softmax.hpp
0 → 100644
View file @
efa5dcce
#ifndef MIGRAPHX_GUARD_KERNELS_SOFTMAX_HPP
#define MIGRAPHX_GUARD_KERNELS_SOFTMAX_HPP
#include <migraphx/kernels/reduce.hpp>
#include <migraphx/kernels/basic_ops.hpp>
namespace
migraphx
{
template
<
index_int
Axis
,
class
Input
,
class
Output
>
void
softmax
(
Input
input
,
Output
output
)
{
reduce
::
block
::
run
<
reduce
::
with_axis
<
Input
,
Axis
>>
([
&
](
auto
,
auto
r
)
{
auto
batch_max
=
r
.
reduce
(
op
::
max
{},
lowest
{},
op
::
id
{})(
input
);
auto
batch_sum
=
r
.
reduce
(
op
::
sum
{},
0
,
[
&
](
auto
x
)
{
return
migraphx
::
exp
(
x
-
batch_max
);
})(
input
);
r
.
outer
(
output
,
input
)([
&
](
auto
&
y
,
auto
x
)
{
y
=
migraphx
::
exp
(
x
-
batch_max
)
/
batch_sum
;
});
});
}
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_SOFTMAX_HPP
src/targets/gpu/kernels/include/migraphx/kernels/vec.hpp
View file @
efa5dcce
...
...
@@ -4,6 +4,7 @@
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/functional.hpp>
#include <migraphx/kernels/debug.hpp>
namespace
migraphx
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment