"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "4125756e88e82370c197fecf28e9f0b4d7eee6c3"
Commit eb7e273d authored by Irvin Ho's avatar Irvin Ho Committed by Francisco Massa
Browse files

during normalize, construct mean and std tensors on the same device as the input tensor (#787)

parent 9d9f48a3
...@@ -203,8 +203,8 @@ def normalize(tensor, mean, std, inplace=False): ...@@ -203,8 +203,8 @@ def normalize(tensor, mean, std, inplace=False):
if not inplace: if not inplace:
tensor = tensor.clone() tensor = tensor.clone()
mean = torch.tensor(mean, dtype=torch.float32) mean = torch.tensor(mean, dtype=torch.float32, device=tensor.device)
std = torch.tensor(std, dtype=torch.float32) std = torch.tensor(std, dtype=torch.float32, device=tensor.device)
tensor.sub_(mean[:, None, None]).div_(std[:, None, None]) tensor.sub_(mean[:, None, None]).div_(std[:, None, None])
return tensor return tensor
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment