Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
9b929d4e
Commit
9b929d4e
authored
Dec 29, 2022
by
charlie
Browse files
Merge branch 'develop' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into dyn_model_test
parents
c4b1102e
4394e9b3
Changes
111
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
695 additions
and
165 deletions
+695
-165
src/include/migraphx/op/argmax.hpp
src/include/migraphx/op/argmax.hpp
+19
-11
src/include/migraphx/op/broadcast.hpp
src/include/migraphx/op/broadcast.hpp
+2
-2
src/include/migraphx/op/contiguous.hpp
src/include/migraphx/op/contiguous.hpp
+18
-9
src/include/migraphx/op/dot.hpp
src/include/migraphx/op/dot.hpp
+51
-22
src/include/migraphx/op/flatten.hpp
src/include/migraphx/op/flatten.hpp
+39
-9
src/include/migraphx/op/layout.hpp
src/include/migraphx/op/layout.hpp
+68
-0
src/include/migraphx/op/pooling.hpp
src/include/migraphx/op/pooling.hpp
+115
-36
src/include/migraphx/op/softmax.hpp
src/include/migraphx/op/softmax.hpp
+5
-5
src/include/migraphx/op/squeeze.hpp
src/include/migraphx/op/squeeze.hpp
+5
-10
src/include/migraphx/op/transpose.hpp
src/include/migraphx/op/transpose.hpp
+31
-15
src/include/migraphx/program.hpp
src/include/migraphx/program.hpp
+1
-0
src/include/migraphx/shape.hpp
src/include/migraphx/shape.hpp
+6
-0
src/include/migraphx/shape_for_each.hpp
src/include/migraphx/shape_for_each.hpp
+3
-1
src/insert_pad.cpp
src/insert_pad.cpp
+2
-2
src/instruction.cpp
src/instruction.cpp
+18
-0
src/layout_nhwc.cpp
src/layout_nhwc.cpp
+118
-0
src/module.cpp
src/module.cpp
+88
-0
src/onnx/onnx_parser.cpp
src/onnx/onnx_parser.cpp
+18
-5
src/onnx/parse_binary_op.cpp
src/onnx/parse_binary_op.cpp
+6
-0
src/onnx/parse_pooling.cpp
src/onnx/parse_pooling.cpp
+82
-38
No files found.
src/include/migraphx/op/argmax.hpp
View file @
9b929d4e
...
@@ -30,6 +30,7 @@
...
@@ -30,6 +30,7 @@
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/dyn_output.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -56,13 +57,21 @@ struct argmax
...
@@ -56,13 +57,21 @@ struct argmax
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
);
auto
lens
=
inputs
[
0
].
lens
();
const
auto
&
s0
=
inputs
[
0
];
if
(
s0
.
dynamic
())
{
auto
dyn_dims
=
s0
.
dyn_dims
();
dyn_dims
[
axis
]
=
{
1
,
1
,
0
};
return
{
shape
::
int64_type
,
dyn_dims
};
}
else
{
auto
lens
=
s0
.
lens
();
lens
[
axis
]
=
1
;
lens
[
axis
]
=
1
;
return
{
shape
::
int64_type
,
lens
};
return
{
shape
::
int64_type
,
lens
};
}
}
}
template
<
class
T
>
template
<
class
T
>
int64_t
calc_argmax
(
T
&
input
,
std
::
vector
<
std
::
size_t
>&
indices
,
size_t
item_num
)
const
int64_t
calc_argmax
(
T
&
input
,
std
::
vector
<
std
::
size_t
>&
indices
,
size_t
item_num
)
const
...
@@ -79,19 +88,18 @@ struct argmax
...
@@ -79,19 +88,18 @@ struct argmax
max_index
=
i
;
max_index
=
i
;
}
}
}
}
return
max_index
;
return
max_index
;
}
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
{
argument
result
{
out
put_shape
};
argument
result
{
dyn_out
.
com
put
ed
_shape
};
auto
batch_item_num
=
args
.
front
().
get_shape
().
lens
()[
axis
];
auto
batch_item_num
=
args
.
front
().
get_shape
().
lens
()[
axis
];
result
.
visit
([
&
](
auto
output
)
{
result
.
visit
([
&
](
auto
output
)
{
args
[
0
].
visit
([
&
](
auto
input
)
{
args
[
0
].
visit
([
&
](
auto
input
)
{
par_for
(
out
put_shape
.
elements
(),
[
&
](
auto
i
)
{
par_for
(
dyn_out
.
com
put
ed
_shape
.
elements
(),
[
&
](
auto
i
)
{
auto
data_idx
=
out
put_shape
.
multi
(
i
);
auto
data_idx
=
dyn_out
.
com
put
ed
_shape
.
multi
(
i
);
output
[
i
]
=
this
->
calc_argmax
(
input
,
data_idx
,
batch_item_num
);
output
[
i
]
=
this
->
calc_argmax
(
input
,
data_idx
,
batch_item_num
);
});
});
});
});
...
...
src/include/migraphx/op/broadcast.hpp
View file @
9b929d4e
...
@@ -128,8 +128,8 @@ struct broadcast
...
@@ -128,8 +128,8 @@ struct broadcast
{
{
MIGRAPHX_THROW
(
"BROADCAST_2in: s0 length doesn't match with static s1 axis "
MIGRAPHX_THROW
(
"BROADCAST_2in: s0 length doesn't match with static s1 axis "
"dimension length ("
+
"dimension length ("
+
migraphx
::
to_string
(
s0
.
dyn_dim
s
()[
0
])
+
migraphx
::
to_string
(
s0
.
len
s
()[
0
])
+
" != "
+
migraphx
::
to_string
(
s1
.
dyn_dim
s
()[
axis
])
+
")"
);
" != "
+
migraphx
::
to_string
(
s1
.
len
s
()[
axis
])
+
")"
);
}
}
std
::
vector
<
size_t
>
bcast_strides
(
s1
.
ndim
(),
0
);
std
::
vector
<
size_t
>
bcast_strides
(
s1
.
ndim
(),
0
);
std
::
copy
(
s0
.
strides
().
begin
(),
s0
.
strides
().
end
(),
bcast_strides
.
begin
()
+
axis
);
std
::
copy
(
s0
.
strides
().
begin
(),
s0
.
strides
().
end
(),
bcast_strides
.
begin
()
+
axis
);
...
...
src/include/migraphx/op/contiguous.hpp
View file @
9b929d4e
...
@@ -28,6 +28,7 @@
...
@@ -28,6 +28,7 @@
#include <migraphx/argument.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/dyn_output.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -42,19 +43,27 @@ namespace op {
...
@@ -42,19 +43,27 @@ namespace op {
struct
contiguous
struct
contiguous
{
{
std
::
string
name
()
const
{
return
"contiguous"
;
}
std
::
string
name
()
const
{
return
"contiguous"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
);
if
(
inputs
.
front
().
standard
())
auto
s0
=
inputs
.
front
();
return
inputs
.
front
();
if
(
s0
.
dynamic
()
or
s0
.
standard
())
auto
lens
=
inputs
.
at
(
0
).
lens
();
{
auto
t
=
inputs
.
at
(
0
).
type
();
return
s0
;
}
else
{
const
auto
&
lens
=
s0
.
lens
();
auto
t
=
s0
.
type
();
return
{
t
,
lens
};
return
{
t
,
lens
};
}
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
}
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
{
assert
(
out
put_shape
.
standard
());
assert
(
dyn_out
.
com
put
ed
_shape
.
standard
());
argument
result
{
out
put_shape
};
argument
result
{
dyn_out
.
com
put
ed
_shape
};
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
input
)
{
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
input
)
{
shape_for_each
(
output
.
get_shape
(),
[
&
](
const
auto
&
idx
)
{
shape_for_each
(
output
.
get_shape
(),
[
&
](
const
auto
&
idx
)
{
output
(
idx
.
begin
(),
idx
.
end
())
=
input
(
idx
.
begin
(),
idx
.
end
());
output
(
idx
.
begin
(),
idx
.
end
())
=
input
(
idx
.
begin
(),
idx
.
end
());
...
...
src/include/migraphx/op/dot.hpp
View file @
9b929d4e
...
@@ -28,6 +28,7 @@
...
@@ -28,6 +28,7 @@
#include <migraphx/argument.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/gemm.hpp>
#include <migraphx/gemm.hpp>
#include <migraphx/dyn_output.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -38,41 +39,69 @@ struct dot
...
@@ -38,41 +39,69 @@ struct dot
std
::
string
name
()
const
{
return
"dot"
;
}
std
::
string
name
()
const
{
return
"dot"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
same_type
().
has
(
2
);
check_shapes
{
inputs
,
*
this
,
true
}.
same_type
().
same_ndims
().
has
(
2
);
const
shape
&
a
=
inputs
.
at
(
0
);
const
shape
&
a
=
inputs
.
at
(
0
);
const
shape
&
b
=
inputs
.
at
(
1
);
const
shape
&
b
=
inputs
.
at
(
1
);
auto
t
=
a
.
type
();
auto
t
=
a
.
type
();
if
(
not
std
::
all_of
(
if
(
not
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[](
auto
s
)
{
return
s
.
ndim
()
>=
2
;
}))
inputs
.
begin
(),
inputs
.
end
(),
[](
auto
s
)
{
return
s
.
lens
().
size
()
>=
2
;
}))
{
{
MIGRAPHX_THROW
(
"DOT: dot only accept 2 or more dim
s operands
"
);
MIGRAPHX_THROW
(
"DOT: dot only accept
s operands with
2 or more dim
ensions
"
);
}
}
if
(
a
.
dynamic
()
or
b
.
dynamic
())
// only handle the case that the batch size of a and b are the same
{
auto
s0
=
a
.
to_dynamic
();
auto
s1
=
b
.
to_dynamic
();
if
(
not
std
::
equal
(
s0
.
dyn_dims
().
rbegin
()
+
2
,
s0
.
dyn_dims
().
rend
(),
s1
.
dyn_dims
().
rbegin
()
+
2
,
s1
.
dyn_dims
().
rend
()))
{
MIGRAPHX_THROW
(
"DOT: dynamic outer dimensions of A and B mismatch: {"
+
to_string_range
(
s0
.
dyn_dims
())
+
"} x {"
+
to_string_range
(
s1
.
dyn_dims
())
+
"}"
);
}
std
::
size_t
dim_0
=
s0
.
ndim
()
-
2
;
std
::
size_t
dim_1
=
s0
.
ndim
()
-
1
;
if
(
s0
.
dyn_dims
()[
dim_1
]
!=
s1
.
dyn_dims
()[
dim_0
])
{
MIGRAPHX_THROW
(
"DOT: dynamic inner dimensions do not match: {"
+
to_string_range
(
s0
.
dyn_dims
())
+
"} x {"
+
to_string_range
(
s1
.
dyn_dims
())
+
"}"
);
}
auto
out_dyn_dims
=
s0
.
dyn_dims
();
out_dyn_dims
[
dim_1
]
=
s1
.
dyn_dims
()[
dim_1
];
return
{
t
,
out_dyn_dims
};
}
else
{
// only handle the case that all the dimensions except the last two are the same
if
(
not
std
::
equal
(
if
(
not
std
::
equal
(
a
.
lens
().
rbegin
()
+
2
,
a
.
lens
().
rend
(),
b
.
lens
().
rbegin
()
+
2
,
b
.
lens
().
rend
()))
a
.
lens
().
rbegin
()
+
2
,
a
.
lens
().
rend
(),
b
.
lens
().
rbegin
()
+
2
,
b
.
lens
().
rend
()))
{
{
MIGRAPHX_THROW
(
"DOT: batch size of A and B mismatch: {"
+
to_string_range
(
a
.
lens
())
+
MIGRAPHX_THROW
(
"DOT: static outer dimensions of A and B mismatch: {"
+
"} x {"
+
to_string_range
(
b
.
lens
())
+
"}"
);
to_string_range
(
a
.
lens
())
+
"} x {"
+
to_string_range
(
b
.
lens
())
+
"}"
);
}
}
std
::
size_t
dim_0
=
a
.
lens
().
size
()
-
2
;
std
::
size_t
dim_0
=
a
.
ndim
()
-
2
;
std
::
size_t
dim_1
=
a
.
lens
().
size
()
-
1
;
std
::
size_t
dim_1
=
a
.
ndim
()
-
1
;
if
(
a
.
lens
()[
dim_1
]
!=
b
.
lens
()[
dim_0
])
if
(
a
.
lens
()[
dim_1
]
!=
b
.
lens
()[
dim_0
])
{
{
MIGRAPHX_THROW
(
"DOT: inner dimensions do not match: {"
+
to_string_range
(
a
.
lens
())
+
MIGRAPHX_THROW
(
"DOT: static inner dimensions do not match: {"
+
"} x {"
+
to_string_range
(
b
.
lens
())
+
"}"
);
to_string_range
(
a
.
lens
())
+
"} x {"
+
to_string_range
(
b
.
lens
())
+
"}"
);
}
}
auto
out_lens
=
a
.
lens
();
auto
out_lens
=
a
.
lens
();
out_lens
[
dim_1
]
=
b
.
lens
()[
dim_1
];
out_lens
[
dim_1
]
=
b
.
lens
()[
dim_1
];
return
{
t
,
out_lens
};
return
{
t
,
out_lens
};
}
}
}
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
{
argument
result
=
argument
{
out
put_shape
};
argument
result
=
argument
{
dyn_out
.
com
put
ed
_shape
};
visit_all
(
result
,
args
[
0
],
args
[
1
])(
visit_all
(
result
,
args
[
0
],
args
[
1
])(
[
&
](
auto
cmat
,
auto
amat
,
auto
bmat
)
{
gemm
(
cmat
,
amat
,
bmat
,
1.0
f
,
0.0
f
);
});
[
&
](
auto
cmat
,
auto
amat
,
auto
bmat
)
{
gemm
(
cmat
,
amat
,
bmat
,
1.0
f
,
0.0
f
);
});
return
result
;
return
result
;
...
...
src/include/migraphx/op/flatten.hpp
View file @
9b929d4e
...
@@ -55,17 +55,47 @@ struct flatten
...
@@ -55,17 +55,47 @@ struct flatten
std
::
string
name
()
const
{
return
"flatten"
;
}
std
::
string
name
()
const
{
return
"flatten"
;
}
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
).
standard
();
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
);
auto
&&
lens
=
inputs
.
front
().
lens
();
auto
s
=
inputs
[
0
];
auto
x
=
if
(
s
.
dynamic
())
std
::
accumulate
(
lens
.
begin
(),
lens
.
begin
()
+
axis
,
std
::
size_t
{
1
},
std
::
multiplies
<>
{});
{
auto
y
=
auto
min_lens
=
s
.
min_lens
();
std
::
accumulate
(
lens
.
begin
()
+
axis
,
lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<>
{});
auto
max_lens
=
s
.
max_lens
();
return
{
inputs
.
at
(
0
).
type
(),
{
x
,
y
}};
auto
opt_lens
=
s
.
opt_lens
();
// If any of the opt values is 0, output opt will be 0
shape
::
dynamic_dimension
x
=
{
std
::
accumulate
(
min_lens
.
begin
(),
min_lens
.
begin
()
+
axis
,
std
::
size_t
{
1
},
std
::
multiplies
<>
{}),
std
::
accumulate
(
max_lens
.
begin
(),
max_lens
.
begin
()
+
axis
,
std
::
size_t
{
1
},
std
::
multiplies
<>
{}),
std
::
accumulate
(
opt_lens
.
begin
(),
opt_lens
.
begin
()
+
axis
,
std
::
size_t
{
1
},
std
::
multiplies
<>
{})};
shape
::
dynamic_dimension
y
=
{
std
::
accumulate
(
min_lens
.
begin
()
+
axis
,
min_lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<>
{}),
std
::
accumulate
(
max_lens
.
begin
()
+
axis
,
max_lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<>
{}),
std
::
accumulate
(
opt_lens
.
begin
()
+
axis
,
opt_lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<>
{}),
};
return
{
s
.
type
(),
{
x
,
y
}};
}
else
{
check_shapes
{
inputs
,
*
this
}.
standard
();
auto
&&
lens
=
s
.
lens
();
auto
x
=
std
::
accumulate
(
lens
.
begin
(),
lens
.
begin
()
+
axis
,
std
::
size_t
{
1
},
std
::
multiplies
<>
{});
auto
y
=
std
::
accumulate
(
lens
.
begin
()
+
axis
,
lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<>
{});
return
{
s
.
type
(),
{
x
,
y
}};
}
}
}
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
{
return
args
[
0
].
reshape
(
out
put_shape
);
return
args
[
0
].
reshape
(
dyn_out
.
com
put
ed
_shape
);
}
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
};
...
...
src/include/migraphx/op/layout.hpp
0 → 100644
View file @
9b929d4e
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_OP_LAYOUT_HPP
#define MIGRAPHX_GUARD_OP_LAYOUT_HPP
#include <migraphx/config.hpp>
#include <array>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/op/unary.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
layout
:
unary
<
layout
>
{
std
::
vector
<
int64_t
>
permutation
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
permutation
,
"permutation"
));
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
).
only_dims
(
permutation
.
size
());
auto
lens
=
inputs
.
at
(
0
).
lens
();
auto
t
=
inputs
.
at
(
0
).
type
();
return
shape
::
from_permutation
(
t
,
lens
,
permutation
);
}
auto
apply
()
const
{
return
[](
auto
x
)
{
return
x
;
};
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_OP_LAYOUT_HPP
src/include/migraphx/op/pooling.hpp
View file @
9b929d4e
...
@@ -31,7 +31,7 @@
...
@@ -31,7 +31,7 @@
#include <migraphx/argument.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/
int_divide
.hpp>
#include <migraphx/
dyn_output
.hpp>
#include <cmath>
#include <cmath>
#include <utility>
#include <utility>
...
@@ -49,6 +49,9 @@ struct pooling
...
@@ -49,6 +49,9 @@ struct pooling
bool
ceil_mode
=
false
;
bool
ceil_mode
=
false
;
int
lp_order
=
2
;
int
lp_order
=
2
;
// Global pooling with dynamic shape input
bool
dyn_global
=
false
;
template
<
class
Self
,
class
F
>
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
static
auto
reflect
(
Self
&
self
,
F
f
)
{
{
...
@@ -57,7 +60,8 @@ struct pooling
...
@@ -57,7 +60,8 @@ struct pooling
f
(
self
.
stride
,
"stride"
),
f
(
self
.
stride
,
"stride"
),
f
(
self
.
lengths
,
"lengths"
),
f
(
self
.
lengths
,
"lengths"
),
f
(
self
.
ceil_mode
,
"ceil_mode"
),
f
(
self
.
ceil_mode
,
"ceil_mode"
),
f
(
self
.
lp_order
,
"lp_order"
));
f
(
self
.
lp_order
,
"lp_order"
),
f
(
self
.
dyn_global
,
"dyn_global"
));
}
}
std
::
string
name
()
const
{
return
"pooling"
;
}
std
::
string
name
()
const
{
return
"pooling"
;
}
...
@@ -65,51 +69,111 @@ struct pooling
...
@@ -65,51 +69,111 @@ struct pooling
void
check_attribute_size
()
const
void
check_attribute_size
()
const
{
{
if
((
padding
.
size
()
!=
stride
.
size
()
and
(
padding
.
size
()
/
2
)
!=
stride
.
size
())
or
if
((
padding
.
size
()
!=
stride
.
size
()
and
(
padding
.
size
()
/
2
)
!=
stride
.
size
())
or
stride
.
size
()
!=
lengths
.
size
())
(
not
dyn_global
and
stride
.
size
()
!=
lengths
.
size
())
)
{
{
MIGRAPHX_THROW
(
"POOLING: inconsistent attribute sizes"
);
MIGRAPHX_THROW
(
"POOLING: inconsistent attribute sizes"
);
}
}
}
}
size_t
kdims
()
const
{
check_attribute_size
();
return
stride
.
size
();
}
value
attributes
()
const
{
return
{{
"normalize_padding"
,
"padding"
}};
}
value
attributes
()
const
{
return
{{
"normalize_padding"
,
"padding"
}};
}
std
::
vector
<
std
::
size_t
>
calc_spatial_dim_out
(
const
std
::
vector
<
std
::
size_t
>&
input_lens
,
std
::
size_t
kdims
)
const
{
std
::
vector
<
std
::
size_t
>
output_lens
{};
for
(
size_t
i
=
0
;
i
<
kdims
;
++
i
)
{
if
(
input_lens
[
i
+
2
]
==
0
)
{
// handle opt = 0
output_lens
.
push_back
(
0
);
}
else
{
std
::
size_t
padding_factor
=
2
*
padding
[
i
];
if
(
padding
.
size
()
==
2
*
kdims
)
padding_factor
=
padding
[
i
]
+
padding
[
i
+
kdims
];
assert
(
input_lens
[
i
+
2
]
+
padding_factor
>=
lengths
[
i
]);
std
::
size_t
dim_size
=
input_lens
[
i
+
2
]
+
padding_factor
-
lengths
[
i
];
std
::
size_t
len
=
(
ceil_mode
)
?
dim_size
/
stride
[
i
]
+
static_cast
<
std
::
size_t
>
((
dim_size
%
stride
[
i
]
!=
0
))
// ceil uint divide
:
dim_size
/
stride
[
i
];
// floor divide
output_lens
.
push_back
(
len
+
1
);
}
}
return
output_lens
;
}
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
);
check_attribute_size
();
const
shape
&
input
=
inputs
.
at
(
0
);
const
shape
&
input
=
inputs
.
at
(
0
);
auto
input_lens
=
input
.
lens
();
size_t
kdims
=
input_lens
.
size
()
-
2
;
auto
input_size
=
inputs
[
0
].
lens
().
size
();
auto
padding_size
=
padding
.
size
();
auto
padding_size
=
padding
.
size
();
if
(
input_size
!=
padding_size
/
2
+
2
and
input_size
!=
padding_size
+
2
)
size_t
kdims
=
input
.
ndim
()
-
2
;
if
(
input
.
ndim
()
!=
padding_size
/
2
+
2
and
input
.
ndim
()
!=
padding_size
+
2
)
{
{
MIGRAPHX_THROW
(
"POOLING: input and attribute size mismatch!"
);
MIGRAPHX_THROW
(
"POOLING: input and attribute size mismatch!"
);
}
}
std
::
vector
<
std
::
size_t
>
output_lens
(
input_lens
.
begin
(),
input_lens
.
begin
()
+
2
);
if
(
input
.
dynamic
())
for
(
size_t
i
=
0
;
i
<
kdims
;
i
++
)
{
{
std
::
ptrdiff_t
dim_size
;
auto
input_dyn_dims
=
input
.
dyn_dims
();
auto
padding_factor
=
2
*
padding
[
i
];
std
::
vector
<
shape
::
dynamic_dimension
>
output_dyn_dims
(
input_dyn_dims
.
begin
(),
if
(
padding_size
==
2
*
kdims
)
input_dyn_dims
.
begin
()
+
2
);
padding_factor
=
padding
[
i
]
+
padding
[
i
+
kdims
];
if
(
dyn_global
)
dim_size
=
input_lens
[
i
+
2
]
+
padding_factor
-
lengths
[
i
];
{
assert
(
dim_size
>=
0
);
for
(
size_t
i
=
0
;
i
<
kdims
;
++
i
)
std
::
size_t
len
=
(
ceil_mode
)
?
ceil_divide
<
std
::
ptrdiff_t
>
(
dim_size
,
stride
[
i
])
{
:
floor_divide
<
std
::
ptrdiff_t
>
(
dim_size
,
stride
[
i
]);
output_dyn_dims
.
push_back
(
shape
::
dynamic_dimension
{
1
,
1
,
1
});
output_lens
.
push_back
(
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
1
,
len
+
1
)));
}
}
return
inputs
[
0
].
with_lens
(
output_lens
);
return
{
input
.
type
(),
output_dyn_dims
};
}
else
{
auto
min_spatial_dims
=
calc_spatial_dim_out
(
input
.
min_lens
(),
kdims
);
auto
max_spatial_dims
=
calc_spatial_dim_out
(
input
.
max_lens
(),
kdims
);
auto
opt_spatial_dims
=
calc_spatial_dim_out
(
input
.
opt_lens
(),
kdims
);
for
(
size_t
i
=
0
;
i
<
kdims
;
++
i
)
{
output_dyn_dims
.
push_back
(
shape
::
dynamic_dimension
{
min_spatial_dims
[
i
],
max_spatial_dims
[
i
],
opt_spatial_dims
[
i
]});
}
return
{
input
.
type
(),
output_dyn_dims
};
}
}
}
else
{
auto
input_lens
=
input
.
lens
();
size_t
kdims
()
const
std
::
vector
<
std
::
size_t
>
output_lens
(
input_lens
.
begin
(),
input_lens
.
begin
()
+
2
);
// Used for when normalize_compute_shape() is called again at model eval time
// for an originally dynamic shape. Since kernel shape is not used with dyn_global.
if
(
dyn_global
)
{
{
check_attribute_size
();
for
(
size_t
i
=
0
;
i
<
kdims
;
++
i
)
return
stride
.
size
();
{
output_lens
.
push_back
(
1
);
}
return
{
input
.
type
(),
output_lens
};
}
else
{
auto
output_spatial_lens
=
calc_spatial_dim_out
(
input_lens
,
kdims
);
output_lens
.
insert
(
output_lens
.
end
(),
output_spatial_lens
.
begin
(),
output_spatial_lens
.
end
());
return
inputs
[
0
].
with_lens
(
output_lens
);
}
}
}
}
struct
lpnorm_pool
struct
lpnorm_pool
...
@@ -158,7 +222,11 @@ struct pooling
...
@@ -158,7 +222,11 @@ struct pooling
};
};
template
<
class
Type
,
class
Out
,
class
In
,
class
Op
>
template
<
class
Type
,
class
Out
,
class
In
,
class
Op
>
void
calc_pooling
(
const
shape
&
output_shape
,
Out
&
output
,
const
In
&
input
,
Op
op
)
const
void
calc_pooling
(
const
shape
&
output_shape
,
Out
&
output
,
const
In
&
input
,
const
std
::
vector
<
std
::
size_t
>&
kernel_dims
,
Op
op
)
const
{
{
auto
in_s
=
input
.
get_shape
();
auto
in_s
=
input
.
get_shape
();
auto
in_lens
=
in_s
.
lens
();
auto
in_lens
=
in_s
.
lens
();
...
@@ -172,7 +240,7 @@ struct pooling
...
@@ -172,7 +240,7 @@ struct pooling
auto
d_2
=
dim
-
2
;
auto
d_2
=
dim
-
2
;
int
start
=
int
start
=
static_cast
<
int
>
(
idx_o
[
dim
]
*
stride
[
d_2
])
-
static_cast
<
int
>
(
padding
[
d_2
]);
static_cast
<
int
>
(
idx_o
[
dim
]
*
stride
[
d_2
])
-
static_cast
<
int
>
(
padding
[
d_2
]);
int
end
=
std
::
min
(
start
+
length
s
[
d_2
],
in_lens
[
dim
]);
int
end
=
std
::
min
(
start
+
kernel_dim
s
[
d_2
],
in_lens
[
dim
]);
start
=
std
::
max
(
start
,
0
);
start
=
std
::
max
(
start
,
0
);
win_start
.
push_back
(
start
);
win_start
.
push_back
(
start
);
win_size
.
push_back
(
end
-
start
);
win_size
.
push_back
(
end
-
start
);
...
@@ -198,21 +266,32 @@ struct pooling
...
@@ -198,21 +266,32 @@ struct pooling
});
});
}
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
dyn_out
.
computed_shape
};
auto
input_lens
=
args
[
0
].
get_shape
().
lens
();
std
::
vector
<
std
::
size_t
>
kernel_dims
;
if
(
dyn_global
)
{
{
argument
result
{
output_shape
};
kernel_dims
.
insert
(
kernel_dims
.
end
(),
input_lens
.
begin
()
+
2
,
input_lens
.
end
());
}
else
{
kernel_dims
=
this
->
lengths
;
}
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
input
)
{
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
input
)
{
using
type
=
typename
decltype
(
output
)
::
value_type
;
using
type
=
typename
decltype
(
output
)
::
value_type
;
switch
(
mode
)
switch
(
mode
)
{
{
case
migraphx
::
op
::
pooling_mode
::
average
:
case
migraphx
::
op
::
pooling_mode
::
average
:
calc_pooling
<
type
>
(
out
put_shape
,
output
,
input
,
avg_pool
{});
calc_pooling
<
type
>
(
dyn_out
.
com
put
ed
_shape
,
output
,
input
,
kernel_dims
,
avg_pool
{});
break
;
break
;
case
migraphx
::
op
::
pooling_mode
::
max
:
case
migraphx
::
op
::
pooling_mode
::
max
:
calc_pooling
<
type
>
(
out
put_shape
,
output
,
input
,
max_pool
{});
calc_pooling
<
type
>
(
dyn_out
.
com
put
ed
_shape
,
output
,
input
,
kernel_dims
,
max_pool
{});
break
;
break
;
case
migraphx
::
op
::
pooling_mode
::
lpnorm
:
case
migraphx
::
op
::
pooling_mode
::
lpnorm
:
calc_pooling
<
type
>
(
output_shape
,
output
,
input
,
lpnorm_pool
{
lp_order
});
calc_pooling
<
type
>
(
dyn_out
.
computed_shape
,
output
,
input
,
kernel_dims
,
lpnorm_pool
{
lp_order
});
break
;
break
;
}
}
});
});
...
...
src/include/migraphx/op/softmax.hpp
View file @
9b929d4e
...
@@ -53,15 +53,15 @@ struct softmax
...
@@ -53,15 +53,15 @@ struct softmax
std
::
string
name
()
const
{
return
"softmax"
;
}
std
::
string
name
()
const
{
return
"softmax"
;
}
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
);
if
(
inputs
.
at
(
0
).
packed
())
auto
s0
=
inputs
[
0
];
if
(
s0
.
dynamic
()
or
s0
.
packed
())
{
{
return
inputs
.
at
(
0
)
;
return
s0
;
}
}
else
else
{
{
auto
lens
=
inputs
.
at
(
0
).
lens
();
return
{
s0
.
type
(),
s0
.
lens
()};
return
{
inputs
.
at
(
0
).
type
(),
lens
};
}
}
}
}
...
...
src/include/migraphx/op/squeeze.hpp
View file @
9b929d4e
...
@@ -59,9 +59,8 @@ struct squeeze
...
@@ -59,9 +59,8 @@ struct squeeze
auto
input_shape
=
inputs
[
0
];
auto
input_shape
=
inputs
[
0
];
if
(
input_shape
.
dynamic
())
if
(
input_shape
.
dynamic
())
{
{
std
::
vector
<
shape
::
dynamic_dimension
>
one_dyn_dims
{{
1
,
1
,
0
},
{
1
,
1
,
1
}};
if
(
std
::
any_of
(
axes
.
begin
(),
axes
.
end
(),
[
&
](
auto
axis
)
{
if
(
std
::
any_of
(
axes
.
begin
(),
axes
.
end
(),
[
&
](
auto
axis
)
{
return
not
contains
(
one_dyn_dims
,
input_shape
.
dyn_dims
()[
axis
]
)
;
return
input_shape
.
dyn_dims
()[
axis
]
!=
1
;
}))
}))
{
{
MIGRAPHX_THROW
(
MIGRAPHX_THROW
(
...
@@ -70,14 +69,10 @@ struct squeeze
...
@@ -70,14 +69,10 @@ struct squeeze
std
::
vector
<
shape
::
dynamic_dimension
>
dyn_dims
=
{};
std
::
vector
<
shape
::
dynamic_dimension
>
dyn_dims
=
{};
if
(
axes
.
empty
())
if
(
axes
.
empty
())
{
{
for
(
auto
i
:
range
(
input_shape
.
ndim
()))
std
::
copy_if
(
input_shape
.
dyn_dims
().
cbegin
(),
{
input_shape
.
dyn_dims
().
cend
(),
auto
dd
=
input_shape
.
dyn_dims
()[
i
];
std
::
back_inserter
(
dyn_dims
),
if
(
not
contains
(
one_dyn_dims
,
dd
))
[
&
](
auto
dd
)
{
return
dd
!=
1
;
});
{
dyn_dims
.
push_back
(
dd
);
}
}
}
}
else
else
{
{
...
...
src/include/migraphx/op/transpose.hpp
View file @
9b929d4e
...
@@ -29,6 +29,7 @@
...
@@ -29,6 +29,7 @@
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/dyn_output.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -45,17 +46,15 @@ struct transpose
...
@@ -45,17 +46,15 @@ struct transpose
}
}
std
::
string
name
()
const
{
return
"transpose"
;
}
std
::
string
name
()
const
{
return
"transpose"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
);
auto
input
=
inputs
.
at
(
0
);
auto
input
=
inputs
.
at
(
0
);
auto
input_lens
=
input
.
lens
();
auto
input_strides
=
input
.
strides
();
auto
t
=
input
.
type
();
if
(
dims
.
size
()
!=
input
_lens
.
size
())
if
(
dims
.
size
()
!=
input
.
ndim
())
{
{
MIGRAPHX_THROW
(
"Permutation has wrong number of axes"
);
MIGRAPHX_THROW
(
"
TRANSPOSE:
Permutation has wrong number of axes"
);
}
}
std
::
vector
<
int64_t
>
axes
(
dims
.
size
());
std
::
vector
<
int64_t
>
axes
(
dims
.
size
());
std
::
iota
(
axes
.
begin
(),
axes
.
end
(),
0
);
std
::
iota
(
axes
.
begin
(),
axes
.
end
(),
0
);
...
@@ -63,19 +62,36 @@ struct transpose
...
@@ -63,19 +62,36 @@ struct transpose
{
{
MIGRAPHX_THROW
(
"TRANSPOSE: Invalid permutation"
);
MIGRAPHX_THROW
(
"TRANSPOSE: Invalid permutation"
);
}
}
std
::
vector
<
size_t
>
output_lens
(
input_lens
.
size
());
std
::
vector
<
size_t
>
output_strides
(
input_lens
.
size
());
if
(
input
.
dynamic
())
for
(
std
::
size_t
i
=
0
;
i
<
output_lens
.
size
();
i
++
)
{
std
::
vector
<
shape
::
dynamic_dimension
>
output_dyn_dims
(
input
.
ndim
());
std
::
transform
(
dims
.
cbegin
(),
dims
.
cend
(),
output_dyn_dims
.
begin
(),
[
&
](
auto
dim
)
{
return
input
.
dyn_dims
()[
dim
];
});
return
{
input
.
type
(),
output_dyn_dims
};
}
else
{
auto
input_lens
=
input
.
lens
();
auto
input_strides
=
input
.
strides
();
std
::
vector
<
size_t
>
output_lens
(
input
.
ndim
());
std
::
vector
<
size_t
>
output_strides
(
input
.
ndim
());
for
(
std
::
size_t
i
=
0
;
i
<
input
.
ndim
();
i
++
)
{
{
output_lens
[
i
]
=
input_lens
[
dims
[
i
]];
output_lens
[
i
]
=
input_lens
[
dims
[
i
]];
output_strides
[
i
]
=
input_strides
[
dims
[
i
]];
output_strides
[
i
]
=
input_strides
[
dims
[
i
]];
}
}
return
{
t
,
output_lens
,
output_strides
};
return
{
input
.
type
()
,
output_lens
,
output_strides
};
}
}
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
}
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
{
return
args
[
0
].
reshape
(
out
put_shape
);
return
args
[
0
].
reshape
(
dyn_out
.
com
put
ed
_shape
);
}
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
};
...
...
src/include/migraphx/program.hpp
View file @
9b929d4e
...
@@ -115,6 +115,7 @@ struct program
...
@@ -115,6 +115,7 @@ struct program
print_func
)
const
;
print_func
)
const
;
void
print_graph
(
std
::
ostream
&
os
,
bool
brief
=
false
)
const
;
void
print_graph
(
std
::
ostream
&
os
,
bool
brief
=
false
)
const
;
void
print_py
(
std
::
ostream
&
os
)
const
;
void
print_cpp
(
std
::
ostream
&
os
)
const
;
void
print_cpp
(
std
::
ostream
&
os
)
const
;
void
dry_run
(
parameter_map
params
)
const
;
void
dry_run
(
parameter_map
params
)
const
;
...
...
src/include/migraphx/shape.hpp
View file @
9b929d4e
...
@@ -101,6 +101,12 @@ struct shape
...
@@ -101,6 +101,12 @@ struct shape
friend
bool
operator
==
(
const
dynamic_dimension
&
x
,
const
dynamic_dimension
&
y
);
friend
bool
operator
==
(
const
dynamic_dimension
&
x
,
const
dynamic_dimension
&
y
);
friend
bool
operator
!=
(
const
dynamic_dimension
&
x
,
const
dynamic_dimension
&
y
);
friend
bool
operator
!=
(
const
dynamic_dimension
&
x
,
const
dynamic_dimension
&
y
);
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
dynamic_dimension
&
x
);
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
dynamic_dimension
&
x
);
// compare to fixed std::size_t dimension
friend
bool
operator
==
(
const
dynamic_dimension
&
x
,
const
std
::
size_t
&
y
);
friend
bool
operator
==
(
const
std
::
size_t
&
x
,
const
dynamic_dimension
&
y
);
friend
bool
operator
!=
(
const
dynamic_dimension
&
x
,
const
std
::
size_t
&
y
);
friend
bool
operator
!=
(
const
std
::
size_t
&
x
,
const
dynamic_dimension
&
y
);
};
};
static
const
std
::
vector
<
type_t
>&
types
();
static
const
std
::
vector
<
type_t
>&
types
();
...
...
src/include/migraphx/shape_for_each.hpp
View file @
9b929d4e
...
@@ -31,6 +31,9 @@
...
@@ -31,6 +31,9 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
/**
* Iterates the given function over the indices from the shape in order.
*/
template
<
class
F
>
template
<
class
F
>
void
shape_for_each
(
const
migraphx
::
shape
&
s
,
F
f
)
void
shape_for_each
(
const
migraphx
::
shape
&
s
,
F
f
)
{
{
...
@@ -51,7 +54,6 @@ void shape_for_each(const migraphx::shape& s, F f)
...
@@ -51,7 +54,6 @@ void shape_for_each(const migraphx::shape& s, F f)
call
(
indices
);
call
(
indices
);
}
}
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
...
...
src/insert_pad.cpp
View file @
9b929d4e
...
@@ -77,14 +77,14 @@ static void update_pooling(const instruction_ref& input, const instruction_ref&
...
@@ -77,14 +77,14 @@ static void update_pooling(const instruction_ref& input, const instruction_ref&
{
{
return
;
return
;
}
}
auto
kdims
=
input
->
get_shape
().
lens
().
size
()
-
2
;
auto
kdims
=
input
->
get_shape
().
ndim
()
-
2
;
if
(
std
::
equal
(
op
.
padding
.
begin
(),
if
(
std
::
equal
(
op
.
padding
.
begin
(),
op
.
padding
.
begin
()
+
kdims
,
op
.
padding
.
begin
()
+
kdims
,
op
.
padding
.
begin
()
+
kdims
,
op
.
padding
.
begin
()
+
kdims
,
op
.
padding
.
end
()))
op
.
padding
.
end
()))
return
;
return
;
std
::
vector
<
int64_t
>
padding
(
input
->
get_shape
().
lens
().
size
()
*
2
,
0
);
std
::
vector
<
int64_t
>
padding
(
input
->
get_shape
().
ndim
()
*
2
,
0
);
std
::
vector
<
size_t
>
pads_l
(
op
.
padding
.
begin
(),
op
.
padding
.
begin
()
+
kdims
);
std
::
vector
<
size_t
>
pads_l
(
op
.
padding
.
begin
(),
op
.
padding
.
begin
()
+
kdims
);
std
::
vector
<
size_t
>
pads_r
(
op
.
padding
.
begin
()
+
kdims
,
op
.
padding
.
end
());
std
::
vector
<
size_t
>
pads_r
(
op
.
padding
.
begin
()
+
kdims
,
op
.
padding
.
end
());
op
.
padding
=
std
::
vector
<
size_t
>
(
kdims
*
2
,
0
);
op
.
padding
=
std
::
vector
<
size_t
>
(
kdims
*
2
,
0
);
...
...
src/instruction.cpp
100644 → 100755
View file @
9b929d4e
...
@@ -302,6 +302,24 @@ void instruction::replace_mod_argument(module_ref old, module_ref new_mod)
...
@@ -302,6 +302,24 @@ void instruction::replace_mod_argument(module_ref old, module_ref new_mod)
std
::
replace
(
module_args
.
begin
(),
module_args
.
end
(),
old
,
new_mod
);
std
::
replace
(
module_args
.
begin
(),
module_args
.
end
(),
old
,
new_mod
);
}
}
bool
instruction
::
is_undefined
()
const
{
if
(
op
.
name
()
==
"undefined"
)
{
return
true
;
}
else
if
(
this
->
inputs
().
empty
())
{
return
false
;
}
else
{
return
std
::
all_of
(
this
->
inputs
().
begin
(),
this
->
inputs
().
end
(),
[](
auto
arg
)
{
return
arg
->
is_undefined
();
});
}
}
bool
instruction
::
can_eval
()
const
bool
instruction
::
can_eval
()
const
{
{
if
(
op
.
name
()
==
"@literal"
)
if
(
op
.
name
()
==
"@literal"
)
...
...
src/layout_nhwc.cpp
0 → 100644
View file @
9b929d4e
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/layout_nhwc.hpp>
#include <migraphx/module.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
template
<
class
Predicate
>
std
::
vector
<
instruction_ref
>
find_lasts
(
const
module
&
m
,
Predicate
pred
)
{
std
::
vector
<
instruction_ref
>
result
;
fix
([
&
](
auto
self
,
auto
ins
)
{
if
(
pred
(
ins
))
{
result
.
push_back
(
ins
);
return
;
}
for
(
auto
input
:
ins
->
inputs
())
self
(
input
);
})(
std
::
prev
(
m
.
end
()));
return
result
;
}
std
::
unordered_set
<
instruction_ref
>
preserve_output_layout
(
module
&
m
)
{
std
::
unordered_set
<
instruction_ref
>
result
;
std
::
vector
<
instruction_ref
>
outputs
=
find_lasts
(
m
,
[](
auto
ins
)
{
return
ins
->
name
()
==
"convolution"
and
ins
->
get_shape
().
lens
().
size
()
==
4
;
});
for
(
auto
output
:
outputs
)
{
auto
permutation
=
find_permutation
(
output
->
get_shape
());
auto
layout
=
m
.
insert_instruction
(
std
::
next
(
output
),
make_op
(
"layout"
,
{{
"permutation"
,
permutation
}}),
output
);
result
.
insert
(
m
.
replace_instruction
(
output
,
layout
));
}
return
result
;
}
void
transform_convolutions
(
module
&
m
)
{
for
(
auto
ins
:
iterator_for
(
m
))
{
if
(
ins
->
name
()
!=
"convolution"
)
continue
;
if
(
ins
->
get_shape
().
lens
().
size
()
!=
4
)
continue
;
auto
v
=
ins
->
get_operator
().
to_value
();
if
(
v
.
at
(
"group"
).
to
<
int
>
()
>
1
)
continue
;
auto
args
=
ins
->
inputs
();
std
::
transform
(
args
.
begin
(),
args
.
end
(),
args
.
begin
(),
[
&
](
const
auto
&
i
)
{
return
m
.
insert_instruction
(
ins
,
make_op
(
"layout"
,
{{
"permutation"
,
{
0
,
2
,
3
,
1
}}}),
i
);
});
auto
conv
=
m
.
insert_instruction
(
ins
,
ins
->
get_operator
(),
args
);
auto
c
=
m
.
insert_instruction
(
ins
,
make_op
(
"contiguous"
),
conv
);
m
.
replace_instruction
(
ins
,
c
);
}
}
void
remove_layout
(
module
&
m
,
const
std
::
unordered_set
<
instruction_ref
>&
output_layouts
)
{
for
(
auto
ins
:
iterator_for
(
m
))
{
if
(
ins
->
name
()
!=
"layout"
)
continue
;
if
(
ins
->
get_shape
()
!=
ins
->
inputs
().
front
()
->
get_shape
())
continue
;
if
(
contains
(
output_layouts
,
ins
))
continue
;
m
.
replace_instruction
(
ins
,
ins
->
inputs
().
front
());
}
}
void
layout_nhwc
::
apply
(
module_pass_manager
&
mpm
)
const
{
std
::
unordered_set
<
instruction_ref
>
output_layouts
=
preserve_output_layout
(
mpm
.
get_module
());
transform_convolutions
(
mpm
.
get_module
());
mpm
.
run_pass
(
dead_code_elimination
{});
mpm
.
run_pass
(
eliminate_contiguous
{
"contiguous"
});
mpm
.
run_pass
(
dead_code_elimination
{});
remove_layout
(
mpm
.
get_module
(),
output_layouts
);
mpm
.
run_pass
(
dead_code_elimination
{});
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/module.cpp
View file @
9b929d4e
...
@@ -789,6 +789,22 @@ static std::string cpp_var_name(const std::string& name)
...
@@ -789,6 +789,22 @@ static std::string cpp_var_name(const std::string& name)
return
to_c_id
(
"x_"
+
replace_string
(
name
,
":"
,
"_module_"
));
return
to_c_id
(
"x_"
+
replace_string
(
name
,
":"
,
"_module_"
));
}
}
static
void
print_py_op
(
std
::
ostream
&
os
,
const
operation
&
op
)
{
auto
v
=
op
.
to_value
();
os
<<
"migraphx.op("
<<
enclose_name
(
op
.
name
());
auto
default_values
=
make_op
(
op
.
name
()).
to_value
();
for
(
auto
&&
x
:
v
)
{
auto
name
=
x
.
get_key
();
if
(
default_values
[
name
]
==
x
)
continue
;
os
<<
", "
<<
name
<<
"="
<<
to_json_string
(
x
.
without_key
());
}
os
<<
")"
;
}
static
void
print_make_op
(
std
::
ostream
&
os
,
const
operation
&
op
)
static
void
print_make_op
(
std
::
ostream
&
os
,
const
operation
&
op
)
{
{
auto
v
=
op
.
to_value
();
auto
v
=
op
.
to_value
();
...
@@ -804,6 +820,14 @@ static void print_make_op(std::ostream& os, const operation& op)
...
@@ -804,6 +820,14 @@ static void print_make_op(std::ostream& os, const operation& op)
os
<<
")"
;
os
<<
")"
;
}
}
static
void
print_py_shape
(
std
::
ostream
&
os
,
const
migraphx
::
shape
&
s
)
{
os
<<
"migraphx.shape("
<<
s
.
type_string
()
<<
", lens="
<<
to_json_string
(
s
.
lens
());
if
(
not
s
.
standard
())
os
<<
", strides="
<<
to_json_string
(
s
.
strides
());
os
<<
")"
;
}
static
void
print_cpp_shape
(
std
::
ostream
&
os
,
const
migraphx
::
shape
&
s
)
static
void
print_cpp_shape
(
std
::
ostream
&
os
,
const
migraphx
::
shape
&
s
)
{
{
os
<<
"migraphx::shape{migraphx::shape::"
<<
s
.
type_string
();
os
<<
"migraphx::shape{migraphx::shape::"
<<
s
.
type_string
();
...
@@ -813,6 +837,68 @@ static void print_cpp_shape(std::ostream& os, const migraphx::shape& s)
...
@@ -813,6 +837,68 @@ static void print_cpp_shape(std::ostream& os, const migraphx::shape& s)
os
<<
"}"
;
os
<<
"}"
;
}
}
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
module
::
print_py
(
std
::
ostream
&
os
,
const
std
::
string
&
mname
,
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
names
)
const
{
// cppcheck-suppress variableScope
unsigned
long
seed
=
names
.
size
();
auto
last
=
std
::
prev
(
this
->
end
());
names
=
this
->
print
(
[
&
](
auto
ins
,
auto
ins_names
)
{
std
::
vector
<
std
::
string
>
input_vars
;
std
::
transform
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
std
::
back_inserter
(
input_vars
),
[
&
](
auto
input
)
{
return
cpp_var_name
(
ins_names
.
at
(
input
));
});
if
(
ins
!=
last
)
os
<<
cpp_var_name
(
ins_names
.
at
(
ins
))
<<
" = "
;
if
(
ins
->
name
()
==
"@literal"
)
{
os
<<
mname
<<
".add_literal("
;
bool
use_abs
=
false
;
ins
->
get_literal
().
visit
([
&
](
auto
v
)
{
use_abs
=
std
::
none_of
(
v
.
begin
(),
v
.
end
(),
[](
auto
x
)
{
return
x
<
0
;
});
});
// Disable abs for now
use_abs
=
false
;
if
(
use_abs
)
os
<<
"migraphx.abs_literal("
;
os
<<
"migraphx.generate_literal("
;
print_py_shape
(
os
,
ins
->
get_shape
());
os
<<
", "
<<
seed
<<
")"
;
if
(
use_abs
)
os
<<
")"
;
os
<<
")"
<<
std
::
endl
;
seed
++
;
}
else
if
(
ins
->
name
()
==
"@param"
)
{
std
::
string
name
=
any_cast
<
builtin
::
param
>
(
ins
->
get_operator
()).
parameter
;
os
<<
mname
<<
".add_parameter("
<<
enclose_name
(
name
)
<<
","
;
print_py_shape
(
os
,
ins
->
get_shape
());
os
<<
")"
<<
std
::
endl
;
}
else
if
(
ins
->
name
()
==
"@return"
)
{
os
<<
mname
<<
".add_return(["
<<
join_strings
(
input_vars
,
", "
)
<<
"])"
<<
std
::
endl
;
}
else
{
assert
(
ins
->
name
().
front
()
!=
'@'
);
os
<<
mname
<<
".add_instruction("
;
print_py_op
(
os
,
ins
->
get_operator
());
os
<<
", ["
<<
join_strings
(
input_vars
,
", "
)
<<
"]"
;
os
<<
")"
<<
std
::
endl
;
}
},
names
);
return
names
;
}
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
module
::
print_cpp
(
std
::
ostream
&
os
,
module
::
print_cpp
(
std
::
ostream
&
os
,
const
std
::
string
&
mname
,
const
std
::
string
&
mname
,
...
@@ -874,6 +960,8 @@ module::print_cpp(std::ostream& os,
...
@@ -874,6 +960,8 @@ module::print_cpp(std::ostream& os,
return
names
;
return
names
;
}
}
void
module
::
print_py
(
std
::
ostream
&
os
)
const
{
this
->
print_py
(
os
,
this
->
name
(),
{});
}
void
module
::
print_cpp
(
std
::
ostream
&
os
)
const
{
this
->
print_cpp
(
os
,
this
->
name
(),
{});
}
void
module
::
print_cpp
(
std
::
ostream
&
os
)
const
{
this
->
print_cpp
(
os
,
this
->
name
(),
{});
}
void
module
::
annotate
(
std
::
ostream
&
os
,
std
::
function
<
void
(
instruction_ref
)
>
a
)
const
void
module
::
annotate
(
std
::
ostream
&
os
,
std
::
function
<
void
(
instruction_ref
)
>
a
)
const
...
...
src/onnx/onnx_parser.cpp
View file @
9b929d4e
...
@@ -393,18 +393,31 @@ literal onnx_parser::parse_value(const onnx::AttributeProto& attr) const
...
@@ -393,18 +393,31 @@ literal onnx_parser::parse_value(const onnx::AttributeProto& attr) const
literal
onnx_parser
::
parse_tensor
(
const
onnx
::
TensorProto
&
t
)
const
literal
onnx_parser
::
parse_tensor
(
const
onnx
::
TensorProto
&
t
)
const
{
{
std
::
vector
<
std
::
size_t
>
dims
(
t
.
dims
().
begin
(),
t
.
dims
().
end
());
std
::
vector
<
std
::
size_t
>
dims
(
t
.
dims
().
begin
(),
t
.
dims
().
end
());
if
(
not
t
.
external_data
().
empty
())
auto
type
=
get_type
(
t
.
data_type
());
shape
tensor_shape
(
type
,
dims
);
auto
external_data
=
t
.
external_data
();
if
(
not
external_data
.
empty
())
{
const
std
::
string
&
data_file
=
external_data
.
at
(
0
).
value
();
size_t
num_data_fields
=
external_data
.
size
();
size_t
offset
=
0
;
size_t
nbytes
=
tensor_shape
.
bytes
();
if
(
num_data_fields
>
1
)
// if offset field is present
{
offset
=
std
::
stoul
(
t
.
external_data
().
at
(
1
).
value
());
}
if
(
num_data_fields
>
2
)
// if nbytes field is present
{
{
const
std
::
string
&
data_file
=
t
.
external_data
().
at
(
0
).
value
();
nbytes
=
std
::
stoul
(
t
.
external_data
().
at
(
2
).
value
());
auto
raw_buffer
=
read_buffer
(
path
+
"/"
+
data_file
);
}
auto
raw_buffer
=
read_buffer
(
path
+
"/"
+
data_file
,
offset
,
nbytes
);
std
::
string
s
(
raw_buffer
.
begin
(),
raw_buffer
.
end
());
std
::
string
s
(
raw_buffer
.
begin
(),
raw_buffer
.
end
());
auto
type
=
get_type
(
t
.
data_type
());
return
create_literal
(
type
,
dims
,
s
.
data
());
return
create_literal
(
type
,
dims
,
s
.
data
());
}
}
if
(
t
.
has_raw_data
())
if
(
t
.
has_raw_data
())
{
{
const
std
::
string
&
s
=
t
.
raw_data
();
const
std
::
string
&
s
=
t
.
raw_data
();
auto
type
=
get_type
(
t
.
data_type
());
return
create_literal
(
type
,
dims
,
s
.
data
());
return
create_literal
(
type
,
dims
,
s
.
data
());
}
}
...
...
src/onnx/parse_binary_op.cpp
View file @
9b929d4e
...
@@ -57,6 +57,12 @@ struct parse_binary_op : op_parser<parse_binary_op>
...
@@ -57,6 +57,12 @@ struct parse_binary_op : op_parser<parse_binary_op>
parser
.
parse_value
(
info
.
attributes
.
at
(
"broadcast"
)).
at
<
uint64_t
>
();
parser
.
parse_value
(
info
.
attributes
.
at
(
"broadcast"
)).
at
<
uint64_t
>
();
if
(
broadcasted
!=
0
)
if
(
broadcasted
!=
0
)
{
{
if
(
std
::
any_of
(
args
.
cbegin
(),
args
.
cend
(),
[](
auto
a
)
{
return
a
->
get_shape
().
dynamic
();
}))
{
MIGRAPHX_THROW
(
"Binary op broadcast attribute not supported for dynamic input shapes"
);
}
uint64_t
axis
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"axis"
)).
at
<
uint64_t
>
();
uint64_t
axis
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"axis"
)).
at
<
uint64_t
>
();
auto
l
=
info
.
add_instruction
(
auto
l
=
info
.
add_instruction
(
make_op
(
"broadcast"
,
make_op
(
"broadcast"
,
...
...
src/onnx/parse_pooling.cpp
View file @
9b929d4e
...
@@ -47,52 +47,42 @@ struct parse_pooling : op_parser<parse_pooling>
...
@@ -47,52 +47,42 @@ struct parse_pooling : op_parser<parse_pooling>
{
"GlobalLpPool"
,
"lpnorm"
}};
{
"GlobalLpPool"
,
"lpnorm"
}};
}
}
instruction_ref
parse
(
const
op_desc
&
opd
,
value
handle_values
(
const
op_desc
&
opd
,
const
onnx_parser
&
/*parser*/
,
onnx_parser
::
node_info
info
,
onnx_parser
::
node_info
info
,
std
::
vector
<
instruction_ref
>
args
)
const
const
shape
&
in_shape
,
value
values
)
const
{
{
const
std
::
unordered_map
<
std
::
string
,
op
::
pooling_mode
>
mode_map
=
{
auto
kdims
=
in_shape
.
ndim
()
-
2
;
{
"max"
,
op
::
pooling_mode
::
max
},
if
(
starts_with
(
opd
.
onnx_name
,
"Global"
))
{
"average"
,
op
::
pooling_mode
::
average
},
{
"lpnorm"
,
op
::
pooling_mode
::
lpnorm
}};
std
::
string
mode
=
opd
.
op_name
;
if
(
not
contains
(
mode_map
,
mode
))
{
{
MIGRAPHX_THROW
(
"onnx pooling mode must be [
\"
max
\"
,
\"
average
\"
,
\"
lpnorm
\"
]"
);
// if spatial dimensions are dynamic use dyn_global flag
if
(
in_shape
.
dynamic
()
and
std
::
any_of
(
in_shape
.
dyn_dims
().
cbegin
()
+
2
,
in_shape
.
dyn_dims
().
cend
(),
[](
auto
dd
)
{
return
not
dd
.
is_fixed
();
}))
{
values
[
"dyn_global"
]
=
true
;
values
[
"lengths"
]
=
std
::
vector
<
size_t
>
();
}
}
operation
op
=
make_op
(
"pooling"
,
{{
"mode"
,
mode_map
.
at
(
mode
)}});
else
value
values
=
op
.
to_value
();
auto
l0
=
args
[
0
];
auto
in_lens
=
l0
->
get_shape
().
lens
();
assert
(
in_lens
.
size
()
>
2
);
auto
kdims
=
in_lens
.
size
()
-
2
;
if
(
starts_with
(
opd
.
onnx_name
,
"Global"
))
{
{
values
[
"lengths"
]
=
std
::
vector
<
size_t
>
(
in_lens
.
begin
()
+
2
,
in_lens
.
end
());
// works with static and fixed dynamic shape
auto
m_lens
=
in_shape
.
max_lens
();
values
[
"lengths"
]
=
std
::
vector
<
size_t
>
(
m_lens
.
begin
()
+
2
,
m_lens
.
end
());
}
}
}
// does not support ceil_mode
if
(
contains
(
info
.
attributes
,
"ceil_mode"
))
if
(
contains
(
info
.
attributes
,
"ceil_mode"
))
{
{
values
[
"ceil_mode"
]
=
static_cast
<
bool
>
(
info
.
attributes
.
at
(
"ceil_mode"
).
i
());
values
[
"ceil_mode"
]
=
static_cast
<
bool
>
(
info
.
attributes
.
at
(
"ceil_mode"
).
i
());
}
}
// count include padding, if count include pad is 1, we always use
// explicit pad
int
count_include_pad
=
0
;
if
(
contains
(
info
.
attributes
,
"count_include_pad"
))
{
count_include_pad
=
info
.
attributes
.
at
(
"count_include_pad"
).
i
();
}
if
(
contains
(
info
.
attributes
,
"strides"
))
if
(
contains
(
info
.
attributes
,
"strides"
))
{
{
values
[
"stride"
].
clear
();
values
[
"stride"
].
clear
();
copy
(
info
.
attributes
[
"strides"
].
ints
(),
std
::
back_inserter
(
values
[
"stride"
]));
copy
(
info
.
attributes
[
"strides"
].
ints
(),
std
::
back_inserter
(
values
[
"stride"
]));
check_attr_sizes
(
kdims
,
values
[
"stride"
].
size
(),
"PARSE_POOLING: inconsistent strides"
);
check_attr_sizes
(
kdims
,
values
[
"stride"
].
size
(),
"PARSE_POOLING: inconsistent strides"
);
}
}
if
(
contains
(
info
.
attributes
,
"kernel_shape"
))
if
(
contains
(
info
.
attributes
,
"kernel_shape"
))
{
{
values
[
"lengths"
].
clear
();
values
[
"lengths"
].
clear
();
...
@@ -110,6 +100,46 @@ struct parse_pooling : op_parser<parse_pooling>
...
@@ -110,6 +100,46 @@ struct parse_pooling : op_parser<parse_pooling>
// ensure pads availabe only when auto_pad is "NOT_SET"
// ensure pads availabe only when auto_pad is "NOT_SET"
check_padding_mode
(
info
,
"POOLING"
);
check_padding_mode
(
info
,
"POOLING"
);
return
values
;
}
instruction_ref
parse
(
const
op_desc
&
opd
,
const
onnx_parser
&
/*parser*/
,
onnx_parser
::
node_info
info
,
std
::
vector
<
instruction_ref
>
args
)
const
{
std
::
string
mode
=
opd
.
op_name
;
const
std
::
unordered_map
<
std
::
string
,
op
::
pooling_mode
>
mode_map
=
{
{
"max"
,
op
::
pooling_mode
::
max
},
{
"average"
,
op
::
pooling_mode
::
average
},
{
"lpnorm"
,
op
::
pooling_mode
::
lpnorm
}};
if
(
not
contains
(
mode_map
,
mode
))
{
MIGRAPHX_THROW
(
"PARSE_POOLING: onnx pooling mode must be [
\"
max
\"
,
\"
average
\"
,
\"
lpnorm
\"
]"
);
}
operation
op
=
make_op
(
"pooling"
,
{{
"mode"
,
mode_map
.
at
(
mode
)}});
value
values
=
op
.
to_value
();
auto
l0
=
args
[
0
];
auto
in_shape
=
l0
->
get_shape
();
assert
(
in_shape
.
ndim
()
>
2
);
auto
kdims
=
in_shape
.
ndim
()
-
2
;
values
=
handle_values
(
opd
,
info
,
in_shape
,
values
);
// count include padding, if count include pad is 1, we always use
// explicit pad
int
count_include_pad
=
0
;
if
(
contains
(
info
.
attributes
,
"count_include_pad"
))
{
if
(
in_shape
.
dynamic
())
{
MIGRAPHX_THROW
(
"PARSE_POOLING: count_include_pad attribute is not supported for "
"dynamic input shape"
);
}
count_include_pad
=
info
.
attributes
.
at
(
"count_include_pad"
).
i
();
}
std
::
vector
<
int64_t
>
paddings
;
std
::
vector
<
int64_t
>
paddings
;
float
pad_val
=
((
mode
==
"max"
)
?
std
::
numeric_limits
<
float
>::
lowest
()
:
0.0
f
);
float
pad_val
=
((
mode
==
"max"
)
?
std
::
numeric_limits
<
float
>::
lowest
()
:
0.0
f
);
...
@@ -122,6 +152,13 @@ struct parse_pooling : op_parser<parse_pooling>
...
@@ -122,6 +152,13 @@ struct parse_pooling : op_parser<parse_pooling>
}
}
if
(
contains
(
info
.
attributes
,
"auto_pad"
))
if
(
contains
(
info
.
attributes
,
"auto_pad"
))
{
if
(
in_shape
.
dynamic
())
{
MIGRAPHX_THROW
(
"PARSE_POOLING: Auto padding pooling with dynamic input shape not supported"
);
}
else
{
{
values
[
"padding"
].
clear
();
values
[
"padding"
].
clear
();
// return paddings could be empty, then setting to 0 for no padding
// return paddings could be empty, then setting to 0 for no padding
...
@@ -129,9 +166,10 @@ struct parse_pooling : op_parser<parse_pooling>
...
@@ -129,9 +166,10 @@ struct parse_pooling : op_parser<parse_pooling>
values
,
values
,
values
[
"lengths"
].
to_vector
<
std
::
size_t
>
(),
values
[
"lengths"
].
to_vector
<
std
::
size_t
>
(),
{
1
,
1
},
{
1
,
1
},
in_
lens
,
in_shape
.
lens
()
,
paddings
);
paddings
);
}
}
}
if
(
paddings
.
size
()
!=
2
*
kdims
)
if
(
paddings
.
size
()
!=
2
*
kdims
)
{
{
...
@@ -150,6 +188,7 @@ struct parse_pooling : op_parser<parse_pooling>
...
@@ -150,6 +188,7 @@ struct parse_pooling : op_parser<parse_pooling>
values
[
"stride"
].
resize
(
kdims
);
values
[
"stride"
].
resize
(
kdims
);
std
::
fill_n
(
values
[
"stride"
].
begin
(),
kdims
,
1
);
std
::
fill_n
(
values
[
"stride"
].
begin
(),
kdims
,
1
);
}
}
// used to calculate the supposed output shape
// used to calculate the supposed output shape
std
::
vector
<
int64_t
>
orig_padding
=
paddings
;
std
::
vector
<
int64_t
>
orig_padding
=
paddings
;
...
@@ -159,6 +198,11 @@ struct parse_pooling : op_parser<parse_pooling>
...
@@ -159,6 +198,11 @@ struct parse_pooling : op_parser<parse_pooling>
if
(
not
slice_start
.
empty
())
if
(
not
slice_start
.
empty
())
{
{
if
(
in_shape
.
dynamic
())
{
MIGRAPHX_THROW
(
"PARSE_POOLING: asymmetric padding not supported for dynamic input shape"
);
}
// calculate expected output shape
// calculate expected output shape
orig_padding
.
insert
(
orig_padding
.
begin
()
+
kdims
,
2
,
0
);
orig_padding
.
insert
(
orig_padding
.
begin
()
+
kdims
,
2
,
0
);
orig_padding
.
insert
(
orig_padding
.
begin
(),
2
,
0
);
orig_padding
.
insert
(
orig_padding
.
begin
(),
2
,
0
);
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment