Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
78a1dc1e
Commit
78a1dc1e
authored
Jan 27, 2023
by
Paul
Browse files
Add a quick groupnorm op
parent
bf0a7d92
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
150 additions
and
0 deletions
+150
-0
src/targets/gpu/jit/groupnorm.cpp
src/targets/gpu/jit/groupnorm.cpp
+123
-0
src/targets/gpu/kernels/include/migraphx/kernels/groupnorm.hpp
...argets/gpu/kernels/include/migraphx/kernels/groupnorm.hpp
+27
-0
No files found.
src/targets/gpu/jit/groupnorm.cpp
0 → 100644
View file @
78a1dc1e
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
using
namespace
migraphx
::
gpu
::
gen
;
// NOLINT
static
const
char
*
const
groupnorm_kernel
=
R"__migraphx__(
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/groupnorm.hpp>
#include <migraphx/kernels/vectorize.hpp>
#include <migraphx/kernels/preload.hpp>
#include <args.hpp>
namespace migraphx {
extern "C" {
__global__ void groupnorm_kernel(${params})
{
transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto... xs) {
groupnorm(xs...);
});
}
}
} // namespace migraphx
)__migraphx__"
;
struct
groupnorm_compiler
:
compiler
<
groupnorm_compiler
>
{
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"groupnorm"
};
}
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
{
// TODO: Use reduce_dims
auto
axis
=
inputs
.
front
().
lens
().
size
()
-
1
;
auto
faxis
=
find_fast_axis
({
inputs
.
front
()});
vectorize
vec
{};
// Vectorize if the axis is a reduction axis
if
(
axis
==
faxis
)
{
vec
=
vectorize
::
elements
(
ctx
,
faxis
,
inputs
);
}
auto
relements
=
inputs
[
0
].
lens
()[
axis
]
/
vec
.
size
;
auto
nelements
=
inputs
.
back
().
elements
();
auto
block_size
=
compute_block_size
(
relements
,
256
);
hip_compile_options
options
;
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
nelements
*
block_size
,
256
),
block_size
);
options
.
output
=
inputs
.
back
();
options
.
inputs
=
inputs
;
options
.
kernel_name
=
"groupnorm_kernel"
;
auto
src
=
interpolate_string
(
groupnorm_kernel
,
{{
"kernel"
,
options
.
kernel_name
},
{
"params"
,
enum_params
(
inputs
.
size
(),
"void * private_p"
)},
{
"args"
,
enum_params
(
inputs
.
size
(),
"private_p"
)},
{
"transformers"
,
make_transformer_args
(
vec
)}});
return
compile_hip_code_object
(
src
,
options
);
}
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
)
const
{
auto
v
=
op
.
to_value
();
v
[
"groupnorm"
]
=
"groupnorm"
;
v
[
"kernel"
]
=
"groupnorm_kernel"
;
if
(
op
.
name
()
==
"gpu::preadd_groupnorm"
)
{
v
[
"groupnorm"
]
=
"add_groupnorm"
;
v
[
"kernel"
]
=
"add_groupnorm_kernel"
;
}
if
(
not
ins
->
module_inputs
().
empty
())
{
auto
*
pm
=
ins
->
module_inputs
().
front
();
v
[
"preamble"
]
=
generate_pointwise
(
*
pm
,
"post_groupnorm"
);
v
[
"post"
]
=
"MIGRAPHX_LIFT(post_groupnorm)"
;
v
[
"kernel"
]
=
v
[
"groupnorm"
].
to
<
std
::
string
>
()
+
"_"
+
generate_name_from_ops
(
*
pm
)
+
"_kernel"
;
}
return
replace
(
compile_op
(
ctx
,
to_shapes
(
ins
->
inputs
()),
v
));
}
};
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/kernels/include/migraphx/kernels/groupnorm.hpp
0 → 100644
View file @
78a1dc1e
#ifndef GUARD_AMDMIGRAPHX_GROUP_NORM_HPP
#define GUARD_AMDMIGRAPHX_GROUP_NORM_HPP
#include <migraphx/kernels/reduce.hpp>
#include <migraphx/kernels/ops.hpp>
#include <migraphx/kernels/vec.hpp>
#include <migraphx/kernels/print.hpp>
namespace
migraphx
{
template
<
class
Output
,
class
T
>
__device__
void
groupnorm
(
Output
out
,
T
x0
)
{
reduce
::
block
::
run
<
Output
>
([
&
](
auto
out_idx
,
auto
r
)
{
constexpr
auto
relements
=
r
.
template
elements
<
T
>();
auto
z1
=
r
.
reduce
(
op
::
sum
{},
0
,
op
::
mean
<
relements
>
{})(
x0
);
auto
z4
=
r
.
reduce
(
op
::
sum
{},
0
,
[
&
](
auto
x
)
{
auto
diff
=
x
-
z1
;
return
(
diff
*
diff
)
/
vec_type
<
decltype
(
diff
)
>
{
relements
};
})(
x0
);
r
.
outer
([
&
]
{
out
[
out_idx
]
=
migraphx
::
rsqrt
(
z4
+
1e-12
);
});
});
}
}
// namespace migraphx
#endif // GUARD_AMDMIGRAPHX_GROUP_NORM_HPP
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment