Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
ecfb0b72
Commit
ecfb0b72
authored
Jun 29, 2022
by
Paul
Browse files
Merge branch 'dot-add' into bert-opt
parents
98229c34
d6523b49
Changes
31
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
313 additions
and
99 deletions
+313
-99
src/include/migraphx/stringutils.hpp
src/include/migraphx/stringutils.hpp
+8
-2
src/module.cpp
src/module.cpp
+52
-51
src/pass_manager.cpp
src/pass_manager.cpp
+25
-9
src/program.cpp
src/program.cpp
+26
-2
src/simplify_algebra.cpp
src/simplify_algebra.cpp
+16
-26
src/targets/gpu/fuse_ops.cpp
src/targets/gpu/fuse_ops.cpp
+1
-0
src/targets/gpu/gemm_impl.cpp
src/targets/gpu/gemm_impl.cpp
+11
-8
src/targets/gpu/include/migraphx/gpu/gemm.hpp
src/targets/gpu/include/migraphx/gpu/gemm.hpp
+2
-1
test/ref_ops_test.cpp
test/ref_ops_test.cpp
+74
-0
test/verify/gemm_add_broadcast1.cpp
test/verify/gemm_add_broadcast1.cpp
+49
-0
test/verify/gemm_add_broadcast2.cpp
test/verify/gemm_add_broadcast2.cpp
+49
-0
No files found.
src/include/migraphx/stringutils.hpp
View file @
ecfb0b72
...
@@ -44,8 +44,8 @@ auto with_char(F f)
...
@@ -44,8 +44,8 @@ auto with_char(F f)
return
[
=
](
unsigned
char
c
)
->
bool
{
return
f
(
c
);
};
return
[
=
](
unsigned
char
c
)
->
bool
{
return
f
(
c
);
};
}
}
inline
std
::
string
inline
void
replace_string
(
std
::
string
subject
,
const
std
::
string
&
search
,
const
std
::
string
&
replace
)
replace_string
_inplace
(
std
::
string
&
subject
,
const
std
::
string
&
search
,
const
std
::
string
&
replace
)
{
{
size_t
pos
=
0
;
size_t
pos
=
0
;
while
((
pos
=
subject
.
find
(
search
,
pos
))
!=
std
::
string
::
npos
)
while
((
pos
=
subject
.
find
(
search
,
pos
))
!=
std
::
string
::
npos
)
...
@@ -53,6 +53,12 @@ replace_string(std::string subject, const std::string& search, const std::string
...
@@ -53,6 +53,12 @@ replace_string(std::string subject, const std::string& search, const std::string
subject
.
replace
(
pos
,
search
.
length
(),
replace
);
subject
.
replace
(
pos
,
search
.
length
(),
replace
);
pos
+=
replace
.
length
();
pos
+=
replace
.
length
();
}
}
}
inline
std
::
string
replace_string
(
std
::
string
subject
,
const
std
::
string
&
search
,
const
std
::
string
&
replace
)
{
replace_string_inplace
(
subject
,
search
,
replace
);
return
subject
;
return
subject
;
}
}
...
...
src/module.cpp
View file @
ecfb0b72
...
@@ -35,6 +35,7 @@
...
@@ -35,6 +35,7 @@
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/json.hpp>
#include <iostream>
#include <iostream>
#include <sstream>
#include <sstream>
#include <algorithm>
#include <algorithm>
...
@@ -706,44 +707,33 @@ void module::print_graph(std::ostream& os, bool brief) const
...
@@ -706,44 +707,33 @@ void module::print_graph(std::ostream& os, bool brief) const
os
<<
"}"
<<
std
::
endl
;
os
<<
"}"
<<
std
::
endl
;
}
}
static
std
::
string
cpp_var_name
(
const
std
::
string
&
name
)
static
std
::
string
to_c_id
(
const
std
::
string
&
name
,
char
rep
=
'_'
)
{
{
return
"m"
+
replace_string
(
name
,
"@"
,
"x"
);
std
::
string
id
=
transform_string
(
name
,
[
&
](
auto
c
)
{
if
(
with_char
(
::
isalnum
)(
c
)
or
c
==
'_'
)
return
c
;
return
rep
;
});
while
(
contains
(
id
,
"__"
))
replace_string_inplace
(
id
,
"__"
,
"_"
);
return
id
;
}
}
static
std
::
string
cpp_
op_
var
(
const
std
::
string
&
name
,
instruction_ref
ins
)
static
std
::
string
cpp_var
_name
(
const
std
::
string
&
name
)
{
{
return
replace_string
(
name
,
"
@
"
,
ins
->
name
(
));
return
to_c_id
(
"x_"
+
replace_string
(
name
,
"
:
"
,
"_module_"
));
}
}
static
void
print_
op_attributes
(
std
::
ostream
&
os
,
const
std
::
string
&
name
,
const
operation
&
op
)
static
void
print_
make_op
(
std
::
ostream
&
os
,
const
operation
&
op
)
{
{
std
::
string
x
=
to_string
(
op
);
os
<<
"migraphx::make_op("
<<
enclose_name
(
op
.
name
());
if
(
contains
(
x
,
"["
))
auto
v
=
op
.
to_value
();
if
(
not
v
.
empty
())
{
{
auto
start
=
x
.
find
(
'['
);
os
<<
", "
auto
end
=
x
.
find
(
']'
);
<<
"migraphx::from_json_string("
<<
enclose_name
(
to_json_string
(
v
))
<<
")"
;
std
::
string
attribute_text
=
x
.
substr
(
start
+
1
,
end
-
start
-
1
);
std
::
vector
<
std
::
string
>
attributes
;
for
(
auto
&&
attribute
:
split_string
(
attribute_text
,
','
))
{
if
(
contains
(
attribute
,
'='
))
attributes
.
push_back
(
attribute
);
else
attributes
.
back
()
+=
","
+
attribute
;
}
for
(
auto
&&
attribute
:
attributes
)
{
auto
p
=
split_string
(
attribute
,
'='
);
auto
key
=
p
.
front
();
auto
value
=
p
.
back
();
if
(
contains
({
"bn_mode"
,
"padding_mode"
},
key
))
continue
;
if
(
key
==
"mode"
)
value
=
enclose_name
(
trim
(
value
));
os
<<
name
<<
"."
<<
key
<<
" = "
<<
value
<<
";"
<<
std
::
endl
;
}
}
}
os
<<
")"
;
}
}
static
void
print_cpp_shape
(
std
::
ostream
&
os
,
const
migraphx
::
shape
&
s
)
static
void
print_cpp_shape
(
std
::
ostream
&
os
,
const
migraphx
::
shape
&
s
)
...
@@ -756,22 +746,25 @@ static void print_cpp_shape(std::ostream& os, const migraphx::shape& s)
...
@@ -756,22 +746,25 @@ static void print_cpp_shape(std::ostream& os, const migraphx::shape& s)
}
}
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
module
::
print_cpp
(
std
::
ostream
&
os
,
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
names
)
const
module
::
print_cpp
(
std
::
ostream
&
os
,
const
std
::
string
&
mname
,
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
names
)
const
{
{
os
<<
"migraphx::module p;"
<<
std
::
endl
;
// cppcheck-suppress variableScope
unsigned
long
seed
=
0
;
unsigned
long
seed
=
names
.
size
();
auto
last
=
std
::
prev
(
this
->
end
());
names
=
this
->
print
(
names
=
this
->
print
(
[
&
](
auto
ins
,
auto
ins_names
)
{
[
&
](
auto
ins
,
auto
ins_names
)
{
auto
op
=
cpp_op_var
(
ins_names
.
at
(
ins
),
ins
)
;
std
::
vector
<
std
::
string
>
input_vars
;
if
(
ins
->
name
().
front
()
!=
'@'
)
std
::
transform
(
ins
->
inputs
().
begin
(),
{
ins
->
inputs
().
end
(),
os
<<
"migraphx::op::"
<<
ins
->
name
()
<<
" "
<<
op
<<
";"
<<
std
::
endl
;
std
::
back_inserter
(
input_vars
),
print_op_attributes
(
os
,
op
,
ins
->
get_operator
()
);
[
&
](
auto
input
)
{
return
cpp_var_name
(
ins_names
.
at
(
input
));
}
);
}
if
(
ins
!=
last
)
os
<<
"auto "
<<
cpp_var_name
(
ins_names
.
at
(
ins
))
<<
" = "
;
os
<<
"auto "
<<
cpp_var_name
(
ins_names
.
at
(
ins
))
<<
" = "
;
if
(
ins
->
name
()
==
"@literal"
)
if
(
ins
->
name
()
==
"@literal"
)
{
{
os
<<
"p.
add_literal("
;
os
<<
mname
<<
"->
add_literal("
;
bool
use_abs
=
false
;
bool
use_abs
=
false
;
ins
->
get_literal
().
visit
([
&
](
auto
v
)
{
ins
->
get_literal
().
visit
([
&
](
auto
v
)
{
use_abs
=
std
::
none_of
(
v
.
begin
(),
v
.
end
(),
[](
auto
x
)
{
return
x
<
0
;
});
use_abs
=
std
::
none_of
(
v
.
begin
(),
v
.
end
(),
[](
auto
x
)
{
return
x
<
0
;
});
...
@@ -789,17 +782,22 @@ module::print_cpp(std::ostream& os, std::unordered_map<instruction_ref, std::str
...
@@ -789,17 +782,22 @@ module::print_cpp(std::ostream& os, std::unordered_map<instruction_ref, std::str
else
if
(
ins
->
name
()
==
"@param"
)
else
if
(
ins
->
name
()
==
"@param"
)
{
{
std
::
string
name
=
any_cast
<
builtin
::
param
>
(
ins
->
get_operator
()).
parameter
;
std
::
string
name
=
any_cast
<
builtin
::
param
>
(
ins
->
get_operator
()).
parameter
;
os
<<
"p.
add_parameter("
<<
enclose_name
(
name
)
<<
","
;
os
<<
mname
<<
"->
add_parameter("
<<
enclose_name
(
name
)
<<
","
;
print_cpp_shape
(
os
,
ins
->
get_shape
());
print_cpp_shape
(
os
,
ins
->
get_shape
());
os
<<
");"
<<
std
::
endl
;
os
<<
");"
<<
std
::
endl
;
}
}
else
if
(
ins
->
name
()
==
"@return"
)
{
os
<<
mname
<<
"->add_return({"
;
os
<<
join_strings
(
input_vars
,
", "
);
os
<<
"});"
<<
std
::
endl
;
}
else
else
{
{
os
<<
"p.add_instruction("
<<
op
;
assert
(
ins
->
name
().
front
()
!=
'@'
);
for
(
auto
input
:
ins
->
inputs
())
os
<<
mname
<<
"->add_instruction("
;
{
print_make_op
(
os
,
ins
->
get_operator
());
os
<<
", "
<<
cpp_var_name
(
ins_names
.
at
(
input
));
os
<<
", "
<<
join_strings
(
input_vars
,
", "
);
}
os
<<
");"
<<
std
::
endl
;
os
<<
");"
<<
std
::
endl
;
}
}
},
},
...
@@ -808,7 +806,7 @@ module::print_cpp(std::ostream& os, std::unordered_map<instruction_ref, std::str
...
@@ -808,7 +806,7 @@ module::print_cpp(std::ostream& os, std::unordered_map<instruction_ref, std::str
return
names
;
return
names
;
}
}
void
module
::
print_cpp
(
std
::
ostream
&
os
)
const
{
this
->
print_cpp
(
os
,
{});
}
void
module
::
print_cpp
(
std
::
ostream
&
os
)
const
{
this
->
print_cpp
(
os
,
this
->
name
(),
{});
}
void
module
::
annotate
(
std
::
ostream
&
os
,
std
::
function
<
void
(
instruction_ref
)
>
a
)
const
void
module
::
annotate
(
std
::
ostream
&
os
,
std
::
function
<
void
(
instruction_ref
)
>
a
)
const
{
{
...
@@ -819,17 +817,20 @@ void module::annotate(std::ostream& os, std::function<void(instruction_ref)> a)
...
@@ -819,17 +817,20 @@ void module::annotate(std::ostream& os, std::function<void(instruction_ref)> a)
});
});
}
}
std
::
vector
<
module_ref
>
module
::
get_sub_modules
()
const
std
::
vector
<
module_ref
>
module
::
get_sub_modules
(
bool
shallow
)
const
{
{
std
::
vector
<
module_ref
>
vec_modules
;
std
::
vector
<
module_ref
>
vec_modules
;
for
(
auto
ins
:
iterator_for
(
*
this
))
for
(
auto
ins
:
iterator_for
(
*
this
))
{
{
const
auto
&
mod_args
=
ins
->
module_inputs
();
const
auto
&
mod_args
=
ins
->
module_inputs
();
vec_modules
.
insert
(
vec_modules
.
end
(),
mod_args
.
begin
(),
mod_args
.
end
());
vec_modules
.
insert
(
vec_modules
.
end
(),
mod_args
.
begin
(),
mod_args
.
end
());
for
(
const
auto
&
smod
:
mod_args
)
if
(
not
shallow
)
{
{
auto
sub_mods
=
smod
->
get_sub_modules
();
for
(
const
auto
&
smod
:
mod_args
)
vec_modules
.
insert
(
vec_modules
.
end
(),
sub_mods
.
begin
(),
sub_mods
.
end
());
{
auto
sub_mods
=
smod
->
get_sub_modules
();
vec_modules
.
insert
(
vec_modules
.
end
(),
sub_mods
.
begin
(),
sub_mods
.
end
());
}
}
}
}
}
...
...
src/pass_manager.cpp
View file @
ecfb0b72
...
@@ -66,14 +66,12 @@ void run_pass(program& prog, const pass& p, tracer trace)
...
@@ -66,14 +66,12 @@ void run_pass(program& prog, const pass& p, tracer trace)
struct
module_pm
:
module_pass_manager
struct
module_pm
:
module_pass_manager
{
{
module
*
mod
;
module
*
mod
=
nullptr
;
program
*
prog
;
tracer
*
t
=
nullptr
;
tracer
*
t
;
module
*
common_parent
=
nullptr
;
program
*
prog
=
nullptr
;
module_pm
(
module
*
pmod
=
nullptr
,
program
*
pprog
=
nullptr
,
tracer
*
pt
=
nullptr
)
module_pm
(
module
*
pmod
=
nullptr
,
tracer
*
pt
=
nullptr
)
:
mod
(
pmod
),
t
(
pt
)
{}
:
mod
(
pmod
),
prog
(
pprog
),
t
(
pt
)
{
}
template
<
class
...
Ts
>
template
<
class
...
Ts
>
void
trace
(
Ts
&&
...
xs
)
const
void
trace
(
Ts
&&
...
xs
)
const
...
@@ -92,6 +90,7 @@ struct module_pm : module_pass_manager
...
@@ -92,6 +90,7 @@ struct module_pm : module_pass_manager
assert
(
prog
);
assert
(
prog
);
return
prog
->
create_module
(
name
);
return
prog
->
create_module
(
name
);
}
}
virtual
module
*
get_common_parent
()
override
{
return
common_parent
;
}
virtual
void
run_pass
(
const
pass
&
p
)
override
virtual
void
run_pass
(
const
pass
&
p
)
override
{
{
assert
(
mod
);
assert
(
mod
);
...
@@ -111,7 +110,7 @@ void run_passes(module& mod, const std::vector<pass>& passes, tracer trace)
...
@@ -111,7 +110,7 @@ void run_passes(module& mod, const std::vector<pass>& passes, tracer trace)
trace
=
tracer
{
std
::
cout
};
trace
=
tracer
{
std
::
cout
};
for
(
const
auto
&
p
:
passes
)
for
(
const
auto
&
p
:
passes
)
{
{
module_pm
{
&
mod
,
nullptr
,
&
trace
}.
run_pass
(
p
);
module_pm
{
&
mod
,
&
trace
}.
run_pass
(
p
);
}
}
}
}
...
@@ -119,14 +118,31 @@ void run_passes(program& prog, const std::vector<pass>& passes, tracer trace)
...
@@ -119,14 +118,31 @@ void run_passes(program& prog, const std::vector<pass>& passes, tracer trace)
{
{
if
(
enabled
(
MIGRAPHX_TRACE_PASSES
{}))
if
(
enabled
(
MIGRAPHX_TRACE_PASSES
{}))
trace
=
tracer
{
std
::
cout
};
trace
=
tracer
{
std
::
cout
};
std
::
unordered_set
<
module_ref
>
visited
;
for
(
const
auto
&
p
:
passes
)
for
(
const
auto
&
p
:
passes
)
{
{
auto
mods
=
prog
.
get_modules
();
auto
mods
=
prog
.
get_modules
();
auto
tree
=
prog
.
get_module_tree
();
visited
.
clear
();
for
(
const
auto
&
mod
:
reverse
(
mods
))
for
(
const
auto
&
mod
:
reverse
(
mods
))
{
{
if
(
mod
->
bypass
())
if
(
mod
->
bypass
())
continue
;
continue
;
module_pm
{
mod
,
&
prog
,
&
trace
}.
run_pass
(
p
);
if
(
not
visited
.
insert
(
mod
).
second
)
continue
;
module_pm
mpm
{
mod
,
&
trace
};
mpm
.
prog
=
&
prog
;
auto
parents
=
range
(
tree
.
equal_range
(
mod
));
auto
nparents
=
distance
(
parents
);
if
(
nparents
==
0
)
mpm
.
common_parent
=
nullptr
;
else
if
(
nparents
==
1
)
mpm
.
common_parent
=
parents
.
begin
()
->
second
;
else
// Just set common parent to main module when there is muliple parents for now
// TODO: Compute the common parent
mpm
.
common_parent
=
prog
.
get_main_module
();
mpm
.
run_pass
(
p
);
}
}
run_pass
(
prog
,
p
,
trace
);
run_pass
(
prog
,
p
,
trace
);
}
}
...
...
src/program.cpp
View file @
ecfb0b72
...
@@ -790,10 +790,17 @@ void program::print_cpp(std::ostream& os) const
...
@@ -790,10 +790,17 @@ void program::print_cpp(std::ostream& os) const
{
{
auto
vec_modules
=
this
->
get_modules
();
auto
vec_modules
=
this
->
get_modules
();
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
names
;
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
names
;
os
<<
"migraphx::program p;
\n
"
;
for
(
auto
&
mod
:
vec_modules
)
for
(
auto
&
mod
:
vec_modules
)
{
{
os
<<
"module:
\"
"
<<
mod
->
name
()
<<
"
\"
"
<<
std
::
endl
;
std
::
string
var_name
=
"m"
+
mod
->
name
();
names
=
mod
->
print_cpp
(
os
,
names
);
os
<<
"migraphx::module_ref "
<<
var_name
<<
" = "
;
if
(
mod
->
name
()
==
"main"
)
os
<<
"p.get_main_module();"
;
else
os
<<
"p.create_module(
\"
"
<<
mod
->
name
()
<<
"
\"
);"
;
os
<<
std
::
endl
;
names
=
mod
->
print_cpp
(
os
,
var_name
,
names
);
os
<<
std
::
endl
;
os
<<
std
::
endl
;
}
}
}
}
...
@@ -869,6 +876,23 @@ std::vector<module*> program::get_modules()
...
@@ -869,6 +876,23 @@ std::vector<module*> program::get_modules()
return
result
;
return
result
;
}
}
template
<
class
Module
,
class
Map
>
void
generic_insert_module_tree
(
Module
*
pm
,
Map
&
m
)
{
for
(
auto
*
sm
:
pm
->
get_sub_modules
(
true
))
{
m
.
insert
(
std
::
make_pair
(
sm
,
pm
));
generic_insert_module_tree
(
sm
,
m
);
}
}
std
::
unordered_multimap
<
module_ref
,
module_ref
>
program
::
get_module_tree
()
{
std
::
unordered_multimap
<
module_ref
,
module_ref
>
result
;
generic_insert_module_tree
(
this
->
get_main_module
(),
result
);
return
result
;
}
template
<
class
Map
,
class
T
>
template
<
class
Map
,
class
T
>
bool
is_unused_module
(
Map
&
m
,
const
std
::
vector
<
T
*>&
mods
,
const
std
::
string
&
name
)
bool
is_unused_module
(
Map
&
m
,
const
std
::
vector
<
T
*>&
mods
,
const
std
::
string
&
name
)
{
{
...
...
src/simplify_algebra.cpp
View file @
ecfb0b72
...
@@ -303,37 +303,26 @@ struct find_double_add_lit_broadcast
...
@@ -303,37 +303,26 @@ struct find_double_add_lit_broadcast
struct
find_inner_broadcast
struct
find_inner_broadcast
{
{
auto
matcher
()
const
auto
matcher
()
const
{
return
pointwise
(
match
::
all_of
[
match
::
inputs
()](
match
::
broadcast
()));
}
{
return
pointwise
(
match
::
all_of
[
match
::
inputs
()](
match
::
broadcast_shape
(),
match
::
name
(
"broadcast"
,
"multibroadcast"
)));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
{
auto
ins
=
r
.
result
;
auto
ins
=
r
.
result
;
auto
inpu
ts
=
ins
->
inputs
();
auto
broadcas
ts
=
ins
->
inputs
();
if
(
inpu
ts
.
empty
())
if
(
broadcas
ts
.
empty
())
return
;
return
;
std
::
transform
(
inputs
.
begin
(),
inputs
.
end
(),
inputs
.
begin
(),
[
&
](
auto
i
)
{
std
::
vector
<
instruction_ref
>
inputs
;
if
(
contains
({
"broadcast"
,
"multibroadcast"
},
i
->
name
()))
std
::
transform
(
broadcasts
.
begin
(),
return
i
->
inputs
().
front
();
broadcasts
.
end
(),
else
std
::
back_inserter
(
inputs
),
return
i
;
[](
auto
i
)
{
return
i
->
inputs
().
front
();
});
});
if
(
std
::
any_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
&
](
auto
i
)
{
return
i
->
get_shape
()
!=
inputs
.
front
()
->
get_shape
();
if
(
not
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
&
](
auto
&
x
)
{
return
x
->
get_shape
()
==
inputs
.
front
()
->
get_shape
();
}))
}))
return
;
return
;
auto
op
=
m
.
insert_instruction
(
ins
,
ins
->
get_operator
(),
inputs
);
auto
op
=
m
.
insert_instruction
(
ins
,
ins
->
get_operator
(),
inputs
);
auto
bop
=
std
::
find_if
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
[
&
](
auto
i
)
{
m
.
replace_instruction
(
ins
,
broadcasts
.
front
()
->
get_operator
(),
op
);
return
contains
({
"broadcast"
,
"multibroadcast"
},
i
->
name
());
});
assert
(
bop
!=
ins
->
inputs
().
end
());
m
.
replace_instruction
(
ins
,
(
*
bop
)
->
get_operator
(),
op
);
}
}
};
};
...
@@ -461,8 +450,9 @@ struct find_splits
...
@@ -461,8 +450,9 @@ struct find_splits
{
{
auto
matcher
()
const
auto
matcher
()
const
{
{
return
match
::
any
(
match
::
any_of
[
match
::
outputs
()](
match
::
name
(
"slice"
)(
return
match
::
any
(
match
::
any_of
[
match
::
outputs
()](
match
::
pointwise
(),
reduction
()))));
match
::
any_of
[
match
::
outputs
()](
match
::
name
(
"slice"
)(
match
::
any_of
[
match
::
outputs
()](
match
::
pointwise
(
match
::
any_of
(
match
::
nargs
(
1
),
match
::
nargs
(
2
))),
reduction
()))));
}
}
static
bool
is_dependent
(
const
module
&
m
,
instruction_ref
ins1
,
instruction_ref
ins2
)
static
bool
is_dependent
(
const
module
&
m
,
instruction_ref
ins1
,
instruction_ref
ins2
)
...
...
src/targets/gpu/fuse_ops.cpp
View file @
ecfb0b72
...
@@ -706,6 +706,7 @@ struct miopen_fusion
...
@@ -706,6 +706,7 @@ struct miopen_fusion
return
args
.
back
();
return
args
.
back
();
}
}
};
};
MIGRAPHX_REGISTER_OP
(
miopen_fusion
)
struct
miopen_conv_bias
struct
miopen_conv_bias
{
{
...
...
src/targets/gpu/gemm_impl.cpp
View file @
ecfb0b72
...
@@ -97,6 +97,12 @@ void gemm_impl(context& ctx,
...
@@ -97,6 +97,12 @@ void gemm_impl(context& ctx,
bool
int8_x4_format
,
bool
int8_x4_format
,
bool
compute_fp32
)
bool
compute_fp32
)
{
{
const
bool
is_3inputs
=
(
args
.
size
()
==
4
);
if
(
!
is_3inputs
)
{
beta
=
0
;
}
bool
transa
=
is_transposed
(
args
[
0
].
get_shape
());
bool
transa
=
is_transposed
(
args
[
0
].
get_shape
());
bool
transb
=
is_transposed
(
args
[
1
].
get_shape
());
bool
transb
=
is_transposed
(
args
[
1
].
get_shape
());
auto
n_dim
=
output_shape
.
lens
().
size
();
auto
n_dim
=
output_shape
.
lens
().
size
();
...
@@ -105,12 +111,8 @@ void gemm_impl(context& ctx,
...
@@ -105,12 +111,8 @@ void gemm_impl(context& ctx,
rocblas_int
lda
=
args
[
0
].
get_shape
().
strides
()[
transa
?
dim_1
:
dim_0
];
rocblas_int
lda
=
args
[
0
].
get_shape
().
strides
()[
transa
?
dim_1
:
dim_0
];
rocblas_int
ldb
=
args
[
1
].
get_shape
().
strides
()[
transb
?
dim_1
:
dim_0
];
rocblas_int
ldb
=
args
[
1
].
get_shape
().
strides
()[
transb
?
dim_1
:
dim_0
];
rocblas_int
ldc
=
args
[
2
].
get_shape
().
strides
()[
dim_0
];
rocblas_int
ldc
=
args
[
2
].
get_shape
().
strides
()[
dim_0
];
rocblas_int
ldd
=
is_3inputs
?
args
[
3
].
get_shape
().
strides
()[
dim_0
]
:
ldc
;
bool
is_3inputs
=
(
args
.
size
()
==
4
);
if
(
!
is_3inputs
)
{
beta
=
0
;
}
rocblas_datatype
arg_type
=
get_type
(
args
[
0
].
get_shape
().
type
());
rocblas_datatype
arg_type
=
get_type
(
args
[
0
].
get_shape
().
type
());
auto
output_type
=
arg_type
;
auto
output_type
=
arg_type
;
if
(
output_type
==
rocblas_datatype_i8_r
)
if
(
output_type
==
rocblas_datatype_i8_r
)
...
@@ -186,7 +188,7 @@ void gemm_impl(context& ctx,
...
@@ -186,7 +188,7 @@ void gemm_impl(context& ctx,
ldc
,
ldc
,
is_3inputs
?
to_pointer
(
args
[
3
])
:
to_pointer
(
args
[
2
]),
is_3inputs
?
to_pointer
(
args
[
3
])
:
to_pointer
(
args
[
2
]),
output_type
,
output_type
,
ld
c
,
ld
d
,
compute_type
,
compute_type
,
rocblas_gemm_algo_standard
,
rocblas_gemm_algo_standard
,
0
,
0
,
...
@@ -197,6 +199,7 @@ void gemm_impl(context& ctx,
...
@@ -197,6 +199,7 @@ void gemm_impl(context& ctx,
auto
a_stride
=
get_batch_stride
(
args
[
0
]);
auto
a_stride
=
get_batch_stride
(
args
[
0
]);
auto
b_stride
=
get_batch_stride
(
args
[
1
]);
auto
b_stride
=
get_batch_stride
(
args
[
1
]);
auto
c_stride
=
get_batch_stride
(
args
[
2
]);
auto
c_stride
=
get_batch_stride
(
args
[
2
]);
auto
d_stride
=
is_3inputs
?
get_batch_stride
(
args
[
3
])
:
c_stride
;
rocblas_invoke
(
&
rocblas_gemm_strided_batched_ex
,
rocblas_invoke
(
&
rocblas_gemm_strided_batched_ex
,
ctx
.
get_stream
().
get_rocblas
(),
ctx
.
get_stream
().
get_rocblas
(),
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
...
@@ -220,8 +223,8 @@ void gemm_impl(context& ctx,
...
@@ -220,8 +223,8 @@ void gemm_impl(context& ctx,
c_stride
,
c_stride
,
is_3inputs
?
to_pointer
(
args
[
3
])
:
to_pointer
(
args
[
2
]),
is_3inputs
?
to_pointer
(
args
[
3
])
:
to_pointer
(
args
[
2
]),
output_type
,
output_type
,
ld
c
,
ld
d
,
c
_stride
,
d
_stride
,
num_matrices
,
num_matrices
,
compute_type
,
compute_type
,
rocblas_gemm_algo_standard
,
rocblas_gemm_algo_standard
,
...
...
src/targets/gpu/include/migraphx/gpu/gemm.hpp
View file @
ecfb0b72
...
@@ -74,13 +74,14 @@ struct rocblas_gemm
...
@@ -74,13 +74,14 @@ struct rocblas_gemm
{
{
std
::
vector
<
shape
>
in_shapes
(
inputs
);
std
::
vector
<
shape
>
in_shapes
(
inputs
);
in_shapes
.
pop_back
();
in_shapes
.
pop_back
();
check_shapes
{
in_shapes
,
*
this
}
.
not_broadcasted
()
;
check_shapes
{
in_shapes
,
*
this
};
blas_shape
(
inputs
[
0
]);
blas_shape
(
inputs
[
0
]);
blas_shape
(
inputs
[
1
]);
blas_shape
(
inputs
[
1
]);
// if gemm and add are fused
// if gemm and add are fused
if
(
in_shapes
.
size
()
>
2
)
if
(
in_shapes
.
size
()
>
2
)
{
{
auto
cmat_shape
=
in_shapes
.
back
();
auto
cmat_shape
=
in_shapes
.
back
();
check_shapes
{{
cmat_shape
},
*
this
}.
not_transposed
().
not_broadcasted
();
in_shapes
.
pop_back
();
in_shapes
.
pop_back
();
blas_shape
(
cmat_shape
);
blas_shape
(
cmat_shape
);
auto
op_out_shape
=
op
.
compute_shape
(
in_shapes
);
auto
op_out_shape
=
op
.
compute_shape
(
in_shapes
);
...
...
test/ref_ops_test.cpp
View file @
ecfb0b72
...
@@ -3187,6 +3187,80 @@ TEST_CASE(nms_test)
...
@@ -3187,6 +3187,80 @@ TEST_CASE(nms_test)
EXPECT
(
migraphx
::
verify_range
(
result
,
gold
));
EXPECT
(
migraphx
::
verify_range
(
result
,
gold
));
}
}
TEST_CASE
(
nms_transpose1_test
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
boxes_s
{
migraphx
::
shape
::
float_type
,
{
1
,
4
,
6
}};
std
::
vector
<
float
>
boxes_vec
=
{
0.5
,
0.5
,
0.5
,
0.5
,
0.5
,
0.5
,
0.5
,
0.6
,
0.4
,
10.5
,
10.6
,
100.5
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
};
migraphx
::
shape
scores_s
{
migraphx
::
shape
::
float_type
,
{
1
,
1
,
6
}};
std
::
vector
<
float
>
scores_vec
=
{
0.9
,
0.75
,
0.6
,
0.95
,
0.5
,
0.3
};
auto
t_boxes_l
=
mm
->
add_literal
(
migraphx
::
literal
(
boxes_s
,
boxes_vec
));
auto
scores_l
=
mm
->
add_literal
(
migraphx
::
literal
(
scores_s
,
scores_vec
));
auto
max_out_l
=
mm
->
add_literal
(
int64_t
{
4
});
auto
iou_threshold
=
mm
->
add_literal
(
0.5
f
);
auto
score_threshold
=
mm
->
add_literal
(
0.0
f
);
auto
transpose_boxes
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
2
,
1
}}}),
t_boxes_l
);
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"nonmaxsuppression"
,
{{
"center_point_box"
,
1
}}),
transpose_boxes
,
scores_l
,
max_out_l
,
iou_threshold
,
score_threshold
);
mm
->
add_return
({
r
});
p
.
compile
(
migraphx
::
ref
::
target
{});
auto
output
=
p
.
eval
({}).
back
();
std
::
vector
<
int64_t
>
result
;
output
.
visit
([
&
](
auto
out
)
{
result
.
assign
(
out
.
begin
(),
out
.
end
());
});
std
::
vector
<
int64_t
>
gold
=
{
0
,
0
,
3
,
0
,
0
,
0
,
0
,
0
,
5
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
};
EXPECT
(
migraphx
::
verify_range
(
result
,
gold
));
}
TEST_CASE
(
nms_transpose2_test
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
boxes_s
{
migraphx
::
shape
::
float_type
,
{
4
,
1
,
6
}};
std
::
vector
<
float
>
boxes_vec
=
{
0.5
,
0.5
,
0.5
,
0.5
,
0.5
,
0.5
,
0.5
,
0.6
,
0.4
,
10.5
,
10.6
,
100.5
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
};
migraphx
::
shape
scores_s
{
migraphx
::
shape
::
float_type
,
{
1
,
1
,
6
}};
std
::
vector
<
float
>
scores_vec
=
{
0.9
,
0.75
,
0.6
,
0.95
,
0.5
,
0.3
};
auto
t_boxes_l
=
mm
->
add_literal
(
migraphx
::
literal
(
boxes_s
,
boxes_vec
));
auto
scores_l
=
mm
->
add_literal
(
migraphx
::
literal
(
scores_s
,
scores_vec
));
auto
max_out_l
=
mm
->
add_literal
(
int64_t
{
4
});
auto
iou_threshold
=
mm
->
add_literal
(
0.5
f
);
auto
score_threshold
=
mm
->
add_literal
(
0.0
f
);
auto
transpose_boxes
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
2
,
0
}}}),
t_boxes_l
);
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"nonmaxsuppression"
,
{{
"center_point_box"
,
1
}}),
transpose_boxes
,
scores_l
,
max_out_l
,
iou_threshold
,
score_threshold
);
mm
->
add_return
({
r
});
p
.
compile
(
migraphx
::
ref
::
target
{});
auto
output
=
p
.
eval
({}).
back
();
std
::
vector
<
int64_t
>
result
;
output
.
visit
([
&
](
auto
out
)
{
result
.
assign
(
out
.
begin
(),
out
.
end
());
});
std
::
vector
<
int64_t
>
gold
=
{
0
,
0
,
3
,
0
,
0
,
0
,
0
,
0
,
5
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
};
EXPECT
(
migraphx
::
verify_range
(
result
,
gold
));
}
TEST_CASE
(
nonzero_test
)
TEST_CASE
(
nonzero_test
)
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
...
...
test/verify/gemm_add_broadcast1.cpp
0 → 100644
View file @
ecfb0b72
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/apply_alpha_beta.hpp>
struct
gemm_add_broadcast1
:
verify_program
<
gemm_add_broadcast1
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
3
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
float_type
,
{
1
,
3
,
4
}};
migraphx
::
shape
m3_shape
{
migraphx
::
shape
::
float_type
,
{
1
,
1
,
4
}};
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
l2
=
mm
->
add_parameter
(
"2"
,
m2_shape
);
auto
l3
=
mm
->
add_parameter
(
"3"
,
m3_shape
);
auto
l3_b
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
1
,
2
,
4
}}}),
l3
);
auto
dot
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
l1
,
l2
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
dot
,
l3_b
);
return
p
;
}
};
test/verify/gemm_add_broadcast2.cpp
0 → 100644
View file @
ecfb0b72
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/apply_alpha_beta.hpp>
struct
gemm_add_broadcast2
:
verify_program
<
gemm_add_broadcast2
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
3
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
float_type
,
{
1
,
3
,
4
}};
migraphx
::
shape
m3_shape
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
1
}};
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
l2
=
mm
->
add_parameter
(
"2"
,
m2_shape
);
auto
l3
=
mm
->
add_parameter
(
"3"
,
m3_shape
);
auto
l3_b
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
1
,
2
,
4
}}}),
l3
);
auto
dot
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
l1
,
l2
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
dot
,
l3_b
);
return
p
;
}
};
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment