Unverified Commit 01726ce2 authored by Ammar Ahmad Awan's avatar Ammar Ahmad Awan Committed by GitHub
Browse files

Add 1-bit Adam support to DeepSpeed (#380)



* 1-bit adam (#353)
Co-authored-by: default avatarJeff Rasley <jerasley@microsoft.com>
Co-authored-by: default avatarYour Name <you@example.com>
Co-authored-by: default avatartanghl1994 <htang14@ur.rochester.edu>
Co-authored-by: default avatarHank <tanghl1994@gmail.com>
Co-authored-by: default avatarroot <root@node2x12b.cs.rochester.edu>
Co-authored-by: default avatarAmmar Ahmad Awan <awan.ammar@microsoft.com>
parent 234bba0c
......@@ -12,7 +12,7 @@ jobs:
cuda.version: '10.0'
pytorch.version: '1.2'
torchvision.version: '0.4.0'
runmodeltests: true
runmodeltests: false
#PyTorch15-CUDA101:
# python.version: '3.7'
# cuda.version: '10.1'
......@@ -40,6 +40,7 @@ jobs:
conda install -q --yes pip
conda install -q --yes gxx_linux-64
if [[ $(cuda.version) != "10.2" ]]; then conda install --yes -c conda-forge cudatoolkit-dev=$(cuda.version) ; fi
echo "PATH=$PATH, LD_LIBRARY_PATH=$LD_LIBRARY_PATH"
displayName: 'Setup environment python=$(python.version) pytorch=$(pytorch.version) cuda=$(cuda.version)'
# Manually install torch/torchvision first to enforce versioning.
......
import torch
import warnings
import importlib
try:
import deepspeed as ds
import deepspeed
print("deepspeed successfully imported")
except ImportError as err:
raise err
print(f"torch install path: {torch.__path__}")
print(f"torch version: {torch.__version__}")
print(f"deepspeed info: {ds.__version__}, {ds.__git_hash__}, {ds.__git_branch__}")
print(f"deepspeed install path: {deepspeed.__path__}")
print(
f"deepspeed info: {deepspeed.__version__}, {deepspeed.__git_hash__}, {deepspeed.__git_branch__}"
)
try:
apex_C = importlib.import_module('apex_C')
......@@ -17,12 +21,6 @@ try:
except Exception as err:
raise err
try:
fused_lamb = importlib.import_module('deepspeed.ops.lamb.fused_lamb_cuda')
print('deepspeed fused lamb kernels successfully installed')
except Exception as err:
raise err
try:
from apex.optimizers import FP16_Optimizer
print("using old-style apex")
......@@ -30,8 +28,19 @@ except ImportError:
print("using new-style apex")
try:
ds_transformer = importlib.import_module(
'deepspeed.ops.transformer.transformer_cuda')
print('deepspeed transformer kernels successfully installed')
importlib.import_module('deepspeed.ops.lamb.fused_lamb_cuda')
print('deepspeed lamb successfully installed.')
except Exception as err:
raise err
warnings.warn("deepspeed lamb is NOT installed.")
try:
importlib.import_module('deepspeed.ops.transformer.transformer_cuda')
print('deepspeed transformer kernels successfully installed.')
except Exception as err:
warnings.warn('deepspeed transformer kernels are NOT installed.')
try:
importlib.import_module('deepspeed.ops.sparse_attention.cpp_utils')
print('deepspeed sparse attention successfully installed.')
except ImportError:
warnings.warn('deepspeed sparse attention is NOT installed.')
# Copyright 2020 The Microsoft DeepSpeed Team
#############################################
# Torch distributed constants
#############################################
TORCH_DISTRIBUTED_DEFAULT_PORT = 29500
PDSH_LAUNCHER = 'pdsh'
PDSH_MAX_FAN_OUT = 1024
OPENMPI_LAUNCHER = 'openmpi'
MVAPICH_LAUNCHER = 'mvapich'
MVAPICH_TMP_HOSTFILE = '/tmp/deepspeed_mvapich_hostfile'
# Copyright 2020 The Microsoft DeepSpeed Team
"""
Copyright 2020 The Microsoft DeepSpeed Team: deepspeed@microsoft.com
DeepSpeed launcher, this is similar to torch.distributed.launch but supports
additional features such as abitrary gpu exclusion.
deepspeed.launcher.launch is intended to be run on a single worker node and
will spawn several worker sub-processes depending on how many devices/ranks
are on the worker.
"""
import sys
......@@ -10,7 +16,8 @@ import base64
from collections import defaultdict
from argparse import ArgumentParser, REMAINDER
from deepspeed.utils import logger
from .constants import TORCH_DISTRIBUTED_DEFAULT_PORT
from ..utils import logger
def parse_args():
......@@ -32,7 +39,7 @@ def parse_args():
" single node multi-proc training, the"
" --master_addr can simply be 127.0.0.1")
parser.add_argument("--master_port",
default=29500,
default=TORCH_DISTRIBUTED_DEFAULT_PORT,
type=int,
help="Master node (rank 0)'s free port that needs to "
"be used for communication during distributed "
......
import os
import sys
import shutil
import subprocess
import warnings
from abc import ABC, abstractmethod
from ..utils import logger
from .constants import PDSH_MAX_FAN_OUT, MVAPICH_TMP_HOSTFILE
class MultiNodeRunner(ABC):
def __init__(self, args, world_info_base64):
self.args = args
self.user_arguments = self.parse_user_args()
self.user_script = args.user_script
self.world_info_base64 = world_info_base64
self.exports = {}
@abstractmethod
def backend_exists(self):
pass
@abstractmethod
def get_cmd(self, environment, active_resources):
pass
def add_export(self, key, var):
self.exports[key.strip()] = var.strip()
def parse_user_args(self):
return self.args.user_args
class PDSHRunner(MultiNodeRunner):
def __init__(self, args, world_info_base64):
super().__init__(args, world_info_base64)
def backend_exists(self):
return shutil.which('pdsh')
def parse_user_args(self):
return list(
map(lambda x: x if x.startswith("-") else "'{}'".format(x),
self.args.user_args))
def get_cmd(self, environment, active_resources):
environment['PDSH_RCMD_TYPE'] = 'ssh'
active_workers = ",".join(active_resources.keys())
logger.info("Running on the following workers: %s" % active_workers)
# PDSH flags for max node fan out and specific hosts to launch on
# See https://linux.die.net/man/1/pdsh for flag details
pdsh_cmd_args = ['pdsh', '-f', str(PDSH_MAX_FAN_OUT), '-w', active_workers]
exports = ""
for key, val in self.exports.items():
exports += "export {}={}; ".format(key, val)
deepspeed_launch = [
exports,
"cd {};".format(os.path.abspath('.')),
sys.executable,
"-u",
"-m",
"deepspeed.launcher.launch",
'--world_info={}'.format(self.world_info_base64),
"--node_rank=%n",
"--master_addr={}".format(self.args.master_addr),
"--master_port={}".format(self.args.master_port)
]
return pdsh_cmd_args + deepspeed_launch + [self.user_script
] + self.user_arguments
class OpenMPIRunner(MultiNodeRunner):
def __init__(self, args, world_info_base64, resource_pool):
super().__init__(args, world_info_base64)
self.resource_pool = resource_pool
self.add_export('UCX_TLS', 'tcp')
def backend_exists(self):
#TODO: if IB is available we should suggestion mvapich
return shutil.which('ompi_info')
def get_cmd(self, environment, active_resources):
#TODO: Allow for include/exclude at node-level but not gpu-level
assert self.args.include == "" and self.args.exclude == "", 'openmpi backend does not support worker include/exclusion'
assert self.args.num_nodes == -1 and self.args.num_gpus == -1, 'openmpi backend does not support limiting num nodes/gpus'
total_process_count = sum(self.resource_pool.values())
mpirun_cmd = [
'mpirun',
'-n',
f'{total_process_count}',
'-hostfile',
f'{self.args.hostfile}',
'--mca',
'btl',
'^openib',
'--mca',
'btl_tcp_if_include',
'eth0',
]
export_cmd = []
for k, v in self.exports.items():
export_cmd += ['-x', f'{k}={v}']
python_exec = [sys.executable, "-u"]
return mpirun_cmd + export_cmd + python_exec + [self.user_script
] + self.user_arguments
class MVAPICHRunner(MultiNodeRunner):
def __init__(self, args, world_info_base64, resource_pool):
super().__init__(args, world_info_base64)
self.resource_pool = resource_pool
# Disable the CMA kernel module, not available on Ubuntu systems
self.add_export('MV2_SMP_USE_CMA', '0')
# If we fail this will output more verbose logging
self.add_export('MV2_DEBUG_SHOW_BACKTRACE', '1')
# Enabled cuda-aware communication
self.add_export('MV2_USE_CUDA', '1')
# Support deep learning frameworks: http://hidl.cse.ohio-state.edu/userguide/horovod/
self.add_export('MV2_SUPPORT_DL', '1')
# Support MPI_THREAD_MULTIPLE
self.add_export('MV2_ENABLE_AFFINITY', '0')
# Performance tuning flags for allgather
self.add_export('MV2_INTER_ALLGATHER_TUNING', '5')
self.add_export('MV2_CUDA_USE_NAIVE', '0')
def backend_exists(self):
#TODO: if IB is available we should suggestion mvapich
mpiname_exists = shutil.which('mpiname')
exists = False
if not mpiname_exists:
warnings.warn("mpiname does not exist, mvapich is not installed properly")
else:
results = subprocess.check_output('mpiname', shell=True)
mpiname_results = results.decode('utf-8').strip()
if "MVAPICH2-GDR" in mpiname_results:
exists = True
else:
warnings.warn(
f"Expected MVAPICH2-GDR as return for mpiname but received {mpiname_results}"
)
return exists
def get_cmd(self, environment, active_resources):
#TODO: Allow for include/exclude at node-level but not gpu-level
assert self.args.include == "" and self.args.exclude == "", 'mvapich backend does not support worker include/exclusion'
assert self.args.num_nodes == -1 and self.args.num_gpus == -1, 'mvapich backend does not support limiting num nodes/gpus'
devices_per_node = self.resource_pool.values()
total_process_count = sum(devices_per_node)
process_per_node = list(devices_per_node)[0]
assert all([n == process_per_node for n in devices_per_node]), "mvapich requires same number of devices per node"
with open(MVAPICH_TMP_HOSTFILE, 'w') as fd:
for host in self.resource_pool.keys():
fd.write(f'{host}\n')
mpirun_cmd = [
'mpirun',
'-np',
f'{total_process_count}',
'-ppn',
f'{process_per_node}',
'--hostfile',
f'{MVAPICH_TMP_HOSTFILE}',
]
export_cmd = []
for k, v in self.exports.items():
export_cmd += ['-env', f'{k}={v}']
python_exec = [sys.executable, "-u"]
return mpirun_cmd + export_cmd + python_exec + [self.user_script
] + self.user_arguments
# Copyright 2020 The Microsoft DeepSpeed Team
"""
Copyright 2020 The Microsoft DeepSpeed Team
DeepSpeed runner is the main front-end to launching multi-worker
training jobs with DeepSpeed. By default this uses pdsh to parallel
ssh into multiple worker nodes and launch all the neccisary processes
per rank for training.
"""
import os
......@@ -14,11 +18,13 @@ from copy import deepcopy
import torch.cuda
from deepspeed.runtime.constants import TORCH_DISTRIBUTED_DEFAULT_PORT
from deepspeed.utils import logger
from .multinode_runner import PDSHRunner, OpenMPIRunner, MVAPICHRunner
from .constants import TORCH_DISTRIBUTED_DEFAULT_PORT, \
PDSH_LAUNCHER, OPENMPI_LAUNCHER, MVAPICH_LAUNCHER
from ..utils import logger
DLTS_HOSTFILE = "/job/hostfile"
EXPORT_ENVS = ["NCCL", "PYTHON"]
EXPORT_ENVS = ["NCCL", "PYTHON", "MV2", 'UCX']
DEEPSPEED_ENVIRONMENT_NAME = ".deepspeed_env"
DEEPSPEED_ENVIRONMENT_PATHS = [os.path.expanduser("~"), '.']
PDSH_MAX_FAN_OUT = 1024
......@@ -62,12 +68,20 @@ def parse_args(args=None):
resources except slot 0 on worker-1.
''')
parser.add_argument("--num_nodes", type=int, default=-1, help="")
parser.add_argument("--num_nodes",
type=int,
default=-1,
help="Total number of worker nodes to run on, this will use "
"the top N hosts from the given hostfile.")
parser.add_argument("--num_gpus", type=int, default=-1, help="")
parser.add_argument("--num_gpus",
type=int,
default=-1,
help="Max number of GPUs to use on each node, will use "
"[0:N) GPU ids on each node.")
parser.add_argument("--master_port",
default=int(TORCH_DISTRIBUTED_DEFAULT_PORT),
default=TORCH_DISTRIBUTED_DEFAULT_PORT,
type=int,
help="(optional) Port used by PyTorch distributed for "
"communication during training.")
......@@ -78,6 +92,18 @@ def parse_args(args=None):
help="(optional) IP address of node 0, will be "
"inferred via 'hostname -I' if not specified.")
parser.add_argument("--launcher",
default=PDSH_LAUNCHER,
type=str,
help="(optional) choose launcher backend for multi-node"
"training. Options currently include PDSH, OpenMPI, MVAPICH.")
parser.add_argument("--launcher_args",
default="",
type=str,
help="(optional) pass launcher specific arguments as a "
"single quoted argument.")
parser.add_argument("user_script",
type=str,
help="User script to launch, followed by any required "
......@@ -292,17 +318,18 @@ def main(args=None):
]
cmd = deepspeed_launch + [args.user_script] + args.user_args
else:
env['PDSH_RCMD_TYPE'] = 'ssh'
active_workers = ",".join(active_resources.keys())
logger.info("Running on the following workers: %s" % active_workers)
# PDSH flags for max node fan out and specific hosts to launch on
# See https://linux.die.net/man/1/pdsh for flag details
pdsh_cmd_args = ['pdsh', '-f', str(PDSH_MAX_FAN_OUT), '-w', active_workers]
args.launcher = args.launcher.lower()
if args.launcher == PDSH_LAUNCHER:
runner = PDSHRunner(args, world_info_base64)
elif args.launcher == OPENMPI_LAUNCHER:
runner = OpenMPIRunner(args, world_info_base64, resource_pool)
elif args.launcher == MVAPICH_LAUNCHER:
runner = MVAPICHRunner(args, world_info_base64, resource_pool)
else:
raise NotImplementedError(f"Unknown launcher {args.launcher}")
num_nodes = len(active_resources.keys())
num_gpus_per_node = None
if not runner.backend_exists():
raise RuntimeError(f"launcher '{args.launcher}' not installed.")
curr_path = os.path.abspath('.')
if 'PYTHONPATH' in env:
......@@ -312,33 +339,20 @@ def main(args=None):
exports = ""
for var in env.keys():
if any(map(lambda name: var.startswith(name), EXPORT_ENVS)):
exports += "export {}={}; ".format(var, env[var])
if any([var.startswith(name) for name in EXPORT_ENVS]):
runner.add_export(var, env[var])
for environ_path in DEEPSPEED_ENVIRONMENT_PATHS:
environ_file = os.path.join(environ_path, DEEPSPEED_ENVIRONMENT_NAME)
if os.path.isfile(environ_file):
with open(environ_file, 'r') as fd:
for var in fd.readlines():
exports += "export {}; ".format(var.strip())
key, val = var.split('=')
runner.add_export(key, val)
deepspeed_launch = [
exports,
"cd {};".format(curr_path),
sys.executable,
"-u",
"-m",
"deepspeed.launcher.launch",
'--world_info={}'.format(world_info_base64),
"--node_rank=%n",
"--master_addr={}".format(args.master_addr),
"--master_port={}".format(args.master_port)
]
user_args = list(
map(lambda x: x if x.startswith("-") else "'{}'".format(x),
args.user_args))
cmd = pdsh_cmd_args + deepspeed_launch + [args.user_script] + user_args
logger.info("cmd={}".format(cmd))
cmd = runner.get_cmd(env, active_resources)
logger.info("cmd = {}".format(' '.join(cmd)))
result = subprocess.Popen(cmd, env=env)
result.wait()
......
......@@ -14,9 +14,10 @@ from deepspeed.runtime.activation_checkpointing.config import DeepSpeedActivatio
from deepspeed.utils import logger
TENSOR_CORE_ALIGN_SIZE = 8
ONEBIT_ADAM_OPTIMIZER = 'onebitadam'
ADAM_OPTIMIZER = 'adam'
LAMB_OPTIMIZER = 'lamb'
DEEPSPEED_OPTIMIZERS = [ADAM_OPTIMIZER, LAMB_OPTIMIZER]
DEEPSPEED_OPTIMIZERS = [ADAM_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER]
def get_amp_enabled(param_dict):
......
'''
Copyright 2019 The Microsoft DeepSpeed Team
'''
from mpi4py import MPI
import numpy as np
import cupy
def my_igather(rank, size, comm, sendbuf, recbuf, root):
req = []
if rank == root:
for idx in range(size):
if idx != rank:
req.append(comm.Irecv(recbuf[idx], source=idx))
else:
recbuf[rank] = sendbuf
else:
req.append(comm.Isend(sendbuf, dest=root))
return req
def gather_cuda(rank,
world_size,
comm,
cupy_sign_list_packed,
cupy_recvbuf_sign,
cupy_worker_scale,
cupy_recvbuf_scale):
# We do in-place operations on cupy buffers so we do not return any buffers
requests = []
for idx in range(world_size):
req_sign = my_igather(rank,
world_size,
comm,
cupy_sign_list_packed[idx],
cupy_recvbuf_sign,
root=idx)
requests += req_sign
for idx in range(world_size):
req_scale = my_igather(rank,
world_size,
comm,
cupy_worker_scale,
cupy_recvbuf_scale,
root=idx)
requests += req_scale
MPI.Request.Waitall(requests)
def gather_host(rank,
world_size,
comm,
cupy_sign_list_packed,
cupy_recvbuf_sign,
cupy_worker_scale,
cupy_recvbuf_scale):
# In-place operations are not possible for newly created cupy arrays
# so we need to return the new buffers
numpy_recvbuf_sign = np.zeros([world_size,
cupy_sign_list_packed[rank].size],
dtype=cupy_sign_list_packed[0].dtype)
numpy_recvbuf_scale = np.zeros([world_size, 1], dtype=cupy_worker_scale.dtype)
# 1. convert from cupy to numpy
numpy_sign_list_packed = cupy_sign_list_packed
for idx in range(world_size):
numpy_sign_list_packed[idx] = cupy.asnumpy(cupy_sign_list_packed[idx])
numpy_worker_scale = cupy.asnumpy(cupy_worker_scale)
numpy_recvbuf_scale = cupy.asnumpy(cupy_recvbuf_scale)
cupy.cuda.get_current_stream().synchronize()
# 2. use numpy buffers for communication
requests = []
for idx in range(world_size):
req_sign = my_igather(rank,
world_size,
comm,
numpy_sign_list_packed[idx],
numpy_recvbuf_sign,
root=idx)
requests += req_sign
for idx in range(world_size):
req_scale = my_igather(rank,
world_size,
comm,
numpy_worker_scale,
numpy_recvbuf_scale,
root=idx)
requests += req_scale
MPI.Request.Waitall(requests)
# 3. Convert back from numpy to cupy
cupy_recvbuf_sign = cupy.asarray(numpy_recvbuf_sign)
for idx in range(world_size):
cupy_sign_list_packed[idx] = cupy.asarray(numpy_sign_list_packed[idx])
cupy_worker_scale = cupy.asarray(numpy_worker_scale)
cupy_recvbuf_scale = cupy.asarray(numpy_recvbuf_scale)
cupy.cuda.get_current_stream().synchronize()
return cupy_sign_list_packed, cupy_recvbuf_sign, cupy_worker_scale, cupy_recvbuf_scale
def allgather_cuda(comm,
cupy_server_sign_packed,
cupy_recvbuf_sign_server,
cupy_server_scale,
cupy_recvbuf_scale_server):
comm.Allgather(cupy_server_sign_packed, cupy_recvbuf_sign_server)
comm.Allgather(cupy_server_scale, cupy_recvbuf_scale_server)
def allgather_host(comm,
cupy_server_sign_packed,
cupy_recvbuf_sign_server,
cupy_server_scale,
cupy_recvbuf_scale_server):
# 1. Convert cupy to numpy
numpy_recvbuf_sign_server = np.zeros([comm.Get_size(),
cupy_server_sign_packed.size],
dtype=cupy_server_sign_packed.dtype)
numpy_recvbuf_scale_server = np.zeros([comm.Get_size(),
1],
dtype=cupy_server_scale.dtype)
numpy_server_sign_packed = cupy.asnumpy(cupy_server_sign_packed)
numpy_recvbuf_sign_server = cupy.asnumpy(cupy_recvbuf_sign_server)
numpy_server_scale = cupy.asnumpy(cupy_server_scale)
numpy_recvbuf_scale_server = cupy.asnumpy(cupy_recvbuf_scale_server)
cupy.cuda.get_current_stream().synchronize()
# 2. Communicate numpy buffers
comm.Allgather(numpy_server_sign_packed, numpy_recvbuf_sign_server)
comm.Allgather(numpy_server_scale, numpy_recvbuf_scale_server)
comm.Barrier()
# 3. Convert numpy back to cupy
cupy_server_sign_packed = cupy.asarray(numpy_server_sign_packed)
cupy_recvbuf_sign_server = cupy.asarray(numpy_recvbuf_sign_server)
cupy_server_scale = cupy.asarray(numpy_server_scale)
cupy_recvbuf_scale_server = cupy.asarray(numpy_recvbuf_scale_server)
cupy.cuda.get_current_stream().synchronize()
return cupy_server_sign_packed, cupy_recvbuf_sign_server, cupy_server_scale, cupy_recvbuf_scale_server
......@@ -18,7 +18,8 @@ from deepspeed.runtime.activation_checkpointing import checkpointing as activati
from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer
from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer
from deepspeed.runtime.config import DeepSpeedConfig, \
ADAM_OPTIMIZER, LAMB_OPTIMIZER, DEEPSPEED_OPTIMIZERS
ADAM_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, DEEPSPEED_OPTIMIZERS
from deepspeed.runtime.dataloader import DeepSpeedDataLoader
from deepspeed.runtime.constants import \
ROUTE_TRAIN, ROUTE_PREDICT, ROUTE_EVAL, \
......@@ -27,8 +28,6 @@ from deepspeed.runtime.constants import \
from deepspeed.runtime.csr_tensor import CSRTensor
import deepspeed.runtime.lr_schedules as lr_schedules
from deepspeed.ops.lamb import FusedLamb
from deepspeed.utils import logger
from deepspeed.utils.timer import ThroughputTimer, SynchronizedWallClockTimer
......@@ -122,6 +121,7 @@ class DeepSpeedEngine(Module):
self.config_params = config_params
self.loaded_checkpoint_mp_world_size = None
self.loaded_checkpoint_dp_world_size = None
self.enable_backward_allreduce = True
if dist_init_required is None:
dist_init_required = not dist.is_initialized()
......@@ -527,6 +527,7 @@ class DeepSpeedEngine(Module):
def _configure_basic_optimizer(self, model_parameters):
optimizer_parameters = self.optimizer_params()
# print(optimizer_parameters.keys())
if 'max_grad_norm' in optimizer_parameters.keys():
raise ValueError(
"'max_grad_norm' is not supported as an optimizer parameter, please switch to using the deepspeed parameter 'gradient_clipping' see: https://www.deepspeed.ai/docs/config-json/#gradient-clipping for more details"
......@@ -535,7 +536,11 @@ class DeepSpeedEngine(Module):
from apex.optimizers.fused_adam import FusedAdam
optimizer = FusedAdam(model_parameters, **optimizer_parameters)
elif self.optimizer_name() == LAMB_OPTIMIZER:
from deepspeed.ops.lamb import FusedLamb
optimizer = FusedLamb(model_parameters, **optimizer_parameters)
elif self.optimizer_name() == ONEBIT_ADAM_OPTIMIZER:
from deepspeed.runtime.fp16.onebit_adam import OnebitAdam
optimizer = OnebitAdam(model_parameters, self, **optimizer_parameters)
else:
torch_optimizer = getattr(torch.optim, self.optimizer_name())
optimizer = torch_optimizer(model_parameters, **optimizer_parameters)
......@@ -545,7 +550,8 @@ class DeepSpeedEngine(Module):
initial_dynamic_scale = self.initial_dynamic_scale()
dynamic_loss_args = self.dynamic_loss_scale_args()
clip_grad = self.gradient_clipping()
if self.optimizer_name() == ADAM_OPTIMIZER:
if self.optimizer_name() == ADAM_OPTIMIZER or self.optimizer_name(
) == ONEBIT_ADAM_OPTIMIZER:
if self.dynamic_loss_scale():
logger.info('Creating fp16 optimizer with dynamic loss scale')
timers = self.timers if self.wall_clock_breakdown() else None
......@@ -734,7 +740,7 @@ class DeepSpeedEngine(Module):
else:
self.buffered_allreduce_fallback(elements_per_buffer=bucket_size)
def backward(self, loss, allreduce_gradients=True):
def backward(self, loss, allreduce_gradients=True, release_loss=False):
r"""Execute backward pass on the loss
Arguments:
......@@ -796,7 +802,7 @@ class DeepSpeedEngine(Module):
self.timers('backward_allreduce_microstep').start()
self.timers('backward_allreduce').start()
if allreduce_gradients:
if allreduce_gradients and self.enable_backward_allreduce:
self.allreduce_gradients()
if self.wall_clock_breakdown():
......@@ -805,6 +811,10 @@ class DeepSpeedEngine(Module):
self.timers('backward').stop()
self.timers('backward_microstep').stop()
if release_loss:
# loss.data = None
pass
return loss
def is_gradient_accumulation_boundary(self):
......
......@@ -101,6 +101,20 @@ class FP16_Optimizer(object):
self.overflow = False
self.overflow_checker = CheckOverflow(self.fp16_groups, mpu=self.mpu)
self.initialize_optimizer_states()
def initialize_optimizer_states(self):
for i, group in enumerate(self.fp16_groups):
self.fp32_groups_flat[i].grad = torch.zeros(
self.fp32_groups_flat[i].size(),
device=self.fp32_groups_flat[i].device)
self.optimizer.step()
for i, group in enumerate(self.fp16_groups):
self.fp32_groups_flat[i].grad = None
return
def zero_grad(self, set_grads_to_None=True):
"""
......@@ -204,6 +218,9 @@ class FP16_Optimizer(object):
if p.grad is None else p.grad.to(data_type) for p in group
]))
for p in group:
p.grad = None
self.fp32_groups_flat[i].grad = grads_groups_flat[i]
self.start_timers([COMPUTE_NORM])
......@@ -223,6 +240,7 @@ class FP16_Optimizer(object):
"scale: {}, reducing to {}".format(prev_scale,
self.cur_scale))
self.log_timers(OVERFLOW_TIMERS)
grads_groups_flat = None
return self.overflow
self.start_timers([UNSCALE_AND_CLIP])
......
'''
Copyright 2020 The Microsoft DeepSpeed Team
'''
import types
import torch
import importlib
import numpy as np
import time
import cupy
from torch.utils.dlpack import to_dlpack
from torch.utils.dlpack import from_dlpack
from deepspeed.utils.logging import logger
from mpi4py import MPI
from deepspeed.runtime.custom_collectives import gather_cuda, gather_host, allgather_cuda, allgather_host
class OnebitAdam(torch.optim.Optimizer):
"""Implements the 1-bit Adam algorithm. Currently GPU-only.
For usage example please see, TODO DeepSpeed Tutorial
It has been proposed in APMSqueeze (https://arxiv.org/abs/2008.11343)
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups.
lr (float, optional): learning rate. (default: 1e-3)
freeze_step (int, optional): Number of steps for warmup (uncompressed)
stage before we start using compressed communication. (default 100000)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square. (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability. (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
max_coeff(float, optional): maximum value of the lamb coefficient (default: 10.0)
min_coeff(float, optional): minimum value of the lamb coefficient (default: 0.01)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False) NOT SUPPORTED in 1-bit Adam!
eps_inside_sqrt (boolean, optional): in the 'update parameters' step,
adds eps to the bias-corrected second moment estimate before
evaluating square root instead of adding it to the square root of
second moment estimate as in the original paper. (default: False)
cuda_aware (boolean, required): Set True if the underlying MPI implementation
supports CUDA-Aware communication. (default: False)
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
def __init__(self,
params,
deepspeed=None,
lr=1e-3,
freeze_step=100000,
bias_correction=True,
betas=(0.9,
0.999),
eps=1e-8,
eps_inside_sqrt=False,
weight_decay=0.,
max_grad_norm=0.,
amsgrad=False,
cuda_aware=False):
if amsgrad:
raise RuntimeError('1-bit Adam does not support the AMSGrad variant.')
defaults = dict(lr=lr,
bias_correction=bias_correction,
betas=betas,
eps=eps,
weight_decay=weight_decay,
max_grad_norm=max_grad_norm)
super(OnebitAdam, self).__init__(params, defaults)
from mpi4py import MPI
self.eps_mode = 0 if eps_inside_sqrt else 1
self.comm = MPI.COMM_WORLD
self.rank = self.comm.Get_rank()
self.size = self.comm.Get_size()
self.comm_time = 0.0
self.step_time = 0.0
self.ave_step = 1
self.bk_time = 0.0
self.divider = int(self.size * 8 / np.gcd(self.size, 8))
self.deepspeed = deepspeed
self.adam_freeze_key = False
self.initialize = False
self.freeze_step = freeze_step
self.cuda_aware = cuda_aware
def torch2cupy(self, tensor):
return cupy.fromDlpack(to_dlpack(tensor))
def cupy2torch(self, cupy_tensor):
return from_dlpack(cupy_tensor.toDlpack())
def compress_by_chunk(self, cupy_bool_tensor, num_chunks):
packed_sign = cupy.packbits(cupy_bool_tensor)
sign_list_packed = cupy.split(packed_sign, num_chunks)
cupy.cuda.get_current_stream().synchronize()
return sign_list_packed
def Compressed_Allreduce(self,
buffer_m: torch.tensor,
worker_error,
server_error,
rank,
world_size,
comm,
local_rank):
all_start_time = time.time()
original_size = buffer_m.numel()
cupy.cuda.Device(local_rank).use()
if torch.numel(buffer_m) != torch.numel(worker_error):
empty_tensor = torch.zeros(torch.numel(worker_error) - torch.numel(buffer_m),
device=buffer_m.device)
buffer_m = torch.cat([buffer_m, empty_tensor])
buffer_m.add_(worker_error)
worker_scale = torch.norm(buffer_m) / np.sqrt(torch.numel(buffer_m))
sign_buffer_m = buffer_m.sign().add_(1).bool()
sign_buffer_m = sign_buffer_m.float()
sign_buffer_m.add_(-0.5).mul_(2.0)
worker_error.set_((buffer_m - worker_scale * sign_buffer_m))
sign_buffer_m = None
compensated_buffer_m = buffer_m
compensated_buffer_m.sign_()
compensated_buffer_m = compensated_buffer_m.add_(1).bool()
cupy_worker_scale = self.torch2cupy(worker_scale)
cupy_compensated_buffer_m = self.torch2cupy(compensated_buffer_m)
compensated_buffer_m = None
cupy_sign_list_packed = self.compress_by_chunk(cupy_compensated_buffer_m,
world_size)
cupy_compensated_buffer_m = None
cupy_recvbuf_sign = cupy.zeros([world_size,
cupy_sign_list_packed[rank].size],
dtype=cupy_sign_list_packed[0].dtype)
cupy_recvbuf_scale = cupy.zeros([world_size, 1], dtype=cupy_worker_scale.dtype)
# Communication Phase 1
gather_start = time.time()
if self.cuda_aware:
gather_cuda(rank,
world_size,
comm,
cupy_sign_list_packed,
cupy_recvbuf_sign,
cupy_worker_scale,
cupy_recvbuf_scale)
else:
cupy_sign_list_packed, cupy_recvbuf_sign, cupy_worker_scale, cupy_recvbuf_scale = gather_host(rank,
world_size,
comm,
cupy_sign_list_packed,
cupy_recvbuf_sign,
cupy_worker_scale,
cupy_recvbuf_scale)
gather_end = time.time()
cupy_unpacked_sign = (cupy.unpackbits(cupy_recvbuf_sign.flatten())).reshape(
world_size,
-1)
cupy_recvbuf_sign = None
unpacked_sign = self.cupy2torch(cupy_unpacked_sign).float()
cupy_unpacked_sign = None
unpacked_sign = unpacked_sign.add_(-0.5).mul_(2.0)
worker_scale = self.cupy2torch(cupy_recvbuf_scale).mul_(1 / world_size)
compensated_server_m = unpacked_sign.mul_(worker_scale).sum(0)
unpacked_sign = None
compensated_server_m.add_(server_error)
server_scale = torch.norm(compensated_server_m) / np.sqrt(
compensated_server_m.numel())
sign_server_m = compensated_server_m.sign().add_(1).bool()
sign_server_m = sign_server_m.float()
sign_server_m.add_(-0.5).mul_(2.0)
server_error.set_(compensated_server_m - server_scale * sign_server_m)
sign_server_m = None
compensated_server_m.sign_()
compensated_server_m = compensated_server_m.add_(1).bool()
cupy_server_scale = self.torch2cupy(server_scale)
cupy_compensated_server_m = self.torch2cupy(compensated_server_m)
compensated_server_m = None
cupy_server_sign_packed = self.compress_by_chunk(cupy_compensated_server_m, 1)
cupy_recvbuf_sign_server = cupy.zeros(
[world_size,
cupy_server_sign_packed[0].size],
dtype=cupy_sign_list_packed[0].dtype)
cupy_recvbuf_scale_server = cupy.zeros([world_size,
1],
dtype=cupy_worker_scale.dtype)
# Communication Phase 2
if self.cuda_aware:
allgather_cuda(comm,
cupy_server_sign_packed[0],
cupy_recvbuf_sign_server,
cupy_server_scale,
cupy_recvbuf_scale_server)
else:
cupy_server_sign_packed[0], cupy_recvbuf_sign_server, cupy_server_scale, cupy_recvbuf_scale_server = allgather_host(comm,
cupy_server_sign_packed[0],
cupy_recvbuf_sign_server,
cupy_server_scale,
cupy_recvbuf_scale_server)
cupy_server_unpacked_sign = (cupy.unpackbits(
cupy_recvbuf_sign_server.flatten())).reshape(world_size,
-1)
cupy_recvbuf_sign_server = None
server_unpacked_sign = self.cupy2torch(cupy_server_unpacked_sign)
cupy_server_unpacked_sign = None
server_unpacked_sign = server_unpacked_sign.float().add_(-0.5).mul_(2.0)
server_scale = self.cupy2torch(cupy_recvbuf_scale_server)
buffer_m = server_unpacked_sign.mul_(server_scale).flatten()[0:original_size]
return buffer_m
def step(self, closure=None, grads=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
grads (list of tensors, optional): weight gradient to use for the
optimizer update. If gradients have type torch.half, parameters
are expected to be in type torch.float. (default: None)
output params (list of tensors, optional): A reduced recision copy
of the updated weights written out in addition to the regular
updated weights. Have to be of same type as gradients. (default: None)
scale (float, optional): factor to divide gradient tensor values
by before applying to weights. (default: 1)
"""
loss = None
if closure is not None:
loss = closure()
gather_time = 0
allgather_time = 0
all_time = 0
if self.adam_freeze_key is False:
v_diff_buffer = 0.0
if grads is None:
grads_group = [None] * len(self.param_groups)
# backward compatibility
# assuming a list/generator of parameter means single group
elif isinstance(grads, types.GeneratorType):
grads_group = [grads]
elif type(grads[0]) != list:
grads_group = [grads]
else:
grads_group = grads
for group, grads_this_group in zip(self.param_groups, grads_group):
if grads_this_group is None:
grads_this_group = [None] * len(group['params'])
bias_correction = 1 if group['bias_correction'] else 0
for p, grad in zip(group['params'], grads_this_group):
if p.grad is None and grad is None:
continue
if grad is None:
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError(
'FusedAdam does not support sparse gradients, please consider SparseAdam instead'
)
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data)
state['tensor_size'] = torch.numel(p.data)
state['corrected_tensor_size'] = state['tensor_size']
if state['tensor_size'] % (self.size * self.divider) != 0:
state['corrected_tensor_size'] += ((self.size * self.divider) -
(state['tensor_size'] %
(self.size * self.divider)))
state['server_chunk_size'] = state[
'corrected_tensor_size'] // self.size
if not self.initialize or (self.adam_freeze_key
and 'worker_error' not in state.keys()):
torch.cuda.empty_cache()
state['worker_error'] = torch.zeros(state['corrected_tensor_size'],
device=p.device)
state['server_error'] = torch.zeros(state['server_chunk_size'],
device=p.device)
torch.cuda.empty_cache()
self.adam_freeze_key = True
if not self.initialize and torch.distributed.get_rank() == 0:
print("Cupy Buffers Initialized Successfully.")
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['betas']
state['step'] += 1
if self.adam_freeze_key is False:
exp_avg.mul_(beta1).add_(1 - beta1, grad)
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
grad = None
if self.initialize:
update = exp_avg / (exp_avg_sq.sqrt() + group['eps'])
else:
if 'non_freeze' in group.keys() and group['non_freeze'] is True:
dist.all_reduce(grad)
grad.mul_(1 / dist.get_world_size())
exp_avg.mul_(beta1).add(1 - beta1, grad)
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
grad = None
else:
if self.initialize is True:
exp_avg.mul_(beta1).add_(1 - beta1, grad)
grad = None
if self.size > 1:
exp_avg.set_(
self.Compressed_Allreduce(exp_avg,
state['worker_error'],
state['server_error'],
self.rank,
self.size,
self.comm,
self.deepspeed.local_rank))
if self.initialize:
update = exp_avg / (exp_avg_sq.sqrt() + group['eps'])
if self.initialize:
if group['weight_decay'] > 0.0:
update += group['weight_decay'] * p.data
with torch.no_grad():
p.add_(-group['lr'] * update)
if not self.initialize:
print('Pop out errors', flush=True)
state.pop('worker_error')
state.pop('server_error')
if not self.initialize:
self.adam_freeze_key = False
self.initialize = True
print(
f"Finished the initialization step at rant {torch.distributed.get_rank()}"
)
return loss
if self.adam_freeze_key is False:
if state['step'] >= self.freeze_step:
self.adam_freeze_key = True
self.deepspeed.enable_backward_allreduce = False
return loss
......@@ -8,6 +8,7 @@ Helper functions and classes from multiple sources.
import torch
from torch._six import inf
import torch.distributed as dist
from deepspeed.utils import logger
......@@ -23,7 +24,8 @@ class CheckOverflow(object):
for param in group:
self.params.append(param)
def check_using_norm(self, norm_group):
def check_using_norm(self, norm_group, reduce_overflow=True):
#TODO: I don't think reduce_overflow is needed if mpu is None
overflow = -1 in norm_group
if self.mpu is not None:
......@@ -32,6 +34,11 @@ class CheckOverflow(object):
op=torch.distributed.ReduceOp.MAX,
group=self.mpu.get_model_parallel_group())
overflow = overflow_gpu[0].item()
elif reduce_overflow:
cuda_overflow = torch.cuda.FloatTensor([overflow])
dist.all_reduce(cuda_overflow, op=torch.distributed.ReduceOp.MAX)
dist.barrier()
overflow = cuda_overflow[0].item()
return bool(overflow)
......
......@@ -9,8 +9,7 @@ date: 2020-05-15
* Please see our [Azure tutorial](/tutorials/azure/) to get started with DeepSpeed on Azure!
* If you're not on Azure, we recommend using our docker image via `docker pull deepspeed/deepspeed:latest` which contains a pre-installed version of DeepSpeed and all the necessary dependencies.
* If you want to install DeepSpeed manually, we provide an install script
* `install.sh` to help install on a local machine or across an entire cluster.
* If you want to install DeepSpeed manually, we provide an install script `install.sh` to help install on a local machine or across an entire cluster.
## Writing DeepSpeed Models
DeepSpeed model training is accomplished using the DeepSpeed engine. The engine
......
---
title: "1-bit Adam: Up to 5x less communication volume and up to 2x faster training"
---
In this tutorial, we are going to introduce the 1-bit Adam optimizer in DeepSpeed. 1-bit Adam can improve model training speed on communication-constrained clusters, especially for communication-intensive large models by reducing the overall communication volume by up to 5x.
To illustrate the benefits and usage of 1-bit Adam optimizer in DeepSpeed, we use the following two training tasks as examples:
1. BingBertSQuAD Fine-tuning
2. BERT Pre-training
For more details on these tasks, please refer to the tutorial posts on [BingBertSQuAD Fine-tuning](https://www.deepspeed.ai/tutorials/bert-finetuning/) and [BERT Pre-training](https://www.deepspeed.ai/tutorials/bert-pretraining/).
## Overview
If you don't already have a copy of the DeepSpeed repository, please clone in
now and checkout the DeepSpeedExamples submodule that contains the BingBertSQuAD and BERT Pre-training examples.
```shell
git clone https://github.com/microsoft/DeepSpeed
cd DeepSpeed
git submodule update --init --recursive
cd DeepSpeedExamples/
```
## Pre-requisites for 1-bit Adam
1-bit Adam uses advanced communication schemes that are not yet supported by PyTorch distributed and NCCL. We rely on Message Passing Interface (MPI) for these advanced communication primitives.
We package the necessary dependencies in the DeepSpeed docker images. However, if you are using a different build system, please install MPI and mpi4py on your system. We have tested CUDA-Aware MPI communication using the [MVAPICH2-GDR](http://mvapich.cse.ohio-state.edu/userguide/gdr/) library. However, any CUDA-Aware communication library including [OpenMPI](https://www.open-mpi.org/) should work fine with these examples.
An example launch command for 1-bit Adam using the `deepspeed` launcher is as follows:
```shell
deepspeed --launcher=[mvapich|openmpi] script.py
```
Alternatively, the standard mpirun launcher can also be used as follows:
```shell
mpirun -np [#processes] -ppn [#GPUs on each node] -hostfile [hostfile] [MPI flags] bash [training_script.sh]
```
### Configuration
The 1-bit Adam feature can be used by setting the optimizer configuration options as follows. An example json config file is shown below.
```json
{
"train_batch_size": 4096,
"train_micro_batch_size_per_gpu": 64,
"optimizer": {
"type": "OneBitAdam",
"params": {
"lr": 2e-4,
"freeze_step": 400,
"cuda_aware": true
}
},
"fp16": {
"enabled": true,
}
}
```
Please note two new parameters `freeze_step` and `cuda_aware` that have been added to support the 1-bit Adam feature.
`cuda_aware` is used to indicate that the underlying MPI library support CUDA-Aware communication.
This feature is only supported on systems with InfiniBand interconnect and a CUDA-Aware MPI library like [MVAPICH2-GDR](http://mvapich.cse.ohio-state.edu/userguide/gdr/) or OpenMPI built with CUDA-Aware support. Setting `cuda_aware` to False will allow training on Ethernet based systems. However, the communication will happen using sender as well as receiver side memory copies between CPU and GPU buffers before and after communication.
`freeze_step` is the number of warm up steps before 1-bit compression gets applied to the communication. In order to determine the number of warm up steps, one strategy is to set 15-25% of the total training steps for a given model. If it provides the desired outcome, one can try to extract more performance by reducing the steps systematically. In future, we plan to introduce a threshold that can automatically search and decide for the number of warm up steps for different models. The examples below have been tuned for the number of warm up steps. The `freeze_step` parameter has already been set to the best number we found in the corresponding run scripts.
## 1. BingBertSQuAD fine-tuning with 1-bit Adam
* Download the SQuAD dataset:
* Training set: [train-v1.1.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json)
* Validation set: [dev-v1.1.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json)
* Download the HuggingFace checkpoint and config files:
* [bert-large-uncased-whole-word-masking](https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-pytorch_model.bin)
* [bert json config](https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-config.json)
You can also use a pre-trained BERT model checkpoint from either DeepSpeed, [HuggingFace](https://github.com/huggingface/transformers), or [TensorFlow](https://github.com/google-research/bert#pre-trained-models) to run the fine-tuning.
### 1.1 Running BingBertSQuAD with DeepSpeed and 1-bit Adam
The main part of training is done in `nvidia_run_squad_deepspeed.py`, which has
already been modified to use DeepSpeed. The `run_squad_deepspeed.sh` script
helps to invoke training and setup several different hyperparameters relevant
to the training process.
- **DeepSpeed-enabled:** Start training with DeepSpeed by providing the following 4 arguments to this script:
```shell
bash run_squad_deepspeed.sh <NUM_GPUS> <PATH_TO_CHECKPOINT> <PATH_TO_DATA_DIR> <PATH_TO_OUTPUT_DIR>`
```
The first argument is the number of GPUs to train with, second argument is the path to the pre-training checkpoint, third is the path to training and validation sets (e.g., train-v1.1.json), and fourth is path to an output folder where the results will be saved. This script will invoke `nvidia_run_squad_deepspeed.py`.
- **DeepSpeed with 1-bit Adam enabled:** In order to run with 1-bit Adam feature enabled, the same script (`nvidia_run_squad_deepspeed.py`) can be used but there are two options for launching this properly: 1) Launch using deepspeed launcher and 2) Launch with mpirun.
To enable the 1-bit compressed training, 1-bit Adam uses an MPI library (E.g. MVAPICH2-GDR, OpenMPI, etc.) as the communication backend, which means that we can use `mpirun` to launchg the training job. However, our user-friendly launcher called `deepspeed` has been enhanced to launch MPI jobs as well.
### Launch with deepspeed
The following helper script in the DeepSpeedExamples/BingBertSQuAD will launch the training without the need for setting any `mpirun` parameters.
```shell
bash run_squad_deepspeed_onebitadam.sh
```
### Launch with mpirun
Alternatively, we show how the standard `mpirun` launcher can be used for launching the fine-tuning job.
```shell
mpirun -np [#processes] -ppn [#GPUs on each node] -hostfile [hostfile] [MPI flags] bash run_squad_deepspeed_onebitadam.sh
```
For example, in order to use 32 GPUs (4GPUs/node, 8 nodes in total), with the support of InfiniBand, you can use the `mpirun` launcher packaged with the MVAPICH2 library. Please run the folowing command:
```shell
mpirun -np 32 -ppn 4 -hostfile hosts -env MV2_USE_CUDA=1 -env MV2_SUPPORT_DL=1 -env MV2_ENABLE_AFFINITY=0 -env MV2_SMP_USE_CMA=0 bash run_squad_deepspeed_onebitadam.sh
```
### 1.2 Configuration for BingBertSQuAD with DeepSpeed and 1-bit Adam enabled
The `deepspeed_bsz96_onebit_config.json` file gives the user the ability to specify DeepSpeed
options in terms of batch size, micro batch size, optimizer, learning rate, and other parameters.
When running the `nvidia_run_squad_deepspeed.py`, in addition to the
`--deepspeed` flag to enable DeepSpeed, the appropriate DeepSpeed configuration
file must be specified using `--deepspeed_config deepspeed_bsz96_config.json`.
Table 1 shows the fine-tuning configuration we used in our experiments.
| Parameters | Value |
| ------------------------------ | ---------------------|
| Total batch size | 96 |
| Train micro batch size per GPU | 3 |
| Optimizer | **OnebitAdam** |
| Learning rate | 3e-5 |
| Sequence-length | 384 |
| Weight-decay | 0.0 |
| Epoch count | 2 |
| **freeze_step** | 400 |
| **cuda_aware** | True |
Table 1. Fine-tuning configuration
### 1.3 Results for BingBertSQuAD Fine-tuning
The results are summarized in the table below. The total batch size is set to 96 and training is conducted
on 32 GPUs for 2 epochs. A set of parameters (seeds and learning rates) were tried and the best ones were selected.
We fixed the learning rate to 3e-5. The table below shows the F1 and the EM scores we achieved that are on-par or better than the [HuggingFace results](https://github.com/huggingface/transformers/tree/master/examples/question-answering).
| Case | Model | Precision | EM | F1 |
| ----------- | ------------------------------------- | --------- | ----- | ----- |
| HuggingFace | [Bert-large-uncased-whole-word-masking](https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-pytorch_model.bin) | FP16 | 87.26 | 93.32 |
**Note:** For more details about loading checkpoint, argument parsing, initialization, forward pass, backward pass, weight update and evaluation, please refer to the [BingBertSQuAD Fine-tuning](https://www.deepspeed.ai/tutorials/bert-finetuning/) tutorial.
## 2. BERT Pre-training with 1-bit Adam
For data downloading and pre-processing, please refer to [BERT Pre-training](https://www.deepspeed.ai/tutorials/bert-pretraining/) posts
for more details.
### 2.1 Running Pre-training with DeepSpeed and 1-bit Adam
The main part of training is done in `deepspeed_train.py`, which has
already been modified to use DeepSpeed. The `ds_train_bert_onebitadam_bsz4k_seq128.sh` and `ds_train_bert_bsz64k_seq128.sh` are the
shell scripts that
help to invoke training and setup several different hyperparameters relevant
to the training process.
- **DeepSpeed-enabled:** Start training with DeepSpeed by running the command below:
```shell
bash ds_train_bert_bsz64k_seq128.sh
```
- **DeepSpeed with 1-bit Adam enabled:** In order to run with 1-bit Adam feature enabled, the same script (`deepspeed_train.py`) can be used but there are two options for launching this properly:
### Launch with deepspeed
As discussed for BingBertSQuAD fine-tuning, we can simply use the `deepspeed` launcher to launch our BERT pre-training jobs as follows.
```shell
bash ds_train_bert_onebitadam_bsz4k_seq128.sh
```
### Launch with mpirun
Alternatively, use the following command to launch using `mpirun`.
```shell
mpirun -np [#processes] -ppn [#GPUs on each node] -hostfile [hostfile] [MPI flags] bash ds_train_bert_onebitadam_bsz4k_seq128.sh
```
For example, in order to use 32 GPUs (4GPUs/node, 8 nodes in total), with the support of InfiniBand, you can use MVAPICH2 as the launcher and run the following command:
```shell
mpirun -np 32 -ppn 4 -hostfile hosts -env MV2_USE_CUDA=1 -env MV2_SUPPORT_DL=1 -env MV2_ENABLE_AFFINITY=0 -env MV2_SMP_USE_CMA=0 bash ds_train_bert_onebitadam_bsz4k_seq128.sh
```
### 2.2 Configuration for BingBertSQuAD with DeepSpeed and 1-bit Adam enabled
The `deepspeed_bsz4k_onebit_config_seq128.json` file gives the user the ability to specify DeepSpeed
options in terms of batch size, micro batch size, optimizer, learning rate, and other parameters.
Below is the DeepSpeed configuration file for running BERT-large pre-training with sequence length of 128.
```json
{
"train_batch_size": 4096,
"train_micro_batch_size_per_gpu": 64,
"steps_per_print": 1000,
"optimizer": {
"type": "Adam",
"params": {
"lr": 2e-4,
"max_grad_norm": 1.0,
"weight_decay": 0.01,
"bias_correction": false,
"freeze_step": 23000,
"cuda_aware": true
}
},
"fp16": {
"enabled": true,
"loss_scale": 0,
"initial_scale_power": 16
}
}
```
Notice that for BERT-base training (sequence length 128), the suggested freeze_step is 16000. For the rest of the pre-training using sequence 512, we suggest to use a freeze_step of 1500.
### 2.3 Results for BERT pre-training
Using 1-bit Adam, we are able to achieve significantly higher througput compared to the original Adam optimizer. We note that increase training speed during the compressed stage enables overall training speedup of up to 3.5x on Ethernet based systems where communication bandwidth is significantly limited. However, we are able to achieve up to 1.7x overall speedup even for the 40 Gigabit InfiniBand QDR based system. Furthermore, it is important to highlight that we are able to achieve feasible BERT pre-training using 1-bit Adam on a significantly smaller batch size of 4k compared to 32k and 64k for the LAMB optimizer.
Graphs to be added from the blog post ...
......@@ -239,5 +239,5 @@ else
pdsh -w $hosts "python $tmp_wheel_path/basic_install_test.py"
echo "Installation is successful"
fi
pdsh -w $hosts "if [ -d $tmp_wheel_path ]; then rm $tmp_wheel_path/*.whl $tmp_wheel_path/basic_install_test.py $tmp_wheel_path/requirements.txt; rmdir $tmp_wheel_path; fi"
pdsh -w $hosts "if [ -d $tmp_wheel_path ]; then rm $tmp_wheel_path/*.whl $tmp_wheel_path/basic_install_test.py; rmdir $tmp_wheel_path; fi"
fi
......@@ -27,6 +27,11 @@ install_requires = fetch_requirements('requirements/requirements.txt')
dev_requires = fetch_requirements('requirements/requirements-dev.txt')
sparse_attn_requires = fetch_requirements('requirements/requirements-sparse-attn.txt')
onebit_adam_requires = fetch_requirements('requirements/requirements-1bit-adam.txt')
if torch.cuda.is_available():
onebit_adam_requires.append(f"cupy-cuda{torch.version.cuda.replace('.','')[:3]}")
install_requires += onebit_adam_requires
# Build environment variables for custom builds
DS_BUILD_LAMB_MASK = 1
DS_BUILD_TRANSFORMER_MASK = 10
......@@ -227,7 +232,7 @@ setup(name='deepspeed',
description='DeepSpeed library',
author='DeepSpeed Team',
author_email='deepspeed@microsoft.com',
url='http://aka.ms/deepspeed',
url='http://deepspeed.ai',
install_requires=install_requires,
packages=find_packages(exclude=["docker",
"third_party",
......
from mpi4py import MPI
import time
import torch
import torch.distributed as dist
import numpy as np
import deepspeed
from deepspeed.runtime.fp16.onebit_adam import OnebitAdam
comm = MPI.COMM_WORLD
size = comm.Get_size()
rank = comm.Get_rank()
#TODO: Detect the hostname we are running on automatically
torch.distributed.init_process_group(backend='nccl',
init_method='tcp://worker-1:2245',
world_size=size,
rank=rank)
dummy_model = [torch.nn.Parameter(torch.ones(10))]
# Set cuda_aware to True to use CUDA buffers for communication
dummy_optim = OnebitAdam(dummy_model, cuda_aware=True)
device = torch.device('cuda', rank % torch.cuda.device_count())
def torch_sim(a):
a_sign = a.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)
scale = a.norm() / np.sqrt(a.numel())
a_compressed = scale * a_sign
a_sign = None
worker_error = a - a_compressed
dist.all_reduce(a_compressed)
a_compressed.mul_(1 / dist.get_world_size())
a_server_sign = a_compressed.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)
a_list = torch.chunk(a_compressed, chunks=dist.get_world_size())
server_scale = [chunk_a.norm() / np.sqrt(chunk_a.numel()) for chunk_a in a_list]
a_sign_list = torch.chunk(a_server_sign, dist.get_world_size())
a_server_compressed = torch.cat(
[server_scale[i] * a_sign_list[i] for i in range(dist.get_world_size())])
rank = dist.get_rank()
server_error = a_list[rank] - server_scale[rank] * a_sign_list[rank]
torch.cuda.synchronize()
torch.distributed.barrier()
return a_server_compressed, worker_error, server_error
tensor_size = 100 * 2**20
server_size = int(tensor_size / size)
if tensor_size % (8 * size) != 0:
right_tensor_size = tensor_size + (8 * size - (tensor_size % (8 * size)))
else:
right_tensor_size = tensor_size
right_server_size = right_tensor_size // size
# Adding bias to the initialization of the gradient we are communicating
# In order to get rid of the case where some elements in the gradient are too small
a = (torch.rand(tensor_size, device=device) - 0.5) + 0.01 * rank
worker_error = torch.zeros(right_tensor_size, device=device)
server_error = torch.zeros(right_server_size, device=device)
a_torch, worker_error_torch, server_error_torch = torch_sim(a)
torch.cuda.empty_cache()
local_rank = rank % torch.cuda.device_count()
a_after = dummy_optim.Compressed_Allreduce(a,
worker_error,
server_error,
rank,
size,
comm,
local_rank)
threshold = 1e-6
magnitude_threshold = 1e-6
diff_mask = (a_after - a_torch) > threshold
diff_server_mask = torch.chunk(diff_mask, size)[rank]
mpi_server = torch.chunk(a_after, size)[rank] + server_error
torch_server = torch.chunk(a_torch, size)[rank] + server_error_torch
# If the number in the compensated_server_m is too small (e.g 1e-8), then calling sign() might be problematic
# The test would skip those numbers that are too small in compensated_server_m
if torch.sum(diff_server_mask) == 0:
print('Successfully passed the test for 1bit Adam at Rank {}'.format(rank))
else:
check_mag_mask = mpi_server[diff_mask] > magnitude_threshold
if torch.sum(check_mag_mask) == 0:
print('Successfully passed the test for 1bit Adam at Rank {}'.format(rank))
else:
print('Fails at {} of positions'.format(torch.sum(check_mag_mask)))
from mpi4py import MPI
import time
import torch
import torch.distributed as dist
import numpy as np
import deepspeed
from deepspeed.runtime.fp16.onebit_adam import OnebitAdam
comm = MPI.COMM_WORLD
size = comm.Get_size()
rank = comm.Get_rank()
#TODO: Detect the hostname we are running on automatically
torch.distributed.init_process_group(backend='nccl',
init_method='tcp://worker-1:2245',
world_size=size,
rank=rank)
dummy_model = [torch.nn.Parameter(torch.ones(10))]
# Set cuda_aware to False to use host buffers for communication
dummy_optim = OnebitAdam(dummy_model, cuda_aware=False)
device = torch.device('cuda', rank % torch.cuda.device_count())
def torch_sim(a):
a_sign = a.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)
scale = a.norm() / np.sqrt(a.numel())
a_compressed = scale * a_sign
a_sign = None
worker_error = a - a_compressed
dist.all_reduce(a_compressed)
a_compressed.mul_(1 / dist.get_world_size())
a_server_sign = a_compressed.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)
a_list = torch.chunk(a_compressed, chunks=dist.get_world_size())
server_scale = [chunk_a.norm() / np.sqrt(chunk_a.numel()) for chunk_a in a_list]
a_sign_list = torch.chunk(a_server_sign, dist.get_world_size())
a_server_compressed = torch.cat(
[server_scale[i] * a_sign_list[i] for i in range(dist.get_world_size())])
rank = dist.get_rank()
server_error = a_list[rank] - server_scale[rank] * a_sign_list[rank]
torch.cuda.synchronize()
torch.distributed.barrier()
return a_server_compressed, worker_error, server_error
tensor_size = 100 * 2**20
server_size = int(tensor_size / size)
if tensor_size % (8 * size) != 0:
right_tensor_size = tensor_size + (8 * size - (tensor_size % (8 * size)))
else:
right_tensor_size = tensor_size
right_server_size = right_tensor_size // size
# Adding bias to the initialization of the gradient we are communicating
# In order to get rid of the case where some elements in the gradient are too small
a = (torch.rand(tensor_size, device=device) - 0.5) + 0.01 * rank
worker_error = torch.zeros(right_tensor_size, device=device)
server_error = torch.zeros(right_server_size, device=device)
a_torch, worker_error_torch, server_error_torch = torch_sim(a)
torch.cuda.empty_cache()
local_rank = rank % torch.cuda.device_count()
a_after = dummy_optim.Compressed_Allreduce(a,
worker_error,
server_error,
rank,
size,
comm,
local_rank)
threshold = 1e-6
magnitude_threshold = 1e-6
diff_mask = (a_after - a_torch) > threshold
diff_server_mask = torch.chunk(diff_mask, size)[rank]
mpi_server = torch.chunk(a_after, size)[rank] + server_error
torch_server = torch.chunk(a_torch, size)[rank] + server_error_torch
# If the number in the compensated_server_m is too small (e.g 1e-8), then calling sign() might be problematic
# The test would skip those numbers that are too small in compensated_server_m
if torch.sum(diff_server_mask) == 0:
print('Successfully passed the test for 1bit Adam at Rank {}'.format(rank))
else:
check_mag_mask = mpi_server[diff_mask] > magnitude_threshold
if torch.sum(check_mag_mask) == 0:
print('Successfully passed the test for 1bit Adam at Rank {}'.format(rank))
else:
print('Fails at {} of positions'.format(torch.sum(check_mag_mask)))
from mpi4py import MPI
import time
import torch
import torch.distributed as dist
import numpy as np
import deepspeed
from deepspeed.runtime.fp16.onebit_adam import OnebitAdam
comm = MPI.COMM_WORLD
size = comm.Get_size()
rank = comm.Get_rank()
torch.distributed.init_process_group(backend='nccl',
init_method='tcp://worker-0:2245',
world_size=size,
rank=rank)
dummy_model = [torch.nn.Parameter(torch.ones(10))]
dummy_optim = OnebitAdam(dummy_model, cuda_aware=False)
device = torch.device('cuda', rank % torch.cuda.device_count())
def torch_sim(a):
a_sign = a.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)
scale = a.norm() / np.sqrt(a.numel())
a_compressed = scale * a_sign
a_sign = None
worker_error = a - a_compressed
dist.all_reduce(a_compressed)
a_compressed.mul_(1 / dist.get_world_size())
a_server_sign = a_compressed.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)
a_list = torch.chunk(a_compressed, chunks=dist.get_world_size())
server_scale = [chunk_a.norm() / np.sqrt(chunk_a.numel()) for chunk_a in a_list]
a_sign_list = torch.chunk(a_server_sign, dist.get_world_size())
a_server_compressed = torch.cat(
[server_scale[i] * a_sign_list[i] for i in range(dist.get_world_size())])
rank = dist.get_rank()
server_error = a_list[rank] - server_scale[rank] * a_sign_list[rank]
torch.cuda.synchronize()
torch.distributed.barrier()
return a_server_compressed, worker_error, server_error
# Input Tensor size
tensor_size = 100 * 2**20
server_size = int(tensor_size / size)
if tensor_size % (8 * size) != 0:
right_tensor_size = tensor_size + (8 * size - (tensor_size % (8 * size)))
else:
right_tensor_size = tensor_size
right_server_size = right_tensor_size // size
# The -0.5 is required for avoiding sign flips/errors
a = torch.rand(tensor_size, device=device) - 0.5
worker_error = torch.zeros(right_tensor_size, device=device)
server_error = torch.zeros(right_server_size, device=device)
a_torch, worker_error_torch, server_error_torch = torch_sim(a)
torch.cuda.empty_cache()
local_rank = rank % torch.cuda.device_count()
# Test the 1-bit Adam optimizer
a_after = dummy_optim.Compressed_Allreduce(a,
worker_error,
server_error,
rank,
size,
comm,
local_rank)
# If the error is below the threshold, it is acceptable for training
threshold = 1e-6
diff_pos = ((a_after - a_torch) > threshold)
if rank == 0:
before_diff = torch.chunk(a_after - a_torch,
size)[rank] + server_error - server_error_torch
if torch.norm(before_diff) / torch.norm(torch.chunk(a_after,
size)[rank]) < threshold:
print('Successfully passed the test')
else:
print('The difference for the tensor before allgather is {}'.format(
torch.norm(before_diff)))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment