Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
9280150b
Commit
9280150b
authored
Nov 17, 2023
by
charlie
Browse files
scratch work to get bert_uncased working with dynamic batch
parent
c84b8195
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
148 additions
and
20 deletions
+148
-20
src/CMakeLists.txt
src/CMakeLists.txt
+1
-0
src/common.cpp
src/common.cpp
+32
-14
src/include/migraphx/common.hpp
src/include/migraphx/common.hpp
+5
-0
src/include/migraphx/op/dot.hpp
src/include/migraphx/op/dot.hpp
+6
-1
src/include/migraphx/op/dot_broadcast.hpp
src/include/migraphx/op/dot_broadcast.hpp
+89
-0
src/include/migraphx/op/reshape.hpp
src/include/migraphx/op/reshape.hpp
+7
-1
src/onnx/parse_matmul.cpp
src/onnx/parse_matmul.cpp
+8
-4
No files found.
src/CMakeLists.txt
View file @
9280150b
...
...
@@ -139,6 +139,7 @@ register_migraphx_ops(
dimensions_of
div
dot
dot_broadcast
elu
equal
erf
...
...
src/common.cpp
View file @
9280150b
...
...
@@ -51,21 +51,23 @@ std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
});
return
out_lens
;
}
std
::
vector
<
shape
::
dynamic_dimension
>
compute_broadcasted_dyn_dims
(
shape
s0
,
shape
s1
)
std
::
vector
<
shape
::
dynamic_dimension
>
compute_broadcasted_dyn_dims
(
std
::
vector
<
shape
::
dynamic_dimension
>
dds0
,
std
::
vector
<
shape
::
dynamic_dimension
>
dds1
)
{
// change both shapes to dynamic_dimension representation
s0
=
s0
.
to_dynamic
();
s1
=
s1
.
to_dynamic
();
if
(
s0
.
ndim
()
>
s1
.
ndim
())
if
(
dds0
.
size
()
>
dds1
.
size
())
{
std
::
swap
(
s0
,
s1
);
std
::
swap
(
dd
s0
,
dd
s1
);
}
auto
offset
=
s1
.
ndim
()
-
s0
.
ndim
();
std
::
vector
<
shape
::
dynamic_dimension
>
out_dims
(
s1
.
dyn_dims
());
std
::
transform
(
s0
.
dyn_dims
().
cbegin
(),
s0
.
dyn_dims
().
cend
(),
s1
.
dyn_dims
().
cbegin
()
+
offset
,
auto
offset
=
dds1
.
size
()
-
dds0
.
size
();
std
::
vector
<
shape
::
dynamic_dimension
>
out_dims
(
dds1
);
// If one within the range of the other
auto
dd_within_range
=
[
&
](
shape
::
dynamic_dimension
x
,
shape
::
dynamic_dimension
y
)
{
return
(
x
.
min
>=
y
.
min
and
x
.
max
<=
y
.
max
);
};
std
::
transform
(
dds0
.
cbegin
(),
dds0
.
cend
(),
dds1
.
cbegin
()
+
offset
,
out_dims
.
begin
()
+
offset
,
[
&
](
auto
a
,
auto
b
)
{
if
(
a
==
b
or
b
==
1
)
...
...
@@ -76,16 +78,32 @@ std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, sha
{
return
b
;
}
else
if
(
dd_within_range
(
a
,
b
))
{
return
a
;
}
else
if
(
dd_within_range
(
b
,
a
))
{
return
b
;
}
else
{
MIGRAPHX_THROW
(
"COMPUTE_BROADCASTED_DYN_DIMS: dynamic shapes {"
+
migraphx
::
to_string_range
(
s0
.
dyn_dims
()
)
+
"} and {"
+
migraphx
::
to_string_range
(
s1
.
dyn_dims
()
)
+
"} mismatch!"
);
migraphx
::
to_string_range
(
dd
s0
)
+
"} and {"
+
migraphx
::
to_string_range
(
dd
s1
)
+
"} mismatch!"
);
}
});
return
out_dims
;
}
std
::
vector
<
shape
::
dynamic_dimension
>
compute_broadcasted_dyn_dims
(
shape
s0
,
shape
s1
)
{
// change both shapes to dynamic_dimension representation
s0
=
s0
.
to_dynamic
();
s1
=
s1
.
to_dynamic
();
return
compute_broadcasted_dyn_dims
(
s0
.
dyn_dims
(),
s1
.
dyn_dims
());
}
std
::
vector
<
shape
::
dynamic_dimension
>
compute_common_dyn_dims
(
const
std
::
vector
<
shape
>&
shapes
)
{
auto
ret_shape
=
shapes
.
at
(
0
);
...
...
src/include/migraphx/common.hpp
View file @
9280150b
...
...
@@ -58,6 +58,11 @@ MIGRAPHX_EXPORT
std
::
vector
<
std
::
size_t
>
compute_broadcasted_lens
(
std
::
vector
<
std
::
size_t
>
s0
,
std
::
vector
<
std
::
size_t
>
s1
);
MIGRAPHX_EXPORT
std
::
vector
<
shape
::
dynamic_dimension
>
compute_broadcasted_dyn_dims
(
std
::
vector
<
shape
::
dynamic_dimension
>
dds0
,
std
::
vector
<
shape
::
dynamic_dimension
>
dds1
);
MIGRAPHX_EXPORT
std
::
vector
<
shape
::
dynamic_dimension
>
compute_broadcasted_dyn_dims
(
shape
s0
,
shape
s1
);
...
...
src/include/migraphx/op/dot.hpp
View file @
9280150b
...
...
@@ -63,7 +63,12 @@ struct dot
}
std
::
size_t
dim_0
=
s0
.
ndim
()
-
2
;
std
::
size_t
dim_1
=
s0
.
ndim
()
-
1
;
if
(
s0
.
dyn_dims
()[
dim_1
]
!=
s1
.
dyn_dims
()[
dim_0
])
auto
dd_within_range
=
[
&
](
shape
::
dynamic_dimension
x
,
shape
::
dynamic_dimension
y
)
{
return
(
x
.
min
>=
y
.
min
and
x
.
max
<=
y
.
max
);
};
auto
x
=
s0
.
dyn_dims
()[
dim_1
];
auto
y
=
s1
.
dyn_dims
()[
dim_0
];
if
(
not
dd_within_range
(
x
,
y
)
and
not
dd_within_range
(
y
,
x
))
{
MIGRAPHX_THROW
(
"DOT: dynamic inner dimensions do not match: {"
+
to_string_range
(
s0
.
dyn_dims
())
+
"} x {"
+
...
...
src/include/migraphx/op/dot_broadcast.hpp
0 → 100644
View file @
9280150b
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_OPERATORS_DOT_BROADCAST_HPP
#define MIGRAPHX_GUARD_OPERATORS_DOT_BROADCAST_HPP
#include <migraphx/check_shapes.hpp>
#include <migraphx/config.hpp>
#include <migraphx/common.hpp>
#include <migraphx/dyn_output.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
/**
* Broadcast dimensions between two tensors for the `dot` operator.
* Essentially broadcasts between two shapes for dimensions other than the last two.
* This operator is only needed if one of the shapes are dynamic.
* Example:
* a = shape[{1, 4}, 3, 248, 248]
* b = shape[248, 365]
* dot_broadcast(a, b) => shape[{1, 4}, 3, 248, 248] (no change)
* dot_broadcast(b, a) => shape[{1, 4}, 3, 248, 365]
*/
struct
dot_broadcast
{
std
::
string
name
()
const
{
return
"dot_broadcast"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
2
);
auto
s0
=
inputs
.
at
(
0
);
auto
s1
=
inputs
.
at
(
1
);
if
(
s0
.
dynamic
()
or
s1
.
dynamic
())
{
s0
=
s0
.
to_dynamic
();
s1
=
s1
.
to_dynamic
();
auto
dds0_it
=
s0
.
dyn_dims
().
end
()
-
2
;
auto
dds1_it
=
s1
.
dyn_dims
().
end
()
-
2
;
std
::
vector
<
shape
::
dynamic_dimension
>
sliced_dds0
{
s0
.
dyn_dims
().
begin
(),
dds0_it
};
std
::
vector
<
shape
::
dynamic_dimension
>
sliced_dds1
{
s1
.
dyn_dims
().
begin
(),
dds1_it
};
auto
output_dyn_dims
=
compute_broadcasted_dyn_dims
(
sliced_dds0
,
sliced_dds1
);
output_dyn_dims
.
insert
(
output_dyn_dims
.
end
(),
dds0_it
,
s0
.
dyn_dims
().
end
());
return
{
s0
.
type
(),
output_dyn_dims
};
}
else
{
auto
l0_it
=
s0
.
lens
().
begin
()
+
s0
.
ndim
()
-
2
;
std
::
vector
<
std
::
size_t
>
l0_broadcasted_lens
(
s0
.
lens
().
begin
(),
l0_it
);
auto
l1_it
=
s1
.
lens
().
begin
()
+
s1
.
ndim
()
-
2
;
std
::
vector
<
std
::
size_t
>
l1_broadcasted_lens
(
s1
.
lens
().
begin
(),
l1_it
);
auto
output_lens
=
compute_broadcasted_lens
(
l0_broadcasted_lens
,
l1_broadcasted_lens
);
output_lens
.
insert
(
output_lens
.
end
(),
l0_it
,
s0
.
lens
().
end
());
return
{
s0
.
type
(),
output_lens
};
}
}
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
return
args
[
0
].
reshape
(
dyn_out
.
computed_shape
);
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/reshape.hpp
View file @
9280150b
...
...
@@ -69,7 +69,7 @@ struct reshape
auto
dyn_dims
=
s0
.
dyn_dims
();
auto
num_not_fixed
=
std
::
count_if
(
dyn_dims
.
cbegin
(),
dyn_dims
.
cend
(),
[](
auto
dd
)
{
return
not
dd
.
is_fixed
();
});
if
(
num_not_fixed
!
=
1
)
if
(
num_not_fixed
=
=
1
)
{
MIGRAPHX_THROW
(
"Reshape: Only supports one non-fixed dynamic_dimension"
);
}
...
...
@@ -110,6 +110,12 @@ struct reshape
return
shape
::
dynamic_dimension
{
dim
,
dim
};
});
return
{
s0
.
type
(),
output_dyn_dims
};
/*
std::size_t max_val = std::numeric_limits<std::size_t>::max();
std::vector<shape::dynamic_dimension> dds(dims.size(),
shape::dynamic_dimension{0, max_val});
return {s0.type(), dds};
*/
}
template
<
class
Iterator
>
...
...
src/onnx/parse_matmul.cpp
View file @
9280150b
...
...
@@ -71,14 +71,18 @@ struct parse_matmul : op_parser<parse_matmul>
auto
s0_dds
=
a0
->
get_shape
().
to_dynamic
().
dyn_dims
();
auto
s1_dds
=
a1
->
get_shape
().
to_dynamic
().
dyn_dims
();
// TODO: handling this case requires a new multibroadcast mode
if
(
not
std
::
equal
(
s0_dds
.
rbegin
()
+
2
,
s0_dds
.
rend
(),
s1_dds
.
rbegin
()
+
2
,
s1_dds
.
rend
()))
{
MIGRAPHX_THROW
(
"PARSE_MATMUL: dynamic shape broadcasting not supported"
);
auto
broadcasted_a0
=
info
.
add_instruction
(
make_op
(
"dot_broadcast"
),
a0
,
a1
);
auto
broadcasted_a1
=
info
.
add_instruction
(
make_op
(
"dot_broadcast"
),
a1
,
a0
);
dot_res
=
info
.
add_instruction
(
make_op
(
opd
.
op_name
),
broadcasted_a0
,
broadcasted_a1
);
}
else
{
dot_res
=
info
.
add_instruction
(
make_op
(
opd
.
op_name
),
a0
,
a1
);
}
dot_res
=
info
.
add_instruction
(
make_op
(
opd
.
op_name
),
a0
,
a1
);
}
else
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment