Unverified Commit b0c7d09f authored by rohithkrn's avatar rohithkrn Committed by GitHub
Browse files

bfloat16 support for mgpu (#19)

* bfloat16 support for apex DDP

* enable mgpu tests for fp16 and bf16

* update Dockerfile
parent aea81c0c
ARG FROM_IMAGE=lcskrishna/rocm-pytorch:rocm3.3_ubuntu16.04_py3.6_pytorch_updated
ARG FROM_IMAGE=lcskrishna/rocm-pytorch:rocm3.3_ubuntu16.04_py3.6_pytorch_bfloat16_mgpu
FROM ${FROM_IMAGE}
RUN \
......
......@@ -48,8 +48,8 @@ def apply_flat_dist_call(bucket, call, extra_args=None):
for buf, synced in zip(bucket, unflatten(coalesced, bucket)):
buf.copy_(synced)
def split_half_float_double(tensors):
dtypes = ["torch.cuda.HalfTensor", "torch.cuda.FloatTensor", "torch.cuda.DoubleTensor"]
def split_half_float_double_bfloat16(tensors):
dtypes = ["torch.cuda.HalfTensor", "torch.cuda.FloatTensor", "torch.cuda.DoubleTensor", "torch.cuda.BFloat16Tensor"]
buckets = []
for i, dtype in enumerate(dtypes):
bucket = [t for t in tensors if t.type() == dtype]
......@@ -240,7 +240,8 @@ class DistributedDataParallel(Module):
self.param_type_to_tmp_i = {"torch.cuda.HalfTensor" : 0,
"torch.cuda.FloatTensor" : 1,
"torch.cuda.DoubleTensor" : 2}
"torch.cuda.DoubleTensor" : 2,
"torch.cuda.BFloat16Tensor" : 3}
if multi_tensor_applier.available:
# TODO: I really need to centralize the C++ backed imports
......@@ -498,7 +499,7 @@ class DistributedDataParallel(Module):
else:
grads = [param.grad.data for param in self.module.parameters() if param.grad is not None]
split_buckets = split_half_float_double(grads)
split_buckets = split_half_float_double_bfloat16(grads)
# If retain_allreduce_buffers is True and delay_allreduce is False,
# this will only be done during the first backward pass, ignored by the
......@@ -578,8 +579,8 @@ class DistributedDataParallel(Module):
if self.needs_refresh:
self.active_i_buckets = []
self.buckets = []
self.tmp_buckets = [[], [], []] # [running half, float, double buckets]
self.tmp_numels = [0, 0, 0]
self.tmp_buckets = [[], [], [], []] # [running half, float, double, bfloat16 buckets]
self.tmp_numels = [0, 0, 0, 0]
self.bucket_sizes = []
self.param_id_to_active_i = {id(param) : i for i, param in enumerate(param_list)}
self.param_id_to_bucket = {}
......
......@@ -9,6 +9,7 @@ parser = argparse.ArgumentParser()
# FOR DISTRIBUTED: Parse for the local_rank argument, which will be supplied
# automatically by torch.distributed.launch.
parser.add_argument("--local_rank", default=0, type=int)
parser.add_argument("--opt_level", default="O2", type=str)
args = parser.parse_args()
# FOR DISTRIBUTED: If we are running under torch.distributed.launch,
......@@ -42,7 +43,7 @@ y = torch.randn(N, D_out, device='cuda')
model = torch.nn.Linear(D_in, D_out).cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
model, optimizer = amp.initialize(model, optimizer, opt_level="O2")
model, optimizer = amp.initialize(model, optimizer, opt_level=args.opt_level)
if args.distributed:
# FOR DISTRIBUTED: After amp.initialize, wrap the model with
......
......@@ -14,6 +14,9 @@ for model_rank0, model_rank1, master_rank0, master_rank1 in zip(
model_params_rank1,
master_params_rank0,
master_params_rank1):
# converting model params to float is a hack since allclose doesn't support bfloat16 yet.
model_rank0 = model_rank0.float()
model_rank1 = model_rank1.float()
assert torch.allclose(model_rank0, model_rank1), "Model param mismatch"
assert torch.allclose(master_rank0, master_rank1), "Master param mismatch"
# Some debugging/investigation assistance code:
......@@ -23,6 +26,6 @@ for model_rank0, model_rank1, master_rank0, master_rank1 in zip(
# print(maxval.item(), maxind.item(), offending_val_half.item(), offending_val_float.item(),
# offending_val_float.half().item())
# rtol needs to be > 2^-11 because of denormals...
assert torch.allclose(model_rank0, master_rank0.half(), rtol=.005), "Model-master mismatch"
assert torch.allclose(model_rank0, master_rank0, rtol=.005), "Model-master mismatch"
print("OK: Model and master params match across ranks.")
#!/bin/bash
set -e
# To run the test on 2 gpus
export WORLD_SIZE=2
# Test with opt_level="O2"
echo "running opt_level O2"
python3.6 -m torch.distributed.launch --nproc_per_node=2 amp_master_params.py --opt_level "O2"
python3.6 compare.py
# delete the model files
echo -e "O2 test completed. Deleting model files\n"
rm rank0model.pth
rm rank1model.pth
rm rank0master.pth
rm rank1master.pth
# Test with opt_level="O5"
echo "running opt_level O5"
python3.6 -m torch.distributed.launch --nproc_per_node=2 amp_master_params.py --opt_level "O5"
python3.6 compare.py
# delete the model files
echo "O5 test completed. Deleting model files"
rm rank0model.pth
rm rank1model.pth
rm rank0master.pth
rm rank1master.pth
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment