import torch from collections import OrderedDict from basicsr.archs.ridnet_arch import RIDNet if __name__ == '__main__': ori_net_checkpoint = torch.load( 'experiments/pretrained_models/RIDNet/RIDNet_official_original.pt', map_location=lambda storage, loc: storage) rid_net = RIDNet(3, 64, 3) new_ridnet_dict = OrderedDict() rid_net_namelist = [] for name, param in rid_net.named_parameters(): rid_net_namelist.append(name) count = 0 for name, param in ori_net_checkpoint.items(): new_ridnet_dict[rid_net_namelist[count]] = param count += 1 rid_net.load_state_dict(new_ridnet_dict) torch.save(rid_net.state_dict(), 'experiments/pretrained_models/RIDNet/RIDNet.pth')