Unverified Commit 7ed3950e authored by Vasiliy Kuznetsov's avatar Vasiliy Kuznetsov Committed by GitHub
Browse files

vision classification QAT tutorial: fix for DDP (redo) (#2230)

Summary:

Redo of https://github.com/pytorch/vision/pull/2191

Makes the classification QAT tutorial not crash when used
with DDP. There were two issues:

1. the model was moved to GPU before the observers were added, and they
are created on CPU. In the context of this repo, the fix is to finalize
the model before moving to GPU. We can potentially follow up with a
better error message in the future, in a separate PR.
2. the QAT conversion was running on the DDP'ed model, which had various
problems. The fix is to unwrap the model from DDP before cloning it for
evaluation.

There is still work to do on verifying that BN is working correctly in
QAT + DDP, but saving that for a separate PR.

Test Plan:

```
python -m torch.distributed.launch --use_env references/classification/train_quantization.py --data-path {path_to_imagenet_1k} --output_dir {output_dir}
```

Reviewers:

Subscribers:

Tasks:

Tags:
parent e6d3f8c5
...@@ -51,7 +51,6 @@ def main(args): ...@@ -51,7 +51,6 @@ def main(args):
print("Creating model", args.model) print("Creating model", args.model)
# when training quantized models, we always start from a pre-trained fp32 reference model # when training quantized models, we always start from a pre-trained fp32 reference model
model = torchvision.models.quantization.__dict__[args.model](pretrained=True, quantize=args.test_only) model = torchvision.models.quantization.__dict__[args.model](pretrained=True, quantize=args.test_only)
model.to(device)
if not (args.test_only or args.post_training_quantize): if not (args.test_only or args.post_training_quantize):
model.fuse_model() model.fuse_model()
...@@ -66,6 +65,8 @@ def main(args): ...@@ -66,6 +65,8 @@ def main(args):
step_size=args.lr_step_size, step_size=args.lr_step_size,
gamma=args.lr_gamma) gamma=args.lr_gamma)
model.to(device)
criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss()
model_without_ddp = model model_without_ddp = model
if args.distributed: if args.distributed:
...@@ -129,7 +130,7 @@ def main(args): ...@@ -129,7 +130,7 @@ def main(args):
print('Evaluate QAT model') print('Evaluate QAT model')
evaluate(model, criterion, data_loader_test, device=device) evaluate(model, criterion, data_loader_test, device=device)
quantized_eval_model = copy.deepcopy(model) quantized_eval_model = copy.deepcopy(model_without_ddp)
quantized_eval_model.eval() quantized_eval_model.eval()
quantized_eval_model.to(torch.device('cpu')) quantized_eval_model.to(torch.device('cpu'))
torch.quantization.convert(quantized_eval_model, inplace=True) torch.quantization.convert(quantized_eval_model, inplace=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment