".github/git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "36815ef914216a61307477ccff3ee6fb0d7fda20"
Unverified Commit 674e8140 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Add weight averaging and storing methods in references utils (#3352)

* Adding the average_checkpoints() method.

* Adding the store_model_weights() method.
parent 03fec9c7
from collections import defaultdict, deque from collections import defaultdict, deque, OrderedDict
import copy
import datetime import datetime
import hashlib
import time import time
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -252,3 +254,126 @@ def init_distributed_mode(args): ...@@ -252,3 +254,126 @@ def init_distributed_mode(args):
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
world_size=args.world_size, rank=args.rank) world_size=args.world_size, rank=args.rank)
setup_for_distributed(args.rank == 0) setup_for_distributed(args.rank == 0)
def average_checkpoints(inputs):
"""Loads checkpoints from inputs and returns a model with averaged weights. Original implementation taken from:
https://github.com/pytorch/fairseq/blob/a48f235636557b8d3bc4922a6fa90f3a0fa57955/scripts/average_checkpoints.py#L16
Args:
inputs (List[str]): An iterable of string paths of checkpoints to load from.
Returns:
A dict of string keys mapping to various values. The 'model' key
from the returned dict should correspond to an OrderedDict mapping
string parameter names to torch Tensors.
"""
params_dict = OrderedDict()
params_keys = None
new_state = None
num_models = len(inputs)
for fpath in inputs:
with open(fpath, "rb") as f:
state = torch.load(
f,
map_location=(
lambda s, _: torch.serialization.default_restore_location(s, "cpu")
),
)
# Copies over the settings from the first checkpoint
if new_state is None:
new_state = state
model_params = state["model"]
model_params_keys = list(model_params.keys())
if params_keys is None:
params_keys = model_params_keys
elif params_keys != model_params_keys:
raise KeyError(
"For checkpoint {}, expected list of params: {}, "
"but found: {}".format(f, params_keys, model_params_keys)
)
for k in params_keys:
p = model_params[k]
if isinstance(p, torch.HalfTensor):
p = p.float()
if k not in params_dict:
params_dict[k] = p.clone()
# NOTE: clone() is needed in case of p is a shared parameter
else:
params_dict[k] += p
averaged_params = OrderedDict()
for k, v in params_dict.items():
averaged_params[k] = v
if averaged_params[k].is_floating_point():
averaged_params[k].div_(num_models)
else:
averaged_params[k] //= num_models
new_state["model"] = averaged_params
return new_state
def store_model_weights(model, checkpoint_path, checkpoint_key='model', strict=True):
"""
This method can be used to prepare weights files for new models. It receives as
input a model architecture and a checkpoint from the training script and produces
a file with the weights ready for release.
Examples:
from torchvision import models as M
# Classification
model = M.mobilenet_v3_large(pretrained=False)
print(store_model_weights(model, './class.pth'))
# Quantized Classification
model = M.quantization.mobilenet_v3_large(pretrained=False, quantize=False)
model.fuse_model()
model.qconfig = torch.quantization.get_default_qat_qconfig('qnnpack')
_ = torch.quantization.prepare_qat(model, inplace=True)
print(store_model_weights(model, './qat.pth'))
# Object Detection
model = M.detection.fasterrcnn_mobilenet_v3_large_fpn(pretrained=False, pretrained_backbone=False)
print(store_model_weights(model, './obj.pth'))
# Segmentation
model = M.segmentation.deeplabv3_mobilenet_v3_large(pretrained=False, pretrained_backbone=False, aux_loss=True)
print(store_model_weights(model, './segm.pth', strict=False))
Args:
model (pytorch.nn.Module): The model on which the weights will be loaded for validation purposes.
checkpoint_path (str): The path of the checkpoint we will load.
checkpoint_key (str, optional): The key of the checkpoint where the model weights are stored.
Default: "model".
strict (bool): whether to strictly enforce that the keys
in :attr:`state_dict` match the keys returned by this module's
:meth:`~torch.nn.Module.state_dict` function. Default: ``True``
Returns:
output_path (str): The location where the weights are saved.
"""
# Store the new model next to the checkpoint_path
checkpoint_path = os.path.abspath(checkpoint_path)
output_dir = os.path.dirname(checkpoint_path)
# Deep copy to avoid side-effects on the model object.
model = copy.deepcopy(model)
checkpoint = torch.load(checkpoint_path, map_location='cpu')
# Load the weights to the model to validate that everything works
# and remove unnecessary weights (such as auxiliaries, etc)
model.load_state_dict(checkpoint[checkpoint_key], strict=strict)
tmp_path = os.path.join(output_dir, str(model.__hash__()))
torch.save(model.state_dict(), tmp_path)
sha256_hash = hashlib.sha256()
with open(tmp_path, "rb") as f:
# Read and update hash string value in blocks of 4K
for byte_block in iter(lambda: f.read(4096), b""):
sha256_hash.update(byte_block)
hh = sha256_hash.hexdigest()
output_path = os.path.join(output_dir, "weights-" + str(hh[:8]) + ".pth")
os.replace(tmp_path, output_path)
return output_path
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment