Unverified Commit 65c2f974 authored by Shaden Smith's avatar Shaden Smith Committed by GitHub
Browse files

Pipeline parallel training engine. (#392)


Co-authored-by: default avatarJeff Rasley <jerasley@microsoft.com>
parent 41db1c2f
...@@ -8,11 +8,14 @@ from . import ops ...@@ -8,11 +8,14 @@ from . import ops
from .runtime.engine import DeepSpeedEngine from .runtime.engine import DeepSpeedEngine
from .runtime.engine import ADAM_OPTIMIZER, LAMB_OPTIMIZER, DEEPSPEED_ADAM from .runtime.engine import ADAM_OPTIMIZER, LAMB_OPTIMIZER, DEEPSPEED_ADAM
from .runtime.pipe.engine import PipelineEngine
from .runtime.lr_schedules import add_tuning_arguments from .runtime.lr_schedules import add_tuning_arguments
from .runtime.config import DeepSpeedConfig from .runtime.config import DeepSpeedConfig
from .runtime.activation_checkpointing import checkpointing from .runtime.activation_checkpointing import checkpointing
from .ops.transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig from .ops.transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig
from .utils import logger from .utils import log_dist
from .pipe import PipelineModule
try: try:
from .git_version_info import version, git_hash, git_branch from .git_version_info import version, git_hash, git_branch
...@@ -99,23 +102,35 @@ def initialize(args, ...@@ -99,23 +102,35 @@ def initialize(args,
* ``lr_scheduler``: Wrapped lr scheduler if user ``lr_scheduler`` is passed, or * ``lr_scheduler``: Wrapped lr scheduler if user ``lr_scheduler`` is passed, or
if ``lr_scheduler`` specified in JSON configuration. Otherwise ``None``. if ``lr_scheduler`` specified in JSON configuration. Otherwise ``None``.
""" """
logger.info( log_dist("DeepSpeed info: version={}, git-hash={}, git-branch={}".format(
"DeepSpeed info: version={}, git-hash={}, git-branch={}".format( __version__,
__version__, __git_hash__,
__git_hash__, __git_branch__),
__git_branch__), ranks=[0])
)
if not isinstance(model, PipelineModule):
engine = DeepSpeedEngine(args=args, engine = DeepSpeedEngine(args=args,
model=model, model=model,
optimizer=optimizer, optimizer=optimizer,
model_parameters=model_parameters, model_parameters=model_parameters,
training_data=training_data, training_data=training_data,
lr_scheduler=lr_scheduler, lr_scheduler=lr_scheduler,
mpu=mpu, mpu=mpu,
dist_init_required=dist_init_required, dist_init_required=dist_init_required,
collate_fn=collate_fn, collate_fn=collate_fn,
config_params=config_params) config_params=config_params)
else:
assert mpu is None, "mpu must be None with pipeline parallelism"
engine = PipelineEngine(args=args,
model=model,
optimizer=optimizer,
model_parameters=model_parameters,
training_data=training_data,
lr_scheduler=lr_scheduler,
mpu=model.mpu(),
dist_init_required=dist_init_required,
collate_fn=collate_fn,
config_params=config_params)
return_items = [ return_items = [
engine, engine,
......
from ..runtime.pipe import PipelineModule, LayerSpec, TiedLayerSpec
...@@ -480,6 +480,10 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -480,6 +480,10 @@ class CheckpointFunction(torch.autograd.Function):
timers.log(['forward']) timers.log(['forward'])
if SYNCHRONIZE: if SYNCHRONIZE:
torch.cuda.synchronize() torch.cuda.synchronize()
# Tensors returned from forward() may not be differentiable, e.g., attention mask
non_grad_outputs = [o for o in outputs if not o.is_floating_point()]
ctx.mark_non_differentiable(*non_grad_outputs)
return outputs return outputs
@staticmethod @staticmethod
...@@ -548,7 +552,20 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -548,7 +552,20 @@ class CheckpointFunction(torch.autograd.Function):
if isinstance(outputs, torch.Tensor): if isinstance(outputs, torch.Tensor):
outputs = (outputs, ) outputs = (outputs, )
torch.autograd.backward(outputs, args)
# Go over args and build the list of gradient tensors. This is usually just args,
# but if the forward pass returns tensors that do not require_grad then we should
# adjust the arguments to autograd.backward() too. This happens when forward()
# returns indices or a mask (such as an attention mask).
# We skip the first needs_input_grad because it corresponds to run_function.
output_tensors = []
grad_tensors = []
for idx, need_grad in enumerate(ctx.needs_input_grad[1:]):
if need_grad:
output_tensors.append(outputs[idx])
grad_tensors.append(args[idx])
torch.autograd.backward(output_tensors, grad_tensors)
if PROFILE_TIME: if PROFILE_TIME:
timers('backward').stop() timers('backward').stop()
......
...@@ -324,6 +324,20 @@ def get_sparse_attention_type(param_dict): ...@@ -324,6 +324,20 @@ def get_sparse_attention_type(param_dict):
return SPARSE_ATTENTION_TYPE_DEFAULT return SPARSE_ATTENTION_TYPE_DEFAULT
def get_pipeline_config(param_dict):
'''Parses pipeline engine configuration. '''
default_pipeline = {
'stages': 'auto',
'partition': 'best',
'seed_layers': False,
'activation_checkpoint_interval': 0
}
config = default_pipeline
for key, val in param_dict.get('pipeline', {}).items():
config[key] = val
return config
def get_optimizer_name(param_dict): def get_optimizer_name(param_dict):
if OPTIMIZER in param_dict.keys() and \ if OPTIMIZER in param_dict.keys() and \
TYPE in param_dict[OPTIMIZER].keys(): TYPE in param_dict[OPTIMIZER].keys():
...@@ -523,6 +537,7 @@ class DeepSpeedConfig(object): ...@@ -523,6 +537,7 @@ class DeepSpeedConfig(object):
self.tensorboard_job_name = get_tensorboard_job_name(param_dict) self.tensorboard_job_name = get_tensorboard_job_name(param_dict)
self.sparse_attention = get_sparse_attention(param_dict) self.sparse_attention = get_sparse_attention(param_dict)
self.pipeline = get_pipeline_config(param_dict)
def _batch_assertion(self): def _batch_assertion(self):
...@@ -592,10 +607,6 @@ class DeepSpeedConfig(object): ...@@ -592,10 +607,6 @@ class DeepSpeedConfig(object):
assert False, \ assert False, \
'Either train_batch_size or micro_batch_per_gpu needs to be provided' 'Either train_batch_size or micro_batch_per_gpu needs to be provided'
logger.info(
f' After Train batch {self.train_batch_size} micro_batch {self.train_micro_batch_size_per_gpu} and grad_acc {self.gradient_accumulation_steps}'
)
def _configure_train_batch_size(self): def _configure_train_batch_size(self):
self._set_batch_related_parameters() self._set_batch_related_parameters()
self._batch_assertion() self._batch_assertion()
...@@ -646,12 +657,14 @@ class DeepSpeedConfig(object): ...@@ -646,12 +657,14 @@ class DeepSpeedConfig(object):
MAX_GRAD_NORM in self.optimizer_params.keys() and \ MAX_GRAD_NORM in self.optimizer_params.keys() and \
self.optimizer_params[MAX_GRAD_NORM] > 0: self.optimizer_params[MAX_GRAD_NORM] > 0:
if fp16_enabled: if fp16_enabled:
logger.warning( if self.global_rank == 0:
'DeepSpeedConfig: In FP16 mode, DeepSpeed will pass {}:{} to FP16 wrapper' logger.warning(
.format(MAX_GRAD_NORM, 'DeepSpeedConfig: In FP16 mode, DeepSpeed will pass {}:{} to FP16 wrapper'
self.optimizer_params[MAX_GRAD_NORM])) .format(MAX_GRAD_NORM,
self.optimizer_params[MAX_GRAD_NORM]))
else: else:
logger.warning( if self.global_rank == 0:
'DeepSpeedConfig: In FP32 mode, DeepSpeed does not permit MAX_GRAD_NORM ({}) > 0, setting to zero' logger.warning(
.format(self.optimizer_params[MAX_GRAD_NORM])) 'DeepSpeedConfig: In FP32 mode, DeepSpeed does not permit MAX_GRAD_NORM ({}) > 0, setting to zero'
.format(self.optimizer_params[MAX_GRAD_NORM]))
self.optimizer_params[MAX_GRAD_NORM] = 0.0 self.optimizer_params[MAX_GRAD_NORM] = 0.0
...@@ -7,6 +7,29 @@ from torch.utils.data import DataLoader, RandomSampler ...@@ -7,6 +7,29 @@ from torch.utils.data import DataLoader, RandomSampler
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
class RepeatingLoader:
def __init__(self, loader):
"""Wraps an iterator to allow for infinite iteration. This is especially useful
for DataLoader types that we wish to automatically restart upon completion.
Args:
loader (iterator): The data loader to repeat.
"""
self.loader = loader
self.data_iter = iter(self.loader)
def __iter__(self):
return self
def __next__(self):
try:
batch = next(self.data_iter)
except StopIteration:
self.data_iter = iter(self.loader)
batch = next(self.data_iter)
return batch
class DeepSpeedDataLoader(object): class DeepSpeedDataLoader(object):
def __init__(self, def __init__(self,
dataset, dataset,
......
...@@ -29,9 +29,11 @@ from deepspeed.runtime.zero.constants import \ ...@@ -29,9 +29,11 @@ from deepspeed.runtime.zero.constants import \
ZERO_OPTIMIZATION_OPTIMIZER_STATES, ZERO_OPTIMIZATION_GRADIENTS ZERO_OPTIMIZATION_OPTIMIZER_STATES, ZERO_OPTIMIZATION_GRADIENTS
from deepspeed.runtime.csr_tensor import CSRTensor from deepspeed.runtime.csr_tensor import CSRTensor
import deepspeed.runtime.lr_schedules as lr_schedules import deepspeed.runtime.lr_schedules as lr_schedules
from deepspeed.utils import logger from deepspeed.utils import logger, log_dist
from deepspeed.utils.timer import ThroughputTimer, SynchronizedWallClockTimer from deepspeed.utils.timer import ThroughputTimer, SynchronizedWallClockTimer
from .utils import ensure_directory_exists
MEMORY_OPT_ALLREDUCE_SIZE = 500000000 MEMORY_OPT_ALLREDUCE_SIZE = 500000000
SUMMARY_WRITER_DIR_NAME = "JobId" SUMMARY_WRITER_DIR_NAME = "JobId"
...@@ -114,6 +116,7 @@ class DeepSpeedEngine(Module): ...@@ -114,6 +116,7 @@ class DeepSpeedEngine(Module):
self.mpu = mpu self.mpu = mpu
self.data_parallel_group = None self.data_parallel_group = None
self.global_steps = 0 self.global_steps = 0
self.global_samples = 0
self.micro_steps = 0 self.micro_steps = 0
self.skipped_steps = 0 self.skipped_steps = 0
self.gradient_average = True self.gradient_average = True
...@@ -145,7 +148,6 @@ class DeepSpeedEngine(Module): ...@@ -145,7 +148,6 @@ class DeepSpeedEngine(Module):
self._init_distributed(dist_init_required) self._init_distributed(dist_init_required)
self.sample_count = 0
if self.tensorboard_enabled() and self.global_rank == 0: if self.tensorboard_enabled() and self.global_rank == 0:
self.summary_writer = self.get_summary_writer() self.summary_writer = self.get_summary_writer()
...@@ -162,8 +164,10 @@ class DeepSpeedEngine(Module): ...@@ -162,8 +164,10 @@ class DeepSpeedEngine(Module):
steps_per_output=self.steps_per_print(), steps_per_output=self.steps_per_print(),
monitor_memory=False) monitor_memory=False)
self.training_dataloader = self.deepspeed_io( if training_data:
training_data) if training_data else None self.training_dataloader = self.deepspeed_io(training_data)
else:
self.training_dataloader = None
# Configure optimizer and scheduler # Configure optimizer and scheduler
self.optimizer = None self.optimizer = None
...@@ -241,14 +245,20 @@ class DeepSpeedEngine(Module): ...@@ -241,14 +245,20 @@ class DeepSpeedEngine(Module):
def get_summary_writer(self, def get_summary_writer(self,
name="DeepSpeedJobName", name="DeepSpeedJobName",
base=os.environ["HOME"] + "/tensorboard"): base=os.path.join(os.environ["HOME"],
if self.tensorboard_job_name(): "tensorboard")):
name = self.tensorboard_job_name()
if self.tensorboard_output_path(): if self.tensorboard_output_path():
return SummaryWriter(log_dir=self.tensorboard_output_path()) log_dir = self.tensorboard_output_path()
if 'DLWS_JOB_ID' in os.environ: else:
SUMMARY_WRITER_DIR_NAME = os.environ['DLWS_JOB_ID'] + "/logs" if self.tensorboard_job_name():
return SummaryWriter(log_dir=os.path.join(base, SUMMARY_WRITER_DIR_NAME, name)) name = self.tensorboard_job_name()
if 'DLWS_JOB_ID' in os.environ:
SUMMARY_WRITER_DIR_NAME = os.path.join(os.environ['DLWS_JOB_ID'], "logs")
log_dir = os.path.join(base, SUMMARY_WRITER_DIR_NAME, name)
os.makedirs(log_dir, exist_ok=True)
return SummaryWriter(log_dir=log_dir)
def wall_clock_breakdown(self): def wall_clock_breakdown(self):
return self._config.wall_clock_breakdown return self._config.wall_clock_breakdown
...@@ -362,13 +372,15 @@ class DeepSpeedEngine(Module): ...@@ -362,13 +372,15 @@ class DeepSpeedEngine(Module):
# First check for scheduler in json configuration # First check for scheduler in json configuration
lr_scheduler = self._scheduler_from_config(self.optimizer) lr_scheduler = self._scheduler_from_config(self.optimizer)
if lr_scheduler: if lr_scheduler:
logger.info( if self.global_rank == 0:
f'DeepSpeed using configured LR scheduler = {self.scheduler_name()}') logger.info(
f'DeepSpeed using configured LR scheduler = {self.scheduler_name()}')
self.lr_scheduler = lr_scheduler self.lr_scheduler = lr_scheduler
else: else:
logger.warning('DeepSpeed using client LR scheduler') if self.global_rank == 0:
logger.info('DeepSpeed using client LR scheduler')
self.lr_scheduler = client_lr_scheduler self.lr_scheduler = client_lr_scheduler
logger.info(f'DeepSpeed LR Scheduler = {self.lr_scheduler}') log_dist(f'DeepSpeed LR Scheduler = {self.lr_scheduler}', ranks=[0])
def _configure_checkpointing(self, dist_init_required): def _configure_checkpointing(self, dist_init_required):
...@@ -380,11 +392,12 @@ class DeepSpeedEngine(Module): ...@@ -380,11 +392,12 @@ class DeepSpeedEngine(Module):
self.save_non_zero_checkpoint = (dp_rank == 0) self.save_non_zero_checkpoint = (dp_rank == 0)
if self.zero_optimization(): if self.zero_optimization():
pp_rank = torch.distributed.get_rank(group=self.optimizer.dp_process_group) param_rank = torch.distributed.get_rank(
group=self.optimizer.dp_process_group)
# Only the first parameter parallel process needs to store the # Only the first parameter parallel process needs to store the
# optimizer state checkpoints for zero # optimizer state checkpoints for zero
self.save_zero_checkpoint = (pp_rank == dp_rank) self.save_zero_checkpoint = (param_rank == dp_rank)
def _scheduler_from_config(self, optimizer): def _scheduler_from_config(self, optimizer):
scheduler_name = self.scheduler_name() scheduler_name = self.scheduler_name()
...@@ -409,8 +422,6 @@ class DeepSpeedEngine(Module): ...@@ -409,8 +422,6 @@ class DeepSpeedEngine(Module):
self.device = torch.device("cuda", self.local_rank) self.device = torch.device("cuda", self.local_rank)
self.world_size = dist.get_world_size() self.world_size = dist.get_world_size()
self.global_rank = dist.get_rank() self.global_rank = dist.get_rank()
logger.info("Set device to local rank {} within node.".format(
self.local_rank))
else: else:
self.world_size = 1 self.world_size = 1
self.global_rank = 0 self.global_rank = 0
...@@ -484,7 +495,6 @@ class DeepSpeedEngine(Module): ...@@ -484,7 +495,6 @@ class DeepSpeedEngine(Module):
self.broadcast_src_rank = _get_global_rank( self.broadcast_src_rank = _get_global_rank(
self.mpu.get_data_parallel_group(), self.mpu.get_data_parallel_group(),
0) 0)
logger.info(f"global src_rank={self.broadcast_src_rank}")
if not self.amp_enabled(): if not self.amp_enabled():
self._broadcast_model() self._broadcast_model()
...@@ -494,14 +504,17 @@ class DeepSpeedEngine(Module): ...@@ -494,14 +504,17 @@ class DeepSpeedEngine(Module):
if client_optimizer is not None: if client_optimizer is not None:
basic_optimizer = client_optimizer basic_optimizer = client_optimizer
logger.info('Using client Optimizer as basic optimizer') if self.global_rank == 0:
logger.info('Using client Optimizer as basic optimizer')
else: else:
basic_optimizer = self._configure_basic_optimizer(model_parameters) basic_optimizer = self._configure_basic_optimizer(model_parameters)
logger.info( if self.global_rank == 0:
'Using DeepSpeed Optimizer param name {} as basic optimizer'.format( logger.info(
self.optimizer_name())) 'Using DeepSpeed Optimizer param name {} as basic optimizer'.format(
self.optimizer_name()))
logger.info('DeepSpeed Basic Optimizer = {}'.format(basic_optimizer)) if self.global_rank == 0:
logger.info('DeepSpeed Basic Optimizer = {}'.format(basic_optimizer))
if self.zero_optimization(): if self.zero_optimization():
assert not self.amp_enabled(), "Amp and ZeRO are not currently compatible, please use (legacy) fp16 mode which performs similar to amp opt_mode=O2" assert not self.amp_enabled(), "Amp and ZeRO are not currently compatible, please use (legacy) fp16 mode which performs similar to amp opt_mode=O2"
...@@ -509,15 +522,16 @@ class DeepSpeedEngine(Module): ...@@ -509,15 +522,16 @@ class DeepSpeedEngine(Module):
assert self.zero_allow_untested_optimizer(), \ assert self.zero_allow_untested_optimizer(), \
'You are using an untested ZeRO Optimizer. Please add <"zero_allow_untested_optimizer": true> in the configuration file to use it.' 'You are using an untested ZeRO Optimizer. Please add <"zero_allow_untested_optimizer": true> in the configuration file to use it.'
logger.warning( if self.global_rank == 0:
"**** You are using ZeRO with an untested optimizer, proceed with caution *****" logger.warning(
) "**** You are using ZeRO with an untested optimizer, proceed with caution *****"
)
self.optimizer = self._configure_zero_optimizer(basic_optimizer) self.optimizer = self._configure_zero_optimizer(basic_optimizer)
elif self.amp_enabled(): elif self.amp_enabled():
assert not self.fp16_enabled(), "Cannot enable both amp with (legacy) fp16 mode" assert not self.fp16_enabled(), "Cannot enable both amp with (legacy) fp16 mode"
amp_params = self.amp_params() amp_params = self.amp_params()
logger.info(f"Initializing AMP with these params: {amp_params}") if self.global_rank == 0:
logger.info(f"Initializing AMP with these params: {amp_params}")
self.module, self.optimizer = amp.initialize(self.module, basic_optimizer, **amp_params) self.module, self.optimizer = amp.initialize(self.module, basic_optimizer, **amp_params)
self._broadcast_model() self._broadcast_model()
elif self.fp16_enabled(): elif self.fp16_enabled():
...@@ -766,13 +780,10 @@ class DeepSpeedEngine(Module): ...@@ -766,13 +780,10 @@ class DeepSpeedEngine(Module):
if self.tensorboard_enabled(): if self.tensorboard_enabled():
if self.is_gradient_accumulation_boundary(): if self.is_gradient_accumulation_boundary():
if self.global_rank == 0: if self.global_rank == 0:
self.sample_count += (self.train_micro_batch_size_per_gpu() *
self.dp_world_size *
self.gradient_accumulation_steps())
self.summary_events = [ self.summary_events = [
(f'Train/Samples/train_loss', (f'Train/Samples/train_loss',
loss.mean().item() * self.gradient_accumulation_steps(), loss.mean().item() * self.gradient_accumulation_steps(),
self.sample_count) self.global_samples)
] ]
for event in self.summary_events: # write_summary_events for event in self.summary_events: # write_summary_events
self.summary_writer.add_scalar(event[0], event[1], event[2]) self.summary_writer.add_scalar(event[0], event[1], event[2])
...@@ -844,8 +855,47 @@ class DeepSpeedEngine(Module): ...@@ -844,8 +855,47 @@ class DeepSpeedEngine(Module):
torch.nn.utils.clip_grad_norm_(parameters=self.module.parameters(), torch.nn.utils.clip_grad_norm_(parameters=self.module.parameters(),
max_norm=self.gradient_clipping()) max_norm=self.gradient_clipping())
def _take_model_step(self):
if self.gradient_clipping() > 0.0:
if not self.fp16_enabled() and not self.amp_enabled():
self.clip_fp32_gradients()
elif self.amp_enabled():
# AMP's recommended way of doing clipping
# https://nvidia.github.io/apex/advanced.html#gradient-clipping
master_params = amp.master_params(self.optimizer)
torch.nn.utils.clip_grad_norm_(parameters=master_params,
max_norm=self.gradient_clipping())
self.optimizer.step()
#zero grad in basic optimizer could be unreliable and may not exhibit
#the behaviour that we want
if not self.zero_optimization() and not self.fp16_enabled(
) and not self.amp_enabled():
self.zero_grad()
else:
self.optimizer.zero_grad()
report_progress = self.global_rank == 0 if self.global_rank else True
# Check overlow here since in DS fp16 optimizer, the overflow is updated in above step() function.
overflow = False
if hasattr(self.optimizer, 'overflow'):
overflow = self.optimizer.overflow
if overflow:
self.skipped_steps += 1
else:
if self.lr_scheduler is not None:
self.lr_scheduler.step()
if report_progress and (self.global_steps + 1) % self.steps_per_print() == 0:
self._report_progress(self.global_steps + 1)
self.global_steps += 1
self.global_samples += self.train_batch_size()
def step(self): def step(self):
r"""Execute the weight update step after forward and backward propagation on effective_train_batch r"""Execute the weight update step after forward and backward propagation
on effective_train_batch.
""" """
if self.wall_clock_breakdown(): if self.wall_clock_breakdown():
self.timers('step_microstep').start() self.timers('step_microstep').start()
...@@ -855,42 +905,9 @@ class DeepSpeedEngine(Module): ...@@ -855,42 +905,9 @@ class DeepSpeedEngine(Module):
"init in order to use step" "init in order to use step"
report_progress = self.global_rank == 0 if self.global_rank else True report_progress = self.global_rank == 0 if self.global_rank else True
# Update the model when we reach gradient accumulation boundaries
if self.is_gradient_accumulation_boundary(): if self.is_gradient_accumulation_boundary():
self._take_model_step()
if self.gradient_clipping() > 0.0:
if not self.fp16_enabled() and not self.amp_enabled():
self.clip_fp32_gradients()
elif self.amp_enabled():
# AMP's recommended way of doing clipping
# https://nvidia.github.io/apex/advanced.html#gradient-clipping
master_params = amp.master_params(self.optimizer)
torch.nn.utils.clip_grad_norm_(parameters=master_params,
max_norm=self.gradient_clipping())
self.optimizer.step()
#zero grad in basic optimizer could be unreliable and may not exhibit
#the behaviour that we want
if not self.zero_optimization() and not self.fp16_enabled(
) and not self.amp_enabled():
self.zero_grad()
else:
self.optimizer.zero_grad()
# Check overlow here since in DS fp16 optimizer, the overflow is updated in above step() function.
overflow = False
if hasattr(self.optimizer, 'overflow'):
overflow = self.optimizer.overflow
if overflow:
self.skipped_steps += 1
else:
if self.lr_scheduler is not None:
self.lr_scheduler.step()
if report_progress and (self.global_steps +
1) % self.steps_per_print() == 0:
self._report_progress(self.global_steps + 1)
self.global_steps += 1
self.tput_timer.stop(report_progress) self.tput_timer.stop(report_progress)
...@@ -900,7 +917,13 @@ class DeepSpeedEngine(Module): ...@@ -900,7 +917,13 @@ class DeepSpeedEngine(Module):
if self.global_rank == 0: if self.global_rank == 0:
self.summary_events = [(f'Train/Samples/lr', self.summary_events = [(f'Train/Samples/lr',
self.get_lr()[0], self.get_lr()[0],
self.sample_count)] self.global_samples)]
for event in self.summary_events: # write_summary_events
self.summary_writer.add_scalar(event[0], event[1], event[2])
if self.fp16_enabled() and hasattr(self.optimizer, 'cur_scale'):
self.summary_events.append((f'Train/Samples/loss_scale',
self.optimizer.cur_scale,
self.global_samples))
for event in self.summary_events: # write_summary_events for event in self.summary_events: # write_summary_events
self.summary_writer.add_scalar(event[0], event[1], event[2]) self.summary_writer.add_scalar(event[0], event[1], event[2])
self.summary_writer.flush() self.summary_writer.flush()
...@@ -924,20 +947,20 @@ class DeepSpeedEngine(Module): ...@@ -924,20 +947,20 @@ class DeepSpeedEngine(Module):
self.summary_events = [ self.summary_events = [
(f'Train/Samples/elapsed_time_ms_forward', (f'Train/Samples/elapsed_time_ms_forward',
self.timers('forward').elapsed(reset=False) * 1000.0, self.timers('forward').elapsed(reset=False) * 1000.0,
self.sample_count), self.global_samples),
(f'Train/Samples/elapsed_time_ms_backward', (f'Train/Samples/elapsed_time_ms_backward',
self.timers('backward').elapsed(reset=False) * 1000.0, self.timers('backward').elapsed(reset=False) * 1000.0,
self.sample_count), self.global_samples),
(f'Train/Samples/elapsed_time_ms_backward_inner', (f'Train/Samples/elapsed_time_ms_backward_inner',
self.timers('backward_inner').elapsed(reset=False) * 1000.0, self.timers('backward_inner').elapsed(reset=False) * 1000.0,
self.sample_count), self.global_samples),
(f'Train/Samples/elapsed_time_ms_backward_allreduce', (f'Train/Samples/elapsed_time_ms_backward_allreduce',
self.timers('backward_allreduce').elapsed(reset=False) * self.timers('backward_allreduce').elapsed(reset=False) *
1000.0, 1000.0,
self.sample_count), self.global_samples),
(f'Train/Samples/elapsed_time_ms_step', (f'Train/Samples/elapsed_time_ms_step',
self.timers('step').elapsed(reset=False) * 1000.0, self.timers('step').elapsed(reset=False) * 1000.0,
self.sample_count) self.global_samples)
] ]
for event in self.summary_events: # write_summary_events for event in self.summary_events: # write_summary_events
self.summary_writer.add_scalar(event[0], event[1], event[2]) self.summary_writer.add_scalar(event[0], event[1], event[2])
...@@ -977,12 +1000,8 @@ class DeepSpeedEngine(Module): ...@@ -977,12 +1000,8 @@ class DeepSpeedEngine(Module):
def _report_progress(self, step): def _report_progress(self, step):
lr = self.get_lr() lr = self.get_lr()
mom = self.get_mom() mom = self.get_mom()
logger.info('rank:{} step={}, skipped={}, lr={}, mom={}'.format( log_dist(f'step={step}, skipped={self.skipped_steps}, lr={lr}, mom={mom}',
self.global_rank, ranks=[0])
step,
self.skipped_steps,
lr,
mom))
def allreduce_bucket(self, bucket): def allreduce_bucket(self, bucket):
tensor = flatten(bucket) tensor = flatten(bucket)
...@@ -1138,18 +1157,12 @@ class DeepSpeedEngine(Module): ...@@ -1138,18 +1157,12 @@ class DeepSpeedEngine(Module):
return self._get_rank_zero_ckpt_name(checkpoints_path, tag, mp_rank, pp_rank) return self._get_rank_zero_ckpt_name(checkpoints_path, tag, mp_rank, pp_rank)
def _get_ckpt_name(self, checkpoints_path, tag): def _get_ckpt_name(self, checkpoints_path, tag):
mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank() mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank()
ckpt_name = os.path.join(checkpoints_path, ckpt_name = os.path.join(checkpoints_path,
str(tag), str(tag),
'mp_rank_{:02d}'.format(mp_rank) + '_model_states.pt') 'mp_rank_{:02d}'.format(mp_rank) + '_model_states.pt')
return ckpt_name return ckpt_name
def _ensure_directory_exists(self, filename):
dirname = os.path.dirname(filename)
if not os.path.exists(dirname):
os.makedirs(dirname)
def load_checkpoint(self, def load_checkpoint(self,
load_dir, load_dir,
tag, tag,
...@@ -1197,7 +1210,7 @@ class DeepSpeedEngine(Module): ...@@ -1197,7 +1210,7 @@ class DeepSpeedEngine(Module):
.format(load_path)) .format(load_path))
return None, None return None, None
logger.info('Loading checkpoint: {}'.format(load_path)) logger.info(f'rank: {self.global_rank} loading checkpoint: {load_path}')
checkpoint = torch.load(load_path, map_location=lambda storage, loc: storage) checkpoint = torch.load(load_path, map_location=lambda storage, loc: storage)
self.load_module_state_dict(state_dict=checkpoint['module'], self.load_module_state_dict(state_dict=checkpoint['module'],
...@@ -1215,6 +1228,8 @@ class DeepSpeedEngine(Module): ...@@ -1215,6 +1228,8 @@ class DeepSpeedEngine(Module):
self.csr_tensor_module_names = checkpoint['csr_tensor_module_names'] self.csr_tensor_module_names = checkpoint['csr_tensor_module_names']
self.global_steps = checkpoint['global_steps'] self.global_steps = checkpoint['global_steps']
self.global_samples = checkpoint.get('global_samples',
self.global_steps * self.train_batch_size())
self.skipped_steps = checkpoint['skipped_steps'] self.skipped_steps = checkpoint['skipped_steps']
self.loaded_checkpoint_mp_world_size = checkpoint['mp_world_size'] self.loaded_checkpoint_mp_world_size = checkpoint['mp_world_size']
self.loaded_checkpoint_dp_world_size = checkpoint['dp_world_size'] self.loaded_checkpoint_dp_world_size = checkpoint['dp_world_size']
...@@ -1289,7 +1304,7 @@ class DeepSpeedEngine(Module): ...@@ -1289,7 +1304,7 @@ class DeepSpeedEngine(Module):
invalid_zero_ckpt_paths.append(ckpt_name) invalid_zero_ckpt_paths.append(ckpt_name)
if len(invalid_zero_ckpt_paths) > 0: if len(invalid_zero_ckpt_paths) > 0:
logging.warn( logger.warn(
f"Client provided zero checkpoint load paths: {invalid_zero_ckpt_paths} does not exist" f"Client provided zero checkpoint load paths: {invalid_zero_ckpt_paths} does not exist"
) )
return None return None
...@@ -1330,9 +1345,9 @@ class DeepSpeedEngine(Module): ...@@ -1330,9 +1345,9 @@ class DeepSpeedEngine(Module):
name_function = self._get_zero_ckpt_name if zero_checkpoint else self._get_ckpt_name name_function = self._get_zero_ckpt_name if zero_checkpoint else self._get_ckpt_name
try: try:
checkpoint_name = name_function(save_dir, tag) checkpoint_name = name_function(save_dir, tag)
self._ensure_directory_exists(checkpoint_name) ensure_directory_exists(checkpoint_name)
except: except:
logger.error(f'Failed Saving model checkpoint to {save_dir} with tag {tag}') logger.error(f'Failed saving model checkpoint to {save_dir} with tag {tag}')
return False return False
return True return True
...@@ -1351,7 +1366,10 @@ class DeepSpeedEngine(Module): ...@@ -1351,7 +1366,10 @@ class DeepSpeedEngine(Module):
def _save_checkpoint(self, save_dir, tag, client_state={}): def _save_checkpoint(self, save_dir, tag, client_state={}):
save_path = self._get_ckpt_name(save_dir, tag) save_path = self._get_ckpt_name(save_dir, tag)
# self._ensure_directory_exists(save_path) # A hack to save the checkpointing directory. Pipeline parallelism overrides
# module_state_dict() and uses this path to save the model. module_state_dict()
# then instead just returns self._curr_save_path.
self._curr_save_path = os.path.dirname(save_path)
state = { state = {
'module': 'module':
...@@ -1367,6 +1385,8 @@ class DeepSpeedEngine(Module): ...@@ -1367,6 +1385,8 @@ class DeepSpeedEngine(Module):
self.skipped_steps, self.skipped_steps,
'global_steps': 'global_steps':
self.global_steps, self.global_steps,
'global_samples':
self.global_samples,
'dp_world_size': 'dp_world_size':
self.dp_world_size, self.dp_world_size,
'mp_world_size': 'mp_world_size':
...@@ -1374,12 +1394,13 @@ class DeepSpeedEngine(Module): ...@@ -1374,12 +1394,13 @@ class DeepSpeedEngine(Module):
} }
state.update(client_state) state.update(client_state)
logger.info('Saving model checkpoint: {}'.format(save_path)) log_dist(message=f'Saving model checkpoint: {save_path}', ranks=[0])
#logger.info('Saving model checkpoint: {}'.format(save_path))
torch.save(state, save_path) torch.save(state, save_path)
self._curr_save_path = None
def _save_zero_checkpoint(self, save_path, tag): def _save_zero_checkpoint(self, save_path, tag):
zero_checkpoint_name = self._get_zero_ckpt_name(save_path, tag) zero_checkpoint_name = self._get_zero_ckpt_name(save_path, tag)
# self._ensure_directory_exists(zero_checkpoint_name)
zero_sd = {'optimizer_state_dict': self.optimizer.state_dict()} zero_sd = {'optimizer_state_dict': self.optimizer.state_dict()}
torch.save(zero_sd, zero_checkpoint_name) torch.save(zero_sd, zero_checkpoint_name)
logger.info('zero checkpoint saved {}'.format(zero_checkpoint_name)) logger.info('zero checkpoint saved {}'.format(zero_checkpoint_name))
...@@ -97,7 +97,7 @@ class FP16_Optimizer(object): ...@@ -97,7 +97,7 @@ class FP16_Optimizer(object):
self.clip_grad_norm = torch.nn.utils.clip_grad_norm_ self.clip_grad_norm = torch.nn.utils.clip_grad_norm_
#model parallel object #model parallel object
self.mpu = None self.mpu = mpu
self.overflow = False self.overflow = False
self.overflow_checker = CheckOverflow(self.fp16_groups, mpu=self.mpu) self.overflow_checker = CheckOverflow(self.fp16_groups, mpu=self.mpu)
...@@ -237,8 +237,8 @@ class FP16_Optimizer(object): ...@@ -237,8 +237,8 @@ class FP16_Optimizer(object):
if self.overflow: if self.overflow:
if self.verbose: if self.verbose:
print("[deepspeed] OVERFLOW! Skipping step. Attempted loss " print("[deepspeed] OVERFLOW! Skipping step. Attempted loss "
"scale: {}, reducing to {}".format(prev_scale, "scale: {}, reducing to {} ".format(prev_scale,
self.cur_scale)) self.cur_scale))
self.log_timers(OVERFLOW_TIMERS) self.log_timers(OVERFLOW_TIMERS)
grads_groups_flat = None grads_groups_flat = None
return self.overflow return self.overflow
......
...@@ -93,11 +93,13 @@ class FP16_UnfusedOptimizer(object): ...@@ -93,11 +93,13 @@ class FP16_UnfusedOptimizer(object):
else: else:
self.clip_grad_norm = torch.nn.utils.clip_grad_norm_ self.clip_grad_norm = torch.nn.utils.clip_grad_norm_
self.mpu = None self.mpu = mpu
self.overflow = False self.overflow = False
self.overflow_checker = CheckOverflow(self.fp16_groups, mpu=self.mpu) self.overflow_checker = CheckOverflow(self.fp16_groups, mpu=self.mpu)
self.initialize_optimizer_states()
def zero_grad(self, set_grads_to_None=True): def zero_grad(self, set_grads_to_None=True):
""" """
Zero FP16 parameter grads. Zero FP16 parameter grads.
...@@ -349,3 +351,26 @@ class FP16_UnfusedOptimizer(object): ...@@ -349,3 +351,26 @@ class FP16_UnfusedOptimizer(object):
def __repr__(self): def __repr__(self):
return repr(self.optimizer) return repr(self.optimizer)
def initialize_optimizer_states(self):
for i, group in enumerate(self.fp16_groups):
for param in group:
param.grad = torch.zeros(param.size(),
dtype=param.dtype,
device=torch.cuda.current_device())
for i, group in enumerate(self.fp32_groups):
for param in group:
param.grad = torch.zeros(param.size(),
dtype=param.dtype,
device=torch.cuda.current_device())
self.optimizer.step()
for i, group in enumerate(self.fp16_groups):
for param in group:
param.grad = None
for i, group in enumerate(self.fp32_groups):
for param in group:
param.grad = None
from .module import PipelineModule, LayerSpec, TiedLayerSpec
# Copyright 2019 The Microsoft DeepSpeed Team
import time
import logging
import copy
import os
from types import MethodType
from numpy import prod
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
from deepspeed.utils.logging import logger
from deepspeed.utils.timer import SynchronizedWallClockTimer, ThroughputTimer
from ..engine import DeepSpeedEngine, MEMORY_OPT_ALLREDUCE_SIZE
from ..utils import PartitionedTensor, ensure_directory_exists
from ..dataloader import RepeatingLoader
from .module import PipelineModule, PipelineError, TiedLayerSpec
from . import p2p
from . import schedule
TARGET_ID = -2
LOG_STAGE = -2
DATA_PARALLEL_ID = -2
def is_even(number):
return number % 2 == 0
mem_alloced = 0
mem_cached = 0
def _tensor_bytes(tensor):
return tensor.numel() * tensor.element_size()
class PipelineEngine(DeepSpeedEngine):
""" A model wrapper for pipeline-parallel execution.
Parallelism is achieved by executing micro-batches in a pipelined fashion with
gradient accumulation.
"""
def __init__(self, *super_args, **super_kwargs):
super().__init__(*super_args, **super_kwargs)
assert isinstance(self.module, PipelineModule), "model must base PipelineModule"
# pipeline step for logging
self.log_batch_step_id = -1
self.micro_batch_size = self.train_micro_batch_size_per_gpu()
self.micro_batches = self.gradient_accumulation_steps()
# Set Grid and Communication Groups
self.grid = self.module._grid
if self.grid.get_global_rank() == 0:
logger.info(f'CONFIG: micro_batches={self.micro_batches} '
f'micro_batch_size={self.micro_batch_size}')
self.global_rank = self.grid.get_global_rank()
assert self.dp_world_size == self.grid.data_parallel_size
assert self.train_batch_size() == \
self.micro_batch_size * self.micro_batches * self.grid.data_parallel_size
# Set Stage Inf
self.num_stages = self.grid.pipe_parallel_size
self.stage_id = self.grid.get_stage_id()
self.prev_stage = self.stage_id - 1
self.next_stage = self.stage_id + 1
self.data_iterator = None
self.batch_fn = None
self._force_grad_boundary = False
self.batch_timer = ThroughputTimer(batch_size=self.micro_batch_size *
self.micro_batches,
num_workers=self.dp_world_size,
logging_fn=self.tput_log,
monitor_memory=False,
steps_per_output=self.steps_per_print())
# PipelineEngine needs to handle data loading specially due to only the first
# and last stages loading inputs/labels. We construct a sampler that uses
if self.training_data:
self._build_data_iter(self.training_data)
self.is_pipe_parallel = self.grid.pipe_parallel_size > 1
self.is_data_parallel = self.grid.data_parallel_size > 1
self.is_model_parallel = self.grid.model_parallel_size > 1
# Partition input/output buffers
self.is_pipe_partitioned = self.is_model_parallel
self.is_grad_partitioned = False
model_parameters = filter(lambda p: p.requires_grad, self.module.parameters())
num_params = sum([p.numel() for p in model_parameters])
unique_params = num_params
# Subtract tied parameters if we don't own them
if self.module.tied_comms:
tied_params = 0
for key, d in self.module.tied_comms.items():
if self.global_rank != min(d['ranks']):
tied_params += sum(p.numel() for p in d['module'].parameters())
unique_params -= tied_params
params_tensor = torch.LongTensor(data=[num_params,
unique_params]).to(self.device)
dist.all_reduce(params_tensor, group=self.grid.get_model_parallel_group())
params_tensor = params_tensor.tolist()
total_params = params_tensor[0]
unique_params = params_tensor[1]
if self.grid.data_parallel_id == 0:
logger.info(f'RANK={self.global_rank} '
f'STAGE={self.stage_id} '
f'LAYERS={self.module._local_stop - self.module._local_start} '
f'[{self.module._local_start}, {self.module._local_stop}) '
f'STAGE_PARAMS={num_params} ({num_params/1e6:0.3f}M) '
f'TOTAL_PARAMS={total_params} ({total_params/1e6:0.3f}M) '
f'UNIQUE_PARAMS={unique_params} ({unique_params/1e6:0.3f}M)')
#intialize peer-2-peer communication and allreduce groups
if self.is_pipe_parallel:
p2p.init_process_groups(self.grid)
# Pipeline buffers
self.num_pipe_buffers = 0
self.pipe_buffers = {
'inputs' : [], # batch input and received activations
'labels' : [], # labels from batch input
'outputs' : [], # activations
'output_tensors' : [], # tensor object to preserve backward graph
}
self.pipe_recv_buf = None
self.grad_layer = None
self.meta_buffer = None
self.first_output_send = True
self.first_gradient_send = True
#stores the loss for the current micro batch being processed
self.loss = torch.tensor(0.0).to(self.device)
#stores the loss for the entire batch
self.total_loss = None
self.agg_loss = torch.tensor(0.0, requires_grad=False).to(self.device)
self.dp_group_loss = torch.tensor(0.0, requires_grad=False).to(self.device)
if self._config.pipeline['activation_checkpoint_interval'] > 0:
self.module.activation_checkpoint_interval = self._config.pipeline[
'activation_checkpoint_interval']
if self.is_last_stage():
self.loss_model = self.module.loss_fn
# Initialize pipeline communicators. Just send a 0.
if is_even(self.stage_id):
if not self.is_last_stage():
p2p.send(self.loss, self.next_stage)
if not self.is_first_stage():
p2p.recv(self.loss, self.prev_stage)
else:
if not self.is_first_stage():
p2p.recv(self.loss, self.prev_stage)
if not self.is_last_stage():
p2p.send(self.loss, self.next_stage)
# XXX look into timer reporting timing
# Initialize some timers because of early weirdness.
if self.wall_clock_breakdown():
self.timers('forward_microstep').start()
self.timers('forward_microstep').stop()
self.timers('backward_microstep').start()
self.timers('backward_microstep').stop()
self.timers('backward_inner_microstep').start()
self.timers('backward_inner_microstep').stop()
self.timers('backward_allreduce_microstep').start()
self.timers('backward_allreduce_microstep').stop()
self.timers('backward_allreduce').start()
self.timers('backward_allreduce').stop()
self.timers('step_microstep').start()
self.timers('step_microstep').stop()
def _build_data_iter(self, dataset):
sampler = torch.utils.data.distributed.DistributedSampler(
dataset,
num_replicas=self.dp_world_size,
rank=self.mpu.get_data_parallel_rank(),
shuffle=False)
# Build a loader and make it repeating.
pipe_dataloader = self.deepspeed_io(dataset, data_sampler=sampler)
pipe_dataloader = RepeatingLoader(pipe_dataloader)
self.set_dataloader(pipe_dataloader)
def _exec_reduce_tied_grads(self):
self.module.allreduce_tied_weight_gradients()
def _exec_reduce_grads(self):
self._force_grad_boundary = True
if self.is_data_parallel:
self.buffered_allreduce_fallback(
elements_per_buffer=MEMORY_OPT_ALLREDUCE_SIZE)
self._force_grad_boundary = False
def _reserve_pipe_buffers(self, num_buffers):
"""Ensure that each pipeline buffer has at least ``num_buffers`` slots.
This method only reserves slots and does not allocate tensors.
Args:
num_buffers (int): The number of buffers to reserve.
"""
if self.num_pipe_buffers >= num_buffers:
return
num_added = num_buffers - self.num_pipe_buffers
for key in self.pipe_buffers:
self.pipe_buffers[key].extend([None] * num_added)
self.num_pipe_buffers = num_buffers
def train_batch(self, data_iter=None):
"""Progress the pipeline to train the next batch of data.
Returns:
The arithmetic mean of the losses over all micro-batches.
"""
if not torch._C.is_grad_enabled():
raise RuntimeError(
f'train_batch() requires gradients enabled. Use eval_batch() instead.')
if data_iter:
self.set_dataiterator(data_iter)
self.module.train()
self.total_loss = None
self.timers('train_batch').start()
# Do the work
sched = schedule.TrainSchedule(micro_batches=self.micro_batches,
stages=self.num_stages,
stage_id=self.stage_id)
self._exec_schedule(sched)
self.agg_train_loss = self._aggregate_total_loss()
self.timers('train_batch').stop()
if self.global_steps % self.steps_per_print() == 0:
if self.global_rank == 0:
elapsed = self.timers('train_batch').elapsed(reset=True)
iter_time = elapsed / self.steps_per_print()
tput = self.train_batch_size() / elapsed
print(f'steps: {self.global_steps} '
f'loss: {self.agg_train_loss:0.4f} '
f'iter time (s): {iter_time:0.3f} '
f'samples/sec: {tput:0.3f}')
# Tensorboard
if self.tensorboard_enabled():
if self.global_rank == 0:
self.summary_events = [(f'Train/Samples/train_loss',
self.agg_train_loss.mean().item(),
self.global_samples)]
for event in self.summary_events: # write_summary_events
self.summary_writer.add_scalar(event[0], event[1], event[2])
if self.global_steps % self.steps_per_print() == 0:
self.summary_writer.flush()
if self.wall_clock_breakdown(
) and self.global_steps % self.steps_per_print() == 0:
self.timers.log([
'pipe_send_output',
'pipe_send_grad',
'pipe_recv_input',
'pipe_recv_grad'
])
# TODO: should return precisely what loss returned and allow others to be queried?
return self.agg_train_loss
def eval_batch(self, data_iter):
"""Evaluate the pipeline on a batch of data from ``data_iter``.
This method is equivalent to:
.. code-block:: python
module.eval()
with torch.no_grad():
output = module(batch)
Returns:
The arithmetic mean of the losses over all micro-batches.
"""
self.module.eval()
self.total_loss = None
# Use the provided data iterator
train_iterator = self.data_iterator
self.set_dataiterator(data_iter)
# Do the work
sched = schedule.InferenceSchedule(micro_batches=self.micro_batches,
stages=self.num_stages,
stage_id=self.stage_id)
with torch.no_grad():
self._exec_schedule(sched)
self.agg_eval_loss = self._aggregate_total_loss()
if self.tensorboard_enabled():
if self.global_rank == 0:
self.summary_events = [(f'Train/Samples/eval_loss',
self.agg_eval_loss.mean().item(),
self.global_samples)]
for event in self.summary_events: # write_summary_events
self.summary_writer.add_scalar(event[0], event[1], event[2])
self.summary_writer.flush()
# Restore the training iterator
self.set_dataiterator(train_iterator)
# Reset any buffers that may have been populated during the forward passes.
#ds_checkpointing.reset()
return self.agg_eval_loss
def _aggregate_total_loss(self):
# Scale loss, average among DP ranks, and bcast loss to the rest of my DP group
if self.is_last_stage():
loss = self._scale_loss(self.total_loss)
self.dp_group_loss = loss.clone().detach()
## Average loss across all data-parallel groups
agg_loss = self.dp_group_loss.clone().detach()
#print(f'RANK={self.global_rank} bcast SENDER src={self.global_rank} group={self.grid.pp_group}', flush=True)
if self.is_data_parallel:
dist.all_reduce(agg_loss, group=self.mpu.get_data_parallel_group())
agg_loss /= self.dp_world_size
assert self.global_rank in self.grid.pp_group
losses = torch.Tensor([self.dp_group_loss, agg_loss]).to(self.device)
dist.broadcast(tensor=losses,
src=self.global_rank,
group=self.mpu.get_pipe_parallel_group())
else:
# Get loss from last stage
src_rank = self.grid.stage_to_global(self.num_stages - 1)
assert src_rank in self.grid.pp_group
losses = torch.Tensor([0., 0.]).to(self.device)
dist.broadcast(tensor=losses,
src=src_rank,
group=self.grid.get_pipe_parallel_group())
self.dp_group_loss = losses[0].clone().detach()
agg_loss = losses[1].clone().detach()
return agg_loss
def set_dataloader(self, loader):
""" Store a DataLoader to sample for training data. """
if self.is_first_stage() or self.is_last_stage():
self.training_dataloader = loader
self.data_iterator = iter(self.training_dataloader)
def set_dataiterator(self, iterator):
""" Store an iterator to sample for training data. """
if self.is_first_stage() or self.is_last_stage():
self.training_dataloader = None
self.data_iterator = iterator
def set_batch_fn(self, fn):
self.batch_fn = fn
def is_gradient_accumulation_boundary(self):
"""True if the engine is executing a gradient reduction or optimizer step instruction.
This is overridden from :class:`DeepSpeedEngine` to force reductions
and steps when the pipeline engine is instructed to do so.
Returns:
bool: whether reductions and optimizer steps should occur.
"""
return self._force_grad_boundary
def log_for_device(self, *msg):
if LOG_STAGE == self.stage_id or LOG_STAGE == -1:
if DATA_PARALLEL_ID == self.grid.data_parallel_id or DATA_PARALLEL_ID == -1:
print(
f'RANK={dist.get_rank()} '
f'PIPE-ID={self.stage_id} '
f'DATA-ID={self.grid.data_parallel_id} '
f'MBATCH-ID={self.microbatch_id} '
f'STEP-ID={self.log_batch_step_id} '
'::',
*msg,
flush=True)
def tput_log(self, *msg):
if self.global_rank == 0 and self.global_steps % self.steps_per_print() == 0:
print(*msg)
def _next_batch(self):
if self.is_model_parallel:
mp_rank = self.grid.get_slice_parallel_rank()
else:
mp_rank = 0
batch = None
# Only MP rank 0 loads the data.
if mp_rank == 0:
if self.data_iterator is None:
raise ValueError(f"RANK={self.global_rank} no data iterator provided.")
batch = next(self.data_iterator)
# All MP ranks participate in batch_fn, where they might broadcast the data.
if self.batch_fn:
batch = self.batch_fn(batch)
# Sanity check dimensions.
# XXX: the last minibatch with size < micro_batch_size kills us
if torch.is_tensor(batch[0]):
if batch[0].size(0) != self.micro_batch_size:
print(f'size mismatch: {batch[0].size(0)} mb: {self.micro_batch_size}')
return self._next_batch()
else:
assert torch.is_tensor(batch[0][0])
if batch[0][0].size(0) != self.micro_batch_size:
return self._next_batch()
return batch
def _exec_forward_pass(self, buffer_id):
self.tput_timer.start()
self.mem_status('BEFORE FWD', reset_max=True)
if isinstance(self.pipe_buffers['inputs'][buffer_id], tuple):
inputs = tuple(t.clone() for t in self.pipe_buffers['inputs'][buffer_id])
else:
inputs = self.pipe_buffers['inputs'][buffer_id].clone()
# collect the partitioned input from the previous stage
if self.is_pipe_partitioned and not self.is_first_stage():
part_input = PartitionedTensor.from_meta(
meta=inputs[0],
local_part=inputs[1],
group=self.grid.get_slice_parallel_group())
inputs = tuple([part_input.full(), inputs[2]])
inputs[0].requires_grad = True
# skip mask
#inputs[1].requires_grad = True
part_input = None
self.pipe_buffers['inputs'][buffer_id] = inputs
# Zero out the gradients each time we use the tensor because only the data in
# tensor changes across batches
self._zero_grads(inputs)
outputs = super().forward(inputs)
# Partition the outputs if we are not the last stage
if self.is_pipe_partitioned and not self.is_last_stage():
part = PartitionedTensor(tensor=outputs[0],
group=self.grid.get_slice_parallel_group())
# Clear the large output data, but save the computation graph
outputs[0].data = torch.zeros(1)
self.pipe_buffers['output_tensors'][buffer_id] = outputs[0]
# Inject the partitioned tensor into the output before sending
outputs = tuple([part.to_meta(), part.data(), outputs[1]])
part = None
self.pipe_buffers['outputs'][buffer_id] = outputs
# Optionally compute loss on the last device
if self.is_last_stage():
if self.loss_model is not None:
labels = self.pipe_buffers['labels'][buffer_id]
self.loss = self.loss_model(outputs, labels)
else:
# Some models just return loss from forward()
self.loss = outputs
if isinstance(self.loss, torch.Tensor):
if self.total_loss is None:
self.total_loss = torch.zeros_like(self.loss)
self.total_loss += self.loss.detach()
else:
if self.total_loss is None:
self.total_loss = [torch.zeros_like(l) for l in self.loss]
for idx, l in enumerate(self.loss):
self.total_loss[idx] += l.detach()
def _exec_backward_pass(self, buffer_id):
assert self.optimizer is not None, "must provide optimizer during " \
"init in order to use backward"
self.mem_status('BEFORE BWD', reset_max=True)
# The last stage just runs backward on the loss using DeepSpeed's typical
# mechanisms.
if self.is_last_stage():
super().backward(self.loss, allreduce_gradients=False)
self.mem_status('AFTER BWD')
return
outputs = self.pipe_buffers['outputs'][buffer_id]
if self.wall_clock_breakdown():
self.timers('backward_microstep').start()
self.timers('backward').start()
self.timers('backward_inner_microstep').start()
self.timers('backward_inner').start()
# Reconstruct if we previously partitioned the output. We must be
# careful to also restore the computational graph of the tensors we partitioned.
if self.is_pipe_partitioned:
if self.is_grad_partitioned:
part_output = PartitionedTensor.from_meta(
meta=outputs[0],
local_part=outputs[1],
group=self.grid.get_slice_parallel_group())
self.pipe_buffers['output_tensors'][buffer_id].data = part_output.full()
outputs = tuple(
[self.pipe_buffers['output_tensors'][buffer_id],
outputs[2]])
else:
# Already restored from partition
self.pipe_buffers['output_tensors'][buffer_id].data = outputs[0]
outputs = tuple(
[self.pipe_buffers['output_tensors'][buffer_id],
outputs[1]])
grad_tensors = self.grad_layer
if self.is_grad_partitioned:
#print(f'RANK={self.global_rank} BEFORE-BWD restoring grad={self.grad_layer[0].size()} {self.grad_layer[1].size()}')
part_grad = PartitionedTensor.from_meta(
meta=self.grad_layer[0],
local_part=self.grad_layer[1],
group=self.grid.get_slice_parallel_group())
grad_tensors = tuple([part_grad.full(), self.grad_layer[2]])
part_grad = None
#print(f'RANK={self.global_rank} BEFORE-BWD restored grad={self.grad_layer[0].size()} {self.grad_layer[1].size()}')
# This handles either a single tensor or tuple of tensors.
if isinstance(outputs, tuple):
out_tensors = [t for t in outputs if t.is_floating_point()]
assert len(out_tensors) == len(grad_tensors)
torch.autograd.backward(tensors=out_tensors, grad_tensors=grad_tensors)
else:
torch.autograd.backward(tensors=(outputs, ), grad_tensors=(grad_tensors, ))
# Free up the memory from the output of forward()
self.pipe_buffers['output_tensors'][buffer_id] = None
self.pipe_buffers['outputs'][buffer_id] = None
grad_tensors = None
if self.wall_clock_breakdown():
self.timers('backward_inner').stop()
self.timers('backward_inner_microstep').stop()
self.timers('backward').stop()
self.timers('backward_microstep').stop()
self.mem_status('AFTER BWD')
def _exec_load_micro_batch(self, buffer_id):
if self.wall_clock_breakdown():
self.timers('batch_input').start()
batch = self._next_batch()
if self.is_first_stage():
loaded = None
if torch.is_tensor(batch[0]):
loaded = batch[0].clone().to(self.device).detach()
loaded.requires_grad = loaded.is_floating_point()
else:
assert isinstance(batch[0], tuple)
# Assume list or tuple
loaded = []
for x in batch[0]:
assert torch.is_tensor(x)
mine = x.clone().detach().to(self.device)
mine.requires_grad = mine.is_floating_point()
loaded.append(mine)
loaded = tuple(loaded)
self.pipe_buffers['inputs'][buffer_id] = loaded
if self.is_last_stage():
loaded = batch[1]
if torch.is_tensor(batch[1]):
loaded = batch[1].to(self.device)
elif isinstance(batch[1], tuple):
loaded = []
for x in batch[1]:
assert torch.is_tensor(x)
x = x.to(self.device).detach()
loaded.append(x)
loaded = tuple(loaded)
self.pipe_buffers['labels'][buffer_id] = loaded
if self.wall_clock_breakdown():
self.timers('batch_input').stop()
def _send_tensor_meta(self, buffer, recv_stage):
""" Communicate metadata about upcoming p2p transfers.
Metadata is communicated in this order:
* type (0: tensor, 1: list)
* num_tensors if type=list
foreach tensor in buffer:
* ndims
* shape
"""
send_bytes = 0
if isinstance(buffer, torch.Tensor):
type_tensor = torch.LongTensor(data=[0]).to(self.device)
p2p.send(type_tensor, recv_stage)
send_shape = torch.LongTensor(data=buffer.size()).to(self.device)
send_ndims = torch.LongTensor(data=[len(buffer.size())]).to(self.device)
p2p.send(send_ndims, recv_stage)
p2p.send(send_shape, recv_stage)
send_bytes += _tensor_bytes(buffer)
elif isinstance(buffer, list):
assert (False)
type_tensor = torch.LongTensor(data=[1]).to(self.device)
p2p.send(type_tensor, recv_stage)
count_tensor = torch.LongTensor(data=[len(buffer)]).to(self.device)
p2p.send(count_tensor, recv_stage)
for tensor in buffer:
assert isinstance(tensor, torch.Tensor)
send_shape = torch.LongTensor(data=tensor.size()).to(self.device)
send_ndims = torch.LongTensor(data=[len(tensor.size())]).to(self.device)
p2p.send(send_ndims, recv_stage)
p2p.send(send_shape, recv_stage)
send_bytes += _tensor_bytes(tensor)
elif isinstance(buffer, tuple):
type_tensor = torch.LongTensor(data=[2]).to(self.device)
p2p.send(type_tensor, recv_stage)
count_tensor = torch.LongTensor(data=[len(buffer)]).to(self.device)
p2p.send(count_tensor, recv_stage)
for idx, tensor in enumerate(buffer):
assert isinstance(tensor, torch.Tensor)
send_shape = torch.LongTensor(data=tensor.size()).to(self.device)
send_ndims = torch.LongTensor(data=[len(tensor.size())]).to(self.device)
p2p.send(send_ndims, recv_stage)
p2p.send(send_shape, recv_stage)
# Useful for performance debugging.
'''
new_bytes = _tensor_bytes(tensor)
send_bytes += _tensor_bytes(tensor)
# Useful for performance debugging.
if self.grid.data_parallel_id == 0:
print(
f'STAGE={self.stage_id} pipe-send-volume[{idx}]: shape={send_shape} {new_bytes/1024**2:0.2f}MB'
)
'''
else:
raise NotImplementedError(f'Could not send meta type {type(buffer)}')
# Useful for performance debugging.
'''
if self.grid.data_parallel_id == 0:
print(f'STAGE={self.stage_id} pipe-send-volume: {send_bytes/1024**2:0.2f}MB')
'''
def _recv_tensor_meta(self, send_stage):
"""Receive metadata about upcoming p2p transfers and return allocated buffers.
Metadata is communicated in this order:
* type (0: tensor, 1: list)
* num_tensors if type=list
foreach tensor in buffer:
* ndims
* shape
Returns:
Allocated buffer for receiving from send_stage.
"""
type_tensor = torch.LongTensor(data=[0]).to(self.device)
p2p.recv(type_tensor, send_stage)
recv_type = type_tensor.item()
# A single tensor will be sent.
if recv_type == 0:
recv_ndims = torch.LongTensor(data=[0]).to(self.device)
p2p.recv(recv_ndims, send_stage)
recv_ndims = recv_ndims.item()
recv_shape = torch.LongTensor([1] * recv_ndims).to(self.device)
p2p.recv(recv_shape, send_stage)
recv_shape = recv_shape.tolist()
return self._allocate_buffer(recv_shape, num_buffers=1)[0]
# List or tuple of tensors
elif recv_type == 1 or recv_type == 2:
count_tensor = torch.LongTensor(data=[0]).to(self.device)
p2p.recv(count_tensor, send_stage)
num_tensors = count_tensor.item()
recv_shapes = []
for idx in range(num_tensors):
recv_ndims = torch.LongTensor(data=[0]).to(self.device)
p2p.recv(recv_ndims, send_stage)
recv_ndims = recv_ndims.item()
recv_shape = torch.LongTensor([1] * recv_ndims).to(self.device)
p2p.recv(recv_shape, send_stage)
recv_shapes.append(recv_shape.tolist())
buffers = self._allocate_buffers(recv_shapes, num_buffers=1)[0]
# Convert to tuples if requested.
if recv_type == 2:
buffers = tuple(buffers)
return buffers
else:
raise NotImplementedError(f'Could not receive type {type(recv_type)}')
def _exec_send_activations(self, buffer_id):
if self.wall_clock_breakdown():
self.timers('pipe_send_output').start()
outputs = self.pipe_buffers['outputs'][buffer_id]
# NCCL does not like to send torch.BoolTensor types, so cast the mask to half().
# We could do char, but with half() we can eventually flatten with other fp16
# messages (TODO)
if self.module.__class__.__name__ == 'GPT2ModelPipe':
outputs = list(outputs)
outputs[-1] = outputs[-1].half()
outputs = tuple(outputs)
if self.first_output_send:
self.first_output_send = False
self._send_tensor_meta(outputs, self.next_stage)
if isinstance(outputs, torch.Tensor):
p2p.send(outputs, self.next_stage)
elif isinstance(outputs, tuple):
for idx, buffer in enumerate(outputs):
p2p.send(buffer, self.next_stage)
else:
raise NotImplementedError('Could not send output of type '
f'{type(outputs)}')
# Restore the boolean tensor
if self.module.__class__.__name__ == 'GPT2ModelPipe':
outputs = list(outputs)
outputs[-1] = outputs[-1].bool()
outputs = tuple(outputs)
if self.wall_clock_breakdown():
self.timers('pipe_send_output').stop()
def _exec_send_grads(self, buffer_id):
if self.wall_clock_breakdown():
self.timers('pipe_send_grad').start()
inputs = self.pipe_buffers['inputs'][buffer_id]
# Partition the gradient
if self.is_grad_partitioned:
part = PartitionedTensor(tensor=inputs[0].grad,
group=self.grid.get_slice_parallel_group())
# Clear the large output data, but save the computation graph
# Inject the partitoned tensor into the output before sending
# XXX Hack
inputs = tuple([part.to_meta(), part.data(), inputs[1]])
# XXX Terrible hack
# Drop the attention mask from the input buffer here. It does not have
# a grad that needs to be communicated. We free the buffer immediately
# after, so no need to restore it. The receiver also has a hack that skips
# the recv. This is because NCCL does not let us send torch.BoolTensor :-(.
if self.module.__class__.__name__ == 'GPT2ModelPipe':
inputs = list(inputs)
inputs.pop()
inputs = tuple(inputs)
if isinstance(inputs, torch.Tensor):
assert inputs.grad is not None
p2p.send(inputs.grad, self.prev_stage)
else:
# XXX terrible hacky branch
if self.is_grad_partitioned:
# First two sends are partitioned gradient
p2p.send(inputs[0], self.prev_stage)
p2p.send(inputs[1], self.prev_stage)
# XXX hack hack hack
#p2p.send(inputs[2].grad, self.prev_stage)
else:
for idx, buffer in enumerate(inputs):
# Skip tensors that will not produce a grad
if not buffer.is_floating_point():
assert buffer.grad is None
continue
assert buffer.grad is not None
p2p.send(buffer.grad, self.prev_stage)
# We can free up the input buffer now
self.pipe_buffers['inputs'][buffer_id] = None
if self.wall_clock_breakdown():
self.timers('pipe_send_grad').stop()
def _exec_recv_activations(self, buffer_id):
if self.wall_clock_breakdown():
self.timers('pipe_recv_input').start()
recvd = None
# Allocate the buffer if necessary
if self.pipe_recv_buf is None:
self.pipe_recv_buf = self._recv_tensor_meta(self.prev_stage)
if isinstance(self.pipe_recv_buf, torch.Tensor):
p2p.recv(self.pipe_recv_buf, self.prev_stage)
recvd = self.pipe_recv_buf.clone().detach()
recvd.requires_grad = recvd.is_floating_point()
else:
assert isinstance(self.pipe_recv_buf, tuple)
recvd = [None] * len(self.pipe_recv_buf)
for idx, buffer in enumerate(self.pipe_recv_buf):
assert torch.is_tensor(buffer)
# XXX hardcode meta type
if self.is_pipe_partitioned and idx == 0 and buffer.dtype != torch.long:
if self.meta_buffer is None:
self.meta_buffer = torch.zeros(buffer.size(),
dtype=torch.long,
device=self.device)
buffer = self.meta_buffer
p2p.recv(buffer, self.prev_stage)
recvd[idx] = buffer.clone().detach()
# NCCL does not like to send torch.BoolTensor types, so un-cast the
# attention mask
if self.module.__class__.__name__ == 'GPT2ModelPipe':
recvd[-1] = recvd[-1].bool()
recvd = tuple(recvd)
for buffer in recvd:
buffer.requires_grad = buffer.is_floating_point()
self.pipe_buffers['inputs'][buffer_id] = recvd
if self.wall_clock_breakdown():
self.timers('pipe_recv_input').stop()
def _exec_recv_grads(self, buffer_id):
if self.wall_clock_breakdown():
self.timers('pipe_recv_grad').start()
outputs = self.pipe_buffers['outputs'][buffer_id]
# XXX these shapes are hardcoded for Megatron
# Restore partitioned output if it was partitioned and we are sending full gradients
if self.is_pipe_partitioned and not self.is_grad_partitioned:
part_output = PartitionedTensor.from_meta(
meta=outputs[0],
local_part=outputs[1],
group=self.grid.get_slice_parallel_group())
outputs[0].data = part_output.full()
outputs = tuple([outputs[0], outputs[2]])
# save for backward
self.pipe_buffers['outputs'][buffer_id] = outputs
# Allocate gradient if necessary
if self.grad_layer is None:
if isinstance(outputs, torch.Tensor):
s = list(outputs.size())
self.grad_layer = self._allocate_buffer(s, num_buffers=1)[0]
else:
sizes = [list(t.size()) for t in outputs if t.is_floating_point()]
self.grad_layer = self._allocate_buffers(sizes, num_buffers=1)[0]
if isinstance(self.grad_layer, torch.Tensor):
p2p.recv(self.grad_layer, self.next_stage)
else:
assert isinstance(outputs, tuple)
for idx, buffer in enumerate(self.grad_layer):
# XXX GPT-2 hack
if self.is_grad_partitioned and idx == 0 and buffer.dtype != torch.long:
buffer.data = torch.zeros(buffer.size(),
dtype=torch.long,
device=self.device)
p2p.recv(buffer, self.next_stage)
if self.wall_clock_breakdown():
self.timers('pipe_recv_grad').stop()
def _exec_optimizer_step(self):
if self.wall_clock_breakdown():
self.timers('step_microstep').start()
self.timers('step').start()
self.mem_status('BEFORE STEP', reset_max=True)
self._force_grad_boundary = True
self._take_model_step()
self._force_grad_boundary = False
self.mem_status('AFTER STEP')
if self.tensorboard_enabled():
if self.global_rank == 0:
self.summary_events = [(f'Train/Samples/lr',
self.get_lr()[0],
self.global_samples)]
if self.fp16_enabled() and hasattr(self.optimizer, 'cur_scale'):
self.summary_events.append((f'Train/Samples/loss_scale',
self.optimizer.cur_scale,
self.global_samples))
for event in self.summary_events: # write_summary_events
self.summary_writer.add_scalar(event[0], event[1], event[2])
if self.wall_clock_breakdown():
self.timers('step_microstep').stop()
self.timers('step').stop()
if self.global_steps % self.steps_per_print() == 0:
self.timers.log([
'batch_input',
'forward_microstep',
'backward_microstep',
'backward_inner_microstep',
'backward_allreduce_microstep',
'backward_tied_allreduce_microstep',
'step_microstep'
])
if self.global_steps % self.steps_per_print() == 0:
self.timers.log([
'forward',
'backward',
'backward_inner',
'backward_allreduce',
'step'
])
def _zero_grads(self, inputs):
if isinstance(inputs, torch.Tensor):
if inputs.grad is not None:
inputs.grad.data.zero_()
else:
for t in inputs:
if t.grad is not None:
t.grad.data.zero_()
def _allocate_zeros(self, shape, fp16=None, **kwargs):
""" Allocate a tensor of zeros on the engine's device.
Arguments:
shape: the shape of the tensor to allocate
fp16 (bool): whether to use FP16. default: defer to self.fp16_enabled()
kwargs: passed to torch.zeros()
Returns:
A tensor from torch.zeros() allocated on self.device.
"""
if fp16 is None:
fp16 = self.fp16_enabled()
if fp16:
return torch.zeros(shape, dtype=torch.half, device=self.device, **kwargs)
else:
return torch.zeros(shape, device=self.device, **kwargs)
def _allocate_buffer(self, shape, num_buffers=-1, **kwargs):
buffers = []
if num_buffers == -1:
num_buffers = self.num_pipe_buffers
for count in range(num_buffers):
buffers.append(self._allocate_zeros(shape, **kwargs))
return buffers
def _allocate_buffers(self, shapes, requires_grad=False, num_buffers=-1):
buffers = []
if num_buffers == -1:
num_buffers = self.num_pipe_buffers
for count in range(num_buffers):
buffer = []
for shape in shapes:
buffer.append(self._allocate_zeros(shape, requires_grad=requires_grad))
buffers.append(buffer)
return buffers
def forward(self, *args, **kwargs):
raise PipelineError("Only train_batch() is accessible in pipeline mode.")
def backward(self, *args, **kwargs):
raise PipelineError("Only train_batch() is accessible in pipeline mode.")
def step(self, *args, **kwargs):
raise PipelineError("Only train_batch() is accessible in pipeline mode.")
def mem_status(self, msg, print_rank=-1, reset_max=False):
return
global mem_alloced, mem_cached
if not self.global_steps == 0 or not self.global_steps == 9:
#return
pass
if self.mpu.get_data_parallel_rank() != 0:
return
if self.global_rank != 0:
return
rank = self.global_rank
if print_rank != -1 and rank != print_rank:
return
torch.cuda.synchronize()
if reset_max:
torch.cuda.reset_max_memory_cached()
torch.cuda.reset_max_memory_allocated()
new_alloced = torch.cuda.memory_allocated()
new_cached = torch.cuda.memory_cached()
delta_alloced = new_alloced - mem_alloced
delta_cached = new_cached - mem_cached
mem_cached = new_cached
mem_alloced = new_alloced
max_alloced = torch.cuda.max_memory_allocated()
max_cached = torch.cuda.max_memory_cached()
# convert to GB for printing
new_alloced /= 1024**3
new_cached /= 1024**3
delta_alloced /= 1024**3
delta_cached /= 1024**3
max_alloced /= 1024**3
max_cached /= 1024**3
print(
f'RANK={rank} STAGE={self.stage_id} STEP={self.global_steps} MEMSTATS',
msg,
f'current alloc={new_alloced:0.4f}GB (delta={delta_alloced:0.4f}GB max={max_alloced:0.4f}GB) '
f'current cache={new_cached:0.4f}GB (delta={delta_cached:0.4f}GB max={max_cached:0.4f}GB)'
)
def module_state_dict(self):
"""Override hack to save a pipe model and return the directory path of the save.
This method should only be called by DeepSpeed's ``save_checkpoint()``. The
recommended way of saving a ``PipelineModule`` outside of ``save_checkpoint()``
is ``save_state_dict()``.
Returns:
str: The directory path where the checkpoint was saved.
"""
assert isinstance(self.module, PipelineModule)
assert self._curr_save_path is not None, \
"PipelineEngine expects module_state_dict() to be called from save_checkpoint()"
self.module.save_state_dict(self._curr_save_path)
return self._curr_save_path
def load_module_state_dict(self, state_dict, strict=True):
"""Override hack to instead use a directory path.
This is important because pipeline models checkpoint by layer instead of rank.
If ``state_dict`` is not a ``str``, we revert to ``super()`` expecting a ``dict``.
Args:
state_dict (str): Path to the directory for checkpoint.
strict (bool, optional): Strict state loading. Defaults to True.
"""
if not isinstance(state_dict, str):
super().load_module_state_dict(state_dict, strict)
return
self.module.load_state_dir(state_dict, strict=strict)
def is_first_stage(self):
"""True if this process is in the first stage in the pipeline."""
return self.stage_id == 0
def is_last_stage(self):
"""True if this process is in the last stage in the pipeline."""
return self.stage_id == self.num_stages - 1
# A map of PipeInstruction types to methods. Each method will be executed with the
# kwargs provided to the PipeInstruction from the scheduler.
_INSTRUCTION_MAP = {
schedule.OptimizerStep: _exec_optimizer_step,
schedule.ReduceGrads: _exec_reduce_grads,
schedule.ReduceTiedGrads: _exec_reduce_tied_grads,
schedule.LoadMicroBatch: _exec_load_micro_batch,
schedule.ForwardPass: _exec_forward_pass,
schedule.BackwardPass: _exec_backward_pass,
schedule.SendActivation: _exec_send_activations,
schedule.RecvActivation: _exec_recv_activations,
schedule.SendGrad: _exec_send_grads,
schedule.RecvGrad: _exec_recv_grads,
}
def _exec_schedule(self, pipe_schedule):
self._reserve_pipe_buffers(pipe_schedule.num_pipe_buffers())
# For each step in the schedule
for step_cmds in pipe_schedule:
# For each instruction in the step
for cmd in step_cmds:
if type(cmd) not in self._INSTRUCTION_MAP:
raise RuntimeError(
f'{self.__class__.__name__} does not understand instruction {repr(cmd)}'
)
# Equivalent to: self._exec_forward_pass(buffer_id=0)
self._exec_instr = MethodType(self._INSTRUCTION_MAP[type(cmd)], self)
self._exec_instr(**cmd.kwargs)
import os
import enum
import re as regex
from collections import defaultdict
from functools import partial
import torch
import torch.nn as nn
import torch.distributed as dist
from deepspeed.utils import logger
from .. import utils as ds_utils
from ..activation_checkpointing import checkpointing
from .topology import PipeDataParallelTopology, PipelineParallelGrid
class PipelineError(Exception):
"""Errors related to the use of deepspeed.PipelineModule """
class LayerSpec:
"""Building block for specifying pipeline-parallel modules.
LayerSpec stores the type information and parameters for each stage in a
PipelineModule. For example:
nn.Sequence(
torch.nn.Linear(self.in_dim, self.hidden_dim, bias=False),
torch.nn.Linear(self.hidden_hidden, self.out_dim)
)
becomes
layer_specs = [
LayerSpec(torch.nn.Linear, self.in_dim, self.hidden_dim, bias=False),
LayerSpec(torch.nn.Linear, self.hidden_hidden, self.out_dim)]
]
"""
def __init__(self, typename, *module_args, **module_kwargs):
self.typename = typename
self.module_args = module_args
self.module_kwargs = module_kwargs
if not issubclass(typename, nn.Module):
raise RuntimeError('LayerSpec only supports torch.nn.Module types.')
if dist.is_initialized():
self.global_rank = dist.get_rank()
else:
self.global_rank = -1
def __repr__(self):
return ds_utils.call_to_str(self.typename.__name__,
self.module_args,
self.module_kwargs)
def build(self, log=False):
"""Build the stored specification."""
if log:
logger.info(f'RANK={self.global_rank} building {repr(self)}')
return self.typename(*self.module_args, **self.module_kwargs)
class TiedLayerSpec(LayerSpec):
def __init__(self,
key,
typename,
*module_args,
forward_fn=None,
tied_weight_attr='weight',
**module_kwargs):
super().__init__(typename, *module_args, **module_kwargs)
self.key = key
self.forward_fn = forward_fn
self.tied_weight_attr = tied_weight_attr
class PipelineModule(nn.Module):
"""Base class for modules to be parallelized with pipeline parallelism.
Users should subclass PipelineModule and provide layer_specs(), which returns a list
of LayerSpec objects. Thes sequence of layers represents the pipeline-parallel model.
After initialization, a PipelineModule can be used as a traditional torch.nn.Module.
The forward pass is already provided by this base class. The key assumption is that
the output of each layer can be directly fed as input to the next, like a
torch.nn.Sequence.
The key constraint that enables pipeline parallelism is the representation of the
forward pass as a sequence of layers (i.e., stages) and the enforcement of a
simple interface between them.
Example:
class LinearPipeline(PipelineModule):
def __init__(self, in_dim, hidden_dim, out_dim):
self.in_dim = in_dim
self.hidden_dim = hidden_dim
self.out_dim = out_dim
super().__init__()
def layer_specs(self):
return [LayerSpec(torch.nn.Linear, self.in_dim, self.hidden_dim, bias=False),
LayerSpec(torch.nn.Linear, self.hidden_hidden, self.out_dim)]
"""
def __init__(self,
layers,
num_stages=None,
loss_fn=None,
topology=None,
seed_layers=False,
seed_fn=None,
base_seed=1234,
partition_method='parameters',
activation_checkpoint_interval=0,
activation_checkpoint_func=checkpointing.checkpoint):
super().__init__()
if num_stages is None and topology is None:
raise RuntimeError('must provide num_stages or topology')
self.micro_offset = 0
self.loss_fn = loss_fn
self.seed_layers = seed_layers
self.seed_fn = seed_fn
self.base_seed = base_seed
if dist.get_rank() == 0:
try:
seed_str = self.seed_fn.__name__
except AttributeError:
seed_str = None
print(
f'SEED_LAYERS={self.seed_layers} BASE_SEED={self.base_seed} SEED_FN={seed_str}'
)
# Setup world info
self.world_group = dist.new_group(ranks=range(dist.get_world_size()))
self.global_rank = dist.get_rank(group=self.world_group)
self.world_size = dist.get_world_size(group=self.world_group)
if topology:
self._topo = topology
self.num_stages = self._topo.get_dim('pipe')
else:
self.num_stages = num_stages
if topology is None:
if self.world_size % self.num_stages != 0:
raise RuntimeError(
f'num_stages ({self.num_stages}) must divide distributed world size ({self.world_size})'
)
dp = self.world_size // num_stages
topology = PipeDataParallelTopology(num_pp=num_stages, num_dp=dp)
self._topo = topology
# Contruct communicators for pipeline topology
self._grid = PipelineParallelGrid(process_group=self.world_group,
topology=self._topo)
self.stage_id = self._topo.get_coord(self.global_rank).pipe
# Initialize partition information
self._layer_specs = list(layers)
self._num_layers = len(self._layer_specs)
self._local_start = 0
self._local_stop = None
self._partition_layers(method=partition_method)
self.forward_funcs = []
self.tied_modules = nn.ModuleDict()
self.tied_weight_attrs = {}
# Offset the random seed by the stage ID.
#newseed = torch.cuda.initial_seed() + self._grid.get_stage_id()
#ds_utils.set_random_seed(newseed)
#with torch.random.fork_rng(devices=[torch.cuda.current_device()]):
self._build()
self.to('cuda')
self.tied_comms = self._index_tied_modules()
self._synchronize_tied_weights()
self.activation_checkpoint_interval = activation_checkpoint_interval
self.activation_checkpoint_func = activation_checkpoint_func
def _build(self):
specs = self._layer_specs
for local_idx, layer in enumerate(specs[self._local_start:self._local_stop]):
layer_idx = local_idx + self._local_start
if self.seed_layers:
if self.seed_fn:
self.seed_fn(self.base_seed + layer_idx)
else:
ds_utils.set_random_seed(self.base_seed + layer_idx)
# Recursively build PipelineModule objects
if isinstance(layer, PipelineModule):
raise NotImplementedError('RECURSIVE BUILD NOT YET IMPLEMENTED')
# LayerSpec objects contain an nn.Module that should be allocated now.
elif isinstance(layer, nn.Module):
name = str(layer_idx)
self.forward_funcs.append(layer)
self.add_module(name, layer)
# TiedLayerSpec objects contain an nn.Module that should be allocated now.
elif isinstance(layer, TiedLayerSpec):
# Build and register the module if we haven't seen it before.
if layer.key not in self.tied_modules:
self.tied_modules[layer.key] = layer.build()
self.tied_weight_attrs[layer.key] = layer.tied_weight_attr
if layer.forward_fn is None:
# Just use forward()
self.forward_funcs.append(self.tied_modules[layer.key])
else:
# User specified fn with args (module, input)
self.forward_funcs.append(
partial(layer.forward_fn,
self.tied_modules[layer.key]))
# LayerSpec objects contain an nn.Module that should be allocated now.
elif isinstance(layer, LayerSpec):
module = layer.build()
name = str(layer_idx)
self.forward_funcs.append(module)
self.add_module(name, module)
# Last option: layer may be a functional (e.g., lambda). We do nothing in
# that case and just use it in forward()
else:
self.forward_funcs.append(layer)
# All pipeline parameters should be considered as model parallel in the context
# of our FP16 optimizer
for p in self.parameters():
p.model_parallel = True
def _count_layer_params(self):
"""Count the trainable parameters in individual layers.
This routine will only build one layer at a time.
Returns:
A list of the number of parameters in each layer.
"""
param_counts = [0] * len(self._layer_specs)
for idx, layer in enumerate(self._layer_specs):
if isinstance(layer, LayerSpec):
l = layer.build()
params = filter(lambda p: p.requires_grad, l.parameters())
param_counts[idx] = sum(p.numel() for p in params)
elif isinstance(layer, nn.Module):
params = filter(lambda p: p.requires_grad, layer.parameters())
param_counts[idx] = sum(p.numel() for p in params)
return param_counts
def _find_layer_type(self, layername):
idxs = []
typeregex = regex.compile(layername, regex.IGNORECASE)
for idx, layer in enumerate(self._layer_specs):
name = None
if isinstance(layer, LayerSpec):
name = layer.typename.__name__
elif isinstance(layer, nn.Module):
name = layer.__class__.__name__
else:
try:
name = layer.__name__
except AttributeError:
continue
if typeregex.search(name):
idxs.append(idx)
if len(idxs) == 0:
raise RuntimeError(
f"Partitioning '{layername}' found no valid layers to partition.")
return idxs
def forward(self, forward_input):
# We need to offset the seed by the microbatch ID. Save it in a local var to
# ensure it is preserved in the closure. Otherwise checkpointed forward funcs
# will see a different offset.
self.micro_offset += 1
def exec_range_func(start, end):
''' Helper function to be used with checkpoint()
Adapted from torch.utils.checkpoint:checkpoint_sequential()
'''
local_micro_offset = self.micro_offset + 1
def exec_func(*inputs):
# Single tensor inputs need to be unwrapped
if len(inputs) == 1:
inputs = inputs[0]
for idx, layer in enumerate(self.forward_funcs[start:end]):
self.curr_layer = idx + self._local_start
if self.seed_layers:
new_seed = (self.base_seed *
local_micro_offset) + self.curr_layer
if self.seed_fn:
self.seed_fn(new_seed)
else:
ds_utils.set_random_seed(new_seed)
inputs = layer(inputs)
return inputs
return exec_func
if self.activation_checkpoint_interval == 0:
func = exec_range_func(0, len(self.forward_funcs))
x = func(forward_input)
else:
num_layers = len(self.forward_funcs)
x = forward_input
for start_idx in range(0, num_layers, self.activation_checkpoint_interval):
end_idx = min(start_idx + self.activation_checkpoint_interval,
num_layers)
funcs = self.forward_funcs[start_idx:end_idx]
# Since we either pass tensors or tuples of tensors without unpacking, we
# need to be careful not to double-wrap tensors with tuple.
if not isinstance(x, tuple):
x = (x, )
if self._is_checkpointable(funcs):
x = self.activation_checkpoint_func(
exec_range_func(start_idx,
end_idx),
*x)
else:
x = exec_range_func(start_idx, end_idx)(*x)
return x
def _partition_layers(self, method='uniform'):
num_stages = self._topo.get_dim('pipe')
stage_id = self._topo.get_coord(self.global_rank).pipe
if self.global_rank == 0:
logger.info(f'Partitioning pipeline stages with method {method}')
method = method.lower()
# Each stage gets a simple uniform number of layers.
if method == 'uniform':
num_layers = len(self._layer_specs)
self.parts = ds_utils.partition_uniform(num_items=num_layers,
num_parts=num_stages)
elif method == 'parameters':
param_counts = self._count_layer_params()
self.parts = ds_utils.partition_balanced(weights=param_counts,
num_parts=num_stages)
elif method.startswith('type:'):
layertype = method.split(':')[1]
binary_weights = [0] * len(self._layer_specs)
for idx in self._find_layer_type(layertype):
binary_weights[idx] = 1
else:
self.parts = ds_utils.partition_balanced(weights=binary_weights,
num_parts=num_stages)
elif method == 'profile':
raise NotImplementedError(f'Partitioning method {method} not implemented.')
else:
raise NotImplementedError(f'Partitioning method {method} not implemented.')
# Print some information on the partitioning.
if self.global_rank == 0:
for stage in range(num_stages):
start = self.parts[stage]
stop = self.parts[stage + 1]
print(f'stage={stage} layers={stop - start}')
for idx, layer in enumerate(self._layer_specs[start:stop]):
name = str(layer)
if isinstance(layer, LayerSpec):
name = layer.typename.__name__
if isinstance(layer, nn.Module):
name = layer.__class__.__name__
else:
try:
name = layer.__name__
except AttributeError:
pass
print(f' {idx+start:2d}: {name}')
if self.loss_fn:
try:
print(f' loss: {self.loss_fn.__name__}')
except AttributeError:
print(f' loss: {self.loss_fn.__class__.__name__}')
self._set_bounds(start=self.parts[stage_id], stop=self.parts[stage_id + 1])
def allreduce_tied_weight_gradients(self):
'''All reduce the gradients of the tied weights between tied stages'''
for key, comm in self.tied_comms.items():
weight = getattr(self.tied_modules[key], comm['weight_attr'])
dist.all_reduce(weight.grad, group=comm['group'])
def _synchronize_tied_weights(self):
for key, comm in self.tied_comms.items():
dist.broadcast(
getattr(comm['module'],
comm['weight_attr']),
src=min(comm['ranks']),
group=comm['group'],
)
def _index_tied_modules(self):
''' Build communication structures for tied modules. '''
tied_comms = {}
if self._topo.get_dim('pipe') == 1:
return tied_comms
specs = self._layer_specs
tie_keys = set(s.key for s in specs if isinstance(s, TiedLayerSpec))
for key in tie_keys:
# Find the layers that the tied module appears in
tied_layers = []
for idx, layer in enumerate(specs):
if isinstance(layer, TiedLayerSpec) and layer.key == key:
tied_layers.append(idx)
# Find all stages with this tied module
# TODO: Would be nice to remove the nested data/model parallelism loops and
# TODO: instead generalize in some way, since we really just care about the
# TODO: stage that owns the tied layer. Then loop over each (dp, mp, ...)
# TODO: fiber to generate process groups.
tied_stages = set(self.stage_owner(idx) for idx in tied_layers)
for dp in range(self._grid.data_parallel_size):
for mp in range(self._grid.model_parallel_size):
tied_ranks = []
for s in sorted(tied_stages):
if self._grid.model_parallel_size > 1:
tied_ranks.append(
self._grid.stage_to_global(stage_id=s,
data=dp,
model=mp))
else:
tied_ranks.append(
self._grid.stage_to_global(stage_id=s,
data=dp))
group = dist.new_group(ranks=tied_ranks)
# Record this tied module if we own a local copy of it.
if self.global_rank in tied_ranks:
assert key in self.tied_modules
if key in self.tied_modules:
tied_comms[key] = {
'ranks': tied_ranks,
'group': group,
'weight_attr': self.tied_weight_attrs[key],
'module': self.tied_modules[key],
}
# Only count the tied module once in the eyes of the FP16 optimizer
if self.global_rank != tied_ranks[0]:
for p in self.tied_modules[key].parameters():
p.model_parallel = False
'''
if len(tied_comms) > 0:
print(f'RANK={self.global_rank} tied_comms={tied_comms}')
'''
return tied_comms
def partitions(self):
return self.parts
def stage_owner(self, layer_idx):
assert 0 <= layer_idx < self._num_layers
for stage in range(self._topo.get_dim('pipe')):
if self.parts[stage] <= layer_idx < self.parts[stage + 1]:
return stage
raise RuntimeError(f'Layer {layer_idx} not owned? parts={self.parts}')
def _set_bounds(self, start=None, stop=None):
"""Manually define the range of layers that will be built on this process.
These boundaries are treated as list slices and so start is inclusive and stop is
exclusive. The default of None for both results in all layers being built
locally.
"""
self._local_start = start
self._local_stop = stop
def set_checkpoint_interval(self, interval):
""" Checkpoint activations after each ``interval`` layers. Use 0 to disable. """
assert interval >= 0
self.checkpoint_interval = interval
def topology(self):
""" ProcessTopology object to query process mappings. """
return self._topo
def mpu(self):
return self._grid
def num_pipeline_stages(self):
return self._topo.get_dim('pipe')
def ckpt_prefix(self, checkpoints_path, tag):
"""Build a prefix for all checkpoint files written by this module. """
# All checkpoint files start with this
rank_name = 'module'
# Data parallelism is omitted from the naming convention because we are agnostic
# to this in the checkpoint.
omit_dims = frozenset(['data'])
axes = [a for a in self._grid._topo.get_axis_names() if a not in omit_dims]
for dim in axes:
rank = getattr(self._grid._topo.get_coord(rank=self.global_rank), dim)
rank_name += f'-{dim}_{rank:02d}'
ckpt_name = os.path.join(checkpoints_path, str(tag), rank_name)
return ckpt_name
def ckpt_layer_path(self, ckpt_dir, local_layer_idx):
"""Customize a prefix for a specific pipeline module layer. """
idx = local_layer_idx + self._local_start
layer_ckpt_path = os.path.join(ckpt_dir, f'layer_{idx:02d}')
rank_repr = self._grid._topo.get_rank_repr(rank=self.global_rank)
if rank_repr is not '':
layer_ckpt_path += f'-{rank_repr}'
layer_ckpt_path += '-model_states.pt'
return layer_ckpt_path
def save_state_dict(self, save_dir):
if self._grid.data_parallel_id != 0:
return
os.makedirs(save_dir, exist_ok=True)
layer_offset = self._local_start
for idx, layer in enumerate(self.forward_funcs):
model_ckpt_path = self.ckpt_layer_path(save_dir, idx)
if not hasattr(layer, 'state_dict'):
continue
torch.save(layer.state_dict(), model_ckpt_path)
def load_state_dir(self, load_dir, strict=True):
rank = dist.get_rank()
layer_offset = self._local_start
for idx, layer in enumerate(self.forward_funcs):
# Functions, etc. will not have state_dicts
if not hasattr(layer, 'load_state_dict'):
continue
model_ckpt_path = self.ckpt_layer_path(load_dir, idx)
layer.load_state_dict(torch.load(model_ckpt_path,
map_location=lambda storage,
loc: storage),
strict=strict)
if self._grid.data_parallel_id == 0:
logger.info(
f'RANK={self.global_rank} Loaded layer={idx+layer_offset} file={model_ckpt_path}'
)
self._synchronize_tied_weights()
def _is_checkpointable(self, funcs):
if self.__class__.__name__ == 'GPT2ModelPipe':
return all('ParallelTransformerLayerPipe' in f.__class__.__name__
for f in funcs)
params = [f.parameters() for f in funcs if isinstance(f, torch.nn.Module)]
return any(len(list(p)) > 0 for p in params)
'''
Copyright 2019 The Microsoft DeepSpeed Team
'''
import torch.distributed as dist
_groups = None
_grid = None
#initializes adjacent process groups
#run this only after torch.distributed.init_process_group() has been called
def init_process_groups(grid):
global _groups, _grid
_grid = grid
assert _grid.pipe_parallel_size > 1, "There is no model parallelism"
_groups = [dist.new_group(ranks=group) for group in _grid.p2p_groups]
def _is_valid_send_recv(src_stage, dest_stage):
first_stage = 0
last_stage = _grid.pipe_parallel_size - 1
assert abs(src_stage-dest_stage) == 1 or \
(src_stage == first_stage and dest_stage == last_stage) or \
(src_stage == last_stage and dest_stage == first_stage), \
"Functionality currently limited to send and receive between adjacent ranks only"
def send(tensor, dest_stage, async_op=False):
global _groups
async_op = False
src_stage = _grid.get_stage_id()
_is_valid_send_recv(src_stage, dest_stage)
group = _get_send_recv_group(src_stage, dest_stage)
src_rank = _grid.stage_to_global(stage_id=src_stage)
return dist.broadcast(tensor, src_rank, group=group, async_op=async_op)
def recv(tensor, src_stage, async_op=False):
global _groups
async_op = False
dest_stage = _grid.get_stage_id()
_is_valid_send_recv(src_stage, dest_stage)
group = _get_send_recv_group(src_stage, dest_stage)
src_rank = _grid.stage_to_global(stage_id=src_stage)
return dist.broadcast(tensor, src_rank, group=group, async_op=async_op)
def barrier(stage_id):
global _groups, _grid
group_id = _grid.stage_to_global(stage_id=stage_id)
if (dist.get_rank() >= 0):
print("Barrier Group ID", group_id)
print("Barrier Group", _grid.p2p_groups[group_id])
dist.barrier(group=_groups[group_id])
if (dist.get_rank() >= 0):
print("Exiting Barrier ", group_id)
def _get_send_recv_group(src_stage, dest_stage):
'''the group id is always the smaller rank unless its a wrap around'''
stage_id = None
first_stage = 0
last_stage = _grid.pipe_parallel_size - 1
if (src_stage == first_stage and dest_stage == last_stage
or dest_stage == first_stage and src_stage == last_stage):
stage_id = last_stage
elif src_stage > dest_stage:
stage_id = dest_stage
else:
stage_id = src_stage
'''group_id corresponds to group of [group_id, group_id+1]
unless group_id is the rank of the last stage
in which case group_id correspods to group[group_id-num_stages+1, group_id]
'''
group_id = _grid.stage_to_global(stage_id=stage_id)
return _groups[group_id]
from ..utils import call_to_str
from abc import ABC, abstractmethod
class PipeSchedule(ABC):
"""Directs the execution of a pipeline engine by generating sequences of
:class:`PipeInstruction`.
Schedules are generators that yield sequences of
:class:`PipeInstruction` to process the micro-batches in one batch.
Each yielded step is atomic in the sense that a barrier
synchronization can be placed between successive steps without
deadlock.
Below is an example schedule that implements data parallelism with gradient accumulation:
.. code-block:: python
class DataParallelSchedule(PipeSchedule):
def steps(self):
for step_id in range(self.micro_batches):
cmds = [
LoadMicroBatch(buffer_id=0),
ForwardPass(buffer_id=0),
BackwardPass(buffer_id=0),
]
if step_id == self.micro_batches - 1:
cmds.extend([
ReduceGrads(),
OptimizerStep(),
])
yield cmds
def num_pipe_buffers(self):
return 1
Args:
micro_batches (int): The number of micro-batches that comprise a batch.
stages (int): The number of pipeline stages.
stage_id (int): The pipe stage that will execute the generated schedule.
"""
def __init__(self, micro_batches, stages, stage_id):
super().__init__()
self.micro_batches = micro_batches
self.stages = stages
self.stage_id = stage_id
self.prev_stage = self.stage_id - 1
self.next_stage = self.stage_id + 1
@abstractmethod
def steps(self):
"""Yield a list of :class:`PipeInstruction` for each step in the schedule.
.. note::
Schedules must implement ``steps()`` to define the schedule.
Returns:
Instructions to be executed as one step of the pipeline
"""
pass
def num_pipe_buffers(self):
"""The number of pipeline buffers that will be used by this stage.
.. note::
Schedules should specialize ``num_pipe_buffers()`` for memory savings at scale.
Returns:
The number of buffers for the engine to allocate.
"""
return self.micro_batches
def _valid_micro_batch(self, micro_batch_id):
return 0 <= micro_batch_id < self.micro_batches
def _valid_stage(self, stage_id):
return 0 <= stage_id < self.stages
@property
def stage(self):
"""Stage index used to configure this schedule."""
return self.stage_id
@property
def num_stages(self):
"""The number of total pipeline stages used to configure this schedule."""
return self.stages
@property
def num_micro_batches(self):
"""The number of total micro_batches used to configure this schedule."""
return self.micro_batches
@property
def is_first_stage(self):
"""True if the configured ``stage_id`` is the first stage in the pipeline."""
return self.stage_id == 0
@property
def is_last_stage(self):
"""True if the configured ``stage_id`` is the last stage in the pipeline."""
return self.stage_id == self.stages - 1
def _buffer_idx(self, micro_batch_id):
"""Map a micro-batch index to a pipeline buffer index.
This method uses a cyclic allocation strategy.
Args:
micro_batch_id (int): The micro-batch index relative to the beginning of the schedule.
Returns:
int: The index of the buffer that should store data.
"""
assert self._valid_micro_batch(micro_batch_id)
return micro_batch_id % self.num_pipe_buffers()
def __iter__(self):
self.it = None
return self
def __next__(self):
if self.it is None:
self.it = self.steps()
return next(self.it)
class InferenceSchedule(PipeSchedule):
"""A schedule for inferencing batches using pipeline parallelism.
"""
def steps(self):
""""""
prev_micro_batch_id = -1
total_steps = self.micro_batches + self.stages - 1
for step_id in range(total_steps):
cmds = []
micro_batch_id = step_id - self.stage_id
# Alternate send/recv buffers
if _is_even(self.stage_id):
recv_buf = step_id % 2
send_buf = (step_id + 1) % 2
else:
recv_buf = (step_id + 1) % 2
send_buf = step_id % 2
if self.is_first_stage or self.is_last_stage:
if self._valid_micro_batch(micro_batch_id):
cmds.append(LoadMicroBatch(recv_buf))
if _is_even(self.stage_id):
if self._valid_stage(self.next_stage):
if self._valid_micro_batch(micro_batch_id - 1):
cmds.append(SendActivation(send_buf))
if self._valid_stage(self.prev_stage):
if self._valid_micro_batch(micro_batch_id):
cmds.append(RecvActivation(recv_buf))
else:
if self._valid_stage(self.prev_stage):
if self._valid_micro_batch(micro_batch_id):
cmds.append(RecvActivation(recv_buf))
if self._valid_stage(self.next_stage):
if self._valid_micro_batch(micro_batch_id - 1):
cmds.append(SendActivation(send_buf))
if self._valid_micro_batch(micro_batch_id):
cmds.append(ForwardPass(recv_buf))
yield cmds
def num_pipe_buffers(self):
"""Only two pipeline buffers are required for inferencing.
Returns:
``2``
"""
return 2
class TrainSchedule(PipeSchedule):
"""A schedule for training a batch using hybrid parallelism.
Pipeline parallelism is extracted through gradient accumulation and thus
convergence follows that of a data parallel approach with the same batch
size.
"""
def steps(self):
""""""
prev_micro_batch_id = -1
total_steps = 2 * (self.micro_batches + self.stages - 1)
for step_id in range(total_steps):
# Map the step of the pipeline to the micro-batch id and also whether it is a
# forward or backward pass step.
micro_batch_id, is_forward = self._step_to_micro_batch(step_id)
if self._valid_micro_batch(prev_micro_batch_id):
prev_buffer = self._buffer_idx(prev_micro_batch_id)
if self._valid_micro_batch(micro_batch_id):
curr_buffer = self._buffer_idx(micro_batch_id)
cmds = []
# Exchange activations
if is_forward:
if self._valid_micro_batch(micro_batch_id) and self._valid_stage(
self.prev_stage):
cmds.append(RecvActivation(curr_buffer))
if self._valid_micro_batch(prev_micro_batch_id) and self._valid_stage(
self.prev_stage):
cmds.append(SendGrad(prev_buffer))
else:
if self._valid_micro_batch(prev_micro_batch_id) and self._valid_stage(
self.next_stage):
cmds.append(SendActivation(prev_buffer))
if self._valid_micro_batch(micro_batch_id) and self._valid_stage(
self.next_stage):
cmds.append(RecvGrad(curr_buffer))
# First/last stage loads
if self.stage_id == 0 or self.stage_id == self.stages - 1:
if is_forward and self._valid_micro_batch(micro_batch_id):
cmds.append(LoadMicroBatch(curr_buffer))
# Computation
if self._valid_micro_batch(micro_batch_id):
if is_forward:
cmds.append(ForwardPass(curr_buffer))
else:
cmds.append(BackwardPass(curr_buffer))
# Model step at the end of the batch
if step_id == total_steps - 1:
cmds.append(ReduceTiedGrads())
cmds.append(ReduceGrads())
cmds.append(OptimizerStep())
# Prepare state for next time
prev_micro_batch_id = micro_batch_id
yield cmds
def num_pipe_buffers(self):
"""As many buffers as the distance from this stage to the last stage.
"""
buffers = min(self.stages - self.stage_id + 1, self.micro_batches)
return max(2, buffers)
def _step_to_micro_batch(self, step_id):
if _is_even(step_id) and _is_even(self.stage_id):
micro_batch_id = self._even_step_forward_id(step_id)
is_forward = True
elif _is_odd(step_id) and _is_odd(self.stage_id):
micro_batch_id = self._odd_step_forward_id(step_id)
is_forward = True
elif _is_even(step_id) and _is_odd(self.stage_id):
micro_batch_id = self._even_step_backward_id(step_id)
is_forward = False
elif _is_odd(step_id) and _is_even(self.stage_id):
micro_batch_id = self._odd_step_backward_id(step_id)
is_forward = False
else:
assert False
return micro_batch_id, is_forward
def _even_step_forward_id(self, step_id):
base = step_id // 2
micro_batch_id = int(base - self.stage_id // 2)
return micro_batch_id
def _odd_step_forward_id(self, step_id):
base = (step_id - 1) // 2
micro_batch_id = int(base - self.stage_id // 2)
return micro_batch_id
def _even_step_backward_id(self, step_id):
base = step_id // 2
micro_batch_id = int(base - self.stages + (self.stage_id + 1) // 2)
return micro_batch_id
def _odd_step_backward_id(self, step_id):
base = ((step_id - 1) // 2) - self.stages + 1
micro_batch_id = int(base + self.stage_id // 2)
return micro_batch_id
class DataParallelSchedule(PipeSchedule):
"""An example schedule that trains using traditional data parallelism with gradient
accumulation.
"""
def steps(self):
""""""
for step_id in range(self.micro_batches):
cmds = [
LoadMicroBatch(buffer_id=0),
ForwardPass(buffer_id=0),
BackwardPass(buffer_id=0),
]
if step_id == self.micro_batches - 1:
cmds.extend([
ReduceGrads(),
OptimizerStep(),
])
yield cmds
def num_pipe_buffers(self):
"""Only one pipeline buffer needed.
"""
return 1
class PipeInstruction:
"""Base class for all instructions to be executed by the pipeline engine.
All keyword arguments are stored as members similar to a ``namedtuple``. These are
then accessible to the :class:`PipeEngine` during execution.
Args:
kwargs (optional): keyword arguments to store as members
"""
def __init__(self, **kwargs):
self.name = self.__class__.__name__
self.kwargs = kwargs
for key, val in kwargs.items():
setattr(self, key, val)
def __repr__(self):
return call_to_str(self.name, **self.kwargs)
class OptimizerStep(PipeInstruction):
"""Performs one step with the optimizer and zeros gradients.
.. note:: Should be issued after :class:`ReduceGrads` and :class:`ReduceTiedGrads`.
.. note:: Can be a synchronization point among data-parallel ranks.
"""
pass
class ReduceGrads(PipeInstruction):
"""Reduce the computed gradients among data-parallel processes within the stage.
"""
pass
class ReduceTiedGrads(PipeInstruction):
"""Reduce the computed gradients of tied modules within a pipeline-parallel group.
.. warning::
The stages included in this synchronization point are not known until
the model is partitioned among pipeline stages. In the worst case, it
includes all pipeline stages. This instruction should be scheduled
carefully to avoid deadlocks.
"""
pass
class BufferOpInstruction(PipeInstruction):
"""A pipeline instruction that operates on pipeline buffer(s).
Args:
buffer_id (int): the index of the pipeline buffer() to modify.
"""
def __init__(self, buffer_id, **kwargs):
super().__init__(buffer_id=buffer_id, **kwargs)
# IO
class LoadMicroBatch(BufferOpInstruction):
"""Load a micro-batch into a buffer.
Roughly:
.. code-block:: python
buffers['inputs'][buffer_id] = next(data_iter)
"""
pass
# Compute
class ForwardPass(BufferOpInstruction):
"""Compute a forward pass.
Roughly:
.. code-block:: python
buffers['ouputs'][buffer_id] = forward(buffers['inputs'][buffer_id])
"""
pass
class BackwardPass(BufferOpInstruction):
"""Compute a backward pass and accumulate gradients.
Roughly:
.. code-block:: python
outputs = buffers['ouputs'][buffer_id]
gradients = buffers['gradients'][buffer_id]
torch.autograd.backward(tensors=outputs,
grad_tensors=gradients)
"""
pass
# Communication
class SendActivation(BufferOpInstruction):
"""Send activations to the next stage in the pipeline.
Roughly:
.. code-block:: python
send(buffers['outputs'][buffer_id])
.. note::
The communication is blocking and must be paired with a :class:`RecvActivation`
on the next pipeline stage to avoid deadlock.
"""
pass
class RecvActivation(BufferOpInstruction):
"""Receive activations from the previous stage in the pipeline.
Roughly:
.. code-block:: python
buffers['inputs'][buffer_id] = recv()
.. note::
The communication is blocking and must be paired with a :class:`SendActivation`
on the previous pipeline stage to avoid deadlock.
"""
pass
class SendGrad(BufferOpInstruction):
"""Send computed gradients to the previous pipeline stage.
with respect to the received activations
.. note::
Only received tensors with ``requires_grad==True`` will produce gradients.
Missing gradients will be replaced with ``None`` on the receiving stage.
.. note::
The communication is blocking and must be paired with a :class:`RecvGrad`
on the previous pipeline stage to avoid deadlock.
"""
pass
class RecvGrad(BufferOpInstruction):
"""Receive computed gradients the next pipeline stage.
.. note::
Only activations with ``requires_grad==True`` will produce gradients.
Missing gradients will be replaced with ``None``.
.. note::
The communication is blocking and must be paired with a :class:`SendGrad`
on the next pipeline stage to avoid deadlock.
"""
pass
def _is_even(x):
return x % 2 == 0
def _is_odd(x):
return x % 2 != 0
# Copyright 2019 The Microsoft DeepSpeed Team
from deepspeed.utils import logger
import torch.distributed as dist
import sys
from collections import namedtuple
from itertools import product as cartesian_product
class ProcessTopology:
""" Manages the mapping of n-dimensional Cartesian coordinates to linear
indices. This mapping is used to map the rank of processes to the grid
for various forms of parallelism.
Each axis of the tensor is accessed by its name. The provided ordering
of the axes defines the layout of the topology. ProcessTopology uses a "row-major"
layout of the tensor axes, and so axes=['x', 'y'] would map coordinates (x,y) and
(x,y+1) to adjacent linear indices. If instead axes=['y', 'x'] was used, coordinates
(x,y) and (x+1,y) would be adjacent.
Some methods return ProcessCoord namedtuples.
"""
def __init__(self, axes, dims):
"""Create a mapping of n-dimensional tensor coordinates to linear indices.
Arguments:
axes (list): the names of the tensor axes
dims (list): the dimension (length) of each axis of the topology tensor
"""
self.axes = axes # names of each topology axis
self.dims = dims # length of each topology axis
# This is actually a class that lets us hash {'row':3, 'col':2} mappings
self.ProcessCoord = namedtuple('ProcessCoord', axes)
self.mapping = {}
ranges = [range(d) for d in dims]
# example: 1, (0,0,1)
for global_rank, coord in enumerate(cartesian_product(*ranges)):
key = {axis: coord[self.axes.index(axis)] for axis in self.axes}
key = self.ProcessCoord(**key)
# for example, {ProcessCoord(row=0, col=1) : 1}
self.mapping[key] = global_rank
def get_rank(self, **coord_kwargs):
"""Return the global rank of a process via its coordinates.
Coordinates are specified as kwargs. For example:
>>> X = ProcessTopology(axes=['x', 'y'], dims=[2,3])
>>> X.get_rank(x=0, y=1)
1
"""
if len(coord_kwargs) != len(self.axes):
raise ValueError('get_rank() does not support slices. Use filter_match())')
key = self.ProcessCoord(**coord_kwargs)
assert key in self.mapping, f'key {kwargs} invalid'
return self.mapping[key]
def get_axis_names(self):
"""Return a list of the axis names in the ordering of the topology. """
return self.axes
def get_rank_repr(self,
rank,
omit_axes=['data',
'pipe'],
inner_sep='_',
outer_sep='-'):
"""Return a string representation of a rank.
This method is primarily used for checkpointing model data.
For example:
>>> topo = Topo(axes=['a', 'b'], dims=[2, 2])
>>> topo.get_rank_repr(rank=3)
'a_01-b_01'
>>> topo.get_rank_repr(rank=3, omit_axes=['a'])
'b_01'
Args:
rank (int): A rank in the topology.
omit_axes (list, optional): Axes that should not be in the representation. Defaults to ['data', 'pipe'].
inner_sep (str, optional): [description]. Defaults to '_'.
outer_sep (str, optional): [description]. Defaults to '-'.
Returns:
str: A string representation of the coordinate owned by ``rank``.
"""
omit_axes = frozenset(omit_axes)
axes = [a for a in self.get_axis_names() if a not in omit_axes]
names = []
for ax in axes:
ax_rank = getattr(self.get_coord(rank=rank), ax)
names.append(f'{ax}{inner_sep}{ax_rank:02d}')
return outer_sep.join(names)
def get_dim(self, axis):
"""Return the number of processes along the given axis.
For example:
>>> X = ProcessTopology(axes=['x', 'y'], dims=[2,3])
>>> X.get_dim('y')
3
"""
if axis not in self.axes:
return 0
return self.dims[self.axes.index(axis)]
def get_coord(self, rank):
"""Return the coordinate owned by a process rank.
The axes of the returned namedtuple can be directly accessed as members. For
example:
>>> X = ProcessTopology(axes=['x', 'y'], dims=[2,3])
>>> coord = X.get_coord(rank=1)
>>> coord.x
0
>>> coord.y
1
"""
for coord, idx in self.mapping.items():
if idx == rank:
return coord
raise ValueError(f'rank {rank} not found in topology.')
def get_axis_comm_lists(self, axis):
""" Construct lists suitable for a communicator group along axis ``axis``.
Example:
>>> topo = Topo(axes=['pipe', 'data', 'model'], dims=[2, 2, 2])
>>> topo.get_axis_comm_lists('pipe')
[
[0, 4], # data=0, model=0
[1, 5], # data=0, model=1
[2, 6], # data=1, model=0
[3, 7], # data=1, model=1
]
Returns:
A list of lists whose coordinates match in all axes *except* ``axis``.
"""
# We don't want to RuntimeError because it allows us to write more generalized
# code for hybrid parallelisms.
if axis not in self.axes:
return []
# Grab all axes but `axis`
other_axes = [a for a in self.axes if a != axis]
lists = []
# Construct all combinations of coords with other_axes
ranges = [range(self.get_dim(a)) for a in other_axes]
for coord in cartesian_product(*ranges):
other_keys = {a: coord[other_axes.index(a)] for a in other_axes}
# now go over all ranks in `axis`.
sub_list = []
for axis_key in range(self.get_dim(axis)):
key = self.ProcessCoord(**other_keys, **{axis: axis_key})
sub_list.append(self.mapping[key])
lists.append(sub_list)
return lists
def filter_match(self, **filter_kwargs):
"""Return the list of ranks whose coordinates match the provided criteria.
Example:
>>> X = ProcessTopology(axes=['pipe', 'data', 'model'], dims=[2, 2, 2])
>>> X.filter_match(pipe=0, data=1)
[2, 3]
>>> [X.get_coord(rank) for rank in X.filter_match(pipe=0, data=1)]
[ProcessCoord(pipe=0, data=1, model=0), ProcessCoord(pipe=0, data=1, model=1)]
Arguments:
**filter_kwargs (dict): criteria used to select coordinates.
Returns:
The list of ranks whose coordinates match filter_kwargs.
"""
def _filter_helper(x):
for key, val in filter_kwargs.items():
if getattr(x, key) != val:
return False
return True
coords = filter(_filter_helper, self.mapping.keys())
return [self.mapping[coo] for coo in coords]
def get_axis_list(self, axis, idx):
"""Returns the list of global ranks whose coordinate in an axis is idx.
For example:
>>> X = ProcessTopology(axes=['x', 'y'], dims=[2,3])
>>> X.get_axis_list(axis='x', idx=0)
[0, 1, 2]
>>> X.get_axis_list(axis='y', idx=0)
[0, 3]
"""
# This could be faster by generating the desired keys directly instead of
# filtering.
axis_num = self.axes.index(axis)
ranks = [self.mapping[k] for k in self.mapping.keys() if k[axis_num] == idx]
return ranks
def world_size(self):
return len(self.mapping)
def __str__(self):
return str(self.mapping)
def _prime_factors(N):
""" Returns the prime factorization of positive integer N. """
if N <= 0:
raise ValueError("Values must be strictly positive.")
primes = []
while N != 1:
for candidate in range(2, N + 1):
if N % candidate == 0:
primes.append(candidate)
N //= candidate
break
return primes
class PipeDataParallelTopology(ProcessTopology):
""" A topology specialiation for hybrid data and pipeline parallelism.
Uses data parallelism on the last dimension to encourage gradient
reductions to use high-bandwidth intra-node links and lower-volume
pipeline communications to use low-bandwidth inter-node links.
"""
def __init__(self, num_pp, num_dp):
super().__init__(axes=['pipe', 'data'], dims=[num_pp, num_dp])
class PipeModelDataParallelTopology(ProcessTopology):
""" A topology for hybrid pipeline, model, and data parallelism. """
def __init__(self, num_pp, num_mp, num_dp):
super().__init__(axes=['pipe', 'data', 'model'], dims=[num_pp, num_dp, num_mp])
class PipelineParallelGrid:
"""Implements a grid object that stores the data parallel ranks
corresponding to each o the model parallel stages
The grid object organizes the processes in a distributed pytorch job
into a 2D grid, of stage_id and data_parallel_id.
self.stage_id and self.data_parallel_id stores the stage id
and the data parallel id of current process.
self.dp_group groups the processes by stage_id.
self.dp_group[i], is a list containing all process ranks whose
stage_id is i.
self.p2p_groups stores a list of tuple, where each tuple
stores process ranks of adjacent stages for a given data_parallel_id.
For example if num_stage is 5 then a tuple [7,8] represents stages [3, 4],
with data_parallel id = 1. A stage wrap around will appear as non-adjacent ranks,
for example tuple [4,0] with representing wrap-around stage 4 and 0, for
data_parallel_id = 0, or similarly [9,5] represents wrapped around stages [4,0]
for data_parallel_id = 1.
"""
def __init__(self, topology=None, process_group=None):
# TODO use process_group if provided
self.global_rank = dist.get_rank()
self.world_size = dist.get_world_size()
if topology is not None:
if self.global_rank == 0:
print('Using topology:', topology)
self._topo = topology
else:
num_pp = 1
num_dp = 1
for idx, prime in enumerate(_prime_factors(self.world_size)):
if idx % 2 == 0:
num_pp *= prime
else:
num_dp *= prime
self._topo = PipeDataParallelTopology(num_dp=num_dp, num_pp=num_pp)
self.data_parallel_size = max(self._topo.get_dim('data'), 1)
self.pipe_parallel_size = max(self._topo.get_dim('pipe'), 1)
self.model_parallel_size = max(self._topo.get_dim('model'), 1)
assert self._is_grid_valid(), "Invalid Grid"
self.stage_id = self.get_stage_id()
self.data_parallel_id = self.get_data_parallel_id()
# Create new ProcessGroups for all model parallelism. DeepSpeedLight uses these
# to detect overflow, etc.
self.ds_model_proc_group = None
self.ds_model_rank = -1
for dp in range(self.data_parallel_size):
ranks = sorted(self._topo.get_axis_list(axis='data', idx=dp))
if self.global_rank == 0:
#print(f'RANK={self.global_rank} building DeepSpeed model group: {ranks}')
pass
proc_group = dist.new_group(ranks=ranks)
if self.global_rank in ranks:
self.ds_model_proc_group = proc_group
self.ds_model_world_size = len(ranks)
self.ds_model_rank = ranks.index(self.global_rank)
assert self.ds_model_rank > -1
assert self.ds_model_proc_group is not None
# Create new ProcessGroup for gradient all-reduces - these are the data parallel groups
self.dp_group = []
self.dp_groups = self._topo.get_axis_comm_lists('data')
for g in self.dp_groups:
proc_group = dist.new_group(ranks=g)
if self.global_rank in g:
self.dp_group = g
self.dp_proc_group = proc_group
self.is_first_stage = (self.stage_id == 0)
self.is_last_stage = (self.stage_id == (self.pipe_parallel_size - 1))
self.p2p_groups = self._build_p2p_groups()
# Create new ProcessGroup for pipeline collectives - these are pipe parallel groups
self.pp_group = []
self.pp_proc_group = None
self.pipe_groups = self._topo.get_axis_comm_lists('pipe')
for ranks in self.pipe_groups:
if self.global_rank == 0:
#print(f'RANK={self.global_rank} building pipeline group: {ranks}')
pass
proc_group = dist.new_group(ranks=ranks)
if self.global_rank in ranks:
self.pp_group = ranks
self.pp_proc_group = proc_group
assert self.pp_proc_group is not None
# Create new ProcessGroup for model (tensor-slicing) collectives
# Short circuit case without model parallelism.
# TODO: it would be nice if topology had bcast semantics to avoid this branching
# case?
if self.model_parallel_size == 1:
for group_rank in range(self.world_size):
group_rank = [group_rank]
group = dist.new_group(ranks=group_rank)
if group_rank[0] == self.global_rank:
self.slice_group = group_rank
self.slice_proc_group = group
return
else:
self.mp_group = []
self.model_groups = self._topo.get_axis_comm_lists('model')
for g in self.model_groups:
proc_group = dist.new_group(ranks=g)
if self.global_rank in g:
self.slice_group = g
self.slice_proc_group = proc_group
def get_stage_id(self):
return self._topo.get_coord(rank=self.global_rank).pipe
def get_data_parallel_id(self):
return self._topo.get_coord(rank=self.global_rank).data
def _build_p2p_groups(self):
"""Groups for sending and receiving activations and gradients across model
parallel stages.
"""
comm_lists = self._topo.get_axis_comm_lists('pipe')
p2p_lists = []
for rank in range(self.world_size):
for l in comm_lists:
assert len(l) == self.pipe_parallel_size
if rank in l:
idx = l.index(rank)
buddy_rank = l[(idx + 1) % self.pipe_parallel_size]
p2p_lists.append([rank, buddy_rank])
break # next global rank
assert len(p2p_lists) == self.world_size
return p2p_lists
def _is_grid_valid(self):
ranks = 1
for ax in self._topo.get_axis_names():
ranks *= self._topo.get_dim(ax)
return ranks == dist.get_world_size()
#returns the global rank of the process with the provided stage id
#which has the same data_parallel_id as caller process
def stage_to_global(self, stage_id, **kwargs):
me = self._topo.get_coord(self.global_rank)
transform = me._replace(pipe=stage_id, **kwargs)._asdict()
return self._topo.get_rank(**transform)
def topology(self):
return self._topo
# MPU functions for DeepSpeed integration
def get_global_rank(self):
return self.global_rank
def get_pipe_parallel_rank(self):
""" The stage of the pipeline this rank resides in. """
return self.get_stage_id()
def get_pipe_parallel_world_size(self):
""" The number of stages in the pipeline. """
return self.pipe_parallel_size
def get_pipe_parallel_group(self):
""" The group of ranks within the same pipeline. """
return self.pp_proc_group
def get_data_parallel_rank(self):
""" Which pipeline this rank resides in. """
return self.data_parallel_id
def get_data_parallel_world_size(self):
""" The number of pipelines. """
return self.data_parallel_size
def get_data_parallel_group(self):
""" The group of ranks within the same stage of all pipelines. """
return self.dp_proc_group
# These are model parallel groups across all types of model parallelism.
# Deepspeed uses them to detect overflow, etc.
def get_model_parallel_rank(self):
return self.ds_model_rank
def get_model_parallel_world_size(self):
return self.ds_model_world_size
def get_model_parallel_group(self):
return self.ds_model_proc_group
# For Megatron-style tensor slicing
def get_slice_parallel_rank(self):
if 'model' in self._topo.get_axis_names():
return self._topo.get_coord(rank=self.global_rank).model
else:
return 0
def get_slice_parallel_world_size(self):
self.slice_parallel_size
def get_slice_parallel_group(self):
return self.slice_proc_group
...@@ -6,11 +6,36 @@ Copyright NVIDIA/Megatron ...@@ -6,11 +6,36 @@ Copyright NVIDIA/Megatron
Helper functions and classes from multiple sources. Helper functions and classes from multiple sources.
''' '''
import os
from math import ceil
from math import floor
from bisect import bisect_left, bisect_right
import torch import torch
import torch.distributed as dist
from torch._six import inf from torch._six import inf
import torch.distributed as dist import torch.distributed as dist
from deepspeed.utils import logger from deepspeed.utils import logger
from numpy import prod
def ensure_directory_exists(filename):
"""Create the directory path to ``filename`` if it does not already exist.
Args:
filename (str): A file path.
"""
dirname = os.path.dirname(filename)
os.makedirs(dirname, exist_ok=True)
def set_random_seed(seed):
import numpy
import random
random.seed(seed)
numpy.random.seed(seed)
torch.manual_seed(seed)
class CheckOverflow(object): class CheckOverflow(object):
...@@ -85,6 +110,7 @@ class CheckOverflow(object): ...@@ -85,6 +110,7 @@ class CheckOverflow(object):
torch.distributed.all_reduce(overflow_gpu, torch.distributed.all_reduce(overflow_gpu,
op=torch.distributed.ReduceOp.MAX, op=torch.distributed.ReduceOp.MAX,
group=self.mpu.get_model_parallel_group()) group=self.mpu.get_model_parallel_group())
overflow = overflow_gpu[0].item() overflow = overflow_gpu[0].item()
return bool(overflow) return bool(overflow)
...@@ -160,9 +186,8 @@ def get_grad_norm(parameters, norm_type=2, mpu=None): ...@@ -160,9 +186,8 @@ def get_grad_norm(parameters, norm_type=2, mpu=None):
total_norm = 0. total_norm = 0.
for p in parameters: for p in parameters:
if mpu is not None: if mpu is not None:
if (mpu.get_model_parallel_rank() == 0) or (hasattr(p, if (mpu.get_model_parallel_rank() == 0
'model_parallel') ) or is_model_parallel_parameter(p):
and p.model_parallel):
param_norm = p.grad.data.float().norm(norm_type) param_norm = p.grad.data.float().norm(norm_type)
total_norm += param_norm.item()**norm_type total_norm += param_norm.item()**norm_type
else: else:
...@@ -218,9 +243,8 @@ def get_weight_norm(parameters, norm_type=2, mpu=None): ...@@ -218,9 +243,8 @@ def get_weight_norm(parameters, norm_type=2, mpu=None):
total_norm = 0. total_norm = 0.
for p in parameters: for p in parameters:
if mpu is not None: if mpu is not None:
if (mpu.get_model_parallel_rank() == 0) or (hasattr(p, if (mpu.get_model_parallel_rank() == 0
'model_parallel') ) or is_model_parallel_parameter(p):
and p.model_parallel):
try: try:
param_norm = float(torch.norm(p, norm_type, dtype=torch.float32)) param_norm = float(torch.norm(p, norm_type, dtype=torch.float32))
except TypeError as err: except TypeError as err:
...@@ -255,6 +279,255 @@ def is_model_parallel_parameter(p): ...@@ -255,6 +279,255 @@ def is_model_parallel_parameter(p):
return hasattr(p, 'model_parallel') and p.model_parallel return hasattr(p, 'model_parallel') and p.model_parallel
def prefix_sum_inc(weights):
""" Compute an inclusive prefix sum.
Example:
>>> prefix_sum_inc([3,4,5])
[3, 7, 12]
"""
weights_ = [w for w in weights]
for x in range(1, len(weights_)):
weights_[x] += weights_[x - 1]
return weights_
def partition_uniform(num_items, num_parts):
parts = [0] * (num_parts + 1)
# First check for the trivial edge case
if num_items <= num_parts:
for p in range(num_parts + 1):
parts[p] = min(p, num_items)
return parts
chunksize = floor(num_items / num_parts)
for p in range(num_parts):
parts[p] = min(chunksize * p, num_items)
parts[num_parts] = num_items
return parts
def _lprobe(weights, num_parts, bottleneck):
num_items = len(weights)
total_weight = weights[-1]
# initialize partitioning
parts = [0] * (num_parts + 1)
for p in range(1, num_parts + 1):
parts[p] = num_items
bsum = bottleneck # running sum of target weight for pth partition
chunksize = num_items // num_parts
step = chunksize
for p in range(1, num_parts):
# Jump to the next bucket
while (step < num_items) and (weights[step] < bsum):
step += chunksize
# Find the end index of partition p
parts[p] = bisect_left(weights,
bsum,
lo=step - chunksize,
hi=min(step,
num_items))
# Nothing more to partition, return early
if parts[p] == num_items:
# See if the current partition is overweight.
part_size = weights[-1] - weights[parts[p - 1]]
return parts, part_size < bottleneck
# Next partition target
bsum = weights[parts[p] - 1] + bottleneck
return parts, bsum >= total_weight
def _rb_partition_balanced(weights, num_parts, eps):
total_weight = weights[-1]
lower = total_weight / num_parts # best case heaviest partition
upper = total_weight # worst case heaviest partition
# Do a binary search for the best partitioning
while upper > lower + eps:
mid = lower + ((upper - lower) / 2)
parts, success = _lprobe(weights, num_parts, mid)
if success:
upper = mid
else:
lower = mid + eps
return upper
def partition_balanced(weights, num_parts, eps=1e-3):
num_items = len(weights)
# First check for the trivial edge case
if num_items <= num_parts:
return partition_uniform(num_items, num_parts)
weights_ = prefix_sum_inc(weights)
# Find the smallest bottleneck (weight of heaviest partition)
bottleneck = _rb_partition_balanced(weights_, num_parts, eps=eps)
# Now compute that partitioning
parts, success = _lprobe(weights_, num_parts, bottleneck)
assert success
return parts
class PartitionedTensor:
def __init__(self, tensor, group, partition_meta=None):
super().__init__()
self.group = group
self.num_parts = dist.get_world_size(group=self.group)
self.rank = dist.get_rank(group=self.group)
self.orig_size = list(tensor.size())
self.orig_device = tensor.device
self.local_data, self.partition = self._partition_tensor(tensor)
@classmethod
def from_meta(cls, meta, local_part, group, device='cuda'):
assert meta.dtype == torch.long
dummy = torch.ones(dist.get_world_size(group=group))
part_obj = cls(tensor=dummy, group=group)
meta = meta.tolist()
# [N, list0, ..., listN-1]
part_obj.orig_size = meta[1:(1 + meta[0])]
meta = meta[1 + meta[0]:]
part_obj.orig_device = device
part_obj.local_data = local_part.detach()
part_obj.group = group
# Partition is encoded like the rowptr of a CSR matrix:
# [num_parts, rank, 0, part_1, ..., part_num_parts]
# TODO: support shuffle between different partition granularities
assert part_obj.num_parts == meta[0]
assert part_obj.rank == meta[1]
part_obj.partition = meta[2:] # length num_parts+1
return part_obj
def _partition_tensor(self, tensor):
partition = partition_uniform(num_items=tensor.numel(), num_parts=self.num_parts)
start = partition[self.rank]
length = partition[self.rank + 1] - start
tensor_part = tensor.detach().contiguous().view(-1).narrow(
0,
start=start,
length=length).clone()
return tensor_part, partition
def full(self, device=None):
if device is None:
device = self.orig_device
# Allocate the full tensor as a flat buffer.
full_numel = prod(self.full_size())
flat_tensor = torch.zeros([full_numel],
dtype=self.local_data.dtype,
device=device)
# Prepare all-gather buffer
partition_tensors = []
for part_id in range(self.num_parts):
part_size = self.partition[part_id + 1] - self.partition[part_id]
buf = flat_tensor.narrow(0, start=self.partition[part_id], length=part_size)
if part_id == self.rank:
buf.copy_(self.local_data)
partition_tensors.append(buf)
# Collect the full tensor
dist.all_gather(partition_tensors,
partition_tensors[self.rank],
group=self.group)
for i in range(len(partition_tensors)):
partition_tensors[i].data = torch.zeros(1)
partition_tensors[i] = None
return flat_tensor.view(self.full_size()).clone().detach()
def to_meta(self):
"""Returns a torch.LongTensor that encodes partitioning information.
Can be used along with ``data()`` to serialize a ``PartitionedTensor`` for
communication.
Returns:
torch.LongTensor: a tensor encoding the meta-information for the partitioning
"""
meta = []
meta.append(len(self.orig_size))
meta += list(self.orig_size)
meta.append(self.num_parts)
meta.append(self.rank)
meta += self.partition
return torch.LongTensor(data=meta).to(self.orig_device)
def data(self):
return self.local_data
def local_size(self):
return self.local_data.size()
def full_size(self):
return self.orig_size
mem_alloced = 0
mem_cached = 0
def memory_status(msg, print_rank=-1, reset_max=False):
global mem_alloced, mem_cached
rank = dist.get_rank()
if print_rank != -1 and rank != print_rank:
return
torch.cuda.synchronize()
if reset_max:
torch.cuda.reset_max_memory_cached()
torch.cuda.reset_max_memory_allocated()
new_alloced = torch.cuda.memory_allocated()
new_cached = torch.cuda.memory_cached()
delta_alloced = new_alloced - mem_alloced
delta_cached = new_cached - mem_cached
mem_cached = new_cached
mem_alloced = new_alloced
max_alloced = torch.cuda.max_memory_allocated()
max_cached = torch.cuda.max_memory_cached()
# convert to GB for printing
new_alloced /= 1024**3
new_cached /= 1024**3
delta_alloced /= 1024**3
delta_cached /= 1024**3
max_alloced /= 1024**3
max_cached /= 1024**3
print(
f'RANK={rank} MEMSTATS',
msg,
f'device={torch.cuda.current_device()} '
f'current alloc={new_alloced:0.4f}GB (delta={delta_alloced:0.4f}GB max={max_alloced:0.4f}GB) '
f'current cache={new_cached:0.4f}GB (delta={delta_cached:0.4f}GB max={max_cached:0.4f}GB)'
)
def see_memory_usage(message): def see_memory_usage(message):
return return
if torch.distributed.is_initialized() and not torch.distributed.get_rank() == 0: if torch.distributed.is_initialized() and not torch.distributed.get_rank() == 0:
...@@ -278,3 +551,25 @@ def see_memory_usage(message): ...@@ -278,3 +551,25 @@ def see_memory_usage(message):
"Max cache Allocated %s GigaBytes", "Max cache Allocated %s GigaBytes",
torch.cuda.max_memory_cached() / (1024 * 1024 * 1024), torch.cuda.max_memory_cached() / (1024 * 1024 * 1024),
) )
def call_to_str(base, *args, **kwargs):
"""Construct a string representation of a call.
Args:
base (str): name of the call
args (tuple, optional): args to ``base``
kwargs (dict, optional): kwargs supplied to ``base``
Returns:
str: A string representation of base(*args, **kwargs)
"""
name = f'{base}('
if args:
name += ', '.join(repr(arg) for arg in args)
if kwargs:
name += ', '
if kwargs:
name += ', '.join(f'{key}={repr(arg)}' for key, arg in kwargs.items())
name += ')'
return name
...@@ -26,13 +26,6 @@ def get_group_alignment_padding(tensor_list, sub_partition_size, sub_partition_c ...@@ -26,13 +26,6 @@ def get_group_alignment_padding(tensor_list, sub_partition_size, sub_partition_c
padding = get_alignment_padding(flattened_size, i, sub_partition_size) padding = get_alignment_padding(flattened_size, i, sub_partition_size)
group_paddings.append(padding) group_paddings.append(padding)
logger.info("****Padding information*****")
logger.info(f"tensor_size = {flattened_size}")
logger.info(f"sub_partition_size = {sub_partition_size}")
logger.info(f"sub_partition_count = {sub_partition_count}")
for i, padding in enumerate(group_paddings):
logger.info(f"padding[{i}] = {padding}")
return group_paddings return group_paddings
......
from deepspeed.utils.logging import logger, log_dist from deepspeed.utils.logging import logger, log_dist
from deepspeed.runtime.dataloader import RepeatingLoader
...@@ -12,9 +12,9 @@ from deepspeed.utils import logger ...@@ -12,9 +12,9 @@ from deepspeed.utils import logger
def print_rank_0(message): def print_rank_0(message):
if torch.distributed.is_initialized(): if torch.distributed.is_initialized():
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
logger.info(message) print(message)
else: else:
logger.info(message) print(message)
class SynchronizedWallClockTimer: class SynchronizedWallClockTimer:
...@@ -85,12 +85,15 @@ class SynchronizedWallClockTimer: ...@@ -85,12 +85,15 @@ class SynchronizedWallClockTimer:
def log(self, names, normalizer=1.0, reset=True, memory_breakdown=False): def log(self, names, normalizer=1.0, reset=True, memory_breakdown=False):
"""Log a group of timers.""" """Log a group of timers."""
assert normalizer > 0.0 assert normalizer > 0.0
string = 'time (ms)' string = f'rank={torch.distributed.get_rank()} time (ms)'
for name in names: for name in names:
elapsed_time = self.timers[name].elapsed(reset=reset) * 1000.0 / normalizer if name in self.timers:
string += ' | {}: {:.2f}'.format(name, elapsed_time) elapsed_time = self.timers[name].elapsed(
if memory_breakdown: reset=reset) * 1000.0 / normalizer
string += self.memory_usage() string += ' | {}: {:.2f}'.format(name, elapsed_time)
# TODO: use our logging utilitied to selectively print. Useful for model
# parallelism because rank=0 is too restrictive.
print_rank_0(string) print_rank_0(string)
......
...@@ -66,5 +66,7 @@ lnav: ...@@ -66,5 +66,7 @@ lnav:
url: /tutorials/lrrt/ url: /tutorials/lrrt/
- title: "DeepSpeed Sparse Attention" - title: "DeepSpeed Sparse Attention"
url: /tutorials/sparse-attention/ url: /tutorials/sparse-attention/
- title: "Pipeline Parallelism"
url: /tutorials/pipeline/
- title: "Contributing" - title: "Contributing"
url: /contributing/ url: /contributing/
---
title: "Pipeline Parallelism"
---
DeepSpeed v0.3 includes new support for pipeline parallelism! Pipeline
parallelism improves both the memory and compute efficiency of deep learning
training by partitioning the layers of a model into stages that can be
processed in parallel.
DeepSpeed's training engine provides hybrid data and pipeline parallelism and
can be further combined with model parallelism such as
[Megatron-LM](https://github.com/NVIDIA/Megatron-LM).
An illustration of
3D parallelism is shown below. Our latest [results](linklinklink)
demonstrate that this 3D parallelism enables training models with over a
**trillion** parameters.
![3D parallelism in DeepSpeed](/assets/images/3d-parallelism.png)
DeepSpeed uses *gradient accumulation* to extract pipeline parallelism (shown
below). Each batch of training data is divided into micro-batches that can be
processed in parallel by the pipeline stages. Once a stage completes the
forward pass for a micro-batch, the activation memory is communicated to the
next stage in the pipeline. Similarly, as the next stage completes its
backward pass on a micro-batch, the gradient with respect to the activation
is communicated backwards through the pipeline. Each backward pass
accumulates gradients locally. Next, all data parallel groups perform
reductions of the gradients in parallel. Lastly, the optimizer updates the
model weights.
Below is an illustration of how DeepSpeed will train a batch with eight
micro-batches using hybrid two-way data parallelism and two-stage pipeline
parallelism. GPUs 0 and 2 are arranged in a pipeline and will alternate
forward (F) and backward (B) passes. They will then all-reduce (AR) gradients
with their data parallel counterparts, GPUs 1 and 3, respectively. Finally,
the two pipeline stages update their model weights.
![Pipeline Schedule](/assets/images/pipe-schedule.png)
## Getting Starting with Pipeline Parallelism
DeepSpeed strives to accelerate *and* simplify the process of pipeline
parallel training. This section provides first steps with hybrid data and
pipeline parallel training by preparing `torchvision`'s
[AlexNet](https://pytorch.org/docs/1.2.0/_modules/torchvision/models/alexnet.html)
model.
### Expressing Pipeline Models
Pipeline parallelism requires models to be expressed as a sequence of layers.
In the forward pass, each layer consumes the output of the previous
layer. In fact, there is no need to specify a `forward()` for a pipeline
parallel model! The forward pass of a pipeline parallel model implicitly
takes the form:
```python
def forward(self, inputs):
x = inputs
for layer in self.layers:
x = layer(x)
return x
```
PyTorch's
[`torch.nn.Sequential`](https://pytorch.org/docs/stable/generated/torch.nn.Sequential.html)
is a convenient container for expressing pipeline parallel models and can be
parallelized by DeepSpeed with no modification:
```python
net = nn.Sequential(
nn.Linear(in_features, hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(hidden_dim, out_features)
)
from deepspeed.pipe import PipelineModule
net = PipelineModule(layers=net, num_stages=2)
```
`PipelineModule` uses its `layers` argument as the sequence of layers that
comprise the model. After initialization, `net` is divided into two pipeline
stages and its layers moved to the correpsonding GPUs. If more than two GPUs
are present, DeepSpeed will also use hybrid data parallelism.
**Note:** The total number of GPUs must be divisible by the number of pipeline
stages.
{: .notice--info}
**Note:** For large model training, see [memory-efficient model construction](#memory-efficient-module-initialization).
{: .notice--info}
### AlexNet
Let's look at an abbreviated implementation of `torchvision`'s
[AlexNet](https://pytorch.org/docs/1.2.0/_modules/torchvision/models/alexnet.html):
```python
class AlexNet(nn.Module):
def __init__(self, num_classes=1000):
super(AlexNet, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
...
nn.MaxPool2d(kernel_size=3, stride=2),
)
self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
self.classifier = nn.Sequential(
nn.Dropout(),
...
nn.Linear(4096, num_classes),
)
def forward(self, x):
x = self.features(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x
```
`AlexNet` is mostly a composition of several `Sequential` submodules. We can
turn this into a `PipelineModule` by flattening its submodules into a single
sequence of layers:
```python
class AlexNetPipe(AlexNet):
def to_layers(self):
layers = [
*self.features,
self.avgpool,
lambda x: torch.flatten(x, 1),
*self.classifier
]
return layers
from deepspeed.pipe import PipelineModule
net = AlexNetPipe()
net = PipelineModule(layers=net.to_layers(), num_stages=2)
```
**Note:**
the `lamda` in the middle of `layers` above is not a `torch.nn.Module`
type. Any object that implements `__call__()` can be a layer in a
`PipelineModule`: this allows for convenient data transformations in the
pipeline.
{: .notice--info}
### Inputs and Outputs
Following `torch.nn.Sequential`, the inputs and outputs of each layer must be
either a single `torch.Tensor` or a `tuple` of tensors. In practice, some
models may need to modify their forward pass to pack and unpack arguments to
`forward()`. Consider an abbreviated implementation of a stack of Transformer
blocks:
```python
class TransformerBlock(nn.Module)
...
def forward(self, hidden, mask):
output = self.compute(hidden, mask)
return output
...
stack = [ TransformerBlock() for _ in range(num_layers) ]
```
Two modifications to `TransformerBlock` are required:
1. The arguments must be collected into a `tuple`.
2. `mask` must also be returned from `forward()` to pass to the next layer.
These modifications can be accomplished with a short subclass:
```python
class TransformerBlockPipe(TransformerBlock)
def forward(self, inputs):
hidden, mask = inputs
outputs = super().forward(hidden, mask)
return (output, mask)
stack = [ TransformerBlockPipe() for _ in range(num_layers) ]
```
### Training Loops
Pipeline parallelism interleaves forward and backward passes, and thus the
training loop cannot be divided into separate stages of `forward()`,
`backward()` and `step()`.
Instead, DeepSpeed's pipeline engine provides a `train_batch()` method that
advances the pipeline engine until the next batch of training data is
consumed and the model weights updated.
```python
train_iter = iter(train_loader)
loss = engine.train_batch(data_iter=train_iter)
```
The above `train_batch()` example is equivalent to the following with
traditional data parallel DeepSpeed:
```python
train_iter = iter(train_loader)
for micro_batch in engine.gradient_accumulation_steps():
batch = next(data_iter)
loss = engine(batch)
engine.backward(loss)
engine.step()
```
### Dealing with Data
Data parallel training typically has each worker perform IO independently at
the start of each batch. However, in a pipeline parallel environment, only the
first stage uses the input data, and only the last stage uses labels for loss
calculation.
**Note:**
The pipeline engine expects data loaders to return a `tuple` of two items. The
first returned item is the input batch data, and the second item is the data
to be used in the loss calculation. As before, inputs and labels should be
either `torch.Tensor` type or a `tuple` of tensors.
{: .notice--info}
For convenience, the DeepSpeed pipeline engine can construct a distributed
data loader when a dataset is provided to `deepspeed.initialize()`. DeepSpeed
handles the rest of the complexity of data loading, and so the pipeline
training loop becomes:
```python
engine, _, _, _ = deepspeed.initialize(
args=args,
model=net,
model_parameters=[p for p in net.parameters() if p.requires_grad],
training_data=cifar_trainset())
for step in range(args.steps):
loss = engine.train_batch()
```
Of course, DeepSpeed will work with any data loader that you wish to use.
Data loaders should be constructed by the first and last stages in the
pipeline. Each worker should load micro-batches of size
`engine.train_micro_batch_size_per_gpu()` and will be queried
a total of `engine.gradient_accumulation_steps()` times per `train_batch()`.
**Watch out!**
The pipeline engine *pulls* data from an iteratior instead of iterating over
it. It's critical that the data stream does not empty in the middle of a
training batch. Each invocation of `train_batch()` will pull
a total of `engine.gradient_accumulation_steps()` micro-batches of data from
the data iterator.
{: .notice--warning}
DeepSpeed provides a convenience class `deepspeed.utils.RepeatingLoader` that
simply wraps an iterable such as a data loader and restarts it whenever the
end is reached:
```python
train_loader = deepspeed.utils.RepeatingLoader(train_loader)
train_iter = iter(train_loader)
for step in range(args.steps):
loss = engine.train_batch(data_iter=trainiter)
```
## Advanced Topics
### Load Balancing Pipeline Modules
The performance of pipeline parallel training strongly relies on load
balance. DeepSpeed provides several mechanisms for partitioning the model
across GPUs. These strategies can be set with the `partition_method` keyword
argument to `PipelineModule`. Here are partitioning methods currently provided
by DeepSpeed:
* `partition_method="parameters"` (**default**)
balances the number of trainable parameters on each pipeline stage . This is
especially useful in memory-constrained environments and when the size of a
layer is proportional to the computation time.
* `partition_method="type:[regex]"`
balances layers whose class names match `[regex]`. The regular expression
is not case sensitive. For example, `partition_method="type:transformer"`
would balance the number of transformer layers per stage.
* `partition_method="uniform"` balances the number of layers per stage.
### Memory-Efficient Model Construction
Building a `Sequential` and providing it `PipelineModule` is a convenient way
of specifying a pipeline parallel model. However, this approach encounters
scalability issues for massive models. Starting from a `Sequential` allocates
the model in CPU memory redundantly by every worker. A machine with 16 GPUs
must have as much local CPU memory as 16 times the model size.
DeepSpeed provides a `LayerSpec` class that delays the construction of
modules until the model layers have been partitioned across workers. Then,
the modules are built on the GPU that owns the layer.
Here's an example of the abbreviated AlexNet model, but expressed only
with `LayerSpec`s. Note that the syntax is almost unchanged: `nn.ReLU(inplace=True)`
simply becomes `LayerSpec(nn.ReLU, inplace=True)`.
```python
from deepspeed.pipe import PipelineModule, LayerSpec
class AlexNetPipe(PipelineModule):
def __init__(self, num_classes=10, **kwargs):
self.num_classes = num_classes
specs = [
LayerSpec(nn.Conv2d, 3, 64, kernel_size=11, stride=4, padding=2),
LayerSpec(nn.ReLU, inplace=True),
...
LayerSpec(nn.ReLU, inplace=True),
LayerSpec(nn.Linear, 4096, self.num_classes),
]
super().__init__(layers=specs, loss_fn=nn.CrossEntropyLoss(), **kwargs)
```
### Tied Layers
Some models cannot be entirely expressed as pipeline parallel models because
some layers are reused in the pipeline. For example, Transformer based
language models commonly use an embedding layer early in the pipeline to map
vocabulary to hidden states, and then use the embedding to map hidden states
back to vocabulary at the end of the pipeline. If the model was restricted to
pure pipeline parallelism, this embedding reuse would prohibit pipeline
parallelism.
DeepSpeed provides a `TiedLayerSpec` that is an extension of
`LayerSpec`. `TiedLayerSpec` requires an additional argument: `key`.
Each reuse of a layer is specified with a `TiedLayerSpec`, and the `key` field
is used to identify where a layer is reused.
Tied layers are replicated on every pipeline stage that owns an instance of
reuse. Training then proceeds as normal, but an additional all-reduce of the
tied gradients is added after all backward passes complete. The all-reduce
ensures that the weights of the tied layer remain in sync across pipeline stages.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment