Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
412c298e
Commit
412c298e
authored
Oct 02, 2022
by
charlie
Browse files
Progress on changing ops
multibroadcast and broadcast take two inputs
parent
78c799c5
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
111 additions
and
49 deletions
+111
-49
src/include/migraphx/op/broadcast.hpp
src/include/migraphx/op/broadcast.hpp
+57
-24
src/include/migraphx/op/multibroadcast.hpp
src/include/migraphx/op/multibroadcast.hpp
+54
-25
No files found.
src/include/migraphx/op/broadcast.hpp
View file @
412c298e
...
...
@@ -44,7 +44,6 @@ struct broadcast
{
uint64_t
axis
=
0
;
std
::
vector
<
std
::
size_t
>
broadcast_lens
;
std
::
vector
<
shape
::
dynamic_dimension
>
broadcast_dyn_dims
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
...
...
@@ -55,38 +54,72 @@ struct broadcast
std
::
string
name
()
const
{
return
"broadcast"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
auto
input
=
inputs
.
at
(
0
);
auto
t
=
input
.
type
();
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
,
2
);
auto
s0
=
inputs
.
at
(
0
);
auto
t
=
s0
.
type
();
std
::
vector
<
size_t
>
bcast_strides
(
broadcast_lens
.
size
(),
0
);
// the broadcast op is deprecated now, so not handling the negative
// value of axis anymore
if
(
axis
>=
broadcast_lens
.
size
())
{
MIGRAPHX_THROW
(
"BROADCAST : axis is out of range"
);
}
if
(
inputs
.
size
()
==
1
)
{
std
::
vector
<
size_t
>
bcast_strides
(
broadcast_lens
.
size
(),
0
);
// the broadcast op is deprecated now, so not handling the negative
// value of axis anymore
if
(
axis
>=
broadcast_lens
.
size
())
{
MIGRAPHX_THROW
(
"BROADCAST : axis is out of range"
);
}
if
(
broadcast_lens
.
size
()
-
axis
<
input
.
lens
().
size
())
{
MIGRAPHX_THROW
(
"BROADCAST: (broadcast ndims - axis) is less than
input
ndims"
);
}
if
(
broadcast_lens
.
size
()
-
axis
<
s0
.
lens
().
size
())
{
MIGRAPHX_THROW
(
"BROADCAST: (broadcast ndims - axis) is less than
s0
ndims"
);
}
if
(
not
std
::
equal
(
input
.
lens
().
begin
(),
input
.
lens
().
end
(),
broadcast_lens
.
begin
()
+
axis
))
{
MIGRAPHX_THROW
(
"BROADCAST: when broadcasting, succeeding sizes must match"
);
}
std
::
copy
(
input
.
strides
().
begin
(),
input
.
strides
().
end
(),
bcast_strides
.
begin
()
+
axis
);
if
(
not
std
::
equal
(
s0
.
lens
().
begin
(),
s0
.
lens
().
end
(),
broadcast_lens
.
begin
()
+
axis
))
{
MIGRAPHX_THROW
(
"BROADCAST: when broadcasting, succeeding sizes must match"
);
}
std
::
copy
(
s0
.
strides
().
begin
(),
s0
.
strides
().
end
(),
bcast_strides
.
begin
()
+
axis
);
shape
output
{
t
,
broadcast_lens
,
std
::
move
(
bcast_strides
)};
if
(
output
.
elements
()
<
input
.
elements
())
MIGRAPHX_THROW
(
"BROADCAST: output size must be greater than or equal to input size"
);
return
output
;
shape
output
{
t
,
broadcast_lens
,
std
::
move
(
bcast_strides
)};
if
(
output
.
elements
()
<
s0
.
elements
())
MIGRAPHX_THROW
(
"BROADCAST: output size must be greater than or equal to s0 size"
);
return
output
;
}
else
{
if
(
s0
.
dynamic
()
and
s1
.
dynamic
())
{
auto
bcast_max_lens
=
compute_broadcasted_lens
(
s0
.
max_lens
(),
s1
.
max_lens
());
auto
bcast_min_lens
=
compute_broadcasted_lens
(
s0
.
min_lens
(),
s1
.
min_lens
());
auto
bcast_opt_lens
=
compute_broadcasted_lens
(
s0
.
opt_lens
(),
s1
.
opt_lens
());
std
::
vector
<
shape
::
dynamic_dimension
>
output_dyn_dims
=
{};
for
(
size_t
i
=
0
;
i
<
bcast_max_lens
.
size
();
++
i
)
{
output_dyn_dims
.
push_back
(
shape
::
dynamic_dimension
{
min_spatial_dims
[
i
],
max_spatial_dims
[
i
],
opt_spatial_dims
[
i
]});
}
return
{
t
,
output_dyn_dims
};
}
else
if
(
not
s0
.
dynamic
()
and
not
s1
.
dynamic
())
{
auto
bcast_lens
=
compute_broadcasted_lens
(
s0
.
lens
(),
s1
.
lens
());
std
::
vector
<
size_t
>
bcast_strides
(
broadcast_lens
.
size
(),
0
);
std
::
copy
(
s0
.
strides
().
begin
(),
s0
.
strides
().
end
(),
bcast_strides
.
begin
()
+
axis
);
return
{
t
,
std
::
move
(
bcast_lens
),
std
::
move
(
bcast_strides
)};
}
else
{
MIGRAPHX_THROW
(
"BROADCAST: s0 and s1 are not both dynamic or static"
);
}
}
}
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
return
args
[
0
].
reshape
(
output_shape
);
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
...
...
src/include/migraphx/op/multibroadcast.hpp
View file @
412c298e
...
...
@@ -51,39 +51,47 @@ struct multibroadcast
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
,
2
);
auto
t
=
inputs
.
at
(
0
).
type
();
auto
input_shape
=
inputs
.
at
(
0
);
auto
s0
=
inputs
.
at
(
0
);
if
(
s0
.
lens
().
empty
())
{
MIGRAPHX_THROW
(
"MULTIBROADCAST: inputs dimensions should be > 0"
);
}
auto
make_bcast_strides
=
[
&
](
std
::
size_t
out_num_dims
,
std
::
size_t
offset
)
{
std
::
vector
<
size_t
>
bcast_strides
(
out_num_dims
,
0
);
for
(
std
::
ptrdiff_t
i
=
s0
.
lens
().
size
()
-
1
;
i
>=
0
;
i
--
)
{
if
(
output_lens
[
i
+
offset
]
==
s0
.
lens
()[
i
])
{
bcast_strides
[
i
+
offset
]
=
s0
.
strides
()[
i
];
}
}
return
bcast_strides
;
};
if
(
inputs
.
size
()
==
1
)
{
if
(
input_shape
.
lens
().
empty
())
{
MIGRAPHX_THROW
(
"MULTIBROADCAST: inputs dimensions should be > 0"
);
}
if
(
input_shape
.
lens
().
size
()
>
output_lens
.
size
())
if
(
s0
.
lens
().
size
()
>
output_lens
.
size
())
{
MIGRAPHX_THROW
(
"MULTIBROADCAST: inputs dimensions should <= output size"
);
}
auto
offset
=
output_lens
.
size
()
-
input_shape
.
lens
().
size
();
for
(
std
::
ptrdiff_t
i
=
input_shape
.
lens
().
size
()
-
1
;
i
>=
0
;
i
--
)
auto
offset
=
output_lens
.
size
()
-
s0
.
lens
().
size
();
for
(
std
::
ptrdiff_t
i
=
s0
.
lens
().
size
()
-
1
;
i
>=
0
;
i
--
)
{
if
(
output_lens
[
i
+
offset
]
!=
input_shape
.
lens
()[
i
]
and
input_shape
.
lens
()[
i
]
!=
1
)
if
(
output_lens
[
i
+
offset
]
!=
s0
.
lens
()[
i
]
and
s0
.
lens
()[
i
]
!=
1
)
{
MIGRAPHX_THROW
(
"MULTIBROADCAST: input shape {"
+
to_string_range
(
input_shape
.
lens
())
+
"MULTIBROADCAST: input shape {"
+
to_string_range
(
s0
.
lens
())
+
"} cannot be broadcasted to {"
+
to_string_range
(
output_lens
)
+
"}!"
);
}
}
std
::
vector
<
size_t
>
bcast_strides
(
output_lens
.
size
(),
0
);
for
(
std
::
ptrdiff_t
i
=
input_shape
.
lens
().
size
()
-
1
;
i
>=
0
;
i
--
)
{
if
(
output_lens
[
i
+
offset
]
==
input_shape
.
lens
()[
i
])
{
bcast_strides
[
i
+
offset
]
=
input_shape
.
strides
()[
i
];
}
}
return
{
t
,
output_lens
,
bcast_strides
};
auto
bcast_strides
=
make_bcast_strides
(
output_lens
.
size
(),
offset
);
return
{
t
,
output_lens
,
std
::
move
(
bcast_strides
)};
}
else
{
...
...
@@ -91,16 +99,37 @@ struct multibroadcast
// shapes can be dynamic (at compile-time) or static (at evaluation time)
// this function will be called through compute_output_shape conversion to dyn_output
// new compute_broadcasted_lens for dynamic shapes
auto
other_shape
=
inputs
.
at
(
1
);
if
(
input_shape
.
dynamic
()
and
other_shape
.
dynamic
())
{}
else
if
(
not
input_shape
.
dynamic
()
and
not
other_shape
.
dynamic
())
// do we want this to work in both broadcast directions?
// s0 and s1 as shape inputs
// always s0 -> s1 shape or allow s0 to retain the same shape?
// presuming that it's always s0 -> s1 shape, since that's closer to the current behavior
// compute_broadcasted_lens() will swap the shapes if s1.size() < s0.size(), may need to make another function
auto
s1
=
inputs
.
at
(
1
);
if
(
s0
.
dynamic
()
and
s1
.
dynamic
())
{
auto
bcast_max_lens
=
compute_broadcasted_lens
(
s0
.
max_lens
(),
s1
.
max_lens
());
auto
bcast_min_lens
=
compute_broadcasted_lens
(
s0
.
min_lens
(),
s1
.
min_lens
());
auto
bcast_opt_lens
=
compute_broadcasted_lens
(
s0
.
opt_lens
(),
s1
.
opt_lens
());
std
::
vector
<
shape
::
dynamic_dimension
>
output_dyn_dims
=
{};
for
(
size_t
i
=
0
;
i
<
bcast_max_lens
.
size
();
++
i
)
{
output_dyn_dims
.
push_back
(
shape
::
dynamic_dimension
{
min_spatial_dims
[
i
],
max_spatial_dims
[
i
],
opt_spatial_dims
[
i
]});
}
return
{
t
,
std
::
move
(
output_dyn_dims
)};
}
else
if
(
not
s0
.
dynamic
()
and
not
s1
.
dynamic
())
{
auto
output_lens
=
compute_broadcasted_lens
(
input_shape
.
lens
(),
other_shape
.
lens
());
auto
bcast_lens
=
compute_broadcasted_lens
(
s0
.
lens
(),
s1
.
lens
());
auto
offset
=
s1
.
lens
().
size
()
-
s0
.
lens
().
size
();
auto
bcast_strides
=
make_bcast_strides
(
s1
.
lens
().
size
(),
offset
);
return
{
t
,
std
::
move
(
bcast_lens
),
std
::
move
(
bcast_strides
)};
}
else
{
MIGRAPHX_THROW
(
"MULTIBROADCAST:
input_shape and other_shape
are not both dynamic or static"
);
"MULTIBROADCAST:
s0 and s1
are not both dynamic or static"
);
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment