Unverified Commit 27e123df authored by lin bin's avatar lin bin Committed by GitHub
Browse files

fix data device type bug (#3856)

parent 009722a6
......@@ -39,6 +39,7 @@ After getting mixed precision engine, users can do inference with input data.
Note
* Recommend using "cpu"(host) as data device(for both inference data and calibration data) since data should be on host initially and it will be transposed to device before inference. If data type is not "cpu"(host), this tool will transpose it to "cpu" which may increases unnecessary overhead.
* User can also do post-training quantization leveraging TensorRT directly(need to provide calibration dataset).
* Not all op types are supported right now. At present, NNI supports Conv, Linear, Relu and MaxPool. More op types will be supported in the following release.
......
......@@ -290,6 +290,9 @@ class ModelSpeedupTensorRT(BaseModelSpeedup):
calib_data_set.append(data)
calib_data = np.concatenate(calib_data_set)
elif type(self.calib_data_loader) == torch.Tensor:
# trt need numpy as calibration data, only cpu data can convert to numpy directly
if self.calib_data_loader.device != torch.device("cpu"):
self.calib_data_loader = self.calib_data_loader.to("cpu")
calib_data = self.calib_data_loader.numpy()
else:
raise ValueError("Not support calibration datatype")
......@@ -326,6 +329,8 @@ class ModelSpeedupTensorRT(BaseModelSpeedup):
Model input tensor
"""
# convert pytorch tensor to numpy darray
if test_data.device != torch.device("cpu"):
test_data = test_data.to("cpu")
test_data = test_data.numpy()
# Numpy dtype should be float32
assert test_data.dtype == np.float32
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment