Commit 97c70f1a authored by Your Name's avatar Your Name
Browse files

修改模型输入数据格式

parent 5274a8cb
...@@ -454,9 +454,10 @@ class det_rec_functions(object): ...@@ -454,9 +454,10 @@ class det_rec_functions(object):
img_part, shape_part_list = data_part img_part, shape_part_list = data_part
img_part = np.expand_dims(img_part, axis=0) img_part = np.expand_dims(img_part, axis=0)
shape_part_list = np.expand_dims(shape_part_list, axis=0) shape_part_list = np.expand_dims(shape_part_list, axis=0)
img_part = np.ascontiguousarray(img_part)
# migraphx推理 # migraphx推理
resultDets = self.modelDet.run({self.modelDet.get_parameter_names()[0]: migraphx.argument(img_part)}) resultDets = self.modelDet.run({self.modelDet.get_parameter_names()[0]: img_part})
# 获取第一个输出节点的数据,migraphx.argument类型 # 获取第一个输出节点的数据,migraphx.argument类型
resultDet = resultDets[0] resultDet = resultDets[0]
outs_part = np.array(resultDet) outs_part = np.array(resultDet)
...@@ -499,7 +500,7 @@ class det_rec_functions(object): ...@@ -499,7 +500,7 @@ class det_rec_functions(object):
img = img[np.newaxis, :] img = img[np.newaxis, :]
# migraphx推理 # migraphx推理
results = self.modelRec.run({self.modelRec.get_parameter_names()[0]: migraphx.argument(img)}) results = self.modelRec.run({self.modelRec.get_parameter_names()[0]: img})
# 获取第一个输出节点的数据,migraphx.argument类型 # 获取第一个输出节点的数据,migraphx.argument类型
result = results[0] result = results[0]
outs = np.array(result) outs = np.array(result)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment