Commit 85369e3a authored by Bowen Bao's avatar Bowen Bao Committed by Soumith Chintala
Browse files

Fix inception v3 input transform for trace & onnx (#621)

* Fix inception v3 input transform for trace & onnx

* Input transform are in-place updates, which produce issues for tracing
and exporting to onnx.

* nit
parent 8f943d4e
...@@ -70,10 +70,10 @@ class Inception3(nn.Module): ...@@ -70,10 +70,10 @@ class Inception3(nn.Module):
def forward(self, x): def forward(self, x):
if self.transform_input: if self.transform_input:
x = x.clone() x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
x[:, 0] = x[:, 0] * (0.229 / 0.5) + (0.485 - 0.5) / 0.5 x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
x[:, 1] = x[:, 1] * (0.224 / 0.5) + (0.456 - 0.5) / 0.5 x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
x[:, 2] = x[:, 2] * (0.225 / 0.5) + (0.406 - 0.5) / 0.5 x = torch.cat((x_ch0, x_ch1, x_ch2), 1)
# 299 x 299 x 3 # 299 x 299 x 3
x = self.Conv2d_1a_3x3(x) x = self.Conv2d_1a_3x3(x)
# 149 x 149 x 32 # 149 x 149 x 32
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment