"csrc/vscode:/vscode.git/clone" did not exist on "96853af5a830d42496aa6cfd5c670d073d6a0209"
Commit 21db7647 authored by charlie's avatar charlie
Browse files

Resnet50 runs with dyn_test_runner

* changed split_single_dyn_dim to add a get_tuple_element instruction
needed to access tuple type created by select_module
* dyn_test_runner changed to have offload_copy to false
* split_single_dyn_dim is not going to work with offload_copy unless we
make `load`, `copy_to_gpu` and `copy_from_gpu` handle dynamic shapes
parent 5af9aac0
......@@ -128,13 +128,18 @@ void split_single_dyn_dim::apply(module_pass_manager& mpm) const
param_names.cend(),
std::back_inserter(sm_inputs),
[&](auto pn) { return mm->get_parameter(pn); });
migraphx::shape out_attr = migraphx::shape{mm->get_output_shapes()};
auto sm_ins = mm->add_instruction(
auto output_shapes = mm->get_output_shapes();
migraphx::shape out_attr = migraphx::shape{output_shapes};
auto ret = mm->add_instruction(
migraphx::make_op("select_module",
{{"output_dyn_shapes", migraphx::to_value(out_attr)}}),
sm_inputs,
submodules);
mm->replace_return({sm_ins});
if (output_shapes.size() == 1)
{
ret = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
}
mm->replace_return({ret});
}
}
......
......@@ -247,16 +247,15 @@ def run_one_case(model, param_map):
# convert np array to model argument
pp = {}
for key, val in param_map.items():
pp[key] = migraphx.argument(val)
pp[key] = migraphx.to_gpu(migraphx.argument(val))
# run the model
model_outputs = model.run(param_map)
model_outputs = model.run(pp)
# convert argument to np array
outputs = []
outputs = [];
for output in model_outputs:
outputs.append(np.array(output))
host_output = migraphx.from_gpu(output)
outputs.append(np.array(host_output))
return outputs
......@@ -322,7 +321,7 @@ def main():
else:
model = migraphx.parse_onnx(model_path_name,
default_dyn_dim_value=default_dd_val)
model.compile(migraphx.get_target(target))
model.compile(migraphx.get_target(target), offload_copy=False)
# get test cases
cases = get_test_cases(test_loc)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment