Commit c25a91b6 authored by aiss's avatar aiss
Browse files

Merge branch 'ds-v0.9.2-rocm' into 'main'

Ds v0.9.2 rocm

See merge request dcutoolkit/deeplearing/deepspeed!2
parents d1596c94 af82b300
'''
Copyright 2020 The Microsoft DeepSpeed Team
'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import types
import torch
import numpy as np
......@@ -49,13 +51,13 @@ class ZeroOneAdam(torch.optim.Optimizer):
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
def __init__(self,
params,
deepspeed=None,
lr=1e-3,
bias_correction=True,
betas=(0.9,
0.999),
betas=(0.9, 0.999),
eps=1e-8,
eps_inside_sqrt=False,
weight_decay=0.,
......@@ -102,11 +104,12 @@ class ZeroOneAdam(torch.optim.Optimizer):
if self.comm_backend_name == 'nccl':
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
assert TORCH_MAJOR >= 1 and TORCH_MINOR >= 8, "Please use torch 1.8 or greater to enable NCCL backend in 0/1 Adam. Alternatively, please specify 'mpi' as the 'comm_backend_name' in config file to proceed with the MPI backend"
assert (
(TORCH_MAJOR == 1 and TORCH_MINOR >= 8) or TORCH_MAJOR >= 2
), "Please use torch 1.8 or greater to enable NCCL backend in 0/1 Adam. Alternatively, please specify 'mpi' as the 'comm_backend_name' in config file to proceed with the MPI backend"
assert dist.is_initialized() == True, "Please initialize the torch distributed backend."
from deepspeed.runtime.comm.nccl import NcclBackend
self.using_pipeline = hasattr(self.deepspeed,
'pipeline_enable_backward_allreduce')
self.using_pipeline = hasattr(self.deepspeed, 'pipeline_enable_backward_allreduce')
self.comm_backend_handle = NcclBackend(self.deepspeed.mpu)
elif self.comm_backend_name == 'mpi':
......@@ -181,16 +184,12 @@ class ZeroOneAdam(torch.optim.Optimizer):
state['corrected_tensor_size'] = state['tensor_size']
if state['tensor_size'] % (self.size * self.divider) != 0:
state['corrected_tensor_size'] += ((self.size * self.divider) -
(state['tensor_size'] %
(self.size * self.divider)))
state['server_chunk_size'] = state[
'corrected_tensor_size'] // self.size
state['corrected_tensor_size'] += ((self.size * self.divider) - (state['tensor_size'] %
(self.size * self.divider)))
state['server_chunk_size'] = state['corrected_tensor_size'] // self.size
get_accelerator().empty_cache()
state['worker_error'] = torch.zeros(state['corrected_tensor_size'],
device=p.device)
state['server_error'] = torch.zeros(state['server_chunk_size'],
device=p.device)
state['worker_error'] = torch.zeros(state['corrected_tensor_size'], device=p.device)
state['server_error'] = torch.zeros(state['server_chunk_size'], device=p.device)
# Accumulation of momentum, i.e., the u variable in the 0/1 Adam paper
state['momentum_accumulator'] = torch.zeros_like(p.data)
get_accelerator().empty_cache()
......@@ -213,16 +212,10 @@ class ZeroOneAdam(torch.optim.Optimizer):
if self.size > 1:
with torch.no_grad():
grad_onebit = self.comm_backend_handle.compressed_allreduce(
grad,
state['worker_error'],
state['server_error'],
self.deepspeed.local_rank)
grad, state['worker_error'], state['server_error'], self.deepspeed.local_rank)
if 'exp_avg_mask' in group:
if grad_onebit.device != group[
'exp_avg_mask'].device:
group['exp_avg_mask'] = group[
'exp_avg_mask'].to(
device=grad_onebit.device)
if grad_onebit.device != group['exp_avg_mask'].device:
group['exp_avg_mask'] = group['exp_avg_mask'].to(device=grad_onebit.device)
grad_onebit.mul_(group['exp_avg_mask'])
exp_avg.mul_(beta1).add_(1 - beta1, grad_onebit)
else:
......@@ -233,15 +226,12 @@ class ZeroOneAdam(torch.optim.Optimizer):
if not self.initialize:
if self.size > 1:
comm_buffer.set_(
self.comm_backend_handle.compressed_allreduce(
comm_buffer,
state['worker_error'],
state['server_error'],
self.deepspeed.local_rank))
self.comm_backend_handle.compressed_allreduce(comm_buffer, state['worker_error'],
state['server_error'],
self.deepspeed.local_rank))
if 'exp_avg_mask' in group:
if comm_buffer.device != group['exp_avg_mask'].device:
group['exp_avg_mask'] = group['exp_avg_mask'].to(
device=comm_buffer.device)
group['exp_avg_mask'] = group['exp_avg_mask'].to(device=comm_buffer.device)
comm_buffer.mul_(group['exp_avg_mask'])
if self.initialize:
......@@ -252,22 +242,18 @@ class ZeroOneAdam(torch.optim.Optimizer):
p.data.add_(-group['lr'] * update)
if self.freeze_key is True:
comm_buffer.add_(-group['lr'] * update)
if state['step'] % state[
'local_step_interval'] == 0 and self.freeze_key:
if state['step'] % state['local_step_interval'] == 0 and self.freeze_key:
with torch.no_grad():
p.data.add_(-1 * comm_buffer)
comm_buffer.mul_(exp_avg_sq.sqrt() + group['eps'])
if self.size > 1:
comm_buffer.copy_(
self.comm_backend_handle.compressed_allreduce(
comm_buffer,
state['worker_error'],
state['server_error'],
self.deepspeed.local_rank))
self.comm_backend_handle.compressed_allreduce(comm_buffer, state['worker_error'],
state['server_error'],
self.deepspeed.local_rank))
if 'exp_avg_mask' in group:
if comm_buffer.device != group['exp_avg_mask'].device:
group['exp_avg_mask'] = group['exp_avg_mask'].to(
device=comm_buffer.device)
group['exp_avg_mask'] = group['exp_avg_mask'].to(device=comm_buffer.device)
comm_buffer.mul_(group['exp_avg_mask'])
exp_avg.zero_().add_(comm_buffer / state['lrs'], alpha=-1)
p.data.add_(comm_buffer / (exp_avg_sq.sqrt() + group['eps']))
......@@ -298,9 +284,8 @@ class ZeroOneAdam(torch.optim.Optimizer):
state['local_step_counter'] += 1
if state['local_step_counter'] == self.local_step_scaler:
state['local_step_counter'] = 0
state['local_step_interval'] = min(
self.local_step_clipper,
state['local_step_interval'] * 2)
state['local_step_interval'] = min(self.local_step_clipper,
state['local_step_interval'] * 2)
if not self.initialize:
print('Pop out errors', flush=True)
......@@ -343,14 +328,13 @@ class ZeroOneAdam(torch.optim.Optimizer):
for i, group in enumerate(self.param_groups):
if 'exp_avg_mask' in group:
state_dict['param_groups'][i]['exp_avg_mask'] = group['exp_avg_mask']
elif 'exp_avg_mask' not in group and 'exp_avg_mask' in state_dict[
'param_groups'][i]:
elif 'exp_avg_mask' not in group and 'exp_avg_mask' in state_dict['param_groups'][i]:
state_dict['param_groups'][i].pop('exp_avg_mask')
super().load_state_dict(state_dict)
if self.state[self.param_groups[0]['params'][0]]['step'] < self.var_freeze_step:
self.var_freeze_key = False
if (self.state[self.param_groups[0]['params'][0]]['step'] + 1
) % self.state[self.param_groups[0]['params'][0]]['var_interval'] == 0:
if (self.state[self.param_groups[0]['params'][0]]['step'] +
1) % self.state[self.param_groups[0]['params'][0]]['var_interval'] == 0:
if self.using_pipeline:
self.deepspeed.pipeline_enable_backward_allreduce = True
else:
......
'''
Copyright 2019 The Microsoft DeepSpeed Team
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
"""
Copyright NVIDIA/apex
This file is adapted from FP16_Optimizer in NVIDIA/apex
'''
"""
from deepspeed.moe.utils import split_params_grads_into_shared_and_expert_params
import torch
......@@ -24,6 +26,7 @@ class FP16_UnfusedOptimizer(DeepSpeedOptimizer):
For usage example please see, TODO: DeepSpeed V2 Tutorial
"""
def __init__(self,
init_optimizer,
deepspeed=None,
......@@ -105,9 +108,7 @@ class FP16_UnfusedOptimizer(DeepSpeedOptimizer):
self.mpu = mpu
self.overflow = False
self.overflow_checker = CheckOverflow(self.fp16_groups,
mpu=self.mpu,
deepspeed=deepspeed)
self.overflow_checker = CheckOverflow(self.fp16_groups, mpu=self.mpu, deepspeed=deepspeed)
self.initialize_optimizer_states()
......@@ -137,45 +138,33 @@ class FP16_UnfusedOptimizer(DeepSpeedOptimizer):
expert_norm_groups = []
for i, group in enumerate(self.fp16_groups):
grads = [
torch.zeros(p.size(),
dtype=p.dtype,
device=p.device) if p.grad is None else p.grad for p in group
torch.zeros(p.size(), dtype=p.dtype, device=p.device) if p.grad is None else p.grad for p in group
]
grads_groups.append(grads)
grads_groups_flat.append(_flatten_dense_tensors(grads))
grads_for_norm, expert_grads_for_norm = split_params_grads_into_shared_and_expert_params(group)
norm_group_value = 0.0
if len(grads_for_norm) > 0:
norm_group_value = get_weight_norm(
_flatten_dense_tensors(grads_for_norm),
mpu=self.mpu)
norm_group_value = get_weight_norm(_flatten_dense_tensors(grads_for_norm), mpu=self.mpu)
norm_groups.append(norm_group_value)
expert_norm_group_value = 0.0
if len(expert_grads_for_norm) > 0:
expert_norm_group_value = get_weight_norm(
_flatten_dense_tensors(expert_grads_for_norm),
mpu=self.mpu)
expert_norm_group_value = get_weight_norm(_flatten_dense_tensors(expert_grads_for_norm), mpu=self.mpu)
expert_norm_groups.append(expert_norm_group_value)
self.overflow = self.overflow_checker.check_using_norm(norm_groups +
expert_norm_groups)
self.overflow = self.overflow_checker.check_using_norm(norm_groups + expert_norm_groups)
prev_scale = self.cur_scale
self._update_scale(self.overflow)
if self.overflow:
if self.verbose:
logger.info(
"[deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss "
"scale: {}, reducing to {}".format(prev_scale,
self.cur_scale))
logger.info("[deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss "
"scale: {}, reducing to {}".format(prev_scale, self.cur_scale))
return self.overflow
self._global_grad_norm = get_global_norm(norm_list=norm_groups)
combined_scale = self.unscale_and_clip_grads(self._global_grad_norm,
apply_scale=False)
self.optimizer.step(grads=grads_groups,
output_params=self.fp16_groups,
scale=combined_scale)
combined_scale = self.unscale_and_clip_grads(self._global_grad_norm, apply_scale=False)
self.optimizer.step(grads=grads_groups, output_params=self.fp16_groups, scale=combined_scale)
for fp32_group, fp16_group in zip(self.fp32_groups, self.fp16_groups):
for idx, (fp32_param, fp16_param) in enumerate(zip(fp32_group, fp16_group)):
......@@ -199,9 +188,7 @@ class FP16_UnfusedOptimizer(DeepSpeedOptimizer):
def override_loss_scale(self, loss_scale):
if loss_scale != self.external_loss_scale:
logger.info(
f'[deepspeed] setting loss scale from {self.external_loss_scale} -> {loss_scale}'
)
logger.info(f'[deepspeed] setting loss scale from {self.external_loss_scale} -> {loss_scale}')
self.custom_loss_scaler = True
self.external_loss_scale = loss_scale
......@@ -219,10 +206,8 @@ class FP16_UnfusedOptimizer(DeepSpeedOptimizer):
self._update_scale(self.overflow)
if self.overflow:
if self.verbose:
logger.info(
"[deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss "
"scale: {}, reducing to {}".format(prev_scale,
self.cur_scale))
logger.info("[deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss "
"scale: {}, reducing to {}".format(prev_scale, self.cur_scale))
return self.overflow
norm_groups = []
......@@ -236,9 +221,7 @@ class FP16_UnfusedOptimizer(DeepSpeedOptimizer):
# copying gradients to fp32 to wor k with fp32 parameters
for fp32_param, fp16_param in zip(self.fp32_groups[i], self.fp16_groups[i]):
if fp16_param.grad is None:
fp32_param.grad = torch.zeros(fp16_param.size(),
dtype=fp32_param.dtype,
device=fp32_param.device)
fp32_param.grad = torch.zeros(fp16_param.size(), dtype=fp32_param.dtype, device=fp32_param.device)
else:
fp32_param.grad = fp16_param.grad.to(fp32_param.dtype)
......@@ -294,25 +277,19 @@ class FP16_UnfusedOptimizer(DeepSpeedOptimizer):
if self.dynamic_loss_scale:
prev_scale = self.cur_scale
if skip:
self.cur_scale = max(self.cur_scale / self.scale_factor,
self.min_loss_scale)
self.cur_scale = max(self.cur_scale / self.scale_factor, self.min_loss_scale)
self.last_overflow_iter = self.cur_iter
if self.verbose:
logger.info("Grad overflow on iteration: %s", self.cur_iter)
logger.info(
f"Reducing dynamic loss scale from {prev_scale} to {self.cur_scale}"
)
logger.info(f"Reducing dynamic loss scale from {prev_scale} to {self.cur_scale}")
else:
# Ensure self.scale_window updates since last overflow
stable_interval = (self.cur_iter - self.last_overflow_iter) - 1
if (stable_interval > 0) and (stable_interval % self.scale_window == 0):
self.cur_scale *= self.scale_factor
if self.verbose:
logger.info(
f"No Grad overflow for {self.scale_window} iterations")
logger.info(
f"Increasing dynamic loss scale from {prev_scale} to {self.cur_scale}"
)
logger.info(f"No Grad overflow for {self.scale_window} iterations")
logger.info(f"Increasing dynamic loss scale from {prev_scale} to {self.cur_scale}")
else:
if skip:
logger.info("Grad overflow on iteration %s", self.cur_iter)
......
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
from deepspeed.inference.config import DeepSpeedInferenceConfig
from deepspeed.module_inject.replace_policy import replace_policies
from deepspeed.module_inject.utils import policy_to_ds_container
from .engine import DeepSpeedEngine
from .utils import TLinear, get_inactive_params
from deepspeed.runtime.zero import GatheredParameters
import time
import gc
import math
from deepspeed import comm as dist
from deepspeed.accelerator import get_accelerator
from torch import nn
from deepspeed.utils import logger
from deepspeed.ops.op_builder import InferenceBuilder
from deepspeed.module_inject.layers import LinearLayer, Normalize, EmbeddingLayer, OPTEmbedding
try:
import transformers
OPTLearnedPositionalEmbedding = transformers.models.opt.modeling_opt.OPTLearnedPositionalEmbedding
except:
OPTLearnedPositionalEmbedding = None
inference_cuda_module = None
class DeepSpeedHybridEngine(DeepSpeedEngine):
r"""DeepSpeed engine for training and inference."""
inference_mp_group = None
def __init__(self, args, model, **kwargs):
super().__init__(args, model, **kwargs)
# synch seed between all GPUs
_rng_state = get_accelerator().get_rng_state().to(get_accelerator().current_device_name())
dist.broadcast(_rng_state, 0)
get_accelerator().set_rng_state(_rng_state.cpu())
self.Z3_enabled = (self._config.zero_config.stage == 3)
self.gather_all_layers = self._config.hybrid_engine.pin_parameters
# inference containers / fwds
self._inference_containers = []
self._orig_modules = []
self._orig_fwds = []
self.create_inference_module()
# Performance stats
self._t_start = None
self._total_latency = 0
self._iters = 0
self._training_start_time = None
self._generate_latency = 0
self._training_latency = 0
self._total_batch_size = None
self._gather_latency = 0
global inference_cuda_module
if inference_cuda_module is None:
builder = InferenceBuilder()
inference_cuda_module = builder.load()
self.is_lora_fused = False
def convert_to_linear_transposed(self, model):
def _replace_linear_layer(r_module, parent_type=None, prev_type=None):
for name, child in r_module.named_children():
if child.__class__ in [torch.nn.Linear] and \
(parent_type is torch.nn.ModuleList or prev_type is torch.nn.ModuleList):
setattr(r_module, name, TLinear(child, name))
else:
_replace_linear_layer(child, type(r_module), prev_type=parent_type)
return r_module
_replace_linear_layer(model)
def new_inference_container(self, orig_layer, policy_cls, layer_id):
policy = policy_cls(orig_layer, inference=True)
_container = policy_to_ds_container(
policy=policy,
config=DeepSpeedInferenceConfig(set_empty_params=True,
max_out_tokens=self._config.hybrid_engine.max_out_tokens,
min_out_tokens=self._config.hybrid_engine.max_out_tokens,
transposed_mode=True),
model_config=self.module.config if hasattr(self.module, 'config') else None,
layer_id=layer_id,
child=orig_layer)
_container.set_dtype(self._config.fp16_enabled)
if self.mpu is not None:
if hasattr(self.mpu, 'get_model_parallel_world_size'):
_container.set_tensor_parallel_config(self.mpu.get_model_parallel_world_size(),
self.mpu.get_model_parallel_group())
else:
_container.set_tensor_parallel_config(self.mpu.get_tensor_model_parallel_world_size(),
self.mpu.get_tensor_model_parallel_group())
else:
_container.set_tensor_parallel_config(self._config.hybrid_engine.inference_tp_size, self.mp_group)
_container.initialize_tensors(enable_training=True)
_container.create_ds_model_config()
_container.create_module()
_container.set_params_wo_copy(Z3_enabled=self.Z3_enabled)
return _container
def populate_all_inference_policies(self):
self.inference_policies = {}
for plcy in replace_policies:
_ = plcy(None)
if isinstance(plcy._orig_layer_class, list):
for orig_layer_class in plcy._orig_layer_class:
self.inference_policies.update({orig_layer_class: (self.new_inference_container, plcy)})
elif plcy._orig_layer_class is not None:
self.inference_policies.update({plcy._orig_layer_class: (self.new_inference_container, plcy)})
self.inference_policies.update({
nn.Linear: (LinearLayer, ),
nn.Embedding: (EmbeddingLayer, ),
nn.LayerNorm: (Normalize, ),
OPTLearnedPositionalEmbedding: (OPTEmbedding, )
})
def _fuse_lora(self, params, lora_params):
maybe_has_lora_params = [p for p in params if len(p.shape) > 1]
for lora_param, weight in zip(lora_params, maybe_has_lora_params):
if len(lora_param) == 3:
lora_right_weight, \
lora_left_weight, \
lora_scaling = lora_param
weight.data += lora_scaling * torch.matmul(lora_left_weight.t(), lora_right_weight.t())
def fuse_lora_weight(self):
for layer_id in range(len(self.layer_params)):
self._fuse_lora(self.layer_params[layer_id], self.lora_params[layer_id])
def _unfuse_lora(self, params, lora_params):
maybe_has_lora_params = [p for p in params if len(p.shape) > 1]
for lora_param, weight in zip(lora_params, maybe_has_lora_params):
if len(lora_param) == 3:
lora_right_weight, \
lora_left_weight, \
lora_scaling = lora_param
weight.data -= lora_scaling * torch.matmul(lora_left_weight.t(), lora_right_weight.t())
def unfuse_lora_weight(self):
for layer_id in range(len(self.layer_params)):
self._unfuse_lora(self.layer_params[layer_id], self.lora_params[layer_id])
def unfuse_lora_weight_non_pinned(self):
for layer_id in range(len(self.layer_params)):
non_active_params = get_inactive_params(self.layer_params[layer_id])
non_active_lora_params = get_inactive_params(self.layer_lora_params[layer_id])
non_active_params.extend(non_active_lora_params)
with GatheredParameters(non_active_params):
self._unfuse_lora(self.layer_params[layer_id], self.lora_params[layer_id])
def retake_inference_cache(self):
if self._config.hybrid_engine.release_inference_cache:
retake_success = inference_cuda_module.retake_workspace()
if not retake_success:
logger.warning("Unable to acquire workspace on first attempt, emtpying cache and retrying.")
gc.collect()
get_accelerator().empty_cache()
retake_success = inference_cuda_module.retake_workspace()
if not retake_success:
raise RuntimeError("Unable to retake inference workspace.")
def generate(self, *inputs, **kwargs):
if self._total_batch_size is None:
bsz = inputs[0].shape[0] if len(inputs) > 0 else \
kwargs['input_ids'].shape[0]
self._total_batch_size = bsz * dist.get_world_size()
self._t0 = time.time()
if self.Z3_enabled and self.gather_all_layers:
if self._config.hybrid_engine.inference_tp_size > 1:
non_tp_params = []
for other_layer in self._other_layers:
non_tp_params.extend(list(other_layer.parameters()))
partition_size = self._config.hybrid_engine.tp_gather_partition_size
layer_groups = math.ceil(len(self.layer_params) / partition_size)
for lg in range(layer_groups):
non_active_params = []
non_active_lora_params = []
for layer_id in range(lg * partition_size, min(len(self.layer_params), (lg + 1) * partition_size),
1):
non_tp_params.extend(self.layer_params[layer_id][:4])
non_active_params.extend(get_inactive_params(self.layer_params[layer_id]))
non_active_params.extend(get_inactive_params(self.layer_lora_params[layer_id]))
with GatheredParameters(non_active_params):
for layer_id in range(lg * partition_size,
min(len(self.layer_params), (lg + 1) * partition_size), 1):
if len(self.all_lora_params) > 0:
self._fuse_lora(self.layer_params[layer_id], self.lora_params[layer_id])
if self.mpu is not None:
self._inference_containers[layer_id].apply_tensor_parallelism(
mp_group=self.mp_group, tp_size=self._config.hybrid_engine.inference_tp_size)
# TODO(cmikeh2) Evaluate if this can be deferred when release_inference_cache
# is enabled.
gc.collect()
get_accelerator().empty_cache()
self._gather_latency = time.time() - self._t0
input_shape = inputs[0].shape if len(inputs) > 0 else \
kwargs['input_ids'].shape
output = torch.zeros(
(input_shape[0] * self._config.hybrid_engine.inference_tp_size, ) + input_shape[1:],
dtype=inputs[0].dtype if len(inputs) > 0 else kwargs['input_ids'].dtype,
device=inputs[0].device if len(inputs) > 0 else kwargs['input_ids'].device)
input_cont = inputs[0].contiguous() if len(inputs) > 0 else kwargs['input_ids'].contiguous()
dist.all_gather_into_tensor(output, input_cont, group=self.mp_group)
if len(inputs) > 0:
inputs = (output, )
else:
kwargs['input_ids'] = output
self.retake_inference_cache()
non_active_params = get_inactive_params(non_tp_params)
with GatheredParameters(non_active_params):
generate_ret_vals = self._generate(*inputs, **kwargs)
for layer_id in range(len(self.layer_params)):
self._inference_containers[layer_id].release_memory()
rank = dist.get_rank(group=self.mp_group)
generate_ret_vals = generate_ret_vals[input_shape[0] * rank:input_shape[0] * (rank + 1)]
else:
non_active_layers = get_inactive_params(self.all_layers_params)
non_active_lora_params = get_inactive_params(self.all_lora_params)
non_active_layers.extend(non_active_lora_params)
with GatheredParameters(non_active_layers):
self._gather_latency = time.time() - self._t0
if len(self.all_lora_params) > 0:
self.fuse_lora_weight()
self.retake_inference_cache()
generate_ret_vals = self._generate(*inputs, **kwargs)
if len(self.all_lora_params) > 0:
self.unfuse_lora_weight()
else:
if len(self.all_lora_params) > 0 and (not self.Z3_enabled):
self.fuse_lora_weight()
self.retake_inference_cache()
generate_ret_vals = self._generate(*inputs, **kwargs)
if len(self.all_lora_params) > 0:
if (not self.Z3_enabled):
self.unfuse_lora_weight()
else:
self.unfuse_lora_weight_non_pinned()
self.is_lora_fused = False
if self._config.hybrid_engine.release_inference_cache:
inference_cuda_module.release_workspace()
gc.collect()
get_accelerator().empty_cache()
self._generate_latency = time.time() - self._t0 - self._gather_latency
return generate_ret_vals
def create_inference_containers(self, module, layer_id=0):
for name, child in module.named_children():
if child.__class__ in self.inference_policies:
if self.inference_policies[child.__class__][0] == self.new_inference_container:
self._inference_containers.append(self.inference_policies[child.__class__][0](
child, self.inference_policies[child.__class__][-1], layer_id))
self._orig_modules.append(child)
self._orig_fwds.append(child.forward)
self.layer_params.append(self._inference_containers[layer_id].get_all_params())
self.lora_params.append(self._inference_containers[layer_id].get_lora_params())
self.layer_lora_params.append([])
for lora_param in self.lora_params[layer_id]:
self.layer_lora_params[layer_id].extend(lora_param[:-1])
self.all_lora_params.extend(lora_param[:-1])
layer_id += 1
else:
self._other_layers.append(self.inference_policies[child.__class__][0](
weight=child.weight, bias=child.bias if hasattr(child, 'bias') else None))
self._orig_modules_others.append(child)
self._orig_fwds_others.append(child.forward)
else:
self.create_inference_containers(child, layer_id=layer_id)
def create_inference_module(self):
self.layer_params = []
self.layer_lora_params = []
self.lora_params = []
self.all_lora_params = []
self._other_layers = []
self._orig_modules_others = []
self._orig_fwds_others = []
if self._config.hybrid_engine.inference_tp_size > 1:
if self.mpu is not None:
global_rank = dist.get_rank()
world_size = dist.get_world_size()
mp_group_id = global_rank // self._config.hybrid_engine.inference_tp_size
num_mp_groups = world_size // self._config.hybrid_engine.inference_tp_size
for mp_group_id in range(num_mp_groups):
ranks = list(
range(mp_group_id * self._config.hybrid_engine.inference_tp_size, \
(mp_group_id + 1) * self._config.hybrid_engine.inference_tp_size, \
1)
)
mp_group = dist.new_group(ranks)
if global_rank in ranks:
self.mp_group = mp_group
else:
self.mp_group = self.mpu.get_model_parallel_group() if hasattr(self.mpu, 'get_model_parallel_group') else \
self.mpu.get_tensor_model_parallel_group()
else:
self.mp_group = None
self.populate_all_inference_policies()
self.all_layers_params = list(self.module.parameters())
self.create_inference_containers(self.module)
if len(self._inference_containers) > 0:
self._generate = self.module.generate
self.module.generate = self.generate
self._t0 = time.time()
def _zero3_forward(self, layer_id):
def run_forward(*inputs, **kwargs):
non_active_params = get_inactive_params(self.layer_params[layer_id])
non_active_lora_params = get_inactive_params(self.layer_lora_params[layer_id])
non_active_params.extend(non_active_lora_params)
with GatheredParameters(non_active_params):
if len(self.all_lora_params) > 0:
# Use the is_lora_fused flag to prevent multiple fusion in Z3 with non-pinned memory
if not self.is_lora_fused:
self._fuse_lora(self.layer_params[layer_id], self.lora_params[layer_id])
# Set the is_lora_fused to true when reaching the last layer
if layer_id == len(self.layer_params) - 1:
self.is_lora_fused = True
return self._inference_containers[layer_id].module.forward(*inputs, **kwargs)
return run_forward
def eval(self):
if self._t_start is not None:
latency = time.time() - self._t_start
self._total_latency = self._total_latency + latency
self._iters = self._iters + 1
if not dist.is_initialized() or dist.get_rank() == 0:
others = latency - (self._generate_latency + self._training_latency)
print(f'|E2E latency={(latency):.2f}s ' + \
f'|Gather latency={self._gather_latency:.2f}s ({(self._gather_latency / latency * 100):.2f}%) '
f'|Generate time={(self._generate_latency):.2f}s ({(self._generate_latency / latency * 100):.2f}%) ' + \
f'|Training time={(self._training_latency):.2f}s ({(self._training_latency / latency * 100):.2f}%) ' + \
f'|Others={others:.2f} ({(others / latency * 100):.2f}%)'
f'|CurSamplesPerSec={(1 / latency * self._total_batch_size):.2f} ' + \
f'|AvgSamplesPerSec={(1 / (self._total_latency / self._iters) * self._total_batch_size):.2f}')
self._t_start = time.time()
self._training_latency = 0
super().eval()
if len(self._inference_containers) > 0:
for i, (orig_module, inference_container) in enumerate(zip(self._orig_modules,
self._inference_containers)):
if self.Z3_enabled and not self.gather_all_layers:
orig_module.forward = self._zero3_forward(i)
else:
orig_module.forward = inference_container.module.forward
inference_container.align_merged_qkv()
if not self.Z3_enabled or self.gather_all_layers:
for orig_module, inference_layer in zip(self._orig_modules_others, self._other_layers):
orig_module.forward = inference_layer.forward
if self.Z3_enabled:
gc.collect()
get_accelerator().empty_cache()
if self._t_start is None:
self._t_start = time.time()
def train(self, mode=True):
if mode and len(self._orig_modules) > 0:
for inference_container, orig_module, orig_fwd in zip(self._inference_containers, self._orig_modules,
self._orig_fwds):
inference_container.partition_merged_qkv()
orig_module.forward = orig_fwd
for orig_module, orig_fwd in zip(self._orig_modules_others, self._orig_fwds_others):
orig_module.forward = orig_fwd
super().train(mode)
if mode:
self._training_start_time = time.time()
def step(self, lr_kwargs=None):
super().step(lr_kwargs=lr_kwargs)
if len(self._inference_containers) > 0:
if(self._inference_containers[0].module.attention.attn_qkvw is not None and \
self._inference_containers[0].q_k_v is not None):
for inference_container in self._inference_containers:
inference_container.reset_qkv()
if self._training_start_time is not None:
self._training_latency += (time.time() - self._training_start_time)
self._training_start_time = time.time()
"""
Copyright 2019 The Microsoft DeepSpeed Team
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
"""
Implementation of learning rate schedules.
Taken and modified from PyTorch v1.0.1 source
https://github.com/pytorch/pytorch/blob/v1.1.0/torch/optim/lr_scheduler.py
"""
import argparse
......@@ -53,28 +54,15 @@ TOTAL_NUM_STEPS = 'total_num_steps'
def add_tuning_arguments(parser):
group = parser.add_argument_group('Convergence Tuning',
'Convergence tuning configurations')
group = parser.add_argument_group('Convergence Tuning', 'Convergence tuning configurations')
# LR scheduler
group.add_argument('--lr_schedule',
type=str,
default=None,
help='LR schedule for training.')
group.add_argument('--lr_schedule', type=str, default=None, help='LR schedule for training.')
# Learning rate range test
group.add_argument("--lr_range_test_min_lr",
type=float,
default=0.001,
help='Starting lr value.')
group.add_argument("--lr_range_test_step_rate",
type=float,
default=1.0,
help='scaling rate for LR range test.')
group.add_argument("--lr_range_test_step_size",
type=int,
default=1000,
help='training steps per LR change.')
group.add_argument("--lr_range_test_min_lr", type=float, default=0.001, help='Starting lr value.')
group.add_argument("--lr_range_test_step_rate", type=float, default=1.0, help='scaling rate for LR range test.')
group.add_argument("--lr_range_test_step_size", type=int, default=1000, help='training steps per LR change.')
group.add_argument("--lr_range_test_staircase",
type=bool,
default=False,
......@@ -89,66 +77,34 @@ def add_tuning_arguments(parser):
type=int,
default=-1,
help='first stair count for 1Cycle schedule.')
group.add_argument(
"--cycle_second_step_size",
type=int,
default=-1,
help='size of second step of 1Cycle schedule (default first_step_size).')
group.add_argument("--cycle_second_step_size",
type=int,
default=-1,
help='size of second step of 1Cycle schedule (default first_step_size).')
group.add_argument("--cycle_second_stair_count",
type=int,
default=-1,
help='second stair count for 1Cycle schedule.')
group.add_argument(
"--decay_step_size",
type=int,
default=1000,
help='size of intervals for applying post cycle decay (training steps).')
group.add_argument("--decay_step_size",
type=int,
default=1000,
help='size of intervals for applying post cycle decay (training steps).')
# 1Cycle LR
group.add_argument("--cycle_min_lr",
type=float,
default=0.01,
help='1Cycle LR lower bound.')
group.add_argument("--cycle_max_lr",
type=float,
default=0.1,
help='1Cycle LR upper bound.')
group.add_argument("--decay_lr_rate",
type=float,
default=0.0,
help='post cycle LR decay rate.')
group.add_argument("--cycle_min_lr", type=float, default=0.01, help='1Cycle LR lower bound.')
group.add_argument("--cycle_max_lr", type=float, default=0.1, help='1Cycle LR upper bound.')
group.add_argument("--decay_lr_rate", type=float, default=0.0, help='post cycle LR decay rate.')
# 1Cycle Momentum
group.add_argument('--cycle_momentum',
default=False,
action='store_true',
help='Enable 1Cycle momentum schedule.')
group.add_argument("--cycle_min_mom",
type=float,
default=0.8,
help='1Cycle momentum lower bound.')
group.add_argument("--cycle_max_mom",
type=float,
default=0.9,
help='1Cycle momentum upper bound.')
group.add_argument("--decay_mom_rate",
type=float,
default=0.0,
help='post cycle momentum decay rate.')
group.add_argument('--cycle_momentum', default=False, action='store_true', help='Enable 1Cycle momentum schedule.')
group.add_argument("--cycle_min_mom", type=float, default=0.8, help='1Cycle momentum lower bound.')
group.add_argument("--cycle_max_mom", type=float, default=0.9, help='1Cycle momentum upper bound.')
group.add_argument("--decay_mom_rate", type=float, default=0.0, help='post cycle momentum decay rate.')
# Warmup LR
group.add_argument('--warmup_min_lr',
type=float,
default=0,
help='WarmupLR minimum/initial LR value')
group.add_argument('--warmup_max_lr',
type=float,
default=0.001,
help='WarmupLR maximum LR value.')
group.add_argument('--warmup_num_steps',
type=int,
default=1000,
help='WarmupLR step count for LR warmup.')
group.add_argument('--warmup_min_lr', type=float, default=0, help='WarmupLR minimum/initial LR value')
group.add_argument('--warmup_max_lr', type=float, default=0.001, help='WarmupLR maximum LR value.')
group.add_argument('--warmup_num_steps', type=int, default=1000, help='WarmupLR step count for LR warmup.')
group.add_argument('--warmup_type',
type=str,
default=WARMUP_LOG_RATE,
......@@ -168,16 +124,13 @@ def override_lr_range_test_params(args, params):
if hasattr(args, LR_RANGE_TEST_MIN_LR) and args.lr_range_test_min_lr is not None:
params[LR_RANGE_TEST_MIN_LR] = args.lr_range_test_min_lr
if hasattr(args,
LR_RANGE_TEST_STEP_RATE) and args.lr_range_test_step_rate is not None:
if hasattr(args, LR_RANGE_TEST_STEP_RATE) and args.lr_range_test_step_rate is not None:
params[LR_RANGE_TEST_STEP_RATE] = args.lr_range_test_step_rate
if hasattr(args,
LR_RANGE_TEST_STEP_SIZE) and args.lr_range_test_step_size is not None:
if hasattr(args, LR_RANGE_TEST_STEP_SIZE) and args.lr_range_test_step_size is not None:
params[LR_RANGE_TEST_STEP_SIZE] = args.lr_range_test_step_size
if hasattr(args,
LR_RANGE_TEST_STAIRCASE) and args.lr_range_test_staircase is not None:
if hasattr(args, LR_RANGE_TEST_STAIRCASE) and args.lr_range_test_staircase is not None:
params[LR_RANGE_TEST_STAIRCASE] = args.lr_range_test_staircase
......@@ -185,15 +138,13 @@ def override_1cycle_params(args, params):
if hasattr(args, CYCLE_FIRST_STEP_SIZE) and args.cycle_first_step_size is not None:
params[CYCLE_FIRST_STEP_SIZE] = args.cycle_first_step_size
if hasattr(args,
CYCLE_FIRST_STAIR_COUNT) and args.cycle_first_stair_count is not None:
if hasattr(args, CYCLE_FIRST_STAIR_COUNT) and args.cycle_first_stair_count is not None:
params[CYCLE_FIRST_STAIR_COUNT] = args.cycle_first_stair_count
if hasattr(args, CYCLE_SECOND_STEP_SIZE) and args.cycle_second_step_size is not None:
params[CYCLE_SECOND_STEP_SIZE] = args.cycle_second_step_size
if hasattr(args,
CYCLE_SECOND_STAIR_COUNT) and args.cycle_second_stair_count is not None:
if hasattr(args, CYCLE_SECOND_STAIR_COUNT) and args.cycle_second_stair_count is not None:
params[CYCLE_SECOND_STAIR_COUNT] = args.cycle_second_stair_count
if hasattr(args, DECAY_STEP_SIZE) and args.decay_step_size is not None:
......@@ -301,8 +252,7 @@ def get_torch_optimizer(optimizer):
if hasattr(optimizer, 'optimizer') and isinstance(optimizer.optimizer, Optimizer):
return optimizer.optimizer
raise TypeError('{} is not a subclass of torch.optim.Optimizer'.format(
type(optimizer).__name__))
raise TypeError('{} is not a subclass of torch.optim.Optimizer'.format(type(optimizer).__name__))
class LRRangeTest(object):
......@@ -343,6 +293,7 @@ class LRRangeTest(object):
_A disciplined approach to neural network hyper-parameters: Part 1 -- learning rate, batch size, momentum, and weight decay:
https://arxiv.org/abs/1803.09820
"""
def __init__(self,
optimizer: Optimizer,
lr_range_test_min_lr: float = 1e-3,
......@@ -353,13 +304,10 @@ class LRRangeTest(object):
self.optimizer = get_torch_optimizer(optimizer)
if isinstance(lr_range_test_min_lr,
list) or isinstance(lr_range_test_min_lr,
tuple):
if isinstance(lr_range_test_min_lr, list) or isinstance(lr_range_test_min_lr, tuple):
if len(lr_range_test_min_lr) != len(self.optimizer.param_groups):
raise ValueError("expected {} lr_range_test_min_lr, got {}".format(
len(self.optimizer.param_groups),
len(lr_range_test_min_lr)))
raise ValueError("expected {} lr_range_test_min_lr, got {}".format(len(self.optimizer.param_groups),
len(lr_range_test_min_lr)))
self.min_lr = list(lr_range_test_min_lr)
else:
self.min_lr = [lr_range_test_min_lr] * len(self.optimizer.param_groups)
......@@ -384,9 +332,7 @@ class LRRangeTest(object):
def get_lr(self):
lr_increase = self._get_increase()
return [
lr_range_test_min_lr * lr_increase for lr_range_test_min_lr in self.min_lr
]
return [lr_range_test_min_lr * lr_increase for lr_range_test_min_lr in self.min_lr]
def get_last_lr(self):
""" Return last computed learning rate by current scheduler.
......@@ -480,6 +426,7 @@ class OneCycle(object):
.. _A disciplined approach to neural network hyper-parameters: Part 1 -- learning rate, batch size, momentum, and weight decay: https://arxiv.org/abs/1803.09820
"""
def __init__(self,
optimizer,
cycle_min_lr,
......@@ -499,26 +446,16 @@ class OneCycle(object):
self.optimizer = get_torch_optimizer(optimizer)
# Initialize cycle shape
self._initialize_cycle(cycle_first_step_size,
cycle_second_step_size,
cycle_first_stair_count,
cycle_second_stair_count,
decay_step_size)
self._initialize_cycle(cycle_first_step_size, cycle_second_step_size, cycle_first_stair_count,
cycle_second_stair_count, decay_step_size)
# Initialize cycle lr
self._initialize_lr(self.optimizer,
cycle_min_lr,
cycle_max_lr,
decay_lr_rate,
last_batch_iteration)
self._initialize_lr(self.optimizer, cycle_min_lr, cycle_max_lr, decay_lr_rate, last_batch_iteration)
# Initialize cyclic momentum
self.cycle_momentum = cycle_momentum
if cycle_momentum:
self._initialize_momentum(self.optimizer,
cycle_min_mom,
cycle_max_mom,
decay_mom_rate,
self._initialize_momentum(self.optimizer, cycle_min_mom, cycle_max_mom, decay_mom_rate,
last_batch_iteration)
# Initialize batch iteration tracker
......@@ -526,16 +463,11 @@ class OneCycle(object):
# Configure cycle shape
def _initialize_cycle(self,
cycle_first_step_size,
cycle_second_step_size,
cycle_first_stair_count,
cycle_second_stair_count,
decay_step_size):
def _initialize_cycle(self, cycle_first_step_size, cycle_second_step_size, cycle_first_stair_count,
cycle_second_stair_count, decay_step_size):
cycle_first_step_size = float(cycle_first_step_size)
cycle_second_step_size = float(
cycle_second_step_size
) if cycle_second_step_size is not None else cycle_first_step_size
cycle_second_step_size) if cycle_second_step_size is not None else cycle_first_step_size
self.total_size = cycle_first_step_size + cycle_second_step_size
self.step_ratio = cycle_first_step_size / self.total_size
......@@ -551,12 +483,7 @@ class OneCycle(object):
self.skip_mom_decay = False
# Configure lr schedule
def _initialize_lr(self,
optimizer,
cycle_min_lr,
cycle_max_lr,
decay_lr_rate,
last_batch_iteration):
def _initialize_lr(self, optimizer, cycle_min_lr, cycle_max_lr, decay_lr_rate, last_batch_iteration):
self.min_lrs = [cycle_min_lr] * len(optimizer.param_groups)
if last_batch_iteration == -1:
for lr, group in zip(self.min_lrs, optimizer.param_groups):
......@@ -569,12 +496,7 @@ class OneCycle(object):
self.skip_lr_decay = True
# Configure momentum schedule
def _initialize_momentum(self,
optimizer,
cycle_min_mom,
cycle_max_mom,
decay_mom_rate,
last_batch_iteration):
def _initialize_momentum(self, optimizer, cycle_min_mom, cycle_max_mom, decay_mom_rate, last_batch_iteration):
if 'betas' not in optimizer.defaults:
optimizer_name = type(optimizer).__name__
logger.warn(
......@@ -722,6 +644,7 @@ class WarmupLR(object):
>>> scheduler.step()
"""
def __init__(self,
optimizer: Optimizer,
warmup_min_lr: float = 0.0,
......@@ -738,9 +661,8 @@ class WarmupLR(object):
self.warmup_num_steps = max(2, warmup_num_steps)
# Currently only support linear and log function
if warmup_type not in {WARMUP_LOG_RATE, WARMUP_LINEAR_RATE}:
logger.warning(
f"Using unknown warmup_type: {warmup_type}. The increasing function "
f"is set to default (log)")
logger.warning(f"Using unknown warmup_type: {warmup_type}. The increasing function "
f"is set to default (log)")
warmup_type = WARMUP_LOG_RATE
self.warmup_type = warmup_type
self.inverse_log_warm_up = 1.0 / math.log(self.warmup_num_steps)
......@@ -748,15 +670,10 @@ class WarmupLR(object):
def get_lr(self):
if self.last_batch_iteration < 0:
logger.warning(
"Attempting to get learning rate from scheduler before it has started")
logger.warning("Attempting to get learning rate from scheduler before it has started")
return [0.0]
gamma = self._get_gamma()
return [
min_lr + (delta_lr * gamma) for min_lr,
delta_lr in zip(self.min_lrs,
self.delta_lrs)
]
return [min_lr + (delta_lr * gamma) for min_lr, delta_lr in zip(self.min_lrs, self.delta_lrs)]
def get_last_lr(self):
""" Return last computed learning rate by current scheduler.
......@@ -789,10 +706,8 @@ class WarmupLR(object):
def _format_param(self, optimizer, param_value, param_name):
if isinstance(param_value, list) or isinstance(param_value, tuple):
if len(param_value) != len(optimizer.param_groups):
raise ValueError("expected {} value for {}, got {}".format(
len(optimizer.param_groups),
param_name,
FileNotFoundError(param_value)))
raise ValueError("expected {} value for {}, got {}".format(len(optimizer.param_groups), param_name,
FileNotFoundError(param_value)))
return list(param_value)
return [param_value] * len(optimizer.param_groups)
......@@ -819,6 +734,7 @@ class WarmupDecayLR(WarmupLR):
>>> scheduler.step()
"""
def __init__(self,
optimizer: Optimizer,
total_num_steps: int,
......@@ -829,17 +745,11 @@ class WarmupDecayLR(WarmupLR):
last_batch_iteration: int = -1):
self.total_num_steps = total_num_steps
super(WarmupDecayLR,
self).__init__(optimizer,
warmup_min_lr,
warmup_max_lr,
warmup_num_steps,
warmup_type,
last_batch_iteration)
super(WarmupDecayLR, self).__init__(optimizer, warmup_min_lr, warmup_max_lr, warmup_num_steps, warmup_type,
last_batch_iteration)
if self.total_num_steps < self.warmup_num_steps:
logger.warning('total_num_steps {} is less than warmup_num_steps {}'.format(
total_num_steps,
warmup_num_steps))
total_num_steps, warmup_num_steps))
def _get_gamma(self):
if self.last_batch_iteration < self.warmup_num_steps:
......@@ -850,5 +760,4 @@ class WarmupDecayLR(WarmupLR):
return max(
0.0,
float(self.total_num_steps - self.last_batch_iteration) /
float(max(1.0,
self.total_num_steps - self.warmup_num_steps)))
float(max(1.0, self.total_num_steps - self.warmup_num_steps)))
'''Copyright The Microsoft DeepSpeed Team'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .module import PipelineModule, LayerSpec, TiedLayerSpec
from .topology import ProcessTopology
# Copyright 2019 The Microsoft DeepSpeed Team
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from types import MethodType
......@@ -12,6 +15,8 @@ from deepspeed.accelerator import get_accelerator
from ..engine import DeepSpeedEngine, MEMORY_OPT_ALLREDUCE_SIZE
from ..utils import PartitionedTensor
from ..dataloader import RepeatingLoader
from ..zero.config import ZeroStageEnum
from ..activation_checkpointing import checkpointing as ds_checkpointing
from .module import PipelineModule, PipelineError
from . import p2p
......@@ -41,18 +46,8 @@ class PipelineEngine(DeepSpeedEngine):
is provided.
"""
ID_TO_DTYPE = [
torch.float32,
torch.float64,
torch.complex64,
torch.complex128,
torch.float16,
torch.bfloat16,
torch.uint8,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.bool
torch.float32, torch.float64, torch.complex64, torch.complex128, torch.float16, torch.bfloat16, torch.uint8,
torch.int8, torch.int16, torch.int32, torch.int64, torch.bool
]
DTYPE_TO_ID = {dtype: id_ for id_, dtype in enumerate(ID_TO_DTYPE)}
......@@ -134,8 +129,7 @@ class PipelineEngine(DeepSpeedEngine):
if self.global_rank != min(d['ranks']):
tied_params += sum(p.numel() for p in d['module'].parameters())
unique_params -= tied_params
params_tensor = torch.LongTensor(data=[num_params,
unique_params]).to(self.device)
params_tensor = torch.LongTensor(data=[num_params, unique_params]).to(self.device)
dist.all_reduce(params_tensor, group=self.grid.get_model_parallel_group())
params_tensor = params_tensor.tolist()
total_params = params_tensor[0]
......@@ -156,10 +150,10 @@ class PipelineEngine(DeepSpeedEngine):
# Pipeline buffers
self.num_pipe_buffers = 0
self.pipe_buffers = {
'inputs' : [], # batch input and received activations
'labels' : [], # labels from batch input
'outputs' : [], # activations
'output_tensors' : [], # tensor object to preserve backward graph
'inputs': [], # batch input and received activations
'labels': [], # labels from batch input
'outputs': [], # activations
'output_tensors': [], # tensor object to preserve backward graph
}
self.pipe_recv_buf = None
self.grad_layer = None
......@@ -178,8 +172,7 @@ class PipelineEngine(DeepSpeedEngine):
self.dp_group_loss = torch.tensor(0.0, requires_grad=False).to(self.device)
if self._config.pipeline['activation_checkpoint_interval'] > 0:
self.module.activation_checkpoint_interval = self._config.pipeline[
'activation_checkpoint_interval']
self.module.activation_checkpoint_interval = self._config.pipeline['activation_checkpoint_interval']
self.module.checkpoint_parallel_write_pipeline = self._config.checkpoint_parallel_write_pipeline
......@@ -220,11 +213,10 @@ class PipelineEngine(DeepSpeedEngine):
self.has_attention_mask = value
def _build_data_iter(self, dataset):
sampler = torch.utils.data.distributed.DistributedSampler(
dataset,
num_replicas=self.dp_world_size,
rank=self.mpu.get_data_parallel_rank(),
shuffle=False)
sampler = torch.utils.data.distributed.DistributedSampler(dataset,
num_replicas=self.dp_world_size,
rank=self.mpu.get_data_parallel_rank(),
shuffle=False)
# Build a loader and make it repeating.
pipe_dataloader = self.deepspeed_io(dataset, data_sampler=sampler)
pipe_dataloader = RepeatingLoader(pipe_dataloader)
......@@ -251,11 +243,10 @@ class PipelineEngine(DeepSpeedEngine):
self._force_grad_boundary = True
if self.pipeline_enable_backward_allreduce:
if self.bfloat16_enabled():
if self.zero_optimization_stage() == 0:
if self.zero_optimization_stage() < ZeroStageEnum().gradients:
self._bf16_reduce_grads()
else:
assert self.zero_optimization_stage() == 1, "only bf16 + z1 are supported"
raise NotImplementedError()
raise NotImplementedError("PP+BF16 only work for ZeRO Stage 1")
else:
self.allreduce_gradients(bucket_size=MEMORY_OPT_ALLREDUCE_SIZE)
self._force_grad_boundary = False
......@@ -317,8 +308,7 @@ class PipelineEngine(DeepSpeedEngine):
The arithmetic mean of the losses computed this batch.
"""
if not torch._C.is_grad_enabled():
raise RuntimeError(
f'train_batch() requires gradients enabled. Use eval_batch() instead.')
raise RuntimeError(f'train_batch() requires gradients enabled. Use eval_batch() instead.')
# Curriculum learning could change activation shape
if self.curriculum_enabled_legacy():
......@@ -360,28 +350,17 @@ class PipelineEngine(DeepSpeedEngine):
# Monitoring
if self.global_rank == 0 and self.monitor.enabled:
self.summary_events = [(f'Train/Samples/train_loss',
self.agg_train_loss.mean().item(),
self.summary_events = [(f'Train/Samples/train_loss', self.agg_train_loss.mean().item(),
self.global_samples)]
self.monitor.write_events(self.summary_events)
if self.wall_clock_breakdown(
) and self.global_steps % self.steps_per_print() == 0:
self.timers.log([
'pipe_send_output',
'pipe_send_grad',
'pipe_recv_input',
'pipe_recv_grad'
])
if self.wall_clock_breakdown() and self.global_steps % self.steps_per_print() == 0:
self.timers.log(['pipe_send_output', 'pipe_send_grad', 'pipe_recv_input', 'pipe_recv_grad'])
# TODO: should return precisely what loss returned and allow others to be queried?
return self.agg_train_loss
def eval_batch(self,
data_iter,
return_logits=False,
compute_loss=True,
reduce_output='avg'):
def eval_batch(self, data_iter, return_logits=False, compute_loss=True, reduce_output='avg'):
"""Evaluate the pipeline on a batch of data from ``data_iter``. The
engine will evaluate ``self.train_batch_size()`` total samples
collectively across all workers.
......@@ -448,9 +427,7 @@ class PipelineEngine(DeepSpeedEngine):
eval_output = self._bcast_pipe_scalar(eval_output)
if self.global_rank == 0 and self.monitor.enabled:
self.summary_events = [(f'Train/Samples/eval_loss',
eval_output.mean().item(),
self.global_samples)]
self.summary_events = [(f'Train/Samples/eval_loss', eval_output.mean().item(), self.global_samples)]
self.monitor.write_events(self.summary_events)
# Restore the training iterator
......@@ -510,8 +487,7 @@ class PipelineEngine(DeepSpeedEngine):
reduced /= self.dp_world_size
else:
for idx in range(len(reduced)):
dist.all_reduce(reduced[idx],
group=self.mpu.get_data_parallel_group())
dist.all_reduce(reduced[idx], group=self.mpu.get_data_parallel_group())
reduced[idx] /= self.dp_world_size
return reduced
......@@ -525,13 +501,11 @@ class PipelineEngine(DeepSpeedEngine):
assert src_rank in self.grid.pp_group
if self.global_rank == src_rank:
result = data.clone().detach()
result = data.clone().detach().type(dtype).to(self.device)
else:
result = torch.Tensor([0.]).type(dtype).to(self.device)
dist.broadcast(tensor=result,
src=src_rank,
group=self.mpu.get_pipe_parallel_group())
dist.broadcast(tensor=result, src=src_rank, group=self.mpu.get_pipe_parallel_group())
return result
......@@ -550,18 +524,14 @@ class PipelineEngine(DeepSpeedEngine):
assert self.global_rank in self.grid.pp_group
losses = torch.Tensor([self.dp_group_loss, agg_loss]).to(self.device)
dist.broadcast(tensor=losses,
src=self.global_rank,
group=self.mpu.get_pipe_parallel_group())
dist.broadcast(tensor=losses, src=self.global_rank, group=self.mpu.get_pipe_parallel_group())
else:
# Get loss from last stage
src_rank = self.grid.stage_to_global(self.num_stages - 1)
assert src_rank in self.grid.pp_group
losses = torch.Tensor([0., 0.]).to(self.device)
dist.broadcast(tensor=losses,
src=src_rank,
group=self.grid.get_pipe_parallel_group())
dist.broadcast(tensor=losses, src=src_rank, group=self.grid.get_pipe_parallel_group())
self.dp_group_loss = losses[0].clone().detach()
agg_loss = losses[1].clone().detach()
......@@ -638,10 +608,9 @@ class PipelineEngine(DeepSpeedEngine):
# collect the partitioned input from the previous stage
if self.is_pipe_partitioned and not self.is_first_stage():
part_input = PartitionedTensor.from_meta(
meta=inputs[0],
local_part=inputs[1],
group=self.grid.get_slice_parallel_group())
part_input = PartitionedTensor.from_meta(meta=inputs[0],
local_part=inputs[1],
group=self.grid.get_slice_parallel_group())
inputs = (part_input.full(), *inputs[2:])
inputs[0].requires_grad = True
......@@ -657,23 +626,24 @@ class PipelineEngine(DeepSpeedEngine):
outputs = super().forward(inputs)
# Reset activation checkpointing buffers.
# Need to call this between evaluation iterations
if not self.module.training:
ds_checkpointing.reset()
# Partition the outputs if we are not the last stage
if self.is_pipe_partitioned and not self.is_last_stage():
if isinstance(outputs, tuple):
first_output = outputs[0]
# TODO: Improve pipe partitioning to pass multiple tensors that require grads
assert all([
torch.is_tensor(elt) and elt.requires_grad is False
for elt in outputs[1:]
])
assert all([torch.is_tensor(elt) and elt.requires_grad is False for elt in outputs[1:]])
outputs_tail = outputs[1:]
elif torch.is_tensor(outputs):
first_output = outputs
outputs_tail = []
else:
raise ValueError("expecting a tensor or a tuple of tensors")
part = PartitionedTensor(tensor=first_output,
group=self.grid.get_slice_parallel_group())
part = PartitionedTensor(tensor=first_output, group=self.grid.get_slice_parallel_group())
# Clear the large output data, but save the computation graph
first_output.data = torch.zeros(1)
self.pipe_buffers['output_tensors'][buffer_id] = first_output
......@@ -732,10 +702,9 @@ class PipelineEngine(DeepSpeedEngine):
# careful to also restore the computational graph of the tensors we partitioned.
if self.is_pipe_partitioned:
if self.is_grad_partitioned:
part_output = PartitionedTensor.from_meta(
meta=outputs[0],
local_part=outputs[1],
group=self.grid.get_slice_parallel_group())
part_output = PartitionedTensor.from_meta(meta=outputs[0],
local_part=outputs[1],
group=self.grid.get_slice_parallel_group())
self.pipe_buffers['output_tensors'][buffer_id].data = part_output.full()
outputs = (self.pipe_buffers['output_tensors'][buffer_id], *outputs[2:])
else:
......@@ -746,10 +715,9 @@ class PipelineEngine(DeepSpeedEngine):
grad_tensors = self.grad_layer
if self.is_grad_partitioned:
#print(f'RANK={self.global_rank} BEFORE-BWD restoring grad={self.grad_layer[0].size()} {self.grad_layer[1].size()}')
part_grad = PartitionedTensor.from_meta(
meta=self.grad_layer[0],
local_part=self.grad_layer[1],
group=self.grid.get_slice_parallel_group())
part_grad = PartitionedTensor.from_meta(meta=self.grad_layer[0],
local_part=self.grad_layer[1],
group=self.grid.get_slice_parallel_group())
grad_tensors = (part_grad.full(), *grad_tensors[2:])
part_grad = None
#print(f'RANK={self.global_rank} BEFORE-BWD restored grad={self.grad_layer[0].size()} {self.grad_layer[1].size()}')
......@@ -795,7 +763,7 @@ class PipelineEngine(DeepSpeedEngine):
loaded = batch[0].clone().to(self.device).detach()
loaded.requires_grad = loaded.is_floating_point()
else:
assert isinstance(batch[0], tuple)
assert isinstance(batch[0], (tuple, list))
# Assume list or tuple
loaded = []
for x in batch[0]:
......@@ -865,8 +833,7 @@ class PipelineEngine(DeepSpeedEngine):
assert isinstance(tensor, torch.Tensor)
send_shape = torch.LongTensor(data=tensor.size()).to(self.device)
send_ndims = torch.LongTensor(data=[len(tensor.size())]).to(self.device)
send_dtype = torch.LongTensor(data=[self.DTYPE_TO_ID[tensor.dtype]]).to(
self.device)
send_dtype = torch.LongTensor(data=[self.DTYPE_TO_ID[tensor.dtype]]).to(self.device)
p2p.send(send_dtype, recv_stage)
p2p.send(send_ndims, recv_stage)
p2p.send(send_shape, recv_stage)
......@@ -990,17 +957,14 @@ class PipelineEngine(DeepSpeedEngine):
if isinstance(inputs, tuple):
first_input = inputs[0]
assert all([torch.is_tensor(elt) for elt in inputs[1:]])
inputs_grad_tail = [
elt.grad for elt in inputs[1:] if elt.grad is not None
]
inputs_grad_tail = [elt.grad for elt in inputs[1:] if elt.grad is not None]
elif torch.is_tensor(inputs):
first_input = inputs
inputs_grad_tail = []
else:
raise ValueError("expecting a tensor or a tuple of tensors")
assert torch.is_tensor(first_input)
part = PartitionedTensor(tensor=first_input.grad,
group=self.grid.get_slice_parallel_group())
part = PartitionedTensor(tensor=first_input.grad, group=self.grid.get_slice_parallel_group())
inputs = (part.to_meta(), part.data(), *inputs_grad_tail)
......@@ -1060,9 +1024,7 @@ class PipelineEngine(DeepSpeedEngine):
# XXX hardcode meta type
if self.is_pipe_partitioned and idx == 0 and buffer.dtype != torch.long:
if self.meta_buffer is None:
self.meta_buffer = torch.zeros(buffer.size(),
dtype=torch.long,
device=self.device)
self.meta_buffer = torch.zeros(buffer.size(), dtype=torch.long, device=self.device)
buffer = self.meta_buffer
p2p.recv(buffer, self.prev_stage)
......@@ -1091,10 +1053,9 @@ class PipelineEngine(DeepSpeedEngine):
# XXX these shapes are hardcoded for Megatron
# Restore partitioned output if it was partitioned and we are sending full gradients
if self.is_pipe_partitioned and not self.is_grad_partitioned:
part_output = PartitionedTensor.from_meta(
meta=outputs[0],
local_part=outputs[1],
group=self.grid.get_slice_parallel_group())
part_output = PartitionedTensor.from_meta(meta=outputs[0],
local_part=outputs[1],
group=self.grid.get_slice_parallel_group())
outputs[0].data = part_output.full()
outputs = (outputs[0], *outputs[2:])
# save for backward
......@@ -1104,9 +1065,7 @@ class PipelineEngine(DeepSpeedEngine):
if self.grad_layer is None:
if isinstance(outputs, torch.Tensor):
s = list(outputs.size())
self.grad_layer = self._allocate_buffer(s,
dtype=outputs.dtype,
num_buffers=1)[0]
self.grad_layer = self._allocate_buffer(s, dtype=outputs.dtype, num_buffers=1)[0]
else:
# XXX This is a HACK
# When we exchange activations/gradients, the two pipe stages
......@@ -1123,17 +1082,12 @@ class PipelineEngine(DeepSpeedEngine):
# branches on is_grad_partitioned so we don't filter out the
# metadata tensor.
if self.is_grad_partitioned:
sizes_and_dtypes = [
(list(t.size()),
t.dtype) for t in outputs[:2]
] + [(list(t.size()),
t.dtype) for t in outputs[2:] if t.is_floating_point()]
sizes_and_dtypes = [(list(t.size()), t.dtype)
for t in outputs[:2]] + [(list(t.size()), t.dtype)
for t in outputs[2:] if t.is_floating_point()]
else:
sizes_and_dtypes = [(list(t.size()),
t.dtype) for t in outputs
if t.is_floating_point()]
self.grad_layer = self._allocate_buffers(sizes_and_dtypes,
num_buffers=1)[0]
sizes_and_dtypes = [(list(t.size()), t.dtype) for t in outputs if t.is_floating_point()]
self.grad_layer = self._allocate_buffers(sizes_and_dtypes, num_buffers=1)[0]
if isinstance(self.grad_layer, torch.Tensor):
p2p.recv(self.grad_layer, self.next_stage)
......@@ -1142,9 +1096,7 @@ class PipelineEngine(DeepSpeedEngine):
for idx, buffer in enumerate(self.grad_layer):
# XXX GPT-2 hack
if self.is_grad_partitioned and idx == 0 and buffer.dtype != torch.long:
buffer.data = torch.zeros(buffer.size(),
dtype=torch.long,
device=self.device)
buffer.data = torch.zeros(buffer.size(), dtype=torch.long, device=self.device)
p2p.recv(buffer, self.next_stage)
if self.wall_clock_breakdown():
......@@ -1163,13 +1115,10 @@ class PipelineEngine(DeepSpeedEngine):
self.mem_status('AFTER STEP')
if self.global_rank == 0 and self.monitor.enabled:
self.summary_events = [(f'Train/Samples/lr',
self.get_lr()[0],
self.global_samples)]
self.summary_events = [(f'Train/Samples/lr', self.get_lr()[0], self.global_samples)]
if self.fp16_enabled() and hasattr(self.optimizer, 'cur_scale'):
self.summary_events.append((f'Train/Samples/loss_scale',
self.optimizer.cur_scale,
self.global_samples))
self.summary_events.append(
(f'Train/Samples/loss_scale', self.optimizer.cur_scale, self.global_samples))
self.monitor.write_events(self.summary_events)
if self.wall_clock_breakdown():
......@@ -1177,22 +1126,11 @@ class PipelineEngine(DeepSpeedEngine):
self.timers('step').stop()
if self.global_steps % self.steps_per_print() == 0:
self.timers.log([
'batch_input',
'forward_microstep',
'backward_microstep',
'backward_inner_microstep',
'backward_allreduce_microstep',
'backward_tied_allreduce_microstep',
'step_microstep'
'batch_input', 'forward_microstep', 'backward_microstep', 'backward_inner_microstep',
'backward_allreduce_microstep', 'backward_tied_allreduce_microstep', 'step_microstep'
])
if self.global_steps % self.steps_per_print() == 0:
self.timers.log([
'forward',
'backward',
'backward_inner',
'backward_allreduce',
'step'
])
self.timers.log(['forward', 'backward', 'backward_inner', 'backward_allreduce', 'step'])
def _zero_grads(self, inputs):
if isinstance(inputs, torch.Tensor):
......@@ -1236,10 +1174,7 @@ class PipelineEngine(DeepSpeedEngine):
for count in range(num_buffers):
buffer = []
for shape, dtype in shapes_and_dtypes:
buffer.append(
self._allocate_zeros(shape,
dtype=dtype,
requires_grad=requires_grad))
buffer.append(self._allocate_zeros(shape, dtype=dtype, requires_grad=requires_grad))
buffers.append(buffer)
return buffers
......@@ -1298,11 +1233,9 @@ class PipelineEngine(DeepSpeedEngine):
max_cached /= 1024**3
print(
f'RANK={rank} STAGE={self.stage_id} STEP={self.global_steps} MEMSTATS',
msg,
f'RANK={rank} STAGE={self.stage_id} STEP={self.global_steps} MEMSTATS', msg,
f'current alloc={new_alloced:0.4f}GB (delta={delta_alloced:0.4f}GB max={max_alloced:0.4f}GB) '
f'current cache={new_cached:0.4f}GB (delta={delta_cached:0.4f}GB max={max_cached:0.4f}GB)'
)
f'current cache={new_cached:0.4f}GB (delta={delta_cached:0.4f}GB max={max_cached:0.4f}GB)')
def module_state_dict(self):
"""Override hack to save a pipe model and return the directory path of the save.
......@@ -1318,11 +1251,10 @@ class PipelineEngine(DeepSpeedEngine):
assert self._curr_ckpt_path is not None, \
"PipelineEngine expects module_state_dict() to be called from save_checkpoint()"
self.module.save_state_dict(self._curr_ckpt_path,
checkpoint_engine=self.checkpoint_engine)
self.module.save_state_dict(self._curr_ckpt_path, checkpoint_engine=self.checkpoint_engine)
return None
def load_module_state_dict(self, state_dict, strict=True, custom_load_fn=None):
def load_module_state_dict(self, checkpoint, strict=True, custom_load_fn=None):
"""Override hack to instead use a directory path.
This is important because pipeline models checkpoint by layer instead of rank.
......@@ -1334,6 +1266,7 @@ class PipelineEngine(DeepSpeedEngine):
strict (bool, optional): Strict state loading. Defaults to True.
"""
assert custom_load_fn is None, "custom_load_fn not supported w. pipeline parallelism"
state_dict = checkpoint['module']
if (state_dict is not None) and (not isinstance(state_dict, str)):
super().load_module_state_dict(state_dict, strict)
return
......@@ -1367,9 +1300,7 @@ class PipelineEngine(DeepSpeedEngine):
# For each instruction in the step
for cmd in step_cmds:
if type(cmd) not in self._INSTRUCTION_MAP:
raise RuntimeError(
f'{self.__class__.__name__} does not understand instruction {repr(cmd)}'
)
raise RuntimeError(f'{self.__class__.__name__} does not understand instruction {repr(cmd)}')
# Equivalent to: self._exec_forward_pass(buffer_id=0)
self._exec_instr = MethodType(self._INSTRUCTION_MAP[type(cmd)], self)
......
'''Copyright The Microsoft DeepSpeed Team'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import os
import glob
......@@ -45,6 +48,7 @@ class LayerSpec:
LayerSpec(torch.nn.Linear, self.hidden_hidden, self.out_dim)]
]
"""
def __init__(self, typename, *module_args, **module_kwargs):
self.typename = typename
self.module_args = module_args
......@@ -59,9 +63,7 @@ class LayerSpec:
self.global_rank = -1
def __repr__(self):
return ds_utils.call_to_str(self.typename.__name__,
self.module_args,
self.module_kwargs)
return ds_utils.call_to_str(self.typename.__name__, self.module_args, self.module_kwargs)
def build(self, log=False):
"""Build the stored specification."""
......@@ -72,13 +74,8 @@ class LayerSpec:
class TiedLayerSpec(LayerSpec):
def __init__(self,
key,
typename,
*module_args,
forward_fn=None,
tied_weight_attr='weight',
**module_kwargs):
def __init__(self, key, typename, *module_args, forward_fn=None, tied_weight_attr='weight', **module_kwargs):
super().__init__(typename, *module_args, **module_kwargs)
self.key = key
self.forward_fn = forward_fn
......@@ -120,6 +117,7 @@ class PipelineModule(nn.Module):
activation_checkpoint_func (callable, optional): The function to use for activation checkpointing. Defaults to ``deepspeed.checkpointing.checkpoint``.
checkpointable_layers(list, optional): Checkpointable layers may not be checkpointed. Defaults to None which does not additional filtering.
"""
def __init__(self,
layers,
num_stages=None,
......@@ -154,9 +152,7 @@ class PipelineModule(nn.Module):
seed_str = self.seed_fn.__name__
except AttributeError:
seed_str = None
print(
f'SEED_LAYERS={self.seed_layers} BASE_SEED={self.base_seed} SEED_FN={seed_str}'
)
print(f'SEED_LAYERS={self.seed_layers} BASE_SEED={self.base_seed} SEED_FN={seed_str}')
# Setup world info
self.world_group = dist.new_group(ranks=range(dist.get_world_size()))
......@@ -173,15 +169,13 @@ class PipelineModule(nn.Module):
if topology is None:
if self.world_size % self.num_stages != 0:
raise RuntimeError(
f'num_stages ({self.num_stages}) must divide distributed world size ({self.world_size})'
)
f'num_stages ({self.num_stages}) must divide distributed world size ({self.world_size})')
dp = self.world_size // num_stages
topology = PipeDataParallelTopology(num_pp=num_stages, num_dp=dp)
self._topo = topology
# Construct communicators for pipeline topology
self._grid = PipelineParallelGrid(process_group=self.world_group,
topology=self._topo)
self._grid = PipelineParallelGrid(process_group=self.world_group, topology=self._topo)
self.stage_id = self._topo.get_coord(self.global_rank).pipe
......@@ -245,9 +239,7 @@ class PipelineModule(nn.Module):
self.forward_funcs.append(self.tied_modules[layer.key])
else:
# User specified fn with args (module, input)
self.forward_funcs.append(
partial(layer.forward_fn,
self.tied_modules[layer.key]))
self.forward_funcs.append(partial(layer.forward_fn, self.tied_modules[layer.key]))
# LayerSpec objects contain an nn.Module that should be allocated now.
elif isinstance(layer, LayerSpec):
......@@ -304,8 +296,7 @@ class PipelineModule(nn.Module):
idxs.append(idx)
if len(idxs) == 0:
raise RuntimeError(
f"Partitioning '{layername}' found no valid layers to partition.")
raise RuntimeError(f"Partitioning '{layername}' found no valid layers to partition.")
return idxs
def forward(self, forward_input):
......@@ -327,8 +318,7 @@ class PipelineModule(nn.Module):
for idx, layer in enumerate(self.forward_funcs[start:end]):
self.curr_layer = idx + self._local_start
if self.seed_layers:
new_seed = (self.base_seed *
local_micro_offset) + self.curr_layer
new_seed = (self.base_seed * local_micro_offset) + self.curr_layer
if self.seed_fn:
self.seed_fn(new_seed)
else:
......@@ -346,8 +336,7 @@ class PipelineModule(nn.Module):
num_layers = len(self.forward_funcs)
x = forward_input
for start_idx in range(0, num_layers, self.activation_checkpoint_interval):
end_idx = min(start_idx + self.activation_checkpoint_interval,
num_layers)
end_idx = min(start_idx + self.activation_checkpoint_interval, num_layers)
funcs = self.forward_funcs[start_idx:end_idx]
# Since we either pass tensors or tuples of tensors without unpacking, we
......@@ -356,10 +345,7 @@ class PipelineModule(nn.Module):
x = (x, )
if self._is_checkpointable(funcs):
x = self.activation_checkpoint_func(
exec_range_func(start_idx,
end_idx),
*x)
x = self.activation_checkpoint_func(exec_range_func(start_idx, end_idx), *x)
else:
x = exec_range_func(start_idx, end_idx)(*x)
return x
......@@ -376,19 +362,16 @@ class PipelineModule(nn.Module):
# Each stage gets a simple uniform number of layers.
if method == 'uniform':
num_layers = len(self._layer_specs)
self.parts = ds_utils.partition_uniform(num_items=num_layers,
num_parts=num_stages)
self.parts = ds_utils.partition_uniform(num_items=num_layers, num_parts=num_stages)
elif method == 'parameters':
param_counts = self._count_layer_params()
self.parts = ds_utils.partition_balanced(weights=param_counts,
num_parts=num_stages)
self.parts = ds_utils.partition_balanced(weights=param_counts, num_parts=num_stages)
elif method.startswith('type:'):
layertype = method.split(':')[1]
binary_weights = [0] * len(self._layer_specs)
for idx in self._find_layer_type(layertype):
binary_weights[idx] = 1
self.parts = ds_utils.partition_balanced(weights=binary_weights,
num_parts=num_stages)
self.parts = ds_utils.partition_balanced(weights=binary_weights, num_parts=num_stages)
elif method == 'profile':
raise NotImplementedError(f'Partitioning method {method} not implemented.')
else:
......@@ -436,8 +419,7 @@ class PipelineModule(nn.Module):
def _synchronize_tied_weights(self):
for key, comm in self.tied_comms.items():
dist.broadcast(
getattr(comm['module'],
comm['weight_attr']),
getattr(comm['module'], comm['weight_attr']),
src=min(comm['ranks']),
group=comm['group'],
)
......@@ -467,14 +449,9 @@ class PipelineModule(nn.Module):
tied_ranks = []
for s in sorted(tied_stages):
if self._grid.get_slice_parallel_world_size() > 1:
tied_ranks.append(
self._grid.stage_to_global(stage_id=s,
data=dp,
model=mp))
tied_ranks.append(self._grid.stage_to_global(stage_id=s, data=dp, model=mp))
else:
tied_ranks.append(
self._grid.stage_to_global(stage_id=s,
data=dp))
tied_ranks.append(self._grid.stage_to_global(stage_id=s, data=dp))
group = dist.new_group(ranks=tied_ranks)
# Record this tied module if we own a local copy of it.
......@@ -587,7 +564,7 @@ class PipelineModule(nn.Module):
start, end = 0, num_layers
layer_list = self.forward_funcs[start:end]
os.makedirs(save_dir, exist_ok=True)
checkpoint_engine.makedirs(save_dir, exist_ok=True)
for idx, layer in enumerate(layer_list):
model_ckpt_path = self.ckpt_layer_path(save_dir, start + idx)
if not hasattr(layer, 'state_dict'):
......@@ -599,10 +576,7 @@ class PipelineModule(nn.Module):
# It is expected that the garbage collector will reclaim the cloned tensor storage to avoid memory bloat.
# See https://pytorch.org/docs/stable/notes/serialization.html#preserve-storage-sharing
orig_state_dict = layer.state_dict()
final_state_dict = type(orig_state_dict)(
{k: v.clone()
for k,
v in orig_state_dict.items()})
final_state_dict = type(orig_state_dict)({k: v.clone() for k, v in orig_state_dict.items()})
checkpoint_engine.save(final_state_dict, model_ckpt_path)
def load_state_dir(self, load_dir, checkpoint_engine, strict=True):
......@@ -616,10 +590,9 @@ class PipelineModule(nn.Module):
mp_rank = self._grid.get_slice_parallel_rank()
mp_world_size = self._grid.get_slice_parallel_world_size()
sd_loader = SDLoaderFactory.get_sd_loader(
model_ckpt_list,
version=2.0,
checkpoint_engine=checkpoint_engine)
sd_loader = SDLoaderFactory.get_sd_loader(model_ckpt_list,
version=2.0,
checkpoint_engine=checkpoint_engine)
load_path, checkpoint, _ = sd_loader.load(mp_world_size, mp_rank, module_key=None, is_pipe_parallel=True)
layer.load_state_dict(checkpoint)
......@@ -636,8 +609,7 @@ class PipelineModule(nn.Module):
# Some layers like torch.nn.Embedding will not receive grads if checkpointed, which breaks things.
# I presume it's related to the discrete inputs that cannot require_grad? Need to revisit.
if self.__class__.__name__ in ('GPTModelPipe', 'GPT2ModelPipe'):
return all('ParallelTransformerLayerPipe' in f.__class__.__name__
for f in funcs)
return all('ParallelTransformerLayerPipe' in f.__class__.__name__ for f in funcs)
if self.checkpointable_layers is not None:
return all(f.__class__.__name__ in self.checkpointable_layers for f in funcs)
......
'''
Copyright 2019 The Microsoft DeepSpeed Team
'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import pickle
import typing
......@@ -111,12 +112,10 @@ def send_obj(msg: typing.Any, dest: int):
# serialize the message
msg = pickle.dumps(msg)
# construct a tensor to send
msg = torch.ByteTensor(torch.ByteStorage.from_buffer(msg)).to(
get_accelerator().device_name())
msg = torch.ByteTensor(torch.ByteStorage.from_buffer(msg)).to(get_accelerator().device_name())
# Send meta and message
length_tensor = torch.tensor([len(msg)],
dtype=torch.long).to(get_accelerator().device_name())
length_tensor = torch.tensor([len(msg)], dtype=torch.long).to(get_accelerator().device_name())
dist.send(length_tensor, dst=dest)
dist.send(msg, dst=dest)
......@@ -135,8 +134,7 @@ def recv_obj(sender: int) -> typing.Any:
dist.recv(length, src=sender)
# Receive and deserialize
msg = torch.empty(length.item(),
dtype=torch.uint8).to(get_accelerator().device_name())
msg = torch.empty(length.item(), dtype=torch.uint8).to(get_accelerator().device_name())
dist.recv(msg, src=sender)
msg = pickle.loads(msg.cpu().numpy().tobytes())
......
'''Copyright The Microsoft DeepSpeed Team'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from ..utils import call_to_str
......@@ -42,6 +45,7 @@ class PipeSchedule(ABC):
stages (int): The number of pipeline stages.
stage_id (int): The pipe stage that will execute the generated schedule.
"""
def __init__(self, micro_batches, stages, stage_id):
super().__init__()
self.micro_batches = micro_batches
......@@ -131,6 +135,7 @@ class PipeSchedule(ABC):
class InferenceSchedule(PipeSchedule):
"""A schedule for inferencing batches using pipeline parallelism.
"""
def steps(self):
""""""
prev_micro_batch_id = -1
......@@ -188,6 +193,7 @@ class TrainSchedule(PipeSchedule):
convergence follows that of a data parallel approach with the same batch
size.
"""
def steps(self):
""""""
prev_micro_batch_id = -1
......@@ -206,19 +212,15 @@ class TrainSchedule(PipeSchedule):
# Exchange activations
if is_forward:
if self._valid_micro_batch(micro_batch_id) and self._valid_stage(
self.prev_stage):
cmds.append(RecvActivation(curr_buffer))
if self._valid_micro_batch(prev_micro_batch_id) and self._valid_stage(
self.prev_stage):
if self._valid_micro_batch(prev_micro_batch_id) and self._valid_stage(self.prev_stage):
cmds.append(SendGrad(prev_buffer))
if self._valid_micro_batch(micro_batch_id) and self._valid_stage(self.prev_stage):
cmds.append(RecvActivation(curr_buffer))
else:
if self._valid_micro_batch(prev_micro_batch_id) and self._valid_stage(
self.next_stage):
cmds.append(SendActivation(prev_buffer))
if self._valid_micro_batch(micro_batch_id) and self._valid_stage(
self.next_stage):
if self._valid_micro_batch(micro_batch_id) and self._valid_stage(self.next_stage):
cmds.append(RecvGrad(curr_buffer))
if self._valid_micro_batch(prev_micro_batch_id) and self._valid_stage(self.next_stage):
cmds.append(SendActivation(prev_buffer))
# First/last stage loads
if self.stage_id == 0 or self.stage_id == self.stages - 1:
......@@ -243,9 +245,14 @@ class TrainSchedule(PipeSchedule):
yield cmds
def num_pipe_buffers(self):
"""As many buffers as the distance from this stage to the last stage.
"""Return the number of pipeline buffers required for this stage.
This is equivalent to the maximum number of in-flight forward passes,
since we need to remember the activations of forward passes in order
to run backpropagation. For synchronous 1F1B, this is equivalent to
the index difference between this stage and the last stage.
"""
buffers = min(self.stages - self.stage_id + 1, self.micro_batches)
buffers = min(self.stages - self.stage_id, self.micro_batches)
return max(2, buffers)
def _step_to_micro_batch(self, step_id):
......@@ -295,6 +302,7 @@ class DataParallelSchedule(PipeSchedule):
"""An example schedule that trains using traditional data parallelism with gradient
accumulation.
"""
def steps(self):
""""""
for step_id in range(self.micro_batches):
......@@ -325,6 +333,7 @@ class PipeInstruction:
Args:
kwargs (optional): keyword arguments to store as members
"""
def __init__(self, **kwargs):
self.name = self.__class__.__name__
self.kwargs = kwargs
......@@ -369,6 +378,7 @@ class BufferOpInstruction(PipeInstruction):
Args:
buffer_id (int): the index of the pipeline buffer() to modify.
"""
def __init__(self, buffer_id, **kwargs):
super().__init__(buffer_id=buffer_id, **kwargs)
......
# Copyright 2019 The Microsoft DeepSpeed Team
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from deepspeed import comm as dist
......@@ -19,6 +22,7 @@ class ProcessTopology:
Some methods return ProcessCoord namedtuples.
"""
def __init__(self, axes, dims):
"""Create a mapping of n-dimensional tensor coordinates to linear indices.
......@@ -62,12 +66,7 @@ class ProcessTopology:
"""Return a list of the axis names in the ordering of the topology. """
return self.axes
def get_rank_repr(self,
rank,
omit_axes=['data',
'pipe'],
inner_sep='_',
outer_sep='-'):
def get_rank_repr(self, rank, omit_axes=['data', 'pipe'], inner_sep='_', outer_sep='-'):
"""Return a string representation of a rank.
This method is primarily used for checkpointing model data.
......@@ -181,6 +180,7 @@ class ProcessTopology:
Returns:
The list of ranks whose coordinates match filter_kwargs.
"""
def _filter_helper(x):
for key, val in filter_kwargs.items():
if getattr(x, key) != val:
......@@ -236,12 +236,14 @@ class PipeDataParallelTopology(ProcessTopology):
reductions to use high-bandwidth intra-node links and lower-volume
pipeline communications to use low-bandwidth inter-node links.
"""
def __init__(self, num_pp, num_dp):
super().__init__(axes=['pipe', 'data'], dims=[num_pp, num_dp])
class PipeModelDataParallelTopology(ProcessTopology):
""" A topology for hybrid pipeline, model, and data parallelism. """
def __init__(self, num_pp, num_mp, num_dp):
super().__init__(axes=['pipe', 'data', 'model'], dims=[num_pp, num_dp, num_mp])
......@@ -268,6 +270,7 @@ class PipelineParallelGrid:
data_parallel_id = 0, or similarly [9,5] represents wrapped around stages [4,0]
for data_parallel_id = 1.
"""
def __init__(self, topology=None, process_group=None):
# TODO use process_group if provided
self.global_rank = dist.get_rank()
......
'''Copyright The Microsoft DeepSpeed Team'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import numpy as np
from deepspeed.utils import log_dist
......@@ -13,6 +16,7 @@ class ProgressiveLayerDrop(object):
The lower the theta value, the faster the training speed. Default value: 0.5.
gamma (float): a hyper-parameter that controls how fast the drop ratio increases. Default value: 0.001.
"""
def __init__(self, theta=0.5, gamma=0.001):
super().__init__()
......@@ -29,6 +33,7 @@ class ProgressiveLayerDrop(object):
return self.current_theta
def update_state(self, global_step):
def _prob(x, gamma, p):
return (1. - p) * np.exp(-gamma * x) + p
......
'''Copyright The Microsoft DeepSpeed Team'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
import math
......@@ -9,6 +12,7 @@ TWO_D_PARAMS = 6
class Quantizer(object):
def __init__(self,
q_groups=1,
q_mixed_fp16=False,
......@@ -39,17 +43,12 @@ class Quantizer(object):
result = False
for index in range(self.layer_num):
if self.q_start_bits[index] != self.q_target_bits:
next_step = self.qsteps + (
TWO_D_PARAMS * (self.layer_num if self.layer_num != 0 else 1))
next_step = self.qsteps + (TWO_D_PARAMS * (self.layer_num if self.layer_num != 0 else 1))
if next_step >= self.q_period[index]:
result = True
return result
def quantize(self,
parameter_group,
overflow,
eigenvalue_enabled,
block_eigenvalue={}):
def quantize(self, parameter_group, overflow, eigenvalue_enabled, block_eigenvalue={}):
if overflow and not eigenvalue_enabled:
return
......@@ -65,7 +64,8 @@ class Quantizer(object):
if block_eigenvalue is None:
eigenvalue, layer_id = None, 0
else:
eigenvalue, layer_id = block_eigenvalue[param_id] if param_id in block_eigenvalue else (None, 0)
eigenvalue, layer_id = block_eigenvalue[param_id] if param_id in block_eigenvalue else (None,
0)
if eigenvalue is not None:
factor = 1 + math.floor(eigenvalue * 4)
p.data = self.compute_quantization(p.data, layer_id, factor)
......@@ -91,15 +91,11 @@ class Quantizer(object):
if self.q_type == 'symmetric':
scale = 2 * torch.max(torch.abs(g_min), torch.abs(g_max)) / q_range
zero_point = 0.
input_flat = (input_flat / scale + p).round().clamp(
-(q_range >> 1),
(q_range >> 1) - 1) * scale
input_flat = (input_flat / scale + p).round().clamp(-(q_range >> 1), (q_range >> 1) - 1) * scale
elif self.q_type == 'asymmetric':
scale = (g_max - g_min) / q_range
zero_point = (g_min / scale).round() * scale
input_flat = ((input_flat - zero_point) / scale + p).round().clamp(
0,
(q_range - 1)) * scale + zero_point
input_flat = ((input_flat - zero_point) / scale + p).round().clamp(0, (q_range - 1)) * scale + zero_point
output = input_flat.reshape(inputs.shape).contiguous()
return output
......@@ -126,8 +122,7 @@ class Quantizer(object):
def mixed_fp16_quantize(self, input, input_q, index):
if self.q_mixed_fp16 and self.q_start_bits[index] >= (self.q_target_bits - 1):
input_q = input * self.quantize_real_ratio + (
1 - self.quantize_real_ratio) * input_q
input_q = input * self.quantize_real_ratio + (1 - self.quantize_real_ratio) * input_q
return input_q
return input_q
......@@ -152,15 +147,12 @@ class Quantizer(object):
if self.use_quantizer_kernel:
if input.start_bits <= 2:
raise ValueError(
'Quantization bit is too low, please do it without quantization kernel!'
)
input_q = ds_quantizer(
input.data.clone(),
self.q_groups,
input.start_bits,
asym=False if self.q_type == 'symmetric' else True,
sr=False if self.q_rounding == 'nearest_neighbor' else True)
raise ValueError('Quantization bit is too low, please do it without quantization kernel!')
input_q = ds_quantizer(input.data.clone(),
self.q_groups,
input.start_bits,
asym=False if self.q_type == 'symmetric' else True,
sr=False if self.q_rounding == 'nearest_neighbor' else True)
else:
if input.start_bits >= 3:
input_flat = self.quantize_highbit(input.data, input.start_bits)
......
"""
Copyright 2020 The Microsoft DeepSpeed Team
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
"""
Implementation of a compressed sparse tensor. Similar in
functionality to TensorFlow's IndexedSlices implementation.
"""
......@@ -10,6 +12,7 @@ import torch
class SparseTensor(object):
""" Compressed Sparse Tensor """
def __init__(self, dense_tensor=None):
self.orig_dense_tensor = dense_tensor
self.is_sparse = dense_tensor.is_sparse
......@@ -29,9 +32,7 @@ class SparseTensor(object):
self.dense_size = None
def to_coo_tensor(self):
return torch.sparse_coo_tensor(self.indices.unsqueeze(0),
self.values,
self.dense_size)
return torch.sparse_coo_tensor(self.indices.unsqueeze(0), self.values, self.dense_size)
@staticmethod
def type():
......@@ -40,10 +41,7 @@ class SparseTensor(object):
def to_dense(self):
it = self.indices.unsqueeze(1)
full_indices = torch.cat([it for _ in range(self.dense_size[1])], dim=1)
return self.values.new_zeros(self.dense_size).scatter_add_(
0,
full_indices,
self.values)
return self.values.new_zeros(self.dense_size).scatter_add_(0, full_indices, self.values)
def sparse_size(self):
index_size = list(self.indices.size())
......
'''
Copyright 2020 The Microsoft DeepSpeed Team
'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
import os
......@@ -18,6 +19,7 @@ AUTO_MODULE_KEY = 'auto'
class SDLoaderFactory:
@staticmethod
def get_sd_loader_json(json_file, checkpoint_engine):
if isinstance(json_file, str):
......@@ -33,10 +35,7 @@ class SDLoaderFactory:
mp_size = data.get('mp_size', 0)
if sd_type.lower() in ['bloom', 'ds_model']:
return data
return SDLoaderFactory.get_sd_loader(ckpt_list,
checkpoint_engine,
sd_type,
version)
return SDLoaderFactory.get_sd_loader(ckpt_list, checkpoint_engine, sd_type, version)
@staticmethod
def get_sd_loader(ckpt_list, checkpoint_engine, sd_type='Megatron', version=None):
......@@ -47,12 +46,12 @@ class SDLoaderFactory:
class SDLoaderBase(ABC):
def __init__(self, ckpt_list, version, checkpoint_engine):
self.module_key = None
self.ckpt_list = ckpt_list
self.version = version
self.checkpoint_engine = TorchCheckpointEngine(
) if checkpoint_engine is None else checkpoint_engine
self.checkpoint_engine = TorchCheckpointEngine() if checkpoint_engine is None else checkpoint_engine
self.check_ckpt_list()
def load(self,
......@@ -99,9 +98,9 @@ class SDLoaderBase(ABC):
loc: storage)
if quantize:
quantizer = WeightQuantization(mlp_extra_grouping=mlp_extra_grouping,
mp_size=mp_world_size)
sd_module, all_scales = quantizer.sd_quantize_megatron(self.get_module(sd), quantize_bits, quantize_groups)
quantizer = WeightQuantization(mlp_extra_grouping=mlp_extra_grouping, mp_size=mp_world_size)
sd_module, all_scales = quantizer.sd_quantize_megatron(self.get_module(sd), quantize_bits,
quantize_groups)
self.set_module(sd, sd_module)
else:
all_scales = None
......@@ -118,17 +117,10 @@ class SDLoaderBase(ABC):
assert num_ckpt % mp_world_size == 0, 'Invalid checkpoints and world size for sd merge'
num_to_merge = num_ckpt // mp_world_size
ckpt_list = [
self.ckpt_list[i] for i in range(num_to_merge * mp_rank,
num_to_merge * (mp_rank + 1))
]
ckpt_list = [self.ckpt_list[i] for i in range(num_to_merge * mp_rank, num_to_merge * (mp_rank + 1))]
logger.info(f"mp_rank: {mp_rank}, ckpt_list: {ckpt_list}")
sd_list = [
self.checkpoint_engine.load(ckpt,
map_location=lambda storage,
loc: storage) for ckpt in ckpt_list
]
sd_list = [self.checkpoint_engine.load(ckpt, map_location=lambda storage, loc: storage) for ckpt in ckpt_list]
return sd_list
def get_split_state_dict(self, mp_world_size, mp_rank):
......@@ -139,18 +131,15 @@ class SDLoaderBase(ABC):
ckpt_index = mp_rank // num_to_split
ckpt_offset = mp_rank % num_to_split
logger.info(
f"mp_rank: {mp_rank}, ckpt_list: {self.ckpt_list[ckpt_index]}, offset: {ckpt_offset}"
)
logger.info(f"mp_rank: {mp_rank}, ckpt_list: {self.ckpt_list[ckpt_index]}, offset: {ckpt_offset}")
sd = self.checkpoint_engine.load(self.ckpt_list[ckpt_index],
map_location=lambda storage,
loc: storage)
sd = self.checkpoint_engine.load(self.ckpt_list[ckpt_index], map_location=lambda storage, loc: storage)
return sd, num_to_split, ckpt_offset
def _choose_module_key(self, sd):
assert not ('module' in sd and 'model' in sd), "checkpoint has both 'model' and 'module' keys, not sure how to proceed"
assert not ('module' in sd
and 'model' in sd), "checkpoint has both 'model' and 'module' keys, not sure how to proceed"
assert 'module' in sd or 'model' in sd, "checkpoint contains neither 'model' or 'module' keys, not sure how to proceed"
if 'module' in sd:
return 'module'
......@@ -178,32 +167,19 @@ class SDLoaderBase(ABC):
#logger.info(f'checkpoint file list: {self.ckpt_list}')
assert len(self.ckpt_list) > 0
sd = self.checkpoint_engine.load(self.ckpt_list[0],
map_location=lambda storage,
loc: storage)
sd = self.checkpoint_engine.load(self.ckpt_list[0], map_location=lambda storage, loc: storage)
# check checkpoint count is same with saved mp_world_size
if 'mp_world_size' in sd.keys():
assert len(self.ckpt_list) == sd['mp_world_size'], f"checkpoint count {len(self.ckpt_list)} is different from saved mp_world_size {sd['mp_world_size']}"
assert len(self.ckpt_list) == sd[
'mp_world_size'], f"checkpoint count {len(self.ckpt_list)} is different from saved mp_world_size {sd['mp_world_size']}"
@abstractmethod
def merge_state_dict(self,
mp_world_size,
mp_rank,
quantize,
quantize_bits,
groups,
mlp_extra_grouping):
def merge_state_dict(self, mp_world_size, mp_rank, quantize, quantize_bits, groups, mlp_extra_grouping):
pass
@abstractmethod
def split_state_dict(self,
mp_world_size,
mp_rank,
quantize,
quantize_bits,
groups,
mlp_extra_grouping):
def split_state_dict(self, mp_world_size, mp_rank, quantize, quantize_bits, groups, mlp_extra_grouping):
pass
@abstractmethod
......@@ -212,6 +188,7 @@ class SDLoaderBase(ABC):
class MegatronSDLoader(SDLoaderBase):
def __init__(self, ckpt_list, version, checkpoint_engine):
super().__init__(ckpt_list, version, checkpoint_engine)
"""
......@@ -340,40 +317,27 @@ class MegatronSDLoader(SDLoaderBase):
ckpt_ver = self.get_checkpoint_version(ds_sd)
logger.info(f"checkpoint version: {ckpt_ver}")
if quantize:
quantizer = WeightQuantization(mlp_extra_grouping=mlp_extra_grouping,
mp_size=mp_world_size)
quantizer = WeightQuantization(mlp_extra_grouping=mlp_extra_grouping, mp_size=mp_world_size)
for key in keys:
value_list = [sd[key] for sd in client_sd_list]
if "attention.dense.weight" in key or "mlp.dense_4h_to_h.weight" in key:
if quantize:
value_list = quantizer.Quantize(value_list,
quantize_bits,
groups,
key=key,
merge_dim=1)
value_list = quantizer.Quantize(value_list, quantize_bits, groups, key=key, merge_dim=1)
new_client_sd[key] = torch.cat(value_list, axis=1)
elif "attention.query_key_value" in key:
if quantize and "attention.query_key_value.weight" in key:
value_list = quantizer.Quantize(value_list,
quantize_bits,
groups,
key=key)
value_list = quantizer.Quantize(value_list, quantize_bits, groups, key=key)
new_client_sd[key] = torch.cat(value_list, axis=0)
else:
if quantize:
new_client_sd[key] = torch.cat(value_list, axis=0)
else:
new_client_sd[key] = self.merge_query_key_value(
value_list,
ckpt_ver)
new_client_sd[key] = self.merge_query_key_value(value_list, ckpt_ver)
elif "mlp.dense_h_to_4h.weight" in key or "word_embeddings.weight" in key or "mlp.dense_h_to_4h.bias" in key:
if quantize and "mlp.dense_h_to_4h.weight" in key:
value_list = quantizer.Quantize(value_list,
quantize_bits,
groups,
key=key)
value_list = quantizer.Quantize(value_list, quantize_bits, groups, key=key)
new_client_sd[key] = torch.cat(value_list, axis=0)
else:
new_client_sd[key] = value_list[0]
......@@ -402,8 +366,7 @@ class MegatronSDLoader(SDLoaderBase):
logger.info(f"checkpoint version: {ckpt_ver}")
if quantize:
quantizer = WeightQuantization(mlp_extra_grouping=mlp_extra_grouping,
mp_size=mp_world_size)
quantizer = WeightQuantization(mlp_extra_grouping=mlp_extra_grouping, mp_size=mp_world_size)
for key in client_sd.keys():
value = client_sd[key]
......@@ -419,11 +382,7 @@ class MegatronSDLoader(SDLoaderBase):
if quantize and "attention.query_key_value.weight" in key:
q_vals = quantizer.Quantize([value], quantize_bits, groups, key)
value = q_vals[0]
new_client_sd[key] = self.split_query_key_value(
value,
num_to_split,
ckpt_offset,
ckpt_ver)
new_client_sd[key] = self.split_query_key_value(value, num_to_split, ckpt_offset, ckpt_ver)
elif "mlp.dense_h_to_4h.weight" in key or "word_embeddings.weight" in key or "mlp.dense_h_to_4h.bias" in key or "final_linear.weight" in key:
assert value.shape[0] % num_to_split == 0
split_size = value.shape[0] // num_to_split
......@@ -443,16 +402,11 @@ class MegatronSDLoader(SDLoaderBase):
def sanity_check(self, ckpt_file_name):
keys_to_check = [
"attention.dense.weight",
"mlp.dense_4h_to_h.weight",
"attention.query_key_value",
"mlp.dense_h_to_4h.weight",
"mlp.dense_h_to_4h.bias"
"attention.dense.weight", "mlp.dense_4h_to_h.weight", "attention.query_key_value",
"mlp.dense_h_to_4h.weight", "mlp.dense_h_to_4h.bias"
]
sd = self.checkpoint_engine.load(ckpt_file_name,
map_location=lambda storage,
loc: storage)
sd = self.checkpoint_engine.load(ckpt_file_name, map_location=lambda storage, loc: storage)
# partial_key is a sub-string of one key in the sd
def check_key_exist(partial_key, sd):
......@@ -465,10 +419,9 @@ class MegatronSDLoader(SDLoaderBase):
return found
for key in keys_to_check:
assert check_key_exist(key, self.get_module(sd)), f'key: {key} is not found in the checkpoint {ckpt_file_name}'
assert check_key_exist(key,
self.get_module(sd)), f'key: {key} is not found in the checkpoint {ckpt_file_name}'
def get_checkpoint_version(self, state_dict):
# Use 0 if version info doesn't exist
return self.version if self.version is not None else state_dict.get(
'checkpoint_version',
0)
return self.version if self.version is not None else state_dict.get('checkpoint_version', 0)
'''
Copyright 2020 The Microsoft DeepSpeed Team.
Licensed under the MIT license.
'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
'''
Copyright 2020 The Microsoft DeepSpeed Team.
Licensed under the MIT license.
'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from deepspeed.runtime.config_utils import get_scalar_param
from deepspeed.runtime.swap_tensor.constants import *
......@@ -19,26 +19,11 @@ def get_aio_config(param_dict):
if AIO in param_dict.keys() and param_dict[AIO] is not None:
aio_dict = param_dict[AIO]
return {
AIO_BLOCK_SIZE:
get_scalar_param(aio_dict,
AIO_BLOCK_SIZE,
AIO_BLOCK_SIZE_DEFAULT),
AIO_QUEUE_DEPTH:
get_scalar_param(aio_dict,
AIO_QUEUE_DEPTH,
AIO_QUEUE_DEPTH_DEFAULT),
AIO_THREAD_COUNT:
get_scalar_param(aio_dict,
AIO_THREAD_COUNT,
AIO_THREAD_COUNT_DEFAULT),
AIO_SINGLE_SUBMIT:
get_scalar_param(aio_dict,
AIO_SINGLE_SUBMIT,
AIO_SINGLE_SUBMIT_DEFAULT),
AIO_OVERLAP_EVENTS:
get_scalar_param(aio_dict,
AIO_OVERLAP_EVENTS,
AIO_OVERLAP_EVENTS_DEFAULT)
AIO_BLOCK_SIZE: get_scalar_param(aio_dict, AIO_BLOCK_SIZE, AIO_BLOCK_SIZE_DEFAULT),
AIO_QUEUE_DEPTH: get_scalar_param(aio_dict, AIO_QUEUE_DEPTH, AIO_QUEUE_DEPTH_DEFAULT),
AIO_THREAD_COUNT: get_scalar_param(aio_dict, AIO_THREAD_COUNT, AIO_THREAD_COUNT_DEFAULT),
AIO_SINGLE_SUBMIT: get_scalar_param(aio_dict, AIO_SINGLE_SUBMIT, AIO_SINGLE_SUBMIT_DEFAULT),
AIO_OVERLAP_EVENTS: get_scalar_param(aio_dict, AIO_OVERLAP_EVENTS, AIO_OVERLAP_EVENTS_DEFAULT)
}
return AIO_DEFAULT_DICT
"""
Copyright 2020 The Microsoft DeepSpeed Team.
Licensed under the MIT license.
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
"""
Functionality of swapping tensors to/from (NVMe) storage devices.
"""
import torch
......@@ -15,6 +16,7 @@ ASYNC_SWAPPER_WAIT_TIMER = 'async_swap_gradient_wait'
class AsyncTensorSwapper(object):
def __init__(self, aio_handle, numel_alignment, timers):
self.free_buffer_index = []
self.swapping_buffer_index = []
......@@ -70,9 +72,7 @@ class AsyncTensorSwapper(object):
if dist.get_rank() == 0:
element_size = torch.tensor([], dtype=self.dtype).element_size()
swapped_GB = (self.num_elements_swapped * element_size) / (1024**3)
logger.debug(
f'{message} num_elems = {self.num_elements_swapped}, {swapped_GB:5.2f} GB'
)
logger.debug(f'{message} num_elems = {self.num_elements_swapped}, {swapped_GB:5.2f} GB')
def _swap_out_tensor(self, tensor, swap_path):
assert len(self.all_buffers) > 0
......
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
"""
"Copyright 2020 The Microsoft DeepSpeed Team.
Licensed under the MIT license.
AIO
"""
#########################################
# AIO
#########################################
AIO_FORMAT = '''
"aio": {
"block_size": 1048576,
......
"""
Copyright 2020 The Microsoft DeepSpeed Team.
Licensed under the MIT license.
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
"""
Functionality of swapping tensors to/from (NVMe) storage devices.
"""
......@@ -17,6 +18,7 @@ from deepspeed.runtime.swap_tensor.utils import SwapBufferManager, SwapBufferPoo
class FlattenedTensorSwapInfo(object):
def __init__(self, path, length, offset):
self.path = path
self.offset = offset
......@@ -24,6 +26,7 @@ class FlattenedTensorSwapInfo(object):
class OptimizerStateSwapInfo(object):
def __init__(self, parameter, numel, base_folder):
self.tensors = []
self.param_id = id(parameter)
......@@ -66,13 +69,8 @@ class OptimizerStateSwapInfo(object):
gradient_paths = []
for offset, length in zip(offsets, lengths):
if not offset in self.swapped_gradients.keys():
path = os.path.join(
self.swap_folder,
f'{self.param_id}_gradient_{offset}_{length}.tensor.swp')
self.swapped_gradients[offset] = FlattenedTensorSwapInfo(
path,
length,
offset)
path = os.path.join(self.swap_folder, f'{self.param_id}_gradient_{offset}_{length}.tensor.swp')
self.swapped_gradients[offset] = FlattenedTensorSwapInfo(path, length, offset)
gradient_paths.append(self.swapped_gradients[offset].path)
......@@ -86,11 +84,7 @@ class OptimizerStateSwapInfo(object):
def get_swap_gradient_buffers(self, swap_buffer):
assert self.numel() <= swap_buffer.numel()
return [
swap_buffer.narrow(0,
grad.offset,
grad.length) for grad in self.swapped_gradients.values()
]
return [swap_buffer.narrow(0, grad.offset, grad.length) for grad in self.swapped_gradients.values()]
def get_swap_gradient_paths(self):
return [grad.path for grad in self.swapped_gradients.values()]
......@@ -116,24 +110,15 @@ SWAP_OUT_GRADIENT_TIMER = 'swap_out_gradient'
class OptimizerSwapper(object):
def __init__(self,
swap_config,
aio_config,
base_folder,
optimizer,
largest_numel,
device,
dtype,
timers):
def __init__(self, swap_config, aio_config, base_folder, optimizer, largest_numel, device, dtype, timers):
self.swap_config = swap_config
self.aio_config = aio_config
# NVMe swap management
self.swap_params_info = {}
self.swap_element_size = torch.tensor([], dtype=dtype).element_size()
self.swap_folder = os.path.join(base_folder,
'optimizer',
f'rank{dist.get_rank()}')
self.swap_folder = os.path.join(base_folder, 'optimizer', f'rank{dist.get_rank()}')
os.makedirs(self.swap_folder, exist_ok=True)
self.optimizer = optimizer
......@@ -191,11 +176,7 @@ class OptimizerSwapper(object):
self.timer_names.add(SWAP_OUT_GRADIENT_TIMER)
self.timer_names.update(gradient_swapper.get_timer_names())
def _swap_out_gradients(self,
parameter,
gradient_offsets,
gradient_tensors,
gradient_swapper):
def _swap_out_gradients(self, parameter, gradient_offsets, gradient_tensors, gradient_swapper):
if not id(parameter) in self.swap_params_info.keys():
return
......@@ -205,10 +186,8 @@ class OptimizerSwapper(object):
swappable_offsets = []
swappable_lengths = []
aligned_gradients, aligned_offsets = self._adjust_for_misaligned_lengths(
tensors=gradient_tensors,
offsets=gradient_offsets
)
aligned_gradients, aligned_offsets = self._adjust_for_misaligned_lengths(tensors=gradient_tensors,
offsets=gradient_offsets)
self._start_timer(SWAP_OUT_GRADIENT_TIMER)
for tensor, offset in zip(aligned_gradients, aligned_offsets):
......@@ -222,38 +201,26 @@ class OptimizerSwapper(object):
if len(swappable_tensors) > 0:
if not gradient_swapper.has_buffers():
pinned_buffers = self.swap_buffer_manager.allocate_all(
num_elems=self.largest_numel,
dtype=self.dtype)
pinned_buffers = self.swap_buffer_manager.allocate_all(num_elems=self.largest_numel, dtype=self.dtype)
gradient_swapper.add_buffers(pinned_buffers)
swappable_paths = swap_info.get_or_create_gradient_paths(
swappable_offsets,
swappable_lengths)
swappable_paths = swap_info.get_or_create_gradient_paths(swappable_offsets, swappable_lengths)
gradient_swapper.swap_out_tensors(tensor_list=swappable_tensors,
path_list=swappable_paths)
gradient_swapper.swap_out_tensors(tensor_list=swappable_tensors, path_list=swappable_paths)
self._stop_timer(SWAP_OUT_GRADIENT_TIMER)
self.timer_names.add(SWAP_OUT_GRADIENT_TIMER)
def _initialize_from_swapped_fp16_params(self,
aio_handle,
fp16_partitions_info,
fp16_num_elems,
fp16_pinned_buffers,
fp32_parameters):
def _initialize_from_swapped_fp16_params(self, aio_handle, fp16_partitions_info, fp16_num_elems,
fp16_pinned_buffers, fp32_parameters):
assert len(fp32_parameters) == len(fp16_partitions_info)
assert len(fp32_parameters) == len(fp16_num_elems)
assert all([buffer.is_pinned() for buffer in fp16_pinned_buffers])
fp32_swap_paths = self._get_swap_paths(parameters=fp32_parameters,
num_elems=fp16_num_elems)
fp32_swap_paths = self._get_swap_paths(parameters=fp32_parameters, num_elems=fp16_num_elems)
fp32_pinned_buffers = self.swap_buffer_manager.allocate_all(
num_elems=self.largest_numel,
dtype=self.dtype)
fp32_pinned_buffers = self.swap_buffer_manager.allocate_all(num_elems=self.largest_numel, dtype=self.dtype)
fp16_buffer_numel = [buf.numel() for buf in fp16_pinned_buffers]
assert all([numel >= self.largest_numel for numel in fp16_buffer_numel]), \
......@@ -264,11 +231,10 @@ class OptimizerSwapper(object):
curr_index = 0
while curr_index < len(fp32_parameters):
fp16_pinned_tensors = self._swap_in_fp16_params(
aio_handle=aio_handle,
fp16_num_elems=fp16_num_elems[curr_index:],
fp16_partitions_info=fp16_partitions_info[curr_index:],
fp16_swap_buffers=fp16_swap_buffers)
fp16_pinned_tensors = self._swap_in_fp16_params(aio_handle=aio_handle,
fp16_num_elems=fp16_num_elems[curr_index:],
fp16_partitions_info=fp16_partitions_info[curr_index:],
fp16_swap_buffers=fp16_swap_buffers)
if dist.get_rank() == 0 and SWAPPER_DEBUG_MODE:
for i, tensor in enumerate(fp16_pinned_tensors):
......@@ -277,11 +243,10 @@ class OptimizerSwapper(object):
f'swap_in_fp16_param: fp32_id = {id(fp32_parameters[true_index])} index = {true_index} orig_num_elem = {fp16_num_elems[true_index]}, swap_num_elem = {fp16_pinned_tensors[i].numel()}'
)
swap_out_count = self._swap_out_fp16_params(
aio_handle=aio_handle,
fp32_swap_paths=fp32_swap_paths[curr_index:],
fp32_swap_buffers=fp32_swap_buffers,
fp16_pinned_tensors=fp16_pinned_tensors)
swap_out_count = self._swap_out_fp16_params(aio_handle=aio_handle,
fp32_swap_paths=fp32_swap_paths[curr_index:],
fp32_swap_buffers=fp32_swap_buffers,
fp16_pinned_tensors=fp16_pinned_tensors)
assert swap_out_count == len(fp16_pinned_tensors), \
f"{swap_out_count} does not match {len(fp16_pinned_tensors)}"
......@@ -291,11 +256,7 @@ class OptimizerSwapper(object):
self.swap_buffer_manager.free(fp32_pinned_buffers)
def _swap_in_fp16_params(self,
aio_handle,
fp16_num_elems,
fp16_partitions_info,
fp16_swap_buffers):
def _swap_in_fp16_params(self, aio_handle, fp16_num_elems, fp16_partitions_info, fp16_swap_buffers):
assert len(fp16_num_elems) > 0
swapped_fp16_tensors = []
......@@ -330,11 +291,7 @@ class OptimizerSwapper(object):
return swapped_fp16_tensors
def _swap_out_fp16_params(self,
aio_handle,
fp32_swap_paths,
fp32_swap_buffers,
fp16_pinned_tensors):
def _swap_out_fp16_params(self, aio_handle, fp32_swap_paths, fp32_swap_buffers, fp16_pinned_tensors):
assert len(fp16_pinned_tensors) <= len(fp32_swap_paths)
swap_out_count = 0
......@@ -343,11 +300,8 @@ class OptimizerSwapper(object):
fp32_swap_buffers.swap_out(aio_handle)
fp32_swap_buffers.reset()
pinned_tensor, _ = fp32_swap_buffers.insert_tensor(
fp16_tensor,
fp32_swap_paths[i],
self._io_aligned_numel(fp16_tensor.numel())
)
pinned_tensor, _ = fp32_swap_buffers.insert_tensor(fp16_tensor, fp32_swap_paths[i],
self._io_aligned_numel(fp16_tensor.numel()))
assert pinned_tensor is not None
swap_out_count += 1
......@@ -359,15 +313,12 @@ class OptimizerSwapper(object):
def _initialize_parameters(self, parameters, src_tensors, aio_handle):
assert len(parameters) == len(src_tensors)
swap_paths = self._get_swap_paths(parameters=parameters,
num_elems=[src.numel() for src in src_tensors])
swap_paths = self._get_swap_paths(parameters=parameters, num_elems=[src.numel() for src in src_tensors])
SWAP_INIT_TIMER = "swap_init_write"
self._start_timer(SWAP_INIT_TIMER)
pinned_buffers = self.swap_buffer_manager.allocate_all(
num_elems=self.largest_numel,
dtype=self.dtype)
pinned_buffers = self.swap_buffer_manager.allocate_all(num_elems=self.largest_numel, dtype=self.dtype)
assert pinned_buffers is not None
self._swap_out_unpinned_tensors(aio_handle=aio_handle,
......@@ -397,11 +348,7 @@ class OptimizerSwapper(object):
swap_paths = [info.swap_paths[0] for info in swap_info_list]
return swap_paths
def _swap_out_unpinned_tensors(self,
aio_handle,
unpinned_tensors,
dest_paths,
pinned_buffers):
def _swap_out_unpinned_tensors(self, aio_handle, unpinned_tensors, dest_paths, pinned_buffers):
swap_buffer_count = len(pinned_buffers)
unpinned_tensor_count = len(unpinned_tensors)
......@@ -441,8 +388,7 @@ class OptimizerSwapper(object):
continue
# Split into two by making remainder a tensor
aligned_length = (orig_tensor.numel() //
self.numel_alignment) * self.numel_alignment
aligned_length = (orig_tensor.numel() // self.numel_alignment) * self.numel_alignment
new_tensors.append(orig_tensor.narrow(0, 0, aligned_length))
new_offsets.append(orig_offset)
......@@ -489,10 +435,9 @@ class OptimizerSwapper(object):
param_id = id(parameter)
assert not param_id in self.swap_params_info
self.swap_params_info[param_id] = OptimizerStateSwapInfo(
parameter=parameter,
numel=numel,
base_folder=self.swap_folder)
self.swap_params_info[param_id] = OptimizerStateSwapInfo(parameter=parameter,
numel=numel,
base_folder=self.swap_folder)
swap_info = self.swap_params_info[param_id]
self._update_param_state_info(swap_info, parameter)
......
"""
Copyright 2020 The Microsoft DeepSpeed Team
Licensed under the MIT license.
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
"""
Functionality of swapping optimizer tensors to/from (NVMe) storage devices.
"""
......@@ -25,30 +26,14 @@ SWAP_IN_GRADIENT_TIMER = 'swap_in_gradient'
class PartitionedOptimizerSwapper(OptimizerSwapper):
def __init__(self,
swap_config,
aio_config,
base_folder,
optimizer,
largest_numel,
device,
dtype,
timers):
super(PartitionedOptimizerSwapper,
self).__init__(swap_config,
aio_config,
base_folder,
optimizer,
largest_numel,
device,
dtype,
timers)
def __init__(self, swap_config, aio_config, base_folder, optimizer, largest_numel, device, dtype, timers):
super(PartitionedOptimizerSwapper, self).__init__(swap_config, aio_config, base_folder, optimizer,
largest_numel, device, dtype, timers)
aio_op = AsyncIOBuilder().load()
self.aio_handle = aio_op.aio_handle(aio_config[AIO_BLOCK_SIZE],
aio_config[AIO_QUEUE_DEPTH],
aio_config[AIO_SINGLE_SUBMIT],
aio_config[AIO_OVERLAP_EVENTS],
self.aio_handle = aio_op.aio_handle(aio_config[AIO_BLOCK_SIZE], aio_config[AIO_QUEUE_DEPTH],
aio_config[AIO_SINGLE_SUBMIT], aio_config[AIO_OVERLAP_EVENTS],
aio_config[AIO_THREAD_COUNT])
# Overlap swapping out
......@@ -56,33 +41,21 @@ class PartitionedOptimizerSwapper(OptimizerSwapper):
numel_alignment=self.numel_alignment,
timers=self.timers)
self.print_exclude_list += [
'aio_handle',
'gradient_swapper',
'print_exclude_list'
]
self.print_exclude_list += ['aio_handle', 'gradient_swapper', 'print_exclude_list']
if dist.get_rank() == 0:
print_object(obj=self,
name='PartitionedOptimizerSwapper',
exclude_list=self.print_exclude_list)
print_object(obj=self, name='PartitionedOptimizerSwapper', exclude_list=self.print_exclude_list)
def initialize_parameters(self, parameters, src_tensors):
self._initialize_parameters(parameters=parameters,
src_tensors=src_tensors,
aio_handle=self.aio_handle)
def initialize_from_swapped_fp16_params(self,
fp16_partitions_info,
fp16_num_elems,
fp16_pinned_buffers,
self._initialize_parameters(parameters=parameters, src_tensors=src_tensors, aio_handle=self.aio_handle)
def initialize_from_swapped_fp16_params(self, fp16_partitions_info, fp16_num_elems, fp16_pinned_buffers,
fp32_parameters):
self._initialize_from_swapped_fp16_params(
aio_handle=self.aio_handle,
fp16_partitions_info=fp16_partitions_info,
fp16_num_elems=fp16_num_elems,
fp16_pinned_buffers=fp16_pinned_buffers,
fp32_parameters=fp32_parameters)
self._initialize_from_swapped_fp16_params(aio_handle=self.aio_handle,
fp16_partitions_info=fp16_partitions_info,
fp16_num_elems=fp16_num_elems,
fp16_pinned_buffers=fp16_pinned_buffers,
fp32_parameters=fp32_parameters)
def flush_gradients(self):
self._flush_gradient_swapper(self.gradient_swapper)
......@@ -94,8 +67,7 @@ class PartitionedOptimizerSwapper(OptimizerSwapper):
self._flush_gradient_swapper(self.gradient_swapper)
required_buffer_count = len(
swap_info.tensors) + (1 if swap_info.has_gradients() else 0)
required_buffer_count = len(swap_info.tensors) + (1 if swap_info.has_gradients() else 0)
aligned_numel = self._io_aligned_numel(swap_info.numel())
pinned_buffers = self.swap_buffer_manager.allocate(num_elems=aligned_numel,
count=required_buffer_count,
......@@ -111,9 +83,7 @@ class PartitionedOptimizerSwapper(OptimizerSwapper):
self.timer_names.add(SWAP_IN_PARAM_TIMER)
self._start_timer(SWAP_IN_GRADIENT_TIMER)
self._swap_in_gradients(aio_handle=self.aio_handle,
parameter=parameter,
dest_buffer=pinned_buffers[-1])
self._swap_in_gradients(aio_handle=self.aio_handle, parameter=parameter, dest_buffer=pinned_buffers[-1])
self._stop_timer(SWAP_IN_GRADIENT_TIMER)
self.timer_names.add(SWAP_IN_GRADIENT_TIMER)
......@@ -125,10 +95,7 @@ class PartitionedOptimizerSwapper(OptimizerSwapper):
self._start_timer(SWAP_OUT_PARAM_TIMER)
pinned_tensors, pinned_paths, unpinned_tensors, unpinned_paths = self._separate_pinned_tensors(swap_info)
swap_bytes = sum([
self._io_aligned_numel(t.numel()) * t.element_size()
for t in swap_info.tensors
])
swap_bytes = sum([self._io_aligned_numel(t.numel()) * t.element_size() for t in swap_info.tensors])
WRITE_TIMER = 'swap_submit_write'
self._start_timer(WRITE_TIMER)
......@@ -139,9 +106,7 @@ class PartitionedOptimizerSwapper(OptimizerSwapper):
t.data = torch.Tensor()
if len(unpinned_tensors) > 0:
pinned_buffers = self.swap_buffer_manager.allocate_all(
num_elems=self.largest_numel,
dtype=self.dtype)
pinned_buffers = self.swap_buffer_manager.allocate_all(num_elems=self.largest_numel, dtype=self.dtype)
self._swap_out_unpinned_tensors(aio_handle=self.aio_handle,
unpinned_tensors=unpinned_tensors,
dest_paths=unpinned_paths,
......@@ -176,8 +141,7 @@ class PartitionedOptimizerSwapper(OptimizerSwapper):
assert len(swap_info.tensors) <= len(dest_buffers)
swap_lengths = [self._io_aligned_numel(swap_info.numel())] * len(
swap_info.tensors)
swap_lengths = [self._io_aligned_numel(swap_info.numel())] * len(swap_info.tensors)
swap_buffers = get_sized_buffers(dest_buffers, swap_lengths)
READ_TIMER = 'swap_submit_read_param'
......@@ -187,8 +151,7 @@ class PartitionedOptimizerSwapper(OptimizerSwapper):
swap_in_tensors(aio_handle, swap_buffers, swap_info.swap_paths)
self._stop_timer(READ_TIMER)
swap_bytes = sum(
[buffer.numel() * buffer.element_size() for buffer in swap_buffers])
swap_bytes = sum([buffer.numel() * buffer.element_size() for buffer in swap_buffers])
self._start_timer(WAIT_TIMER)
aio_handle.wait()
......@@ -223,11 +186,7 @@ class PartitionedOptimizerSwapper(OptimizerSwapper):
def _swap_in_pinned_gradients(self, aio_handle, parameter, gradient_tensor):
swap_info = self.swap_params_info[id(parameter)]
param_gradients = swap_info.swapped_gradients.values()
swap_buffers = [
gradient_tensor.narrow(0,
grad.offset,
grad.length) for grad in param_gradients
]
swap_buffers = [gradient_tensor.narrow(0, grad.offset, grad.length) for grad in param_gradients]
swap_paths = [grad.path for grad in param_gradients]
SWAP_READ_GRADIENTS = 'swap_submit_read_gradient'
SWAP_WAIT_GRADIENTS = 'swap_submit_wait_gradient'
......@@ -256,5 +215,4 @@ class PartitionedOptimizerSwapper(OptimizerSwapper):
self._swap_in_pinned_gradients(aio_handle, parameter, parameter.grad)
if swap_info.unswapped_gradients:
self._retrieve_unswapped_grad_partitions(swap_info=swap_info,
dest_buffer=parameter.grad)
self._retrieve_unswapped_grad_partitions(swap_info=swap_info, dest_buffer=parameter.grad)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment