Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
33189188
Commit
33189188
authored
Jun 07, 2022
by
Paul
Browse files
Merge branch 'dot-add' into bert-opt
parents
5d6880e5
1dda8943
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
125 additions
and
15 deletions
+125
-15
src/simplify_algebra.cpp
src/simplify_algebra.cpp
+59
-13
src/targets/gpu/fuse_ops.cpp
src/targets/gpu/fuse_ops.cpp
+66
-2
No files found.
src/simplify_algebra.cpp
View file @
33189188
...
@@ -185,6 +185,42 @@ struct find_mul_add
...
@@ -185,6 +185,42 @@ struct find_mul_add
}
}
};
};
struct
find_dot_add
{
auto
matcher
()
const
{
return
match
::
name
(
"dot"
)(
match
::
either_arg
(
0
,
1
)(
match
::
name
(
"add"
)(
match
::
either_arg
(
0
,
1
)(
match
::
any
().
bind
(
"x"
),
match
::
any_of
(
match
::
is_constant
()).
bind
(
"b"
)),
match
::
none_of
(
match
::
args
(
match
::
is_constant
(),
match
::
is_constant
())),
match
::
used_once
()),
match
::
is_constant
().
bind
(
"a"
)));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
a_ins
=
r
.
instructions
[
"a"
];
auto
b_ins
=
r
.
instructions
[
"b"
];
auto
x_ins
=
r
.
instructions
[
"x"
];
assert
(
x_ins
!=
b_ins
);
const
bool
flipped
=
a_ins
==
ins
->
inputs
().
back
();
auto
insert_dot
=
[
&
](
auto
x
,
auto
y
)
{
if
(
flipped
)
return
m
.
insert_instruction
(
ins
,
make_op
(
"dot"
),
y
,
x
);
else
return
m
.
insert_instruction
(
ins
,
make_op
(
"dot"
),
x
,
y
);
};
auto
ax_ins
=
insert_dot
(
a_ins
,
x_ins
);
auto
ab_ins
=
insert_dot
(
a_ins
,
b_ins
);
m
.
replace_instruction
(
ins
,
make_op
(
"add"
),
ax_ins
,
ab_ins
);
}
};
struct
find_add_lit_broadcast
struct
find_add_lit_broadcast
{
{
auto
matcher
()
const
auto
matcher
()
const
...
@@ -246,26 +282,35 @@ struct find_inner_broadcast
...
@@ -246,26 +282,35 @@ struct find_inner_broadcast
{
{
auto
matcher
()
const
auto
matcher
()
const
{
{
return
pointwise
(
return
pointwise
(
match
::
all_of
[
match
::
inputs
()](
match
::
nargs
(
2
),
match
::
broadcast_shape
(),
match
::
name
(
"broadcast"
,
"multibroadcast"
)));
match
::
args
(
match
::
name
(
"broadcast"
).
bind
(
"x"
),
match
::
name
(
"broadcast"
).
bind
(
"y"
)));
}
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
{
auto
ins
=
r
.
result
;
auto
ins
=
r
.
result
;
auto
x_ins
=
r
.
instructions
[
"x"
];
auto
inputs
=
ins
->
inputs
();
auto
y_ins
=
r
.
instructions
[
"y"
];
if
(
inputs
.
empty
())
return
;
auto
xbroadcast
=
any_cast
<
op
::
broadcast
>
(
x_ins
->
get_operator
());
std
::
transform
(
inputs
.
begin
(),
inputs
.
end
(),
inputs
.
begin
(),
[
&
](
auto
i
)
{
auto
ybroadcast
=
any_cast
<
op
::
broadcast
>
(
y_ins
->
get_operator
());
if
(
contains
({
"broadcast"
,
"multibroadcast"
},
i
->
name
()))
return
i
->
inputs
().
front
();
else
return
i
;
});
if
(
xbroadcast
.
axis
!=
ybroadcast
.
axis
)
if
(
not
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
&
](
auto
&
x
)
{
return
x
->
get_shape
()
==
inputs
.
front
()
->
get_shape
();
}))
return
;
return
;
auto
op
=
m
.
insert_instruction
(
auto
op
=
m
.
insert_instruction
(
ins
,
ins
->
get_operator
(),
inputs
);
ins
,
ins
->
get_operator
(),
x_ins
->
inputs
().
front
(),
y_ins
->
inputs
().
front
());
auto
bop
=
std
::
find_if
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
[
&
](
auto
i
)
{
m
.
replace_instruction
(
ins
,
xbroadcast
,
op
);
return
contains
({
"broadcast"
,
"multibroadcast"
},
i
->
name
());
});
assert
(
bop
!=
ins
->
inputs
().
end
());
m
.
replace_instruction
(
ins
,
(
*
bop
)
->
get_operator
(),
op
);
}
}
};
};
...
@@ -1025,6 +1070,7 @@ void simplify_algebra::apply(module& m) const
...
@@ -1025,6 +1070,7 @@ void simplify_algebra::apply(module& m) const
find_mul_conv
{},
find_mul_conv
{},
find_mul_slice_conv
{},
find_mul_slice_conv
{},
find_mul_add
{},
find_mul_add
{},
find_dot_add
{},
find_div_const
{},
find_div_const
{},
find_sub_const
{},
find_sub_const
{},
find_rsqrt
{},
find_rsqrt
{},
...
...
src/targets/gpu/fuse_ops.cpp
View file @
33189188
...
@@ -256,6 +256,11 @@ MIGRAPHX_REGISTER_OP(hip_layernorm)
...
@@ -256,6 +256,11 @@ MIGRAPHX_REGISTER_OP(hip_layernorm)
struct
hip_triadd_layernorm
:
ternary_device
<
hip_triadd_layernorm
,
&
device
::
triadd_layernorm
>
struct
hip_triadd_layernorm
:
ternary_device
<
hip_triadd_layernorm
,
&
device
::
triadd_layernorm
>
{
{
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
4
).
standard
();
return
inputs
[
0
];
}
// Empty finalize to skip dimension reduction
// Empty finalize to skip dimension reduction
void
finalize
(
context
&
,
const
shape
&
,
const
std
::
vector
<
shape
>&
)
{}
void
finalize
(
context
&
,
const
shape
&
,
const
std
::
vector
<
shape
>&
)
{}
};
};
...
@@ -943,13 +948,68 @@ struct find_gemm_pointwise
...
@@ -943,13 +948,68 @@ struct find_gemm_pointwise
{
{
auto
matcher
()
const
auto
matcher
()
const
{
{
return
p
ointwise_name
(
"add
"
)(
return
p
recompile_name
(
"pointwise
"
)(
match
::
nargs
(
3
),
match
::
nargs
(
3
),
match
::
all_of
[
match
::
inputs
()](
match
::
standard_shape
()),
match
::
all_of
[
match
::
inputs
()](
match
::
standard_shape
()),
match
::
either_arg
(
0
,
1
)(
match
::
used_once
().
bind
(
"c"
),
match
::
either_arg
(
0
,
1
)(
match
::
used_once
().
bind
(
"c"
),
match
::
name
(
"gpu::gemm"
)(
match
::
nargs
(
3
)).
bind
(
"gemm"
)));
match
::
name
(
"gpu::gemm"
)(
match
::
nargs
(
3
)).
bind
(
"gemm"
)));
}
}
// TODO: Move to matcher.hpp
static
auto
match_param
(
const
std
::
string
&
name
)
{
return
match
::
make_basic_pred_matcher
([
=
](
auto
ins
)
{
if
(
ins
->
name
()
!=
"@param"
)
return
false
;
auto
p
=
any_cast
<
builtin
::
param
>
(
ins
->
get_operator
());
return
p
.
parameter
==
name
;
});
}
template
<
class
M
>
static
auto
match_mul_const
(
M
m
,
const
std
::
string
&
var
)
{
return
match
::
name
(
"mul"
)(
match
::
either_arg
(
0
,
1
)(
match
::
name
(
"@literal"
).
bind
(
var
),
m
))
.
bind
(
var
+
"_mul"
);
}
static
auto
match_add
(
const
std
::
string
&
input
,
const
std
::
string
&
output
)
{
auto
param
=
match
::
name
(
"@param"
);
auto
add
=
match
::
name
(
"add"
)(
match
::
args
(
param
,
param
));
auto
inner_mul
=
match
::
any_of
(
match_mul_const
(
match_param
(
input
),
"alpha"
),
match_mul_const
(
match_param
(
output
),
"beta"
));
auto
mul_add
=
match
::
name
(
"add"
)(
match
::
either_arg
(
0
,
1
)(
inner_mul
,
param
));
auto
add_mul
=
match_mul_const
(
add
,
"gamma"
);
return
match
::
name
(
"@return"
)(
match
::
args
(
match
::
any_of
(
add
,
mul_add
,
add_mul
)));
}
static
float
get_float
(
instruction_ref
ins
)
{
return
ins
->
get_literal
().
at
<
float
>
();
}
template
<
class
Gemm
>
static
bool
update_gemm
(
Gemm
&
gemm
,
module_ref
pm
,
unsigned
input
)
{
auto
names
=
pm
->
get_parameter_names
();
if
(
names
.
size
()
!=
2
)
return
false
;
std
::
sort
(
names
.
begin
(),
names
.
end
());
unsigned
output
=
input
==
0
?
1
:
0
;
auto
mr
=
match
::
match_instruction
(
*
pm
,
std
::
prev
(
pm
->
end
()),
match_add
(
names
[
input
],
names
[
output
]));
if
(
mr
.
result
==
pm
->
end
())
return
false
;
if
(
contains
(
mr
.
instructions
,
"alpha_mul"
))
gemm
.
alpha
*=
get_float
(
mr
.
instructions
[
"alpha"
]);
else
if
(
contains
(
mr
.
instructions
,
"beta_mul"
))
gemm
.
beta
*=
get_float
(
mr
.
instructions
[
"beta"
]);
else
if
(
contains
(
mr
.
instructions
,
"gamma_mul"
))
{
gemm
.
alpha
*=
get_float
(
mr
.
instructions
[
"gamma"
]);
gemm
.
beta
*=
get_float
(
mr
.
instructions
[
"gamma"
]);
}
return
true
;
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
{
auto
ins
=
r
.
result
;
auto
ins
=
r
.
result
;
...
@@ -961,6 +1021,11 @@ struct find_gemm_pointwise
...
@@ -961,6 +1021,11 @@ struct find_gemm_pointwise
// Already fused gemm
// Already fused gemm
if
(
not
float_equal
(
gemm
.
beta
,
0
))
if
(
not
float_equal
(
gemm
.
beta
,
0
))
return
;
return
;
gemm
.
beta
=
1
;
if
(
not
update_gemm
(
gemm
,
ins
->
module_inputs
().
front
(),
ins
->
inputs
().
front
()
==
gemm_ins
?
0
:
1
))
return
;
auto
inputs
=
gemm_ins
->
inputs
();
auto
inputs
=
gemm_ins
->
inputs
();
inputs
.
pop_back
();
inputs
.
pop_back
();
...
@@ -968,7 +1033,6 @@ struct find_gemm_pointwise
...
@@ -968,7 +1033,6 @@ struct find_gemm_pointwise
inputs
.
push_back
(
c_ins
);
inputs
.
push_back
(
c_ins
);
inputs
.
push_back
(
ins
->
inputs
().
back
());
inputs
.
push_back
(
ins
->
inputs
().
back
());
gemm
.
beta
=
1
;
m
.
replace_instruction
(
ins
,
gemm
,
inputs
);
m
.
replace_instruction
(
ins
,
gemm
,
inputs
);
}
}
};
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment