r"""PyTorch re-implementation adapted from the Lua code in ``https://github.com/bobbens/sketch_simplification''. """ import torch import torch.nn as nn import torch.nn.functional as F import math # from canvas import DOWNLOAD_TO_CACHE from artist import DOWNLOAD_TO_CACHE __all__ = ['SketchSimplification', 'sketch_simplification_gan', 'sketch_simplification_mse', 'sketch_to_pencil_v1', 'sketch_to_pencil_v2'] class SketchSimplification(nn.Module): r"""NOTE: 1. Input image should has only one gray channel. 2. Input image size should be divisible by 8. 3. Sketch in the input/output image is in dark color while background in light color. """ def __init__(self, mean, std): assert isinstance(mean, float) and isinstance(std, float) super(SketchSimplification, self).__init__() self.mean = mean self.std = std # layers self.layers = nn.Sequential( nn.Conv2d(1, 48, 5, 2, 2), nn.ReLU(inplace=True), nn.Conv2d(48, 128, 3, 1, 1), nn.ReLU(inplace=True), nn.Conv2d(128, 128, 3, 1, 1), nn.ReLU(inplace=True), nn.Conv2d(128, 128, 3, 2, 1), nn.ReLU(inplace=True), nn.Conv2d(128, 256, 3, 1, 1), nn.ReLU(inplace=True), nn.Conv2d(256, 256, 3, 1, 1), nn.ReLU(inplace=True), nn.Conv2d(256, 256, 3, 2, 1), nn.ReLU(inplace=True), nn.Conv2d(256, 512, 3, 1, 1), nn.ReLU(inplace=True), nn.Conv2d(512, 1024, 3, 1, 1), nn.ReLU(inplace=True), nn.Conv2d(1024, 1024, 3, 1, 1), nn.ReLU(inplace=True), nn.Conv2d(1024, 1024, 3, 1, 1), nn.ReLU(inplace=True), nn.Conv2d(1024, 1024, 3, 1, 1), nn.ReLU(inplace=True), nn.Conv2d(1024, 512, 3, 1, 1), nn.ReLU(inplace=True), nn.Conv2d(512, 256, 3, 1, 1), nn.ReLU(inplace=True), nn.ConvTranspose2d(256, 256, 4, 2, 1), nn.ReLU(inplace=True), nn.Conv2d(256, 256, 3, 1, 1), nn.ReLU(inplace=True), nn.Conv2d(256, 128, 3, 1, 1), nn.ReLU(inplace=True), nn.ConvTranspose2d(128, 128, 4, 2, 1), nn.ReLU(inplace=True), nn.Conv2d(128, 128, 3, 1, 1), nn.ReLU(inplace=True), nn.Conv2d(128, 48, 3, 1, 1), nn.ReLU(inplace=True), nn.ConvTranspose2d(48, 48, 4, 2, 1), nn.ReLU(inplace=True), nn.Conv2d(48, 24, 3, 1, 1), nn.ReLU(inplace=True), nn.Conv2d(24, 1, 3, 1, 1), nn.Sigmoid()) def forward(self, x): r"""x: [B, 1, H, W] within range [0, 1]. Sketch pixels in dark color. """ x = (x - self.mean) / self.std return self.layers(x) def sketch_simplification_gan(pretrained=False): model = SketchSimplification(mean=0.9664114577640158, std=0.0858381272736797) if pretrained: # model.load_state_dict(torch.load( # DOWNLOAD_TO_CACHE('models/sketch_simplification/sketch_simplification_gan.pth'), # map_location='cpu')) model.load_state_dict(torch.load( DOWNLOAD_TO_CACHE('VideoComposer/Hangjie/models/sketch_simplification/sketch_simplification_gan.pth'), map_location='cpu')) return model def sketch_simplification_mse(pretrained=False): model = SketchSimplification(mean=0.9664423107454593, std=0.08583666033640507) if pretrained: model.load_state_dict(torch.load( DOWNLOAD_TO_CACHE('models/sketch_simplification/sketch_simplification_mse.pth'), map_location='cpu')) return model def sketch_to_pencil_v1(pretrained=False): model = SketchSimplification(mean=0.9817833515894078, std=0.0925009022585048) if pretrained: model.load_state_dict(torch.load( DOWNLOAD_TO_CACHE('models/sketch_simplification/sketch_to_pencil_v1.pth'), map_location='cpu')) return model def sketch_to_pencil_v2(pretrained=False): model = SketchSimplification(mean=0.9851298627337799, std=0.07418377454883571) if pretrained: model.load_state_dict(torch.load( DOWNLOAD_TO_CACHE('models/sketch_simplification/sketch_to_pencil_v2.pth'), map_location='cpu')) return model