"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "f96c42fca53230057b16941b078a0a9eee06e20f"
Unverified Commit 220e0ff9 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Add multi-type support on `get_weight()` (#4967)

* Add multi-type support on get_weight.

* Fix bug on method call.

* Adding logging suffix for QAT.
parent ebc4ca76
...@@ -30,12 +30,12 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, arg ...@@ -30,12 +30,12 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, arg
for i, (image, target) in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)): for i, (image, target) in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)):
start_time = time.time() start_time = time.time()
image, target = image.to(device), target.to(device) image, target = image.to(device), target.to(device)
with torch.cuda.amp.autocast(enabled=args.amp): with torch.cuda.amp.autocast(enabled=scaler is not None):
output = model(image) output = model(image)
loss = criterion(output, target) loss = criterion(output, target)
optimizer.zero_grad() optimizer.zero_grad()
if args.amp: if scaler is not None:
scaler.scale(loss).backward() scaler.scale(loss).backward()
if args.clip_grad_norm is not None: if args.clip_grad_norm is not None:
# we should unscale the gradients of optimizer's assigned params if do gradient clipping # we should unscale the gradients of optimizer's assigned params if do gradient clipping
......
...@@ -121,7 +121,7 @@ def main(args): ...@@ -121,7 +121,7 @@ def main(args):
if args.distributed: if args.distributed:
train_sampler.set_epoch(epoch) train_sampler.set_epoch(epoch)
print("Starting training for epoch", epoch) print("Starting training for epoch", epoch)
train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args.print_freq) train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args)
lr_scheduler.step() lr_scheduler.step()
with torch.inference_mode(): with torch.inference_mode():
if epoch >= args.num_observer_update_epochs: if epoch >= args.num_observer_update_epochs:
...@@ -132,7 +132,7 @@ def main(args): ...@@ -132,7 +132,7 @@ def main(args):
model.apply(torch.nn.intrinsic.qat.freeze_bn_stats) model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
print("Evaluate QAT model") print("Evaluate QAT model")
evaluate(model, criterion, data_loader_test, device=device) evaluate(model, criterion, data_loader_test, device=device, log_suffix="QAT")
quantized_eval_model = copy.deepcopy(model_without_ddp) quantized_eval_model = copy.deepcopy(model_without_ddp)
quantized_eval_model.eval() quantized_eval_model.eval()
quantized_eval_model.to(torch.device("cpu")) quantized_eval_model.to(torch.device("cpu"))
...@@ -261,6 +261,7 @@ def get_args_parser(add_help=True): ...@@ -261,6 +261,7 @@ def get_args_parser(add_help=True):
parser.add_argument( parser.add_argument(
"--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)" "--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)"
) )
parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)")
# Prototype models only # Prototype models only
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
......
...@@ -101,6 +101,11 @@ def get_weight(fn: Callable, weight_name: str) -> Weights: ...@@ -101,6 +101,11 @@ def get_weight(fn: Callable, weight_name: str) -> Weights:
# TODO: Replace ann.__args__ with typing.get_args(ann) after python >= 3.8 # TODO: Replace ann.__args__ with typing.get_args(ann) after python >= 3.8
for t in ann.__args__: # type: ignore[union-attr] for t in ann.__args__: # type: ignore[union-attr]
if isinstance(t, type) and issubclass(t, Weights): if isinstance(t, type) and issubclass(t, Weights):
# ensure the name exists. handles builders with multiple types of weights like in quantization
try:
t.from_str(weight_name)
except ValueError:
continue
weights_class = t weights_class = t
break break
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment