Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
a465fc9d
Commit
a465fc9d
authored
May 23, 2022
by
charlie
Browse files
Dynamic conv draft progress
parent
417d6644
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
41 additions
and
8 deletions
+41
-8
src/auto_contiguous.cpp
src/auto_contiguous.cpp
+1
-1
src/insert_pad.cpp
src/insert_pad.cpp
+5
-0
src/program.cpp
src/program.cpp
+25
-3
src/targets/ref/lowering.cpp
src/targets/ref/lowering.cpp
+4
-0
test/ref_ops_test.cpp
test/ref_ops_test.cpp
+6
-4
No files found.
src/auto_contiguous.cpp
View file @
a465fc9d
...
@@ -40,7 +40,7 @@ void auto_contiguous::apply(module& m) const
...
@@ -40,7 +40,7 @@ void auto_contiguous::apply(module& m) const
if
(
ins
->
outputs
().
empty
()
and
ins
!=
last
)
if
(
ins
->
outputs
().
empty
()
and
ins
!=
last
)
continue
;
continue
;
shape
s
=
ins
->
get_shape
();
shape
s
=
ins
->
get_shape
();
if
(
not
s
.
standard
()
and
s
.
elements
()
!=
0
)
if
(
not
s
.
dynamic
()
and
not
s
.
standard
()
and
s
.
elements
()
!=
0
)
{
{
auto
c
=
m
.
insert_instruction
(
std
::
next
(
ins
),
make_op
(
"contiguous"
),
ins
);
auto
c
=
m
.
insert_instruction
(
std
::
next
(
ins
),
make_op
(
"contiguous"
),
ins
);
m
.
replace_instruction
(
ins
,
c
);
m
.
replace_instruction
(
ins
,
c
);
...
...
src/insert_pad.cpp
View file @
a465fc9d
...
@@ -17,6 +17,11 @@ static void update_op(const instruction_ref& input, const instruction_ref& ins,
...
@@ -17,6 +17,11 @@ static void update_op(const instruction_ref& input, const instruction_ref& ins,
auto
val
=
op
.
to_value
();
auto
val
=
op
.
to_value
();
auto
op_padding
=
val
.
at
(
"padding"
).
to_vector
<
size_t
>
();
auto
op_padding
=
val
.
at
(
"padding"
).
to_vector
<
size_t
>
();
if
(
input
->
get_shape
().
dynamic
())
{
return
;
}
auto
kdims
=
input
->
get_shape
().
lens
().
size
()
-
2
;
auto
kdims
=
input
->
get_shape
().
lens
().
size
()
-
2
;
if
(
std
::
equal
(
op_padding
.
begin
(),
if
(
std
::
equal
(
op_padding
.
begin
(),
op_padding
.
begin
()
+
kdims
,
op_padding
.
begin
()
+
kdims
,
...
...
src/program.cpp
View file @
a465fc9d
...
@@ -264,8 +264,10 @@ std::vector<argument> generic_eval(const module* mod,
...
@@ -264,8 +264,10 @@ std::vector<argument> generic_eval(const module* mod,
auto
param_name
=
any_cast
<
builtin
::
param
>
(
ins
->
get_operator
()).
parameter
;
auto
param_name
=
any_cast
<
builtin
::
param
>
(
ins
->
get_operator
()).
parameter
;
if
(
not
contains
(
params
,
param_name
))
if
(
not
contains
(
params
,
param_name
))
MIGRAPHX_THROW
(
"Parameter not found: "
+
param_name
);
MIGRAPHX_THROW
(
"Parameter not found: "
+
param_name
);
auto
param
=
params
[
param_name
];
auto
param
=
params
[
param_name
];
if
(
param
.
get_shape
()
!=
ins
->
get_shape
())
// TODO: may want to check correct number of dimensions and/or was within bounds
if
(
not
ins
->
get_shape
().
dynamic
()
and
param
.
get_shape
()
!=
ins
->
get_shape
())
MIGRAPHX_THROW
(
"Incorrect shape {"
+
to_string
(
param
.
get_shape
())
+
MIGRAPHX_THROW
(
"Incorrect shape {"
+
to_string
(
param
.
get_shape
())
+
"} for parameter: "
+
param_name
);
"} for parameter: "
+
param_name
);
return
param
;
return
param
;
...
@@ -297,6 +299,25 @@ std::vector<argument> generic_eval(const module* mod,
...
@@ -297,6 +299,25 @@ std::vector<argument> generic_eval(const module* mod,
return
results
[
i
];
return
results
[
i
];
});
});
shape
output_shape
;
auto
ins_shape
=
ins
->
get_shape
();
if
(
ins_shape
.
dynamic
())
{
// Make into a std::vector<instruction_ref> of inputs
auto
to_shapes
=
[](
std
::
vector
<
argument
>
args
)
{
std
::
vector
<
shape
>
shapes
(
args
.
size
());
std
::
transform
(
args
.
begin
(),
args
.
end
(),
shapes
.
begin
(),
[](
argument
i
)
{
return
i
.
get_shape
();
});
return
shapes
;
};
output_shape
=
ins
->
get_operator
().
compute_shape
(
to_shapes
(
values
));
}
else
{
output_shape
=
ins_shape
;
}
const
auto
&
mod_args
=
ins
->
module_inputs
();
const
auto
&
mod_args
=
ins
->
module_inputs
();
auto
module_eval
=
[
&
](
module_ref
smod
,
auto
module_eval
=
[
&
](
module_ref
smod
,
const
std
::
unordered_map
<
std
::
string
,
argument
>&
inputs
)
{
const
std
::
unordered_map
<
std
::
string
,
argument
>&
inputs
)
{
...
@@ -306,11 +327,12 @@ std::vector<argument> generic_eval(const module* mod,
...
@@ -306,11 +327,12 @@ std::vector<argument> generic_eval(const module* mod,
results
.
emplace
(
ins
,
trace
(
ins
,
[
&
]
{
results
.
emplace
(
ins
,
trace
(
ins
,
[
&
]
{
return
ins
->
normalized_operator
().
compute
(
return
ins
->
normalized_operator
().
compute
(
ctx
,
ins
->
ge
t_shape
()
,
values
,
mod_args
,
module_eval
);
ctx
,
outpu
t_shape
,
values
,
mod_args
,
module_eval
);
}));
}));
}
}
assert
(
results
.
find
(
ins
)
!=
results
.
end
());
assert
(
results
.
find
(
ins
)
!=
results
.
end
());
assert
(
results
.
at
(
ins
).
get_shape
()
==
ins
->
get_shape
());
// TODO: update this assert for dynamic shapes
// assert(results.at(ins).get_shape() == ins->get_shape());
}
}
return
{
results
.
at
(
std
::
prev
(
mod
->
end
()))};
return
{
results
.
at
(
std
::
prev
(
mod
->
end
()))};
}
}
...
...
src/targets/ref/lowering.cpp
View file @
a465fc9d
...
@@ -210,6 +210,10 @@ struct ref_convolution : auto_register_op<ref_convolution<Op>>
...
@@ -210,6 +210,10 @@ struct ref_convolution : auto_register_op<ref_convolution<Op>>
}
}
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
if
(
output_shape
.
dynamic
())
{
op
.
normalize_compute_shape
({
args
.
at
(
0
).
get_shape
(),
args
.
at
(
1
).
get_shape
()});
}
argument
result
{
output_shape
};
argument
result
{
output_shape
};
visit_quantize
(
result
,
args
[
0
],
args
[
1
])([
&
](
auto
output
,
auto
input
,
auto
weights
)
{
visit_quantize
(
result
,
args
[
0
],
args
[
1
])([
&
](
auto
output
,
auto
input
,
auto
weights
)
{
auto
in_lens
=
input
.
get_shape
().
lens
();
auto
in_lens
=
input
.
get_shape
().
lens
();
...
...
test/ref_ops_test.cpp
View file @
a465fc9d
...
@@ -855,11 +855,11 @@ TEST_CASE(conv_dynamic_batch_test)
...
@@ -855,11 +855,11 @@ TEST_CASE(conv_dynamic_batch_test)
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
input_shape
{
migraphx
::
shape
::
float_type
,
migraphx
::
shape
input_
dyn_
shape
{
migraphx
::
shape
::
float_type
,
{{
1
,
100
,
0
},
{
3
,
3
,
0
},
{
4
,
4
,
0
},
{
4
,
4
,
0
}}};
{{
1
,
100
,
0
},
{
3
,
3
,
0
},
{
4
,
4
,
0
},
{
4
,
4
,
0
}}};
migraphx
::
shape
weights_shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
3
,
3
}};
migraphx
::
shape
weights_shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
3
,
3
}};
auto
input
=
mm
->
add_parameter
(
"X"
,
input_shape
);
auto
input
=
mm
->
add_parameter
(
"X"
,
input_
dyn_
shape
);
auto
weights
=
mm
->
add_parameter
(
"W"
,
weights_shape
);
auto
weights
=
mm
->
add_parameter
(
"W"
,
weights_shape
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"convolution"
,
{{
"padding"
,
{
1
,
1
}},
{
"stride"
,
{
2
,
2
}}}),
mm
->
add_instruction
(
migraphx
::
make_op
(
"convolution"
,
{{
"padding"
,
{
1
,
1
}},
{
"stride"
,
{
2
,
2
}}}),
input
,
input
,
...
@@ -910,8 +910,10 @@ TEST_CASE(conv_dynamic_batch_test)
...
@@ -910,8 +910,10 @@ TEST_CASE(conv_dynamic_batch_test)
-
0.16138598
,
-
0.16138598
,
0.79344082
};
0.79344082
};
migraphx
::
shape
input_fixed_shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
4
}};
migraphx
::
parameter_map
params
;
migraphx
::
parameter_map
params
;
params
[
"X"
]
=
migraphx
::
argument
(
input_shape
,
a
.
data
());
params
[
"X"
]
=
migraphx
::
argument
(
input_
fixed_
shape
,
a
.
data
());
params
[
"W"
]
=
migraphx
::
argument
(
weights_shape
,
c
.
data
());
params
[
"W"
]
=
migraphx
::
argument
(
weights_shape
,
c
.
data
());
auto
result
=
p
.
eval
(
params
).
back
();
auto
result
=
p
.
eval
(
params
).
back
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment