@@ -49,13 +51,13 @@ class ZeroOneAdam(torch.optim.Optimizer):
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
def__init__(self,
params,
deepspeed=None,
lr=1e-3,
bias_correction=True,
betas=(0.9,
0.999),
betas=(0.9,0.999),
eps=1e-8,
eps_inside_sqrt=False,
weight_decay=0.,
...
...
@@ -102,11 +104,12 @@ class ZeroOneAdam(torch.optim.Optimizer):
ifself.comm_backend_name=='nccl':
TORCH_MAJOR=int(torch.__version__.split('.')[0])
TORCH_MINOR=int(torch.__version__.split('.')[1])
assertTORCH_MAJOR>=1andTORCH_MINOR>=8,"Please use torch 1.8 or greater to enable NCCL backend in 0/1 Adam. Alternatively, please specify 'mpi' as the 'comm_backend_name' in config file to proceed with the MPI backend"
assert(
(TORCH_MAJOR==1andTORCH_MINOR>=8)orTORCH_MAJOR>=2
),"Please use torch 1.8 or greater to enable NCCL backend in 0/1 Adam. Alternatively, please specify 'mpi' as the 'comm_backend_name' in config file to proceed with the MPI backend"
assertdist.is_initialized()==True,"Please initialize the torch distributed backend."