Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
c97080ce
Commit
c97080ce
authored
Dec 03, 2022
by
Paul
Browse files
Fuse transpose
parent
f17d6246
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
80 additions
and
19 deletions
+80
-19
src/include/migraphx/serialize.hpp
src/include/migraphx/serialize.hpp
+32
-15
src/include/migraphx/streamutils.hpp
src/include/migraphx/streamutils.hpp
+11
-1
src/targets/gpu/compile_ops.cpp
src/targets/gpu/compile_ops.cpp
+10
-3
src/targets/gpu/fuse_ops.cpp
src/targets/gpu/fuse_ops.cpp
+27
-0
No files found.
src/include/migraphx/serialize.hpp
View file @
c97080ce
...
@@ -28,6 +28,7 @@
...
@@ -28,6 +28,7 @@
#include <migraphx/value.hpp>
#include <migraphx/value.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/requires.hpp>
#include <migraphx/requires.hpp>
#include <migraphx/optional.hpp>
#include <migraphx/rank.hpp>
#include <migraphx/rank.hpp>
#include <type_traits>
#include <type_traits>
...
@@ -87,46 +88,55 @@ value to_value_impl(rank<3>, const T& x)
...
@@ -87,46 +88,55 @@ value to_value_impl(rank<3>, const T& x)
return
result
;
return
result
;
}
}
template
<
class
T
>
auto
to_value_impl
(
rank
<
4
>
,
const
optional
<
T
>&
x
)
{
value
result
{};
if
(
x
.
has_value
())
to_value
(
*
x
);
return
result
;
}
template
<
class
T
,
MIGRAPHX_REQUIRES
(
std
::
is_signed
<
T
>{})
>
template
<
class
T
,
MIGRAPHX_REQUIRES
(
std
::
is_signed
<
T
>{})
>
value
to_value_impl
(
rank
<
4
>
,
const
T
&
x
)
value
to_value_impl
(
rank
<
5
>
,
const
T
&
x
)
{
{
return
std
::
int64_t
{
x
};
return
std
::
int64_t
{
x
};
}
}
template
<
class
T
,
MIGRAPHX_REQUIRES
(
std
::
is_unsigned
<
T
>{})
>
template
<
class
T
,
MIGRAPHX_REQUIRES
(
std
::
is_unsigned
<
T
>{})
>
value
to_value_impl
(
rank
<
5
>
,
const
T
&
x
)
value
to_value_impl
(
rank
<
6
>
,
const
T
&
x
)
{
{
return
std
::
uint64_t
{
x
};
return
std
::
uint64_t
{
x
};
}
}
template
<
class
T
,
MIGRAPHX_REQUIRES
(
std
::
is_floating_point
<
T
>{})
>
template
<
class
T
,
MIGRAPHX_REQUIRES
(
std
::
is_floating_point
<
T
>{})
>
value
to_value_impl
(
rank
<
6
>
,
const
T
&
x
)
value
to_value_impl
(
rank
<
7
>
,
const
T
&
x
)
{
{
return
double
{
x
};
return
double
{
x
};
}
}
template
<
class
T
,
MIGRAPHX_REQUIRES
(
std
::
is_enum
<
T
>{})
>
template
<
class
T
,
MIGRAPHX_REQUIRES
(
std
::
is_enum
<
T
>{})
>
value
to_value_impl
(
rank
<
7
>
,
const
T
&
x
)
value
to_value_impl
(
rank
<
8
>
,
const
T
&
x
)
{
{
return
x
;
return
x
;
}
}
inline
value
to_value_impl
(
rank
<
8
>
,
const
std
::
string
&
x
)
{
return
x
;
}
inline
value
to_value_impl
(
rank
<
9
>
,
const
std
::
string
&
x
)
{
return
x
;
}
template
<
class
T
>
template
<
class
T
>
auto
to_value_impl
(
rank
<
9
>
,
const
T
&
x
)
->
decltype
(
migraphx_to_value
(
x
))
auto
to_value_impl
(
rank
<
10
>
,
const
T
&
x
)
->
decltype
(
migraphx_to_value
(
x
))
{
{
return
migraphx_to_value
(
x
);
return
migraphx_to_value
(
x
);
}
}
template
<
class
T
>
template
<
class
T
>
auto
to_value_impl
(
rank
<
1
0
>
,
const
T
&
x
)
->
decltype
(
x
.
to_value
())
auto
to_value_impl
(
rank
<
1
1
>
,
const
T
&
x
)
->
decltype
(
x
.
to_value
())
{
{
return
x
.
to_value
();
return
x
.
to_value
();
}
}
template
<
class
T
>
template
<
class
T
>
auto
to_value_impl
(
rank
<
1
1
>
,
const
T
&
x
)
auto
to_value_impl
(
rank
<
1
2
>
,
const
T
&
x
)
->
decltype
(
migraphx_to_value
(
std
::
declval
<
value
&>
(),
x
),
value
{})
->
decltype
(
migraphx_to_value
(
std
::
declval
<
value
&>
(),
x
),
value
{})
{
{
value
v
;
value
v
;
...
@@ -195,28 +205,35 @@ void from_value_impl(rank<5>, const value& v, T& x)
...
@@ -195,28 +205,35 @@ void from_value_impl(rank<5>, const value& v, T& x)
});
});
}
}
template
<
class
T
>
void
from_value_impl
(
rank
<
6
>
,
const
value
&
v
,
optional
<
T
>&
x
)
{
if
(
not
v
.
is_null
())
x
=
from_value
<
T
>
(
v
);
}
template
<
class
T
,
MIGRAPHX_REQUIRES
(
std
::
is_arithmetic
<
T
>{})
>
template
<
class
T
,
MIGRAPHX_REQUIRES
(
std
::
is_arithmetic
<
T
>{})
>
void
from_value_impl
(
rank
<
6
>
,
const
value
&
v
,
T
&
x
)
void
from_value_impl
(
rank
<
7
>
,
const
value
&
v
,
T
&
x
)
{
{
x
=
v
.
to
<
T
>
();
x
=
v
.
to
<
T
>
();
}
}
template
<
class
T
,
MIGRAPHX_REQUIRES
(
std
::
is_enum
<
T
>{})
>
template
<
class
T
,
MIGRAPHX_REQUIRES
(
std
::
is_enum
<
T
>{})
>
void
from_value_impl
(
rank
<
7
>
,
const
value
&
v
,
T
&
x
)
void
from_value_impl
(
rank
<
8
>
,
const
value
&
v
,
T
&
x
)
{
{
x
=
v
.
to
<
T
>
();
x
=
v
.
to
<
T
>
();
}
}
inline
void
from_value_impl
(
rank
<
8
>
,
const
value
&
v
,
std
::
string
&
x
)
{
x
=
v
.
to
<
std
::
string
>
();
}
inline
void
from_value_impl
(
rank
<
9
>
,
const
value
&
v
,
std
::
string
&
x
)
{
x
=
v
.
to
<
std
::
string
>
();
}
template
<
class
T
>
template
<
class
T
>
auto
from_value_impl
(
rank
<
9
>
,
const
value
&
v
,
T
&
x
)
->
decltype
(
x
.
from_value
(
v
),
void
())
auto
from_value_impl
(
rank
<
10
>
,
const
value
&
v
,
T
&
x
)
->
decltype
(
x
.
from_value
(
v
),
void
())
{
{
x
.
from_value
(
v
);
x
.
from_value
(
v
);
}
}
template
<
class
T
>
template
<
class
T
>
auto
from_value_impl
(
rank
<
1
0
>
,
const
value
&
v
,
T
&
x
)
->
decltype
(
migraphx_from_value
(
v
,
x
),
void
())
auto
from_value_impl
(
rank
<
1
1
>
,
const
value
&
v
,
T
&
x
)
->
decltype
(
migraphx_from_value
(
v
,
x
),
void
())
{
{
migraphx_from_value
(
v
,
x
);
migraphx_from_value
(
v
,
x
);
}
}
...
@@ -226,13 +243,13 @@ auto from_value_impl(rank<10>, const value& v, T& x) -> decltype(migraphx_from_v
...
@@ -226,13 +243,13 @@ auto from_value_impl(rank<10>, const value& v, T& x) -> decltype(migraphx_from_v
template
<
class
T
>
template
<
class
T
>
value
to_value
(
const
T
&
x
)
value
to_value
(
const
T
&
x
)
{
{
return
detail
::
to_value_impl
(
rank
<
1
1
>
{},
x
);
return
detail
::
to_value_impl
(
rank
<
1
2
>
{},
x
);
}
}
template
<
class
T
>
template
<
class
T
>
void
from_value
(
const
value
&
v
,
T
&
x
)
void
from_value
(
const
value
&
v
,
T
&
x
)
{
{
detail
::
from_value_impl
(
rank
<
1
0
>
{},
v
,
x
);
detail
::
from_value_impl
(
rank
<
1
1
>
{},
v
,
x
);
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/streamutils.hpp
View file @
c97080ce
...
@@ -29,6 +29,7 @@
...
@@ -29,6 +29,7 @@
#include <migraphx/reflect.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/rank.hpp>
#include <migraphx/rank.hpp>
#include <migraphx/requires.hpp>
#include <migraphx/requires.hpp>
#include <migraphx/optional.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <vector>
#include <vector>
...
@@ -99,12 +100,21 @@ void stream_write_value_impl(rank<0>, std::ostream& os, const T& x)
...
@@ -99,12 +100,21 @@ void stream_write_value_impl(rank<0>, std::ostream& os, const T& x)
os
<<
"}"
;
os
<<
"}"
;
}
}
template
<
class
T
>
void
stream_write_value_impl
(
rank
<
0
>
,
std
::
ostream
&
os
,
const
optional
<
T
>&
x
)
{
if
(
x
.
has_value
())
stream_write_value_impl
(
rank
<
2
>
{},
os
,
*
x
);
else
os
<<
"none"
;
}
}
// namespace detail
}
// namespace detail
template
<
class
T
>
template
<
class
T
>
void
stream_write_value
(
std
::
ostream
&
os
,
const
T
&
x
)
void
stream_write_value
(
std
::
ostream
&
os
,
const
T
&
x
)
{
{
detail
::
stream_write_value_impl
(
rank
<
1
>
{},
os
,
x
);
detail
::
stream_write_value_impl
(
rank
<
2
>
{},
os
,
x
);
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/targets/gpu/compile_ops.cpp
View file @
c97080ce
...
@@ -42,13 +42,15 @@ struct precompile_op
...
@@ -42,13 +42,15 @@ struct precompile_op
operation
op
=
op
::
identity
{};
operation
op
=
op
::
identity
{};
std
::
size_t
additional_args
=
1
;
std
::
size_t
additional_args
=
1
;
bool
ignore_modules
=
false
;
bool
ignore_modules
=
false
;
optional
<
shape
>
output_shape
=
{};
template
<
class
Self
,
class
F
>
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
static
auto
reflect
(
Self
&
self
,
F
f
)
{
{
return
pack
(
f
(
self
.
op
,
"op"
),
return
pack
(
f
(
self
.
op
,
"op"
),
f
(
self
.
additional_args
,
"additional_args"
),
f
(
self
.
additional_args
,
"additional_args"
),
f
(
self
.
ignore_modules
,
"ignore_modules"
));
f
(
self
.
ignore_modules
,
"ignore_modules"
),
f
(
self
.
output_shape
,
"output_shape"
));
}
}
std
::
string
name
()
const
{
return
"gpu::precompile_op"
;
}
std
::
string
name
()
const
{
return
"gpu::precompile_op"
;
}
...
@@ -57,9 +59,14 @@ struct precompile_op
...
@@ -57,9 +59,14 @@ struct precompile_op
{
{
// Pop off additional args
// Pop off additional args
inputs
.
resize
(
inputs
.
size
()
-
additional_args
);
inputs
.
resize
(
inputs
.
size
()
-
additional_args
);
shape
r
{};
if
(
ignore_modules
)
if
(
ignore_modules
)
return
op
.
compute_shape
(
inputs
);
r
=
op
.
compute_shape
(
inputs
);
return
op
.
compute_shape
(
inputs
,
mods
);
else
r
=
op
.
compute_shape
(
inputs
,
mods
);
if
(
output_shape
.
has_value
())
r
=
*
output_shape
;
return
r
;
}
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
shapes
)
const
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
shapes
)
const
...
...
src/targets/gpu/fuse_ops.cpp
View file @
c97080ce
...
@@ -650,6 +650,32 @@ struct find_gemm_pointwise
...
@@ -650,6 +650,32 @@ struct find_gemm_pointwise
}
}
};
};
struct
find_contiguous_tranpose_precompile
{
auto
matcher
()
const
{
return
match
::
name
(
"gpu::contiguous"
)(
match
::
arg
(
0
)(
match
::
name
(
"transpose"
)(
match
::
arg
(
0
)(
match
::
name
(
"gpu::precompile_op"
)(
match
::
used_once
()).
bind
(
"op"
)))
.
bind
(
"transpose"
)));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
op_ins
=
r
.
instructions
[
"op"
];
auto
alloc
=
op_ins
->
inputs
().
back
();
auto
transpose
=
r
.
instructions
[
"transpose"
];
auto
perm
=
transpose
->
get_operator
().
to_value
()[
"permutation"
].
to_vector
<
int64_t
>
();
auto
iperm
=
invert_permutation
(
perm
);
auto
s
=
shape
::
from_permutation
(
op_ins
->
get_shape
().
type
(),
op_ins
->
get_shape
().
lens
(),
iperm
);
auto
v
=
op_ins
->
get_operator
().
to_value
();
v
[
"output_shape"
]
=
to_value
(
s
);
auto
new_op
=
make_op
(
"gpu::precompile_op"
,
v
);
m
.
replace_instruction
(
op_ins
,
new_op
,
op_ins
->
inputs
(),
op_ins
->
module_inputs
());
}
};
struct
find_contiguous_tranpose_gemm
struct
find_contiguous_tranpose_gemm
{
{
auto
matcher
()
const
auto
matcher
()
const
...
@@ -825,6 +851,7 @@ void fuse_ops::apply(module& m) const
...
@@ -825,6 +851,7 @@ void fuse_ops::apply(module& m) const
find_concat_pointwise
{},
find_concat_pointwise
{},
find_gemm_pointwise
{},
find_gemm_pointwise
{},
find_contiguous_tranpose_gemm
{},
find_contiguous_tranpose_gemm
{},
find_contiguous_tranpose_precompile
{},
find_commutative_broadcast
{});
find_commutative_broadcast
{});
match
::
find_matches
(
m
,
find_contiguous
{});
match
::
find_matches
(
m
,
find_contiguous
{});
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment