Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
553a8d02
Commit
553a8d02
authored
Oct 19, 2022
by
charlie
Browse files
refactor again, made a compute broadcast for dyn_dims
parent
c539a7b0
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
105 additions
and
113 deletions
+105
-113
src/common.cpp
src/common.cpp
+46
-40
src/include/migraphx/common.hpp
src/include/migraphx/common.hpp
+2
-0
src/include/migraphx/op/broadcast.hpp
src/include/migraphx/op/broadcast.hpp
+24
-19
src/include/migraphx/op/multibroadcast.hpp
src/include/migraphx/op/multibroadcast.hpp
+6
-54
src/include/migraphx/shape.hpp
src/include/migraphx/shape.hpp
+9
-0
src/shape.cpp
src/shape.cpp
+18
-0
No files found.
src/common.cpp
View file @
553a8d02
...
@@ -68,19 +68,21 @@ std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
...
@@ -68,19 +68,21 @@ std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
std
::
vector
<
shape
::
dynamic_dimension
>
compute_broadcasted_dyn_dims
(
shape
s0
,
shape
s1
)
std
::
vector
<
shape
::
dynamic_dimension
>
compute_broadcasted_dyn_dims
(
shape
s0
,
shape
s1
)
{
{
if
(
s0
.
dynamic
()
or
s1
.
dynamic
())
if
(
not
s0
.
dynamic
()
and
not
s1
.
dynamic
())
{
{
MIGRAPHX_THROW
(
"COMPUTE_BROADCASTED_DYN_DIMS: given two static shapes"
);
}
// change both shapes to dynamic_dimension representation
// change both shapes to dynamic_dimension representation
if
(
not
s0
.
dynamic
())
if
(
not
s0
.
dynamic
())
s0
=
s0
.
to_dynamic
();
s0
=
s0
.
to_dynamic
();
if
(
not
s1
.
dynamic
())
if
(
not
s1
.
dynamic
())
s1
=
s1
.
to_dynamic
();
s1
=
s1
.
to_dynamic
();
if
(
s0
.
rank
()
>
s1
.
rank
())
if
(
s0
.
ndim
()
>
s1
.
ndim
())
{
{
std
::
swap
(
s0
,
s1
);
std
::
swap
(
s0
,
s1
);
}
}
auto
offset
=
s1
.
rank
()
-
s0
.
rank
();
auto
offset
=
s1
.
ndim
()
-
s0
.
ndim
();
std
::
vector
<
shape
::
dynamic_dimension
>
out_dims
(
s1
.
dyn_dims
());
std
::
vector
<
shape
::
dynamic_dimension
>
out_dims
(
s1
.
dyn_dims
());
std
::
vector
<
shape
::
dynamic_dimension
>
one_dyn_dims
{{
1
,
1
,
0
},
{
1
,
1
,
1
}};
std
::
vector
<
shape
::
dynamic_dimension
>
one_dyn_dims
{{
1
,
1
,
0
},
{
1
,
1
,
1
}};
std
::
transform
(
std
::
transform
(
...
@@ -95,8 +97,8 @@ std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, sha
...
@@ -95,8 +97,8 @@ std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, sha
}
}
else
if
(
contains
(
one_dyn_dims
,
a
)
or
contains
(
one_dyn_dims
,
b
))
else
if
(
contains
(
one_dyn_dims
,
a
)
or
contains
(
one_dyn_dims
,
b
))
{
{
return
shape
::
dynamic_dimension
{
// setting opt to 0, may need to be changed
std
::
max
(
a
.
min
,
b
.
min
),
std
::
max
(
a
.
max
,
b
.
max
),
std
::
max
(
a
.
opt
,
b
.
opt
)
};
return
shape
::
dynamic_dimension
{
std
::
max
(
a
.
min
,
b
.
min
),
std
::
max
(
a
.
max
,
b
.
max
),
0
};
}
}
else
else
{
{
...
@@ -105,11 +107,7 @@ std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, sha
...
@@ -105,11 +107,7 @@ std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, sha
migraphx
::
to_string_range
(
s1
.
dyn_dims
())
+
"} mismatch!"
);
migraphx
::
to_string_range
(
s1
.
dyn_dims
())
+
"} mismatch!"
);
}
}
});
});
}
return
out_dims
;
else
{
MIGRAPHX_THROW
(
"COMPUTE_BROADCASTED_DYN_DIMS: given two static shapes"
);
}
}
}
// Compute the common (broadcasted) dimensions of a list of fixed shapes
// Compute the common (broadcasted) dimensions of a list of fixed shapes
...
@@ -183,11 +181,19 @@ instruction_ref insert_common_op(module& m,
...
@@ -183,11 +181,19 @@ instruction_ref insert_common_op(module& m,
// multibroadcast?
// multibroadcast?
if
(
inputs
[
0
]
->
get_shape
().
dyn_dims
()
!=
c_dyn_dims
)
if
(
inputs
[
0
]
->
get_shape
().
dyn_dims
()
!=
c_dyn_dims
)
{
{
inputs
[
0
]
=
m
.
insert_instruction
(
ins
,
make_op
(
"multibroadcast"
),
inputs
[
0
],
inputs
[
1
]);
inputs
[
0
]
=
m
.
insert_instruction
(
ins
,
make_op
(
"multibroadcast"
,
{{
"out_dyn_dims"
,
c_dyn_dims
}}),
inputs
[
0
],
inputs
[
1
]);
}
}
if
(
inputs
[
1
]
->
get_shape
().
dyn_dims
()
!=
c_dyn_dims
)
if
(
inputs
[
1
]
->
get_shape
().
dyn_dims
()
!=
c_dyn_dims
)
{
{
inputs
[
1
]
=
m
.
insert_instruction
(
ins
,
make_op
(
"multibroadcast"
),
inputs
[
1
],
inputs
[
0
]);
inputs
[
1
]
=
m
.
insert_instruction
(
ins
,
make_op
(
"multibroadcast"
,
{{
"out_dyn_dims"
,
c_dyn_dims
}}),
inputs
[
1
],
inputs
[
0
]);
}
}
std
::
transform
(
inputs
.
begin
(),
inputs
.
end
(),
inputs
.
begin
(),
[
&
](
auto
input
)
{
std
::
transform
(
inputs
.
begin
(),
inputs
.
end
(),
inputs
.
begin
(),
[
&
](
auto
input
)
{
if
(
input
->
get_shape
().
type
()
!=
c_type
)
if
(
input
->
get_shape
().
type
()
!=
c_type
)
...
...
src/include/migraphx/common.hpp
View file @
553a8d02
...
@@ -37,6 +37,8 @@ struct operation;
...
@@ -37,6 +37,8 @@ struct operation;
std
::
vector
<
std
::
size_t
>
compute_broadcasted_lens
(
std
::
vector
<
std
::
size_t
>
s0
,
std
::
vector
<
std
::
size_t
>
compute_broadcasted_lens
(
std
::
vector
<
std
::
size_t
>
s0
,
std
::
vector
<
std
::
size_t
>
s1
);
std
::
vector
<
std
::
size_t
>
s1
);
std
::
vector
<
shape
::
dynamic_dimension
>
compute_broadcasted_dyn_dims
(
shape
s0
,
shape
s1
);
shape
common_shape
(
const
std
::
vector
<
shape
>&
shapes
);
shape
common_shape
(
const
std
::
vector
<
shape
>&
shapes
);
instruction_ref
insert_common_op
(
module
&
m
,
instruction_ref
insert_common_op
(
module
&
m
,
...
...
src/include/migraphx/op/broadcast.hpp
View file @
553a8d02
...
@@ -58,7 +58,8 @@ struct broadcast
...
@@ -58,7 +58,8 @@ struct broadcast
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
,
2
);
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
,
2
);
auto
s0
=
inputs
.
at
(
0
);
auto
s0
=
inputs
.
at
(
0
);
auto
t
=
s0
.
type
();
auto
t
=
s0
.
type
();
if
(
inputs
.
size
()
==
1
)
{
std
::
vector
<
size_t
>
bcast_strides
(
broadcast_lens
.
size
(),
0
);
std
::
vector
<
size_t
>
bcast_strides
(
broadcast_lens
.
size
(),
0
);
// the broadcast op is deprecated now, so not handling the negative
// the broadcast op is deprecated now, so not handling the negative
// value of axis anymore
// value of axis anymore
...
@@ -83,6 +84,10 @@ struct broadcast
...
@@ -83,6 +84,10 @@ struct broadcast
MIGRAPHX_THROW
(
"BROADCAST: output size must be greater than or equal to s0 size"
);
MIGRAPHX_THROW
(
"BROADCAST: output size must be greater than or equal to s0 size"
);
return
output
;
return
output
;
}
}
else
{
}
}
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
...
...
src/include/migraphx/op/multibroadcast.hpp
View file @
553a8d02
...
@@ -43,11 +43,12 @@ namespace op {
...
@@ -43,11 +43,12 @@ namespace op {
struct
multibroadcast
struct
multibroadcast
{
{
std
::
vector
<
std
::
size_t
>
output_lens
;
std
::
vector
<
std
::
size_t
>
output_lens
;
std
::
vector
<
shape
::
dynamic_dimension
>
output_dyn_dims
;
template
<
class
Self
,
class
F
>
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
static
auto
reflect
(
Self
&
self
,
F
f
)
{
{
return
pack
(
f
(
self
.
output_lens
,
"out_lens"
));
return
pack
(
f
(
self
.
output_lens
,
"out_lens"
)
,
f
(
self
.
output_dyn_dims
,
"out_dyn_dims"
)
);
}
}
std
::
string
name
()
const
{
return
"multibroadcast"
;
}
std
::
string
name
()
const
{
return
"multibroadcast"
;
}
...
@@ -101,62 +102,13 @@ struct multibroadcast
...
@@ -101,62 +102,13 @@ struct multibroadcast
{
{
// two inputs
// two inputs
auto
s1
=
inputs
.
at
(
1
);
auto
s1
=
inputs
.
at
(
1
);
if
(
s0
.
dynamic
()
and
s1
.
dynamic
())
if
(
s0
.
dynamic
()
or
s1
.
dynamic
())
{
{
// TODO handle both dynamic case
if
(
not
output_dyn_dims
.
empty
())
MIGRAPHX_THROW
(
"MULTIBROADCAST_2IN: not handled; two dynamic shape inputs not handled"
);
}
else
if
(
s0
.
dynamic
()
or
s1
.
dynamic
())
{
// only handles the case when broadcasting static shape to dynamic shape
// all the dimensions in the static shape must match to a fixed dimension in the
// dynamic shape or be 1
// TODO: handling the other possibilities
if
(
s1
.
dynamic
())
{
std
::
swap
(
s0
,
s1
);
}
auto
static_rank
=
s1
.
lens
().
size
();
auto
dyn_rank
=
s0
.
max_lens
().
size
();
if
(
static_rank
>
dyn_rank
)
{
MIGRAPHX_THROW
(
"MULTIBROADCAST_2IN: not handled; static shape has a higher "
"rank than dynamic shape"
);
}
return
s0
;
auto
offset
=
dyn_rank
-
static_rank
;
std
::
vector
<
shape
::
dynamic_dimension
>
out_dims
(
s0
.
dyn_dims
());
std
::
transform
(
s0
.
dyn_dims
().
begin
(),
s0
.
dyn_dims
().
end
(),
s1
.
lens
().
begin
()
+
offset
,
out_lens
.
begin
()
+
offset
,
[
&
](
auto
a
,
auto
b
)
{
if
(
a
==
b
)
{
{
return
a
;
return
{
t
,
output_dyn_dims
}
;
}
}
else
if
((
a
==
1
or
b
==
1
)
and
a
!=
0
and
b
!=
0
)
return
{
t
,
compute_broadcasted_dyn_dims
(
s0
,
s1
)};
{
return
std
::
max
(
a
,
b
);
}
else
{
// if not matching nor 1, set to 0
return
static_cast
<
std
::
size_t
>
(
0
);
}
});
/*
auto bcast_min_lens = compute_broadcasted_lens(s0.min_lens(), s1.min_lens());
auto bcast_max_lens = compute_broadcasted_lens(s0.max_lens(), s1.max_lens());
auto bcast_opt_lens = compute_broadcasted_opt_lens(s0.opt_lens(), s1.opt_lens());
return {t,
std::move(bcast_min_lens),
std::move(bcast_max_lens),
std::move(bcast_opt_lens)};
*/
}
}
else
else
{
{
...
...
src/include/migraphx/shape.hpp
View file @
553a8d02
...
@@ -142,6 +142,12 @@ struct shape
...
@@ -142,6 +142,12 @@ struct shape
const
std
::
vector
<
std
::
size_t
>&
lens
()
const
;
const
std
::
vector
<
std
::
size_t
>&
lens
()
const
;
const
std
::
vector
<
std
::
size_t
>&
strides
()
const
;
const
std
::
vector
<
std
::
size_t
>&
strides
()
const
;
/*!
* The number of dimensions in the shape.
* Same as the number of indices required to get a data value.
*/
std
::
size_t
ndim
()
const
;
/*!
/*!
* Return the number of elements in the tensor.
* Return the number of elements in the tensor.
*/
*/
...
@@ -227,6 +233,9 @@ struct shape
...
@@ -227,6 +233,9 @@ struct shape
shape
with_type
(
type_t
t
)
const
;
shape
with_type
(
type_t
t
)
const
;
// convert the shape to an equivalent dynamic shape
shape
to_dynamic
()
const
;
friend
bool
operator
==
(
const
shape
&
x
,
const
shape
&
y
);
friend
bool
operator
==
(
const
shape
&
x
,
const
shape
&
y
);
friend
bool
operator
!=
(
const
shape
&
x
,
const
shape
&
y
);
friend
bool
operator
!=
(
const
shape
&
x
,
const
shape
&
y
);
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
shape
&
x
);
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
shape
&
x
);
...
...
src/shape.cpp
View file @
553a8d02
...
@@ -265,6 +265,15 @@ const std::vector<std::size_t>& shape::lens() const { return impl->m_lens; }
...
@@ -265,6 +265,15 @@ const std::vector<std::size_t>& shape::lens() const { return impl->m_lens; }
const
std
::
vector
<
std
::
size_t
>&
shape
::
strides
()
const
{
return
impl
->
m_strides
;
}
const
std
::
vector
<
std
::
size_t
>&
shape
::
strides
()
const
{
return
impl
->
m_strides
;
}
std
::
size_t
shape
::
ndim
()
const
{
if
(
this
->
dynamic
())
{
return
dyn_dims
().
size
();
}
return
lens
().
size
();
}
std
::
size_t
shape
::
elements
()
const
{
return
impl
->
elements
();
}
std
::
size_t
shape
::
elements
()
const
{
return
impl
->
elements
();
}
std
::
size_t
shape
::
bytes
()
const
std
::
size_t
shape
::
bytes
()
const
...
@@ -458,6 +467,15 @@ shape shape::with_type(type_t t) const
...
@@ -458,6 +467,15 @@ shape shape::with_type(type_t t) const
return
{
c
};
return
{
c
};
}
}
shape
shape
::
to_dynamic
()
const
{
if
(
this
->
dynamic
())
{
return
*
this
;
}
return
{
type
(),
lens
(),
lens
(),
lens
()};
}
std
::
size_t
shape
::
element_space
()
const
{
return
impl
->
element_space
();
}
std
::
size_t
shape
::
element_space
()
const
{
return
impl
->
element_space
();
}
std
::
string
shape
::
type_string
()
const
{
return
name
(
this
->
type
());
}
std
::
string
shape
::
type_string
()
const
{
return
name
(
this
->
type
());
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment