Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
8bfea5f7
Commit
8bfea5f7
authored
Aug 03, 2022
by
turneram
Browse files
Merge remote-tracking branch 'origin/jit-layernorm' into debug-bert
parents
7e316254
a1663694
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
332 additions
and
100 deletions
+332
-100
src/targets/gpu/compile_gen.cpp
src/targets/gpu/compile_gen.cpp
+66
-19
src/targets/gpu/fuse_ops.cpp
src/targets/gpu/fuse_ops.cpp
+29
-2
src/targets/gpu/include/migraphx/gpu/compile_gen.hpp
src/targets/gpu/include/migraphx/gpu/compile_gen.hpp
+5
-0
src/targets/gpu/jit/layernorm.cpp
src/targets/gpu/jit/layernorm.cpp
+110
-0
src/targets/gpu/jit/pointwise.cpp
src/targets/gpu/jit/pointwise.cpp
+8
-40
src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp
...rgets/gpu/kernels/include/migraphx/kernels/functional.hpp
+3
-2
src/targets/gpu/kernels/include/migraphx/kernels/layernorm.hpp
...argets/gpu/kernels/include/migraphx/kernels/layernorm.hpp
+62
-0
src/targets/gpu/prefuse_ops.cpp
src/targets/gpu/prefuse_ops.cpp
+49
-37
No files found.
src/targets/gpu/compile_gen.cpp
View file @
8bfea5f7
...
@@ -25,6 +25,13 @@
...
@@ -25,6 +25,13 @@
#include <migraphx/shape.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/module.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -75,25 +82,26 @@ std::string vectorize::str() const
...
@@ -75,25 +82,26 @@ std::string vectorize::str() const
preload
preload
::
broadcasts
(
std
::
size_t
axis
,
const
std
::
vector
<
shape
>&
inputs
)
preload
preload
::
broadcasts
(
std
::
size_t
axis
,
const
std
::
vector
<
shape
>&
inputs
)
{
{
const
std
::
size_t
max_lds_bytes
=
4096
;
const
std
::
size_t
max_lds_bytes
=
4096
;
std
::
vector
<
bool
>
result
;
std
::
vector
<
bool
>
result
(
inputs
.
size
());
std
::
transform
(
inputs
.
begin
(),
std
::
vector
<
std
::
size_t
>
preloaded
;
inputs
.
end
(),
for
(
auto
i
:
range
(
inputs
.
size
()))
std
::
back_inserter
(
result
),
{
[
&
](
const
shape
&
input
)
{
return
input
.
strides
()[
axis
]
==
0
;
});
if
(
inputs
[
i
].
strides
()[
axis
]
==
0
)
auto
bytes
=
std
::
inner_product
(
inputs
.
begin
(),
preloaded
.
push_back
(
i
);
inputs
.
end
(),
}
result
.
begin
(),
std
::
sort
(
preloaded
.
begin
(),
preloaded
.
end
(),
by
(
std
::
less
<>
{},
[
&
](
auto
i
)
{
std
::
size_t
{
0
},
return
inputs
[
i
].
bytes
();
std
::
plus
<>
{},
}));
[](
const
shape
&
s
,
bool
b
)
->
std
::
size_t
{
if
(
b
)
std
::
size_t
bytes
=
0
;
return
s
.
bytes
();
for
(
auto
i
:
preloaded
)
return
0
;
{
});
auto
input
=
inputs
[
i
];
if
(
bytes
<
max_lds_bytes
)
bytes
+=
input
.
bytes
();
return
{
result
};
if
(
bytes
>
max_lds_bytes
)
// TODO: Try to partially preload items
break
;
std
::
fill
(
result
.
begin
(),
result
.
end
(),
false
);
result
[
i
]
=
true
;
}
return
{
result
};
return
{
result
};
}
}
...
@@ -125,6 +133,45 @@ std::string make_transformer_args(std::vector<std::string> transformers)
...
@@ -125,6 +133,45 @@ std::string make_transformer_args(std::vector<std::string> transformers)
return
join_strings
(
std
::
move
(
transformers
),
", "
);
return
join_strings
(
std
::
move
(
transformers
),
", "
);
}
}
std
::
string
generate_pointwise
(
const
module
&
pm
,
const
std
::
string
&
name
)
{
module
m
=
pm
;
run_passes
(
m
,
{
eliminate_common_subexpression
{},
dead_code_elimination
{}});
cpp_generator
g
;
g
.
fmap
([](
const
std
::
string
&
fname
)
{
return
"migraphx::"
+
fname
;
});
g
.
add_point_op
(
"where"
,
"${function:where}(${0}, ${1}, ${2})"
);
g
.
add_point_op
(
"prelu"
,
"${function:where}(${0} < 0, ${0} * ${1}, ${0})"
);
g
.
add_point_op
(
"sign"
,
"${function:where}(${0} > 0, 1, ${function:where}(${0} < 0, -1, 0))"
);
g
.
add_point_op
(
"equal"
,
"migraphx::abs(${0} == ${1})"
);
g
.
add_point_op
(
"less"
,
"migraphx::abs(${0} < ${1})"
);
g
.
add_point_op
(
"greater"
,
"migraphx::abs(${0} > ${1})"
);
g
.
add_point_op
(
"not"
,
"migraphx::abs(not ${0})"
);
// Add explict conversions
g
.
fresult
(
[](
const
shape
&
s
)
{
return
"migraphx::convert<"
+
shape
::
cpp_type
(
s
.
type
())
+
">"
;
});
g
.
create_function
(
g
.
generate_module
(
m
).
set_attributes
({
"__device__"
}).
set_generic_types
(
m
).
set_name
(
name
));
return
g
.
str
();
}
static
std
::
vector
<
std
::
string
>
get_op_names
(
const
module
&
m
)
{
std
::
vector
<
std
::
string
>
result
;
for
(
auto
&
ins
:
m
)
{
if
(
starts_with
(
ins
.
name
(),
"@"
))
continue
;
result
.
push_back
(
ins
.
name
());
}
return
result
;
}
std
::
string
generate_name_from_ops
(
const
module
&
m
)
{
auto
op_names
=
get_op_names
(
m
);
return
join_strings
(
op_names
,
"_"
);
}
}
// namespace gen
}
// namespace gen
}
// namespace gpu
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/targets/gpu/fuse_ops.cpp
View file @
8bfea5f7
...
@@ -827,13 +827,14 @@ void apply_conv_bias(context& ctx, module& m, const match::matcher_result& r)
...
@@ -827,13 +827,14 @@ void apply_conv_bias(context& ctx, module& m, const match::matcher_result& r)
m
.
replace_instruction
(
ins
,
cb
,
input_ins
,
weights_ins
,
old_ws_ins
,
bias_ins
,
alloc_ins
);
m
.
replace_instruction
(
ins
,
cb
,
input_ins
,
weights_ins
,
old_ws_ins
,
bias_ins
,
alloc_ins
);
}
}
inline
auto
precompile_name
(
std
::
string
s
)
// NOLINT
template
<
class
...
Strings
>
inline
auto
precompile_name
(
Strings
...
names
)
// NOLINT
{
{
return
match
::
make_basic_pred_matcher
([
=
](
instruction_ref
ins
)
{
return
match
::
make_basic_pred_matcher
([
=
](
instruction_ref
ins
)
{
if
(
ins
->
name
()
!=
"gpu::precompile_op"
)
if
(
ins
->
name
()
!=
"gpu::precompile_op"
)
return
false
;
return
false
;
auto
op
=
from_value
<
operation
>
(
ins
->
get_operator
().
to_value
().
at
(
"op"
));
auto
op
=
from_value
<
operation
>
(
ins
->
get_operator
().
to_value
().
at
(
"op"
));
return
(
op
.
name
()
==
s
);
return
(
contains
({
names
...},
op
.
name
()
)
);
});
});
}
}
...
@@ -1041,6 +1042,31 @@ struct find_contiguous_pointwise
...
@@ -1041,6 +1042,31 @@ struct find_contiguous_pointwise
}
}
};
};
struct
find_layernorm_pointwise
{
auto
matcher
()
const
{
return
precompile_name
(
"pointwise"
)(
match
::
arg
(
0
)(
precompile_name
(
"gpu::prelayernorm"
,
"gpu::preadd_layernorm"
).
bind
(
"layernorm"
)));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
layernorm
=
r
.
instructions
[
"layernorm"
];
auto
*
pm
=
ins
->
module_inputs
().
front
();
if
(
not
layernorm
->
module_inputs
().
empty
())
return
;
auto
inputs
=
layernorm
->
inputs
();
inputs
.
pop_back
();
inputs
.
insert
(
inputs
.
end
(),
ins
->
inputs
().
begin
()
+
1
,
ins
->
inputs
().
end
());
m
.
replace_instruction
(
ins
,
layernorm
->
get_operator
(),
inputs
,
{
pm
});
}
};
void
fuse_ops
::
apply
(
module
&
m
)
const
void
fuse_ops
::
apply
(
module
&
m
)
const
{
{
match
::
find_matches
(
m
,
find_contiguous_pointwise
{},
find_gelu
{},
find_gelu_new
{
fast_math
});
match
::
find_matches
(
m
,
find_contiguous_pointwise
{},
find_gelu
{},
find_gelu_new
{
fast_math
});
...
@@ -1063,6 +1089,7 @@ void fuse_ops::apply(module& m) const
...
@@ -1063,6 +1089,7 @@ void fuse_ops::apply(module& m) const
match
::
find_matches
(
m
,
match
::
find_matches
(
m
,
find_triadd_layernorm
{},
find_triadd_layernorm
{},
find_gemm_add
{},
find_gemm_add
{},
find_layernorm_pointwise
{},
find_gemm_pointwise
{},
find_gemm_pointwise
{},
find_commutative_broadcast
{});
find_commutative_broadcast
{});
match
::
find_matches
(
m
,
find_contiguous
{});
match
::
find_matches
(
m
,
find_contiguous
{});
...
...
src/targets/gpu/include/migraphx/gpu/compile_gen.hpp
View file @
8bfea5f7
...
@@ -25,6 +25,7 @@
...
@@ -25,6 +25,7 @@
#define MIGRAPHX_GUARD_GPU_COMPILE_GEN_HPP
#define MIGRAPHX_GUARD_GPU_COMPILE_GEN_HPP
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/module_ref.hpp>
#include <string>
#include <string>
#include <unordered_map>
#include <unordered_map>
#include <vector>
#include <vector>
...
@@ -62,6 +63,10 @@ std::string make_transformer_args(Ts... xs)
...
@@ -62,6 +63,10 @@ std::string make_transformer_args(Ts... xs)
return
make_transformer_args
({
xs
.
str
()...});
return
make_transformer_args
({
xs
.
str
()...});
}
}
std
::
string
generate_pointwise
(
const
module
&
pm
,
const
std
::
string
&
name
);
std
::
string
generate_name_from_ops
(
const
module
&
m
);
}
// namespace gen
}
// namespace gen
}
// namespace gpu
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/targets/gpu/jit/layernorm.cpp
0 → 100644
View file @
8bfea5f7
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
using
namespace
migraphx
::
gpu
::
gen
;
// NOLINT
static
const
char
*
const
layernorm_kernel
=
R"__migraphx__(
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/layernorm.hpp>
#include <migraphx/kernels/vectorize.hpp>
#include <migraphx/kernels/preload.hpp>
#include <args.hpp>
namespace migraphx {
${preamble}
extern "C" {
__global__ void ${kernel}(${params})
{
auto idx = make_index();
transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto... xs) {
${layernorm}<${axis}>(${post}, xs...);
});
}
}
} // namespace migraphx
)__migraphx__"
;
struct
layernorm_compiler
:
compiler
<
layernorm_compiler
>
{
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"layernorm"
,
"gpu::prelayernorm"
,
"gpu::preadd_layernorm"
};
}
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
{
// TODO: Use reduce_dims
auto
axis
=
inputs
.
front
().
lens
().
size
()
-
1
;
auto
faxis
=
find_fast_axis
({
inputs
.
front
()});
vectorize
vec
{};
// Vectorize if the axis is a reduction axis
if
(
axis
==
faxis
)
{
vec
=
vectorize
::
elements
(
faxis
,
inputs
);
}
auto
preloads
=
preload
::
broadcasts
(
axis
,
inputs
);
auto
relements
=
inputs
[
0
].
lens
()[
axis
]
/
vec
.
size
;
auto
nelements
=
(
inputs
.
back
().
elements
()
/
inputs
[
0
].
lens
()[
axis
]);
auto
block_size
=
compute_block_size
(
relements
,
256
);
hip_compile_options
options
;
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
nelements
*
block_size
,
256
),
block_size
);
options
.
output
=
inputs
.
back
();
options
.
inputs
=
inputs
;
options
.
kernel_name
=
v
.
get
(
"kernel"
,
"layernorm_kernel"
);
auto
src
=
interpolate_string
(
layernorm_kernel
,
{{
"kernel"
,
options
.
kernel_name
},
{
"params"
,
enum_params
(
inputs
.
size
(),
"void * private_p"
)},
{
"args"
,
enum_params
(
inputs
.
size
(),
"private_p"
)},
{
"transformers"
,
make_transformer_args
(
preloads
,
vec
)},
{
"post"
,
v
.
get
(
"post"
,
std
::
string
{
"op::id{}"
})},
{
"preamble"
,
v
.
get
(
"preamble"
,
std
::
string
{})},
{
"layernorm"
,
v
.
get
(
"layernorm"
,
std
::
string
{
"layernorm"
})},
{
"axis"
,
to_string
(
axis
)}});
return
compile_hip_code_object
(
src
,
options
);
}
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
)
const
{
auto
v
=
op
.
to_value
();
v
[
"layernorm"
]
=
"layernorm"
;
v
[
"kernel"
]
=
"layernorm_kernel"
;
if
(
op
.
name
()
==
"gpu::preadd_layernorm"
)
{
v
[
"layernorm"
]
=
"add_layernorm"
;
v
[
"kernel"
]
=
"add_layernorm_kernel"
;
}
if
(
not
ins
->
module_inputs
().
empty
())
{
auto
*
pm
=
ins
->
module_inputs
().
front
();
v
[
"preamble"
]
=
generate_pointwise
(
*
pm
,
"post_layernorm"
);
v
[
"post"
]
=
"MIGRAPHX_LIFT(post_layernorm)"
;
v
[
"kernel"
]
=
v
[
"layernorm"
].
to
<
std
::
string
>
()
+
"_"
+
generate_name_from_ops
(
*
pm
)
+
"_kernel"
;
}
return
replace
(
compile_op
(
ctx
,
to_shapes
(
ins
->
inputs
()),
v
));
}
};
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/jit/pointwise.cpp
View file @
8bfea5f7
...
@@ -65,18 +65,6 @@ __global__ void ${kernel}(${params})
...
@@ -65,18 +65,6 @@ __global__ void ${kernel}(${params})
)__migraphx__"
;
)__migraphx__"
;
static
std
::
vector
<
std
::
string
>
get_op_names
(
const
module
&
m
)
{
std
::
vector
<
std
::
string
>
result
;
for
(
auto
&
ins
:
m
)
{
if
(
starts_with
(
ins
.
name
(),
"@"
))
continue
;
result
.
push_back
(
ins
.
name
());
}
return
result
;
}
struct
pointwise_compiler
:
compiler
<
pointwise_compiler
>
struct
pointwise_compiler
:
compiler
<
pointwise_compiler
>
{
{
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"pointwise"
,
"contiguous"
};
}
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"pointwise"
,
"contiguous"
};
}
...
@@ -126,34 +114,14 @@ struct pointwise_compiler : compiler<pointwise_compiler>
...
@@ -126,34 +114,14 @@ struct pointwise_compiler : compiler<pointwise_compiler>
else
else
{
{
assert
(
not
ins
->
module_inputs
().
empty
());
assert
(
not
ins
->
module_inputs
().
empty
());
auto
*
pm
=
ins
->
module_inputs
().
front
();
auto
*
pm
=
ins
->
module_inputs
().
front
();
run_passes
(
*
pm
,
{
eliminate_common_subexpression
{},
dead_code_elimination
{}});
auto
pf
=
generate_pointwise
(
*
pm
,
"inner_pointwise"
);
cpp_generator
g
;
std
::
string
lambda
=
"MIGRAPHX_LIFT(inner_pointwise)"
;
g
.
fmap
([](
const
std
::
string
&
fname
)
{
return
"migraphx::"
+
fname
;
});
auto
kernel_name
=
generate_name_from_ops
(
*
pm
)
+
"_kernel"
;
g
.
add_point_op
(
"where"
,
"${function:where}(${0}, ${1}, ${2})"
);
return
replace
(
g
.
add_point_op
(
"prelu"
,
"${function:where}(${0} < 0, ${0} * ${1}, ${0})"
);
compile_op
(
ctx
,
g
.
add_point_op
(
"sign"
,
to_shapes
(
ins
->
inputs
()),
"${function:where}(${0} > 0, 1, ${function:where}(${0} < 0, -1, 0))"
);
{{
"lambda"
,
lambda
},
{
"preamble"
,
pf
},
{
"kernel"
,
kernel_name
}}));
g
.
add_point_op
(
"equal"
,
"migraphx::abs(${0} == ${1})"
);
g
.
add_point_op
(
"less"
,
"migraphx::abs(${0} < ${1})"
);
g
.
add_point_op
(
"greater"
,
"migraphx::abs(${0} > ${1})"
);
g
.
add_point_op
(
"not"
,
"migraphx::abs(not ${0})"
);
g
.
add_point_op
(
"mod"
,
"migraphx::mod(${0}, ${1})"
);
g
.
add_point_op
(
"fmod"
,
"migraphx::fmod(${0}, ${1})"
);
// Add explict conversions
g
.
fresult
([](
const
shape
&
s
)
{
return
"migraphx::convert<"
+
shape
::
cpp_type
(
s
.
type
())
+
">"
;
});
auto
name
=
g
.
create_function
(
g
.
generate_module
(
*
pm
).
set_attributes
({
"__device__"
}).
set_generic_types
(
*
pm
));
std
::
string
lambda
=
"MIGRAPHX_LIFT("
+
name
+
")"
;
auto
op_names
=
get_op_names
(
*
pm
);
op_names
.
push_back
(
"kernel"
);
auto
op_name_string
=
join_strings
(
op_names
,
"_"
);
return
replace
(
compile_op
(
ctx
,
to_shapes
(
ins
->
inputs
()),
{{
"lambda"
,
lambda
},
{
"preamble"
,
g
.
str
()},
{
"kernel"
,
op_name_string
}}));
}
}
}
}
};
};
...
...
src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp
View file @
8bfea5f7
...
@@ -31,8 +31,9 @@
...
@@ -31,8 +31,9 @@
->decltype(__VA_ARGS__) { return __VA_ARGS__; }
->decltype(__VA_ARGS__) { return __VA_ARGS__; }
// NOLINTNEXTLINE
// NOLINTNEXTLINE
#define MIGRAPHX_LIFT(...) \
#define MIGRAPHX_LIFT(...) \
[](auto&&... xs) MIGRAPHX_RETURNS((__VA_ARGS__)(static_cast<decltype(xs)>(xs)...))
[](auto&&... private_lisft_xs) MIGRAPHX_RETURNS( \
(__VA_ARGS__)(static_cast<decltype(private_lisft_xs)>(private_lisft_xs)...))
namespace
migraphx
{
namespace
migraphx
{
...
...
src/targets/gpu/kernels/include/migraphx/kernels/layernorm.hpp
0 → 100644
View file @
8bfea5f7
#ifndef MIGRAPHX_GUARD_KERNELS_LAYERNORM_HPP
#define MIGRAPHX_GUARD_KERNELS_LAYERNORM_HPP
#include <migraphx/kernels/reduce.hpp>
#include <migraphx/kernels/ops.hpp>
#include <migraphx/kernels/print.hpp>
namespace
migraphx
{
template
<
index_int
Axis
,
class
F
,
class
BinOp
,
class
Output
,
class
Input1
,
class
Input2
,
class
...
Inputs
>
__device__
void
generic_binary_layernorm
(
F
compute
,
BinOp
op
,
Output
output
,
Input1
input1
,
Input2
input2
,
Inputs
...
inputs
)
{
using
reduce_output
=
reduce
::
with_axis
<
Input1
,
Axis
>
;
constexpr
auto
relements
=
get_shape_c
<
Input1
>
{}.
elements
()
/
get_shape_c
<
reduce_output
>
{}.
elements
();
MIGRAPHX_ASSERT
(
relements
>
0
);
reduce
::
block
::
run
<
reduce_output
>
([
&
](
auto
,
auto
r
)
{
using
value_type
=
typename
Input1
::
type
;
auto
mean
=
[
&
](
auto
f
)
{
return
r
.
reduce
(
op
::
sum
{},
0
,
[
&
](
auto
x1
,
auto
x2
)
{
return
f
(
x1
,
x2
)
/
value_type
{
relements
};
})(
input1
,
input2
);
};
// mean(x)
auto
mean_x
=
mean
(
op
);
// mean(m ^ 2)
auto
mean_m2
=
mean
([
&
](
auto
x1
,
auto
x2
)
{
auto
m
=
op
(
x1
,
x2
)
-
mean_x
;
return
m
*
m
;
});
r
.
inner
([
&
](
auto
&
y
,
auto
x1
,
auto
x2
,
auto
...
xs
)
{
auto
m
=
op
(
x1
,
x2
)
-
mean_x
;
// m * rsqrt(mean(m ^ 2) + 1e-12)
y
=
compute
(
m
*
rsqrt
(
mean_m2
+
value_type
{
1e-12
}),
xs
...);
})(
output
,
input1
,
input2
,
inputs
...);
});
}
template
<
index_int
Axis
,
class
F
,
class
Output
,
class
Input
,
class
...
Inputs
>
__device__
void
layernorm
(
F
compute
,
Output
output
,
Input
input
,
Inputs
...
inputs
)
{
generic_binary_layernorm
<
Axis
>
(
compute
,
[](
auto
x
,
auto
)
{
return
x
;
},
output
,
input
,
input
,
inputs
...);
}
template
<
index_int
Axis
,
class
F
,
class
Output
,
class
Input1
,
class
Input2
,
class
...
Inputs
>
__device__
void
add_layernorm
(
F
compute
,
Output
output
,
Input1
input1
,
Input2
input2
,
Inputs
...
inputs
)
{
generic_binary_layernorm
<
Axis
>
(
compute
,
[](
auto
x1
,
auto
x2
)
{
return
x1
+
x2
;
},
output
,
input1
,
input2
,
inputs
...);
}
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_LAYERNORM_HPP
src/targets/gpu/prefuse_ops.cpp
View file @
8bfea5f7
...
@@ -24,12 +24,53 @@
...
@@ -24,12 +24,53 @@
#include <migraphx/gpu/prefuse_ops.hpp>
#include <migraphx/gpu/prefuse_ops.hpp>
#include <migraphx/match/layernorm.hpp>
#include <migraphx/match/layernorm.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
gpu
{
namespace
{
namespace
{
template
<
class
Derived
,
std
::
size_t
N
>
struct
layernorm_base
{
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
,
std
::
vector
<
module_ref
>
mods
)
const
{
std
::
size_t
nargs
=
1
;
if
(
not
mods
.
empty
())
{
auto
*
pm
=
mods
.
front
();
nargs
=
pm
->
get_parameter_names
().
size
();
}
check_shapes
{
inputs
,
static_cast
<
const
Derived
&>
(
*
this
)}.
has
(
nargs
+
N
);
auto
s
=
inputs
.
at
(
0
);
if
(
s
.
scalar
())
{
return
s
;
}
else
if
(
s
.
broadcasted
())
{
return
{
s
.
type
(),
s
.
lens
()};
}
else
{
return
s
.
with_lens
(
s
.
lens
());
}
}
};
struct
layernorm
:
layernorm_base
<
layernorm
,
0
>
{
std
::
string
name
()
const
{
return
"gpu::prelayernorm"
;
}
};
MIGRAPHX_REGISTER_OP
(
layernorm
);
struct
add_layernorm
:
layernorm_base
<
add_layernorm
,
1
>
{
std
::
string
name
()
const
{
return
"gpu::preadd_layernorm"
;
}
};
MIGRAPHX_REGISTER_OP
(
add_layernorm
);
struct
find_layernorm
struct
find_layernorm
{
{
auto
matcher
()
const
{
return
match
::
layernorm
();
}
auto
matcher
()
const
{
return
match
::
layernorm
();
}
...
@@ -39,59 +80,30 @@ struct find_layernorm
...
@@ -39,59 +80,30 @@ struct find_layernorm
auto
ins
=
r
.
result
;
auto
ins
=
r
.
result
;
auto
x_ins
=
r
.
instructions
[
"x"
];
auto
x_ins
=
r
.
instructions
[
"x"
];
if
(
not
x_ins
->
get_shape
().
standard
())
m
.
replace_instruction
(
ins
,
layernorm
{},
x_ins
);
x_ins
=
m
.
insert_instruction
(
ins
,
make_op
(
"contiguous"
),
x_ins
);
auto
relements
=
x_ins
->
get_shape
().
lens
().
back
();
if
(
relements
>
1024
or
(
relements
%
4
!=
0
and
relements
>
256
))
return
;
auto
a
=
m
.
insert_instruction
(
ins
,
make_op
(
"hip::allocate"
,
{{
"shape"
,
to_value
(
x_ins
->
get_shape
())}}));
m
.
replace_instruction
(
ins
,
make_op
(
"gpu::layernorm"
),
x_ins
,
a
);
}
}
};
};
struct
find_
tri
addlayernorm
struct
find_add
_
layernorm
{
{
auto
matcher
()
const
auto
matcher
()
const
{
{
auto
add1
=
return
match
::
layernorm
()(
match
::
var
(
"x"
)(
match
::
name
(
"add"
).
bind
(
"add"
)));
match
::
name
(
"add"
)(
match
::
none_of
(
match
::
is_constant
()),
match
::
args
(
match
::
any
().
bind
(
"z1"
),
match
::
any
().
bind
(
"z2"
)));
auto
add2
=
match
::
name
(
"add"
)(
match
::
either_arg
(
0
,
1
)(
add1
,
match
::
any
().
bind
(
"z3"
)));
return
match
::
layernorm
()(
match
::
var
(
"x"
)(
add2
));
}
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
{
auto
ins
=
r
.
result
;
auto
ins
=
r
.
result
;
auto
x_ins
=
r
.
instructions
[
"z1"
];
auto
add_ins
=
r
.
instructions
[
"add"
];
auto
y_ins
=
r
.
instructions
[
"z2"
];
auto
z_ins
=
r
.
instructions
[
"z3"
];
for
(
auto
*
pins
:
{
&
x_ins
,
&
y_ins
,
&
z_ins
})
{
if
(
not
(
*
pins
)
->
get_shape
().
standard
())
*
pins
=
m
.
insert_instruction
(
ins
,
make_op
(
"contiguous"
),
*
pins
);
}
auto
relements
=
x_ins
->
get_shape
().
lens
().
back
();
if
(
relements
>
1024
or
(
relements
%
4
!=
0
and
relements
>
256
))
return
;
auto
a
=
m
.
insert_instruction
(
m
.
replace_instruction
(
ins
,
add_layernorm
{},
add_ins
->
inputs
());
ins
,
make_op
(
"hip::allocate"
,
{{
"shape"
,
to_value
(
x_ins
->
get_shape
())}}));
m
.
replace_instruction
(
ins
,
make_op
(
"gpu::triadd_layernorm"
),
x_ins
,
y_ins
,
z_ins
,
a
);
}
}
};
};
}
// namespace
}
// namespace
void
prefuse_ops
::
apply
(
module
&
m
)
const
void
prefuse_ops
::
apply
(
module
&
m
)
const
{
{
match
::
find_matches
(
m
,
find_
tri
addlayernorm
{},
find_layernorm
{});
match
::
find_matches
(
m
,
find_add
_
layernorm
{},
find_layernorm
{});
}
}
}
// namespace gpu
}
// namespace gpu
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment