Unverified Commit 220e0ff9 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Add multi-type support on `get_weight()` (#4967)

* Add multi-type support on get_weight.

* Fix bug on method call.

* Adding logging suffix for QAT.
parent ebc4ca76
......@@ -30,12 +30,12 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, arg
for i, (image, target) in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)):
start_time = time.time()
image, target = image.to(device), target.to(device)
with torch.cuda.amp.autocast(enabled=args.amp):
with torch.cuda.amp.autocast(enabled=scaler is not None):
output = model(image)
loss = criterion(output, target)
optimizer.zero_grad()
if args.amp:
if scaler is not None:
scaler.scale(loss).backward()
if args.clip_grad_norm is not None:
# we should unscale the gradients of optimizer's assigned params if do gradient clipping
......
......@@ -121,7 +121,7 @@ def main(args):
if args.distributed:
train_sampler.set_epoch(epoch)
print("Starting training for epoch", epoch)
train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args.print_freq)
train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args)
lr_scheduler.step()
with torch.inference_mode():
if epoch >= args.num_observer_update_epochs:
......@@ -132,7 +132,7 @@ def main(args):
model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
print("Evaluate QAT model")
evaluate(model, criterion, data_loader_test, device=device)
evaluate(model, criterion, data_loader_test, device=device, log_suffix="QAT")
quantized_eval_model = copy.deepcopy(model_without_ddp)
quantized_eval_model.eval()
quantized_eval_model.to(torch.device("cpu"))
......@@ -261,6 +261,7 @@ def get_args_parser(add_help=True):
parser.add_argument(
"--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)"
)
parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)")
# Prototype models only
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
......
......@@ -101,6 +101,11 @@ def get_weight(fn: Callable, weight_name: str) -> Weights:
# TODO: Replace ann.__args__ with typing.get_args(ann) after python >= 3.8
for t in ann.__args__: # type: ignore[union-attr]
if isinstance(t, type) and issubclass(t, Weights):
# ensure the name exists. handles builders with multiple types of weights like in quantization
try:
t.from_str(weight_name)
except ValueError:
continue
weights_class = t
break
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment