Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
d94c54f0
Commit
d94c54f0
authored
May 31, 2022
by
Paul
Browse files
Add layernorm kernel
parent
6d8e4c53
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
137 additions
and
0 deletions
+137
-0
src/targets/gpu/jit/layernorm.cpp
src/targets/gpu/jit/layernorm.cpp
+80
-0
src/targets/gpu/kernels/include/migraphx/kernels/layernorm.hpp
...argets/gpu/kernels/include/migraphx/kernels/layernorm.hpp
+34
-0
src/targets/gpu/prefuse_ops.cpp
src/targets/gpu/prefuse_ops.cpp
+23
-0
No files found.
src/targets/gpu/jit/layernorm.cpp
0 → 100644
View file @
d94c54f0
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
using
namespace
migraphx
::
gpu
::
gen
;
// NOLINT
static
const
char
*
const
layernorm_kernel
=
R"__migraphx__(
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/layernorm.hpp>
#include <migraphx/kernels/vectorize.hpp>
#include <args.hpp>
namespace migraphx {
extern "C" {
__global__ void layernorm_kernel(void* input_p, void* output_p)
{
transform_args(make_tensors(), ${transformers})(input_p, output_p)([](auto input, auto output) {
layernorm<${axis}>(input, output);
});
}
}
} // namespace migraphx
)__migraphx__"
;
struct
layernorm_compiler
:
compiler
<
layernorm_compiler
>
{
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"layernorm"
};
}
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
{
// TODO: Use reduce_dims
auto
axis
=
inputs
.
front
().
lens
().
size
()
-
1
;
auto
faxis
=
find_fast_axis
({
inputs
.
front
()});
vectorize
vec
{};
// Vectorize if the axis is a reduction axis
if
(
inputs
.
back
().
lens
()[
faxis
]
==
1
)
{
vec
=
vectorize
::
elements
(
faxis
,
inputs
);
}
auto
relements
=
inputs
[
0
].
lens
()[
axis
]
/
vec
.
size
;
auto
nelements
=
inputs
.
back
().
elements
()
/
relements
;
auto
block_size
=
compute_block_size
(
relements
,
256
);
hip_compile_options
options
;
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
nelements
*
block_size
,
256
),
block_size
);
options
.
output
=
inputs
.
back
();
options
.
inputs
=
inputs
;
options
.
kernel_name
=
"layernorm_kernel"
;
auto
src
=
interpolate_string
(
layernorm_kernel
,
{{
"transformers"
,
make_transformer_args
(
vec
)},
{
"axis"
,
to_string
(
axis
)}});
return
compile_hip_code_object
(
src
,
options
);
}
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
)
const
{
return
replace
(
compile_op
(
ctx
,
to_shapes
(
ins
->
inputs
()),
op
.
to_value
()));
}
};
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/kernels/include/migraphx/kernels/layernorm.hpp
0 → 100644
View file @
d94c54f0
#ifndef MIGRAPHX_GUARD_KERNELS_LAYERNORM_HPP
#define MIGRAPHX_GUARD_KERNELS_LAYERNORM_HPP
#include <migraphx/kernels/reduce.hpp>
#include <migraphx/kernels/ops.hpp>
namespace
migraphx
{
template
<
index_int
Axis
,
class
Input
,
class
Output
>
__device__
void
layernorm
(
Input
input
,
Output
output
)
{
constexpr
auto
relements
=
get_shape_c
<
reduce
::
with_axis
<
Input
,
Axis
>>
{}.
elements
()
/
get_shape_c
<
Input
>
{}.
elements
();
reduce
::
block
::
run
<
reduce
::
with_axis
<
Input
,
Axis
>>
([
&
](
auto
,
auto
r
)
{
using
value_type
=
typename
Input
::
type
;
auto
mean
=
[
&
](
auto
f
)
{
return
r
.
reduce
(
op
::
sum
{},
0
,
f
)(
input
)
/
value_type
{
relements
};
};
// mean(x)
auto
mean_x
=
mean
(
op
::
id
{});
// mean(m ^ 2)
auto
mean_m2
=
mean
([
&
](
auto
x
)
{
auto
m
=
x
-
mean_x
;
return
m
*
m
;
});
r
.
inner
([
&
](
auto
&
y
,
auto
x
)
{
auto
m
=
x
-
mean_x
;
// m * rsqrt(mean(m ^ 2) + 1e-12)
y
=
m
*
rsqrt
(
mean_m2
+
value_type
{
1e-12
});
})(
output
,
input
);
});
}
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_LAYERNORM_HPP
src/targets/gpu/prefuse_ops.cpp
View file @
d94c54f0
...
@@ -6,6 +6,29 @@ namespace migraphx {
...
@@ -6,6 +6,29 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
gpu
{
struct
layernorm
{
std
::
string
name
()
const
{
return
"gpu::prelayernorm"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
auto
s
=
inputs
.
at
(
0
);
if
(
s
.
scalar
())
{
return
s
;
}
else
if
(
s
.
broadcasted
())
{
return
{
s
.
type
(),
s
.
lens
()};
}
else
{
return
s
.
with_lens
(
s
.
lens
());
}
}
};
namespace
{
namespace
{
struct
find_layernorm
struct
find_layernorm
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment