"git@developer.sourcefind.cn:OpenDAS/fastmoe.git" did not exist on "976ef2b8095fc63068ffcd25d83d0aac590a96c5"
Unverified Commit 68c8481b authored by Conglong Li's avatar Conglong Li Committed by GitHub
Browse files

1-bit Adam v2 (#817)



Authors: @awan-10 @conglongli @samyam @jeffra

What's new:

NCCL-based implementation which provides better performance and usability compared to the MPI-based implementation.
Add support to momentum masks for those parameters with constant zero gradients during training.
Bug fixes (e.g., #813).

* NCCL-based 1-bit Adam + Code Refactor for Comm. Backends (#594)

* NCCL based 1-bit Implementation + Refactor to add communication backends (#593)

* add nccl 1-bit optim.

* temporary commit to save stuff.

* Use dist collectives instead of mpi routines.

* remove old code for comm.

* Fix bugs. still does not work.

* modify to test the nccl side code path

* Initial gather impl. Works intra-node.

* Updates to comm. phase 2. nccl comm. passed the tests.

* refactor code to introduce nccl/mpi as backends for onebit adam.

* Refactor updates to test/engine.

* Fix compile/runtime errors.

* simplify support for nccl/mpi backends.

* Add missign file

* Add compression backend in constructor. Revert later.

* modify test with some perf counting.

* Implement a true non-blocking gather for nccl side.

* Revert "Add compression backend in constructor. Revert later."

This reverts commit df8c40d3105e9f2542a8aa6619e80d675a09753f.

* improve the 1-bit adam test.

* Refactor comm. and compression backend in 1-bit adam.

* Fix the test.

* Fix runtime errors and typos in nccl backend

* fix mpi backend. modify tests.

* modify nccl perf test.

* fix mpi side errors.

* Add an mpi perf test

* Sync DSE.

* Remove old collectives file.

* Undo a typo.

* Graceful failure for torch versions that don't support nccl pt2pt.

* Revert "Merge branch 'master' into staging-1bit-nccl-v2"

This reverts commit 78400850703b4b2d84f11b73c109f56919e748ea, reversing
changes made to a6dba72aeafad63661dfe566d3accd03d00be78c.

* Revert "Revert "Merge branch 'master' into staging-1bit-nccl-v2""

This reverts commit 6dbdd9858bafef4d340c089fdc0e3ddde3706f47.

* comm optimization + 1-bit lamb

* Saving/debugging commit.

* finalizing 1-bit lamb

* finalizing 1-bit lamb

* add momentum mask and chkpt handling for 1-bit adam

* Cleanup and modify nccl test to be runnable with deepspeed launcher.

* Fix format.

* fix formatting again.

* make test runnable without mpi4py

* Add dist.alltoall and dist.allgather instead of custom functions.

* remove debug prints.

* formatting and renaming

* renaming

* renaming

* add unit test, fix existing tests

* skip unit test when torch < 1.8

* revert 1-bit lamb

* flatten momentum when dimension is more than 1

* add warning message for 1-bit adam under fp32

* improve version check

* add fp32 test

* 1-bit adam doc

* fix file name

* doc fix

* torch 1.8 is released

* doc fix

* fix tests

* update news

* add doc for momentum mask

* fix checkpoing handling, add unit test

* checkpoint handling doc

* doc final cleanup

* bump dates

* update tests

* url change

* doc fix

* fix test

* doc update
Co-authored-by: default avatarAmmar Ahmad Awan <ammar.awan@microsoft.com>
Co-authored-by: default avatarJeff Rasley <jerasley@microsoft.com>
parent 12a53b43
......@@ -31,6 +31,7 @@ information [here](https://innovation.microsoft.com/en-us/exploring-ai-at-scale)
# News
* [2021/03/16] [1-bit Adam v2: NCCL-based implementation and more](https://www.deepspeed.ai/tutorials/onebit-adam/)
* [2021/03/08] [ZeRO-3 Offload: Scale your models to trillion parameters without code changes while leveraging both CPUs & GPUs](https://www.deepspeed.ai/news/2021/03/07/zero3-offload.html)
* [2020/11/12] [Simplified install, JIT compiled ops, PyPI releases, and reduced dependencies](#installation)
* [2020/11/10] [Efficient and robust compressed training through progressive layer dropping](https://www.deepspeed.ai/news/2020/10/28/progressive-layer-dropping-news.html)
......
'''
Copyright 2020 The Microsoft DeepSpeed Team
'''
import torch
import cupy
import time
import numpy as np
from mpi4py import MPI
from deepspeed.runtime.compression.cupy import CupyBackend
class MpiBackend(object):
def __init__(self, cuda_aware):
self.comm = MPI.COMM_WORLD
self.rank = self.comm.Get_rank()
self.size = self.comm.Get_size()
self.cuda_aware = cuda_aware
self.compression_backend = CupyBackend()
def my_igather(self, rank, size, comm, sendbuf, recbuf, root):
req = []
if rank == root:
for idx in range(size):
if idx != rank:
req.append(comm.Irecv(recbuf[idx], source=idx))
else:
recbuf[rank] = sendbuf
else:
req.append(comm.Isend(sendbuf, dest=root))
return req
def gather_cuda(self,
rank,
world_size,
comm,
cupy_sign_list_packed,
cupy_recvbuf_sign,
cupy_worker_scale,
cupy_recvbuf_scale):
# We do in-place operations on cupy buffers so we do not return any buffers
requests = []
for idx in range(world_size):
req_sign = self.my_igather(rank,
world_size,
comm,
cupy_sign_list_packed[idx],
cupy_recvbuf_sign,
root=idx)
requests += req_sign
for idx in range(world_size):
req_scale = self.my_igather(rank,
world_size,
comm,
cupy_worker_scale,
cupy_recvbuf_scale,
root=idx)
requests += req_scale
MPI.Request.Waitall(requests)
def gather_host(self,
rank,
world_size,
comm,
cupy_sign_list_packed,
cupy_recvbuf_sign,
cupy_worker_scale,
cupy_recvbuf_scale):
# In-place operations are not possible for newly created cupy arrays
# so we need to return the new buffers
numpy_recvbuf_sign = np.zeros([world_size,
cupy_sign_list_packed[rank].size],
dtype=cupy_sign_list_packed[0].dtype)
numpy_recvbuf_scale = np.zeros([world_size, 1], dtype=cupy_worker_scale.dtype)
# 1. convert from cupy to numpy
numpy_sign_list_packed = cupy_sign_list_packed
for idx in range(world_size):
numpy_sign_list_packed[idx] = cupy.asnumpy(cupy_sign_list_packed[idx])
numpy_worker_scale = cupy.asnumpy(cupy_worker_scale)
numpy_recvbuf_scale = cupy.asnumpy(cupy_recvbuf_scale)
cupy.cuda.get_current_stream().synchronize()
# 2. use numpy buffers for communication
requests = []
for idx in range(world_size):
req_sign = self.my_igather(rank,
world_size,
comm,
numpy_sign_list_packed[idx],
numpy_recvbuf_sign,
root=idx)
requests += req_sign
for idx in range(world_size):
req_scale = self.my_igather(rank,
world_size,
comm,
numpy_worker_scale,
numpy_recvbuf_scale,
root=idx)
requests += req_scale
MPI.Request.Waitall(requests)
# 3. Convert back from numpy to cupy
cupy_recvbuf_sign = cupy.asarray(numpy_recvbuf_sign)
for idx in range(world_size):
cupy_sign_list_packed[idx] = cupy.asarray(numpy_sign_list_packed[idx])
cupy_worker_scale = cupy.asarray(numpy_worker_scale)
cupy_recvbuf_scale = cupy.asarray(numpy_recvbuf_scale)
cupy.cuda.get_current_stream().synchronize()
return cupy_sign_list_packed, cupy_recvbuf_sign, cupy_worker_scale, cupy_recvbuf_scale
def allgather_cuda(self,
comm,
cupy_server_sign_packed,
cupy_recvbuf_sign_server,
cupy_server_scale,
cupy_recvbuf_scale_server):
comm.Allgather(cupy_server_sign_packed, cupy_recvbuf_sign_server)
comm.Allgather(cupy_server_scale, cupy_recvbuf_scale_server)
def allgather_host(self,
comm,
cupy_server_sign_packed,
cupy_recvbuf_sign_server,
cupy_server_scale,
cupy_recvbuf_scale_server):
# 1. Convert cupy to numpy
numpy_recvbuf_sign_server = np.zeros(
[comm.Get_size(),
cupy_server_sign_packed.size],
dtype=cupy_server_sign_packed.dtype)
numpy_recvbuf_scale_server = np.zeros([comm.Get_size(),
1],
dtype=cupy_server_scale.dtype)
numpy_server_sign_packed = cupy.asnumpy(cupy_server_sign_packed)
numpy_recvbuf_sign_server = cupy.asnumpy(cupy_recvbuf_sign_server)
numpy_server_scale = cupy.asnumpy(cupy_server_scale)
numpy_recvbuf_scale_server = cupy.asnumpy(cupy_recvbuf_scale_server)
cupy.cuda.get_current_stream().synchronize()
# 2. Communicate numpy buffers
comm.Allgather(numpy_server_sign_packed, numpy_recvbuf_sign_server)
comm.Allgather(numpy_server_scale, numpy_recvbuf_scale_server)
comm.Barrier()
# 3. Convert numpy back to cupy
cupy_server_sign_packed = cupy.asarray(numpy_server_sign_packed)
cupy_recvbuf_sign_server = cupy.asarray(numpy_recvbuf_sign_server)
cupy_server_scale = cupy.asarray(numpy_server_scale)
cupy_recvbuf_scale_server = cupy.asarray(numpy_recvbuf_scale_server)
cupy.cuda.get_current_stream().synchronize()
return cupy_server_sign_packed, cupy_recvbuf_sign_server, cupy_server_scale, cupy_recvbuf_scale_server
def compressed_allreduce(self,
buffer_m: torch.tensor,
worker_error,
server_error,
local_rank):
all_start_time = time.time()
original_shape = buffer_m.size()
if len(original_shape) > 1:
buffer_m = torch.flatten(buffer_m)
original_size = buffer_m.numel()
worker_error_size = worker_error.numel()
cupy.cuda.Device(local_rank).use()
if original_size != worker_error_size:
empty_tensor = torch.zeros(worker_error_size - original_size,
device=buffer_m.device)
buffer_m = torch.cat([buffer_m, empty_tensor])
buffer_m.add_(worker_error)
worker_scale = torch.norm(buffer_m) / np.sqrt(torch.numel(buffer_m))
worker_error.set_(buffer_m - worker_scale *
buffer_m.sign().add_(1).bool().float().add_(-0.5).mul_(2.0))
cupy_sign_list_packed = self.compression_backend.compress_by_chunk(
self.compression_backend.torch2cupy(buffer_m.sign_().add_(1).bool()),
self.size)
cupy_worker_scale = self.compression_backend.torch2cupy(worker_scale)
cupy_recvbuf_sign = cupy.zeros(
[self.size,
cupy_sign_list_packed[self.rank].size],
dtype=cupy_sign_list_packed[0].dtype)
cupy_recvbuf_scale = cupy.zeros([self.size, 1], dtype=cupy_worker_scale.dtype)
# Communication Phase 1
gather_start = time.time()
if self.cuda_aware:
self.gather_cuda(self.rank,
self.size,
self.comm,
cupy_sign_list_packed,
cupy_recvbuf_sign,
cupy_worker_scale,
cupy_recvbuf_scale)
else:
_, cupy_recvbuf_sign, _, cupy_recvbuf_scale = self.gather_host(self.rank,
self.size,
self.comm,
cupy_sign_list_packed,
cupy_recvbuf_sign,
cupy_worker_scale,
cupy_recvbuf_scale)
gather_end = time.time()
# cupy_sign_list_packed, cupy_worker_scale, worker_scale = None, None, None
cupy_sign_list_packed = None
compensated_server_m = self.compression_backend.cupy2torch(
(cupy.unpackbits(cupy_recvbuf_sign.flatten())).reshape(
self.size,
-1)).float().add_(-0.5).mul_(2.0).mul_(
self.compression_backend.cupy2torch(cupy_recvbuf_scale).mul_(
1 / self.size)).sum(0)
compensated_server_m.add_(server_error)
server_scale = torch.norm(compensated_server_m) / np.sqrt(
compensated_server_m.numel())
server_error.set_(
compensated_server_m - server_scale *
compensated_server_m.sign().add_(1).bool().float().add_(-0.5).mul_(2.0))
cupy_server_scale = self.compression_backend.torch2cupy(server_scale)
cupy_server_sign_packed = self.compression_backend.compress_by_chunk(
self.compression_backend.torch2cupy(
compensated_server_m.sign_().add_(1).bool()),
1)
compensated_server_m = None
cupy_recvbuf_sign_server = cupy.zeros(
[self.size,
cupy_server_sign_packed[0].size],
dtype=cupy_recvbuf_sign.dtype)
cupy_recvbuf_scale_server = cupy.zeros([self.size,
1],
dtype=cupy_recvbuf_scale.dtype)
# cupy_recvbuf_sign, cupy_recvbuf_scale = None, None
cupy_recvbuf_sign = None
# Communication Phase 2
if self.cuda_aware:
self.allgather_cuda(self.comm,
cupy_server_sign_packed[0],
cupy_recvbuf_sign_server,
cupy_server_scale,
cupy_recvbuf_scale_server)
else:
_, cupy_recvbuf_sign_server, _, cupy_recvbuf_scale_server = self.allgather_host(self.comm,
cupy_server_sign_packed[0],
cupy_recvbuf_sign_server,
cupy_server_scale,
cupy_recvbuf_scale_server)
# cupy_server_sign_packed, cupy_server_scale, server_scale = None, None, None
cupy_server_sign_packed = None
buffer_m.data.copy_(
self.compression_backend.cupy2torch(
(cupy.unpackbits(cupy_recvbuf_sign_server.flatten())).reshape(
self.size,
-1)).float().add_(-0.5).mul_(2.0).mul_(
self.compression_backend.cupy2torch(
cupy_recvbuf_scale_server)).flatten().data)
if original_size != worker_error_size:
buffer_m = buffer_m[0:original_size]
if len(original_shape) > 1:
buffer_m = buffer_m.reshape(original_shape)
# cupy_recvbuf_sign_server, cupy_recvbuf_scale_server = None, None
return buffer_m
'''
Copyright 2020 The Microsoft DeepSpeed Team
'''
import torch
import torch.distributed as dist
import time
import cupy
import numpy as np
from deepspeed.runtime.compression.cupy import CupyBackend
class NcclBackend(object):
def __init__(self):
self.world_group = dist.new_group(ranks=range(dist.get_world_size()))
self.rank = dist.get_rank(group=self.world_group)
self.size = dist.get_world_size(group=self.world_group)
self.compression_backend = CupyBackend()
def my_igather(self, rank, size, group, sendbuf, recvbuf, root):
req = []
if rank == root:
for idx in range(size):
if idx != rank:
req.append(dist.irecv(recvbuf[idx], src=idx, group=group))
else:
recvbuf[rank] = sendbuf
else:
req.append(dist.isend(sendbuf, group=group, dst=root))
return req
def my_gather(self, rank, size, group, sendbuf, recvbuf, root):
if rank == root:
for idx in range(size):
if idx != rank:
dist.recv(recvbuf[idx], src=idx, group=group)
else:
recvbuf[rank] = sendbuf
else:
dist.send(sendbuf, group=group, dst=root)
def compressed_allreduce(self,
buffer_m: torch.tensor,
worker_error,
server_error,
local_rank):
# all_start_time = time.time()
original_shape = buffer_m.size()
if len(original_shape) > 1:
buffer_m = torch.flatten(buffer_m)
original_size = buffer_m.numel()
worker_error_size = worker_error.numel()
cupy.cuda.Device(local_rank).use()
if original_size != worker_error_size:
empty_tensor = torch.zeros(worker_error_size - original_size,
device=buffer_m.device)
buffer_m = torch.cat([buffer_m, empty_tensor])
buffer_m.add_(worker_error)
worker_scale = torch.norm(buffer_m) / np.sqrt(torch.numel(buffer_m))
worker_error.set_(buffer_m - worker_scale *
buffer_m.sign().add_(1).bool().float().add_(-0.5).mul_(2.0))
cupy_sign_list_packed = self.compression_backend.compress_by_chunk(
self.compression_backend.torch2cupy(buffer_m.sign_().add_(1).bool()),
self.size)
cupy_worker_scale = self.compression_backend.torch2cupy(worker_scale)
cupy_recvbuf_sign = cupy.zeros(
[self.size,
cupy_sign_list_packed[self.rank].size],
dtype=cupy_sign_list_packed[0].dtype)
# cupy_recvbuf_scale = cupy.zeros([self.size, 1], dtype=cupy_worker_scale.dtype)
sign_list_packed = [
self.compression_backend.cupy2torch(cupy_sign_list_packed[idx])
for idx in range(self.size)
]
# worker_scale = self.compression_backend.cupy2torch(cupy_worker_scale)
recvbuf_sign = self.compression_backend.cupy2torch(cupy_recvbuf_sign)
#recvbuf_scale = self.compression_backend.cupy2torch(cupy_recvbuf_scale)
recvbuf_scale = [
torch.zeros(1,
dtype=worker_scale.dtype,
device=torch.device(local_rank)) for i in range(self.size)
]
# communication phase 1
# gather_start = time.time()
# Alltoall for sign
dist.all_to_all_single(recvbuf_sign, torch.stack(sign_list_packed))
# Allgather for scale
dist.all_gather(recvbuf_scale, worker_scale)
# gather_end = time.time()
# cupy_sign_list_packed, sign_list_packed, cupy_worker_scale, worker_scale = None, None, None, None
cupy_sign_list_packed = None
cupy_recvbuf_sign = self.compression_backend.torch2cupy(recvbuf_sign)
#cupy_recvbuf_scale = self.compression_backend.torch2cupy(torch.stack(recvbuf_scale))
compensated_server_m = self.compression_backend.cupy2torch(
(cupy.unpackbits(cupy_recvbuf_sign.flatten())).reshape(
self.size,
-1)).float().add_(-0.5).mul_(2.0).mul_(
torch.stack(recvbuf_scale).mul_(1 / self.size)).sum(0)
compensated_server_m.add_(server_error)
server_scale = torch.norm(compensated_server_m) / np.sqrt(
compensated_server_m.numel())
server_error.set_(
compensated_server_m - server_scale *
compensated_server_m.sign().add_(1).bool().float().add_(-0.5).mul_(2.0))
# cupy_server_scale = self.compression_backend.torch2cupy(server_scale)
cupy_server_sign_packed = self.compression_backend.compress_by_chunk(
self.compression_backend.torch2cupy(
compensated_server_m.sign_().add_(1).bool()),
1)
compensated_server_m = None
cupy_recvbuf_sign_server = cupy.zeros(
[self.size,
cupy_server_sign_packed[0].size],
dtype=cupy_recvbuf_sign.dtype)
# cupy_recvbuf_sign, recvbuf_sign = None, None
cupy_recvbuf_sign = None
server_sign_packed = [
self.compression_backend.cupy2torch(cupy_server_sign_packed[0])
]
recvbuf_sign_server = [
self.compression_backend.cupy2torch(cupy_recvbuf_sign_server[idx])
for idx in range(self.size)
]
# server_scale = self.compression_backend.cupy2torch(cupy_server_scale)
cupy_recvbuf_scale_server = cupy.zeros([self.size,
1],
dtype=cupy_worker_scale.dtype)
# cupy_recvbuf_scale, recvbuf_scale = None, None
recvbuf_scale_server = [
self.compression_backend.cupy2torch(cupy_recvbuf_scale_server[idx])
for idx in range(self.size)
]
# Communication Phase 2
dist.all_gather(recvbuf_sign_server, server_sign_packed[0])
dist.all_gather(recvbuf_scale_server, server_scale)
cupy_server_sign_packed = None
# need to convert from a tensor list to a single tensor
# dist.all_gather only provides a tensor list as the recv/output buffer
recvbuf_sign_server = torch.stack(recvbuf_sign_server)
cupy_recvbuf_sign_server = self.compression_backend.torch2cupy(
recvbuf_sign_server)
buffer_m.data.copy_(
self.compression_backend.cupy2torch(
(cupy.unpackbits(cupy_recvbuf_sign_server.flatten())).reshape(
self.size,
-1)).float().add_(-0.5).mul_(2.0).mul_(
self.compression_backend.cupy2torch(
cupy_recvbuf_scale_server)).flatten().data)
if original_size != worker_error_size:
buffer_m = buffer_m[0:original_size]
if len(original_shape) > 1:
buffer_m = buffer_m.reshape(original_shape)
return buffer_m
'''
Copyright 2020 The Microsoft DeepSpeed Team
'''
import cupy
from torch.utils.dlpack import to_dlpack
from torch.utils.dlpack import from_dlpack
class CupyBackend(object):
def __init__(self):
pass
def torch2cupy(self, tensor):
return cupy.fromDlpack(to_dlpack(tensor))
def cupy2torch(self, cupy_tensor):
return from_dlpack(cupy_tensor.toDlpack())
def compress_by_chunk(self, cupy_bool_tensor, num_chunks):
packed_sign = cupy.packbits(cupy_bool_tensor)
sign_list_packed = cupy.split(packed_sign, num_chunks)
cupy.cuda.get_current_stream().synchronize()
return sign_list_packed
'''
Copyright 2019 The Microsoft DeepSpeed Team
'''
from mpi4py import MPI
import numpy as np
import cupy
def my_igather(rank, size, comm, sendbuf, recbuf, root):
req = []
if rank == root:
for idx in range(size):
if idx != rank:
req.append(comm.Irecv(recbuf[idx], source=idx))
else:
recbuf[rank] = sendbuf
else:
req.append(comm.Isend(sendbuf, dest=root))
return req
def gather_cuda(rank,
world_size,
comm,
cupy_sign_list_packed,
cupy_recvbuf_sign,
cupy_worker_scale,
cupy_recvbuf_scale):
# We do in-place operations on cupy buffers so we do not return any buffers
requests = []
for idx in range(world_size):
req_sign = my_igather(rank,
world_size,
comm,
cupy_sign_list_packed[idx],
cupy_recvbuf_sign,
root=idx)
requests += req_sign
for idx in range(world_size):
req_scale = my_igather(rank,
world_size,
comm,
cupy_worker_scale,
cupy_recvbuf_scale,
root=idx)
requests += req_scale
MPI.Request.Waitall(requests)
def gather_host(rank,
world_size,
comm,
cupy_sign_list_packed,
cupy_recvbuf_sign,
cupy_worker_scale,
cupy_recvbuf_scale):
# In-place operations are not possible for newly created cupy arrays
# so we need to return the new buffers
numpy_recvbuf_sign = np.zeros([world_size,
cupy_sign_list_packed[rank].size],
dtype=cupy_sign_list_packed[0].dtype)
numpy_recvbuf_scale = np.zeros([world_size, 1], dtype=cupy_worker_scale.dtype)
# 1. convert from cupy to numpy
numpy_sign_list_packed = cupy_sign_list_packed
for idx in range(world_size):
numpy_sign_list_packed[idx] = cupy.asnumpy(cupy_sign_list_packed[idx])
numpy_worker_scale = cupy.asnumpy(cupy_worker_scale)
numpy_recvbuf_scale = cupy.asnumpy(cupy_recvbuf_scale)
cupy.cuda.get_current_stream().synchronize()
# 2. use numpy buffers for communication
requests = []
for idx in range(world_size):
req_sign = my_igather(rank,
world_size,
comm,
numpy_sign_list_packed[idx],
numpy_recvbuf_sign,
root=idx)
requests += req_sign
for idx in range(world_size):
req_scale = my_igather(rank,
world_size,
comm,
numpy_worker_scale,
numpy_recvbuf_scale,
root=idx)
requests += req_scale
MPI.Request.Waitall(requests)
# 3. Convert back from numpy to cupy
cupy_recvbuf_sign = cupy.asarray(numpy_recvbuf_sign)
for idx in range(world_size):
cupy_sign_list_packed[idx] = cupy.asarray(numpy_sign_list_packed[idx])
cupy_worker_scale = cupy.asarray(numpy_worker_scale)
cupy_recvbuf_scale = cupy.asarray(numpy_recvbuf_scale)
cupy.cuda.get_current_stream().synchronize()
return cupy_sign_list_packed, cupy_recvbuf_sign, cupy_worker_scale, cupy_recvbuf_scale
def allgather_cuda(comm,
cupy_server_sign_packed,
cupy_recvbuf_sign_server,
cupy_server_scale,
cupy_recvbuf_scale_server):
comm.Allgather(cupy_server_sign_packed, cupy_recvbuf_sign_server)
comm.Allgather(cupy_server_scale, cupy_recvbuf_scale_server)
def allgather_host(comm,
cupy_server_sign_packed,
cupy_recvbuf_sign_server,
cupy_server_scale,
cupy_recvbuf_scale_server):
# 1. Convert cupy to numpy
numpy_recvbuf_sign_server = np.zeros([comm.Get_size(),
cupy_server_sign_packed.size],
dtype=cupy_server_sign_packed.dtype)
numpy_recvbuf_scale_server = np.zeros([comm.Get_size(),
1],
dtype=cupy_server_scale.dtype)
numpy_server_sign_packed = cupy.asnumpy(cupy_server_sign_packed)
numpy_recvbuf_sign_server = cupy.asnumpy(cupy_recvbuf_sign_server)
numpy_server_scale = cupy.asnumpy(cupy_server_scale)
numpy_recvbuf_scale_server = cupy.asnumpy(cupy_recvbuf_scale_server)
cupy.cuda.get_current_stream().synchronize()
# 2. Communicate numpy buffers
comm.Allgather(numpy_server_sign_packed, numpy_recvbuf_sign_server)
comm.Allgather(numpy_server_scale, numpy_recvbuf_scale_server)
comm.Barrier()
# 3. Convert numpy back to cupy
cupy_server_sign_packed = cupy.asarray(numpy_server_sign_packed)
cupy_recvbuf_sign_server = cupy.asarray(numpy_recvbuf_sign_server)
cupy_server_scale = cupy.asarray(numpy_server_scale)
cupy_recvbuf_scale_server = cupy.asarray(numpy_recvbuf_scale_server)
cupy.cuda.get_current_stream().synchronize()
return cupy_server_sign_packed, cupy_recvbuf_sign_server, cupy_server_scale, cupy_recvbuf_scale_server
......@@ -675,8 +675,12 @@ class DeepSpeedEngine(Module):
from deepspeed.ops.lamb import FusedLamb
optimizer = FusedLamb(model_parameters, **optimizer_parameters)
elif self.optimizer_name() == ONEBIT_ADAM_OPTIMIZER:
from deepspeed.runtime.fp16.onebit_adam import OnebitAdam
from deepspeed.runtime.fp16.onebit.adam import OnebitAdam
optimizer = OnebitAdam(model_parameters, self, **optimizer_parameters)
if not self.fp16_enabled():
logger.warning(
f'Currently the convergence of 1-bit Adam is only verified under FP16'
)
else:
torch_optimizer = getattr(torch.optim, self.optimizer_name())
optimizer = torch_optimizer(model_parameters, **optimizer_parameters)
......
......@@ -6,19 +6,15 @@ import torch
import importlib
import numpy as np
import time
import cupy
from torch.utils.dlpack import to_dlpack
from torch.utils.dlpack import from_dlpack
from deepspeed.utils.logging import logger
import torch.distributed as dist
from mpi4py import MPI
from deepspeed.runtime.custom_collectives import gather_cuda, gather_host, allgather_cuda, allgather_host
from deepspeed.utils.logging import logger
class OnebitAdam(torch.optim.Optimizer):
"""Implements the 1-bit Adam algorithm. Currently GPU-only.
For usage example please see, TODO DeepSpeed Tutorial
It has been proposed in APMSqueeze (https://arxiv.org/abs/2008.11343)
For usage example please see https://www.deepspeed.ai/tutorials/onebit-adam/
For technical details please read https://arxiv.org/abs/2102.02888
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
......@@ -31,8 +27,6 @@ class OnebitAdam(torch.optim.Optimizer):
eps (float, optional): term added to the denominator to improve
numerical stability. (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
max_coeff(float, optional): maximum value of the lamb coefficient (default: 10.0)
min_coeff(float, optional): minimum value of the lamb coefficient (default: 0.01)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False) NOT SUPPORTED in 1-bit Adam!
......@@ -42,6 +36,7 @@ class OnebitAdam(torch.optim.Optimizer):
second moment estimate as in the original paper. (default: False)
cuda_aware (boolean, required): Set True if the underlying MPI implementation
supports CUDA-Aware communication. (default: False)
comm_backend_name (string, optional): Set to 'mpi' if needed. (default: 'nccl')
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _On the Convergence of Adam and Beyond:
......@@ -60,10 +55,12 @@ class OnebitAdam(torch.optim.Optimizer):
weight_decay=0.,
max_grad_norm=0.,
amsgrad=False,
cuda_aware=False):
cuda_aware=False,
comm_backend_name='nccl'):
if amsgrad:
raise RuntimeError('1-bit Adam does not support the AMSGrad variant.')
defaults = dict(lr=lr,
bias_correction=bias_correction,
betas=betas,
......@@ -72,160 +69,40 @@ class OnebitAdam(torch.optim.Optimizer):
max_grad_norm=max_grad_norm)
super(OnebitAdam, self).__init__(params, defaults)
from mpi4py import MPI
self.eps_mode = 0 if eps_inside_sqrt else 1
assert (dist.is_initialized())
self.comm = MPI.COMM_WORLD
self.rank = self.comm.Get_rank()
self.size = self.comm.Get_size()
self.comm_time = 0.0
self.step_time = 0.0
self.ave_step = 1
self.bk_time = 0.0
self.divider = int(self.size * 8 / np.gcd(self.size, 8))
self.deepspeed = deepspeed
self.adam_freeze_key = False
self.initialize = False
self.freeze_step = freeze_step
self.cuda_aware = cuda_aware
def torch2cupy(self, tensor):
return cupy.fromDlpack(to_dlpack(tensor))
def cupy2torch(self, cupy_tensor):
return from_dlpack(cupy_tensor.toDlpack())
def compress_by_chunk(self, cupy_bool_tensor, num_chunks):
packed_sign = cupy.packbits(cupy_bool_tensor)
sign_list_packed = cupy.split(packed_sign, num_chunks)
cupy.cuda.get_current_stream().synchronize()
return sign_list_packed
def Compressed_Allreduce(self,
buffer_m: torch.tensor,
worker_error,
server_error,
rank,
world_size,
comm,
local_rank):
all_start_time = time.time()
original_size = buffer_m.numel()
cupy.cuda.Device(local_rank).use()
if torch.numel(buffer_m) != torch.numel(worker_error):
empty_tensor = torch.zeros(torch.numel(worker_error) - torch.numel(buffer_m),
device=buffer_m.device)
buffer_m = torch.cat([buffer_m, empty_tensor])
buffer_m.add_(worker_error)
worker_scale = torch.norm(buffer_m) / np.sqrt(torch.numel(buffer_m))
sign_buffer_m = buffer_m.sign().add_(1).bool()
sign_buffer_m = sign_buffer_m.float()
sign_buffer_m.add_(-0.5).mul_(2.0)
worker_error.set_((buffer_m - worker_scale * sign_buffer_m))
sign_buffer_m = None
compensated_buffer_m = buffer_m
compensated_buffer_m.sign_()
compensated_buffer_m = compensated_buffer_m.add_(1).bool()
cupy_worker_scale = self.torch2cupy(worker_scale)
cupy_compensated_buffer_m = self.torch2cupy(compensated_buffer_m)
compensated_buffer_m = None
cupy_sign_list_packed = self.compress_by_chunk(cupy_compensated_buffer_m,
world_size)
cupy_compensated_buffer_m = None
cupy_recvbuf_sign = cupy.zeros([world_size,
cupy_sign_list_packed[rank].size],
dtype=cupy_sign_list_packed[0].dtype)
cupy_recvbuf_scale = cupy.zeros([world_size, 1], dtype=cupy_worker_scale.dtype)
# Communication Phase 1
gather_start = time.time()
if self.cuda_aware:
gather_cuda(rank,
world_size,
comm,
cupy_sign_list_packed,
cupy_recvbuf_sign,
cupy_worker_scale,
cupy_recvbuf_scale)
else:
cupy_sign_list_packed, cupy_recvbuf_sign, cupy_worker_scale, cupy_recvbuf_scale = gather_host(rank,
world_size,
comm,
cupy_sign_list_packed,
cupy_recvbuf_sign,
cupy_worker_scale,
cupy_recvbuf_scale)
gather_end = time.time()
cupy_unpacked_sign = (cupy.unpackbits(cupy_recvbuf_sign.flatten())).reshape(
world_size,
-1)
cupy_recvbuf_sign = None
unpacked_sign = self.cupy2torch(cupy_unpacked_sign).float()
cupy_unpacked_sign = None
unpacked_sign = unpacked_sign.add_(-0.5).mul_(2.0)
worker_scale = self.cupy2torch(cupy_recvbuf_scale).mul_(1 / world_size)
compensated_server_m = unpacked_sign.mul_(worker_scale).sum(0)
unpacked_sign = None
compensated_server_m.add_(server_error)
server_scale = torch.norm(compensated_server_m) / np.sqrt(
compensated_server_m.numel())
sign_server_m = compensated_server_m.sign().add_(1).bool()
sign_server_m = sign_server_m.float()
sign_server_m.add_(-0.5).mul_(2.0)
server_error.set_(compensated_server_m - server_scale * sign_server_m)
sign_server_m = None
compensated_server_m.sign_()
compensated_server_m = compensated_server_m.add_(1).bool()
cupy_server_scale = self.torch2cupy(server_scale)
cupy_compensated_server_m = self.torch2cupy(compensated_server_m)
compensated_server_m = None
cupy_server_sign_packed = self.compress_by_chunk(cupy_compensated_server_m, 1)
cupy_recvbuf_sign_server = cupy.zeros(
[world_size,
cupy_server_sign_packed[0].size],
dtype=cupy_sign_list_packed[0].dtype)
cupy_recvbuf_scale_server = cupy.zeros([world_size,
1],
dtype=cupy_worker_scale.dtype)
# Communication Phase 2
if self.cuda_aware:
allgather_cuda(comm,
cupy_server_sign_packed[0],
cupy_recvbuf_sign_server,
cupy_server_scale,
cupy_recvbuf_scale_server)
else:
cupy_server_sign_packed[0], cupy_recvbuf_sign_server, cupy_server_scale, cupy_recvbuf_scale_server = allgather_host(comm,
cupy_server_sign_packed[0],
cupy_recvbuf_sign_server,
cupy_server_scale,
cupy_recvbuf_scale_server)
self.comm_backend_name = comm_backend_name
# Empty initializer. Set handle based on the comm backend as follows.
self.comm_backend_handle = None
cupy_server_unpacked_sign = (cupy.unpackbits(
cupy_recvbuf_sign_server.flatten())).reshape(world_size,
-1)
cupy_recvbuf_sign_server = None
if self.comm_backend_name == 'nccl':
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
assert TORCH_MAJOR >= 1 and TORCH_MINOR >= 8, "Please use torch 1.8 or greater to enable NCCL backend in 1-bit Adam. Alternatively, please specify 'mpi' as the 'comm_backend_name' in config file to proceed with the MPI backend"
assert dist.is_initialized() == True, "Please initialize the torch distributed backend."
from deepspeed.runtime.comm.nccl import NcclBackend
self.comm_backend_handle = NcclBackend()
server_unpacked_sign = self.cupy2torch(cupy_server_unpacked_sign)
cupy_server_unpacked_sign = None
elif self.comm_backend_name == 'mpi':
from deepspeed.runtime.comm.mpi import MpiBackend
self.comm_backend_handle = MpiBackend(cuda_aware)
server_unpacked_sign = server_unpacked_sign.float().add_(-0.5).mul_(2.0)
server_scale = self.cupy2torch(cupy_recvbuf_scale_server)
buffer_m = server_unpacked_sign.mul_(server_scale).flatten()[0:original_size]
self.size = self.comm_backend_handle.size
return buffer_m
self.divider = int(self.size * 8 / np.gcd(self.size, 8))
def step(self, closure=None, grads=None):
"""Performs a single optimization step.
......@@ -275,9 +152,7 @@ class OnebitAdam(torch.optim.Optimizer):
if grad is None:
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError(
'FusedAdam does not support sparse gradients, please consider SparseAdam instead'
)
raise RuntimeError('1-bit Adam does not support sparse gradients')
state = self.state[p]
......@@ -337,13 +212,24 @@ class OnebitAdam(torch.optim.Optimizer):
if self.size > 1:
exp_avg.set_(
self.Compressed_Allreduce(exp_avg,
state['worker_error'],
state['server_error'],
self.rank,
self.size,
self.comm,
self.deepspeed.local_rank))
self.comm_backend_handle.compressed_allreduce(
exp_avg,
state['worker_error'],
state['server_error'],
self.deepspeed.local_rank))
# Because 1-bit compression cannot represent exact zero, it is required to
# provide a momentum mask for those params that have constant exact zeros in their
# momentums, otherwise the compression error would keep accumulating.
# For example, for BERT pre-training seq 128, bert.embeddings.position_embeddings.weight
# always have exact zeros in its momentum for row 129 to 512, because it only
# learns up to seq length 128 while the model supports up to 512 seq length.
# (See example in DeepSpeedExamples/bing_bert/deepspeed_train.py.)
if 'exp_avg_mask' in group:
if exp_avg.device != group['exp_avg_mask'].device:
group['exp_avg_mask'] = group['exp_avg_mask'].to(
device=exp_avg.device)
exp_avg.mul_(group['exp_avg_mask'])
if self.initialize:
update = exp_avg / (exp_avg_sq.sqrt() + group['eps'])
......@@ -372,3 +258,52 @@ class OnebitAdam(torch.optim.Optimizer):
self.deepspeed.enable_backward_allreduce = False
return loss
def load_state_dict(self, state_dict):
"""
Overrides load_state_dict() to add special handling when loading checkpoints
"""
# Because at different stage exp_avg_mask may change (e.g.,
# BERT pre-training seqlen 128 and 512 ), we don't use the exp_avg_mask
# in checkpoints but always use the one user provided in training script.
# (See example in DeepSpeedExamples/bing_bert/deepspeed_train.py.)
# Thus here we keep the exp_avg_mask unchanged when loading checkpoint
for i, group in enumerate(self.param_groups):
if 'exp_avg_mask' in group:
state_dict['param_groups'][i]['exp_avg_mask'] = group['exp_avg_mask']
elif 'exp_avg_mask' not in group and 'exp_avg_mask' in state_dict[
'param_groups'][i]:
state_dict['param_groups'][i].pop('exp_avg_mask')
super().load_state_dict(state_dict)
if self.state[self.param_groups[0]['params'][0]]['step'] < self.freeze_step:
if torch.distributed.get_rank() == 0:
print("Checkpoint loaded and 1-bit Adam warmup stage starts/continues.")
if self.adam_freeze_key is True:
self.adam_freeze_key = False
self.deepspeed.enable_backward_allreduce = True
else:
if torch.distributed.get_rank() == 0:
print(
"Checkpoint loaded and 1-bit Adam compression stage starts/continues."
)
if self.adam_freeze_key is False:
self.adam_freeze_key = True
self.deepspeed.enable_backward_allreduce = False
# We reset the compression errors when loading checkpoints for 3 reasons:
# 1) The worker and server error at each GPU are distinct, so in current implementation
# only rank 0's errors are saved in the checkpoint. Thus we have to reset the errors.
# If we want to save them correctly we need O(num_gpu*model_size) memory in order to
# gather all the error, which is a very large memory requirement. It's possible to save
# them in a distributed way, but it will make the checkpoint saving/loading much more complicated.
# 2) Even if we are able to save the compression errors correctly, you need to have the
# exact same number of GPUs in order to load them correctly.
# 3) We verified on BERT pre-training that occasionally resetting the compression error
# at checkpoint loading does not affect the convergence.
# However, please avoid frequent checkpoint loading which could break the error
# compensation mechanism thus affect the convergence.
for group in self.param_groups:
for p in group['params']:
if 'worker_error' in self.state[p]:
self.state[p].pop('worker_error')
if 'server_error' in self.state[p]:
self.state[p].pop('server_error')
......@@ -60,7 +60,7 @@ The Adam optimizer also supports the following two params keys/values in additio
| torch\_adam | Use torch's implementation of adam instead of our fused adam implementation | false |
| adam\_w\_mode | Apply L2 regularization (also known as AdamW) | true |
Another example of ***optimizer*** with 1-bit Adam specific parameters is as follows.
Another example of ***optimizer*** with 1-bit Adam
```json
"optimizer": {
......@@ -74,11 +74,20 @@ The Adam optimizer also supports the following two params keys/values in additio
"eps": 1e-8,
"weight_decay": 3e-7,
"freeze_step": 400,
"cuda_aware": true
"cuda_aware": false,
"comm_backend_name": "nccl"
}
}
```
The 1-bit Adam optimizer supports the following three params keys/values in addition to the standard Adam (learn more in our [tutorial](/tutorials/onebit-adam/)):
| "params" key | Description | Default |
| ------------- | --------------------------------------------------------------------------- | ------- |
| freeze\_step | Number of warm up steps before 1-bit compression gets applied to the communication | 100000 |
| cuda\_aware | To indicate that the underlying MPI library supports CUDA-Aware communication | false |
| comm\_backend\_name | To indicate which backend implementation to use | "nccl" |
### Scheduler Parameters
***scheduler***: [dictionary]
......
---
title: "1-bit Adam: Up to 5x less communication volume and up to 2x faster training"
title: "1-bit Adam: Up to 5x less communication volume and up to 3.4x faster training"
---
**Note:**
This tutorial is updated on 03/04/2021 to reflect the 1-bit Adam v2. Changes include: 1) NCCL-based implementation which provides better performance and usability compared to the MPI-based implementation. 2) Add support to momentum masks for those parameters with constant zero gradients during training. 3) Bug fixes. See details below.
{: .notice--info}
**Watch out!**
1) The NCCL-based implementation requires PyTorch >= 1.8 (and NCCL >= 2.8.3 when you have 64 or more GPUs). See details below. 2) Although 1-bit Adam is compatible with both FP16 and FP32, currently we only verified the convergence under mixed precision/FP16 training. 3) Currently 1-bit Adam is not compatible with pipeline parallelism. 4) Frequent checkpoint loading could hurt 1-bit Adam's convergence. See details below.
{: .notice--warning}
In this tutorial, we are going to introduce the 1-bit Adam optimizer in DeepSpeed. 1-bit Adam can improve model training speed on communication-constrained clusters, especially for communication-intensive large models by reducing the overall communication volume by up to 5x. Detailed description of the 1-bit Adam algorithm, its implementation in DeepSpeed, and performance evaluation is available from our [blog post](https://www.deepspeed.ai/news/2020/09/08/onebit-adam-blog-post.html). We also have a [paper](https://arxiv.org/abs/2102.02888) which provides the most complete details including algorithm, system implementation, theoretical analysis, and more evaluations.
To illustrate the benefits and usage of 1-bit Adam optimizer in DeepSpeed, we use the following two training tasks as examples:
......@@ -13,7 +21,7 @@ For more details on these tasks, please refer to the tutorial posts on [BingBert
## 1. Overview
### Pre-requisites for installing DeepSpeed
### 1.1 Pre-requisites for installing DeepSpeed
If you don't already have a copy of the DeepSpeed repository, please clone in
now and checkout the DeepSpeedExamples submodule that contains the BingBertSQuAD and BERT Pre-training examples.
......@@ -25,9 +33,19 @@ git submodule update --init --recursive
cd DeepSpeedExamples/
```
### Pre-requisites for 1-bit Adam
### 1.2 Pre-requisites for 1-bit Adam
#### 1.2.1 (New in v2) NCCL-based implementation
In 1-bit Adam v2, we introduce a new system implementation for compressed communication using the NCCL backend of PyTorch distributed. This significantly improves the usability due to NCCL’s integration with PyTorch distributed. The performance of our new NCCL-based implementation is also better than our earlier MPI-based implementation for Ethernet-based systems and on-par for InfiniBand-based systems. Thus we highly recommend users to choose this implementation.
**Watch out!**
This NCCL-based implementation requires PyTorch >= 1.8. It also requires NCCL >= 2.8.3 when you have 64 or more GPUs to avoid certain NCCL runtime bugs. Currently (2021/03/16) NCCL 2.8.3 is not officially supported by PyTorch. The solution we used is by hacking in NCCL 2.8.3 via `LD_PRELOAD`: 1) Install NCCL 2.8.3. This works for us on a CUDA 11 system: `apt-get install -y libnccl2=2.8.3-1+cuda11.0 libnccl-dev=2.8.3-1+cuda11.0`. 2) Set `LD_PRELOAD` to the the library path. This works for us: `LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libnccl.so.2.8.3`. To confirm `LD_PRELOAD` is working you can see the version it uses in the NCCL logs if you have `NCCL_DEBUG=INFO`, it should say: NCCL version 2.8.3+cuda11.0.
{: .notice--warning}
1-bit Adam uses advanced communication schemes that are not yet supported by PyTorch distributed and NCCL. We rely on Message Passing Interface (MPI) for these advanced communication primitives.
#### 1.2.2 MPI-based implementation
For this implementation, we rely on Message Passing Interface (MPI) for advanced communication primitives.
We package the necessary dependencies in the DeepSpeed docker images. However, if you are using a different build system, please install MPI and mpi4py on your system. To install the prerequisites run:
......@@ -43,31 +61,32 @@ An example launch command for 1-bit Adam using the `deepspeed` launcher is as fo
deepspeed --launcher=[mvapich|openmpi] script.py
```
Please note that because 1-bit Adam uses MPI backend to communicate during the compression stage, the `--launcher=[mvapich|openmpi]` flag is required when using the `deepspeed` launcher.
Please note that for MPI-based implementation of 1-bit Adam, the `--launcher=[mvapich|openmpi]` flag is required when using the `deepspeed` launcher.
Alternatively, the standard mpirun launcher can also be used as follows:
```shell
mpirun -np [#processes] -ppn [#GPUs on each node] -hostfile [hostfile] [MPI flags] bash [training_script.sh]
mpirun -np [#processes] -ppn [#GPUs on each node] -hostfile [hostfile] [MPI flags] python [training_script.py]
```
### 1-bit Algorithm
### 1.3 1-bit Algorithm
The detailed description of the 1-bit Algorithm can be seen from our [blog post](https://www.deepspeed.ai/news/2020/09/09/onebit-adam-blog-post.html).
The detailed description of the 1-bit Algorithm can be seen from our [blog post](https://www.deepspeed.ai/news/2020/09/09/onebit-adam-blog-post.html) and our [paper](https://arxiv.org/abs/2102.02888).
### Configuration of 1-bit Adam
### 1.4 Configuration of 1-bit Adam
The 1-bit Adam feature can be used by setting the optimizer configuration options as follows. An example json config file is shown below.
```json
{
"train_batch_size": 4096,
"train_micro_batch_size_per_gpu": 64,
"train_micro_batch_size_per_gpu": 16,
"optimizer": {
"type": "OneBitAdam",
"params": {
"lr": 2e-4,
"freeze_step": 400,
"cuda_aware": true
"lr": 4e-4,
"freeze_step": 23000,
"cuda_aware": false,
"comm_backend_name": "nccl"
}
},
"fp16": {
......@@ -75,12 +94,20 @@ The 1-bit Adam feature can be used by setting the optimizer configuration option
}
}
```
Please note two new parameters `freeze_step` and `cuda_aware` that have been added to support the 1-bit Adam feature.
Please note three new parameters `freeze_step`, `cuda_aware`, and `comm_backend_name` that have been added to support the 1-bit Adam feature.
`freeze_step` is the number of warm up steps before 1-bit compression gets applied to the communication. In order to determine the number of warm up steps, one strategy is to set 15-25% of the total training steps for a given model (This is related to Adam's variance/second moment term. See detailed analysis in our [paper](https://arxiv.org/abs/2102.02888)). If it provides the desired outcome, one can try to extract more performance by reducing the steps systematically. In future, we plan to introduce a threshold that can automatically search and decide for the number of warm up steps for different models. The examples below have been tuned for the number of warm up steps. The `freeze_step` parameter has already been set to the best number we found in the corresponding run scripts.
`cuda_aware` is used to indicate that the underlying MPI library support CUDA-Aware communication.
This feature is only supported on systems with InfiniBand interconnect and a CUDA-Aware MPI library like [MVAPICH2-GDR](http://mvapich.cse.ohio-state.edu/userguide/gdr/) or OpenMPI built with CUDA-Aware support. Setting `cuda_aware` to False will allow training on Ethernet based systems. However, the communication will happen using sender as well as receiver side memory copies between CPU and GPU buffers before and after communication.
`cuda_aware` is used for MPI-based implementation to indicate that the underlying MPI library supports CUDA-Aware communication. This feature is only supported on systems with InfiniBand interconnect and a CUDA-Aware MPI library like [MVAPICH2-GDR](http://mvapich.cse.ohio-state.edu/userguide/gdr/) or OpenMPI built with CUDA-Aware support. Setting `cuda_aware` to False will allow training on Ethernet based systems. However, the communication will happen using sender as well as receiver side memory copies between CPU and GPU buffers before and after communication.
`freeze_step` is the number of warm up steps before 1-bit compression gets applied to the communication. In order to determine the number of warm up steps, one strategy is to set 15-25% of the total training steps for a given model. If it provides the desired outcome, one can try to extract more performance by reducing the steps systematically. In future, we plan to introduce a threshold that can automatically search and decide for the number of warm up steps for different models. The examples below have been tuned for the number of warm up steps. The `freeze_step` parameter has already been set to the best number we found in the corresponding run scripts.
(New in v2) `comm_backend_name` is used to indicate which backend implementation to use. You can choose between NCCL and MPI-based implementations by setting `comm_backend_name` to "nccl" and "mpi". When using NCCL-based implementation, there is no need to set `cuda_aware`.
#### 1.4.1 (New in v2) Momentum masks for parameters with constant zero gradients
Because 1-bit compression cannot represent exact zero, the compression error would keep accumulating in the momentum if a parameter have constant zero gradients during training. For example, for BERT pre-training seq length 128, `bert.embeddings.position_embeddings.weight` has constant zeros in its gradient and momentum for row 129 to 512, because it only learns up to seq length 128 while the model supports up to seq length 512. Thus in 1-bit Adam v2 we added support of a momentum mask for users to specify those params that have constant exact zeros in their gradients. See [example script](https://github.com/microsoft/DeepSpeedExamples/blob/master/bing_bert/deepspeed_train.py) for how to configure this momentum mask. One thing to note is that we don't use momentum mask saved in checkpoints since this mask could change during training (e.g., BERT seqlen 128 and 512 require different masks). So you have to provide this mask every time in your training script.
**Watch out!**
1-bit Adam replies on an compression error compensation mechanism to maintain the convergence speed at compression stage. When loading checkpoints, we actually reset the compression errors for 3 reasons: 1) The worker and server error at each GPU are distinct, so in current implementation only rank 0's errors are saved in the checkpoint. Thus we have to reset the errors. If we want to save them correctly we need O(num_gpu*model_size) memory in order to gather all the error, which is a very large memory requirement. It's possible to save them in a distributed way, but it will make the checkpoint saving/loading much more complicated. 2) Even if we are able to save the compression errors correctly, you need to have the exact same number of GPUs in order to load them correctly. 3) We verified on BERT pre-training that occasionally resetting the compression error at checkpoint loading does not affect the convergence. However, please avoid frequent checkpoint loading which could break the error compensation mechanism thus affect the convergence.
{: .notice--warning}
## 2. BingBertSQuAD Fine-tuning with 1-bit Adam
......@@ -93,9 +120,13 @@ This feature is only supported on systems with InfiniBand interconnect and a CUD
You can also use a pre-trained BERT model checkpoint from either DeepSpeed, [HuggingFace](https://github.com/huggingface/transformers), or [TensorFlow](https://github.com/google-research/bert#pre-trained-models) to run the fine-tuning.
**Note:** For details about loading checkpoint, argument parsing, initialization, forward pass, backward pass, weight update and evaluation, please refer to the [BingBertSQuAD Fine-tuning](/tutorials/bert-finetuning/) tutorial.
### 2.1 Running BingBertSQuAD with DeepSpeed and 1-bit Adam
The main part of training is done in `nvidia_run_squad_deepspeed.py`, which has
We provide example scripts under [DeepSpeedExamples/BingBertSquad/1-bit_adam/](https://github.com/microsoft/DeepSpeedExamples/tree/master/BingBertSquad/1-bit_adam). There are 3 sets of scripts corresponding to NCCL-based implementation, MPI-based implementation on Ethernet systems, and MPI-based implementation on InfiniBand systems. For MPI-based implementation, we provide both example scripts when launching with deepspeed or mpirun.
<!-- The main part of training is done in `nvidia_run_squad_deepspeed.py`, which has
already been modified to use DeepSpeed. The `run_squad_deepspeed.sh` script
helps to invoke training and setup several different hyperparameters relevant
to the training process.
......@@ -132,7 +163,7 @@ For example, in order to use 32 GPUs (4GPUs/node, 8 nodes in total), with the su
```shell
mpirun -np 32 -ppn 4 -hostfile hosts -env MV2_USE_CUDA=1 -env MV2_SUPPORT_DL=1 -env MV2_ENABLE_AFFINITY=0 -env MV2_SMP_USE_CMA=0 bash run_squad_mpi_onebitadam.sh
```
``` -->
### 2.2 Configuration for BingBertSQuAD with DeepSpeed and 1-bit Adam enabled
......@@ -148,18 +179,16 @@ Table 1 shows the fine-tuning configuration we used in our experiments.
| ------------------------------ | ---------------------|
| Total batch size | 96 |
| Train micro batch size per GPU | 3 |
| Optimizer | **OnebitAdam** |
| Optimizer | **"OnebitAdam"** |
| Learning rate | 3e-5 |
| Sequence-length | 384 |
| Weight-decay | 0.0 |
| Epoch count | 2 |
| **freeze_step** | 400 |
| **cuda_aware** | True |
| **comm_backend_name** | "nccl" |
Table 1. Fine-tuning configuration
**Note:** For more details about loading checkpoint, argument parsing, initialization, forward pass, backward pass, weight update and evaluation, please refer to the [BingBertSQuAD Fine-tuning](/tutorials/bert-finetuning/) tutorial.
### 2.3 Performance Results for BingBertSQuAD Fine-tuning
***Accuracy:***
......@@ -174,19 +203,24 @@ We fixed the learning rate to 3e-5. The table below shows the F1 and the EM scor
***Training Speed and Scalability:***
1-bit Adam enables up to 2.7x overall speedup in training speed for SQuAD fine-tuning. This is made possible by up to 6.2x faster throughput during the compressed stage of the algorithm as shown in Figure 1.
<!-- 1-bit Adam enables up to 2.7x overall speedup in training speed for SQuAD fine-tuning. This is made possible by up to 6.2x faster throughput during the compressed stage of the algorithm as shown in Figure 1.
![SQuAD Finetuning](/assets/images/squad-scaling.png){: .align-center}
Figure 1: Scalability of 1-bit Adam for SQuAD Finetuning on V100 GPUs with batch size of 3/GPU.
Figure 1: Scalability of 1-bit Adam for SQuAD Finetuning on V100 GPUs with batch size of 3/GPU. -->
Performance results of SQuAD Fine-tuning can be seen from our [blog post](https://www.deepspeed.ai/news/2020/09/09/onebit-adam-blog-post.html) and our [paper](https://arxiv.org/abs/2102.02888).
## 3. BERT Pre-training with 1-bit Adam
For data downloading and pre-processing, please refer to the [BERT Pre-training](/tutorials/bert-pretraining/) post.
For data downloading and pre-processing, please refer to the [BERT Pre-training](/tutorials/bert-pretraining/) tutorial.
### 3.1 Running Pre-training with DeepSpeed and 1-bit Adam
The main part of training is done in `deepspeed_train.py`, which has
We provide example scripts under [DeepSpeedExamples/bing_bert/1-bit_adam/](https://github.com/microsoft/DeepSpeedExamples/tree/master/bing_bert/1-bit_adam). There are 3 sets of scripts corresponding to NCCL-based implementation, MPI-based implementation on Ethernet systems, and MPI-based implementation on InfiniBand systems. For MPI-based implementation, we provide both example scripts when launching with deepspeed or mpirun.
<!-- The main part of training is done in `deepspeed_train.py`, which has
already been modified to use DeepSpeed. The `ds_train_bert_onebit_bsz4k_seq128.sh` and `ds_train_bert_bsz64k_seq128.sh`
are the shell scripts that help to invoke training and setup several different hyperparameters relevant
to the training process.
......@@ -218,11 +252,11 @@ mpirun -np [#processes] -ppn [#GPUs on each node] -hostfile [hostfile] [MPI flag
For example, in order to use 32 GPUs (4GPUs/node, 8 nodes in total), with the support of InfiniBand, you can use MVAPICH2 as the launcher and run the following command:
```shell
mpirun -np 32 -ppn 4 -hostfile hosts -env MV2_USE_CUDA=1 -env MV2_SUPPORT_DL=1 -env MV2_ENABLE_AFFINITY=0 -env MV2_SMP_USE_CMA=0 bash ds_train_bert_onebit_bsz4k_seq128.sh
```
``` -->
### 3.2 Configuration for BERT Pre-training with DeepSpeed and 1-bit Adam enabled
The `deepspeed_bsz4k_onebit_config_seq128.json` file gives the user the ability to specify DeepSpeed
The `deepspeed_bsz4k_onebit_config_seq128_*.json` file gives the user the ability to specify DeepSpeed
options in terms of batch size, micro batch size, optimizer, learning rate, and other parameters.
Below is the DeepSpeed configuration file for running BERT-large pre-training with sequence length of 128 using the 1-bit Adam optimizer.
......@@ -240,7 +274,7 @@ Below is the DeepSpeed configuration file for running BERT-large pre-training wi
"weight_decay": 0.01,
"bias_correction": false,
"freeze_step": 23000,
"cuda_aware": true
"comm_backend_name": "nccl"
}
},
"gradient_clipping": 1.0,
......@@ -251,8 +285,8 @@ Below is the DeepSpeed configuration file for running BERT-large pre-training wi
}
}
```
The above file is for BERT-large but for BERT-base training (sequence length 128), the suggested `freeze_step` will need to be changed to 16000. For the rest of the pre-training using sequence 512, we suggest to use a `freeze_step` of 1500. And make sure to set the `cuda_aware` correctly as described above.
The above file is for BERT-large. For BERT-base training (sequence length 128), the suggested `freeze_step` is 16000. For sequence 512 pre-training, we suggest to use a `freeze_step` of 1500 for both BERT-base and BERT-large. And make sure to set the `comm_backend_name` and `cuda_aware` correctly as described above.
### 3.3 Performance Results for BERT Pre-training
Performance results of BERT Pre-training can be seen from our detailed [blog post](https://www.deepspeed.ai/news/2020/09/09/onebit-adam-blog-post.html).
Performance results of BERT Pre-training can be seen from our [blog post](https://www.deepspeed.ai/news/2020/09/09/onebit-adam-blog-post.html) and our [paper](https://arxiv.org/abs/2102.02888).
......@@ -17,4 +17,4 @@ FusedLamb (GPU)
OneBitAdam (GPU)
----------------------------
.. autoclass:: deepspeed.runtime.fp16.OneBitAdam
.. autoclass:: deepspeed.runtime.fp16.onebit.adam.OneBitAdam
......@@ -28,6 +28,7 @@ initiative to enable next-generation AI capabilities at scale, where you can fin
information [here](https://innovation.microsoft.com/en-us/exploring-ai-at-scale).
# What's New?
* [2021/03/16] [1-bit Adam v2: NCCL-based implementation and more](https://www.deepspeed.ai/tutorials/onebit-adam/)
* [2021/03/08] [ZeRO-3 Offload: Scale your models to trillion parameters without code changes while leveraging both CPUs & GPUs](https://www.deepspeed.ai/news/2021/03/07/zero3-offload.html)
* [2020/11/12] [Simplified install, JIT compiled ops, PyPI releases, and reduced dependencies](#installation)
* [2020/11/10] [Efficient and robust compressed training through progressive layer dropping](https://www.deepspeed.ai/news/2020/10/28/progressive-layer-dropping-news.html)
......
......@@ -4,26 +4,22 @@ import torch
import torch.distributed as dist
import numpy as np
import deepspeed
from deepspeed.runtime.fp16.onebit_adam import OnebitAdam
from deepspeed.runtime.comm.mpi import MpiBackend
comm = MPI.COMM_WORLD
size = comm.Get_size()
rank = comm.Get_rank()
#TODO: Detect the hostname we are running on automatically
torch.distributed.init_process_group(backend='nccl',
init_method='tcp://worker-1:2245',
world_size=size,
rank=rank)
dummy_model = [torch.nn.Parameter(torch.ones(10))]
deepspeed.init_distributed(dist_backend='nccl')
# Set cuda_aware to True to use CUDA buffers for communication
dummy_optim = OnebitAdam(dummy_model, cuda_aware=True)
# Change cuda_aware to True to test out CUDA-Aware MPI communication
backend = MpiBackend(cuda_aware=False)
device = torch.device('cuda', rank % torch.cuda.device_count())
# A simulated compression function using torch.distributed
def torch_sim(a):
a_sign = a.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)
scale = a.norm() / np.sqrt(a.numel())
......@@ -52,21 +48,20 @@ if tensor_size % (8 * size) != 0:
else:
right_tensor_size = tensor_size
right_server_size = right_tensor_size // size
# Adding bias to the initialization of the gradient we are communicating
# In order to get rid of the case where some elements in the gradient are too small
a = (torch.rand(tensor_size, device=device) - 0.5) + 0.01 * rank
worker_error = torch.zeros(right_tensor_size, device=device)
server_error = torch.zeros(right_server_size, device=device)
a_torch, worker_error_torch, server_error_torch = torch_sim(a)
torch.cuda.empty_cache()
local_rank = rank % torch.cuda.device_count()
a_after = dummy_optim.Compressed_Allreduce(a,
worker_error,
server_error,
rank,
size,
comm,
local_rank)
a_after = backend.compressed_allreduce(a, worker_error, server_error, local_rank)
threshold = 1e-6
magnitude_threshold = 1e-6
diff_mask = (a_after - a_torch) > threshold
......@@ -74,13 +69,16 @@ diff_server_mask = torch.chunk(diff_mask, size)[rank]
mpi_server = torch.chunk(a_after, size)[rank] + server_error
torch_server = torch.chunk(a_torch, size)[rank] + server_error_torch
test_correctness = True
# If the number in the compensated_server_m is too small (e.g 1e-8), then calling sign() might be problematic
# The test would skip those numbers that are too small in compensated_server_m
if torch.sum(diff_server_mask) == 0:
print('Successfully passed the test for 1bit Adam at Rank {}'.format(rank))
else:
check_mag_mask = mpi_server[diff_mask] > magnitude_threshold
if torch.sum(check_mag_mask) == 0:
print('Successfully passed the test for 1bit Adam at Rank {}'.format(rank))
if test_correctness:
if torch.sum(diff_server_mask) == 0:
print('Successfully passed the test for MPI Backend at Rank {}'.format(rank))
else:
print('Fails at {} of positions'.format(torch.sum(check_mag_mask)))
check_mag_mask = mpi_server[diff_server_mask] > magnitude_threshold
if torch.sum(check_mag_mask) == 0:
print('Successfully passed the test for MPI Backend at Rank {}'.format(rank))
else:
print('Fails at {} of positions'.format(torch.sum(check_mag_mask)))
from mpi4py import MPI
import time
import torch
import torch.distributed as dist
import numpy as np
import deepspeed
from deepspeed.runtime.comm.mpi import MpiBackend
# Configure wall clock timer
from deepspeed.utils.timer import SynchronizedWallClockTimer
from statistics import mean
timers = SynchronizedWallClockTimer()
comm = MPI.COMM_WORLD
size = comm.Get_size()
rank = comm.Get_rank()
deepspeed.init_distributed(dist_backend='nccl')
# Change cuda_aware to True to test out CUDA-Aware MPI communication
backend = MpiBackend(cuda_aware=False)
device = torch.device('cuda', rank % torch.cuda.device_count())
tensor_size = 300 * 2**20
server_size = int(tensor_size / size)
if tensor_size % (8 * size) != 0:
right_tensor_size = tensor_size + (8 * size - (tensor_size % (8 * size)))
else:
right_tensor_size = tensor_size
right_server_size = right_tensor_size // size
# Adding bias to the initialization of the gradient we are communicating
# In order to get rid of the case where some elements in the gradient are too small
a = (torch.rand(tensor_size, device=device) - 0.5) + 0.01 * rank
worker_error = torch.zeros(right_tensor_size, device=device)
server_error = torch.zeros(right_server_size, device=device)
warmup = 10
iters = 10
local_rank = rank % torch.cuda.device_count()
# Warmup
for i in range(warmup):
backend.compressed_allreduce(a, worker_error, server_error, local_rank)
time_list = []
for i in range(iters):
timers('compressed_allreduce').start()
backend.compressed_allreduce(a, worker_error, server_error, local_rank)
timers('compressed_allreduce').stop()
time_list.append(timers('compressed_allreduce').elapsed())
timer_names = ['compressed_allreduce']
timers.log(names=timer_names, normalizer=1, memory_breakdown=None)
places = 2
convert = 1e3
float_size = 4
if rank == 0:
for i in range(iters):
lat = time_list[i]
print("latency = ", lat * convert)
minlat = round(min(time_list) * convert)
maxlat = round(max(time_list) * convert)
meanlat = round(mean(time_list) * convert, places)
print("min, max, and mean = {} ms, {} ms, {} ms".format(minlat, maxlat, meanlat))
from mpi4py import MPI
import time
import torch
import torch.distributed as dist
import numpy as np
import argparse
import deepspeed
from deepspeed.runtime.fp16.onebit_adam import OnebitAdam
import os
comm = MPI.COMM_WORLD
size = comm.Get_size()
rank = comm.Get_rank()
from deepspeed.runtime.comm.nccl import NcclBackend
#TODO: Detect the hostname we are running on automatically
torch.distributed.init_process_group(backend='nccl',
init_method='tcp://worker-1:2245',
world_size=size,
rank=rank)
parser = argparse.ArgumentParser()
parser.add_argument('--local_rank', type=int, default=-1)
args = parser.parse_args()
dummy_model = [torch.nn.Parameter(torch.ones(10))]
deepspeed.init_distributed(dist_backend='nccl')
args.local_rank = int(os.environ['LOCAL_RANK'])
# Set cuda_aware to False to use host buffers for communication
dummy_optim = OnebitAdam(dummy_model, cuda_aware=False)
torch.cuda.set_device(args.local_rank)
device = torch.device("cuda", args.local_rank)
device = torch.device('cuda', rank % torch.cuda.device_count())
size = dist.get_world_size()
rank = dist.get_rank()
backend = NcclBackend()
local_rank = args.local_rank
# A simulated compression function using torch.distributed
def torch_sim(a):
a_sign = a.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)
scale = a.norm() / np.sqrt(a.numel())
......@@ -45,28 +47,26 @@ def torch_sim(a):
return a_server_compressed, worker_error, server_error
tensor_size = 100 * 2**20
tensor_size = 300 * 2**20
server_size = int(tensor_size / size)
if tensor_size % (8 * size) != 0:
right_tensor_size = tensor_size + (8 * size - (tensor_size % (8 * size)))
else:
right_tensor_size = tensor_size
right_server_size = right_tensor_size // size
# Adding bias to the initialization of the gradient we are communicating
# In order to get rid of the case where some elements in the gradient are too small
a = (torch.rand(tensor_size, device=device) - 0.5) + 0.01 * rank
worker_error = torch.zeros(right_tensor_size, device=device)
server_error = torch.zeros(right_server_size, device=device)
a_torch, worker_error_torch, server_error_torch = torch_sim(a)
torch.cuda.empty_cache()
local_rank = rank % torch.cuda.device_count()
a_after = dummy_optim.Compressed_Allreduce(a,
worker_error,
server_error,
rank,
size,
comm,
local_rank)
a_after = backend.compressed_allreduce(a, worker_error, server_error, local_rank)
threshold = 1e-6
magnitude_threshold = 1e-6
diff_mask = (a_after - a_torch) > threshold
......@@ -74,13 +74,17 @@ diff_server_mask = torch.chunk(diff_mask, size)[rank]
mpi_server = torch.chunk(a_after, size)[rank] + server_error
torch_server = torch.chunk(a_torch, size)[rank] + server_error_torch
test_correctness = True
# If the number in the compensated_server_m is too small (e.g 1e-8), then calling sign() might be problematic
# The test would skip those numbers that are too small in compensated_server_m
if torch.sum(diff_server_mask) == 0:
print('Successfully passed the test for 1bit Adam at Rank {}'.format(rank))
else:
check_mag_mask = mpi_server[diff_mask] > magnitude_threshold
if torch.sum(check_mag_mask) == 0:
print('Successfully passed the test for 1bit Adam at Rank {}'.format(rank))
if test_correctness:
if torch.sum(diff_server_mask) == 0:
print('Successfully passed the test for NCCL Backend at Rank {}'.format(rank))
else:
print('Fails at {} of positions'.format(torch.sum(check_mag_mask)))
check_mag_mask = mpi_server[diff_server_mask] > magnitude_threshold
if torch.sum(check_mag_mask) == 0:
print(
'Successfully passed the test for NCCL Backend at Rank {}'.format(rank))
else:
print('Fails at {} of positions'.format(torch.sum(check_mag_mask)))
import time
import torch
import torch.distributed as dist
import numpy as np
import argparse
import deepspeed
import os
from deepspeed.runtime.comm.nccl import NcclBackend
from deepspeed.utils.timer import SynchronizedWallClockTimer
from statistics import mean
timers = SynchronizedWallClockTimer()
parser = argparse.ArgumentParser()
parser.add_argument('--local_rank', type=int, default=-1)
args = parser.parse_args()
deepspeed.init_distributed(dist_backend='nccl')
args.local_rank = int(os.environ['LOCAL_RANK'])
torch.cuda.set_device(args.local_rank)
device = torch.device("cuda", args.local_rank)
size = dist.get_world_size()
rank = dist.get_rank()
backend = NcclBackend()
local_rank = args.local_rank
# Setting tensor_size (BERT-Large)
tensor_size = 300 * 2**20
server_size = int(tensor_size / size)
if tensor_size % (8 * size) != 0:
right_tensor_size = tensor_size + (8 * size - (tensor_size % (8 * size)))
else:
right_tensor_size = tensor_size
right_server_size = right_tensor_size // size
# Adding bias to the initialization of the gradient we are communicating
# In order to get rid of the case where some elements in the gradient are too small
a = (torch.rand(tensor_size, device=device) - 0.5) + 0.01 * rank
worker_error = torch.zeros(right_tensor_size, device=device)
server_error = torch.zeros(right_server_size, device=device)
warmup = 10
iters = 10
# Warmup
for i in range(warmup):
backend.compressed_allreduce(a, worker_error, server_error, local_rank)
time_list = []
a_sign = a.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)
scale = a.norm() / np.sqrt(a.numel())
a_compressed = scale * a_sign
print("Shape of the compressed buffer:", a_compressed.shape) if rank == 0 else None
for i in range(iters):
timers('compressed_allreduce').start()
backend.compressed_allreduce(a, worker_error, server_error, local_rank)
#torch.distributed.all_reduce(a_compressed)
timers('compressed_allreduce').stop()
time_list.append(timers('compressed_allreduce').elapsed())
#timer_names = ['compressed_allreduce']
#timers.log(names=timer_names, normalizer=1, memory_breakdown=None)
places = 2
convert = 1e3
float_size = 4
if rank == 0:
for i in range(iters):
lat = time_list[i]
print("latency = ", lat * convert)
minlat = round(min(time_list) * convert)
maxlat = round(max(time_list) * convert)
meanlat = round(mean(time_list) * convert, places)
print("min, max, and mean = {} ms, {} ms, {} ms".format(minlat,
maxlat,
meanlat)) if rank == 0 else None
#print("tensor shape", a.shape)
duration = meanlat / 1e3
tput = ((tensor_size * 4) / duration)
print("algo throughput: %f Bytes/s, %f GB/s" % (tput, tput / 1e9)) if rank == 0 else None
size = tensor_size * 4
n = dist.get_world_size()
busbw = (size / duration) * (2 * (n - 1) / n)
print("busbw: %f GB/s" % (busbw / 1e9)) if rank == 0 else None
from mpi4py import MPI
import time
import torch
import torch.distributed as dist
import numpy as np
import deepspeed
from deepspeed.runtime.fp16.onebit_adam import OnebitAdam
comm = MPI.COMM_WORLD
size = comm.Get_size()
rank = comm.Get_rank()
torch.distributed.init_process_group(backend='nccl',
init_method='tcp://worker-0:2245',
world_size=size,
rank=rank)
dummy_model = [torch.nn.Parameter(torch.ones(10))]
dummy_optim = OnebitAdam(dummy_model, cuda_aware=False)
device = torch.device('cuda', rank % torch.cuda.device_count())
def torch_sim(a):
a_sign = a.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)
scale = a.norm() / np.sqrt(a.numel())
a_compressed = scale * a_sign
a_sign = None
worker_error = a - a_compressed
dist.all_reduce(a_compressed)
a_compressed.mul_(1 / dist.get_world_size())
a_server_sign = a_compressed.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)
a_list = torch.chunk(a_compressed, chunks=dist.get_world_size())
server_scale = [chunk_a.norm() / np.sqrt(chunk_a.numel()) for chunk_a in a_list]
a_sign_list = torch.chunk(a_server_sign, dist.get_world_size())
a_server_compressed = torch.cat(
[server_scale[i] * a_sign_list[i] for i in range(dist.get_world_size())])
rank = dist.get_rank()
server_error = a_list[rank] - server_scale[rank] * a_sign_list[rank]
torch.cuda.synchronize()
torch.distributed.barrier()
return a_server_compressed, worker_error, server_error
# Input Tensor size
tensor_size = 100 * 2**20
server_size = int(tensor_size / size)
if tensor_size % (8 * size) != 0:
right_tensor_size = tensor_size + (8 * size - (tensor_size % (8 * size)))
else:
right_tensor_size = tensor_size
right_server_size = right_tensor_size // size
# The -0.5 is required for avoiding sign flips/errors
a = torch.rand(tensor_size, device=device) - 0.5
worker_error = torch.zeros(right_tensor_size, device=device)
server_error = torch.zeros(right_server_size, device=device)
a_torch, worker_error_torch, server_error_torch = torch_sim(a)
torch.cuda.empty_cache()
local_rank = rank % torch.cuda.device_count()
# Test the 1-bit Adam optimizer
a_after = dummy_optim.Compressed_Allreduce(a,
worker_error,
server_error,
rank,
size,
comm,
local_rank)
# If the error is below the threshold, it is acceptable for training
threshold = 1e-6
diff_pos = ((a_after - a_torch) > threshold)
if rank == 0:
before_diff = torch.chunk(a_after - a_torch,
size)[rank] + server_error - server_error_torch
if torch.norm(before_diff) / torch.norm(torch.chunk(a_after,
size)[rank]) < threshold:
print('Successfully passed the test')
else:
print('The difference for the tensor before allgather is {}'.format(
torch.norm(before_diff)))
import torch
import torch.distributed as dist
import deepspeed
import argparse
import pytest
import json
import os
import numpy as np
import time
from common import distributed_test
from simple_model import SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict, create_deepspeed_args
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
if TORCH_MAJOR < 1 or TORCH_MINOR < 8:
pytest.skip("NCCL-based 1-bit compression requires torch 1.8 or higher",
allow_module_level=True)
def test_onebitadam_fp16_basic(tmpdir):
config_dict = {
"train_batch_size": 2,
"steps_per_print": 1,
"optimizer": {
"type": "OneBitAdam",
"params": {
"lr": 0.00015,
"weight_decay": 0.01,
"freeze_step": 2,
"cuda_aware": False,
"comm_backend_name": "nccl"
}
},
"gradient_clipping": 1.0,
"fp16": {
"enabled": True,
"loss_scale": 0,
"initial_scale_power": 16
}
}
args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10
model = SimpleModel(hidden_dim)
@distributed_test(world_size=[1, 2])
def _test_onebitadam_fp16_basic(args, model, hidden_dim):
model, _, _, _ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())
data_loader = random_dataloader(model=model,
total_samples=50,
hidden_dim=hidden_dim,
device=model.device)
for n, batch in enumerate(data_loader):
loss = model(batch[0], batch[1])
model.backward(loss)
model.step()
_test_onebitadam_fp16_basic(args=args, model=model, hidden_dim=hidden_dim)
def test_onebitadam_fp32_basic(tmpdir):
config_dict = {
"train_batch_size": 2,
"steps_per_print": 1,
"optimizer": {
"type": "OneBitAdam",
"params": {
"lr": 0.00015,
"weight_decay": 0.01,
"freeze_step": 2,
"cuda_aware": False,
"comm_backend_name": "nccl"
}
},
"gradient_clipping": 1.0,
}
args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10
model = SimpleModel(hidden_dim)
@distributed_test(world_size=[1, 2])
def _test_onebitadam_fp32_basic(args, model, hidden_dim):
model, _, _, _ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())
data_loader = random_dataloader(model=model,
total_samples=50,
hidden_dim=hidden_dim,
device=model.device,
dtype=torch.float)
for n, batch in enumerate(data_loader):
loss = model(batch[0], batch[1])
model.backward(loss)
model.step()
_test_onebitadam_fp32_basic(args=args, model=model, hidden_dim=hidden_dim)
def test_onebitadam_exp_avg_mask(tmpdir):
config_dict = {
"train_batch_size": 2,
"steps_per_print": 1,
"optimizer": {
"type": "OneBitAdam",
"params": {
"lr": 0.00015,
"weight_decay": 0.01,
"freeze_step": 2,
"cuda_aware": False,
"comm_backend_name": "nccl"
}
},
"gradient_clipping": 1.0,
"fp16": {
"enabled": True,
"loss_scale": 0,
"initial_scale_power": 16
}
}
args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10
model = SimpleModel(hidden_dim)
param_optimizer = list(model.named_parameters())
mask1 = torch.zeros_like(param_optimizer[0][1].data)
for col in range(mask1.size()[1]):
mask1[0][col] += 1
mask1 = torch.flatten(mask1)
optimizer_grouped_parameters = [{
'params': [param_optimizer[0][1]],
'weight_decay': 0.01,
'exp_avg_mask': mask1
},
{
'params': [param_optimizer[1][1]],
'weight_decay': 0.01
}]
@distributed_test(world_size=[2])
def _test_onebitadam_exp_avg_mask(args, model, hidden_dim):
model, optimizer, _, _ = deepspeed.initialize(args=args,
model=model,
model_parameters=optimizer_grouped_parameters)
data_loader = random_dataloader(model=model,
total_samples=50,
hidden_dim=hidden_dim,
device=model.device)
for n, batch in enumerate(data_loader):
loss = model(batch[0], batch[1])
model.backward(loss)
model.step()
# Test whether the momentum mask works
for v in optimizer.state.values():
if v['exp_avg'].size() == mask1.size():
assert torch.allclose(v['exp_avg'], v['exp_avg'].mul_(mask1.to(device=v['exp_avg'].device)), atol=1e-07), f"Momentum mask is not working properly"
_test_onebitadam_exp_avg_mask(args=args, model=model, hidden_dim=hidden_dim)
def test_onebitadam_checkpointing(tmpdir):
config_dict = {
"train_batch_size": 2,
"steps_per_print": 1,
"optimizer": {
"type": "OneBitAdam",
"params": {
"lr": 0.00015,
"weight_decay": 0.01,
"freeze_step": 2,
"cuda_aware": False,
"comm_backend_name": "nccl"
}
},
"gradient_clipping": 1.0,
"fp16": {
"enabled": True,
"loss_scale": 0,
"initial_scale_power": 16
}
}
args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10
model = SimpleModel(hidden_dim)
param_optimizer = list(model.named_parameters())
mask1 = torch.zeros_like(param_optimizer[0][1].data)
mask2 = torch.zeros_like(param_optimizer[0][1].data)
for col in range(mask1.size()[1]):
mask1[0][col] += 1
mask2[1][col] += 1
mask1 = torch.flatten(mask1)
mask2 = torch.flatten(mask2)
optimizer_grouped_parameters_1 = [{
'params': [param_optimizer[0][1]],
'weight_decay': 0.01,
'exp_avg_mask': mask1
},
{
'params': [param_optimizer[1][1]],
'weight_decay': 0.01
}]
optimizer_grouped_parameters_2 = [{
'params': [param_optimizer[0][1]],
'weight_decay': 0.01,
'exp_avg_mask': mask2
},
{
'params': [param_optimizer[1][1]],
'weight_decay': 0.01
}]
optimizer_grouped_parameters_3 = [{
'params': [param_optimizer[0][1]],
'weight_decay': 0.01
},
{
'params': [param_optimizer[1][1]],
'weight_decay': 0.01
}]
@distributed_test(world_size=[2])
def _test_onebitadam_checkpointing(mask1, mask2, args, model, hidden_dim):
model_1, optimizer_1, _, _ = deepspeed.initialize(args=args,
model=model,
model_parameters=optimizer_grouped_parameters_1)
data_loader = random_dataloader(model=model_1,
total_samples=10,
hidden_dim=hidden_dim,
device=model_1.device)
for n, batch in enumerate(data_loader):
loss = model_1(batch[0], batch[1])
model_1.backward(loss)
model_1.step()
# Test whether momentum mask still exist after saving checkpoint
assert optimizer_1.optimizer.adam_freeze_key is True
mask1 = mask1.to(device=optimizer_1.param_groups[0]['exp_avg_mask'].device)
assert torch.allclose(optimizer_1.param_groups[0]['exp_avg_mask'], mask1, atol=1e-07), f"Incorrect momentum mask"
save_folder = os.path.join(tmpdir, 'saved_checkpoint')
# optimizer_1.optimizer.gather_compression_errors()
model_1.save_checkpoint(save_folder, tag=None)
time.sleep(5)
assert torch.allclose(optimizer_1.param_groups[0]['exp_avg_mask'], mask1, atol=1e-07), f"Momentum mask should not change after saving checkpoint"
model_2, optimizer_2, _, _ = deepspeed.initialize(args=args,
model=model,
model_parameters=optimizer_grouped_parameters_2)
# Test whether momentum mask stays the same after loading checkpoint
mask2 = mask2.to(device=optimizer_2.param_groups[0]['exp_avg_mask'].device)
assert torch.allclose(optimizer_2.param_groups[0]['exp_avg_mask'], mask2, atol=1e-07), f"Incorrect momentum mask"
model_2.load_checkpoint(save_folder,
tag=None,
load_optimizer_states=True,
load_lr_scheduler_states=True)
assert torch.allclose(optimizer_2.param_groups[0]['exp_avg_mask'], mask2, atol=1e-07), f"Momentum mask should not change after loading checkpoint"
# Test whether worker&server error is resetted
for v in optimizer_2.state.values():
assert 'worker_error' not in v, f"Incorrect worker error"
assert 'server_error' not in v, f"Incorrect server error"
assert optimizer_2.optimizer.adam_freeze_key is True
model_3, optimizer_3, _, _ = deepspeed.initialize(args=args,
model=model,
model_parameters=optimizer_grouped_parameters_3)
optimizer_3.optimizer.freeze_step = 20
data_loader = random_dataloader(model=model_3,
total_samples=50,
hidden_dim=hidden_dim,
device=model_3.device)
for n, batch in enumerate(data_loader):
loss = model_3(batch[0], batch[1])
model_3.backward(loss)
model_3.step()
assert optimizer_3.optimizer.adam_freeze_key is True
# Test whether momentum mask stays the same after loading checkpoint
assert 'exp_avg_mask' not in optimizer_3.param_groups[0], f"Incorrect momentum mask"
model_3.load_checkpoint(save_folder,
tag=None,
load_optimizer_states=True,
load_lr_scheduler_states=True)
assert 'exp_avg_mask' not in optimizer_3.param_groups[0], f"Momentum mask should not change after loading checkpoint"
# Test whether worker&server error is resetted
for v in optimizer_3.state.values():
assert 'worker_error' not in v, f"Incorrect worker error"
assert 'server_error' not in v, f"Incorrect server error"
assert optimizer_3.optimizer.adam_freeze_key is False
_test_onebitadam_checkpointing(mask1,
mask2,
args=args,
model=model,
hidden_dim=hidden_dim)
def test_compressed_allreduce_basic(tmpdir):
@distributed_test(world_size=[1, 2])
def _test_compressed_allreduce_basic():
from deepspeed.runtime.comm.nccl import NcclBackend
size = dist.get_world_size()
rank = dist.get_rank()
backend = NcclBackend()
local_rank = dist.get_rank()
device = torch.device("cuda", dist.get_rank())
# A simulated compression function using torch.distributed
def torch_sim(a):
a_sign = a.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)
scale = a.norm() / np.sqrt(a.numel())
a_compressed = scale * a_sign
a_sign = None
worker_error = a - a_compressed
dist.all_reduce(a_compressed)
a_compressed.mul_(1 / dist.get_world_size())
a_server_sign = a_compressed.sign().add_(1).bool().float().add_(-0.5).mul_(
2.0)
a_list = torch.chunk(a_compressed, chunks=dist.get_world_size())
server_scale = [
chunk_a.norm() / np.sqrt(chunk_a.numel()) for chunk_a in a_list
]
a_sign_list = torch.chunk(a_server_sign, dist.get_world_size())
a_server_compressed = torch.cat(
[server_scale[i] * a_sign_list[i] for i in range(dist.get_world_size())])
rank = dist.get_rank()
server_error = a_list[rank] - server_scale[rank] * a_sign_list[rank]
torch.cuda.synchronize()
torch.distributed.barrier()
return a_server_compressed, worker_error, server_error
tensor_size = 300 * 2**20
server_size = int(tensor_size / size)
if tensor_size % (8 * size) != 0:
right_tensor_size = tensor_size + (8 * size - (tensor_size % (8 * size)))
else:
right_tensor_size = tensor_size
right_server_size = right_tensor_size // size
# Adding bias to the initialization of the gradient we are communicating
# In order to get rid of the case where some elements in the gradient are too small
a = (torch.rand(tensor_size, device=device) - 0.5) + 0.01 * rank
worker_error = torch.zeros(right_tensor_size, device=device)
server_error = torch.zeros(right_server_size, device=device)
a_torch, worker_error_torch, server_error_torch = torch_sim(a)
torch.cuda.empty_cache()
a_after = backend.compressed_allreduce(a, worker_error, server_error, local_rank)
threshold = 1e-6
magnitude_threshold = 1e-6
diff_mask = (a_after - a_torch) > threshold
diff_server_mask = torch.chunk(diff_mask, size)[rank]
mpi_server = torch.chunk(a_after, size)[rank] + server_error
torch_server = torch.chunk(a_torch, size)[rank] + server_error_torch
# If the number in the compensated_server_m is too small (e.g 1e-8), then calling sign() might be problematic
# The test would skip those numbers that are too small in compensated_server_m
check_mag_mask = mpi_server[diff_server_mask] > magnitude_threshold
if torch.sum(check_mag_mask) != 0:
print('Fails at {} of positions'.format(torch.sum(check_mag_mask)))
assert torch.sum(diff_server_mask) == 0 or torch.sum(check_mag_mask) == 0
_test_compressed_allreduce_basic()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment