Commit 8c001321 authored by boomb0om's avatar boomb0om
Browse files

Add autocast support

parent 861000cc
...@@ -25,6 +25,7 @@ class RealESRGAN: ...@@ -25,6 +25,7 @@ class RealESRGAN:
self.model.eval() self.model.eval()
self.model.to(self.device) self.model.to(self.device)
@torch.cuda.amp.autocast()
def predict(self, lr_image, batch_size=4, patches_size=192, def predict(self, lr_image, batch_size=4, patches_size=192,
padding=24, pad_size=15): padding=24, pad_size=15):
scale = self.scale scale = self.scale
...@@ -41,7 +42,7 @@ class RealESRGAN: ...@@ -41,7 +42,7 @@ class RealESRGAN:
for i in range(batch_size, img.shape[0], batch_size): for i in range(batch_size, img.shape[0], batch_size):
res = torch.cat((res, self.model(img[i:i+batch_size])), 0) res = torch.cat((res, self.model(img[i:i+batch_size])), 0)
sr_image = res.permute((0,2,3,1)).cpu().clamp_(0, 1) sr_image = res.permute((0,2,3,1)).clamp_(0, 1).cpu()
np_sr_image = sr_image.numpy() np_sr_image = sr_image.numpy()
padded_size_scaled = tuple(np.multiply(p_shape[0:2], scale)) + (3,) padded_size_scaled = tuple(np.multiply(p_shape[0:2], scale)) + (3,)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment