Commit 8c001321 authored by boomb0om's avatar boomb0om
Browse files

Add autocast support

parent 861000cc
......@@ -25,6 +25,7 @@ class RealESRGAN:
self.model.eval()
self.model.to(self.device)
@torch.cuda.amp.autocast()
def predict(self, lr_image, batch_size=4, patches_size=192,
padding=24, pad_size=15):
scale = self.scale
......@@ -41,7 +42,7 @@ class RealESRGAN:
for i in range(batch_size, img.shape[0], batch_size):
res = torch.cat((res, self.model(img[i:i+batch_size])), 0)
sr_image = res.permute((0,2,3,1)).cpu().clamp_(0, 1)
sr_image = res.permute((0,2,3,1)).clamp_(0, 1).cpu()
np_sr_image = sr_image.numpy()
padded_size_scaled = tuple(np.multiply(p_shape[0:2], scale)) + (3,)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment