Unverified Commit 39021408 authored by Vasiliy Kuznetsov's avatar Vasiliy Kuznetsov Committed by GitHub
Browse files

torchvision QAT tutorial: update for QAT with DDP (#2280)

Summary:

We've made two recent changes to QAT in PyTorch core:
1. add support for SyncBatchNorm
2. make eager mode QAT prepare scripts respect device affinity

This PR updates the torchvision QAT reference script to take
advantage of both of these.  This should be landed after
https://github.com/pytorch/pytorch/pull/39337 (the last PT
fix) to avoid compatibility issues.

Test Plan:

```
python -m torch.distributed.launch
  --nproc_per_node 8
  --use_env
  references/classification/train_quantization.py
  --data-path {imagenet1k_subset}
  --output-dir {tmp}
  --sync-bn
```

Reviewers:

Subscribers:

Tasks:

Tags:
parent 34810c0c
...@@ -51,12 +51,16 @@ def main(args): ...@@ -51,12 +51,16 @@ def main(args):
print("Creating model", args.model) print("Creating model", args.model)
# when training quantized models, we always start from a pre-trained fp32 reference model # when training quantized models, we always start from a pre-trained fp32 reference model
model = torchvision.models.quantization.__dict__[args.model](pretrained=True, quantize=args.test_only) model = torchvision.models.quantization.__dict__[args.model](pretrained=True, quantize=args.test_only)
model.to(device)
if not (args.test_only or args.post_training_quantize): if not (args.test_only or args.post_training_quantize):
model.fuse_model() model.fuse_model()
model.qconfig = torch.quantization.get_default_qat_qconfig(args.backend) model.qconfig = torch.quantization.get_default_qat_qconfig(args.backend)
torch.quantization.prepare_qat(model, inplace=True) torch.quantization.prepare_qat(model, inplace=True)
if args.distributed and args.sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
optimizer = torch.optim.SGD( optimizer = torch.optim.SGD(
model.parameters(), lr=args.lr, momentum=args.momentum, model.parameters(), lr=args.lr, momentum=args.momentum,
weight_decay=args.weight_decay) weight_decay=args.weight_decay)
...@@ -65,8 +69,6 @@ def main(args): ...@@ -65,8 +69,6 @@ def main(args):
step_size=args.lr_step_size, step_size=args.lr_step_size,
gamma=args.lr_gamma) gamma=args.lr_gamma)
model.to(device)
criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss()
model_without_ddp = model model_without_ddp = model
if args.distributed: if args.distributed:
...@@ -224,6 +226,12 @@ def parse_args(): ...@@ -224,6 +226,12 @@ def parse_args():
It also serializes the transforms", It also serializes the transforms",
action="store_true", action="store_true",
) )
parser.add_argument(
"--sync-bn",
dest="sync_bn",
help="Use sync batch norm",
action="store_true",
)
parser.add_argument( parser.add_argument(
"--test-only", "--test-only",
dest="test_only", dest="test_only",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment