Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
c97080ce
Commit
c97080ce
authored
Dec 03, 2022
by
Paul
Browse files
Fuse transpose
parent
f17d6246
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
80 additions
and
19 deletions
+80
-19
src/include/migraphx/serialize.hpp
src/include/migraphx/serialize.hpp
+32
-15
src/include/migraphx/streamutils.hpp
src/include/migraphx/streamutils.hpp
+11
-1
src/targets/gpu/compile_ops.cpp
src/targets/gpu/compile_ops.cpp
+10
-3
src/targets/gpu/fuse_ops.cpp
src/targets/gpu/fuse_ops.cpp
+27
-0
No files found.
src/include/migraphx/serialize.hpp
View file @
c97080ce
...
...
@@ -28,6 +28,7 @@
#include <migraphx/value.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/requires.hpp>
#include <migraphx/optional.hpp>
#include <migraphx/rank.hpp>
#include <type_traits>
...
...
@@ -87,46 +88,55 @@ value to_value_impl(rank<3>, const T& x)
return
result
;
}
template
<
class
T
>
auto
to_value_impl
(
rank
<
4
>
,
const
optional
<
T
>&
x
)
{
value
result
{};
if
(
x
.
has_value
())
to_value
(
*
x
);
return
result
;
}
template
<
class
T
,
MIGRAPHX_REQUIRES
(
std
::
is_signed
<
T
>{})
>
value
to_value_impl
(
rank
<
4
>
,
const
T
&
x
)
value
to_value_impl
(
rank
<
5
>
,
const
T
&
x
)
{
return
std
::
int64_t
{
x
};
}
template
<
class
T
,
MIGRAPHX_REQUIRES
(
std
::
is_unsigned
<
T
>{})
>
value
to_value_impl
(
rank
<
5
>
,
const
T
&
x
)
value
to_value_impl
(
rank
<
6
>
,
const
T
&
x
)
{
return
std
::
uint64_t
{
x
};
}
template
<
class
T
,
MIGRAPHX_REQUIRES
(
std
::
is_floating_point
<
T
>{})
>
value
to_value_impl
(
rank
<
6
>
,
const
T
&
x
)
value
to_value_impl
(
rank
<
7
>
,
const
T
&
x
)
{
return
double
{
x
};
}
template
<
class
T
,
MIGRAPHX_REQUIRES
(
std
::
is_enum
<
T
>{})
>
value
to_value_impl
(
rank
<
7
>
,
const
T
&
x
)
value
to_value_impl
(
rank
<
8
>
,
const
T
&
x
)
{
return
x
;
}
inline
value
to_value_impl
(
rank
<
8
>
,
const
std
::
string
&
x
)
{
return
x
;
}
inline
value
to_value_impl
(
rank
<
9
>
,
const
std
::
string
&
x
)
{
return
x
;
}
template
<
class
T
>
auto
to_value_impl
(
rank
<
9
>
,
const
T
&
x
)
->
decltype
(
migraphx_to_value
(
x
))
auto
to_value_impl
(
rank
<
10
>
,
const
T
&
x
)
->
decltype
(
migraphx_to_value
(
x
))
{
return
migraphx_to_value
(
x
);
}
template
<
class
T
>
auto
to_value_impl
(
rank
<
1
0
>
,
const
T
&
x
)
->
decltype
(
x
.
to_value
())
auto
to_value_impl
(
rank
<
1
1
>
,
const
T
&
x
)
->
decltype
(
x
.
to_value
())
{
return
x
.
to_value
();
}
template
<
class
T
>
auto
to_value_impl
(
rank
<
1
1
>
,
const
T
&
x
)
auto
to_value_impl
(
rank
<
1
2
>
,
const
T
&
x
)
->
decltype
(
migraphx_to_value
(
std
::
declval
<
value
&>
(),
x
),
value
{})
{
value
v
;
...
...
@@ -195,28 +205,35 @@ void from_value_impl(rank<5>, const value& v, T& x)
});
}
template
<
class
T
>
void
from_value_impl
(
rank
<
6
>
,
const
value
&
v
,
optional
<
T
>&
x
)
{
if
(
not
v
.
is_null
())
x
=
from_value
<
T
>
(
v
);
}
template
<
class
T
,
MIGRAPHX_REQUIRES
(
std
::
is_arithmetic
<
T
>{})
>
void
from_value_impl
(
rank
<
6
>
,
const
value
&
v
,
T
&
x
)
void
from_value_impl
(
rank
<
7
>
,
const
value
&
v
,
T
&
x
)
{
x
=
v
.
to
<
T
>
();
}
template
<
class
T
,
MIGRAPHX_REQUIRES
(
std
::
is_enum
<
T
>{})
>
void
from_value_impl
(
rank
<
7
>
,
const
value
&
v
,
T
&
x
)
void
from_value_impl
(
rank
<
8
>
,
const
value
&
v
,
T
&
x
)
{
x
=
v
.
to
<
T
>
();
}
inline
void
from_value_impl
(
rank
<
8
>
,
const
value
&
v
,
std
::
string
&
x
)
{
x
=
v
.
to
<
std
::
string
>
();
}
inline
void
from_value_impl
(
rank
<
9
>
,
const
value
&
v
,
std
::
string
&
x
)
{
x
=
v
.
to
<
std
::
string
>
();
}
template
<
class
T
>
auto
from_value_impl
(
rank
<
9
>
,
const
value
&
v
,
T
&
x
)
->
decltype
(
x
.
from_value
(
v
),
void
())
auto
from_value_impl
(
rank
<
10
>
,
const
value
&
v
,
T
&
x
)
->
decltype
(
x
.
from_value
(
v
),
void
())
{
x
.
from_value
(
v
);
}
template
<
class
T
>
auto
from_value_impl
(
rank
<
1
0
>
,
const
value
&
v
,
T
&
x
)
->
decltype
(
migraphx_from_value
(
v
,
x
),
void
())
auto
from_value_impl
(
rank
<
1
1
>
,
const
value
&
v
,
T
&
x
)
->
decltype
(
migraphx_from_value
(
v
,
x
),
void
())
{
migraphx_from_value
(
v
,
x
);
}
...
...
@@ -226,13 +243,13 @@ auto from_value_impl(rank<10>, const value& v, T& x) -> decltype(migraphx_from_v
template
<
class
T
>
value
to_value
(
const
T
&
x
)
{
return
detail
::
to_value_impl
(
rank
<
1
1
>
{},
x
);
return
detail
::
to_value_impl
(
rank
<
1
2
>
{},
x
);
}
template
<
class
T
>
void
from_value
(
const
value
&
v
,
T
&
x
)
{
detail
::
from_value_impl
(
rank
<
1
0
>
{},
v
,
x
);
detail
::
from_value_impl
(
rank
<
1
1
>
{},
v
,
x
);
}
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/streamutils.hpp
View file @
c97080ce
...
...
@@ -29,6 +29,7 @@
#include <migraphx/reflect.hpp>
#include <migraphx/rank.hpp>
#include <migraphx/requires.hpp>
#include <migraphx/optional.hpp>
#include <migraphx/config.hpp>
#include <vector>
...
...
@@ -99,12 +100,21 @@ void stream_write_value_impl(rank<0>, std::ostream& os, const T& x)
os
<<
"}"
;
}
template
<
class
T
>
void
stream_write_value_impl
(
rank
<
0
>
,
std
::
ostream
&
os
,
const
optional
<
T
>&
x
)
{
if
(
x
.
has_value
())
stream_write_value_impl
(
rank
<
2
>
{},
os
,
*
x
);
else
os
<<
"none"
;
}
}
// namespace detail
template
<
class
T
>
void
stream_write_value
(
std
::
ostream
&
os
,
const
T
&
x
)
{
detail
::
stream_write_value_impl
(
rank
<
1
>
{},
os
,
x
);
detail
::
stream_write_value_impl
(
rank
<
2
>
{},
os
,
x
);
}
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/targets/gpu/compile_ops.cpp
View file @
c97080ce
...
...
@@ -42,13 +42,15 @@ struct precompile_op
operation
op
=
op
::
identity
{};
std
::
size_t
additional_args
=
1
;
bool
ignore_modules
=
false
;
optional
<
shape
>
output_shape
=
{};
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
op
,
"op"
),
f
(
self
.
additional_args
,
"additional_args"
),
f
(
self
.
ignore_modules
,
"ignore_modules"
));
f
(
self
.
ignore_modules
,
"ignore_modules"
),
f
(
self
.
output_shape
,
"output_shape"
));
}
std
::
string
name
()
const
{
return
"gpu::precompile_op"
;
}
...
...
@@ -57,9 +59,14 @@ struct precompile_op
{
// Pop off additional args
inputs
.
resize
(
inputs
.
size
()
-
additional_args
);
shape
r
{};
if
(
ignore_modules
)
return
op
.
compute_shape
(
inputs
);
return
op
.
compute_shape
(
inputs
,
mods
);
r
=
op
.
compute_shape
(
inputs
);
else
r
=
op
.
compute_shape
(
inputs
,
mods
);
if
(
output_shape
.
has_value
())
r
=
*
output_shape
;
return
r
;
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
shapes
)
const
...
...
src/targets/gpu/fuse_ops.cpp
View file @
c97080ce
...
...
@@ -650,6 +650,32 @@ struct find_gemm_pointwise
}
};
struct
find_contiguous_tranpose_precompile
{
auto
matcher
()
const
{
return
match
::
name
(
"gpu::contiguous"
)(
match
::
arg
(
0
)(
match
::
name
(
"transpose"
)(
match
::
arg
(
0
)(
match
::
name
(
"gpu::precompile_op"
)(
match
::
used_once
()).
bind
(
"op"
)))
.
bind
(
"transpose"
)));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
op_ins
=
r
.
instructions
[
"op"
];
auto
alloc
=
op_ins
->
inputs
().
back
();
auto
transpose
=
r
.
instructions
[
"transpose"
];
auto
perm
=
transpose
->
get_operator
().
to_value
()[
"permutation"
].
to_vector
<
int64_t
>
();
auto
iperm
=
invert_permutation
(
perm
);
auto
s
=
shape
::
from_permutation
(
op_ins
->
get_shape
().
type
(),
op_ins
->
get_shape
().
lens
(),
iperm
);
auto
v
=
op_ins
->
get_operator
().
to_value
();
v
[
"output_shape"
]
=
to_value
(
s
);
auto
new_op
=
make_op
(
"gpu::precompile_op"
,
v
);
m
.
replace_instruction
(
op_ins
,
new_op
,
op_ins
->
inputs
(),
op_ins
->
module_inputs
());
}
};
struct
find_contiguous_tranpose_gemm
{
auto
matcher
()
const
...
...
@@ -825,6 +851,7 @@ void fuse_ops::apply(module& m) const
find_concat_pointwise
{},
find_gemm_pointwise
{},
find_contiguous_tranpose_gemm
{},
find_contiguous_tranpose_precompile
{},
find_commutative_broadcast
{});
match
::
find_matches
(
m
,
find_contiguous
{});
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment