Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
412c298e
Commit
412c298e
authored
Oct 02, 2022
by
charlie
Browse files
Progress on changing ops
multibroadcast and broadcast take two inputs
parent
78c799c5
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
111 additions
and
49 deletions
+111
-49
src/include/migraphx/op/broadcast.hpp
src/include/migraphx/op/broadcast.hpp
+57
-24
src/include/migraphx/op/multibroadcast.hpp
src/include/migraphx/op/multibroadcast.hpp
+54
-25
No files found.
src/include/migraphx/op/broadcast.hpp
View file @
412c298e
...
@@ -44,7 +44,6 @@ struct broadcast
...
@@ -44,7 +44,6 @@ struct broadcast
{
{
uint64_t
axis
=
0
;
uint64_t
axis
=
0
;
std
::
vector
<
std
::
size_t
>
broadcast_lens
;
std
::
vector
<
std
::
size_t
>
broadcast_lens
;
std
::
vector
<
shape
::
dynamic_dimension
>
broadcast_dyn_dims
;
template
<
class
Self
,
class
F
>
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
static
auto
reflect
(
Self
&
self
,
F
f
)
...
@@ -55,38 +54,72 @@ struct broadcast
...
@@ -55,38 +54,72 @@ struct broadcast
std
::
string
name
()
const
{
return
"broadcast"
;
}
std
::
string
name
()
const
{
return
"broadcast"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
,
2
);
auto
input
=
inputs
.
at
(
0
);
auto
s0
=
inputs
.
at
(
0
);
auto
t
=
input
.
type
();
auto
t
=
s0
.
type
();
std
::
vector
<
size_t
>
bcast_strides
(
broadcast_lens
.
size
(),
0
);
if
(
inputs
.
size
()
==
1
)
// the broadcast op is deprecated now, so not handling the negative
{
// value of axis anymore
std
::
vector
<
size_t
>
bcast_strides
(
broadcast_lens
.
size
(),
0
);
if
(
axis
>=
broadcast_lens
.
size
())
// the broadcast op is deprecated now, so not handling the negative
{
// value of axis anymore
MIGRAPHX_THROW
(
"BROADCAST : axis is out of range"
);
if
(
axis
>=
broadcast_lens
.
size
())
}
{
MIGRAPHX_THROW
(
"BROADCAST : axis is out of range"
);
}
if
(
broadcast_lens
.
size
()
-
axis
<
input
.
lens
().
size
())
if
(
broadcast_lens
.
size
()
-
axis
<
s0
.
lens
().
size
())
{
{
MIGRAPHX_THROW
(
"BROADCAST: (broadcast ndims - axis) is less than
input
ndims"
);
MIGRAPHX_THROW
(
"BROADCAST: (broadcast ndims - axis) is less than
s0
ndims"
);
}
}
if
(
not
std
::
equal
(
input
.
lens
().
begin
(),
input
.
lens
().
end
(),
broadcast_lens
.
begin
()
+
axis
))
if
(
not
std
::
equal
(
s0
.
lens
().
begin
(),
s0
.
lens
().
end
(),
broadcast_lens
.
begin
()
+
axis
))
{
{
MIGRAPHX_THROW
(
"BROADCAST: when broadcasting, succeeding sizes must match"
);
MIGRAPHX_THROW
(
"BROADCAST: when broadcasting, succeeding sizes must match"
);
}
}
std
::
copy
(
input
.
strides
().
begin
(),
input
.
strides
().
end
(),
bcast_strides
.
begin
()
+
axis
);
std
::
copy
(
s0
.
strides
().
begin
(),
s0
.
strides
().
end
(),
bcast_strides
.
begin
()
+
axis
);
shape
output
{
t
,
broadcast_lens
,
std
::
move
(
bcast_strides
)};
shape
output
{
t
,
broadcast_lens
,
std
::
move
(
bcast_strides
)};
if
(
output
.
elements
()
<
input
.
elements
())
if
(
output
.
elements
()
<
s0
.
elements
())
MIGRAPHX_THROW
(
"BROADCAST: output size must be greater than or equal to input size"
);
MIGRAPHX_THROW
(
"BROADCAST: output size must be greater than or equal to s0 size"
);
return
output
;
return
output
;
}
else
{
if
(
s0
.
dynamic
()
and
s1
.
dynamic
())
{
auto
bcast_max_lens
=
compute_broadcasted_lens
(
s0
.
max_lens
(),
s1
.
max_lens
());
auto
bcast_min_lens
=
compute_broadcasted_lens
(
s0
.
min_lens
(),
s1
.
min_lens
());
auto
bcast_opt_lens
=
compute_broadcasted_lens
(
s0
.
opt_lens
(),
s1
.
opt_lens
());
std
::
vector
<
shape
::
dynamic_dimension
>
output_dyn_dims
=
{};
for
(
size_t
i
=
0
;
i
<
bcast_max_lens
.
size
();
++
i
)
{
output_dyn_dims
.
push_back
(
shape
::
dynamic_dimension
{
min_spatial_dims
[
i
],
max_spatial_dims
[
i
],
opt_spatial_dims
[
i
]});
}
return
{
t
,
output_dyn_dims
};
}
else
if
(
not
s0
.
dynamic
()
and
not
s1
.
dynamic
())
{
auto
bcast_lens
=
compute_broadcasted_lens
(
s0
.
lens
(),
s1
.
lens
());
std
::
vector
<
size_t
>
bcast_strides
(
broadcast_lens
.
size
(),
0
);
std
::
copy
(
s0
.
strides
().
begin
(),
s0
.
strides
().
end
(),
bcast_strides
.
begin
()
+
axis
);
return
{
t
,
std
::
move
(
bcast_lens
),
std
::
move
(
bcast_strides
)};
}
else
{
MIGRAPHX_THROW
(
"BROADCAST: s0 and s1 are not both dynamic or static"
);
}
}
}
}
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
return
args
[
0
].
reshape
(
output_shape
);
return
args
[
0
].
reshape
(
output_shape
);
}
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
};
...
...
src/include/migraphx/op/multibroadcast.hpp
View file @
412c298e
...
@@ -51,39 +51,47 @@ struct multibroadcast
...
@@ -51,39 +51,47 @@ struct multibroadcast
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
,
2
);
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
,
2
);
auto
t
=
inputs
.
at
(
0
).
type
();
auto
t
=
inputs
.
at
(
0
).
type
();
auto
input_shape
=
inputs
.
at
(
0
);
auto
s0
=
inputs
.
at
(
0
);
if
(
s0
.
lens
().
empty
())
{
MIGRAPHX_THROW
(
"MULTIBROADCAST: inputs dimensions should be > 0"
);
}
auto
make_bcast_strides
=
[
&
](
std
::
size_t
out_num_dims
,
std
::
size_t
offset
)
{
std
::
vector
<
size_t
>
bcast_strides
(
out_num_dims
,
0
);
for
(
std
::
ptrdiff_t
i
=
s0
.
lens
().
size
()
-
1
;
i
>=
0
;
i
--
)
{
if
(
output_lens
[
i
+
offset
]
==
s0
.
lens
()[
i
])
{
bcast_strides
[
i
+
offset
]
=
s0
.
strides
()[
i
];
}
}
return
bcast_strides
;
};
if
(
inputs
.
size
()
==
1
)
if
(
inputs
.
size
()
==
1
)
{
{
if
(
input_shape
.
lens
().
empty
())
if
(
s0
.
lens
().
size
()
>
output_lens
.
size
())
{
MIGRAPHX_THROW
(
"MULTIBROADCAST: inputs dimensions should be > 0"
);
}
if
(
input_shape
.
lens
().
size
()
>
output_lens
.
size
())
{
{
MIGRAPHX_THROW
(
"MULTIBROADCAST: inputs dimensions should <= output size"
);
MIGRAPHX_THROW
(
"MULTIBROADCAST: inputs dimensions should <= output size"
);
}
}
auto
offset
=
output_lens
.
size
()
-
input_shape
.
lens
().
size
();
auto
offset
=
output_lens
.
size
()
-
s0
.
lens
().
size
();
for
(
std
::
ptrdiff_t
i
=
input_shape
.
lens
().
size
()
-
1
;
i
>=
0
;
i
--
)
for
(
std
::
ptrdiff_t
i
=
s0
.
lens
().
size
()
-
1
;
i
>=
0
;
i
--
)
{
{
if
(
output_lens
[
i
+
offset
]
!=
input_shape
.
lens
()[
i
]
and
input_shape
.
lens
()[
i
]
!=
1
)
if
(
output_lens
[
i
+
offset
]
!=
s0
.
lens
()[
i
]
and
s0
.
lens
()[
i
]
!=
1
)
{
{
MIGRAPHX_THROW
(
MIGRAPHX_THROW
(
"MULTIBROADCAST: input shape {"
+
to_string_range
(
input_shape
.
lens
())
+
"MULTIBROADCAST: input shape {"
+
to_string_range
(
s0
.
lens
())
+
"} cannot be broadcasted to {"
+
to_string_range
(
output_lens
)
+
"}!"
);
"} cannot be broadcasted to {"
+
to_string_range
(
output_lens
)
+
"}!"
);
}
}
}
}
std
::
vector
<
size_t
>
bcast_strides
(
output_lens
.
size
(),
0
);
auto
bcast_strides
=
make_bcast_strides
(
output_lens
.
size
(),
offset
);
for
(
std
::
ptrdiff_t
i
=
input_shape
.
lens
().
size
()
-
1
;
i
>=
0
;
i
--
)
return
{
t
,
output_lens
,
std
::
move
(
bcast_strides
)};
{
if
(
output_lens
[
i
+
offset
]
==
input_shape
.
lens
()[
i
])
{
bcast_strides
[
i
+
offset
]
=
input_shape
.
strides
()[
i
];
}
}
return
{
t
,
output_lens
,
bcast_strides
};
}
}
else
else
{
{
...
@@ -91,16 +99,37 @@ struct multibroadcast
...
@@ -91,16 +99,37 @@ struct multibroadcast
// shapes can be dynamic (at compile-time) or static (at evaluation time)
// shapes can be dynamic (at compile-time) or static (at evaluation time)
// this function will be called through compute_output_shape conversion to dyn_output
// this function will be called through compute_output_shape conversion to dyn_output
// new compute_broadcasted_lens for dynamic shapes
// new compute_broadcasted_lens for dynamic shapes
auto
other_shape
=
inputs
.
at
(
1
);
// do we want this to work in both broadcast directions?
if
(
input_shape
.
dynamic
()
and
other_shape
.
dynamic
())
{}
// s0 and s1 as shape inputs
else
if
(
not
input_shape
.
dynamic
()
and
not
other_shape
.
dynamic
())
// always s0 -> s1 shape or allow s0 to retain the same shape?
// presuming that it's always s0 -> s1 shape, since that's closer to the current behavior
// compute_broadcasted_lens() will swap the shapes if s1.size() < s0.size(), may need to make another function
auto
s1
=
inputs
.
at
(
1
);
if
(
s0
.
dynamic
()
and
s1
.
dynamic
())
{
auto
bcast_max_lens
=
compute_broadcasted_lens
(
s0
.
max_lens
(),
s1
.
max_lens
());
auto
bcast_min_lens
=
compute_broadcasted_lens
(
s0
.
min_lens
(),
s1
.
min_lens
());
auto
bcast_opt_lens
=
compute_broadcasted_lens
(
s0
.
opt_lens
(),
s1
.
opt_lens
());
std
::
vector
<
shape
::
dynamic_dimension
>
output_dyn_dims
=
{};
for
(
size_t
i
=
0
;
i
<
bcast_max_lens
.
size
();
++
i
)
{
output_dyn_dims
.
push_back
(
shape
::
dynamic_dimension
{
min_spatial_dims
[
i
],
max_spatial_dims
[
i
],
opt_spatial_dims
[
i
]});
}
return
{
t
,
std
::
move
(
output_dyn_dims
)};
}
else
if
(
not
s0
.
dynamic
()
and
not
s1
.
dynamic
())
{
{
auto
output_lens
=
compute_broadcasted_lens
(
input_shape
.
lens
(),
other_shape
.
lens
());
auto
bcast_lens
=
compute_broadcasted_lens
(
s0
.
lens
(),
s1
.
lens
());
auto
offset
=
s1
.
lens
().
size
()
-
s0
.
lens
().
size
();
auto
bcast_strides
=
make_bcast_strides
(
s1
.
lens
().
size
(),
offset
);
return
{
t
,
std
::
move
(
bcast_lens
),
std
::
move
(
bcast_strides
)};
}
}
else
else
{
{
MIGRAPHX_THROW
(
MIGRAPHX_THROW
(
"MULTIBROADCAST:
input_shape and other_shape
are not both dynamic or static"
);
"MULTIBROADCAST:
s0 and s1
are not both dynamic or static"
);
}
}
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment