"test/vscode:/vscode.git/clone" did not exist on "15385fb1f1577e5f2ca5f501e3025e76fc2ad4a6"
Commit e6a8bf6a authored by liucong's avatar liucong
Browse files

修改python示例程序

parent 25e4f86c
......@@ -62,32 +62,46 @@ def ort_seg_dcu(model_path,image):
provider_options=[{'device_id':'0','migraphx_fp16_enable':'true'}]
dcu_session = ort.InferenceSession(model_path, providers=['MIGraphXExecutionProvider'], provider_options=provider_options)
output_data = np.empty(dcu_session.get_outputs()[0].shape).astype(np.float32)
inputData={}
outputData={}
inputData[dcu_session.get_inputs()[0].name]=ort.OrtValue.ortvalue_from_numpy(image, device_type='cuda')
outputData[dcu_session.get_outputs()[0].name]=ort.OrtValue.ortvalue_from_numpy(output_data, device_type='cuda')
images = [image]
input_nodes = dcu_session.get_inputs()
input_names = [i_n.name for i_n in input_nodes]
output_nodes = dcu_session.get_outputs()
output_names = [o_n.name for o_n in output_nodes]
input_dict = {}
for i_d, i_n in zip(images, input_names):
input_dict[i_n] = i_d
io_binding = dcu_session.io_binding()
io_binding.bind_input(
name=dcu_session.get_inputs()[0].name,
device_type=inputData[dcu_session.get_inputs()[0].name].device_name(),
device_id=0,
element_type=np.float32,
shape=inputData[dcu_session.get_inputs()[0].name].shape(),
buffer_ptr=inputData[dcu_session.get_inputs()[0].name].data_ptr())
inputData = {}
for key in input_dict.keys():
inputData[key] = ort.OrtValue.ortvalue_from_numpy(input_dict[key], device_type='cuda')
io_binding.bind_input(
name = key,
device_type = inputData[key].device_name(),
device_id = 0,
element_type = np.float32,
shape = inputData[key].shape(),
buffer_ptr = inputData[key].data_ptr())
outputData = {}
output_data = {}
for index, o_n in enumerate(output_names):
output_data[o_n] = np.empty(dcu_session.get_outputs()[index].shape).astype(np.float32)
outputData[o_n] = ort.OrtValue.ortvalue_from_numpy(output_data[o_n], device_type='cuda')
io_binding.bind_output(
name=dcu_session.get_outputs()[0].name,
device_type=outputData[dcu_session.get_outputs()[0].name].device_name(),
device_id=0,
element_type=np.float32,
shape=outputData[dcu_session.get_outputs()[0].name].shape(),
buffer_ptr=outputData[dcu_session.get_outputs()[0].name].data_ptr())
name = o_n,
device_type = outputData[o_n].device_name(),
device_id = 0,
element_type = np.float32,
shape = outputData[o_n].shape(),
buffer_ptr = outputData[o_n].data_ptr())
dcu_session.run_with_iobinding(io_binding)
scores=np.array(io_binding.copy_outputs_to_cpu()[0])
result = io_binding.copy_outputs_to_cpu()[0]
scores = np.array(result)
print("ort result.shape:",scores.shape)
return scores
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment