Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
8076d7f7
Commit
8076d7f7
authored
Jun 25, 2022
by
Paul
Browse files
Merge branch 'jit-layernorm' into bert-opt2
parents
aabb14ff
817543c7
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
204 additions
and
70 deletions
+204
-70
src/targets/gpu/compile_gen.cpp
src/targets/gpu/compile_gen.cpp
+45
-0
src/targets/gpu/fuse_ops.cpp
src/targets/gpu/fuse_ops.cpp
+29
-2
src/targets/gpu/include/migraphx/gpu/compile_gen.hpp
src/targets/gpu/include/migraphx/gpu/compile_gen.hpp
+5
-0
src/targets/gpu/jit/layernorm.cpp
src/targets/gpu/jit/layernorm.cpp
+39
-9
src/targets/gpu/jit/pointwise.cpp
src/targets/gpu/jit/pointwise.cpp
+8
-38
src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp
...rgets/gpu/kernels/include/migraphx/kernels/functional.hpp
+3
-2
src/targets/gpu/kernels/include/migraphx/kernels/layernorm.hpp
...argets/gpu/kernels/include/migraphx/kernels/layernorm.hpp
+36
-13
src/targets/gpu/prefuse_ops.cpp
src/targets/gpu/prefuse_ops.cpp
+39
-6
No files found.
src/targets/gpu/compile_gen.cpp
View file @
8076d7f7
...
...
@@ -25,6 +25,12 @@
#include <migraphx/shape.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/module.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/instruction.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
@@ -125,6 +131,45 @@ std::string make_transformer_args(std::vector<std::string> transformers)
return
join_strings
(
std
::
move
(
transformers
),
", "
);
}
std
::
string
generate_pointwise
(
const
module
&
pm
,
const
std
::
string
&
name
)
{
module
m
=
pm
;
run_passes
(
m
,
{
eliminate_common_subexpression
{},
dead_code_elimination
{}});
cpp_generator
g
;
g
.
fmap
([](
const
std
::
string
&
fname
)
{
return
"migraphx::"
+
fname
;
});
g
.
add_point_op
(
"where"
,
"${function:where}(${0}, ${1}, ${2})"
);
g
.
add_point_op
(
"prelu"
,
"${function:where}(${0} < 0, ${0} * ${1}, ${0})"
);
g
.
add_point_op
(
"sign"
,
"${function:where}(${0} > 0, 1, ${function:where}(${0} < 0, -1, 0))"
);
g
.
add_point_op
(
"equal"
,
"migraphx::abs(${0} == ${1})"
);
g
.
add_point_op
(
"less"
,
"migraphx::abs(${0} < ${1})"
);
g
.
add_point_op
(
"greater"
,
"migraphx::abs(${0} > ${1})"
);
g
.
add_point_op
(
"not"
,
"migraphx::abs(not ${0})"
);
// Add explict conversions
g
.
fresult
(
[](
const
shape
&
s
)
{
return
"migraphx::convert<"
+
shape
::
cpp_type
(
s
.
type
())
+
">"
;
});
g
.
create_function
(
g
.
generate_module
(
m
).
set_attributes
({
"__device__"
}).
set_generic_types
(
m
).
set_name
(
name
));
return
g
.
str
();
}
static
std
::
vector
<
std
::
string
>
get_op_names
(
const
module
&
m
)
{
std
::
vector
<
std
::
string
>
result
;
for
(
auto
&
ins
:
m
)
{
if
(
starts_with
(
ins
.
name
(),
"@"
))
continue
;
result
.
push_back
(
ins
.
name
());
}
return
result
;
}
std
::
string
generate_name_from_ops
(
const
module
&
m
)
{
auto
op_names
=
get_op_names
(
m
);
return
join_strings
(
op_names
,
"_"
);
}
}
// namespace gen
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/targets/gpu/fuse_ops.cpp
View file @
8076d7f7
...
...
@@ -830,13 +830,14 @@ void apply_conv_bias(context& ctx, module& m, const match::matcher_result& r)
m
.
replace_instruction
(
ins
,
cb
,
input_ins
,
weights_ins
,
old_ws_ins
,
bias_ins
,
alloc_ins
);
}
inline
auto
precompile_name
(
std
::
string
s
)
// NOLINT
template
<
class
...
Strings
>
inline
auto
precompile_name
(
Strings
...
names
)
// NOLINT
{
return
match
::
make_basic_pred_matcher
([
=
](
instruction_ref
ins
)
{
if
(
ins
->
name
()
!=
"gpu::precompile_op"
)
return
false
;
auto
op
=
from_value
<
operation
>
(
ins
->
get_operator
().
to_value
().
at
(
"op"
));
return
(
op
.
name
()
==
s
);
return
(
contains
({
names
...},
op
.
name
()
)
);
});
}
...
...
@@ -1111,6 +1112,31 @@ struct find_contiguous_pointwise
}
};
struct
find_layernorm_pointwise
{
auto
matcher
()
const
{
return
precompile_name
(
"pointwise"
)(
match
::
arg
(
0
)(
precompile_name
(
"gpu::prelayernorm"
,
"gpu::preadd_layernorm"
).
bind
(
"layernorm"
)));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
layernorm
=
r
.
instructions
[
"layernorm"
];
auto
*
pm
=
ins
->
module_inputs
().
front
();
if
(
not
layernorm
->
module_inputs
().
empty
())
return
;
auto
inputs
=
layernorm
->
inputs
();
inputs
.
pop_back
();
inputs
.
insert
(
inputs
.
end
(),
ins
->
inputs
().
begin
()
+
1
,
ins
->
inputs
().
end
());
m
.
replace_instruction
(
ins
,
layernorm
->
get_operator
(),
inputs
,
{
pm
});
}
};
void
fuse_ops
::
apply
(
module
&
m
)
const
{
match
::
find_matches
(
m
,
find_contiguous_pointwise
{},
find_gelu
{},
find_gelu_new
{
fast_math
});
...
...
@@ -1133,6 +1159,7 @@ void fuse_ops::apply(module& m) const
match
::
find_matches
(
m
,
find_triadd_layernorm
{},
find_gemm_add
{},
find_layernorm_pointwise
{},
find_gemm_pointwise
{},
find_commutative_broadcast
{});
match
::
find_matches
(
m
,
find_contiguous
{});
...
...
src/targets/gpu/include/migraphx/gpu/compile_gen.hpp
View file @
8076d7f7
...
...
@@ -25,6 +25,7 @@
#define MIGRAPHX_GUARD_GPU_COMPILE_GEN_HPP
#include <migraphx/config.hpp>
#include <migraphx/module_ref.hpp>
#include <string>
#include <unordered_map>
#include <vector>
...
...
@@ -62,6 +63,10 @@ std::string make_transformer_args(Ts... xs)
return
make_transformer_args
({
xs
.
str
()...});
}
std
::
string
generate_pointwise
(
const
module
&
pm
,
const
std
::
string
&
name
);
std
::
string
generate_name_from_ops
(
const
module
&
m
);
}
// namespace gen
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/targets/gpu/jit/layernorm.cpp
View file @
8076d7f7
...
...
@@ -19,15 +19,19 @@ static const char* const layernorm_kernel = R"__migraphx__(
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/layernorm.hpp>
#include <migraphx/kernels/vectorize.hpp>
#include <migraphx/kernels/preload.hpp>
#include <args.hpp>
namespace migraphx {
${preamble}
extern "C" {
__global__ void
layernorm_kernel(void* input_p, void* output_p
)
__global__ void
${kernel}(${params}
)
{
transform_args(make_tensors(), rotate_last(), ${transformers})(input_p, output_p)([](auto... xs) {
layernorm<${axis}>(op::id{}, xs...);
auto idx = make_index();
transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto... xs) {
${layernorm}<${axis}>(${post}, xs...);
});
}
...
...
@@ -39,7 +43,10 @@ __global__ void layernorm_kernel(void* input_p, void* output_p)
struct
layernorm_compiler
:
compiler
<
layernorm_compiler
>
{
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"layernorm"
,
"gpu::prelayernorm"
};
}
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"layernorm"
,
"gpu::prelayernorm"
,
"gpu::preadd_layernorm"
};
}
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
{
...
...
@@ -52,6 +59,7 @@ struct layernorm_compiler : compiler<layernorm_compiler>
{
vec
=
vectorize
::
elements
(
faxis
,
inputs
);
}
auto
preloads
=
preload
::
broadcasts
(
axis
,
inputs
);
auto
relements
=
inputs
[
0
].
lens
()[
axis
]
/
vec
.
size
;
auto
nelements
=
inputs
.
back
().
elements
()
/
relements
;
auto
block_size
=
compute_block_size
(
relements
,
256
);
...
...
@@ -60,18 +68,40 @@ struct layernorm_compiler : compiler<layernorm_compiler>
v
,
compute_global_for
(
ctx
,
nelements
*
block_size
,
256
),
block_size
);
options
.
output
=
inputs
.
back
();
options
.
inputs
=
inputs
;
options
.
kernel_name
=
"layernorm_kernel"
;
options
.
kernel_name
=
v
.
get
(
"kernel"
,
"layernorm_kernel"
)
;
auto
src
=
interpolate_string
(
layernorm_kernel
,
{{
"transformers"
,
make_transformer_args
(
vec
)},
{
"axis"
,
to_string
(
axis
)}});
auto
src
=
interpolate_string
(
layernorm_kernel
,
{{
"kernel"
,
options
.
kernel_name
},
{
"params"
,
enum_params
(
inputs
.
size
(),
"void * private_p"
)},
{
"args"
,
enum_params
(
inputs
.
size
(),
"private_p"
)},
{
"transformers"
,
make_transformer_args
(
preloads
,
vec
)},
{
"post"
,
v
.
get
(
"post"
,
std
::
string
{
"op::id{}"
})},
{
"preamble"
,
v
.
get
(
"preamble"
,
std
::
string
{})},
{
"layernorm"
,
v
.
get
(
"layernorm"
,
std
::
string
{
"layernorm"
})},
{
"axis"
,
to_string
(
axis
)}});
return
compile_hip_code_object
(
src
,
options
);
}
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
)
const
{
return
replace
(
compile_op
(
ctx
,
to_shapes
(
ins
->
inputs
()),
op
.
to_value
()));
auto
v
=
op
.
to_value
();
v
[
"layernorm"
]
=
"layernorm"
;
v
[
"kernel"
]
=
"layernorm_kernel"
;
if
(
op
.
name
()
==
"gpu::preadd_layernorm"
)
{
v
[
"layernorm"
]
=
"add_layernorm"
;
v
[
"kernel"
]
=
"add_layernorm_kernel"
;
}
if
(
not
ins
->
module_inputs
().
empty
())
{
auto
*
pm
=
ins
->
module_inputs
().
front
();
v
[
"preamble"
]
=
generate_pointwise
(
*
pm
,
"post_layernorm"
);
v
[
"post"
]
=
"MIGRAPHX_LIFT(post_layernorm)"
;
v
[
"kernel"
]
=
v
[
"layernorm"
].
to
<
std
::
string
>
()
+
"_"
+
generate_name_from_ops
(
*
pm
)
+
"_kernel"
;
}
return
replace
(
compile_op
(
ctx
,
to_shapes
(
ins
->
inputs
()),
v
));
}
};
...
...
src/targets/gpu/jit/pointwise.cpp
View file @
8076d7f7
...
...
@@ -65,18 +65,6 @@ __global__ void ${kernel}(${params})
)__migraphx__"
;
static
std
::
vector
<
std
::
string
>
get_op_names
(
const
module
&
m
)
{
std
::
vector
<
std
::
string
>
result
;
for
(
auto
&
ins
:
m
)
{
if
(
starts_with
(
ins
.
name
(),
"@"
))
continue
;
result
.
push_back
(
ins
.
name
());
}
return
result
;
}
struct
pointwise_compiler
:
compiler
<
pointwise_compiler
>
{
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"pointwise"
,
"contiguous"
};
}
...
...
@@ -126,32 +114,14 @@ struct pointwise_compiler : compiler<pointwise_compiler>
else
{
assert
(
not
ins
->
module_inputs
().
empty
());
auto
*
pm
=
ins
->
module_inputs
().
front
();
run_passes
(
*
pm
,
{
eliminate_common_subexpression
{},
dead_code_elimination
{}});
cpp_generator
g
;
g
.
fmap
([](
const
std
::
string
&
fname
)
{
return
"migraphx::"
+
fname
;
});
g
.
add_point_op
(
"where"
,
"${function:where}(${0}, ${1}, ${2})"
);
g
.
add_point_op
(
"prelu"
,
"${function:where}(${0} < 0, ${0} * ${1}, ${0})"
);
g
.
add_point_op
(
"sign"
,
"${function:where}(${0} > 0, 1, ${function:where}(${0} < 0, -1, 0))"
);
g
.
add_point_op
(
"equal"
,
"migraphx::abs(${0} == ${1})"
);
g
.
add_point_op
(
"less"
,
"migraphx::abs(${0} < ${1})"
);
g
.
add_point_op
(
"greater"
,
"migraphx::abs(${0} > ${1})"
);
g
.
add_point_op
(
"not"
,
"migraphx::abs(not ${0})"
);
// Add explict conversions
g
.
fresult
([](
const
shape
&
s
)
{
return
"migraphx::convert<"
+
shape
::
cpp_type
(
s
.
type
())
+
">"
;
});
auto
name
=
g
.
create_function
(
g
.
generate_module
(
*
pm
).
set_attributes
({
"__device__"
}).
set_generic_types
(
*
pm
));
std
::
string
lambda
=
"MIGRAPHX_LIFT("
+
name
+
")"
;
auto
op_names
=
get_op_names
(
*
pm
);
op_names
.
push_back
(
"kernel"
);
auto
op_name_string
=
join_strings
(
op_names
,
"_"
);
return
replace
(
compile_op
(
ctx
,
to_shapes
(
ins
->
inputs
()),
{{
"lambda"
,
lambda
},
{
"preamble"
,
g
.
str
()},
{
"kernel"
,
op_name_string
}}));
auto
*
pm
=
ins
->
module_inputs
().
front
();
auto
pf
=
generate_pointwise
(
*
pm
,
"inner_pointwise"
);
std
::
string
lambda
=
"MIGRAPHX_LIFT(inner_pointwise)"
;
auto
kernel_name
=
generate_name_from_ops
(
*
pm
)
+
"_kernel"
;
return
replace
(
compile_op
(
ctx
,
to_shapes
(
ins
->
inputs
()),
{{
"lambda"
,
lambda
},
{
"preamble"
,
pf
},
{
"kernel"
,
kernel_name
}}));
}
}
};
...
...
src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp
View file @
8076d7f7
...
...
@@ -31,8 +31,9 @@
->decltype(__VA_ARGS__) { return __VA_ARGS__; }
// NOLINTNEXTLINE
#define MIGRAPHX_LIFT(...) \
[](auto&&... xs) MIGRAPHX_RETURNS((__VA_ARGS__)(static_cast<decltype(xs)>(xs)...))
#define MIGRAPHX_LIFT(...) \
[](auto&&... private_lisft_xs) MIGRAPHX_RETURNS( \
(__VA_ARGS__)(static_cast<decltype(private_lisft_xs)>(private_lisft_xs)...))
namespace
migraphx
{
...
...
src/targets/gpu/kernels/include/migraphx/kernels/layernorm.hpp
View file @
8076d7f7
...
...
@@ -6,34 +6,57 @@
namespace
migraphx
{
template
<
index_int
Axis
,
class
F
,
class
Output
,
class
Input
,
class
...
Inputs
>
__device__
void
layernorm
(
F
compute
,
Output
output
,
Input
input
,
Inputs
...
inputs
)
template
<
index_int
Axis
,
class
F
,
class
BinOp
,
class
Output
,
class
Input1
,
class
Input2
,
class
...
Inputs
>
__device__
void
generic_binary_layernorm
(
F
compute
,
BinOp
op
,
Output
output
,
Input1
input1
,
Input2
input2
,
Inputs
...
inputs
)
{
using
reduce_output
=
reduce
::
with_axis
<
Input
,
Axis
>
;
using
reduce_output
=
reduce
::
with_axis
<
Input
1
,
Axis
>
;
constexpr
auto
relements
=
get_shape_c
<
Input
>
{}.
elements
()
/
get_shape_c
<
reduce_output
>
{}.
elements
();
get_shape_c
<
Input
1
>
{}.
elements
()
/
get_shape_c
<
reduce_output
>
{}.
elements
();
MIGRAPHX_ASSERT
(
relements
>
0
);
reduce
::
block
::
run
<
reduce_output
>
([
&
](
auto
,
auto
r
)
{
using
value_type
=
typename
Input
::
type
;
using
value_type
=
typename
Input
1
::
type
;
auto
mean
=
[
&
](
auto
f
)
{
return
r
.
reduce
(
op
::
sum
{},
0
,
[
&
](
auto
x
)
{
return
f
(
x
)
/
value_type
{
relements
};
})(
input
);
return
r
.
reduce
(
op
::
sum
{},
0
,
[
&
](
auto
x1
,
auto
x2
)
{
return
f
(
x1
,
x2
)
/
value_type
{
relements
};
})(
input1
,
input2
);
};
// mean(x)
auto
mean_x
=
mean
(
op
::
id
{}
);
auto
mean_x
=
mean
(
op
);
// mean(m ^ 2)
auto
mean_m2
=
mean
([
&
](
auto
x
)
{
auto
m
=
x
-
mean_x
;
auto
mean_m2
=
mean
([
&
](
auto
x
1
,
auto
x2
)
{
auto
m
=
op
(
x1
,
x2
)
-
mean_x
;
return
m
*
m
;
});
r
.
inner
([
&
](
auto
&
y
,
auto
x
,
auto
...
xs
)
{
auto
m
=
x
-
mean_x
;
r
.
inner
([
&
](
auto
&
y
,
auto
x
1
,
auto
x2
,
auto
...
xs
)
{
auto
m
=
op
(
x1
,
x2
)
-
mean_x
;
// m * rsqrt(mean(m ^ 2) + 1e-12)
y
=
compute
(
m
*
rsqrt
(
mean_m2
+
value_type
{
1e-12
}),
xs
...);
})(
output
,
input
,
inputs
...);
})(
output
,
input
1
,
input2
,
inputs
...);
});
}
template
<
index_int
Axis
,
class
F
,
class
Output
,
class
Input
,
class
...
Inputs
>
__device__
void
layernorm
(
F
compute
,
Output
output
,
Input
input
,
Inputs
...
inputs
)
{
generic_binary_layernorm
<
Axis
>
(
compute
,
[](
auto
x
,
auto
)
{
return
x
;
},
output
,
input
,
input
,
inputs
...);
}
template
<
index_int
Axis
,
class
F
,
class
Output
,
class
Input1
,
class
Input2
,
class
...
Inputs
>
__device__
void
add_layernorm
(
F
compute
,
Output
output
,
Input1
input1
,
Input2
input2
,
Inputs
...
inputs
)
{
generic_binary_layernorm
<
Axis
>
(
compute
,
[](
auto
x1
,
auto
x2
)
{
return
x1
+
x2
;
},
output
,
input1
,
input2
,
inputs
...);
}
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_LAYERNORM_HPP
src/targets/gpu/prefuse_ops.cpp
View file @
8076d7f7
...
...
@@ -30,13 +30,19 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
{
struct
layernorm
{
std
::
string
name
()
const
{
return
"gpu::prelayernorm"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
template
<
class
Derived
,
std
::
size_t
N
>
struct
layernorm_base
{
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
,
std
::
vector
<
module_ref
>
mods
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
std
::
size_t
nargs
=
1
;
if
(
not
mods
.
empty
())
{
auto
*
pm
=
mods
.
front
();
nargs
=
pm
->
get_parameter_names
().
size
();
}
check_shapes
{
inputs
,
static_cast
<
const
Derived
&>
(
*
this
)}.
has
(
nargs
+
N
);
auto
s
=
inputs
.
at
(
0
);
if
(
s
.
scalar
())
{
...
...
@@ -52,8 +58,19 @@ struct layernorm
}
}
};
struct
layernorm
:
layernorm_base
<
layernorm
,
0
>
{
std
::
string
name
()
const
{
return
"gpu::prelayernorm"
;
}
};
MIGRAPHX_REGISTER_OP
(
layernorm
);
struct
add_layernorm
:
layernorm_base
<
add_layernorm
,
1
>
{
std
::
string
name
()
const
{
return
"gpu::preadd_layernorm"
;
}
};
MIGRAPHX_REGISTER_OP
(
add_layernorm
);
struct
find_layernorm
{
auto
matcher
()
const
{
return
match
::
layernorm
();
}
...
...
@@ -67,6 +84,22 @@ struct find_layernorm
}
};
struct
find_add_layernorm
{
auto
matcher
()
const
{
return
match
::
layernorm
()(
match
::
var
(
"x"
)(
match
::
name
(
"add"
).
bind
(
"add"
)));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
add_ins
=
r
.
instructions
[
"add"
];
m
.
replace_instruction
(
ins
,
add_layernorm
{},
add_ins
->
inputs
());
}
};
struct
find_gpulayernorm
{
auto
matcher
()
const
{
return
match
::
layernorm
();
}
...
...
@@ -128,7 +161,7 @@ struct find_gputriaddlayernorm
void
prefuse_ops
::
apply
(
module
&
m
)
const
{
match
::
find_matches
(
m
,
find_layernorm
{});
match
::
find_matches
(
m
,
find_add_layernorm
{},
find_layernorm
{});
// match::find_matches(m, find_gputriaddlayernorm{}, find_gpulayernorm{});
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment