Unverified Commit 27e123df authored by lin bin's avatar lin bin Committed by GitHub
Browse files

fix data device type bug (#3856)

parent 009722a6
...@@ -39,6 +39,7 @@ After getting mixed precision engine, users can do inference with input data. ...@@ -39,6 +39,7 @@ After getting mixed precision engine, users can do inference with input data.
Note Note
* Recommend using "cpu"(host) as data device(for both inference data and calibration data) since data should be on host initially and it will be transposed to device before inference. If data type is not "cpu"(host), this tool will transpose it to "cpu" which may increases unnecessary overhead.
* User can also do post-training quantization leveraging TensorRT directly(need to provide calibration dataset). * User can also do post-training quantization leveraging TensorRT directly(need to provide calibration dataset).
* Not all op types are supported right now. At present, NNI supports Conv, Linear, Relu and MaxPool. More op types will be supported in the following release. * Not all op types are supported right now. At present, NNI supports Conv, Linear, Relu and MaxPool. More op types will be supported in the following release.
......
...@@ -290,6 +290,9 @@ class ModelSpeedupTensorRT(BaseModelSpeedup): ...@@ -290,6 +290,9 @@ class ModelSpeedupTensorRT(BaseModelSpeedup):
calib_data_set.append(data) calib_data_set.append(data)
calib_data = np.concatenate(calib_data_set) calib_data = np.concatenate(calib_data_set)
elif type(self.calib_data_loader) == torch.Tensor: elif type(self.calib_data_loader) == torch.Tensor:
# trt need numpy as calibration data, only cpu data can convert to numpy directly
if self.calib_data_loader.device != torch.device("cpu"):
self.calib_data_loader = self.calib_data_loader.to("cpu")
calib_data = self.calib_data_loader.numpy() calib_data = self.calib_data_loader.numpy()
else: else:
raise ValueError("Not support calibration datatype") raise ValueError("Not support calibration datatype")
...@@ -326,6 +329,8 @@ class ModelSpeedupTensorRT(BaseModelSpeedup): ...@@ -326,6 +329,8 @@ class ModelSpeedupTensorRT(BaseModelSpeedup):
Model input tensor Model input tensor
""" """
# convert pytorch tensor to numpy darray # convert pytorch tensor to numpy darray
if test_data.device != torch.device("cpu"):
test_data = test_data.to("cpu")
test_data = test_data.numpy() test_data = test_data.numpy()
# Numpy dtype should be float32 # Numpy dtype should be float32
assert test_data.dtype == np.float32 assert test_data.dtype == np.float32
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment