Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
eb8f205b
"examples/instruct_pix2pix/requirements.txt" did not exist on "c2283310688ff75e8fb4be3d9938ed0818cb038d"
Commit
eb8f205b
authored
May 25, 2022
by
charlie
Browse files
Dynamic weight handling
parent
a0dd2ef9
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
14 additions
and
16 deletions
+14
-16
src/include/migraphx/op/convolution.hpp
src/include/migraphx/op/convolution.hpp
+11
-14
src/program.cpp
src/program.cpp
+3
-2
No files found.
src/include/migraphx/op/convolution.hpp
View file @
eb8f205b
...
@@ -64,21 +64,19 @@ struct convolution
...
@@ -64,21 +64,19 @@ struct convolution
const
shape
&
input
=
inputs
.
at
(
0
);
const
shape
&
input
=
inputs
.
at
(
0
);
const
shape
&
weights
=
inputs
.
at
(
1
);
const
shape
&
weights
=
inputs
.
at
(
1
);
if
(
weights
.
dynamic
())
{
MIGRAPHX_THROW
(
"CONVOLUTION: dynamic weights not supported"
);
}
const
size_t
num_spatial_dims
=
input_size
-
2
;
const
size_t
num_spatial_dims
=
input_size
-
2
;
if
(
num_spatial_dims
!=
this
->
kdims
())
if
(
num_spatial_dims
!=
this
->
kdims
())
{
{
MIGRAPHX_THROW
(
"CONVOLUTION: input k-dims does not match attribute size"
);
MIGRAPHX_THROW
(
"CONVOLUTION: input k-dims does not match attribute size"
);
}
}
if
(
!
input
.
dynamic
()
and
input
.
lens
().
at
(
1
)
!=
(
weights
.
lens
().
at
(
1
)
*
group
))
if
(
not
input
.
dynamic
()
and
not
weights
.
dynamic
()
and
input
.
lens
().
at
(
1
)
!=
(
weights
.
lens
().
at
(
1
)
*
group
))
MIGRAPHX_THROW
(
"CONVOLUTION: mismatched channel numbers"
);
MIGRAPHX_THROW
(
"CONVOLUTION: mismatched channel numbers"
);
auto
calc_output_lens
=
auto
calc_output_lens
=
[
this
,
&
weights
,
&
num_spatial_dims
,
&
padding_size
](
std
::
vector
<
std
::
size_t
>
lens
)
{
[
this
,
&
num_spatial_dims
,
&
padding_size
](
std
::
vector
<
std
::
size_t
>
i_lens
,
std
::
vector
<
std
::
size_t
>
w_lens
)
{
std
::
vector
<
size_t
>
ret
=
{};
std
::
vector
<
size_t
>
ret
=
{};
// calculate the output shape of the convolution: ((W - K + 2P) / S) + 1
// calculate the output shape of the convolution: ((W - K + 2P) / S) + 1
for
(
size_t
i
=
0
;
i
<
num_spatial_dims
;
i
++
)
for
(
size_t
i
=
0
;
i
<
num_spatial_dims
;
i
++
)
...
@@ -91,8 +89,7 @@ struct convolution
...
@@ -91,8 +89,7 @@ struct convolution
}
}
ret
.
push_back
(
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
ret
.
push_back
(
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
1
,
1
,
(
lens
[
i
+
2
]
-
(
1
+
dilation
[
i
]
*
(
weights
.
lens
()[
i
+
2
]
-
1
))
+
(
i_lens
[
i
+
2
]
-
(
1
+
dilation
[
i
]
*
(
w_lens
[
i
+
2
]
-
1
))
+
padding_factor
)
/
padding_factor
)
/
stride
[
i
]
+
stride
[
i
]
+
1
)));
1
)));
}
}
...
@@ -103,9 +100,9 @@ struct convolution
...
@@ -103,9 +100,9 @@ struct convolution
{
{
std
::
vector
<
shape
::
dynamic_dimension
>
output_dyn_dims
=
{
input
.
dyn_dims
().
at
(
0
),
std
::
vector
<
shape
::
dynamic_dimension
>
output_dyn_dims
=
{
input
.
dyn_dims
().
at
(
0
),
input
.
dyn_dims
().
at
(
1
)};
input
.
dyn_dims
().
at
(
1
)};
auto
min_spatial_dims
=
calc_output_lens
(
input
.
min_lens
());
auto
min_spatial_dims
=
calc_output_lens
(
input
.
min_lens
()
,
weights
.
min_lens
()
);
auto
max_spatial_dims
=
calc_output_lens
(
input
.
max_lens
());
auto
max_spatial_dims
=
calc_output_lens
(
input
.
max_lens
()
,
weights
.
max_lens
()
);
auto
opt_spatial_dims
=
calc_output_lens
(
input
.
opt_lens
());
auto
opt_spatial_dims
=
calc_output_lens
(
input
.
opt_lens
()
,
weights
.
opt_lens
()
);
for
(
size_t
i
=
0
;
i
<
num_spatial_dims
;
++
i
)
for
(
size_t
i
=
0
;
i
<
num_spatial_dims
;
++
i
)
{
{
output_dyn_dims
.
push_back
(
shape
::
dynamic_dimension
{
output_dyn_dims
.
push_back
(
shape
::
dynamic_dimension
{
...
@@ -116,7 +113,7 @@ struct convolution
...
@@ -116,7 +113,7 @@ struct convolution
else
else
{
{
std
::
vector
<
size_t
>
output_lens
{
input
.
lens
()[
0
],
weights
.
lens
()[
0
]};
std
::
vector
<
size_t
>
output_lens
{
input
.
lens
()[
0
],
weights
.
lens
()[
0
]};
auto
spatial_lens
=
calc_output_lens
(
input
.
lens
());
auto
spatial_lens
=
calc_output_lens
(
input
.
lens
()
,
weights
.
lens
()
);
std
::
for_each
(
spatial_lens
.
begin
(),
spatial_lens
.
end
(),
[
&
output_lens
](
auto
x
)
{
std
::
for_each
(
spatial_lens
.
begin
(),
spatial_lens
.
end
(),
[
&
output_lens
](
auto
x
)
{
output_lens
.
push_back
(
x
);
output_lens
.
push_back
(
x
);
});
});
...
...
src/program.cpp
View file @
eb8f205b
...
@@ -312,7 +312,9 @@ std::vector<argument> generic_eval(const module* mod,
...
@@ -312,7 +312,9 @@ std::vector<argument> generic_eval(const module* mod,
return
shapes
;
return
shapes
;
};
};
// TODO: Consider how this will be handled when memoized.
// TODO: Consider how this will be handled when memoized.
// Could memoize these output shapes now so not recalculating
// Could memoize these output shapes into a map so not recalculating
// TODO: Issue with incompatible input tensor to kernel and needing to set
// padding/strides
output_shape
=
ins
->
get_operator
().
compute_shape
(
to_shapes
(
values
));
output_shape
=
ins
->
get_operator
().
compute_shape
(
to_shapes
(
values
));
}
}
else
else
...
@@ -333,7 +335,6 @@ std::vector<argument> generic_eval(const module* mod,
...
@@ -333,7 +335,6 @@ std::vector<argument> generic_eval(const module* mod,
}));
}));
}
}
assert
(
results
.
find
(
ins
)
!=
results
.
end
());
assert
(
results
.
find
(
ins
)
!=
results
.
end
());
// TODO: update this assert for dynamic shapes
if
(
not
ins
->
get_shape
().
dynamic
())
if
(
not
ins
->
get_shape
().
dynamic
())
{
{
assert
(
results
.
at
(
ins
).
get_shape
()
==
ins
->
get_shape
());
assert
(
results
.
at
(
ins
).
get_shape
()
==
ins
->
get_shape
());
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment