Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
b72ad090
Commit
b72ad090
authored
Nov 27, 2023
by
charlie
Browse files
initial
parent
57f734a5
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
194 additions
and
36 deletions
+194
-36
src/common.cpp
src/common.cpp
+40
-28
src/include/migraphx/common.hpp
src/include/migraphx/common.hpp
+5
-0
src/include/migraphx/op/dot.hpp
src/include/migraphx/op/dot.hpp
+17
-3
src/include/migraphx/op/dot_broadcast.hpp
src/include/migraphx/op/dot_broadcast.hpp
+89
-0
src/onnx/parse_matmul.cpp
src/onnx/parse_matmul.cpp
+8
-4
src/simplify_dyn_ops.cpp
src/simplify_dyn_ops.cpp
+35
-1
No files found.
src/common.cpp
View file @
b72ad090
...
@@ -51,21 +51,23 @@ std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
...
@@ -51,21 +51,23 @@ std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
});
});
return
out_lens
;
return
out_lens
;
}
}
std
::
vector
<
shape
::
dynamic_dimension
>
std
::
vector
<
shape
::
dynamic_dimension
>
compute_broadcasted_dyn_dims
(
shape
s0
,
shape
s1
)
compute_broadcasted_dyn_dims
(
std
::
vector
<
shape
::
dynamic_dimension
>
dds0
,
std
::
vector
<
shape
::
dynamic_dimension
>
dds1
)
{
{
// change both shapes to dynamic_dimension representation
if
(
dds0
.
size
()
>
dds1
.
size
())
s0
=
s0
.
to_dynamic
();
s1
=
s1
.
to_dynamic
();
if
(
s0
.
ndim
()
>
s1
.
ndim
())
{
{
std
::
swap
(
s0
,
s1
);
std
::
swap
(
dd
s0
,
dd
s1
);
}
}
auto
offset
=
s1
.
ndim
()
-
s0
.
ndim
();
auto
offset
=
dds1
.
size
()
-
dds0
.
size
();
std
::
vector
<
shape
::
dynamic_dimension
>
out_dims
(
s1
.
dyn_dims
());
std
::
vector
<
shape
::
dynamic_dimension
>
out_dims
(
dds1
);
std
::
transform
(
s0
.
dyn_dims
().
cbegin
(),
// If one within the range of the other
s0
.
dyn_dims
().
cend
(),
auto
dd_within_range
=
[
&
](
shape
::
dynamic_dimension
x
,
shape
::
dynamic_dimension
y
)
{
s1
.
dyn_dims
().
cbegin
()
+
offset
,
return
(
x
.
min
>=
y
.
min
and
x
.
max
<=
y
.
max
);
};
std
::
transform
(
dds0
.
cbegin
(),
dds0
.
cend
(),
dds1
.
cbegin
()
+
offset
,
out_dims
.
begin
()
+
offset
,
out_dims
.
begin
()
+
offset
,
[
&
](
auto
a
,
auto
b
)
{
[
&
](
auto
a
,
auto
b
)
{
if
(
a
==
b
or
b
==
1
)
if
(
a
==
b
or
b
==
1
)
...
@@ -76,16 +78,32 @@ std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, sha
...
@@ -76,16 +78,32 @@ std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, sha
{
{
return
b
;
return
b
;
}
}
else
if
(
dd_within_range
(
a
,
b
))
{
return
a
;
}
else
if
(
dd_within_range
(
b
,
a
))
{
return
b
;
}
else
else
{
{
MIGRAPHX_THROW
(
"COMPUTE_BROADCASTED_DYN_DIMS: dynamic shapes {"
+
MIGRAPHX_THROW
(
"COMPUTE_BROADCASTED_DYN_DIMS: dynamic shapes {"
+
migraphx
::
to_string_range
(
s0
.
dyn_dims
()
)
+
"} and {"
+
migraphx
::
to_string_range
(
dd
s0
)
+
"} and {"
+
migraphx
::
to_string_range
(
s1
.
dyn_dims
()
)
+
"} mismatch!"
);
migraphx
::
to_string_range
(
dd
s1
)
+
"} mismatch!"
);
}
}
});
});
return
out_dims
;
return
out_dims
;
}
}
std
::
vector
<
shape
::
dynamic_dimension
>
compute_broadcasted_dyn_dims
(
shape
s0
,
shape
s1
)
{
// change both shapes to dynamic_dimension representation
s0
=
s0
.
to_dynamic
();
s1
=
s1
.
to_dynamic
();
return
compute_broadcasted_dyn_dims
(
s0
.
dyn_dims
(),
s1
.
dyn_dims
());
}
std
::
vector
<
shape
::
dynamic_dimension
>
compute_common_dyn_dims
(
const
std
::
vector
<
shape
>&
shapes
)
std
::
vector
<
shape
::
dynamic_dimension
>
compute_common_dyn_dims
(
const
std
::
vector
<
shape
>&
shapes
)
{
{
auto
ret_shape
=
shapes
.
at
(
0
);
auto
ret_shape
=
shapes
.
at
(
0
);
...
@@ -151,24 +169,18 @@ insert_common_args(module& m, instruction_ref ins, std::vector<instruction_ref>
...
@@ -151,24 +169,18 @@ insert_common_args(module& m, instruction_ref ins, std::vector<instruction_ref>
auto
c_dyn_dims
=
compute_common_dyn_dims
(
input_shapes
);
auto
c_dyn_dims
=
compute_common_dyn_dims
(
input_shapes
);
auto
s0
=
inputs
[
0
]
->
get_shape
();
auto
s0
=
inputs
[
0
]
->
get_shape
();
if
(
not
s0
.
dynamic
()
or
s0
.
dyn_dims
()
!=
c_dyn_dims
)
// changed to always add the multibroadcast to handle the cases from split_single_dyn_dim
{
inputs
[
0
]
=
m
.
insert_instruction
(
inputs
[
0
]
=
m
.
insert_instruction
(
ins
,
make_op
(
"multibroadcast"
,
{{
"out_dyn_dims"
,
to_value
(
c_dyn_dims
)}}),
inputs
);
ins
,
make_op
(
"multibroadcast"
,
{{
"out_dyn_dims"
,
to_value
(
c_dyn_dims
)}}),
inputs
);
}
std
::
transform
(
inputs
.
begin
()
+
1
,
inputs
.
end
(),
inputs
.
begin
()
+
1
,
[
&
](
auto
input
)
{
std
::
transform
(
inputs
.
begin
()
+
1
,
inputs
.
end
(),
inputs
.
begin
()
+
1
,
[
&
](
auto
input
)
{
// uses previous input to avoid recalculating the common shape from the
// uses previous input to avoid recalculating the common shape from the
// full set of input shapes at runtime
// full set of input shapes at runtime
auto
s
=
input
->
get_shape
();
auto
s
=
input
->
get_shape
();
if
(
not
s
.
dynamic
()
or
s
.
dyn_dims
()
!=
c_dyn_dims
)
return
m
.
insert_instruction
(
{
ins
,
return
m
.
insert_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_dyn_dims"
,
to_value
(
c_dyn_dims
)}}),
ins
,
input
,
make_op
(
"multibroadcast"
,
{{
"out_dyn_dims"
,
to_value
(
c_dyn_dims
)}}),
inputs
[
0
]);
input
,
inputs
[
0
]);
}
return
input
;
});
});
std
::
transform
(
inputs
.
begin
(),
inputs
.
end
(),
inputs
.
begin
(),
[
&
](
auto
input
)
{
std
::
transform
(
inputs
.
begin
(),
inputs
.
end
(),
inputs
.
begin
(),
[
&
](
auto
input
)
{
if
(
input
->
get_shape
().
type
()
!=
c_type
)
if
(
input
->
get_shape
().
type
()
!=
c_type
)
...
...
src/include/migraphx/common.hpp
View file @
b72ad090
...
@@ -58,6 +58,11 @@ MIGRAPHX_EXPORT
...
@@ -58,6 +58,11 @@ MIGRAPHX_EXPORT
std
::
vector
<
std
::
size_t
>
compute_broadcasted_lens
(
std
::
vector
<
std
::
size_t
>
s0
,
std
::
vector
<
std
::
size_t
>
compute_broadcasted_lens
(
std
::
vector
<
std
::
size_t
>
s0
,
std
::
vector
<
std
::
size_t
>
s1
);
std
::
vector
<
std
::
size_t
>
s1
);
MIGRAPHX_EXPORT
std
::
vector
<
shape
::
dynamic_dimension
>
compute_broadcasted_dyn_dims
(
std
::
vector
<
shape
::
dynamic_dimension
>
dds0
,
std
::
vector
<
shape
::
dynamic_dimension
>
dds1
);
MIGRAPHX_EXPORT
MIGRAPHX_EXPORT
std
::
vector
<
shape
::
dynamic_dimension
>
compute_broadcasted_dyn_dims
(
shape
s0
,
shape
s1
);
std
::
vector
<
shape
::
dynamic_dimension
>
compute_broadcasted_dyn_dims
(
shape
s0
,
shape
s1
);
...
...
src/include/migraphx/op/dot.hpp
View file @
b72ad090
...
@@ -34,6 +34,9 @@ namespace migraphx {
...
@@ -34,6 +34,9 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
namespace
op
{
/**
* Matrix multiplication of two tensors.
*/
struct
dot
struct
dot
{
{
std
::
string
name
()
const
{
return
"dot"
;
}
std
::
string
name
()
const
{
return
"dot"
;
}
...
@@ -50,25 +53,36 @@ struct dot
...
@@ -50,25 +53,36 @@ struct dot
}
}
if
(
a
.
dynamic
()
or
b
.
dynamic
())
if
(
a
.
dynamic
()
or
b
.
dynamic
())
{
{
auto
dd_within_range
=
[
&
](
shape
::
dynamic_dimension
x
,
shape
::
dynamic_dimension
y
)
{
return
(
x
.
min
>=
y
.
min
and
x
.
max
<=
y
.
max
);
};
auto
s0
=
a
.
to_dynamic
();
auto
s0
=
a
.
to_dynamic
();
auto
s1
=
b
.
to_dynamic
();
auto
s1
=
b
.
to_dynamic
();
if
(
not
std
::
equal
(
s0
.
dyn_dims
().
rbegin
()
+
2
,
if
(
not
std
::
equal
(
s0
.
dyn_dims
().
rbegin
()
+
2
,
s0
.
dyn_dims
().
rend
(),
s0
.
dyn_dims
().
rend
(),
s1
.
dyn_dims
().
rbegin
()
+
2
,
s1
.
dyn_dims
().
rbegin
()
+
2
,
s1
.
dyn_dims
().
rend
()))
s1
.
dyn_dims
().
rend
(),
[
&
](
auto
x
,
auto
y
)
{
return
(
dd_within_range
(
x
,
y
)
or
dd_within_range
(
y
,
x
));
}))
{
{
MIGRAPHX_THROW
(
"DOT: dynamic outer dimensions of A and B mismatch: {"
+
MIGRAPHX_THROW
(
"DOT: dynamic outer dimensions of A and B mismatch or not within "
"dynamic_dimension range: {"
+
to_string_range
(
s0
.
dyn_dims
())
+
"} x {"
+
to_string_range
(
s0
.
dyn_dims
())
+
"} x {"
+
to_string_range
(
s1
.
dyn_dims
())
+
"}"
);
to_string_range
(
s1
.
dyn_dims
())
+
"}"
);
}
}
std
::
size_t
dim_0
=
s0
.
ndim
()
-
2
;
std
::
size_t
dim_0
=
s0
.
ndim
()
-
2
;
std
::
size_t
dim_1
=
s0
.
ndim
()
-
1
;
std
::
size_t
dim_1
=
s0
.
ndim
()
-
1
;
if
(
s0
.
dyn_dims
()[
dim_1
]
!=
s1
.
dyn_dims
()[
dim_0
])
auto
x
=
s0
.
dyn_dims
()[
dim_1
];
auto
y
=
s1
.
dyn_dims
()[
dim_0
];
if
(
not
dd_within_range
(
x
,
y
)
and
not
dd_within_range
(
y
,
x
))
{
{
MIGRAPHX_THROW
(
"DOT: dynamic inner dimensions do not match: {"
+
MIGRAPHX_THROW
(
"DOT: dynamic inner dimensions do not match: {"
+
to_string_range
(
s0
.
dyn_dims
())
+
"} x {"
+
to_string_range
(
s0
.
dyn_dims
())
+
"} x {"
+
to_string_range
(
s1
.
dyn_dims
())
+
"}"
);
to_string_range
(
s1
.
dyn_dims
())
+
"}"
);
}
}
// NOTE could make this compute_shape more precise by using outer dimensions of the
// shape that's dd_within_range. currently this just uses the outer dimensions of s0.
auto
out_dyn_dims
=
s0
.
dyn_dims
();
auto
out_dyn_dims
=
s0
.
dyn_dims
();
out_dyn_dims
[
dim_1
]
=
s1
.
dyn_dims
()[
dim_1
];
out_dyn_dims
[
dim_1
]
=
s1
.
dyn_dims
()[
dim_1
];
return
{
t
,
out_dyn_dims
};
return
{
t
,
out_dyn_dims
};
...
...
src/include/migraphx/op/dot_broadcast.hpp
0 → 100644
View file @
b72ad090
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_OPERATORS_DOT_BROADCAST_HPP
#define MIGRAPHX_GUARD_OPERATORS_DOT_BROADCAST_HPP
#include <migraphx/check_shapes.hpp>
#include <migraphx/config.hpp>
#include <migraphx/common.hpp>
#include <migraphx/dyn_output.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
/**
* Broadcast dimensions between two tensors for the `dot` operator.
* Essentially broadcasts between two shapes for dimensions other than the last two.
* This operator is only needed if one of the shapes are dynamic.
* Example:
* a = shape[{1, 4}, 3, 248, 248]
* b = shape[248, 365]
* dot_broadcast(a, b) => shape[{1, 4}, 3, 248, 248] (no change)
* dot_broadcast(b, a) => shape[{1, 4}, 3, 248, 365]
*/
struct
dot_broadcast
{
std
::
string
name
()
const
{
return
"dot_broadcast"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
2
);
auto
s0
=
inputs
.
at
(
0
);
auto
s1
=
inputs
.
at
(
1
);
if
(
s0
.
dynamic
()
or
s1
.
dynamic
())
{
s0
=
s0
.
to_dynamic
();
s1
=
s1
.
to_dynamic
();
auto
dds0_it
=
s0
.
dyn_dims
().
end
()
-
2
;
auto
dds1_it
=
s1
.
dyn_dims
().
end
()
-
2
;
std
::
vector
<
shape
::
dynamic_dimension
>
sliced_dds0
{
s0
.
dyn_dims
().
begin
(),
dds0_it
};
std
::
vector
<
shape
::
dynamic_dimension
>
sliced_dds1
{
s1
.
dyn_dims
().
begin
(),
dds1_it
};
auto
output_dyn_dims
=
compute_broadcasted_dyn_dims
(
sliced_dds0
,
sliced_dds1
);
output_dyn_dims
.
insert
(
output_dyn_dims
.
end
(),
dds0_it
,
s0
.
dyn_dims
().
end
());
return
{
s0
.
type
(),
output_dyn_dims
};
}
else
{
auto
l0_it
=
s0
.
lens
().
begin
()
+
s0
.
ndim
()
-
2
;
std
::
vector
<
std
::
size_t
>
l0_broadcasted_lens
(
s0
.
lens
().
begin
(),
l0_it
);
auto
l1_it
=
s1
.
lens
().
begin
()
+
s1
.
ndim
()
-
2
;
std
::
vector
<
std
::
size_t
>
l1_broadcasted_lens
(
s1
.
lens
().
begin
(),
l1_it
);
auto
output_lens
=
compute_broadcasted_lens
(
l0_broadcasted_lens
,
l1_broadcasted_lens
);
output_lens
.
insert
(
output_lens
.
end
(),
l0_it
,
s0
.
lens
().
end
());
return
{
s0
.
type
(),
output_lens
};
}
}
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
return
args
[
0
].
reshape
(
dyn_out
.
computed_shape
);
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/onnx/parse_matmul.cpp
View file @
b72ad090
...
@@ -71,14 +71,18 @@ struct parse_matmul : op_parser<parse_matmul>
...
@@ -71,14 +71,18 @@ struct parse_matmul : op_parser<parse_matmul>
auto
s0_dds
=
a0
->
get_shape
().
to_dynamic
().
dyn_dims
();
auto
s0_dds
=
a0
->
get_shape
().
to_dynamic
().
dyn_dims
();
auto
s1_dds
=
a1
->
get_shape
().
to_dynamic
().
dyn_dims
();
auto
s1_dds
=
a1
->
get_shape
().
to_dynamic
().
dyn_dims
();
// TODO: handling this case requires a new multibroadcast mode
if
(
not
std
::
equal
(
if
(
not
std
::
equal
(
s0_dds
.
rbegin
()
+
2
,
s0_dds
.
rend
(),
s1_dds
.
rbegin
()
+
2
,
s1_dds
.
rend
()))
s0_dds
.
rbegin
()
+
2
,
s0_dds
.
rend
(),
s1_dds
.
rbegin
()
+
2
,
s1_dds
.
rend
()))
{
{
MIGRAPHX_THROW
(
"PARSE_MATMUL: dynamic shape broadcasting not supported"
);
auto
broadcasted_a0
=
info
.
add_instruction
(
make_op
(
"dot_broadcast"
),
a0
,
a1
);
auto
broadcasted_a1
=
info
.
add_instruction
(
make_op
(
"dot_broadcast"
),
a1
,
a0
);
dot_res
=
info
.
add_instruction
(
make_op
(
opd
.
op_name
),
broadcasted_a0
,
broadcasted_a1
);
}
else
{
dot_res
=
info
.
add_instruction
(
make_op
(
opd
.
op_name
),
a0
,
a1
);
}
}
dot_res
=
info
.
add_instruction
(
make_op
(
opd
.
op_name
),
a0
,
a1
);
}
}
else
else
{
{
...
...
src/simplify_dyn_ops.cpp
View file @
b72ad090
...
@@ -318,6 +318,39 @@ struct find_const_alloc_fill
...
@@ -318,6 +318,39 @@ struct find_const_alloc_fill
}
}
};
};
/**
* Simplify dot_broadcast instructions with two static shaped arguments
* From:
* dot_broadcast(static_shape_arg, static_shape_arg)
* To:
* multibroadcast(static_shape_arg); output_lens = static_dot_broadcasted_shape
*/
struct
find_static_dot_broadcast
{
auto
matcher
()
const
{
return
match
::
name
(
"dot_broadcast"
)(
match
::
arg
(
0
)(
match
::
static_shape
()),
match
::
arg
(
1
)(
match
::
static_shape
()));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
mr
)
const
{
auto
dot_broadcast_ins
=
mr
.
result
;
auto
inputs
=
dot_broadcast_ins
->
inputs
();
auto
s0
=
inputs
.
at
(
0
)
->
get_shape
();
auto
s1
=
inputs
.
at
(
1
)
->
get_shape
();
auto
l0_it
=
s0
.
lens
().
begin
()
+
s0
.
ndim
()
-
2
;
std
::
vector
<
std
::
size_t
>
l0_broadcasted_lens
(
s0
.
lens
().
begin
(),
l0_it
);
auto
l1_it
=
s1
.
lens
().
begin
()
+
s1
.
ndim
()
-
2
;
std
::
vector
<
std
::
size_t
>
l1_broadcasted_lens
(
s1
.
lens
().
begin
(),
l1_it
);
auto
output_lens
=
compute_broadcasted_lens
(
l0_broadcasted_lens
,
l1_broadcasted_lens
);
output_lens
.
insert
(
output_lens
.
end
(),
l0_it
,
s0
.
lens
().
end
());
m
.
replace_instruction
(
dot_broadcast_ins
,
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
output_lens
}}),
inputs
.
at
(
0
));
}
};
void
simplify_dyn_ops
::
apply
(
module
&
m
)
const
void
simplify_dyn_ops
::
apply
(
module
&
m
)
const
{
{
match
::
find_matches
(
m
,
match
::
find_matches
(
m
,
...
@@ -327,7 +360,8 @@ void simplify_dyn_ops::apply(module& m) const
...
@@ -327,7 +360,8 @@ void simplify_dyn_ops::apply(module& m) const
find_const_2in_slice
{},
find_const_2in_slice
{},
find_const_3in_slice
{},
find_const_3in_slice
{},
find_const_4in_slice
{},
find_const_4in_slice
{},
find_const_alloc_fill
{});
find_const_alloc_fill
{},
find_static_dot_broadcast
{});
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment