Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
8520e0b8
Unverified
Commit
8520e0b8
authored
Jul 05, 2022
by
Paul Fultz II
Committed by
GitHub
Jul 05, 2022
Browse files
Add jit softmax (#1243)
* Add softmax kernel
parent
27e980c4
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
201 additions
and
2 deletions
+201
-2
src/targets/gpu/compile_gen.cpp
src/targets/gpu/compile_gen.cpp
+3
-0
src/targets/gpu/jit/softmax.cpp
src/targets/gpu/jit/softmax.cpp
+107
-0
src/targets/gpu/kernels/include/migraphx/kernels/array.hpp
src/targets/gpu/kernels/include/migraphx/kernels/array.hpp
+8
-0
src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp
...rgets/gpu/kernels/include/migraphx/kernels/functional.hpp
+1
-1
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
+34
-0
src/targets/gpu/kernels/include/migraphx/kernels/shape.hpp
src/targets/gpu/kernels/include/migraphx/kernels/shape.hpp
+1
-0
src/targets/gpu/kernels/include/migraphx/kernels/softmax.hpp
src/targets/gpu/kernels/include/migraphx/kernels/softmax.hpp
+45
-0
src/targets/gpu/kernels/include/migraphx/kernels/vec.hpp
src/targets/gpu/kernels/include/migraphx/kernels/vec.hpp
+2
-0
src/targets/gpu/lowering.cpp
src/targets/gpu/lowering.cpp
+0
-1
No files found.
src/targets/gpu/compile_gen.cpp
View file @
8520e0b8
...
...
@@ -43,6 +43,9 @@ static std::vector<std::size_t> vector_sizes(const std::vector<shape>& inputs)
vectorize
vectorize
::
elements
(
std
::
size_t
axis
,
const
std
::
vector
<
shape
>&
inputs
)
{
if
(
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
&
](
const
auto
&
s
)
{
return
s
.
lens
()[
axis
]
==
1
;
}))
return
{
1
,
axis
};
auto
sizes
=
vector_sizes
(
inputs
);
std
::
vector
<
std
::
size_t
>
max_vec_size
;
std
::
transform
(
inputs
.
begin
(),
...
...
src/targets/gpu/jit/softmax.cpp
0 → 100644
View file @
8520e0b8
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
using
namespace
migraphx
::
gpu
::
gen
;
// NOLINT
static
const
char
*
const
softmax_kernel
=
R"__migraphx__(
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/softmax.hpp>
#include <migraphx/kernels/vectorize.hpp>
#include <args.hpp>
namespace migraphx {
extern "C" {
__global__ void softmax_kernel(void* input_p, void* output_p)
{
transform_args(make_tensors(), ${transformers})(input_p, output_p)([](auto input, auto output) {
softmax<${axis}>(input, output);
});
}
}
} // namespace migraphx
)__migraphx__"
;
struct
softmax_compiler
:
compiler
<
softmax_compiler
>
{
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"softmax"
};
}
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
{
// TODO: Use reduce_dims
auto
axis
=
v
.
at
(
"axis"
).
to
<
int64_t
>
();
auto
faxis
=
find_fast_axis
({
inputs
.
front
()});
vectorize
vec
{};
// Vectorize if the axis is a reduction axis
if
(
faxis
==
axis
)
{
vec
=
vectorize
::
elements
(
faxis
,
inputs
);
}
auto
relements
=
inputs
[
0
].
lens
()[
axis
]
/
vec
.
size
;
auto
nelements
=
(
inputs
.
back
().
elements
()
/
inputs
[
0
].
lens
()[
axis
]);
auto
block_size
=
compute_block_size
(
relements
,
256
);
hip_compile_options
options
;
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
nelements
*
block_size
,
256
),
block_size
);
options
.
output
=
inputs
.
back
();
options
.
inputs
=
inputs
;
options
.
kernel_name
=
"softmax_kernel"
;
auto
src
=
interpolate_string
(
softmax_kernel
,
{{
"transformers"
,
make_transformer_args
(
vec
)},
{
"axis"
,
to_string
(
axis
)}});
return
compile_hip_code_object
(
src
,
options
);
}
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
)
const
{
return
replace
(
compile_op
(
ctx
,
to_shapes
(
ins
->
inputs
()),
op
.
to_value
()));
}
};
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/kernels/include/migraphx/kernels/array.hpp
View file @
8520e0b8
...
...
@@ -27,6 +27,7 @@
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/type_traits.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/functional.hpp>
#include <migraphx/kernels/debug.hpp>
namespace
migraphx
{
...
...
@@ -213,6 +214,13 @@ constexpr auto transform(integral_const_array<T, Xs...>, F f)
return
integral_const_array
<
T
,
f
(
Xs
)...
>
{};
}
template
<
class
T
,
T
...
Xs
,
class
F
>
constexpr
auto
transform_i
(
integral_const_array
<
T
,
Xs
...
>
,
F
f
)
{
return
sequence_c
<
sizeof
...(
Xs
)
>
(
[
=
](
auto
...
is
)
{
return
integral_const_array
<
T
,
f
(
Xs
,
is
)...
>
{};
});
}
template
<
class
T
,
T
...
Xs
,
class
U
,
U
...
Ys
,
class
F
>
constexpr
auto
transform
(
integral_const_array
<
T
,
Xs
...
>
,
integral_const_array
<
U
,
Ys
...
>
,
F
f
)
{
...
...
src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp
View file @
8520e0b8
...
...
@@ -24,7 +24,7 @@
#ifndef MIGRAPHX_GUARD_KERNELS_FUNCTIONAL_HPP
#define MIGRAPHX_GUARD_KERNELS_FUNCTIONAL_HPP
#include <migraphx/kernels/
array
.hpp>
#include <migraphx/kernels/
integral_constant
.hpp>
// NOLINTNEXTLINE
#define MIGRAPHX_RETURNS(...) \
...
...
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
View file @
8520e0b8
...
...
@@ -175,6 +175,21 @@ constexpr auto sliced(Slicer slicer, F f)
};
}
template
<
class
Input
,
index_int
Axis
>
constexpr
auto
compute_reduce_axis
()
{
constexpr
auto
lens
=
transform_i
(
get_shape_c
<
Input
>
{}.
lens
,
[](
index_int
x
,
index_int
i
)
->
index_int
{
if
(
i
==
Axis
)
return
1
;
return
x
;
});
return
make_shape
(
lens
,
get_shape_c
<
Input
>
{}.
strides
);
}
template
<
class
Input
,
index_int
Axis
>
using
with_axis
=
decltype
(
compute_reduce_axis
<
Input
,
Axis
>
());
struct
block
{
template
<
class
Slicer
>
...
...
@@ -201,6 +216,14 @@ struct block
if
(
idx
.
local
==
0
)
f
();
}
template
<
class
F
>
__device__
auto
inner
(
F
f
)
const
{
return
sliced
(
slicer
,
[
=
](
auto
x
,
auto
...
xs
)
{
idx
.
local_stride
(
x
.
get_shape
().
elements
(),
[
&
](
auto
j
)
{
f
(
x
[
j
],
xs
[
j
]...);
});
});
}
};
template
<
class
Slicer
>
...
...
@@ -247,6 +270,17 @@ struct lane
{
f
();
}
template
<
class
F
>
__device__
auto
inner
(
F
f
)
const
{
return
sliced
(
slicer
,
[
=
](
auto
x
,
auto
...
xs
)
{
for
(
index_int
j
=
0
;
j
<
x
.
get_shape
().
elements
();
j
++
)
{
f
(
x
[
j
],
xs
[
j
]...);
}
});
}
};
template
<
class
Slicer
>
...
...
src/targets/gpu/kernels/include/migraphx/kernels/shape.hpp
View file @
8520e0b8
...
...
@@ -32,6 +32,7 @@ namespace migraphx {
template
<
class
Lens
,
class
Strides
>
struct
shape
{
using
shape_type
=
shape
;
using
index_array
=
typename
Lens
::
base_array
;
Lens
lens
=
{};
Strides
strides
=
{};
...
...
src/targets/gpu/kernels/include/migraphx/kernels/softmax.hpp
0 → 100644
View file @
8520e0b8
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_KERNELS_SOFTMAX_HPP
#define MIGRAPHX_GUARD_KERNELS_SOFTMAX_HPP
#include <migraphx/kernels/reduce.hpp>
#include <migraphx/kernels/ops.hpp>
namespace
migraphx
{
template
<
index_int
Axis
,
class
Input
,
class
Output
>
__device__
void
softmax
(
Input
input
,
Output
output
)
{
reduce
::
block
::
run
<
reduce
::
with_axis
<
Input
,
Axis
>>
([
&
](
auto
,
auto
r
)
{
auto
batch_max
=
r
.
reduce
(
op
::
max
{},
lowest
{},
op
::
id
{})(
input
);
auto
batch_sum
=
r
.
reduce
(
op
::
sum
{},
0
,
[
&
](
auto
x
)
{
return
migraphx
::
exp
(
x
-
batch_max
);
})(
input
);
r
.
inner
([
&
](
auto
&
y
,
auto
x
)
{
y
=
migraphx
::
exp
(
x
-
batch_max
)
/
batch_sum
;
})(
output
,
input
);
});
}
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_SOFTMAX_HPP
src/targets/gpu/kernels/include/migraphx/kernels/vec.hpp
View file @
8520e0b8
...
...
@@ -27,6 +27,8 @@
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/functional.hpp>
#include <migraphx/kernels/type_traits.hpp>
#include <migraphx/kernels/debug.hpp>
namespace
migraphx
{
...
...
src/targets/gpu/lowering.cpp
View file @
8520e0b8
...
...
@@ -186,7 +186,6 @@ struct miopen_apply
add_extend_op
(
"rnn_var_sl_shift_output"
);
add_extend_op
(
"rnn_var_sl_shift_sequence"
);
add_extend_op
(
"scatter_none"
);
add_extend_op
(
"softmax"
);
add_extend_op
(
"topk"
);
add_batch_norm_inference_op
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment