Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
47a07c3a
Commit
47a07c3a
authored
Nov 27, 2023
by
charlie
Browse files
add dynamic_dimension.within_range()
parent
0ef0d0bb
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
43 additions
and
25 deletions
+43
-25
src/common.cpp
src/common.cpp
+2
-6
src/include/migraphx/op/dot.hpp
src/include/migraphx/op/dot.hpp
+36
-19
src/include/migraphx/shape.hpp
src/include/migraphx/shape.hpp
+5
-0
No files found.
src/common.cpp
View file @
47a07c3a
...
...
@@ -61,10 +61,6 @@ compute_broadcasted_dyn_dims(std::vector<shape::dynamic_dimension> dds0,
}
auto
offset
=
dds1
.
size
()
-
dds0
.
size
();
std
::
vector
<
shape
::
dynamic_dimension
>
out_dims
(
dds1
);
// If one within the range of the other
auto
dd_within_range
=
[
&
](
shape
::
dynamic_dimension
x
,
shape
::
dynamic_dimension
y
)
{
return
(
x
.
min
>=
y
.
min
and
x
.
max
<=
y
.
max
);
};
std
::
transform
(
dds0
.
cbegin
(),
dds0
.
cend
(),
dds1
.
cbegin
()
+
offset
,
...
...
@@ -78,11 +74,11 @@ compute_broadcasted_dyn_dims(std::vector<shape::dynamic_dimension> dds0,
{
return
b
;
}
else
if
(
dd_
within_range
(
a
,
b
))
else
if
(
a
.
within_range
(
b
))
{
return
a
;
}
else
if
(
dd_
within_range
(
b
,
a
))
else
if
(
b
.
within_range
(
a
))
{
return
b
;
}
...
...
src/include/migraphx/op/dot.hpp
View file @
47a07c3a
...
...
@@ -53,38 +53,55 @@ struct dot
}
if
(
a
.
dynamic
()
or
b
.
dynamic
())
{
auto
dd_within_range
=
[
&
](
shape
::
dynamic_dimension
x
,
shape
::
dynamic_dimension
y
)
{
return
(
x
.
min
>=
y
.
min
and
x
.
max
<=
y
.
max
);
};
auto
s0
=
a
.
to_dynamic
();
auto
s1
=
b
.
to_dynamic
();
if
(
not
std
::
equal
(
s0
.
dyn_dims
().
rbegin
()
+
2
,
s0
.
dyn_dims
().
rend
(),
s1
.
dyn_dims
().
rbegin
()
+
2
,
s1
.
dyn_dims
().
rend
(),
[
&
](
auto
x
,
auto
y
)
{
return
(
dd_within_range
(
x
,
y
)
or
dd_within_range
(
y
,
x
));
}))
std
::
vector
<
shape
::
dynamic_dimension
>
out_dyn_dims
;
// check outer dimensions are within range
// put within range dynamic_dimensions into the out_dyn_dims
bool
outers_within_range
=
std
::
equal
(
s0
.
dyn_dims
().
rbegin
()
+
2
,
s0
.
dyn_dims
().
rend
(),
s1
.
dyn_dims
().
rbegin
()
+
2
,
s1
.
dyn_dims
().
rend
(),
[
&
](
auto
x
,
auto
y
)
{
if
(
x
.
within_range
(
y
))
{
out_dyn_dims
.
push_back
(
x
);
return
true
;
}
else
if
(
y
.
within_range
(
x
))
{
out_dyn_dims
.
push_back
(
y
);
return
true
;
}
else
{
return
false
;
}
});
if
(
not
outers_within_range
)
{
MIGRAPHX_THROW
(
"DOT: dynamic outer dimensions of A and B mismatch or not within "
"dynamic_dimension range: {"
+
to_string_range
(
s0
.
dyn_dims
())
+
"} x {"
+
to_string_range
(
s1
.
dyn_dims
())
+
"}"
);
}
std
::
size_t
dim_0
=
s0
.
ndim
()
-
2
;
std
::
size_t
dim_1
=
s0
.
ndim
()
-
1
;
auto
x
=
s0
.
dyn_dims
()[
dim_1
];
auto
y
=
s1
.
dyn_dims
()[
dim_0
];
if
(
not
dd_within_range
(
x
,
y
)
and
not
dd_within_range
(
y
,
x
))
std
::
size_t
dim_i
=
s0
.
ndim
()
-
2
;
std
::
size_t
dim_j
=
s0
.
ndim
()
-
1
;
auto
x
=
s0
.
dyn_dims
()[
dim_i
];
auto
y
=
s1
.
dyn_dims
()[
dim_j
];
// check inner dimensions are within range
if
(
not
x
.
within_range
(
y
)
and
not
y
.
within_range
(
x
))
{
MIGRAPHX_THROW
(
"DOT: dynamic inner dimensions do not match: {"
+
to_string_range
(
s0
.
dyn_dims
())
+
"} x {"
+
to_string_range
(
s1
.
dyn_dims
())
+
"}"
);
}
// NOTE could make this compute_shape more precise by using outer dimensions of the
// shape that's dd_within_range. currently this just uses the outer dimensions of s0.
auto
out_dyn_dims
=
s0
.
dyn_dims
();
out_dyn_dims
[
dim_1
]
=
s1
.
dyn_dims
()[
dim_1
];
out_dyn_dims
.
push_back
(
s0
.
dyn_dims
()[
dim_i
]);
out_dyn_dims
.
push_back
(
s1
.
dyn_dims
()[
dim_j
]);
return
{
t
,
out_dyn_dims
};
}
else
...
...
src/include/migraphx/shape.hpp
View file @
47a07c3a
...
...
@@ -102,6 +102,11 @@ struct MIGRAPHX_EXPORT shape
bool
is_fixed
()
const
;
bool
has_optimal
()
const
;
bool
within_range
(
const
dynamic_dimension
&
other
)
{
return
(
this
->
min
>=
other
.
min
and
this
->
max
<=
other
.
max
);
}
MIGRAPHX_EXPORT
friend
bool
operator
==
(
const
dynamic_dimension
&
x
,
const
dynamic_dimension
&
y
);
MIGRAPHX_EXPORT
friend
bool
operator
!=
(
const
dynamic_dimension
&
x
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment