Commit a730f38f authored by Michael Carilli's avatar Michael Carilli
Browse files

Adding simple distributed example for #200

parent 7f0d8c87
......@@ -117,7 +117,7 @@ def main():
args.world_size = 1
if args.distributed:
args.gpu = args.local_rank % torch.cuda.device_count()
args.gpu = args.local_rank
torch.cuda.set_device(args.gpu)
torch.distributed.init_process_group(backend='nccl',
init_method='env://')
......@@ -334,7 +334,7 @@ def train(train_loader, model, criterion, optimizer, epoch):
if args.prof: torch.cuda.nvtx.range_pop()
if i%args.print_freq == 0:
# Every print_freq iterations, check the loss accuracy and speed.
# Every print_freq iterations, check the loss, accuracy, and speed.
# For best performance, it doesn't make sense to print these metrics every
# iteration, since they incur an allreduce and some host<->device syncs.
......
**distributed_data_parallel.py** and **run.sh** show an example using Amp with
[apex.parallel.DistributedDataParallel](https://nvidia.github.io/apex/parallel.html) or
[torch.nn.parallel.DistributedDataParallel](https://pytorch.org/docs/stable/nn.html#distributeddataparallel)
and the Pytorch multiprocess launcher script,
[torch.distributed.launch](https://pytorch.org/docs/master/distributed.html#launch-utility).
The use of `Amp` with distributed does not need to change from ordinary
single-process use. The only gotcha is that wrapping your model with `DistributedDataParallel` must
come after the call to `amp.initialize`. Test via
```bash
bash run.sh
```
**This is intended purely as an instructional example, not a performance showcase.**
import torch
import argparse
import os
from apex import amp
# FOR DISTRIBUTED: (can also use torch.nn.parallel.DistributedDataParallel instead)
from apex.parallel import DistributedDataParallel
parser = argparse.ArgumentParser()
# FOR DISTRIBUTED: Parse for the local_rank argument, which will be supplied
# automatically by torch.distributed.launch.
parser.add_argument("--local_rank", default=0, type=int)
args = parser.parse_args()
# FOR DISTRIBUTED: If we are running under torch.distributed.launch,
# the 'WORLD_SIZE' environment variable will also be set automatically.
args.distributed = False
if 'WORLD_SIZE' in os.environ:
args.distributed = int(os.environ['WORLD_SIZE']) > 1
if args.distributed:
# FOR DISTRIBUTED: Set the device according to local_rank.
torch.cuda.set_device(args.local_rank)
# FOR DISTRIBUTED: Initialize the backend. torch.distributed.launch will provide
# environment variables, and requires that you use init_method=`env://`.
torch.distributed.init_process_group(backend='nccl',
init_method='env://')
torch.backends.cudnn.benchmark = True
N, D_in, D_out = 64, 1024, 16
# Each process receives its own batch of "fake input data" and "fake target data."
# The "training loop" in each process just uses this fake batch over and over.
# https://github.com/NVIDIA/apex/tree/master/examples/imagenet provides a more realistic
# example of distributed data sampling for both training and validation.
x = torch.randn(N, D_in, device='cuda')
y = torch.randn(N, D_out, device='cuda')
model = torch.nn.Linear(D_in, D_out).cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
if args.distributed:
# FOR DISTRIBUTED: After amp.initialize, wrap the model with
# apex.parallel.DistributedDataParallel.
model = DistributedDataParallel(model)
# torch.nn.parallel.DistributedDataParallel is also fine, with some added args:
# model = torch.nn.parallel.DistributedDataParallel(model,
# device_ids=[args.local_rank],
# output_device=args.local_rank)
loss_fn = torch.nn.MSELoss()
for t in range(500):
optimizer.zero_grad()
y_pred = model(x)
loss = loss_fn(y_pred, y)
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
optimizer.step()
if args.local_rank == 0:
print("final loss = ", loss)
#!/bin/bash
python -m torch.distributed.launch --nproc_per_node=2 distributed_data_parallel.py
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment