Commit eb7e273d authored by Irvin Ho's avatar Irvin Ho Committed by Francisco Massa
Browse files

during normalize, construct mean and std tensors on the same device as the input tensor (#787)

parent 9d9f48a3
......@@ -203,8 +203,8 @@ def normalize(tensor, mean, std, inplace=False):
if not inplace:
tensor = tensor.clone()
mean = torch.tensor(mean, dtype=torch.float32)
std = torch.tensor(std, dtype=torch.float32)
mean = torch.tensor(mean, dtype=torch.float32, device=tensor.device)
std = torch.tensor(std, dtype=torch.float32, device=tensor.device)
tensor.sub_(mean[:, None, None]).div_(std[:, None, None])
return tensor
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment