# Copyright (c) OpenMMLab. All rights reserved. import torch class LoadNoInit: """Initialize model without parameter initialization.""" def __init__(self): self.constant_ = torch.nn.init.constant_ self.zeros_ = torch.nn.init.zeros_ self.ones_ = torch.nn.init.ones_ self.uniform_ = torch.nn.init.uniform_ self.normal_ = torch.nn.init.normal_ self.kaiming_uniform_ = torch.nn.init.kaiming_uniform_ self.kaiming_normal_ = torch.nn.init.kaiming_normal_ def __enter__(self, *args, **kwargs): """Replace initializers with no-op.""" torch.nn.init.constant_ = lambda *args, **kwargs: None torch.nn.init.zeros_ = lambda *args, **kwargs: None torch.nn.init.ones_ = lambda *args, **kwargs: None torch.nn.init.uniform_ = lambda *args, **kwargs: None torch.nn.init.normal_ = lambda *args, **kwargs: None torch.nn.init.kaiming_uniform_ = lambda *args, **kwargs: None torch.nn.init.kaiming_normal_ = lambda *args, **kwargs: None def __exit__(self, *args, **kwargs): """Recover.""" torch.nn.init.constant_ = self.constant_ torch.nn.init.zeros_ = self.zeros_ torch.nn.init.ones_ = self.ones_ torch.nn.init.uniform_ = self.uniform_ torch.nn.init.normal_ = self.normal_ torch.nn.init.kaiming_uniform_ = self.kaiming_uniform_ torch.nn.init.kaiming_normal_ = self.kaiming_normal_