"""This script defines the base network model for Deep3DFaceRecon_pytorch
"""
importos
importnumpyasnp
importtorch
fromcollectionsimportOrderedDict
fromabcimportABC,abstractmethod
from.importnetworks
classBaseModel(ABC):
"""This class is an abstract base class (ABC) for models.
To create a subclass, you need to implement the following five functions:
-- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
-- <set_input>: unpack data from dataset and apply preprocessing.
-- <forward>: produce intermediate results.
-- <optimize_parameters>: calculate losses, gradients, and update network weights.
-- <modify_commandline_options>: (optionally) add model-specific options and set default options.
"""
def__init__(self,opt):
"""Initialize the BaseModel class.
Parameters:
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
When creating your custom class, you need to implement your own initialization.
In this fucntion, you should first call <BaseModel.__init__(self, opt)>
Then, you need to define four lists:
-- self.loss_names (str list): specify the training losses that you want to plot and save.
-- self.model_names (str list): specify the images that you want to display and save.
-- self.visual_names (str list): define networks used in our training.
-- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
"""
self.opt=opt
self.isTrain=False
self.device=torch.device('cpu')
self.save_dir=" "# os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir
self.loss_names=[]
self.model_names=[]
self.visual_names=[]
self.parallel_names=[]
self.optimizers=[]
self.image_paths=[]
self.metric=0# used for learning rate policy 'plateau'
@staticmethod
defdict_grad_hook_factory(add_func=lambdax:x):
saved_dict=dict()
defhook_gen(name):
defgrad_hook(grad):
saved_vals=add_func(grad)
saved_dict[name]=saved_vals
returngrad_hook
returnhook_gen,saved_dict
@staticmethod
defmodify_commandline_options(parser,is_train):
"""Add new model-specific options, and rewrite default values for existing options.
Parameters:
parser -- original option parser
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
Returns:
the modified parser.
"""
returnparser
@abstractmethod
defset_input(self,input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
Parameters:
input (dict): includes the data itself and its metadata information.
"""
pass
@abstractmethod
defforward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
pass
@abstractmethod
defoptimize_parameters(self):
"""Calculate losses, gradients, and update network weights; called in every training iteration"""
pass
defsetup(self,opt):
"""Load and print networks; create schedulers
Parameters:
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
"""Add new model-specific options and rewrite default values for existing options.
Parameters:
parser -- the option parser
is_train -- if it is training phase or test phase. You can use this flag to add training-specific or test-specific options.
Returns:
the modified parser.
"""
parser.set_defaults(dataset_mode='aligned')# You can rewrite default values for this model. For example, this model usually uses aligned dataset as its dataset.
ifis_train:
parser.add_argument('--lambda_regression',type=float,default=1.0,help='weight for the regression loss')# You can define new arguments for this model.
returnparser
def__init__(self,opt):
"""Initialize this model class.
Parameters:
opt -- training/test options
A few things can be done here.
- (required) call the initialization function of BaseModel
- define loss function, visualization images, model names, and optimizers
"""
BaseModel.__init__(self,opt)# call the initialization method of BaseModel
# specify the training losses you want to print out. The program will call base_model.get_current_losses to plot the losses to the console and save them to the disk.
self.loss_names=['loss_G']
# specify the images you want to save and display. The program will call base_model.get_current_visuals to save and display these images.
self.visual_names=['data_A','data_B','output']
# specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks to save and load networks.
# you can use opt.isTrain to specify different behaviors for training and test. For example, some networks will not be used during test, and you don't need to load them.
self.model_names=['G']
# define networks; you can use opt.isTrain to specify different behaviors for training and test.