Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
d673e0c4
Commit
d673e0c4
authored
Oct 12, 2018
by
Paul
Browse files
Fix fusion for triadd
parent
270194c4
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
73 additions
and
6 deletions
+73
-6
src/targets/gpu/fuse_ops.cpp
src/targets/gpu/fuse_ops.cpp
+69
-4
src/targets/gpu/include/migraph/gpu/device/add_relu.hpp
src/targets/gpu/include/migraph/gpu/device/add_relu.hpp
+2
-0
src/targets/gpu/target.cpp
src/targets/gpu/target.cpp
+2
-2
No files found.
src/targets/gpu/fuse_ops.cpp
View file @
d673e0c4
...
...
@@ -136,6 +136,36 @@ MIGRAPH_PRED_MATCHER(fusable_conv, instruction_ref ins)
op
.
dilation
==
make_array
<
size_t
>
(
1
,
1
);
}
struct
hip_triadd
{
std
::
string
name
()
const
{
return
"hip::triadd"
;
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
4
);
return
inputs
.
front
();
}
argument
compute
(
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
args
)
const
{
device
::
add
(
args
.
at
(
3
),
args
.
at
(
0
),
args
.
at
(
1
),
args
.
at
(
2
));
return
args
.
at
(
3
);
}
};
struct
hip_triadd_relu
{
std
::
string
name
()
const
{
return
"hip::triadd_relu"
;
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
4
);
return
inputs
.
front
();
}
argument
compute
(
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
args
)
const
{
device
::
add_relu
(
args
.
at
(
3
),
args
.
at
(
0
),
args
.
at
(
1
),
args
.
at
(
2
));
return
args
.
at
(
3
);
}
};
struct
hip_add_relu
{
std
::
string
name
()
const
{
return
"hip::add_relu"
;
}
...
...
@@ -155,7 +185,7 @@ struct match_add_relu
{
auto
matcher
()
const
{
return
match
::
name
(
"gpu::relu"
)(
match
::
arg
(
0
)(
match
::
name
(
"gpu::add"
).
bind
(
"add"
)));
return
match
::
name
(
"gpu::relu"
)(
match
::
arg
(
0
)(
match
::
any_of
(
match
::
name
(
"gpu::add"
)
,
match
::
name
(
"hip::triadd"
))
.
bind
(
"add"
)));
}
void
apply
(
program
&
p
,
match
::
matcher_result
r
)
const
...
...
@@ -165,7 +195,36 @@ struct match_add_relu
auto
args
=
add_ins
->
inputs
();
// Use the allocation from the relu operator
args
.
back
()
=
ins
->
inputs
().
back
();
p
.
replace_instruction
(
ins
,
hip_add_relu
{},
args
);
if
(
add_ins
->
name
()
==
"gpu::add"
)
p
.
replace_instruction
(
ins
,
hip_add_relu
{},
args
);
else
if
(
add_ins
->
name
()
==
"hip::triadd"
)
p
.
replace_instruction
(
ins
,
hip_triadd_relu
{},
args
);
}
};
struct
match_triadd
{
auto
matcher
()
const
{
return
match
::
name
(
"gpu::add"
)(
match
::
either_arg
(
0
,
1
)(
match
::
name
(
"gpu::add"
).
bind
(
"add"
),
match
::
any
().
bind
(
"input"
)));
}
void
apply
(
program
&
p
,
match
::
matcher_result
r
)
const
{
auto
add_ins
=
r
.
instructions
[
"add"
];
auto
input_ins
=
r
.
instructions
[
"input"
];
auto
ins
=
r
.
result
;
auto
args
=
add_ins
->
inputs
();
auto
is_broadcasted
=
[](
auto
arg
)
{
return
arg
->
get_shape
().
broadcasted
();
};
if
(
std
::
count_if
(
args
.
begin
(),
args
.
end
(),
is_broadcasted
)
>
1
)
return
;
args
.
insert
(
args
.
begin
(),
input_ins
);
// Ensure the last arguments is the broadcasted one
auto
it
=
std
::
find_if
(
args
.
begin
(),
args
.
end
(),
is_broadcasted
);
if
(
it
!=
args
.
end
())
std
::
swap
(
*
it
,
*
std
::
prev
(
args
.
end
(),
2
));
args
.
back
()
=
ins
->
inputs
().
back
();
p
.
replace_instruction
(
ins
,
hip_triadd
{},
args
);
}
};
...
...
@@ -305,8 +364,14 @@ struct match_conv_bias_relu
void
fuse_ops
::
apply
(
program
&
p
)
const
{
// match::find_matches(p, match_conv_bias_relu{ctx}, match_conv_bias{ctx}, match_add_relu{});
match
::
find_matches
(
p
,
match_conv_bias
{
ctx
},
match_add_relu
{});
// clang-format off
match
::
find_matches
(
p
,
match_triadd
{},
match_conv_bias_relu
{
ctx
},
match_conv_bias
{
ctx
}
match_add_relu
{}
);
// clang-format on
}
}
// namespace gpu
...
...
src/targets/gpu/include/migraph/gpu/device/add_relu.hpp
View file @
d673e0c4
...
...
@@ -10,6 +10,8 @@ namespace device {
void
add_relu
(
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg2
);
void
add_relu
(
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg2
,
const
argument
&
arg3
);
}
// namespace device
}
// namespace gpu
}
// namespace migraph
...
...
src/targets/gpu/target.cpp
View file @
d673e0c4
...
...
@@ -35,10 +35,10 @@ std::vector<pass> target::get_passes(migraph::context& gctx) const
simplify_reshapes
{},
dead_code_elimination
{},
lowering
{
ctx
},
fuse_ops
{
&
ctx
},
dead_code_elimination
{},
eliminate_contiguous
{},
dead_code_elimination
{},
fuse_ops
{
&
ctx
},
dead_code_elimination
{},
write_literals
{
&
ctx
},
memory_coloring
{
"hip::allocate"
},
eliminate_workspace
{},
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment