Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
b162c4ec
Commit
b162c4ec
authored
Oct 03, 2022
by
charlie
Browse files
More progress
parent
412c298e
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
136 additions
and
118 deletions
+136
-118
src/common.cpp
src/common.cpp
+27
-13
src/include/migraphx/common.hpp
src/include/migraphx/common.hpp
+5
-0
src/include/migraphx/op/broadcast.hpp
src/include/migraphx/op/broadcast.hpp
+63
-52
src/include/migraphx/op/multibroadcast.hpp
src/include/migraphx/op/multibroadcast.hpp
+41
-53
No files found.
src/common.cpp
View file @
b162c4ec
...
@@ -31,6 +31,22 @@
...
@@ -31,6 +31,22 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
auto
compute_broadcasting
=
[](
std
::
vector
<
std
::
size_t
>
s0
,
std
::
vector
<
std
::
size_t
>
s1
)
{
std
::
vector
<
std
::
size_t
>
out_lens
(
s1
);
auto
offset
=
s1
.
size
()
-
s0
.
size
();
std
::
transform
(
s0
.
begin
(),
s0
.
end
(),
s1
.
begin
()
+
offset
,
out_lens
.
begin
()
+
offset
,
[
&
](
auto
a
,
auto
b
)
{
if
(
a
!=
b
and
a
!=
1
and
b
!=
1
)
{
MIGRAPHX_THROW
(
"COMPUTE_BROADCASTLEN: shape {"
+
to_string_range
(
s0
)
+
"} and {"
+
to_string_range
(
s1
)
+
"} mismatch!"
);
}
return
std
::
max
(
a
,
b
);
});
return
out_lens
;
};
// Example:
// Example:
// s0 = (3,2,4,5) and s1 = (2,1,1)
// s0 = (3,2,4,5) and s1 = (2,1,1)
//
//
...
@@ -50,20 +66,17 @@ std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
...
@@ -50,20 +66,17 @@ std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
return
s0
;
return
s0
;
if
(
s0
.
size
()
>
s1
.
size
())
if
(
s0
.
size
()
>
s1
.
size
())
s0
.
swap
(
s1
);
s0
.
swap
(
s1
);
return
compute_broadcasting
(
s0
,
s1
);
}
std
::
vector
<
std
::
size_t
>
out_lens
(
s1
);
std
::
vector
<
std
::
size_t
>
broadcast_s0s1_lens
(
std
::
vector
<
std
::
size_t
>
s0
,
auto
offset
=
s1
.
size
()
-
s0
.
size
();
std
::
vector
<
std
::
size_t
>
s1
)
std
::
transform
(
{
s0
.
begin
(),
s0
.
end
(),
s1
.
begin
()
+
offset
,
out_lens
.
begin
()
+
offset
,
[
&
](
auto
a
,
auto
b
)
{
if
(
s0
==
s1
)
if
(
a
!=
b
and
a
!=
1
and
b
!=
1
)
return
s0
;
{
if
(
s0
.
size
()
>
s1
.
size
())
MIGRAPHX_THROW
(
"COMPUTE_BROADCASTLEN: shape {"
+
to_string_range
(
s0
)
+
"} and {"
+
MIGRAPHX_THROW
(
"BROADCAST_SHAPE_LENS: s0 size > s1 size and swap not allowed"
);
to_string_range
(
s1
)
+
"} mismatch!"
);
return
compute_broadcasting
(
s0
,
s1
);
}
return
std
::
max
(
a
,
b
);
});
return
out_lens
;
}
}
std
::
vector
<
std
::
size_t
>
compute_common_lens
(
const
std
::
vector
<
shape
>&
shapes
)
std
::
vector
<
std
::
size_t
>
compute_common_lens
(
const
std
::
vector
<
shape
>&
shapes
)
...
@@ -114,6 +127,7 @@ instruction_ref insert_common_op(module& m,
...
@@ -114,6 +127,7 @@ instruction_ref insert_common_op(module& m,
const
operation
&
op
,
const
operation
&
op
,
std
::
vector
<
instruction_ref
>
inputs
)
std
::
vector
<
instruction_ref
>
inputs
)
{
{
// TODO update this to handle dynamic shapes
auto
common
=
common_shape
(
to_shapes
(
inputs
));
auto
common
=
common_shape
(
to_shapes
(
inputs
));
std
::
transform
(
inputs
.
begin
(),
inputs
.
end
(),
inputs
.
begin
(),
[
&
](
auto
input
)
{
std
::
transform
(
inputs
.
begin
(),
inputs
.
end
(),
inputs
.
begin
(),
[
&
](
auto
input
)
{
if
(
input
->
get_shape
().
lens
()
!=
common
.
lens
())
if
(
input
->
get_shape
().
lens
()
!=
common
.
lens
())
...
...
src/include/migraphx/common.hpp
View file @
b162c4ec
...
@@ -36,6 +36,11 @@ struct operation;
...
@@ -36,6 +36,11 @@ struct operation;
std
::
vector
<
std
::
size_t
>
compute_broadcasted_lens
(
std
::
vector
<
std
::
size_t
>
s0
,
std
::
vector
<
std
::
size_t
>
compute_broadcasted_lens
(
std
::
vector
<
std
::
size_t
>
s0
,
std
::
vector
<
std
::
size_t
>
s1
);
std
::
vector
<
std
::
size_t
>
s1
);
// This version doesn't allow s0.size() > s1.size()
std
::
vector
<
std
::
size_t
>
broadcast_s0s1_lens
(
std
::
vector
<
std
::
size_t
>
s0
,
std
::
vector
<
std
::
size_t
>
s1
);
shape
common_shape
(
const
std
::
vector
<
shape
>&
shapes
);
shape
common_shape
(
const
std
::
vector
<
shape
>&
shapes
);
instruction_ref
insert_common_op
(
module
&
m
,
instruction_ref
insert_common_op
(
module
&
m
,
...
...
src/include/migraphx/op/broadcast.hpp
View file @
b162c4ec
...
@@ -27,6 +27,7 @@
...
@@ -27,6 +27,7 @@
#include <migraphx/check_shapes.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/common.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -56,63 +57,73 @@ struct broadcast
...
@@ -56,63 +57,73 @@ struct broadcast
{
{
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
,
2
);
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
,
2
);
auto
s0
=
inputs
.
at
(
0
);
auto
s0
=
inputs
.
at
(
0
);
auto
t
=
s0
.
type
();
auto
t
=
s0
.
type
();
if
(
inputs
.
size
()
==
1
)
if
(
inputs
.
size
()
==
1
)
{
{
std
::
vector
<
size_t
>
bcast_strides
(
broadcast_lens
.
size
(),
0
);
std
::
vector
<
size_t
>
bcast_strides
(
broadcast_lens
.
size
(),
0
);
// the broadcast op is deprecated now, so not handling the negative
// the broadcast op is deprecated now, so not handling the negative
// value of axis anymore
// value of axis anymore
if
(
axis
>=
broadcast_lens
.
size
())
if
(
axis
>=
broadcast_lens
.
size
())
{
MIGRAPHX_THROW
(
"BROADCAST : axis is out of range"
);
}
if
(
broadcast_lens
.
size
()
-
axis
<
s0
.
lens
().
size
())
{
MIGRAPHX_THROW
(
"BROADCAST: (broadcast ndims - axis) is less than s0 ndims"
);
}
if
(
not
std
::
equal
(
s0
.
lens
().
begin
(),
s0
.
lens
().
end
(),
broadcast_lens
.
begin
()
+
axis
))
{
MIGRAPHX_THROW
(
"BROADCAST: when broadcasting, succeeding sizes must match"
);
}
std
::
copy
(
s0
.
strides
().
begin
(),
s0
.
strides
().
end
(),
bcast_strides
.
begin
()
+
axis
);
shape
output
{
t
,
broadcast_lens
,
std
::
move
(
bcast_strides
)};
if
(
output
.
elements
()
<
s0
.
elements
())
MIGRAPHX_THROW
(
"BROADCAST: output size must be greater than or equal to s0 size"
);
return
output
;
}
else
{
if
(
s0
.
dynamic
()
and
s1
.
dynamic
())
{
auto
bcast_max_lens
=
compute_broadcasted_lens
(
s0
.
max_lens
(),
s1
.
max_lens
());
auto
bcast_min_lens
=
compute_broadcasted_lens
(
s0
.
min_lens
(),
s1
.
min_lens
());
auto
bcast_opt_lens
=
compute_broadcasted_lens
(
s0
.
opt_lens
(),
s1
.
opt_lens
());
std
::
vector
<
shape
::
dynamic_dimension
>
output_dyn_dims
=
{};
for
(
size_t
i
=
0
;
i
<
bcast_max_lens
.
size
();
++
i
)
{
output_dyn_dims
.
push_back
(
shape
::
dynamic_dimension
{
min_spatial_dims
[
i
],
max_spatial_dims
[
i
],
opt_spatial_dims
[
i
]});
}
return
{
t
,
output_dyn_dims
};
}
else
if
(
not
s0
.
dynamic
()
and
not
s1
.
dynamic
())
{
{
auto
bcast_lens
=
compute_broadcasted_lens
(
s0
.
lens
(),
s1
.
lens
());
MIGRAPHX_THROW
(
"BROADCAST : axis is out of range"
);
std
::
vector
<
size_t
>
bcast_strides
(
broadcast_lens
.
size
(),
0
);
}
std
::
copy
(
s0
.
strides
().
begin
(),
s0
.
strides
().
end
(),
bcast_strides
.
begin
()
+
axis
);
return
{
t
,
std
::
move
(
bcast_lens
),
std
::
move
(
bcast_strides
)};
if
(
broadcast_lens
.
size
()
-
axis
<
s0
.
lens
().
size
())
{
MIGRAPHX_THROW
(
"BROADCAST: (broadcast ndims - axis) is less than s0 ndims"
);
}
if
(
not
std
::
equal
(
s0
.
lens
().
begin
(),
s0
.
lens
().
end
(),
broadcast_lens
.
begin
()
+
axis
))
{
MIGRAPHX_THROW
(
"BROADCAST: when broadcasting, succeeding sizes must match"
);
}
std
::
copy
(
s0
.
strides
().
begin
(),
s0
.
strides
().
end
(),
bcast_strides
.
begin
()
+
axis
);
shape
output
{
t
,
broadcast_lens
,
std
::
move
(
bcast_strides
)};
if
(
output
.
elements
()
<
s0
.
elements
())
MIGRAPHX_THROW
(
"BROADCAST: output size must be greater than or equal to s0 size"
);
return
output
;
}
else
{
auto
s1
=
inputs
.
at
(
1
);
if
(
axis
>=
s1
.
max_lens
().
size
())
{
MIGRAPHX_THROW
(
"BROADCAST_2in: axis is out of range of s1"
);
}
if
(
s1
.
max_lens
().
size
()
-
axis
<
s0
.
max_lens
().
size
())
{
MIGRAPHX_THROW
(
"BROADCAST_2in: (s1 rank - axis) is less than s0 rank"
);
}
if
(
s0
.
dynamic
()
or
s1
.
dynamic
())
{
auto
bcast_max_lens
=
broadcast_s0s1_lens
(
s0
.
max_lens
(),
s1
.
max_lens
());
auto
bcast_min_lens
=
broadcast_s0s1_lens
(
s0
.
min_lens
(),
s1
.
min_lens
());
auto
bcast_opt_lens
=
broadcast_s0s1_lens
(
s0
.
opt_lens
(),
s1
.
opt_lens
());
std
::
vector
<
shape
::
dynamic_dimension
>
output_dyn_dims
=
{};
for
(
size_t
i
=
0
;
i
<
bcast_max_lens
.
size
();
++
i
)
{
output_dyn_dims
.
push_back
(
shape
::
dynamic_dimension
{
bcast_max_lens
[
i
],
bcast_min_lens
[
i
],
bcast_opt_lens
[
i
]});
}
return
{
t
,
std
::
move
(
output_dyn_dims
)};
}
}
else
else
{
{
MIGRAPHX_THROW
(
if
(
not
std
::
equal
(
s0
.
lens
().
begin
(),
s0
.
lens
().
end
(),
s1
.
lens
().
begin
()
+
axis
))
"BROADCAST: s0 and s1 are not both dynamic or static"
);
{
MIGRAPHX_THROW
(
"BROADCAST_2in: when broadcasting, succeeding sizes must match"
);
}
auto
bcast_lens
=
compute_broadcasted_lens
(
s0
.
lens
(),
s1
.
lens
());
std
::
vector
<
size_t
>
bcast_strides
(
broadcast_lens
.
size
(),
0
);
std
::
copy
(
s0
.
strides
().
begin
(),
s0
.
strides
().
end
(),
bcast_strides
.
begin
()
+
axis
);
return
{
t
,
std
::
move
(
bcast_lens
),
std
::
move
(
bcast_strides
)};
}
}
}
}
}
}
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
...
...
src/include/migraphx/op/multibroadcast.hpp
View file @
b162c4ec
...
@@ -50,27 +50,25 @@ struct multibroadcast
...
@@ -50,27 +50,25 @@ struct multibroadcast
{
{
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
,
2
);
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
,
2
);
auto
t
=
inputs
.
at
(
0
).
type
();
auto
t
=
inputs
.
at
(
0
).
type
();
auto
s0
=
inputs
.
at
(
0
);
auto
s0
=
inputs
.
at
(
0
);
if
(
s0
.
lens
().
empty
())
if
(
s0
.
lens
().
empty
())
{
{
MIGRAPHX_THROW
(
"MULTIBROADCAST: inputs dimensions should be > 0"
);
MIGRAPHX_THROW
(
"MULTIBROADCAST: inputs dimensions should be > 0"
);
}
}
auto
make_bcast_strides
=
[
&
](
std
::
size_t
out_num_dims
,
std
::
size_t
offset
)
{
auto
make_bcast_strides
=
[
&
](
std
::
size_t
out_num_dims
,
std
::
size_t
offset
)
std
::
vector
<
size_t
>
bcast_strides
(
out_num_dims
,
0
);
{
for
(
std
::
ptrdiff_t
i
=
s0
.
lens
().
size
()
-
1
;
i
>=
0
;
i
--
)
std
::
vector
<
size_t
>
bcast_strides
(
out_num_dims
,
0
);
{
for
(
std
::
ptrdiff_t
i
=
s0
.
lens
().
size
()
-
1
;
i
>=
0
;
i
--
)
if
(
output_lens
[
i
+
offset
]
==
s0
.
lens
()[
i
])
{
{
if
(
output_lens
[
i
+
offset
]
==
s0
.
lens
()[
i
])
bcast_strides
[
i
+
offset
]
=
s0
.
strides
()[
i
];
{
}
bcast_strides
[
i
+
offset
]
=
s0
.
strides
()[
i
];
}
}
return
bcast_strides
;
}
};
return
bcast_strides
;
};
if
(
inputs
.
size
()
==
1
)
if
(
inputs
.
size
()
==
1
)
{
{
...
@@ -84,52 +82,42 @@ struct multibroadcast
...
@@ -84,52 +82,42 @@ struct multibroadcast
{
{
if
(
output_lens
[
i
+
offset
]
!=
s0
.
lens
()[
i
]
and
s0
.
lens
()[
i
]
!=
1
)
if
(
output_lens
[
i
+
offset
]
!=
s0
.
lens
()[
i
]
and
s0
.
lens
()[
i
]
!=
1
)
{
{
MIGRAPHX_THROW
(
MIGRAPHX_THROW
(
"MULTIBROADCAST: input shape {"
+
to_string_range
(
s0
.
lens
())
+
"MULTIBROADCAST: input shape
{"
+
to_string_range
(
s0
.
lens
()
)
+
"} cannot be broadcasted to
{"
+
to_string_range
(
output_
lens
)
+
"} cannot be broadcasted to {"
+
to_string_range
(
output_lens
)
+
"}!"
);
"}!"
);
}
}
}
}
auto
bcast_strides
=
make_bcast_strides
(
output_lens
.
size
(),
offset
);
auto
bcast_strides
=
make_bcast_strides
(
output_lens
.
size
(),
offset
);
return
{
t
,
output_lens
,
std
::
move
(
bcast_strides
)};
return
{
t
,
output_lens
,
std
::
move
(
bcast_strides
)};
}
}
else
else
{
{
// need both shapes when handling dynamic case
// shapes can be dynamic (at compile-time) or static (at evaluation time)
// this function will be called through compute_output_shape conversion to dyn_output
// new compute_broadcasted_lens for dynamic shapes
// do we want this to work in both broadcast directions?
// s0 and s1 as shape inputs
// always s0 -> s1 shape or allow s0 to retain the same shape?
// presuming that it's always s0 -> s1 shape, since that's closer to the current behavior
// compute_broadcasted_lens() will swap the shapes if s1.size() < s0.size(), may need to make another function
auto
s1
=
inputs
.
at
(
1
);
auto
s1
=
inputs
.
at
(
1
);
if
(
s0
.
dynamic
()
and
s1
.
dynamic
())
if
(
s0
.
max_lens
().
size
()
>
s1
.
max_lens
().
size
())
{
{
auto
bcast_max_lens
=
compute_broadcasted_lens
(
s0
.
max_lens
(),
s1
.
max_lens
());
MIGRAPHX_THROW
(
"MULTIBROADCAST: s0 rank should <= s1 rank"
);
auto
bcast_min_lens
=
compute_broadcasted_lens
(
s0
.
min_lens
(),
s1
.
min_lens
());
}
auto
bcast_opt_lens
=
compute_broadcasted_lens
(
s0
.
opt_lens
(),
s1
.
opt_lens
());
if
(
s0
.
dynamic
()
or
s1
.
dynamic
())
std
::
vector
<
shape
::
dynamic_dimension
>
output_dyn_dims
=
{};
for
(
size_t
i
=
0
;
i
<
bcast_max_lens
.
size
();
++
i
)
{
output_dyn_dims
.
push_back
(
shape
::
dynamic_dimension
{
min_spatial_dims
[
i
],
max_spatial_dims
[
i
],
opt_spatial_dims
[
i
]});
}
return
{
t
,
std
::
move
(
output_dyn_dims
)};
}
else
if
(
not
s0
.
dynamic
()
and
not
s1
.
dynamic
())
{
{
auto
bcast_lens
=
compute_broadcasted_lens
(
s0
.
lens
(),
s1
.
lens
());
auto
bcast_max_lens
=
broadcast_s0s1_lens
(
s0
.
max_lens
(),
s1
.
max_lens
());
auto
offset
=
s1
.
lens
().
size
()
-
s0
.
lens
().
size
();
auto
bcast_min_lens
=
broadcast_s0s1_lens
(
s0
.
min_lens
(),
s1
.
min_lens
());
auto
bcast_strides
=
make_bcast_strides
(
s1
.
lens
().
size
(),
offset
);
auto
bcast_opt_lens
=
broadcast_s0s1_lens
(
s0
.
opt_lens
(),
s1
.
opt_lens
());
return
{
t
,
std
::
move
(
bcast_lens
),
std
::
move
(
bcast_strides
)};
std
::
vector
<
shape
::
dynamic_dimension
>
output_dyn_dims
=
{};
for
(
size_t
i
=
0
;
i
<
bcast_max_lens
.
size
();
++
i
)
{
output_dyn_dims
.
push_back
(
shape
::
dynamic_dimension
{
bcast_max_lens
[
i
],
bcast_min_lens
[
i
],
bcast_opt_lens
[
i
]});
}
return
{
t
,
std
::
move
(
output_dyn_dims
)};
}
}
else
else
{
{
MIGRAPHX_THROW
(
auto
bcast_lens
=
compute_broadcasted_lens
(
s0
.
lens
(),
s1
.
lens
());
"MULTIBROADCAST: s0 and s1 are not both dynamic or static"
);
auto
offset
=
s1
.
lens
().
size
()
-
s0
.
lens
().
size
();
auto
bcast_strides
=
make_bcast_strides
(
s1
.
lens
().
size
(),
offset
);
return
{
t
,
std
::
move
(
bcast_lens
),
std
::
move
(
bcast_strides
)};
}
}
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment