Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
1784584e
Unverified
Commit
1784584e
authored
Aug 16, 2022
by
Paul Fultz II
Committed by
GitHub
Aug 16, 2022
Browse files
Add jit layernorm fusion (#1301)
parent
18e4a2c6
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
396 additions
and
102 deletions
+396
-102
src/targets/gpu/compile_gen.cpp
src/targets/gpu/compile_gen.cpp
+65
-19
src/targets/gpu/fuse_ops.cpp
src/targets/gpu/fuse_ops.cpp
+29
-2
src/targets/gpu/include/migraphx/gpu/compile_gen.hpp
src/targets/gpu/include/migraphx/gpu/compile_gen.hpp
+5
-0
src/targets/gpu/jit/layernorm.cpp
src/targets/gpu/jit/layernorm.cpp
+133
-0
src/targets/gpu/jit/pointwise.cpp
src/targets/gpu/jit/pointwise.cpp
+8
-40
src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp
...rgets/gpu/kernels/include/migraphx/kernels/functional.hpp
+3
-2
src/targets/gpu/kernels/include/migraphx/kernels/layernorm.hpp
...argets/gpu/kernels/include/migraphx/kernels/layernorm.hpp
+83
-0
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
+19
-0
src/targets/gpu/kernels/include/migraphx/kernels/vec.hpp
src/targets/gpu/kernels/include/migraphx/kernels/vec.hpp
+1
-1
src/targets/gpu/prefuse_ops.cpp
src/targets/gpu/prefuse_ops.cpp
+49
-37
test/verify/test_layernorm.cpp
test/verify/test_layernorm.cpp
+1
-1
No files found.
src/targets/gpu/compile_gen.cpp
View file @
1784584e
...
@@ -25,6 +25,13 @@
...
@@ -25,6 +25,13 @@
#include <migraphx/shape.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/module.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -75,25 +82,25 @@ std::string vectorize::str() const
...
@@ -75,25 +82,25 @@ std::string vectorize::str() const
preload
preload
::
broadcasts
(
std
::
size_t
axis
,
const
std
::
vector
<
shape
>&
inputs
)
preload
preload
::
broadcasts
(
std
::
size_t
axis
,
const
std
::
vector
<
shape
>&
inputs
)
{
{
const
std
::
size_t
max_lds_bytes
=
4096
;
const
std
::
size_t
max_lds_bytes
=
4096
;
std
::
vector
<
bool
>
result
;
std
::
vector
<
bool
>
result
(
inputs
.
size
())
;
std
::
transform
(
inputs
.
begin
(),
std
::
vector
<
std
::
size_t
>
preloaded
;
inputs
.
end
(),
auto
idxs
=
range
(
inputs
.
size
());
std
::
back_inserter
(
re
sult
),
std
::
copy_if
(
idxs
.
begin
(),
idxs
.
end
(),
std
::
back_inserter
(
p
re
loaded
),
[
&
](
auto
i
)
{
[
&
](
const
shape
&
input
)
{
return
input
.
strides
()[
axis
]
==
0
;
});
return
input
s
[
i
]
.
strides
()[
axis
]
==
0
;
auto
bytes
=
std
::
inner_product
(
inputs
.
begin
(),
});
inputs
.
end
(),
std
::
sort
(
preloaded
.
begin
(),
preloaded
.
end
(),
by
(
std
::
less
<>
{},
[
&
](
auto
i
)
{
re
sult
.
begin
()
,
re
turn
inputs
[
i
].
bytes
()
;
std
::
size_t
{
0
},
}));
std
::
plus
<>
{},
[](
const
shape
&
s
,
bool
b
)
->
std
::
size_t
{
std
::
size_t
bytes
=
0
;
if
(
b
)
for
(
auto
i
:
preloaded
)
return
s
.
bytes
();
{
return
0
;
auto
input
=
inputs
[
i
]
;
}
);
bytes
+=
input
.
bytes
(
);
if
(
bytes
<
max_lds_bytes
)
if
(
bytes
>
max_lds_bytes
)
return
{
result
}
;
break
;
// TODO: Try to partially preload items
result
[
i
]
=
true
;
std
::
fill
(
result
.
begin
(),
result
.
end
(),
false
);
}
return
{
result
};
return
{
result
};
}
}
...
@@ -125,6 +132,45 @@ std::string make_transformer_args(std::vector<std::string> transformers)
...
@@ -125,6 +132,45 @@ std::string make_transformer_args(std::vector<std::string> transformers)
return
join_strings
(
std
::
move
(
transformers
),
", "
);
return
join_strings
(
std
::
move
(
transformers
),
", "
);
}
}
std
::
string
generate_pointwise
(
const
module
&
pm
,
const
std
::
string
&
name
)
{
module
m
=
pm
;
run_passes
(
m
,
{
eliminate_common_subexpression
{},
dead_code_elimination
{}});
cpp_generator
g
;
g
.
fmap
([](
const
std
::
string
&
fname
)
{
return
"migraphx::"
+
fname
;
});
g
.
add_point_op
(
"where"
,
"${function:where}(${0}, ${1}, ${2})"
);
g
.
add_point_op
(
"prelu"
,
"${function:where}(${0} < 0, ${0} * ${1}, ${0})"
);
g
.
add_point_op
(
"sign"
,
"${function:where}(${0} > 0, 1, ${function:where}(${0} < 0, -1, 0))"
);
g
.
add_point_op
(
"equal"
,
"migraphx::abs(${0} == ${1})"
);
g
.
add_point_op
(
"less"
,
"migraphx::abs(${0} < ${1})"
);
g
.
add_point_op
(
"greater"
,
"migraphx::abs(${0} > ${1})"
);
g
.
add_point_op
(
"not"
,
"migraphx::abs(not ${0})"
);
// Add explict conversions
g
.
fresult
(
[](
const
shape
&
s
)
{
return
"migraphx::convert<"
+
shape
::
cpp_type
(
s
.
type
())
+
">"
;
});
g
.
create_function
(
g
.
generate_module
(
m
).
set_attributes
({
"__device__"
}).
set_generic_types
(
m
).
set_name
(
name
));
return
g
.
str
();
}
static
std
::
vector
<
std
::
string
>
get_op_names
(
const
module
&
m
)
{
std
::
vector
<
std
::
string
>
result
;
for
(
auto
&
ins
:
m
)
{
if
(
starts_with
(
ins
.
name
(),
"@"
))
continue
;
result
.
push_back
(
ins
.
name
());
}
return
result
;
}
std
::
string
generate_name_from_ops
(
const
module
&
m
)
{
auto
op_names
=
get_op_names
(
m
);
return
join_strings
(
op_names
,
"_"
);
}
}
// namespace gen
}
// namespace gen
}
// namespace gpu
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/targets/gpu/fuse_ops.cpp
View file @
1784584e
...
@@ -827,13 +827,14 @@ void apply_conv_bias(context& ctx, module& m, const match::matcher_result& r)
...
@@ -827,13 +827,14 @@ void apply_conv_bias(context& ctx, module& m, const match::matcher_result& r)
m
.
replace_instruction
(
ins
,
cb
,
input_ins
,
weights_ins
,
old_ws_ins
,
bias_ins
,
alloc_ins
);
m
.
replace_instruction
(
ins
,
cb
,
input_ins
,
weights_ins
,
old_ws_ins
,
bias_ins
,
alloc_ins
);
}
}
inline
auto
precompile_name
(
std
::
string
s
)
// NOLINT
template
<
class
...
Strings
>
inline
auto
precompile_name
(
Strings
...
names
)
// NOLINT
{
{
return
match
::
make_basic_pred_matcher
([
=
](
instruction_ref
ins
)
{
return
match
::
make_basic_pred_matcher
([
=
](
instruction_ref
ins
)
{
if
(
ins
->
name
()
!=
"gpu::precompile_op"
)
if
(
ins
->
name
()
!=
"gpu::precompile_op"
)
return
false
;
return
false
;
auto
op
=
from_value
<
operation
>
(
ins
->
get_operator
().
to_value
().
at
(
"op"
));
auto
op
=
from_value
<
operation
>
(
ins
->
get_operator
().
to_value
().
at
(
"op"
));
return
(
op
.
name
()
==
s
);
return
(
contains
({
names
...},
op
.
name
()
)
);
});
});
}
}
...
@@ -1041,6 +1042,31 @@ struct find_contiguous_pointwise
...
@@ -1041,6 +1042,31 @@ struct find_contiguous_pointwise
}
}
};
};
struct
find_layernorm_pointwise
{
auto
matcher
()
const
{
return
precompile_name
(
"pointwise"
)(
match
::
arg
(
0
)(
precompile_name
(
"gpu::prelayernorm"
,
"gpu::preadd_layernorm"
).
bind
(
"layernorm"
)));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
layernorm
=
r
.
instructions
[
"layernorm"
];
auto
*
pm
=
ins
->
module_inputs
().
front
();
if
(
not
layernorm
->
module_inputs
().
empty
())
return
;
auto
inputs
=
layernorm
->
inputs
();
inputs
.
pop_back
();
inputs
.
insert
(
inputs
.
end
(),
ins
->
inputs
().
begin
()
+
1
,
ins
->
inputs
().
end
());
m
.
replace_instruction
(
ins
,
layernorm
->
get_operator
(),
inputs
,
{
pm
});
}
};
void
fuse_ops
::
apply
(
module
&
m
)
const
void
fuse_ops
::
apply
(
module
&
m
)
const
{
{
match
::
find_matches
(
m
,
find_contiguous_pointwise
{},
find_gelu
{},
find_gelu_new
{
fast_math
});
match
::
find_matches
(
m
,
find_contiguous_pointwise
{},
find_gelu
{},
find_gelu_new
{
fast_math
});
...
@@ -1063,6 +1089,7 @@ void fuse_ops::apply(module& m) const
...
@@ -1063,6 +1089,7 @@ void fuse_ops::apply(module& m) const
match
::
find_matches
(
m
,
match
::
find_matches
(
m
,
find_triadd_layernorm
{},
find_triadd_layernorm
{},
find_gemm_add
{},
find_gemm_add
{},
find_layernorm_pointwise
{},
find_gemm_pointwise
{},
find_gemm_pointwise
{},
find_commutative_broadcast
{});
find_commutative_broadcast
{});
match
::
find_matches
(
m
,
find_contiguous
{});
match
::
find_matches
(
m
,
find_contiguous
{});
...
...
src/targets/gpu/include/migraphx/gpu/compile_gen.hpp
View file @
1784584e
...
@@ -25,6 +25,7 @@
...
@@ -25,6 +25,7 @@
#define MIGRAPHX_GUARD_GPU_COMPILE_GEN_HPP
#define MIGRAPHX_GUARD_GPU_COMPILE_GEN_HPP
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/module_ref.hpp>
#include <string>
#include <string>
#include <unordered_map>
#include <unordered_map>
#include <vector>
#include <vector>
...
@@ -62,6 +63,10 @@ std::string make_transformer_args(Ts... xs)
...
@@ -62,6 +63,10 @@ std::string make_transformer_args(Ts... xs)
return
make_transformer_args
({
xs
.
str
()...});
return
make_transformer_args
({
xs
.
str
()...});
}
}
std
::
string
generate_pointwise
(
const
module
&
pm
,
const
std
::
string
&
name
);
std
::
string
generate_name_from_ops
(
const
module
&
m
);
}
// namespace gen
}
// namespace gen
}
// namespace gpu
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/targets/gpu/jit/layernorm.cpp
0 → 100644
View file @
1784584e
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
using
namespace
migraphx
::
gpu
::
gen
;
// NOLINT
static
const
char
*
const
layernorm_kernel
=
R"__migraphx__(
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/layernorm.hpp>
#include <migraphx/kernels/vectorize.hpp>
#include <migraphx/kernels/preload.hpp>
#include <args.hpp>
namespace migraphx {
${preamble}
extern "C" {
__global__ void ${kernel}(${params})
{
auto idx = make_index();
transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto... xs) {
${layernorm}<${axis}>(${post}, xs...);
});
}
}
} // namespace migraphx
)__migraphx__"
;
struct
layernorm_compiler
:
compiler
<
layernorm_compiler
>
{
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"layernorm"
,
"gpu::prelayernorm"
,
"gpu::preadd_layernorm"
};
}
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
{
// TODO: Use reduce_dims
auto
axis
=
inputs
.
front
().
lens
().
size
()
-
1
;
auto
faxis
=
find_fast_axis
({
inputs
.
front
()});
vectorize
vec
{};
// Vectorize if the axis is a reduction axis
if
(
axis
==
faxis
)
{
vec
=
vectorize
::
elements
(
faxis
,
inputs
);
}
auto
preloads
=
preload
::
broadcasts
(
axis
,
inputs
);
auto
relements
=
inputs
[
0
].
lens
()[
axis
]
/
vec
.
size
;
auto
nelements
=
(
inputs
.
back
().
elements
()
/
inputs
[
0
].
lens
()[
axis
]);
auto
block_size
=
compute_block_size
(
relements
,
256
);
hip_compile_options
options
;
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
nelements
*
block_size
,
256
),
block_size
);
options
.
output
=
inputs
.
back
();
options
.
inputs
=
inputs
;
options
.
kernel_name
=
v
.
get
(
"kernel"
,
"layernorm_kernel"
);
auto
src
=
interpolate_string
(
layernorm_kernel
,
{{
"kernel"
,
options
.
kernel_name
},
{
"params"
,
enum_params
(
inputs
.
size
(),
"void * private_p"
)},
{
"args"
,
enum_params
(
inputs
.
size
(),
"private_p"
)},
{
"transformers"
,
make_transformer_args
(
preloads
,
vec
)},
{
"post"
,
v
.
get
(
"post"
,
std
::
string
{
"op::id{}"
})},
{
"preamble"
,
v
.
get
(
"preamble"
,
std
::
string
{})},
{
"layernorm"
,
v
.
get
(
"layernorm"
,
std
::
string
{
"layernorm"
})},
{
"axis"
,
to_string
(
axis
)}});
return
compile_hip_code_object
(
src
,
options
);
}
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
)
const
{
auto
v
=
op
.
to_value
();
v
[
"layernorm"
]
=
"layernorm"
;
v
[
"kernel"
]
=
"layernorm_kernel"
;
if
(
op
.
name
()
==
"gpu::preadd_layernorm"
)
{
v
[
"layernorm"
]
=
"add_layernorm"
;
v
[
"kernel"
]
=
"add_layernorm_kernel"
;
}
if
(
not
ins
->
module_inputs
().
empty
())
{
auto
*
pm
=
ins
->
module_inputs
().
front
();
v
[
"preamble"
]
=
generate_pointwise
(
*
pm
,
"post_layernorm"
);
v
[
"post"
]
=
"MIGRAPHX_LIFT(post_layernorm)"
;
v
[
"kernel"
]
=
v
[
"layernorm"
].
to
<
std
::
string
>
()
+
"_"
+
generate_name_from_ops
(
*
pm
)
+
"_kernel"
;
}
return
replace
(
compile_op
(
ctx
,
to_shapes
(
ins
->
inputs
()),
v
));
}
};
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/jit/pointwise.cpp
View file @
1784584e
...
@@ -65,18 +65,6 @@ __global__ void ${kernel}(${params})
...
@@ -65,18 +65,6 @@ __global__ void ${kernel}(${params})
)__migraphx__"
;
)__migraphx__"
;
static
std
::
vector
<
std
::
string
>
get_op_names
(
const
module
&
m
)
{
std
::
vector
<
std
::
string
>
result
;
for
(
auto
&
ins
:
m
)
{
if
(
starts_with
(
ins
.
name
(),
"@"
))
continue
;
result
.
push_back
(
ins
.
name
());
}
return
result
;
}
struct
pointwise_compiler
:
compiler
<
pointwise_compiler
>
struct
pointwise_compiler
:
compiler
<
pointwise_compiler
>
{
{
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"pointwise"
,
"contiguous"
};
}
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"pointwise"
,
"contiguous"
};
}
...
@@ -126,34 +114,14 @@ struct pointwise_compiler : compiler<pointwise_compiler>
...
@@ -126,34 +114,14 @@ struct pointwise_compiler : compiler<pointwise_compiler>
else
else
{
{
assert
(
not
ins
->
module_inputs
().
empty
());
assert
(
not
ins
->
module_inputs
().
empty
());
auto
*
pm
=
ins
->
module_inputs
().
front
();
auto
*
pm
=
ins
->
module_inputs
().
front
();
run_passes
(
*
pm
,
{
eliminate_common_subexpression
{},
dead_code_elimination
{}});
auto
pf
=
generate_pointwise
(
*
pm
,
"inner_pointwise"
);
cpp_generator
g
;
std
::
string
lambda
=
"MIGRAPHX_LIFT(inner_pointwise)"
;
g
.
fmap
([](
const
std
::
string
&
fname
)
{
return
"migraphx::"
+
fname
;
});
auto
kernel_name
=
generate_name_from_ops
(
*
pm
)
+
"_kernel"
;
g
.
add_point_op
(
"where"
,
"${function:where}(${0}, ${1}, ${2})"
);
return
replace
(
g
.
add_point_op
(
"prelu"
,
"${function:where}(${0} < 0, ${0} * ${1}, ${0})"
);
compile_op
(
ctx
,
g
.
add_point_op
(
"sign"
,
to_shapes
(
ins
->
inputs
()),
"${function:where}(${0} > 0, 1, ${function:where}(${0} < 0, -1, 0))"
);
{{
"lambda"
,
lambda
},
{
"preamble"
,
pf
},
{
"kernel"
,
kernel_name
}}));
g
.
add_point_op
(
"equal"
,
"migraphx::abs(${0} == ${1})"
);
g
.
add_point_op
(
"less"
,
"migraphx::abs(${0} < ${1})"
);
g
.
add_point_op
(
"greater"
,
"migraphx::abs(${0} > ${1})"
);
g
.
add_point_op
(
"not"
,
"migraphx::abs(not ${0})"
);
g
.
add_point_op
(
"mod"
,
"migraphx::mod(${0}, ${1})"
);
g
.
add_point_op
(
"fmod"
,
"migraphx::fmod(${0}, ${1})"
);
// Add explict conversions
g
.
fresult
([](
const
shape
&
s
)
{
return
"migraphx::convert<"
+
shape
::
cpp_type
(
s
.
type
())
+
">"
;
});
auto
name
=
g
.
create_function
(
g
.
generate_module
(
*
pm
).
set_attributes
({
"__device__"
}).
set_generic_types
(
*
pm
));
std
::
string
lambda
=
"MIGRAPHX_LIFT("
+
name
+
")"
;
auto
op_names
=
get_op_names
(
*
pm
);
op_names
.
push_back
(
"kernel"
);
auto
op_name_string
=
join_strings
(
op_names
,
"_"
);
return
replace
(
compile_op
(
ctx
,
to_shapes
(
ins
->
inputs
()),
{{
"lambda"
,
lambda
},
{
"preamble"
,
g
.
str
()},
{
"kernel"
,
op_name_string
}}));
}
}
}
}
};
};
...
...
src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp
View file @
1784584e
...
@@ -31,8 +31,9 @@
...
@@ -31,8 +31,9 @@
->decltype(__VA_ARGS__) { return __VA_ARGS__; }
->decltype(__VA_ARGS__) { return __VA_ARGS__; }
// NOLINTNEXTLINE
// NOLINTNEXTLINE
#define MIGRAPHX_LIFT(...) \
#define MIGRAPHX_LIFT(...) \
[](auto&&... xs) MIGRAPHX_RETURNS((__VA_ARGS__)(static_cast<decltype(xs)>(xs)...))
[](auto&&... private_lisft_xs) MIGRAPHX_RETURNS( \
(__VA_ARGS__)(static_cast<decltype(private_lisft_xs)>(private_lisft_xs)...))
namespace
migraphx
{
namespace
migraphx
{
...
...
src/targets/gpu/kernels/include/migraphx/kernels/layernorm.hpp
0 → 100644
View file @
1784584e
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_KERNELS_LAYERNORM_HPP
#define MIGRAPHX_GUARD_KERNELS_LAYERNORM_HPP
#include <migraphx/kernels/reduce.hpp>
#include <migraphx/kernels/ops.hpp>
#include <migraphx/kernels/print.hpp>
namespace
migraphx
{
template
<
index_int
Axis
,
class
F
,
class
BinOp
,
class
Output
,
class
Input1
,
class
Input2
,
class
...
Inputs
>
__device__
void
generic_binary_layernorm
(
F
compute
,
BinOp
op
,
Output
output
,
Input1
input1
,
Input2
input2
,
Inputs
...
inputs
)
{
using
reduce_output
=
reduce
::
with_axis
<
Input1
,
Axis
>
;
reduce
::
block
::
run
<
reduce_output
>
([
&
](
auto
,
auto
r
)
{
using
value_type
=
typename
Input1
::
type
;
constexpr
auto
relements
=
r
.
template
elements
<
Input1
>();
auto
mean
=
[
&
](
auto
f
)
{
return
r
.
reduce
(
op
::
sum
{},
0
,
[
&
](
auto
x1
,
auto
x2
)
{
return
f
(
x1
,
x2
)
/
value_type
{
relements
};
})(
input1
,
input2
);
};
// mean(x)
auto
mean_x
=
mean
(
op
);
// mean(m ^ 2)
auto
mean_m2
=
mean
([
&
](
auto
x1
,
auto
x2
)
{
auto
m
=
op
(
x1
,
x2
)
-
mean_x
;
return
m
*
m
;
});
r
.
inner
([
&
](
auto
&
y
,
auto
x1
,
auto
x2
,
auto
...
xs
)
{
auto
m
=
op
(
x1
,
x2
)
-
mean_x
;
// m * rsqrt(mean(m ^ 2) + 1e-12)
y
=
compute
(
m
*
rsqrt
(
mean_m2
+
value_type
{
1e-12
}),
xs
...);
})(
output
,
input1
,
input2
,
inputs
...);
});
}
template
<
index_int
Axis
,
class
F
,
class
Output
,
class
Input
,
class
...
Inputs
>
__device__
void
layernorm
(
F
compute
,
Output
output
,
Input
input
,
Inputs
...
inputs
)
{
generic_binary_layernorm
<
Axis
>
(
compute
,
[](
auto
x
,
auto
)
{
return
x
;
},
output
,
input
,
input
,
inputs
...);
}
template
<
index_int
Axis
,
class
F
,
class
Output
,
class
Input1
,
class
Input2
,
class
...
Inputs
>
__device__
void
add_layernorm
(
F
compute
,
Output
output
,
Input1
input1
,
Input2
input2
,
Inputs
...
inputs
)
{
generic_binary_layernorm
<
Axis
>
(
compute
,
[](
auto
x1
,
auto
x2
)
{
return
x1
+
x2
;
},
output
,
input1
,
input2
,
inputs
...);
}
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_LAYERNORM_HPP
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
View file @
1784584e
...
@@ -224,6 +224,18 @@ struct block
...
@@ -224,6 +224,18 @@ struct block
idx
.
local_stride
(
x
.
get_shape
().
elements
(),
[
&
](
auto
j
)
{
f
(
x
[
j
],
xs
[
j
]...);
});
idx
.
local_stride
(
x
.
get_shape
().
elements
(),
[
&
](
auto
j
)
{
f
(
x
[
j
],
xs
[
j
]...);
});
});
});
}
}
template
<
class
Input
>
constexpr
auto
elements
()
const
{
using
reduce_type
=
decltype
(
slicer
(
Input
{}));
using
value_type
=
typename
Input
::
type
;
constexpr
auto
relements
=
get_shape_c
<
reduce_type
>
{}.
elements
();
if
constexpr
(
vec_size
<
value_type
>
()
>
1
)
return
relements
*
vec_size
<
value_type
>
();
else
return
relements
;
}
};
};
template
<
class
Slicer
>
template
<
class
Slicer
>
...
@@ -281,6 +293,13 @@ struct lane
...
@@ -281,6 +293,13 @@ struct lane
}
}
});
});
}
}
template
<
class
Input
>
constexpr
auto
elements
()
const
{
using
reduce_type
=
decltype
(
slicer
(
Input
{}));
return
get_shape_c
<
reduce_type
>
{}.
elements
();
}
};
};
template
<
class
Slicer
>
template
<
class
Slicer
>
...
...
src/targets/gpu/kernels/include/migraphx/kernels/vec.hpp
View file @
1784584e
...
@@ -175,7 +175,7 @@ template <class T, class Op>
...
@@ -175,7 +175,7 @@ template <class T, class Op>
constexpr
auto
vec_reduce
(
T
x
,
Op
op
)
constexpr
auto
vec_reduce
(
T
x
,
Op
op
)
{
{
if
constexpr
(
vec_size
<
T
>
()
<
2
)
if
constexpr
(
vec_size
<
T
>
()
<
2
)
return
x
;
return
vec_type
<
T
>
{
x
}
;
else
else
{
{
vec_type
<
T
>
result
=
x
[
0
];
vec_type
<
T
>
result
=
x
[
0
];
...
...
src/targets/gpu/prefuse_ops.cpp
View file @
1784584e
...
@@ -24,12 +24,53 @@
...
@@ -24,12 +24,53 @@
#include <migraphx/gpu/prefuse_ops.hpp>
#include <migraphx/gpu/prefuse_ops.hpp>
#include <migraphx/match/layernorm.hpp>
#include <migraphx/match/layernorm.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
gpu
{
namespace
{
namespace
{
template
<
class
Derived
,
std
::
size_t
N
>
struct
layernorm_base
{
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
,
std
::
vector
<
module_ref
>
mods
)
const
{
std
::
size_t
nargs
=
1
;
if
(
not
mods
.
empty
())
{
auto
*
pm
=
mods
.
front
();
nargs
=
pm
->
get_parameter_names
().
size
();
}
check_shapes
{
inputs
,
static_cast
<
const
Derived
&>
(
*
this
)}.
has
(
nargs
+
N
);
auto
s
=
inputs
.
at
(
0
);
if
(
s
.
scalar
())
{
return
s
;
}
else
if
(
s
.
broadcasted
())
{
return
{
s
.
type
(),
s
.
lens
()};
}
else
{
return
s
.
with_lens
(
s
.
lens
());
}
}
};
struct
layernorm
:
layernorm_base
<
layernorm
,
0
>
{
std
::
string
name
()
const
{
return
"gpu::prelayernorm"
;
}
};
MIGRAPHX_REGISTER_OP
(
layernorm
);
struct
add_layernorm
:
layernorm_base
<
add_layernorm
,
1
>
{
std
::
string
name
()
const
{
return
"gpu::preadd_layernorm"
;
}
};
MIGRAPHX_REGISTER_OP
(
add_layernorm
);
struct
find_layernorm
struct
find_layernorm
{
{
auto
matcher
()
const
{
return
match
::
layernorm
();
}
auto
matcher
()
const
{
return
match
::
layernorm
();
}
...
@@ -39,59 +80,30 @@ struct find_layernorm
...
@@ -39,59 +80,30 @@ struct find_layernorm
auto
ins
=
r
.
result
;
auto
ins
=
r
.
result
;
auto
x_ins
=
r
.
instructions
[
"x"
];
auto
x_ins
=
r
.
instructions
[
"x"
];
if
(
not
x_ins
->
get_shape
().
standard
())
m
.
replace_instruction
(
ins
,
layernorm
{},
x_ins
);
x_ins
=
m
.
insert_instruction
(
ins
,
make_op
(
"contiguous"
),
x_ins
);
auto
relements
=
x_ins
->
get_shape
().
lens
().
back
();
if
(
relements
>
1024
or
(
relements
%
4
!=
0
and
relements
>
256
))
return
;
auto
a
=
m
.
insert_instruction
(
ins
,
make_op
(
"hip::allocate"
,
{{
"shape"
,
to_value
(
x_ins
->
get_shape
())}}));
m
.
replace_instruction
(
ins
,
make_op
(
"gpu::layernorm"
),
x_ins
,
a
);
}
}
};
};
struct
find_
tri
addlayernorm
struct
find_add
_
layernorm
{
{
auto
matcher
()
const
auto
matcher
()
const
{
{
auto
add1
=
return
match
::
layernorm
()(
match
::
var
(
"x"
)(
match
::
name
(
"add"
).
bind
(
"add"
)));
match
::
name
(
"add"
)(
match
::
none_of
(
match
::
is_constant
()),
match
::
args
(
match
::
any
().
bind
(
"z1"
),
match
::
any
().
bind
(
"z2"
)));
auto
add2
=
match
::
name
(
"add"
)(
match
::
either_arg
(
0
,
1
)(
add1
,
match
::
any
().
bind
(
"z3"
)));
return
match
::
layernorm
()(
match
::
var
(
"x"
)(
add2
));
}
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
{
auto
ins
=
r
.
result
;
auto
ins
=
r
.
result
;
auto
x_ins
=
r
.
instructions
[
"z1"
];
auto
add_ins
=
r
.
instructions
[
"add"
];
auto
y_ins
=
r
.
instructions
[
"z2"
];
auto
z_ins
=
r
.
instructions
[
"z3"
];
for
(
auto
*
pins
:
{
&
x_ins
,
&
y_ins
,
&
z_ins
})
{
if
(
not
(
*
pins
)
->
get_shape
().
standard
())
*
pins
=
m
.
insert_instruction
(
ins
,
make_op
(
"contiguous"
),
*
pins
);
}
auto
relements
=
x_ins
->
get_shape
().
lens
().
back
();
if
(
relements
>
1024
or
(
relements
%
4
!=
0
and
relements
>
256
))
return
;
auto
a
=
m
.
insert_instruction
(
m
.
replace_instruction
(
ins
,
add_layernorm
{},
add_ins
->
inputs
());
ins
,
make_op
(
"hip::allocate"
,
{{
"shape"
,
to_value
(
x_ins
->
get_shape
())}}));
m
.
replace_instruction
(
ins
,
make_op
(
"gpu::triadd_layernorm"
),
x_ins
,
y_ins
,
z_ins
,
a
);
}
}
};
};
}
// namespace
}
// namespace
void
prefuse_ops
::
apply
(
module
&
m
)
const
void
prefuse_ops
::
apply
(
module
&
m
)
const
{
{
match
::
find_matches
(
m
,
find_
tri
addlayernorm
{},
find_layernorm
{});
match
::
find_matches
(
m
,
find_add
_
layernorm
{},
find_layernorm
{});
}
}
}
// namespace gpu
}
// namespace gpu
...
...
test/verify/test_layernorm.cpp
View file @
1784584e
...
@@ -68,7 +68,7 @@ struct test_layernorm : verify_program<test_layernorm>
...
@@ -68,7 +68,7 @@ struct test_layernorm : verify_program<test_layernorm>
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
std
::
vector
<
size_t
>
dims
=
{
1
,
1
,
5
};
std
::
vector
<
size_t
>
dims
=
{
1
,
2
,
5
};
auto
x
=
mm
->
add_parameter
(
"x"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
dims
});
auto
x
=
mm
->
add_parameter
(
"x"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
dims
});
add_layernorm
(
*
mm
,
x
,
dims
);
add_layernorm
(
*
mm
,
x
,
dims
);
return
p
;
return
p
;
...
...
gaoqiong
@gaoqiong
mentioned in commit
9afce86d
·
Dec 14, 2023
mentioned in commit
9afce86d
mentioned in commit 9afce86dfa141fd93dd9b4de878e167db78e4c4f
Toggle commit list
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment