Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
b162c4ec
"...gpu/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "6ae2f087eb04b45acaff42fc20b30a6b08f24cca"
Commit
b162c4ec
authored
Oct 03, 2022
by
charlie
Browse files
More progress
parent
412c298e
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
136 additions
and
118 deletions
+136
-118
src/common.cpp
src/common.cpp
+27
-13
src/include/migraphx/common.hpp
src/include/migraphx/common.hpp
+5
-0
src/include/migraphx/op/broadcast.hpp
src/include/migraphx/op/broadcast.hpp
+63
-52
src/include/migraphx/op/multibroadcast.hpp
src/include/migraphx/op/multibroadcast.hpp
+41
-53
No files found.
src/common.cpp
View file @
b162c4ec
...
@@ -31,6 +31,22 @@
...
@@ -31,6 +31,22 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
auto
compute_broadcasting
=
[](
std
::
vector
<
std
::
size_t
>
s0
,
std
::
vector
<
std
::
size_t
>
s1
)
{
std
::
vector
<
std
::
size_t
>
out_lens
(
s1
);
auto
offset
=
s1
.
size
()
-
s0
.
size
();
std
::
transform
(
s0
.
begin
(),
s0
.
end
(),
s1
.
begin
()
+
offset
,
out_lens
.
begin
()
+
offset
,
[
&
](
auto
a
,
auto
b
)
{
if
(
a
!=
b
and
a
!=
1
and
b
!=
1
)
{
MIGRAPHX_THROW
(
"COMPUTE_BROADCASTLEN: shape {"
+
to_string_range
(
s0
)
+
"} and {"
+
to_string_range
(
s1
)
+
"} mismatch!"
);
}
return
std
::
max
(
a
,
b
);
});
return
out_lens
;
};
// Example:
// Example:
// s0 = (3,2,4,5) and s1 = (2,1,1)
// s0 = (3,2,4,5) and s1 = (2,1,1)
//
//
...
@@ -50,20 +66,17 @@ std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
...
@@ -50,20 +66,17 @@ std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
return
s0
;
return
s0
;
if
(
s0
.
size
()
>
s1
.
size
())
if
(
s0
.
size
()
>
s1
.
size
())
s0
.
swap
(
s1
);
s0
.
swap
(
s1
);
return
compute_broadcasting
(
s0
,
s1
);
}
std
::
vector
<
std
::
size_t
>
out_lens
(
s1
);
std
::
vector
<
std
::
size_t
>
broadcast_s0s1_lens
(
std
::
vector
<
std
::
size_t
>
s0
,
auto
offset
=
s1
.
size
()
-
s0
.
size
();
std
::
vector
<
std
::
size_t
>
s1
)
std
::
transform
(
{
s0
.
begin
(),
s0
.
end
(),
s1
.
begin
()
+
offset
,
out_lens
.
begin
()
+
offset
,
[
&
](
auto
a
,
auto
b
)
{
if
(
s0
==
s1
)
if
(
a
!=
b
and
a
!=
1
and
b
!=
1
)
return
s0
;
{
if
(
s0
.
size
()
>
s1
.
size
())
MIGRAPHX_THROW
(
"COMPUTE_BROADCASTLEN: shape {"
+
to_string_range
(
s0
)
+
"} and {"
+
MIGRAPHX_THROW
(
"BROADCAST_SHAPE_LENS: s0 size > s1 size and swap not allowed"
);
to_string_range
(
s1
)
+
"} mismatch!"
);
return
compute_broadcasting
(
s0
,
s1
);
}
return
std
::
max
(
a
,
b
);
});
return
out_lens
;
}
}
std
::
vector
<
std
::
size_t
>
compute_common_lens
(
const
std
::
vector
<
shape
>&
shapes
)
std
::
vector
<
std
::
size_t
>
compute_common_lens
(
const
std
::
vector
<
shape
>&
shapes
)
...
@@ -114,6 +127,7 @@ instruction_ref insert_common_op(module& m,
...
@@ -114,6 +127,7 @@ instruction_ref insert_common_op(module& m,
const
operation
&
op
,
const
operation
&
op
,
std
::
vector
<
instruction_ref
>
inputs
)
std
::
vector
<
instruction_ref
>
inputs
)
{
{
// TODO update this to handle dynamic shapes
auto
common
=
common_shape
(
to_shapes
(
inputs
));
auto
common
=
common_shape
(
to_shapes
(
inputs
));
std
::
transform
(
inputs
.
begin
(),
inputs
.
end
(),
inputs
.
begin
(),
[
&
](
auto
input
)
{
std
::
transform
(
inputs
.
begin
(),
inputs
.
end
(),
inputs
.
begin
(),
[
&
](
auto
input
)
{
if
(
input
->
get_shape
().
lens
()
!=
common
.
lens
())
if
(
input
->
get_shape
().
lens
()
!=
common
.
lens
())
...
...
src/include/migraphx/common.hpp
View file @
b162c4ec
...
@@ -36,6 +36,11 @@ struct operation;
...
@@ -36,6 +36,11 @@ struct operation;
std
::
vector
<
std
::
size_t
>
compute_broadcasted_lens
(
std
::
vector
<
std
::
size_t
>
s0
,
std
::
vector
<
std
::
size_t
>
compute_broadcasted_lens
(
std
::
vector
<
std
::
size_t
>
s0
,
std
::
vector
<
std
::
size_t
>
s1
);
std
::
vector
<
std
::
size_t
>
s1
);
// This version doesn't allow s0.size() > s1.size()
std
::
vector
<
std
::
size_t
>
broadcast_s0s1_lens
(
std
::
vector
<
std
::
size_t
>
s0
,
std
::
vector
<
std
::
size_t
>
s1
);
shape
common_shape
(
const
std
::
vector
<
shape
>&
shapes
);
shape
common_shape
(
const
std
::
vector
<
shape
>&
shapes
);
instruction_ref
insert_common_op
(
module
&
m
,
instruction_ref
insert_common_op
(
module
&
m
,
...
...
src/include/migraphx/op/broadcast.hpp
View file @
b162c4ec
...
@@ -27,6 +27,7 @@
...
@@ -27,6 +27,7 @@
#include <migraphx/check_shapes.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/common.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -56,63 +57,73 @@ struct broadcast
...
@@ -56,63 +57,73 @@ struct broadcast
{
{
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
,
2
);
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
,
2
);
auto
s0
=
inputs
.
at
(
0
);
auto
s0
=
inputs
.
at
(
0
);
auto
t
=
s0
.
type
();
auto
t
=
s0
.
type
();
if
(
inputs
.
size
()
==
1
)
if
(
inputs
.
size
()
==
1
)
{
{
std
::
vector
<
size_t
>
bcast_strides
(
broadcast_lens
.
size
(),
0
);
std
::
vector
<
size_t
>
bcast_strides
(
broadcast_lens
.
size
(),
0
);
// the broadcast op is deprecated now, so not handling the negative
// the broadcast op is deprecated now, so not handling the negative
// value of axis anymore
// value of axis anymore
if
(
axis
>=
broadcast_lens
.
size
())
if
(
axis
>=
broadcast_lens
.
size
())
{
MIGRAPHX_THROW
(
"BROADCAST : axis is out of range"
);
}
if
(
broadcast_lens
.
size
()
-
axis
<
s0
.
lens
().
size
())
{
MIGRAPHX_THROW
(
"BROADCAST: (broadcast ndims - axis) is less than s0 ndims"
);
}
if
(
not
std
::
equal
(
s0
.
lens
().
begin
(),
s0
.
lens
().
end
(),
broadcast_lens
.
begin
()
+
axis
))
{
MIGRAPHX_THROW
(
"BROADCAST: when broadcasting, succeeding sizes must match"
);
}
std
::
copy
(
s0
.
strides
().
begin
(),
s0
.
strides
().
end
(),
bcast_strides
.
begin
()
+
axis
);
shape
output
{
t
,
broadcast_lens
,
std
::
move
(
bcast_strides
)};
if
(
output
.
elements
()
<
s0
.
elements
())
MIGRAPHX_THROW
(
"BROADCAST: output size must be greater than or equal to s0 size"
);
return
output
;
}
else
{
if
(
s0
.
dynamic
()
and
s1
.
dynamic
())
{
auto
bcast_max_lens
=
compute_broadcasted_lens
(
s0
.
max_lens
(),
s1
.
max_lens
());
auto
bcast_min_lens
=
compute_broadcasted_lens
(
s0
.
min_lens
(),
s1
.
min_lens
());
auto
bcast_opt_lens
=
compute_broadcasted_lens
(
s0
.
opt_lens
(),
s1
.
opt_lens
());
std
::
vector
<
shape
::
dynamic_dimension
>
output_dyn_dims
=
{};
for
(
size_t
i
=
0
;
i
<
bcast_max_lens
.
size
();
++
i
)
{
output_dyn_dims
.
push_back
(
shape
::
dynamic_dimension
{
min_spatial_dims
[
i
],
max_spatial_dims
[
i
],
opt_spatial_dims
[
i
]});
}
return
{
t
,
output_dyn_dims
};
}
else
if
(
not
s0
.
dynamic
()
and
not
s1
.
dynamic
())
{
{
auto
bcast_lens
=
compute_broadcasted_lens
(
s0
.
lens
(),
s1
.
lens
());
MIGRAPHX_THROW
(
"BROADCAST : axis is out of range"
);
std
::
vector
<
size_t
>
bcast_strides
(
broadcast_lens
.
size
(),
0
);
}
std
::
copy
(
s0
.
strides
().
begin
(),
s0
.
strides
().
end
(),
bcast_strides
.
begin
()
+
axis
);
return
{
t
,
std
::
move
(
bcast_lens
),
std
::
move
(
bcast_strides
)};
if
(
broadcast_lens
.
size
()
-
axis
<
s0
.
lens
().
size
())
{
MIGRAPHX_THROW
(
"BROADCAST: (broadcast ndims - axis) is less than s0 ndims"
);
}
if
(
not
std
::
equal
(
s0
.
lens
().
begin
(),
s0
.
lens
().
end
(),
broadcast_lens
.
begin
()
+
axis
))
{
MIGRAPHX_THROW
(
"BROADCAST: when broadcasting, succeeding sizes must match"
);
}
std
::
copy
(
s0
.
strides
().
begin
(),
s0
.
strides
().
end
(),
bcast_strides
.
begin
()
+
axis
);
shape
output
{
t
,
broadcast_lens
,
std
::
move
(
bcast_strides
)};
if
(
output
.
elements
()
<
s0
.
elements
())
MIGRAPHX_THROW
(
"BROADCAST: output size must be greater than or equal to s0 size"
);
return
output
;
}
else
{
auto
s1
=
inputs
.
at
(
1
);
if
(
axis
>=
s1
.
max_lens
().
size
())
{
MIGRAPHX_THROW
(
"BROADCAST_2in: axis is out of range of s1"
);
}
if
(
s1
.
max_lens
().
size
()
-
axis
<
s0
.
max_lens
().
size
())
{
MIGRAPHX_THROW
(
"BROADCAST_2in: (s1 rank - axis) is less than s0 rank"
);
}
if
(
s0
.
dynamic
()
or
s1
.
dynamic
())
{
auto
bcast_max_lens
=
broadcast_s0s1_lens
(
s0
.
max_lens
(),
s1
.
max_lens
());
auto
bcast_min_lens
=
broadcast_s0s1_lens
(
s0
.
min_lens
(),
s1
.
min_lens
());
auto
bcast_opt_lens
=
broadcast_s0s1_lens
(
s0
.
opt_lens
(),
s1
.
opt_lens
());
std
::
vector
<
shape
::
dynamic_dimension
>
output_dyn_dims
=
{};
for
(
size_t
i
=
0
;
i
<
bcast_max_lens
.
size
();
++
i
)
{
output_dyn_dims
.
push_back
(
shape
::
dynamic_dimension
{
bcast_max_lens
[
i
],
bcast_min_lens
[
i
],
bcast_opt_lens
[
i
]});
}
return
{
t
,
std
::
move
(
output_dyn_dims
)};
}
}
else
else
{
{
MIGRAPHX_THROW
(
if
(
not
std
::
equal
(
s0
.
lens
().
begin
(),
s0
.
lens
().
end
(),
s1
.
lens
().
begin
()
+
axis
))
"BROADCAST: s0 and s1 are not both dynamic or static"
);
{
MIGRAPHX_THROW
(
"BROADCAST_2in: when broadcasting, succeeding sizes must match"
);
}
auto
bcast_lens
=
compute_broadcasted_lens
(
s0
.
lens
(),
s1
.
lens
());
std
::
vector
<
size_t
>
bcast_strides
(
broadcast_lens
.
size
(),
0
);
std
::
copy
(
s0
.
strides
().
begin
(),
s0
.
strides
().
end
(),
bcast_strides
.
begin
()
+
axis
);
return
{
t
,
std
::
move
(
bcast_lens
),
std
::
move
(
bcast_strides
)};
}
}
}
}
}
}
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
...
...
src/include/migraphx/op/multibroadcast.hpp
View file @
b162c4ec
...
@@ -50,27 +50,25 @@ struct multibroadcast
...
@@ -50,27 +50,25 @@ struct multibroadcast
{
{
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
,
2
);
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
,
2
);
auto
t
=
inputs
.
at
(
0
).
type
();
auto
t
=
inputs
.
at
(
0
).
type
();
auto
s0
=
inputs
.
at
(
0
);
auto
s0
=
inputs
.
at
(
0
);
if
(
s0
.
lens
().
empty
())
if
(
s0
.
lens
().
empty
())
{
{
MIGRAPHX_THROW
(
"MULTIBROADCAST: inputs dimensions should be > 0"
);
MIGRAPHX_THROW
(
"MULTIBROADCAST: inputs dimensions should be > 0"
);
}
}
auto
make_bcast_strides
=
[
&
](
std
::
size_t
out_num_dims
,
std
::
size_t
offset
)
{
auto
make_bcast_strides
=
[
&
](
std
::
size_t
out_num_dims
,
std
::
size_t
offset
)
std
::
vector
<
size_t
>
bcast_strides
(
out_num_dims
,
0
);
{
for
(
std
::
ptrdiff_t
i
=
s0
.
lens
().
size
()
-
1
;
i
>=
0
;
i
--
)
std
::
vector
<
size_t
>
bcast_strides
(
out_num_dims
,
0
);
{
for
(
std
::
ptrdiff_t
i
=
s0
.
lens
().
size
()
-
1
;
i
>=
0
;
i
--
)
if
(
output_lens
[
i
+
offset
]
==
s0
.
lens
()[
i
])
{
{
if
(
output_lens
[
i
+
offset
]
==
s0
.
lens
()[
i
])
bcast_strides
[
i
+
offset
]
=
s0
.
strides
()[
i
];
{
}
bcast_strides
[
i
+
offset
]
=
s0
.
strides
()[
i
];
}
}
return
bcast_strides
;
}
};
return
bcast_strides
;
};
if
(
inputs
.
size
()
==
1
)
if
(
inputs
.
size
()
==
1
)
{
{
...
@@ -84,52 +82,42 @@ struct multibroadcast
...
@@ -84,52 +82,42 @@ struct multibroadcast
{
{
if
(
output_lens
[
i
+
offset
]
!=
s0
.
lens
()[
i
]
and
s0
.
lens
()[
i
]
!=
1
)
if
(
output_lens
[
i
+
offset
]
!=
s0
.
lens
()[
i
]
and
s0
.
lens
()[
i
]
!=
1
)
{
{
MIGRAPHX_THROW
(
MIGRAPHX_THROW
(
"MULTIBROADCAST: input shape {"
+
to_string_range
(
s0
.
lens
())
+
"MULTIBROADCAST: input shape
{"
+
to_string_range
(
s0
.
lens
()
)
+
"} cannot be broadcasted to
{"
+
to_string_range
(
output_
lens
)
+
"} cannot be broadcasted to {"
+
to_string_range
(
output_lens
)
+
"}!"
);
"}!"
);
}
}
}
}
auto
bcast_strides
=
make_bcast_strides
(
output_lens
.
size
(),
offset
);
auto
bcast_strides
=
make_bcast_strides
(
output_lens
.
size
(),
offset
);
return
{
t
,
output_lens
,
std
::
move
(
bcast_strides
)};
return
{
t
,
output_lens
,
std
::
move
(
bcast_strides
)};
}
}
else
else
{
{
// need both shapes when handling dynamic case
// shapes can be dynamic (at compile-time) or static (at evaluation time)
// this function will be called through compute_output_shape conversion to dyn_output
// new compute_broadcasted_lens for dynamic shapes
// do we want this to work in both broadcast directions?
// s0 and s1 as shape inputs
// always s0 -> s1 shape or allow s0 to retain the same shape?
// presuming that it's always s0 -> s1 shape, since that's closer to the current behavior
// compute_broadcasted_lens() will swap the shapes if s1.size() < s0.size(), may need to make another function
auto
s1
=
inputs
.
at
(
1
);
auto
s1
=
inputs
.
at
(
1
);
if
(
s0
.
dynamic
()
and
s1
.
dynamic
())
if
(
s0
.
max_lens
().
size
()
>
s1
.
max_lens
().
size
())
{
{
auto
bcast_max_lens
=
compute_broadcasted_lens
(
s0
.
max_lens
(),
s1
.
max_lens
());
MIGRAPHX_THROW
(
"MULTIBROADCAST: s0 rank should <= s1 rank"
);
auto
bcast_min_lens
=
compute_broadcasted_lens
(
s0
.
min_lens
(),
s1
.
min_lens
());
}
auto
bcast_opt_lens
=
compute_broadcasted_lens
(
s0
.
opt_lens
(),
s1
.
opt_lens
());
if
(
s0
.
dynamic
()
or
s1
.
dynamic
())
std
::
vector
<
shape
::
dynamic_dimension
>
output_dyn_dims
=
{};
for
(
size_t
i
=
0
;
i
<
bcast_max_lens
.
size
();
++
i
)
{
output_dyn_dims
.
push_back
(
shape
::
dynamic_dimension
{
min_spatial_dims
[
i
],
max_spatial_dims
[
i
],
opt_spatial_dims
[
i
]});
}
return
{
t
,
std
::
move
(
output_dyn_dims
)};
}
else
if
(
not
s0
.
dynamic
()
and
not
s1
.
dynamic
())
{
{
auto
bcast_lens
=
compute_broadcasted_lens
(
s0
.
lens
(),
s1
.
lens
());
auto
bcast_max_lens
=
broadcast_s0s1_lens
(
s0
.
max_lens
(),
s1
.
max_lens
());
auto
offset
=
s1
.
lens
().
size
()
-
s0
.
lens
().
size
();
auto
bcast_min_lens
=
broadcast_s0s1_lens
(
s0
.
min_lens
(),
s1
.
min_lens
());
auto
bcast_strides
=
make_bcast_strides
(
s1
.
lens
().
size
(),
offset
);
auto
bcast_opt_lens
=
broadcast_s0s1_lens
(
s0
.
opt_lens
(),
s1
.
opt_lens
());
return
{
t
,
std
::
move
(
bcast_lens
),
std
::
move
(
bcast_strides
)};
std
::
vector
<
shape
::
dynamic_dimension
>
output_dyn_dims
=
{};
for
(
size_t
i
=
0
;
i
<
bcast_max_lens
.
size
();
++
i
)
{
output_dyn_dims
.
push_back
(
shape
::
dynamic_dimension
{
bcast_max_lens
[
i
],
bcast_min_lens
[
i
],
bcast_opt_lens
[
i
]});
}
return
{
t
,
std
::
move
(
output_dyn_dims
)};
}
}
else
else
{
{
MIGRAPHX_THROW
(
auto
bcast_lens
=
compute_broadcasted_lens
(
s0
.
lens
(),
s1
.
lens
());
"MULTIBROADCAST: s0 and s1 are not both dynamic or static"
);
auto
offset
=
s1
.
lens
().
size
()
-
s0
.
lens
().
size
();
auto
bcast_strides
=
make_bcast_strides
(
s1
.
lens
().
size
(),
offset
);
return
{
t
,
std
::
move
(
bcast_lens
),
std
::
move
(
bcast_strides
)};
}
}
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment