Commit e4575be9 authored by huaerkl's avatar huaerkl
Browse files

v1.0

parents
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Megatron grad scaler."""
from abc import ABC
from abc import abstractmethod
import torch
class MegatronGradScaler(ABC):
def __init__(self, initial_scale):
"""Initialize scale value with the input initial scale."""
assert initial_scale > 0.0
self._scale = torch.cuda.FloatTensor([initial_scale])
@property
def scale(self):
return self._scale
@property
def inv_scale(self):
return self._scale.double().reciprocal().float()
@abstractmethod
def update(self, found_inf):
pass
@abstractmethod
def state_dict(self):
pass
@abstractmethod
def load_state_dict(self, state_dict):
pass
class ConstantGradScaler(MegatronGradScaler):
def update(self, found_inf):
pass
def state_dict(self):
return dict()
def load_state_dict(self, state_dict):
pass
class DynamicGradScaler(MegatronGradScaler):
def __init__(self, initial_scale, min_scale,
growth_factor, backoff_factor,
growth_interval, hysteresis):
""""Grad scaler with dynamic scale that gets adjusted
during training."""
super(DynamicGradScaler, self).__init__(initial_scale)
# Lower bound on the scale.
assert min_scale > 0.0
assert min_scale <= initial_scale
self.min_scale = torch.cuda.FloatTensor([min_scale])
# Growth and backoff factors for the scale.
assert growth_factor > 1.0
self.growth_factor = torch.cuda.FloatTensor([growth_factor])
assert backoff_factor < 1.0
assert backoff_factor > 0.0
self.backoff_factor = torch.cuda.FloatTensor([backoff_factor])
# Interval over which if we don't see any inf/nan,
# we will scale the grad scale by the growth factor.
assert growth_interval > 0
self.growth_interval = growth_interval
# Number of inf/nans we should see before scaling down
# the grad scale by the backoff factor.
assert hysteresis > 0
self.hysteresis = hysteresis
# Trackers.
self._growth_tracker = 0
self._hysteresis_tracker = self.hysteresis
def update(self, found_inf):
# If we have an inf/nan, growth tracker is set to 0
# and hysterisis tracker is reduced by 1.
if found_inf:
self._growth_tracker = 0
self._hysteresis_tracker -= 1
# Now if we are out of hysteresis count, scale down the loss.
if self._hysteresis_tracker <= 0:
self._scale = torch.max(self._scale * self.backoff_factor,
self.min_scale)
else:
# If there is no nan/inf, increment the growth tracker.
self._growth_tracker += 1
# If we have had enough consequitive intervals with no nan/inf:
if self._growth_tracker == self.growth_interval:
# Reset the tracker and hysteresis trackers,
self._growth_tracker = 0
self._hysteresis_tracker = self.hysteresis
# and scale up the loss scale.
self._scale = self._scale * self.growth_factor
def state_dict(self):
state_dict = {}
state_dict['scale'] = self._scale
state_dict['growth_tracker'] = self._growth_tracker
state_dict['hysteresis_tracker'] = self._hysteresis_tracker
return state_dict
def load_state_dict(self, state_dict):
self._scale = state_dict['scale'].cuda(torch.cuda.current_device())
self._growth_tracker = state_dict['growth_tracker']
self._hysteresis_tracker = state_dict['hysteresis_tracker']
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Megatron optimizer."""
from abc import ABC
from abc import abstractmethod
import torch
from apex.multi_tensor_apply import multi_tensor_applier
import amp_C
from megatron import get_timers
from megatron import mpu
from megatron import print_rank_0
from .clip_grads import clip_grad_norm_fp32, count_zeros_fp32
def _zero_grad_group_helper(group, set_to_none):
"""Zero out the gradient for a group of parameters.
Note: copied from torch.optim.optimizer."""
for param in group:
if param.grad is not None:
if set_to_none:
param.grad = None
else:
if param.grad.grad_fn is not None:
param.grad.detach_()
else:
param.grad.requires_grad_(False)
param.grad.zero_()
def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None):
"""Use multi-tensor-applier to copy values from one list to another.
We don't have a blfoat16 implementation so for now if the overflow_buf
is not provided, we default back to simple loop copy to be compatible
with bfloat16."""
if overflow_buf:
overflow_buf.fill_(0)
# Scaling with factor `1.0` is equivalent to copy.
multi_tensor_applier(amp_C.multi_tensor_scale,
overflow_buf,
[this, that],
1.0)
else:
for this_, that_ in zip(this, that):
that_.copy_(this_)
class MegatronOptimizer(ABC):
def __init__(self, optimizer, clip_grad,
log_num_zeros_in_grad,
params_have_main_grad):
"""Input optimizer is the base optimizer for example Adam."""
self.optimizer = optimizer
assert self.optimizer, 'no optimizer is provided.'
# Set gradient clipping and logging params.
self.clip_grad = clip_grad
self.log_num_zeros_in_grad = log_num_zeros_in_grad
self.params_have_main_grad = params_have_main_grad
def get_parameters(self):
params = []
for param_group in self.optimizer.param_groups:
for param in param_group['params']:
params.append(param)
return params
def clip_grad_norm(self, clip_grad):
params = self.get_parameters()
return clip_grad_norm_fp32(params, clip_grad)
def count_zeros(self):
params = self.get_parameters()
return count_zeros_fp32(params)
@abstractmethod
def zero_grad(self, set_to_none=True):
pass
@abstractmethod
def get_loss_scale(self):
"""The output should be a cuda tensor of size 1."""
pass
def scale_loss(self, loss):
"""Simple scaling."""
return self.get_loss_scale() * loss
@abstractmethod
def step(self):
pass
@abstractmethod
def reload_model_params(self):
"""Refreshes any internal state from the current model parameters.
Call whenever the parameters are changed outside of the optimizer.
For example, when we load a model from a checkpoint without loading
the optimizer, the model parameters are updated but for fp16 optimizer
with main parameters, the main parameters need to also be updated."""
pass
@abstractmethod
def state_dict(self):
pass
@abstractmethod
def load_state_dict(self, state_dict):
pass
# Promote state so it can be retrieved or set via
# "optimizer_instance.state"
def _get_state(self):
return self.optimizer.state
def _set_state(self, value):
self.optimizer.state = value
state = property(_get_state, _set_state)
# Promote param_groups so it can be retrieved or set via
# "optimizer_instance.param_groups"
# (for example, to adjust the learning rate)
def _get_param_groups(self):
return self.optimizer.param_groups
def _set_param_groups(self, value):
self.optimizer.param_groups = value
param_groups = property(_get_param_groups, _set_param_groups)
class Float16OptimizerWithFloat16Params(MegatronOptimizer):
"""Float16 optimizer for fp16 and bf16 data types.
Arguments:
optimizer: base optimizer such as Adam or SGD
clip_grad: clip gradeints with this global L2 norm. Note
that clipping is ignored if clip_grad == 0
log_num_zeros_in_grad: return number of zeros in the gradients.
params_have_main_grad: flag indicating if parameters have
a `main_grad` field. If this is set, we are assuming
that the model parameters are store in the `main_grad`
field instead of the typical `grad` field. This happens
for the DDP cases where there is a contihuous buffer
holding the gradients. For example for bfloat16, we want
to do gradient accumulation and all-reduces in float32
and as a result we store those gradients in the main_grad.
Note that main grad is not necessarily in float32.
bf16: if true, the model is running in bfloat16.
grad_scaler: used for scaling gradients. Note that this can be
None. This case happens when `bf16 = True` and we don't
use any loss scale. Note that for `bf16 = True`, we can have
a constnat gradient scaler. Also for `bf16 = False`, we
always require a grad scaler.
"""
def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad, bf16, grad_scaler):
super(Float16OptimizerWithFloat16Params, self).__init__(
optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad)
self.bf16 = bf16
self.grad_scaler = grad_scaler
# None grad scaler is only supported for bf16.
if self.grad_scaler is None:
assert self.bf16, 'fp16 expects a grad scaler.'
# Tensor used to determine if a nan/if has happend.
# Any non-zero value indicates inf/nan.
# Note that we keep this for the cases that grad scaler is none.
# We still record nan/inf if we have a bfloat16 with a grad scaler.
if self.grad_scaler:
self.found_inf = torch.cuda.FloatTensor([0.0])
# Dummy tensor needed for apex multi-apply tensor.
# For bfloat, we don't have multi-tensor apply and for now
# we set it to none so the multi-tensor apply gets ignored.
if bf16:
self._dummy_overflow_buf = None
else:
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
# In case grad scaler is not passed, define the unity scale.
if self.grad_scaler is None:
self._scale_one = torch.cuda.FloatTensor([1.0])
# ======================
# main parameter stuff
# ======================
# Three groups of parameters:
# float16_groups: original float16 parameters
# fp32_from_float16_groups: fp32 copy of float16 parameters
# fp32_from_fp32_groups: original fp32 parameters
self.float16_groups = []
self.fp32_from_float16_groups = []
self.fp32_from_fp32_groups = []
# For all the groups in the original optimizer:
for param_group in self.optimizer.param_groups:
float16_params_this_group = []
fp32_params_this_group = []
fp32_from_float16_params_this_group = []
# For all the parameters in this group:
for i, param in enumerate(param_group['params']):
if param.requires_grad:
# float16 params:
if param.type() in ['torch.cuda.HalfTensor',
'torch.cuda.BFloat16Tensor']:
float16_params_this_group.append(param)
# Create a copy
main_param = param.detach().clone().float()
# Copy tensor model parallel attributes.
mpu.copy_tensor_model_parallel_attributes(main_param,
param)
if hasattr(param, 'shared'):
main_param.shared = param.shared
# Replace the optimizer params with the new fp32 copy.
param_group['params'][i] = main_param
fp32_from_float16_params_this_group.append(main_param)
# Reset existing state dict key to the new main param.
if param in self.optimizer.state:
self.optimizer.state[main_param] \
= self.optimizer.state.pop(param)
# fp32 params.
elif param.type() == 'torch.cuda.FloatTensor':
fp32_params_this_group.append(param)
param_group['params'][i] = param
else:
raise TypeError('Wrapped parameters must be one of '
'torch.cuda.FloatTensor, '
'torch.cuda.HalfTensor, or '
'torch.cuda.BFloat16Tensor. '
'Received {}'.format(param.type()))
self.float16_groups.append(float16_params_this_group)
self.fp32_from_float16_groups.append(
fp32_from_float16_params_this_group)
self.fp32_from_fp32_groups.append(fp32_params_this_group)
# Leverage state_dict() and load_state_dict() to
# recast preexisting per-param state tensors
self.optimizer.load_state_dict(self.optimizer.state_dict())
def zero_grad(self, set_to_none=True):
"""We only need to zero the model related parameters, i.e.,
float16_groups & fp32_from_fp32_groups."""
for group in self.float16_groups:
_zero_grad_group_helper(group, set_to_none)
for group in self.fp32_from_fp32_groups:
_zero_grad_group_helper(group, set_to_none)
def get_loss_scale(self):
if self.grad_scaler is None:
return self._scale_one
return self.grad_scaler.scale
def _copy_model_grads_to_main_grads(self):
# This only needs to be done for the float16 group.
for model_group, main_group in zip(self.float16_groups,
self.fp32_from_float16_groups):
for model_param, main_param in zip(model_group, main_group):
if self.params_have_main_grad:
main_param.grad = model_param.main_grad.float()
else:
if model_param.grad is not None:
main_param.grad = model_param.grad.float()
# For fp32 grads, we need to reset the grads to main grad.
if self.params_have_main_grad:
for model_group in self.fp32_from_fp32_groups:
for model_param in model_group:
model_param.grad = model_param.main_grad
def _unscale_main_grads_and_check_for_nan(self):
main_grads = []
# fp32 params fromm float16 ones.
for main_group in self.fp32_from_float16_groups:
for main_param in main_group:
if main_param.grad is not None:
main_grads.append(main_param.grad.data)
# Append fp32 parameters.
for main_group in self.fp32_from_fp32_groups:
for main_param in main_group:
if main_param.grad is not None:
main_grads.append(main_param.grad.data)
# Reset found inf.
self.found_inf.fill_(0.0)
# Unscale and set found inf/nan
torch._amp_foreach_non_finite_check_and_unscale_(
main_grads, self.found_inf, self.grad_scaler.inv_scale)
# Update across all model parallel instances.
torch.distributed.all_reduce(self.found_inf,
op=torch.distributed.ReduceOp.MAX,
group=mpu.get_model_parallel_group())
# Check for nan.
found_inf_flag = (self.found_inf.item() > 0)
return found_inf_flag
def _get_model_and_main_params_data_float16(self):
model_data = []
main_data = []
for model_group, main_group in zip(self.float16_groups,
self.fp32_from_float16_groups):
for model_param, main_param in zip(model_group, main_group):
model_data.append(model_param.data)
main_data.append(main_param.data)
return model_data, main_data
def _copy_main_params_to_model_params(self):
# Only needed for the float16 params.
model_data, main_data = self._get_model_and_main_params_data_float16()
_multi_tensor_copy_this_to_that(this=main_data, that=model_data,
overflow_buf=self._dummy_overflow_buf)
def _copy_model_params_to_main_params(self):
# Only needed for the float16 params.
model_data, main_data = self._get_model_and_main_params_data_float16()
_multi_tensor_copy_this_to_that(this=model_data, that=main_data,
overflow_buf=self._dummy_overflow_buf)
def reload_model_params(self):
self._copy_model_params_to_main_params()
@torch.no_grad()
def step(self):
timers = get_timers()
# Copy gradients from model params to main params.
timers('optimizer-copy-to-main-grad').start()
self._copy_model_grads_to_main_grads()
timers('optimizer-copy-to-main-grad').stop()
# Do unscale, check for inf, and update grad scaler only for
# the case that grad scaler is provided.
if self.grad_scaler:
# Unscale and check for inf/nan.
timers('optimizer-unscale-and-check-inf').start()
found_inf_flag = self._unscale_main_grads_and_check_for_nan()
timers('optimizer-unscale-and-check-inf').stop()
# We are done with scaling gradients
# so we can update the loss scale.
self.grad_scaler.update(found_inf_flag)
# If we found inf/nan, skip the update.
if found_inf_flag:
return False, None, None
# Clip the main gradients.
timers('optimizer-clip-main-grad').start()
grad_norm = None
if self.clip_grad > 0.0:
grad_norm = self.clip_grad_norm(self.clip_grad)
timers('optimizer-clip-main-grad').stop()
# count the zeros in the grads
num_zeros_in_grad = self.count_zeros() if \
self.log_num_zeros_in_grad else None
# Step the optimizer.
self.optimizer.step()
# Update params from main params.
timers('optimizer-copy-main-to-model-params').start()
self._copy_main_params_to_model_params()
timers('optimizer-copy-main-to-model-params').stop()
# Successful update.
return True, grad_norm, num_zeros_in_grad
def state_dict(self):
state_dict = {}
state_dict['optimizer'] = self.optimizer.state_dict()
if self.grad_scaler:
state_dict['grad_scaler'] = self.grad_scaler.state_dict()
state_dict['fp32_from_fp16_params'] = self.fp32_from_float16_groups
return state_dict
def load_state_dict(self, state_dict):
# Optimizer.
optimizer_key = 'optimizer'
if optimizer_key not in state_dict:
optimizer_key = 'optimizer_state_dict'
print_rank_0('***WARNING*** loading optimizer from '
'an old checkpoint ...')
self.optimizer.load_state_dict(state_dict[optimizer_key])
# Grad scaler.
if 'grad_scaler' not in state_dict:
print_rank_0('***WARNING*** found an old checkpoint, will not '
'load grad scaler ...')
else:
if self.grad_scaler:
self.grad_scaler.load_state_dict(state_dict['grad_scaler'])
else:
print_rank_0('***WARNING*** fould the grad scaler in the '
'checkpoint but it is None in the class. '
'Skipping loading grad scaler ...')
# Copy data for the main params.
fp32_from_float16_params_key = 'fp32_from_fp16_params'
if fp32_from_float16_params_key not in state_dict:
fp32_from_float16_params_key = 'fp32_from_fp16'
for current_group, saved_group in zip(
self.fp32_from_float16_groups,
state_dict[fp32_from_float16_params_key]):
for current_param, saved_param in zip(current_group, saved_group):
current_param.data.copy_(saved_param.data)
class FP32Optimizer(MegatronOptimizer):
def __init__(self, optimizer, clip_grad,
log_num_zeros_in_grad,
params_have_main_grad):
super(FP32Optimizer, self).__init__(
optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad)
self._scale = torch.cuda.FloatTensor([1.0])
def zero_grad(self, set_to_none=True):
"""Copied from torch.optim.optimizer"""
for group in self.optimizer.param_groups:
_zero_grad_group_helper(group['params'], set_to_none)
def get_loss_scale(self):
"""FP32 optimizer does not do any scaling."""
return self._scale
@torch.no_grad()
def step(self):
"""Clip gradients (if needed) and step the base optimizer.
Always return successful since there is no overflow."""
# Copy main_grads to grads.
if self.params_have_main_grad:
for param_group in self.optimizer.param_groups:
for param in param_group['params']:
param.grad = param.main_grad
# Clip gradients.
grad_norm = None
if self.clip_grad > 0.0:
grad_norm = self.clip_grad_norm(self.clip_grad)
# count the zeros in the grads
num_zeros_in_grad = self.count_zeros() if \
self.log_num_zeros_in_grad else None
# Update parameters.
self.optimizer.step()
# No overflow for FP32 optimizer.
return True, grad_norm, num_zeros_in_grad
def reload_model_params(self):
pass
def state_dict(self):
return self.optimizer.state_dict()
def load_state_dict(self, state_dict):
self.optimizer.load_state_dict(state_dict)
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import reduce
import operator
import torch
from megatron import get_args
from megatron import mpu
def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
use_ring_exchange=False):
"""Communicate tensors between stages. Used as helper method in other
communication methods that are used in megatron/schedules.py.
Takes the following arguments:
tensor_send_next: tensor to send to next rank (no tensor sent if
set to None).
tensor_send_prev: tensor to send to prev rank (no tensor sent if
set to None).
recv_prev: boolean for whether tensor should be received from
previous rank.
recv_next: boolean for whether tensor should be received from
next rank.
use_ring_exchange: boolean for whether torch.distributed.ring_exchange()
API should be used.
Returns:
(tensor_recv_prev, tensor_recv_next)
"""
args = get_args()
# Create placeholder tensors for receive in forward and backward directions
# if needed.
tensor_recv_prev = None
tensor_recv_next = None
tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
if args.scatter_gather_tensors_in_pipeline:
tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1) // \
mpu.get_tensor_model_parallel_world_size()
else:
tensor_chunk_shape = tensor_shape
dtype = args.params_dtype
if args.fp32_residual_connection:
dtype = torch.float
if recv_prev:
tensor_recv_prev = torch.empty(tensor_chunk_shape,
requires_grad=True,
device=torch.cuda.current_device(),
dtype=dtype)
if recv_next:
tensor_recv_next = torch.empty(tensor_chunk_shape,
requires_grad=True,
device=torch.cuda.current_device(),
dtype=dtype)
# Split tensor into smaller chunks if using scatter-gather optimization.
if args.scatter_gather_tensors_in_pipeline:
if tensor_send_next is not None:
tensor_send_next = mpu.split_tensor_into_1d_equal_chunks(tensor_send_next)
if tensor_send_prev is not None:
tensor_send_prev = mpu.split_tensor_into_1d_equal_chunks(tensor_send_prev)
# Send tensors in both the forward and backward directions as appropriate.
if use_ring_exchange:
torch.distributed.ring_exchange(tensor_send_prev=tensor_send_prev,
tensor_recv_prev=tensor_recv_prev,
tensor_send_next=tensor_send_next,
tensor_recv_next=tensor_recv_next,
group=mpu.get_pipeline_model_parallel_group())
else:
ops = []
if tensor_send_prev is not None:
send_prev_op = torch.distributed.P2POp(
torch.distributed.isend, tensor_send_prev,
mpu.get_pipeline_model_parallel_prev_rank())
ops.append(send_prev_op)
if tensor_recv_prev is not None:
recv_prev_op = torch.distributed.P2POp(
torch.distributed.irecv, tensor_recv_prev,
mpu.get_pipeline_model_parallel_prev_rank())
ops.append(recv_prev_op)
if tensor_send_next is not None:
send_next_op = torch.distributed.P2POp(
torch.distributed.isend, tensor_send_next,
mpu.get_pipeline_model_parallel_next_rank())
ops.append(send_next_op)
if tensor_recv_next is not None:
recv_next_op = torch.distributed.P2POp(
torch.distributed.irecv, tensor_recv_next,
mpu.get_pipeline_model_parallel_next_rank())
ops.append(recv_next_op)
if len(ops) > 0:
reqs = torch.distributed.batch_isend_irecv(ops)
for req in reqs:
req.wait()
# To protect against race condition when using batch_isend_irecv().
torch.cuda.synchronize()
# If using scatter-gather optimization, gather smaller chunks.
if args.scatter_gather_tensors_in_pipeline:
if recv_prev:
tensor_recv_prev = mpu.gather_split_1d_tensor(
tensor_recv_prev).view(tensor_shape).requires_grad_()
if recv_next:
tensor_recv_next = mpu.gather_split_1d_tensor(
tensor_recv_next).view(tensor_shape).requires_grad_()
return tensor_recv_prev, tensor_recv_next
def recv_forward(timers=None):
"""Receive tensor from previous rank in pipeline (forward receive)."""
if mpu.is_pipeline_first_stage():
input_tensor = None
else:
if timers is not None:
timers('forward-recv').start()
input_tensor, _ = _communicate(
tensor_send_next=None,
tensor_send_prev=None,
recv_prev=True,
recv_next=False)
if timers is not None:
timers('forward-recv').stop()
return input_tensor
def recv_backward(timers=None):
"""Receive tensor from next rank in pipeline (backward receive)."""
if mpu.is_pipeline_last_stage():
output_tensor_grad = None
else:
if timers is not None:
timers('backward-recv').start()
_, output_tensor_grad = _communicate(
tensor_send_next=None,
tensor_send_prev=None,
recv_prev=False,
recv_next=True)
if timers is not None:
timers('backward-recv').stop()
return output_tensor_grad
def send_forward(output_tensor, timers=None):
"""Send tensor to next rank in pipeline (forward send)."""
if not mpu.is_pipeline_last_stage():
if timers is not None:
timers('forward-send').start()
_communicate(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_prev=False,
recv_next=False)
if timers is not None:
timers('forward-send').stop()
def send_backward(input_tensor_grad, timers=None):
"""Send tensor to previous rank in pipeline (backward send)."""
if not mpu.is_pipeline_first_stage():
if timers is not None:
timers('backward-send').start()
_communicate(
tensor_send_next=None,
tensor_send_prev=input_tensor_grad,
recv_prev=False,
recv_next=False)
if timers is not None:
timers('backward-send').stop()
def send_forward_recv_backward(output_tensor, timers=None):
"""Batched send and recv with next rank in pipeline."""
if mpu.is_pipeline_last_stage():
output_tensor_grad = None
else:
if timers is not None:
timers('forward-send-backward-recv').start()
_, output_tensor_grad = _communicate(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_prev=False,
recv_next=True)
if timers is not None:
timers('forward-send-backward-recv').stop()
return output_tensor_grad
def send_backward_recv_forward(input_tensor_grad, timers=None):
"""Batched send and recv with previous rank in pipeline."""
if mpu.is_pipeline_first_stage():
input_tensor = None
else:
if timers is not None:
timers('backward-send-forward-recv').start()
input_tensor, _ = _communicate(
tensor_send_next=None,
tensor_send_prev=input_tensor_grad,
recv_prev=True,
recv_next=False)
if timers is not None:
timers('backward-send-forward-recv').stop()
return input_tensor
def send_forward_recv_forward(output_tensor, recv_prev, timers=None):
"""Batched recv from previous rank and send to next rank in pipeline."""
if timers is not None:
timers('forward-send-forward-recv').start()
input_tensor, _ = _communicate(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_prev=recv_prev,
recv_next=False)
if timers is not None:
timers('forward-send-forward-recv').stop()
return input_tensor
def send_backward_recv_backward(input_tensor_grad, recv_next, timers=None):
"""Batched recv from next rank and send to previous rank in pipeline."""
if timers is not None:
timers('backward-send-backward-recv').start()
_, output_tensor_grad = _communicate(
tensor_send_next=None,
tensor_send_prev=input_tensor_grad,
recv_prev=False,
recv_next=recv_next)
if timers is not None:
timers('backward-send-backward-recv').stop()
return output_tensor_grad
def send_forward_backward_recv_forward_backward(
output_tensor, input_tensor_grad, recv_prev,
recv_next, timers=None):
"""Batched send and recv with previous and next ranks in pipeline."""
if timers is not None:
timers('forward-backward-send-forward-backward-recv').start()
input_tensor, output_tensor_grad = _communicate(
tensor_send_next=output_tensor,
tensor_send_prev=input_tensor_grad,
recv_prev=recv_prev,
recv_next=recv_next)
if timers is not None:
timers('forward-backward-send-forward-backward-recv').stop()
return input_tensor, output_tensor_grad
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
MAJOR = 1
MINOR = 1.5
# Use the following formatting: (major, minor)
VERSION = (MAJOR, MINOR)
__version__ = '.'.join(map(str, VERSION)) + '.bs'
__package_name__ = 'megatron-lm'
__contact_names__ = 'NVIDIA INC'
__url__ = 'https://github.com/NVIDIA/Megatron-LM'
__download_url__ = 'https://github.com/NVIDIA/Megatron-LM/releases'
__description__ = 'Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism.'
__license__ = 'See https://github.com/NVIDIA/Megatron-LM/blob/master/LICENSE'
__keywords__ = 'deep learning, Megatron, gpu, NLP, nvidia, pytorch, torch, language'
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron import get_args
from megatron import get_num_microbatches
from megatron import get_timers
from megatron import mpu
from megatron import p2p_communication
from megatron.utils import unwrap_model
from megatron.model.distributed import DistributedDataParallel as LocalDDP
from megatron.model.module import Float16Module
def get_forward_backward_func():
args = get_args()
if mpu.get_pipeline_model_parallel_world_size() > 1:
if args.virtual_pipeline_model_parallel_size is not None:
forward_backward_func = forward_backward_pipelining_with_interleaving
else:
forward_backward_func = forward_backward_pipelining_without_interleaving
else:
forward_backward_func = forward_backward_no_pipelining
return forward_backward_func
def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_reduced):
"""Forward step for passed-in model.
If first stage, input tensor is obtained from data_iterator, otherwise
passed-in input_tensor is used.
Returns output tensor."""
timers = get_timers()
args = get_args()
timers('forward-compute').start()
unwrapped_model = unwrap_model(
model, (torchDDP, LocalDDP, Float16Module))
if not args.deepspeed:
unwrapped_model.set_input_tensor(input_tensor)
else:
unwrapped_model.module.set_input_tensor(input_tensor)
output_tensor, loss_func = forward_step_func(data_iterator, model)
if mpu.is_pipeline_last_stage():
if 'accuracy'in loss_func:
loss_reduced = loss_func
output_tensor = output_tensor / get_num_microbatches()
losses_reduced.append(loss_reduced)
else:
output_tensor = loss_func(output_tensor)
loss, loss_reduced = output_tensor
output_tensor = loss / get_num_microbatches()
losses_reduced.append(loss_reduced)
timers('forward-compute').stop()
return output_tensor
def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad, model=None):
"""Backward step through passed-in output tensor.
If last stage, output_tensor_grad is None, otherwise gradient of loss
with respect to stage's output tensor.
Returns gradient of loss with respect to input tensor (None if first
stage)."""
args = get_args()
if args.deepspeed:
assert model is not None
timers = get_timers()
timers('backward-compute').start()
# Retain the grad on the input_tensor.
if input_tensor is not None:
input_tensor.retain_grad()
if args.deepspeed:
model.backward(output_tensor)
else:
# Backward pass.
if output_tensor_grad is None:
output_tensor = optimizer.scale_loss(output_tensor)
torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad)
# Collect the grad of the input_tensor.
input_tensor_grad = None
if input_tensor is not None:
input_tensor_grad = input_tensor.grad
timers('backward-compute').stop()
return input_tensor_grad
@contextmanager
def dummy_handler():
try:
yield
finally:
pass
def forward_backward_no_pipelining(forward_step_func, data_iterator, model,
optimizer, timers, forward_only):
"""Run forward and backward passes with no pipeline parallelism
(no inter-stage communication).
Returns dictionary with losses."""
assert len(model) == 1
model = model[0]
context_handler = dummy_handler
if isinstance(model, torchDDP):
context_handler = model.no_sync
losses_reduced = []
input_tensor, output_tensor_grad = None, None
with context_handler():
for i in range(get_num_microbatches() - 1):
output_tensor = forward_step(forward_step_func, data_iterator, model,
input_tensor, losses_reduced)
if not forward_only:
backward_step(optimizer, input_tensor, output_tensor,
output_tensor_grad, model)
# Run computation for last microbatch out of context handler (want to
# synchronize gradients).
output_tensor = forward_step(forward_step_func, data_iterator, model,
input_tensor, losses_reduced)
if not forward_only:
backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad, model)
return losses_reduced
def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterator, model,
optimizer, timers, forward_only):
"""Run interleaved 1F1B schedule (model split into model chunks), with
communication between pipeline stages as needed.
Returns dictionary with losses if the last stage, empty dict otherwise."""
input_tensors = [[] for _ in range(len(model))]
output_tensors = [[] for _ in range(len(model))]
losses_reduced = []
if not forward_only:
output_tensor_grads = [[] for _ in range(len(model))]
pipeline_parallel_size = mpu.get_pipeline_model_parallel_world_size()
pipeline_parallel_rank = mpu.get_pipeline_model_parallel_rank()
# Compute number of warmup and remaining microbatches.
num_model_chunks = len(model)
num_microbatches = get_num_microbatches() * num_model_chunks
all_warmup_microbatches = False
if forward_only:
num_warmup_microbatches = num_microbatches
else:
# Run all forward passes and then all backward passes if number of
# microbatches is just the number of pipeline stages.
# Otherwise, perform (num_model_chunks-1)*pipeline_parallel_size on
# all workers, followed by more microbatches after depending on
# stage ID (more forward passes for earlier stages, later stages can
# immediately start with 1F1B).
if get_num_microbatches() == pipeline_parallel_size:
num_warmup_microbatches = num_microbatches
all_warmup_microbatches = True
else:
num_warmup_microbatches = \
(pipeline_parallel_size - pipeline_parallel_rank - 1) * 2
num_warmup_microbatches += (
num_model_chunks - 1) * pipeline_parallel_size
num_warmup_microbatches = min(num_warmup_microbatches,
num_microbatches)
num_microbatches_remaining = \
num_microbatches - num_warmup_microbatches
def get_model_chunk_id(microbatch_id, forward):
"""Helper method to get the model chunk ID given the iteration number."""
microbatch_id_in_group = microbatch_id % (pipeline_parallel_size * num_model_chunks)
model_chunk_id = microbatch_id_in_group // pipeline_parallel_size
if not forward:
model_chunk_id = (num_model_chunks - model_chunk_id - 1)
return model_chunk_id
def forward_step_helper(microbatch_id):
"""Helper method to run forward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
forward_step())."""
model_chunk_id = get_model_chunk_id(microbatch_id, forward=True)
mpu.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
if mpu.is_pipeline_first_stage():
if len(input_tensors[model_chunk_id]) == \
len(output_tensors[model_chunk_id]):
input_tensors[model_chunk_id].append(None)
input_tensor = input_tensors[model_chunk_id][-1]
output_tensor = forward_step(forward_step_func,
data_iterator[model_chunk_id],
model[model_chunk_id],
input_tensor, losses_reduced)
output_tensors[model_chunk_id].append(output_tensor)
return output_tensor
def backward_step_helper(microbatch_id):
"""Helper method to run backward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
backward_step())."""
model_chunk_id = get_model_chunk_id(microbatch_id, forward=False)
mpu.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
if mpu.is_pipeline_last_stage():
if len(output_tensor_grads[model_chunk_id]) == 0:
output_tensor_grads[model_chunk_id].append(None)
input_tensor = input_tensors[model_chunk_id].pop(0)
output_tensor = output_tensors[model_chunk_id].pop(0)
output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0)
input_tensor_grad = \
backward_step(optimizer,
input_tensor,
output_tensor,
output_tensor_grad)
return input_tensor_grad
# Run warmup forward passes.
mpu.set_virtual_pipeline_model_parallel_rank(0)
input_tensors[0].append(
p2p_communication.recv_forward(timers))
for k in range(num_warmup_microbatches):
output_tensor = forward_step_helper(k)
# Determine if tensor should be received from previous stage.
next_forward_model_chunk_id = get_model_chunk_id(k+1, forward=True)
recv_prev = True
if mpu.is_pipeline_first_stage(ignore_virtual=True):
if next_forward_model_chunk_id == 0:
recv_prev = False
if k == (num_microbatches - 1):
recv_prev = False
# Don't send tensor downstream if on last stage.
if mpu.is_pipeline_last_stage():
output_tensor = None
# Send and receive tensors as appropriate (send tensors computed
# in this iteration; receive tensors for next iteration).
if k == (num_warmup_microbatches - 1) and not forward_only and \
not all_warmup_microbatches:
input_tensor_grad = None
recv_next = True
if mpu.is_pipeline_last_stage(ignore_virtual=True):
recv_next = False
input_tensor, output_tensor_grad = \
p2p_communication.send_forward_backward_recv_forward_backward(
output_tensor, input_tensor_grad,
recv_prev=recv_prev, recv_next=recv_next,
timers=timers)
output_tensor_grads[num_model_chunks-1].append(output_tensor_grad)
else:
input_tensor = \
p2p_communication.send_forward_recv_forward(
output_tensor, recv_prev, timers)
input_tensors[next_forward_model_chunk_id].append(input_tensor)
# Run 1F1B in steady state.
for k in range(num_microbatches_remaining):
# Forward pass.
forward_k = k + num_warmup_microbatches
output_tensor = forward_step_helper(forward_k)
# Backward pass.
backward_k = k
input_tensor_grad = backward_step_helper(backward_k)
# Send output_tensor and input_tensor_grad, receive input_tensor
# and output_tensor_grad.
# Determine if current stage has anything to send in either direction,
# otherwise set tensor to None.
forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
mpu.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id)
if mpu.is_pipeline_last_stage():
output_tensor = None
backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
mpu.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id)
if mpu.is_pipeline_first_stage():
input_tensor_grad = None
# Determine if peers are sending, and where in data structure to put
# received tensors.
recv_prev = True
if mpu.is_pipeline_first_stage(ignore_virtual=True):
# First stage is ahead of last stage by (pipeline_parallel_size - 1).
next_forward_model_chunk_id = get_model_chunk_id(
forward_k - (pipeline_parallel_size - 1), forward=True)
if next_forward_model_chunk_id == (num_model_chunks - 1):
recv_prev = False
next_forward_model_chunk_id += 1
else:
next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1,
forward=True)
recv_next = True
if mpu.is_pipeline_last_stage(ignore_virtual=True):
# Last stage is ahead of first stage by (pipeline_parallel_size - 1).
next_backward_model_chunk_id = get_model_chunk_id(
backward_k - (pipeline_parallel_size - 1), forward=False)
if next_backward_model_chunk_id == 0:
recv_next = False
next_backward_model_chunk_id -= 1
else:
next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1,
forward=False)
# If last iteration, don't receive; we already received one extra
# before the start of the for loop.
if k == (num_microbatches_remaining - 1):
recv_prev = False
# Communicate tensors.
input_tensor, output_tensor_grad = \
p2p_communication.send_forward_backward_recv_forward_backward(
output_tensor, input_tensor_grad,
recv_prev=recv_prev, recv_next=recv_next,
timers=timers)
# Put input_tensor and output_tensor_grad in data structures in the
# right location.
if recv_prev:
input_tensors[next_forward_model_chunk_id].append(input_tensor)
if recv_next:
output_tensor_grads[next_backward_model_chunk_id].append(
output_tensor_grad)
# Run cooldown backward passes (flush out pipeline).
if not forward_only:
if all_warmup_microbatches:
output_tensor_grads[num_model_chunks-1].append(
p2p_communication.recv_backward(timers))
for k in range(num_microbatches_remaining, num_microbatches):
input_tensor_grad = backward_step_helper(k)
next_backward_model_chunk_id = get_model_chunk_id(k+1, forward=False)
recv_next = True
if mpu.is_pipeline_last_stage(ignore_virtual=True):
if next_backward_model_chunk_id == (num_model_chunks - 1):
recv_next = False
if k == (num_microbatches - 1):
recv_next = False
output_tensor_grads[next_backward_model_chunk_id].append(
p2p_communication.send_backward_recv_backward(
input_tensor_grad, recv_next, timers))
return losses_reduced
def forward_backward_pipelining_without_interleaving(forward_step_func, data_iterator,
model, optimizer, timers,
forward_only):
"""Run non-interleaved 1F1B schedule, with communication between pipeline
stages.
Returns dictionary with losses if the last stage, empty dict otherwise."""
timers = get_timers()
assert len(model) == 1
model = model[0]
# Compute number of warmup microbatches.
num_microbatches = get_num_microbatches()
num_warmup_microbatches = \
(mpu.get_pipeline_model_parallel_world_size() -
mpu.get_pipeline_model_parallel_rank() - 1)
num_warmup_microbatches = min(
num_warmup_microbatches,
num_microbatches)
num_microbatches_remaining = \
num_microbatches - num_warmup_microbatches
input_tensors = []
output_tensors = []
losses_reduced = []
# Run warmup forward passes.
for i in range(num_warmup_microbatches):
input_tensor = p2p_communication.recv_forward(timers)
output_tensor = forward_step(forward_step_func, data_iterator, model,
input_tensor, losses_reduced)
p2p_communication.send_forward(output_tensor, timers)
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
# Before running 1F1B, need to receive first forward tensor.
# If all microbatches are run in warmup / cooldown phase, then no need to
# receive this tensor here.
if num_microbatches_remaining > 0:
input_tensor = p2p_communication.recv_forward(timers)
# Run 1F1B in steady state.
for i in range(num_microbatches_remaining):
last_iteration = (i == (num_microbatches_remaining - 1))
output_tensor = forward_step(forward_step_func, data_iterator, model,
input_tensor, losses_reduced)
if forward_only:
p2p_communication.send_forward(output_tensor, timers)
else:
output_tensor_grad = \
p2p_communication.send_forward_recv_backward(output_tensor,
timers)
# Add input_tensor and output_tensor to end of list, then pop from the
# start of the list for backward pass.
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
if forward_only:
if not last_iteration:
input_tensor = p2p_communication.recv_forward(timers)
else:
input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0)
input_tensor_grad = \
backward_step(optimizer, input_tensor, output_tensor,
output_tensor_grad, model)
if last_iteration:
input_tensor = None
p2p_communication.send_backward(input_tensor_grad, timers)
else:
input_tensor = \
p2p_communication.send_backward_recv_forward(
input_tensor_grad, timers)
# Run cooldown backward passes.
if not forward_only:
for i in range(num_warmup_microbatches):
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)
output_tensor_grad = p2p_communication.recv_backward(timers)
input_tensor_grad = \
backward_step(optimizer, input_tensor, output_tensor,
output_tensor_grad, model)
p2p_communication.send_backward(input_tensor_grad, timers)
return losses_reduced
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import importlib.util
import inspect
import logging
import numpy as np
import os
import random
import re
import shutil
import sys
import tempfile
import unittest
from distutils.util import strtobool
from io import StringIO
from packaging import version
from pathlib import Path
from typing import Iterator, Union
from unittest import mock
from unittest.case import SkipTest
try:
import torch
_torch_available = True
except:
_torch_available = False
def is_torch_available():
return _torch_available
def parse_flag_from_env(key, default=False):
try:
value = os.environ[key]
except KeyError:
# KEY isn't set, default to `default`.
_value = default
else:
# KEY is set, convert it to True or False.
try:
_value = strtobool(value)
except ValueError:
# More values are supported, but let's keep the message simple.
raise ValueError(f"If set, {key} must be yes or no.")
return _value
def parse_int_from_env(key, default=None):
try:
value = os.environ[key]
except KeyError:
_value = default
else:
try:
_value = int(value)
except ValueError:
raise ValueError(f"If set, {key} must be a int.")
return _value
def require_torch(test_case):
"""
Decorator marking a test that requires PyTorch.
These tests are skipped when PyTorch isn't installed.
"""
if not is_torch_available():
return unittest.skip("test requires PyTorch")(test_case)
else:
return test_case
def require_torch_multi_gpu(test_case):
"""
Decorator marking a test that requires a multi-GPU setup (in PyTorch). These tests are skipped on a machine without
multiple GPUs.
To run *only* the multi_gpu tests, assuming all test names contain multi_gpu: $ pytest -sv ./tests -k "multi_gpu"
"""
if not is_torch_available():
return unittest.skip("test requires PyTorch")(test_case)
import torch
if torch.cuda.device_count() < 2:
return unittest.skip("test requires multiple GPUs")(test_case)
else:
return test_case
def require_torch_non_multi_gpu(test_case):
"""
Decorator marking a test that requires 0 or 1 GPU setup (in PyTorch).
"""
if not is_torch_available():
return unittest.skip("test requires PyTorch")(test_case)
import torch
if torch.cuda.device_count() > 1:
return unittest.skip("test requires 0 or 1 GPU")(test_case)
else:
return test_case
def require_torch_up_to_2_gpus(test_case):
"""
Decorator marking a test that requires 0 or 1 or 2 GPU setup (in PyTorch).
"""
if not is_torch_available():
return unittest.skip("test requires PyTorch")(test_case)
import torch
if torch.cuda.device_count() > 2:
return unittest.skip("test requires 0 or 1 or 2 GPUs")(test_case)
else:
return test_case
def require_torch_tpu(test_case):
"""
Decorator marking a test that requires a TPU (in PyTorch).
"""
if not is_torch_tpu_available():
return unittest.skip("test requires PyTorch TPU")
else:
return test_case
if is_torch_available():
# Set env var CUDA_VISIBLE_DEVICES="" to force cpu-mode
import torch
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
else:
torch_device = None
def require_torch_gpu(test_case):
"""Decorator marking a test that requires CUDA and PyTorch."""
if torch_device != "cuda":
return unittest.skip("test requires CUDA")(test_case)
else:
return test_case
def require_datasets(test_case):
"""Decorator marking a test that requires datasets."""
if not is_datasets_available():
return unittest.skip("test requires `datasets`")(test_case)
else:
return test_case
def is_deepspeed_available():
return importlib.util.find_spec("deepspeed") is not None
def require_deepspeed(test_case):
"""
Decorator marking a test that requires deepspeed
"""
if not is_deepspeed_available():
return unittest.skip("test requires deepspeed")(test_case)
else:
return test_case
def is_bnb_available():
return importlib.util.find_spec("bitsandbytes") is not None
def require_bnb(test_case):
"""
Decorator marking a test that requires bitsandbytes
"""
if not is_bnb_available():
return unittest.skip("test requires bitsandbytes from https://github.com/facebookresearch/bitsandbytes")(test_case)
else:
return test_case
def require_bnb_non_decorator():
"""
Non-Decorator function that would skip a test if bitsandbytes is missing
"""
if not is_bnb_available():
raise SkipTest("Test requires bitsandbytes from https://github.com/facebookresearch/bitsandbytes")
def set_seed(seed: int=42):
"""
Helper function for reproducible behavior to set the seed in ``random``, ``numpy``, ``torch``
Args:
seed (:obj:`int`): The seed to set.
"""
random.seed(seed)
np.random.seed(seed)
if is_torch_available():
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# ^^ safe to call this function even if cuda is not available
def get_gpu_count():
"""
Return the number of available gpus (regardless of whether torch or tf is used)
"""
if is_torch_available():
import torch
return torch.cuda.device_count()
elif is_tf_available():
import tensorflow as tf
return len(tf.config.list_physical_devices("GPU"))
else:
return 0
def torch_assert_equal(actual, expected, **kwargs):
# assert_close was added around pt-1.9, it does better checks - e.g will check dimensions match
if hasattr(torch.testing, "assert_close"):
return torch.testing.assert_close(actual, expected, rtol=0.0, atol=0.0, **kwargs)
else:
return torch.allclose(actual, expected, rtol=0.0, atol=0.0)
def torch_assert_close(actual, expected, **kwargs):
# assert_close was added around pt-1.9, it does better checks - e.g will check dimensions match
if hasattr(torch.testing, "assert_close"):
return torch.testing.assert_close(actual, expected, **kwargs)
else:
kwargs.pop("msg", None) # doesn't have msg arg
return torch.allclose(actual, expected, **kwargs)
def is_torch_bf16_available():
# from https://github.com/huggingface/transformers/blob/26eb566e43148c80d0ea098c76c3d128c0281c16/src/transformers/file_utils.py#L301
if is_torch_available():
import torch
if not torch.cuda.is_available() or torch.version.cuda is None:
return False
if torch.cuda.get_device_properties(torch.cuda.current_device()).major < 8:
return False
if int(torch.version.cuda.split(".")[0]) < 11:
return False
if not version.parse(torch.__version__) >= version.parse("1.09"):
return False
return True
else:
return False
def require_torch_bf16(test_case):
"""Decorator marking a test that requires CUDA hardware supporting bf16 and PyTorch >= 1.9."""
if not is_torch_bf16_available():
return unittest.skip("test requires CUDA hardware supporting bf16 and PyTorch >= 1.9")(test_case)
else:
return test_case
def get_tests_dir(append_path=None):
"""
Args:
append_path: optional path to append to the tests dir path
Return:
The full path to the `tests` dir, so that the tests can be invoked from anywhere. Optionally `append_path` is
joined after the `tests` dir the former is provided.
"""
# this function caller's __file__
caller__file__ = inspect.stack()[1][1]
tests_dir = os.path.abspath(os.path.dirname(caller__file__))
if append_path:
return os.path.join(tests_dir, append_path)
else:
return tests_dir
#
# Helper functions for dealing with testing text outputs
# The original code came from:
# https://github.com/fastai/fastai/blob/master/tests/utils/text.py
# When any function contains print() calls that get overwritten, like progress bars,
# a special care needs to be applied, since under pytest -s captured output (capsys
# or contextlib.redirect_stdout) contains any temporary printed strings, followed by
# \r's. This helper function ensures that the buffer will contain the same output
# with and without -s in pytest, by turning:
# foo bar\r tar mar\r final message
# into:
# final message
# it can handle a single string or a multiline buffer
def apply_print_resets(buf):
return re.sub(r"^.*\r", "", buf, 0, re.M)
def assert_screenout(out, what):
out_pr = apply_print_resets(out).lower()
match_str = out_pr.find(what.lower())
assert match_str != -1, f"expecting to find {what} in output: f{out_pr}"
class CaptureStd:
"""
Context manager to capture:
- stdout: replay it, clean it up and make it available via ``obj.out``
- stderr: replay it and make it available via ``obj.err``
init arguments:
- out - capture stdout:`` True``/``False``, default ``True``
- err - capture stdout: ``True``/``False``, default ``True``
- replay - whether to replay or not: ``True``/``False``, default ``True``. By default each
captured stream gets replayed back on context's exit, so that one can see what the test was
doing. If this is a not wanted behavior and the captured data shouldn't be replayed, pass
``replay=False`` to disable this feature.
Examples::
# to capture stdout only with auto-replay
with CaptureStdout() as cs:
print("Secret message")
assert "message" in cs.out
# to capture stderr only with auto-replay
import sys
with CaptureStderr() as cs:
print("Warning: ", file=sys.stderr)
assert "Warning" in cs.err
# to capture both streams with auto-replay
with CaptureStd() as cs:
print("Secret message")
print("Warning: ", file=sys.stderr)
assert "message" in cs.out
assert "Warning" in cs.err
# to capture just one of the streams, and not the other, with auto-replay
with CaptureStd(err=False) as cs:
print("Secret message")
assert "message" in cs.out
# but best use the stream-specific subclasses
# to capture without auto-replay
with CaptureStd(replay=False) as cs:
print("Secret message")
assert "message" in cs.out
"""
def __init__(self, out=True, err=True, replay=True):
self.replay = replay
if out:
self.out_buf = StringIO()
self.out = "error: CaptureStd context is unfinished yet, called too early"
else:
self.out_buf = None
self.out = "not capturing stdout"
if err:
self.err_buf = StringIO()
self.err = "error: CaptureStd context is unfinished yet, called too early"
else:
self.err_buf = None
self.err = "not capturing stderr"
def __enter__(self):
if self.out_buf:
self.out_old = sys.stdout
sys.stdout = self.out_buf
if self.err_buf:
self.err_old = sys.stderr
sys.stderr = self.err_buf
return self
def __exit__(self, *exc):
if self.out_buf:
sys.stdout = self.out_old
captured = self.out_buf.getvalue()
if self.replay:
sys.stdout.write(captured)
self.out = apply_print_resets(captured)
if self.err_buf:
sys.stderr = self.err_old
captured = self.err_buf.getvalue()
if self.replay:
sys.stderr.write(captured)
self.err = captured
def __repr__(self):
msg = ""
if self.out_buf:
msg += f"stdout: {self.out}\n"
if self.err_buf:
msg += f"stderr: {self.err}\n"
return msg
# in tests it's the best to capture only the stream that's wanted, otherwise
# it's easy to miss things, so unless you need to capture both streams, use the
# subclasses below (less typing). Or alternatively, configure `CaptureStd` to
# disable the stream you don't need to test.
class CaptureStdout(CaptureStd):
"""Same as CaptureStd but captures only stdout"""
def __init__(self, replay=True):
super().__init__(err=False, replay=replay)
class CaptureStderr(CaptureStd):
"""Same as CaptureStd but captures only stderr"""
def __init__(self, replay=True):
super().__init__(out=False, replay=replay)
class CaptureLogger:
"""
Context manager to capture `logging` streams
Args:
- logger: 'logging` logger object
Results:
The captured output is available via `self.out`
Example::
>>> from transformers import logging
>>> from transformers.testing_utils import CaptureLogger
>>> msg = "Testing 1, 2, 3"
>>> logging.set_verbosity_info()
>>> logger = logging.get_logger("transformers.models.bart.tokenization_bart")
>>> with CaptureLogger(logger) as cl:
... logger.info(msg)
>>> assert cl.out, msg+"\n"
"""
def __init__(self, logger):
self.logger = logger
self.io = StringIO()
self.sh = logging.StreamHandler(self.io)
self.out = ""
def __enter__(self):
self.logger.addHandler(self.sh)
return self
def __exit__(self, *exc):
self.logger.removeHandler(self.sh)
self.out = self.io.getvalue()
def __repr__(self):
return f"captured: {self.out}\n"
@contextlib.contextmanager
# adapted from https://stackoverflow.com/a/64789046/9201239
def ExtendSysPath(path: Union[str, os.PathLike]) -> Iterator[None]:
"""
Temporary add given path to `sys.path`.
Usage ::
with ExtendSysPath('/path/to/dir'):
mymodule = importlib.import_module('mymodule')
"""
path = os.fspath(path)
try:
sys.path.insert(0, path)
yield
finally:
sys.path.remove(path)
class TestCasePlus(unittest.TestCase):
"""
This class extends `unittest.TestCase` with additional features.
Feature 1: A set of fully resolved important file and dir path accessors.
In tests often we need to know where things are relative to the current test file, and it's not trivial since the
test could be invoked from more than one directory or could reside in sub-directories with different depths. This
class solves this problem by sorting out all the basic paths and provides easy accessors to them:
* ``pathlib`` objects (all fully resolved):
- ``test_file_path`` - the current test file path (=``__file__``)
- ``test_file_dir`` - the directory containing the current test file
- ``tests_dir`` - the directory of the ``tests`` test suite
- ``data_dir`` - the directory of the ``tests/data`` test suite
- ``repo_root_dir`` - the directory of the repository
- ``src_dir`` - the directory of ``src`` (i.e. where the ``transformers`` sub-dir resides)
* stringified paths---same as above but these return paths as strings, rather than ``pathlib`` objects:
- ``test_file_path_str``
- ``test_file_dir_str``
- ``tests_dir_str``
- ``data_dir_str``
- ``repo_root_dir_str``
- ``src_dir_str``
Feature 2: Flexible auto-removable temporary dirs which are guaranteed to get removed at the end of test.
1. Create a unique temporary dir:
::
def test_whatever(self):
tmp_dir = self.get_auto_remove_tmp_dir()
``tmp_dir`` will contain the path to the created temporary dir. It will be automatically removed at the end of the
test.
2. Create a temporary dir of my choice, ensure it's empty before the test starts and don't
empty it after the test.
::
def test_whatever(self):
tmp_dir = self.get_auto_remove_tmp_dir("./xxx")
This is useful for debug when you want to monitor a specific directory and want to make sure the previous tests
didn't leave any data in there.
3. You can override the first two options by directly overriding the ``before`` and ``after`` args, leading to the
following behavior:
``before=True``: the temporary dir will always be cleared at the beginning of the test.
``before=False``: if the temporary dir already existed, any existing files will remain there.
``after=True``: the temporary dir will always be deleted at the end of the test.
``after=False``: the temporary dir will always be left intact at the end of the test.
Note 1: In order to run the equivalent of ``rm -r`` safely, only subdirs of the project repository checkout are
allowed if an explicit ``tmp_dir`` is used, so that by mistake no ``/tmp`` or similar important part of the
filesystem will get nuked. i.e. please always pass paths that start with ``./``
Note 2: Each test can register multiple temporary dirs and they all will get auto-removed, unless requested
otherwise.
Feature 3: Get a copy of the ``os.environ`` object that sets up ``PYTHONPATH`` specific to the current test suite.
This is useful for invoking external programs from the test suite - e.g. distributed training.
::
def test_whatever(self):
env = self.get_env()
"""
def setUp(self):
# get_auto_remove_tmp_dir feature:
self.teardown_tmp_dirs = []
# figure out the resolved paths for repo_root, tests, etc.
self._test_file_path = inspect.getfile(self.__class__)
path = Path(self._test_file_path).resolve()
self._test_file_dir = path.parents[0]
for up in [1, 2, 3]:
tmp_dir = path.parents[up]
if (tmp_dir / "megatron").is_dir() and (tmp_dir / "tests").is_dir():
break
if tmp_dir:
self._repo_root_dir = tmp_dir
else:
raise ValueError(f"can't figure out the root of the repo from {self._test_file_path}")
self._tests_dir = self._repo_root_dir / "tests"
self._data_dir = self._repo_root_dir / "tests" / "data"
self._src_dir = self._repo_root_dir # megatron doesn't use "src/" prefix in the repo
@property
def test_file_path(self):
return self._test_file_path
@property
def test_file_path_str(self):
return str(self._test_file_path)
@property
def test_file_dir(self):
return self._test_file_dir
@property
def test_file_dir_str(self):
return str(self._test_file_dir)
@property
def tests_dir(self):
return self._tests_dir
@property
def tests_dir_str(self):
return str(self._tests_dir)
@property
def data_dir(self):
return self._data_dir
@property
def data_dir_str(self):
return str(self._data_dir)
@property
def repo_root_dir(self):
return self._repo_root_dir
@property
def repo_root_dir_str(self):
return str(self._repo_root_dir)
@property
def src_dir(self):
return self._src_dir
@property
def src_dir_str(self):
return str(self._src_dir)
def get_env(self):
"""
Return a copy of the ``os.environ`` object that sets up ``PYTHONPATH`` correctly. This is useful
for invoking external programs from the test suite - e.g. distributed training.
It always inserts ``.`` first, then ``./tests`` depending on the test suite type and
finally the preset ``PYTHONPATH`` if any (all full resolved paths).
"""
env = os.environ.copy()
paths = [self.src_dir_str]
paths.append(self.tests_dir_str)
paths.append(env.get("PYTHONPATH", ""))
env["PYTHONPATH"] = ":".join(paths)
return env
def get_auto_remove_tmp_dir(self, tmp_dir=None, before=None, after=None):
"""
Args:
tmp_dir (:obj:`string`, `optional`):
if :obj:`None`:
- a unique temporary path will be created
- sets ``before=True`` if ``before`` is :obj:`None`
- sets ``after=True`` if ``after`` is :obj:`None`
else:
- :obj:`tmp_dir` will be created
- sets ``before=True`` if ``before`` is :obj:`None`
- sets ``after=False`` if ``after`` is :obj:`None`
before (:obj:`bool`, `optional`):
If :obj:`True` and the :obj:`tmp_dir` already exists, make sure to empty it right away if :obj:`False`
and the :obj:`tmp_dir` already exists, any existing files will remain there.
after (:obj:`bool`, `optional`):
If :obj:`True`, delete the :obj:`tmp_dir` at the end of the test if :obj:`False`, leave the
:obj:`tmp_dir` and its contents intact at the end of the test.
Returns:
tmp_dir(:obj:`string`): either the same value as passed via `tmp_dir` or the path to the auto-selected tmp
dir
"""
if tmp_dir is not None:
# defining the most likely desired behavior for when a custom path is provided.
# this most likely indicates the debug mode where we want an easily locatable dir that:
# 1. gets cleared out before the test (if it already exists)
# 2. is left intact after the test
if before is None:
before = True
if after is None:
after = False
# using provided path
path = Path(tmp_dir).resolve()
# to avoid nuking parts of the filesystem, only relative paths are allowed
if not tmp_dir.startswith("./"):
raise ValueError(
f"`tmp_dir` can only be a relative path, i.e. `./some/path`, but received `{tmp_dir}`"
)
# ensure the dir is empty to start with
if before is True and path.exists():
shutil.rmtree(tmp_dir, ignore_errors=True)
path.mkdir(parents=True, exist_ok=True)
else:
# defining the most likely desired behavior for when a unique tmp path is auto generated
# (not a debug mode), here we require a unique tmp dir that:
# 1. is empty before the test (it will be empty in this situation anyway)
# 2. gets fully removed after the test
if before is None:
before = True
if after is None:
after = True
# using unique tmp dir (always empty, regardless of `before`)
tmp_dir = tempfile.mkdtemp()
if after is True:
# register for deletion
self.teardown_tmp_dirs.append(tmp_dir)
return tmp_dir
def tearDown(self):
# get_auto_remove_tmp_dir feature: remove registered temp dirs
for path in self.teardown_tmp_dirs:
shutil.rmtree(path, ignore_errors=True)
self.teardown_tmp_dirs = []
def mockenv(**kwargs):
"""
this is a convenience wrapper, that allows this ::
@mockenv(RUN_SLOW=True, USE_TF=False)
def test_something():
run_slow = os.getenv("RUN_SLOW", False)
use_tf = os.getenv("USE_TF", False)
"""
return mock.patch.dict(os.environ, kwargs)
# from https://stackoverflow.com/a/34333710/9201239
@contextlib.contextmanager
def mockenv_context(*remove, **update):
"""
Temporarily updates the ``os.environ`` dictionary in-place. Similar to mockenv
The ``os.environ`` dictionary is updated in-place so that the modification is sure to work in all situations.
Args:
remove: Environment variables to remove.
update: Dictionary of environment variables and values to add/update.
"""
env = os.environ
update = update or {}
remove = remove or []
# List of environment variables being updated or removed.
stomped = (set(update.keys()) | set(remove)) & set(env.keys())
# Environment variables and values to restore on exit.
update_after = {k: env[k] for k in stomped}
# Environment variables and values to remove on exit.
remove_after = frozenset(k for k in update if k not in env)
try:
env.update(update)
[env.pop(k, None) for k in remove]
yield
finally:
env.update(update_after)
[env.pop(k) for k in remove_after]
# --- distributed testing functions --- #
# adapted from https://stackoverflow.com/a/59041913/9201239
import asyncio # noqa
class _RunOutput:
def __init__(self, returncode, stdout, stderr):
self.returncode = returncode
self.stdout = stdout
self.stderr = stderr
async def _read_stream(stream, callback):
while True:
line = await stream.readline()
if line:
callback(line)
else:
break
async def _stream_subprocess(cmd, env=None, stdin=None, timeout=None, quiet=False, echo=False) -> _RunOutput:
if echo:
print("\nRunning: ", " ".join(cmd))
p = await asyncio.create_subprocess_exec(
cmd[0],
*cmd[1:],
stdin=stdin,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
env=env,
)
# note: there is a warning for a possible deadlock when using `wait` with huge amounts of data in the pipe
# https://docs.python.org/3/library/asyncio-subprocess.html#asyncio.asyncio.subprocess.Process.wait
#
# If it starts hanging, will need to switch to the following code. The problem is that no data
# will be seen until it's done and if it hangs for example there will be no debug info.
# out, err = await p.communicate()
# return _RunOutput(p.returncode, out, err)
out = []
err = []
def tee(line, sink, pipe, label=""):
line = line.decode("utf-8").rstrip()
sink.append(line)
if not quiet:
print(label, line, file=pipe)
# XXX: the timeout doesn't seem to make any difference here
await asyncio.wait(
[
_read_stream(p.stdout, lambda l: tee(l, out, sys.stdout, label="stdout:")),
_read_stream(p.stderr, lambda l: tee(l, err, sys.stderr, label="stderr:")),
],
timeout=timeout,
)
return _RunOutput(await p.wait(), out, err)
def execute_subprocess_async(cmd, env=None, stdin=None, timeout=180, quiet=False, echo=True) -> _RunOutput:
loop = asyncio.get_event_loop()
result = loop.run_until_complete(
_stream_subprocess(cmd, env=env, stdin=stdin, timeout=timeout, quiet=quiet, echo=echo)
)
cmd_str = " ".join(cmd)
if result.returncode > 0:
stderr = "\n".join(result.stderr)
raise RuntimeError(
f"'{cmd_str}' failed with returncode {result.returncode}\n\n"
f"The combined stderr from workers follows:\n{stderr}"
)
# check that the subprocess actually did run and produced some output, should the test rely on
# the remote side to do the testing
if not result.stdout and not result.stderr:
raise RuntimeError(f"'{cmd_str}' produced no output.")
return result
# --- Misc utils --- #
def flatten_arguments(args):
"""
Converts dictionary argument to a list.
Note: we add "IGNORED" at the beginning as this value is ignored by the argparser
Example: {"arg1": "value1", "arg2": "value2"} -> ["IGNORED", "arg1", "value1", "arg2", "value2"]
"""
return ["IGNORED"] + [item for key_value in args.items() for item in key_value if item != ""]
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for generating text."""
import copy
import json
import os
import time
import torch
import torch.nn.functional as F
from megatron import get_args
from megatron import get_tokenizer
from megatron import mpu
from megatron.utils import get_ltor_masks_and_position_ids, unwrap_model
from megatron.p2p_communication import recv_forward, send_forward
# These are needed to unwrap the model, would be nice to put these in megatron.utils if possible?
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
def get_batch(context_tokens):
"""Generate batch from context tokens."""
args = get_args()
tokenizer = get_tokenizer()
# Move to GPU.
tokens = context_tokens.view(args.micro_batch_size, -1).contiguous().cuda()
# Get the attention mask and position ids.
attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
tokens,
tokenizer.eod,
args.reset_position_ids,
args.reset_attention_mask,
args.eod_mask_loss,
prefix_indices=None,
loss_on_targets_only=args.loss_on_targets_only
)
return tokens, attention_mask, position_ids
def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
""" This function has been mostly taken from huggingface conversational
ai code at
https://medium.com/huggingface/how-to-build-a-state-of-the-art-
conversational-ai-with-transfer-learning-2d818ac26313 """
if top_k > 0:
# Remove all tokens with a probability less than the
# last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p > 0.0:
# Cconvert to 1D
sorted_logits, sorted_indices = torch.sort(
logits, descending=True, dim=-1)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1),
dim=-1)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep also the first token
# above the threshold
sorted_indices_to_remove[..., 1:] \
= sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
for i in range(sorted_indices.size(0)):
indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
logits[i][indices_to_remove] = filter_value
return logits
def generate_samples_input_from_file(model):
args = get_args()
tokenizer = get_tokenizer()
# Read the sample file and open the output file.
assert args.sample_input_file is not None, \
'sample input file is not provided.'
if mpu.is_pipeline_first_stage() and mpu.get_tensor_model_parallel_rank() == 0:
fname = open(args.sample_input_file, "r")
all_raw_text = fname.readlines()
input_count = len(all_raw_text)
input_pos = 0
if args.sample_output_file is None:
sample_output_file = args.sample_input_file + ".out"
print('`sample-output-file` not specified, setting '
'it to {}'.format(sample_output_file))
else:
sample_output_file = args.sample_output_file
fname_out = open(sample_output_file, "w+")
context_count = 0
model.eval()
with torch.no_grad():
while True:
terminate_runs = 0
raw_text_len = 0
if mpu.is_pipeline_first_stage() \
and mpu.get_tensor_model_parallel_rank() == 0:
raw_text = all_raw_text[input_pos]
input_pos += 1
if input_pos == input_count:
raw_text = "stop"
raw_text_len = len(raw_text)
if "stop" in raw_text:
terminate_runs = 1
else:
context_tokens = tokenizer.tokenize(raw_text)
context_length = len(context_tokens)
if context_length >= (args.seq_length // 2):
print("\nContext length", context_length,
"\nPlease give smaller context (half of the "
"sequence length)!", flush=True)
continue
else:
context_tokens = tokenizer.tokenize("EMPTY TEXT")
context_length = 0
input_info = [terminate_runs, raw_text_len, context_length]
input_info_tensor = torch.cuda.LongTensor(input_info)
torch.distributed.all_reduce(input_info_tensor,
group=mpu.get_model_parallel_group())
terminate_runs = input_info_tensor[0].item()
raw_text_len = input_info_tensor[1].item()
context_length = input_info_tensor[2].item()
if terminate_runs == 1:
return
# For pipeline parallel we send context tokens to other stages
# so they get the lengths correct
if mpu.get_tensor_model_parallel_rank() == 0 \
and args.pipeline_model_parallel_size > 1:
if mpu.is_pipeline_first_stage():
src = mpu.get_pipeline_model_parallel_first_rank()
group = mpu.get_pipeline_model_parallel_group()
context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
torch.distributed.broadcast(context_tokens_tensor, src, group)
else:
src = mpu.get_pipeline_model_parallel_first_rank()
group = mpu.get_pipeline_model_parallel_group()
context_tokens_tensor = torch.empty(context_length,
dtype=torch.int64,
device=torch.device("cuda"))
torch.distributed.broadcast(context_tokens_tensor, src, group)
context_tokens = context_tokens_tensor.cpu().numpy().tolist()
token_stream = get_token_stream(model, [context_tokens])
for _, decode_tokens in enumerate(token_stream):
pass
if mpu.get_tensor_model_parallel_rank() == 0:
if mpu.is_pipeline_first_stage():
os.system('clear')
print("\nContext:", raw_text, flush=True)
fname_out.write("\nContext:")
fname_out.write(raw_text)
decode_tokens, _ = decode_tokens
decode_tokens = decode_tokens[0].cpu().numpy().tolist()
trim_decode_tokens = tokenizer.detokenize(
decode_tokens)[raw_text_len:]
print("\nMegatron-LM:", trim_decode_tokens, flush=True)
fname_out.write("\n\nMegatron-LM:")
fname_out.write(trim_decode_tokens)
fname_out.write("\n")
raw_text = None
context_count += 1
# We added this function to support the tasks evaluation such as squad
# and drop in the https://github.com/EleutherAI/lm-evaluation-harness
# codebase. The lm-evaluation-harness code can now call this function
# similar to their current generate function call used for gpt style models.
def generate_samples_eval(model, context, max_gen_length, eos_token_id):
# Generate samples for lm evaluation
# NEED TO THINK ABOUT eos token
args = get_args()
tokenizer = get_tokenizer()
raw_text_len = len(context)
model.eval()
context_tokens = tokenizer.tokenize(context)
args.out_seq_length = max_gen_length + len(context_tokens)
args.eos_id = eos_token_id
with torch.no_grad():
token_stream = get_token_stream(model, [context_tokens])
for counter, decode_tokens in enumerate(token_stream):
if counter == args.out_seq_length:
break
decode_tokens, _ = decode_tokens
decode_tokens = decode_tokens[0].cpu().numpy().tolist()
trim_decode_tokens = tokenizer.detokenize(
decode_tokens)[raw_text_len:]
return trim_decode_tokens
def generate_samples_interactive(model, print_frequency=24):
args = get_args()
tokenizer = get_tokenizer()
context_count = 0
model.eval()
with torch.no_grad():
while True:
terminate_runs = 0
raw_text_len = 0
if mpu.is_pipeline_first_stage() \
and mpu.get_tensor_model_parallel_rank() == 0:
os.system('clear')
raw_text = input("\nContext prompt (stop to exit) >>> ")
while not raw_text:
print('Prompt should not be empty!')
raw_text = input("\nContext prompt (stop to exit) >>> ")
raw_text_len = len(raw_text)
if "stop" in raw_text:
terminate_runs = 1
else:
context_tokens = tokenizer.tokenize(raw_text)
context_length = len(context_tokens)
if context_length >= (args.seq_length // 2):
print("\nContext length", context_length,
"\nPlease give smaller context (half of the "
"sequence length)!", flush=True)
continue
else:
context_tokens = tokenizer.tokenize("EMPTY TEXT")
context_length = 0
input_info = [terminate_runs, raw_text_len, context_length]
input_info_tensor = torch.cuda.LongTensor(input_info)
torch.distributed.all_reduce(input_info_tensor,
group=mpu.get_model_parallel_group())
terminate_runs = input_info_tensor[0].item()
raw_text_len = input_info_tensor[1].item()
context_length = input_info_tensor[2].item()
if terminate_runs == 1:
return
# For pipeline parallel we send context tokens to other stages
# so they get the lengths correct
if mpu.get_tensor_model_parallel_rank() == 0 \
and args.pipeline_model_parallel_size > 1:
if mpu.is_pipeline_first_stage():
src = mpu.get_pipeline_model_parallel_first_rank()
group = mpu.get_pipeline_model_parallel_group()
context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
torch.distributed.broadcast(context_tokens_tensor, src, group)
else:
src = mpu.get_pipeline_model_parallel_first_rank()
group = mpu.get_pipeline_model_parallel_group()
context_tokens_tensor = torch.empty(context_length,
dtype=torch.int64,
device=torch.device("cuda"))
torch.distributed.broadcast(context_tokens_tensor, src, group)
context_tokens = context_tokens_tensor.cpu().numpy().tolist()
token_stream = get_token_stream(model, [context_tokens])
for counter, decode_tokens in enumerate(token_stream):
if counter % print_frequency != 0 \
or mpu.get_tensor_model_parallel_rank() != 0 \
or not mpu.is_pipeline_first_stage():
continue
os.system('clear')
print("\nContext:", raw_text, flush=True)
decode_tokens, _ = decode_tokens
decode_tokens = decode_tokens[0].cpu().numpy().tolist()
trim_decode_tokens = tokenizer.detokenize(
decode_tokens)[raw_text_len:]
print("\nMegatron-LM:", trim_decode_tokens, flush=True)
if mpu.is_pipeline_first_stage() \
and mpu.get_tensor_model_parallel_rank() == 0:
os.system('clear')
print("\nContext:", raw_text, flush=True)
if not isinstance(decode_tokens, list):
decode_tokens, _ = decode_tokens
decode_tokens = decode_tokens[0].cpu().numpy().tolist()
trim_decode_tokens = tokenizer.detokenize(
decode_tokens)[raw_text_len:]
print("\nMegatron-LM:", trim_decode_tokens, flush=True)
input("\nPress Enter to continue >>>")
raw_text = None
context_count += 1
def generate_samples_unconditional(model):
args = get_args()
tokenizer = get_tokenizer()
num_samples = args.num_samples
context_tokens = [[tokenizer.eod]
for _ in range(args.micro_batch_size)]
ctr = 0
while True:
start_time = time.time()
for token_stream in get_token_stream(model,
copy.deepcopy(context_tokens)):
pass
if mpu.is_pipeline_last_stage() and \
mpu.get_tensor_model_parallel_rank() == 0:
if ctr % args.log_interval == 0:
print('Avg s/batch:',
(time.time() - start_time) / min(args.log_interval, ctr + 1))
start_time = time.time()
length = len(token_stream)
token_batch = token_stream[0].cpu().numpy().tolist()
length_batch = token_stream[1].cpu().numpy().tolist()
assert len(length_batch) == args.micro_batch_size
for tokens, length in zip(token_batch, length_batch):
tokens = tokens[1:length - 1]
text = tokenizer.detokenize(tokens)
is_finished = length < args.seq_length - 1
datum = {'text': text, 'length': length - 1, 'finished': is_finished}
yield datum
ctr += 1
if ctr >= num_samples:
break
else:
for _ in range(args.micro_batch_size):
yield None
ctr += 1
if ctr >= num_samples:
break
if ctr >= num_samples:
break
def generate_and_write_samples_unconditional(model):
args = get_args()
assert args.genfile is not None
with open(args.genfile, 'w') as f:
for datum in generate_samples_unconditional(model):
if mpu.is_pipeline_last_stage() and \
mpu.get_tensor_model_parallel_rank() == 0:
f.write(json.dumps(datum) + '\n')
def pad_batch(batch, pad_id, args):
context_lengths = []
for tokens in batch:
context_length = len(tokens)
if context_length < args.seq_length:
tokens.extend([pad_id] * (args.seq_length - context_length))
context_lengths.append(context_length)
return batch, context_lengths
def get_token_stream(model, context_tokens):
args = get_args()
tokenizer = get_tokenizer()
context_tokens, context_lengths = pad_batch(context_tokens,
tokenizer.eod, args)
context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
context_length_tensor = torch.cuda.LongTensor(context_lengths)
torch.distributed.broadcast(context_length_tensor,
mpu.get_tensor_model_parallel_src_rank(),
group=mpu.get_tensor_model_parallel_group())
torch.distributed.broadcast(context_tokens_tensor,
mpu.get_tensor_model_parallel_src_rank(),
group=mpu.get_tensor_model_parallel_group())
context_length = context_length_tensor.min().item()
tokens, attention_mask, position_ids = get_batch(context_tokens_tensor)
batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor,
context_length_tensor,
attention_mask, position_ids)
for tokens, lengths in batch_token_iterator:
context_length += 1
if tokens is not None:
yield tokens[:, :context_length], lengths
else:
yield None, None
def switch(val1, val2, boolean):
boolean = boolean.type_as(val1)
return (1 - boolean) * val1 + boolean * val2
def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,
layer_past=None, get_key_value=None,
forward_method_parallel_output=None):
# Hidden size changes when not using recompute, need to tell p2p_communicate
# functions the correct size
args = get_args()
orig_seq_length = args.seq_length
args.seq_length = tokens.shape[1]
input_tensor = recv_forward()
# Forward pass through the model.
unwrapped_model = unwrap_model(
model, (torchDDP, LocalDDP, Float16Module))
unwrapped_model.set_input_tensor(input_tensor)
output_tensor = model(tokens, position_ids, attention_mask,
tokentype_ids=tokentype_ids,
layer_past=layer_past,
get_key_value=get_key_value,
forward_method_parallel_output=forward_method_parallel_output)
if get_key_value:
output_tensor, layer_past = output_tensor
send_forward(output_tensor)
args.seq_length = orig_seq_length
if get_key_value:
return output_tensor, layer_past
return output_tensor
def sample_sequence_batch(model, context_tokens, context_lengths,
attention_mask, position_ids,
maxlen=None, type_ids=None):
args = get_args()
tokenizer = get_tokenizer()
model.eval()
with torch.no_grad():
context_length = context_lengths.min().item()
# added eos_id to support the function generate_samples_eval that passes
# eos_id as an argument and needs termination when that id id found.
if hasattr(args, 'eos_id'):
eos_id = args.eos_id
else:
eos_id = tokenizer.eod
counter = 0
org_context_length = context_length
layer_past = None
batch_size = context_tokens.size(0)
is_done = torch.zeros([batch_size]).byte().cuda()
tokens = context_tokens
if maxlen is None:
maxlen = args.seq_length - 1
if maxlen > (org_context_length + args.out_seq_length):
maxlen = org_context_length + args.out_seq_length
lengths = torch.ones([batch_size]).long().cuda() * maxlen
while context_length <= (maxlen):
if args.recompute:
output = forward_step(model, tokens,
position_ids,
attention_mask,
tokentype_ids=type_ids,
forward_method_parallel_output=False)
if mpu.is_pipeline_last_stage():
assert output is not None
logits = output[:, context_length - 1, :]
else:
types2use = None
if counter == 0:
tokens2use = tokens[:, :context_length]
positions2use = position_ids[:, :context_length]
if type_ids is not None:
types2use = type_ids[:, :context_length]
else:
tokens2use = tokens[:, context_length - 1].view(
batch_size, -1)
positions2use = position_ids[:, context_length - 1].view(
batch_size, -1)
if type_ids is not None:
types2use = type_ids[:, context_length - 1].view(
batch_size, -1)
output, layer_past = forward_step(model, tokens2use,
positions2use,
attention_mask,
layer_past=layer_past,
get_key_value=True,
tokentype_ids=types2use,
forward_method_parallel_output=False)
if mpu.is_pipeline_last_stage():
assert output is not None
logits = output[:, -1].view(batch_size, -1).contiguous()
if mpu.is_pipeline_last_stage():
if args.greedy:
prev = torch.argmax(logits, dim=-1).view(-1)
else:
logits = logits.float()
logits /= args.temperature
logits = top_k_logits(logits, top_k=args.top_k,
top_p=args.top_p)
log_probs = F.softmax(logits, dim=-1)
prev = torch.multinomial(log_probs, num_samples=1).view(-1)
started = context_lengths <= context_length
new_tokens = switch(
tokens[:, context_length].view(-1), prev, started)
tokens[:, context_length] = new_tokens
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_embedding_group()
torch.distributed.broadcast(new_tokens, src, group)
done_token = (prev == eos_id).byte() & started.byte()
just_finished = (done_token & ~is_done).bool()
lengths[just_finished.view(-1)] = context_length
is_done = is_done | done_token
done = torch.all(is_done)
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_pipeline_model_parallel_group()
torch.distributed.broadcast(done, src, group)
yield tokens, lengths
else:
if mpu.is_pipeline_first_stage():
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_embedding_group()
new_tokens = torch.empty_like(tokens[:, context_length])
torch.distributed.broadcast(new_tokens, src, group)
tokens[:, context_length] = new_tokens
yield tokens, None
else:
yield None, None
done = torch.cuda.ByteTensor([0])
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_pipeline_model_parallel_group()
torch.distributed.broadcast(done, src, group)
context_length += 1
counter += 1
if done:
break
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .tokenizer import build_tokenizer
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import re
import unicodedata
import six
def validate_case_matches_checkpoint(do_lower_case, init_checkpoint):
"""Checks whether the casing config is consistent with the checkpoint name."""
# The casing has to be passed in by the user and there is no explicit check
# as to whether it matches the checkpoint. The casing information probably
# should have been stored in the bert_config.json file, but it's not, so
# we have to heuristically detect it to validate.
if not init_checkpoint:
return
m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint)
if m is None:
return
model_name = m.group(1)
lower_models = [
"uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12",
"multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12"
]
cased_models = [
"cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16",
"multi_cased_L-12_H-768_A-12"
]
is_bad_config = False
if model_name in lower_models and not do_lower_case:
is_bad_config = True
actual_flag = "False"
case_name = "lowercased"
opposite_flag = "True"
if model_name in cased_models and do_lower_case:
is_bad_config = True
actual_flag = "True"
case_name = "cased"
opposite_flag = "False"
if is_bad_config:
raise ValueError(
"You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. "
"However, `%s` seems to be a %s model, so you "
"should pass in `--do_lower_case=%s` so that the fine-tuning matches "
"how the model was pre-training. If this error is wrong, please "
"just comment out this check." % (actual_flag, init_checkpoint,
model_name, case_name, opposite_flag))
def convert_to_unicode(text):
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
if six.PY3:
if isinstance(text, str):
return text
elif isinstance(text, bytes):
return text.decode("utf-8", "ignore")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
elif six.PY2:
if isinstance(text, str):
return text.decode("utf-8", "ignore")
elif isinstance(text, unicode):
return text
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
else:
raise ValueError("Not running on Python2 or Python 3?")
def printable_text(text):
"""Returns text encoded in a way suitable for print or `tf.logging`."""
# These functions want `str` for both Python2 and Python3, but in one case
# it's a Unicode string and in the other it's a byte string.
if six.PY3:
if isinstance(text, str):
return text
elif isinstance(text, bytes):
return text.decode("utf-8", "ignore")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
elif six.PY2:
if isinstance(text, str):
return text
elif isinstance(text, unicode):
return text.encode("utf-8")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
else:
raise ValueError("Not running on Python2 or Python 3?")
def load_vocab(vocab_file):
"""Loads a vocabulary file into a dictionary."""
vocab = collections.OrderedDict()
index = 0
with open(vocab_file, "r") as reader:
while True:
token = convert_to_unicode(reader.readline())
if not token:
break
token = token.strip()
vocab[token] = index
index += 1
return vocab
def convert_by_vocab(vocab, items):
"""Converts a sequence of [tokens|ids] using the vocab."""
output = []
for item in items:
output.append(vocab[item])
return output
def convert_tokens_to_ids(vocab, tokens):
return convert_by_vocab(vocab, tokens)
def convert_ids_to_tokens(inv_vocab, ids):
return convert_by_vocab(inv_vocab, ids)
def whitespace_tokenize(text):
"""Runs basic whitespace cleaning and splitting on a piece of text."""
text = text.strip()
if not text:
return []
tokens = text.split()
return tokens
class FullTokenizer(object):
"""Runs end-to-end tokenziation."""
def __init__(self, vocab_file, do_lower_case=True):
self.vocab = load_vocab(vocab_file)
self.inv_vocab = {v: k for k, v in self.vocab.items()}
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
def tokenize(self, text):
split_tokens = []
for token in self.basic_tokenizer.tokenize(text):
for sub_token in self.wordpiece_tokenizer.tokenize(token):
split_tokens.append(sub_token)
return split_tokens
def convert_tokens_to_ids(self, tokens):
return convert_by_vocab(self.vocab, tokens)
def convert_ids_to_tokens(self, ids):
return convert_by_vocab(self.inv_vocab, ids)
def vocab_size(self):
return len(self.vocab)
class BasicTokenizer(object):
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
def __init__(self, do_lower_case=True):
"""Constructs a BasicTokenizer.
Args:
do_lower_case: Whether to lower case the input.
"""
self.do_lower_case = do_lower_case
def tokenize(self, text):
"""Tokenizes a piece of text."""
text = convert_to_unicode(text)
text = self._clean_text(text)
# This was added on November 1st, 2018 for the multilingual and Chinese
# models. This is also applied to the English models now, but it doesn't
# matter since the English models were not trained on any Chinese data
# and generally don't have any Chinese data in them (there are Chinese
# characters in the vocabulary because Wikipedia does have some Chinese
# words in the English Wikipedia.).
text = self._tokenize_chinese_chars(text)
orig_tokens = whitespace_tokenize(text)
split_tokens = []
for token in orig_tokens:
if self.do_lower_case:
token = token.lower()
token = self._run_strip_accents(token)
split_tokens.extend(self._run_split_on_punc(token))
output_tokens = whitespace_tokenize(" ".join(split_tokens))
return output_tokens
def _run_strip_accents(self, text):
"""Strips accents from a piece of text."""
text = unicodedata.normalize("NFD", text)
output = []
for char in text:
cat = unicodedata.category(char)
if cat == "Mn":
continue
output.append(char)
return "".join(output)
def _run_split_on_punc(self, text):
"""Splits punctuation on a piece of text."""
chars = list(text)
i = 0
start_new_word = True
output = []
while i < len(chars):
char = chars[i]
if _is_punctuation(char):
output.append([char])
start_new_word = True
else:
if start_new_word:
output.append([])
start_new_word = False
output[-1].append(char)
i += 1
return ["".join(x) for x in output]
def _tokenize_chinese_chars(self, text):
"""Adds whitespace around any CJK character."""
output = []
for char in text:
cp = ord(char)
if self._is_chinese_char(cp):
output.append(" ")
output.append(char)
output.append(" ")
else:
output.append(char)
return "".join(output)
def _is_chinese_char(self, cp):
"""Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
#
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
# despite its name. The modern Korean Hangul alphabet is a different block,
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled
# like the all of the other languages.
if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
(cp >= 0x3400 and cp <= 0x4DBF) or #
(cp >= 0x20000 and cp <= 0x2A6DF) or #
(cp >= 0x2A700 and cp <= 0x2B73F) or #
(cp >= 0x2B740 and cp <= 0x2B81F) or #
(cp >= 0x2B820 and cp <= 0x2CEAF) or
(cp >= 0xF900 and cp <= 0xFAFF) or #
(cp >= 0x2F800 and cp <= 0x2FA1F)): #
return True
return False
def _clean_text(self, text):
"""Performs invalid character removal and whitespace cleanup on text."""
output = []
for char in text:
cp = ord(char)
if cp == 0 or cp == 0xfffd or _is_control(char):
continue
if _is_whitespace(char):
output.append(" ")
else:
output.append(char)
return "".join(output)
class WordpieceTokenizer(object):
"""Runs WordPiece tokenziation."""
def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200):
self.vocab = vocab
self.unk_token = unk_token
self.max_input_chars_per_word = max_input_chars_per_word
def tokenize(self, text):
"""Tokenizes a piece of text into its word pieces.
This uses a greedy longest-match-first algorithm to perform tokenization
using the given vocabulary.
For example:
input = "unaffable"
output = ["un", "##aff", "##able"]
Args:
text: A single token or whitespace separated tokens. This should have
already been passed through `BasicTokenizer.
Returns:
A list of wordpiece tokens.
"""
text = convert_to_unicode(text)
output_tokens = []
for token in whitespace_tokenize(text):
chars = list(token)
if len(chars) > self.max_input_chars_per_word:
output_tokens.append(self.unk_token)
continue
is_bad = False
start = 0
sub_tokens = []
while start < len(chars):
end = len(chars)
cur_substr = None
while start < end:
substr = "".join(chars[start:end])
if start > 0:
substr = "##" + substr
if substr in self.vocab:
cur_substr = substr
break
end -= 1
if cur_substr is None:
is_bad = True
break
sub_tokens.append(cur_substr)
start = end
if is_bad:
output_tokens.append(self.unk_token)
else:
output_tokens.extend(sub_tokens)
return output_tokens
def _is_whitespace(char):
"""Checks whether `chars` is a whitespace character."""
# \t, \n, and \r are technically contorl characters but we treat them
# as whitespace since they are generally considered as such.
if char == " " or char == "\t" or char == "\n" or char == "\r":
return True
cat = unicodedata.category(char)
if cat == "Zs":
return True
return False
def _is_control(char):
"""Checks whether `chars` is a control character."""
# These are technically control characters but we count them as whitespace
# characters.
if char == "\t" or char == "\n" or char == "\r":
return False
cat = unicodedata.category(char)
if cat in ("Cc", "Cf"):
return True
return False
def _is_punctuation(char):
"""Checks whether `chars` is a punctuation character."""
cp = ord(char)
# We treat all non-letter/number ASCII as punctuation.
# Characters such as "^", "$", and "`" are not in the Unicode
# Punctuation class but we treat them as punctuation anyways, for
# consistency.
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
(cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
return True
cat = unicodedata.category(char)
if cat.startswith("P"):
return True
return False
# coding=utf-8
# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for OpenAI GPT."""
from __future__ import (absolute_import, division, print_function,
unicode_literals)
import sys
import json
import os
import regex as re
from io import open
from megatron import logging
try:
from functools import lru_cache
except ImportError:
# Just a dummy decorator to get the checks to run on python2
# because honestly I don't want to support a byte-level unicode BPE
# tokenizer on python 2 right now.
def lru_cache():
return lambda func: func
logger = logging.get_logger(__name__)
PRETRAINED_VOCAB_ARCHIVE_MAP = {
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json",
}
PRETRAINED_MERGES_ARCHIVE_MAP = {
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt",
}
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
'gpt2': 1024,
}
VOCAB_NAME = 'vocab.json'
MERGES_NAME = 'merges.txt'
SPECIAL_TOKENS_NAME = 'special_tokens.txt'
@lru_cache()
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a signficant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
_chr = unichr if sys.version_info[0] == 2 else chr
bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + \
list(range(ord("®"), ord("ÿ") + 1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8 + n)
n += 1
cs = [_chr(n) for n in cs]
return dict(zip(bs, cs))
def get_pairs(word):
"""Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
class GPT2Tokenizer(object):
"""
GPT-2 BPE tokenizer. Peculiarities:
- Byte-level BPE
"""
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
"""
Instantiate a PreTrainedBertModel from a pre-trained model file.
Download and cache the pre-trained model file if needed.
"""
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path]
special_tokens_file = None
else:
vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME)
special_tokens_file = os.path.join(pretrained_model_name_or_path, SPECIAL_TOKENS_NAME)
if not os.path.exists(special_tokens_file):
special_tokens_file = None
else:
logger.info("loading special tokens file {}".format(special_tokens_file))
# redirect to the cache, if necessary
try:
from .file_utils import cached_path
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir)
except EnvironmentError:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} and {} "
"at this path or url.".format(
pretrained_model_name_or_path,
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
pretrained_model_name_or_path,
vocab_file, merges_file))
return None
if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file:
logger.info("loading vocabulary file {}".format(vocab_file))
logger.info("loading merges file {}".format(merges_file))
else:
logger.info("loading vocabulary file {} from cache at {}".format(
vocab_file, resolved_vocab_file))
logger.info("loading merges file {} from cache at {}".format(
merges_file, resolved_merges_file))
if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
# if we're using a pretrained model, ensure the tokenizer wont index sequences longer
# than the number of positional embeddings
max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
# Instantiate tokenizer.
if special_tokens_file and 'special_tokens' not in kwargs:
special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1]
else:
special_tokens = kwargs.pop('special_tokens', [])
tokenizer = cls(
resolved_vocab_file,
resolved_merges_file,
special_tokens=special_tokens,
*inputs,
**kwargs)
return tokenizer
def __init__(self, vocab_file, merges_file, errors='replace',
special_tokens=None, max_len=None, max_token_len_cache=9):
"""
max_token_len_cache determines whether a normalized token will be cached. It tries to only store shorter tokens in the cache,
with the heuristic that they are more frequent. Increasing this may make tokenization faster but will also take more memory.
The upper bound of the normalized token cache is fixed at 1_000_000 tokens.
"""
self.max_len = max_len if max_len is not None else int(1e12)
self.encoder = json.load(open(vocab_file))
self.decoder = {v: k for k, v in self.encoder.items()}
self.errors = errors # how to handle errors in decoding
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
bpe_merges = [tuple(merge.split()) for merge in bpe_data]
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
# Should haved added re.IGNORECASE so BPE merges can happen for
# capitalized versions of contractions
self.pat = re.compile(
r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
self.special_tokens = {}
self.special_tokens_decoder = {}
self.set_special_tokens(special_tokens)
self.max_token_len_cache = max_token_len_cache
def __len__(self):
return len(self.encoder) + len(self.special_tokens)
def set_special_tokens(self, special_tokens):
""" Add a list of additional tokens to the encoder.
The additional tokens are indexed starting from the last index of the
current vocabulary in the order of the `special_tokens` list.
"""
if not special_tokens:
self.special_tokens = {}
self.special_tokens_decoder = {}
return
self.special_tokens = dict((tok, len(self.encoder) + i)
for i, tok in enumerate(special_tokens))
self.special_tokens_decoder = {v: k for k, v in self.special_tokens.items()}
logger.info("Special tokens {}".format(self.special_tokens))
@lru_cache(1_000_000)
def bpe(self, token):
word = tuple(token)
pairs = get_pairs(word)
if not pairs:
return token
while True:
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except BaseException:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = ' '.join(word)
return word
@lru_cache(1_000_000)
def normalize_token_and_cache(self, token):
return self.normalize_token(token)
def normalize_token(self, token):
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
ret = [bpe_token for bpe_token in self.bpe(token).split(' ')]
return ret
def tokenize(self, text):
""" Tokenize a string. """
max_token_len_cache = self.max_token_len_cache
bpe_tokens = []
if sys.version_info[0] == 2:
for token in re.findall(self.pat, text):
token = ''.join(self.byte_encoder[ord(b)] for b in token)
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' '))
return bpe_tokens
for token in re.findall(self.pat, text):
if len(token) <= max_token_len_cache:
bpe_tokens.extend(self.normalize_token_and_cache(token))
else:
bpe_tokens.extend(self.normalize_token(token))
return bpe_tokens
def convert_tokens_to_ids(self, tokens):
""" Converts a sequence of tokens into ids using the vocab. """
ids = []
if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)):
if tokens in self.special_tokens:
return self.special_tokens[tokens]
else:
return self.encoder.get(tokens, 0)
for token in tokens:
if token in self.special_tokens:
ids.append(self.special_tokens[token])
else:
ids.append(self.encoder.get(token, 0))
if len(ids) > self.max_len:
logger.warning(
"Token indices sequence length is longer than the specified maximum "
" sequence length for this OpenAI GPT model ({} > {}). Running this"
" sequence through the model will result in indexing errors".format(
len(ids), self.max_len)
)
return ids
def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
"""Converts a sequence of ids in BPE tokens using the vocab."""
tokens = []
for i in ids:
if i in self.special_tokens_decoder:
if not skip_special_tokens:
tokens.append(self.special_tokens_decoder[i])
else:
tokens.append(self.decoder[i])
return tokens
def encode(self, text):
return self.convert_tokens_to_ids(self.tokenize(text))
def decode(self, tokens):
text = ''.join([self.decoder[token] for token in tokens])
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
return text
def save_vocabulary(self, vocab_path):
"""Save the tokenizer vocabulary and merge files to a directory."""
if not os.path.isdir(vocab_path):
logger.error("Vocabulary path ({}) should be a directory".format(vocab_path))
return
vocab_file = os.path.join(vocab_path, VOCAB_NAME)
merge_file = os.path.join(vocab_path, MERGES_NAME)
special_tokens_file = os.path.join(vocab_path, SPECIAL_TOKENS_NAME)
with open(vocab_file, 'w', encoding='utf-8') as f:
f.write(json.dumps(self.encoder, ensure_ascii=False))
index = 0
with open(merge_file, "w", encoding="utf-8") as writer:
writer.write(u'#version: 0.2\n')
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
if index != token_index:
logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive."
" Please check that the tokenizer is not corrupted!".format(merge_file))
index = token_index
writer.write(' '.join(bpe_tokens) + u'\n')
index += 1
index = len(self.encoder)
with open(special_tokens_file, 'w', encoding='utf-8') as writer:
for token, token_index in sorted(self.special_tokens.items(), key=lambda kv: kv[1]):
if index != token_index:
logger.warning("Saving special tokens vocabulary to {}: BPE indices are not consecutive."
" Please check that the tokenizer is not corrupted!".format(special_tokens_file))
index = token_index
writer.write(token + u'\n')
index += 1
return vocab_file, merge_file, special_tokens_file
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Megatron tokenizers."""
from abc import ABC
from abc import abstractmethod
from transformers import AutoTokenizer
from .bert_tokenization import FullTokenizer as FullBertTokenizer
from .gpt2_tokenization import GPT2Tokenizer
def build_tokenizer(args):
"""Initialize tokenizer."""
if args.rank == 0:
print('> building {} tokenizer ...'.format(args.tokenizer_type),
flush=True)
# Select and instantiate the tokenizer.
assert args.vocab_file is not None or args.tokenizer_type == "PretrainedFromHF"
if args.tokenizer_type == 'BertWordPieceLowerCase':
tokenizer = _BertWordPieceTokenizer(vocab_file=args.vocab_file,
lower_case=True,
vocab_extra_ids=args.vocab_extra_ids)
elif args.tokenizer_type == 'BertWordPieceCase':
tokenizer = _BertWordPieceTokenizer(vocab_file=args.vocab_file,
lower_case=False,
vocab_extra_ids=args.vocab_extra_ids)
elif args.tokenizer_type == 'GPT2BPETokenizer':
assert args.merge_file is not None
tokenizer = _GPT2BPETokenizer(args.vocab_file, args.merge_file)
elif args.tokenizer_type == "PretrainedFromHF":
assert args.tokenizer_name_or_path is not None
# prevent transformers from logging info and warnings on each rank
import transformers
import logging
if args.rank == 0:
transformers.utils.logging.set_verbosity(logging.INFO)
else:
# shut the warnings on replicas
transformers.utils.logging.set_verbosity(logging.ERROR)
if args.rank == 0:
print(" vocab file is un-used. loading tokenizer from pre-trained model")
tokenizer = _AutoTokenizer(args.tokenizer_name_or_path, vocab_extra_ids=args.vocab_extra_ids)
else:
raise NotImplementedError('{} tokenizer is not '
'implemented.'.format(args.tokenizer_type))
# Add vocab size.
args.padded_vocab_size = _vocab_size_with_padding(tokenizer.vocab_size,
args)
return tokenizer
def _vocab_size_with_padding(orig_vocab_size, args):
"""Apply the requested rules to change the size of the vocabulary"""
if args.pad_vocab_size_to is not None:
if args.pad_vocab_size_to < orig_vocab_size:
raise ValueError(
f"You asked to pad the vocabulary to {args.pad_vocab_size_to} when the initial vocabulary size is "
f"{orig_vocab_size}. You can only pad to a higher value."
)
if args.make_vocab_size_divisible_by is not None and (args.pad_vocab_size_to % args.make_vocab_size_divisible_by) != 0:
raise ValueError(f"{args.pad_vocab_size_to} is not divisible by {args.make_vocab_size_divisible_by}")
after = args.pad_vocab_size_to
else:
# Pad vocab size so it is divisible by model parallel size and still having GPU friendly size.
after = orig_vocab_size
multiple = args.make_vocab_size_divisible_by * \
args.tensor_model_parallel_size
while (after % multiple) != 0:
after += 1
if args.rank == 0:
print(' > padded vocab (size: {}) with {} dummy tokens '
'(new size: {})'.format(
orig_vocab_size, after - orig_vocab_size, after), flush=True)
return after
class AbstractTokenizer(ABC):
"""Abstract class for tokenizer."""
def __init__(self, name):
self.name = name
super().__init__()
@property
@abstractmethod
def vocab_size(self):
pass
@property
@abstractmethod
def vocab(self):
"""Dictionary from vocab text token to id token."""
pass
@property
@abstractmethod
def inv_vocab(self):
"""Dictionary from vocab id token to text token."""
pass
@abstractmethod
def tokenize(self, text):
pass
def detokenize(self, token_ids):
raise NotImplementedError('detokenizer is not implemented for {} '
'tokenizer'.format(self.name))
@property
def cls(self):
raise NotImplementedError('CLS is not provided for {} '
'tokenizer'.format(self.name))
@property
def sep(self):
raise NotImplementedError('SEP is not provided for {} '
'tokenizer'.format(self.name))
@property
def pad(self):
raise NotImplementedError('PAD is not provided for {} '
'tokenizer'.format(self.name))
@property
def eod(self):
raise NotImplementedError('EOD is not provided for {} '
'tokenizer'.format(self.name))
@property
def mask(self):
raise NotImplementedError('MASK is not provided for {} '
'tokenizer'.format(self.name))
class _BertWordPieceTokenizer(AbstractTokenizer):
"""Original BERT wordpiece tokenizer."""
def __init__(self, vocab_file, lower_case=True, vocab_extra_ids=0):
if lower_case:
name = 'BERT Lower Case'
else:
name = 'BERT Upper Case'
super().__init__(name)
self.tokenizer = FullBertTokenizer(vocab_file, do_lower_case=lower_case)
self.cls_id = self.tokenizer.vocab['[CLS]']
self.sep_id = self.tokenizer.vocab['[SEP]']
self.pad_id = self.tokenizer.vocab['[PAD]']
self.mask_id = self.tokenizer.vocab['[MASK]']
self._additional_special_tokens = []
# (dsachan) Add BOS and EOS tokens
SPECIAL_TOKENS = {'eos_token': '[EOS]',
'bos_token': '[BOS]'}
self._bos_token = '[BOS]'
self.add_token(self._bos_token)
self._bos_token_id = self.vocab.get(self._bos_token)
self._eos_token = '[EOS]'
self.add_token(self._eos_token)
self._eos_token_id = self.vocab.get(self._eos_token)
# (dsachan) Add additional special tokens
# These can be used as sentinel tokens in T5 model inputs
additional_special_tokens = []
additional_special_tokens.extend(
["<extra_id_{}>".format(i) for i in range(vocab_extra_ids)])
self.add_additional_special_tokens(additional_special_tokens)
def add_token(self, token):
if token not in self.vocab:
self.inv_vocab[self.vocab_size] = token
# self.vocab_size comes from len(vocab)
# and it will increase as we add elements
self.vocab[token] = self.vocab_size
def add_additional_special_tokens(self, tokens_list):
setattr(self, "additional_special_tokens", tokens_list)
for value in tokens_list:
self.add_token(value)
@property
def vocab_size(self):
return self.tokenizer.vocab_size()
@property
def vocab(self):
return self.tokenizer.vocab
@property
def inv_vocab(self):
return self.tokenizer.inv_vocab
def tokenize(self, text):
text_tokens = self.tokenizer.tokenize(text)
return self.tokenizer.convert_tokens_to_ids(text_tokens)
def decode(self, ids):
tokens = self.tokenizer.convert_ids_to_tokens(ids)
return self.tokenizer.convert_tokens_to_string(tokens)
def decode_token_ids(self, token_ids):
tokens = self.tokenizer.convert_ids_to_tokens(token_ids)
exclude_list = ['[PAD]', '[CLS]']
non_pads = [t for t in tokens if t not in exclude_list]
result = ""
for s in non_pads:
if s.startswith("##"):
result += s[2:]
else:
result += " " + s
return result
@property
def cls(self):
return self.cls_id
@property
def sep(self):
return self.sep_id
@property
def pad(self):
return self.pad_id
@property
def mask(self):
return self.mask_id
@property
def bos_token(self):
""" Beginning of sentence token id """
return self._bos_token
@property
def eos_token(self):
""" End of sentence token id """
return self._eos_token
@property
def additional_special_tokens(self):
""" All the additional special tokens you may want to use (list of strings)."""
return self._additional_special_tokens
@property
def bos_token_id(self):
""" Id of the beginning of sentence token in the vocabulary."""
return self._bos_token_id
@property
def eos_token_id(self):
""" Id of the end of sentence token in the vocabulary."""
return self._eos_token_id
@property
def additional_special_tokens_ids(self):
""" Ids of all the additional special tokens in the vocabulary (list of integers)."""
return [self.vocab.get(token) for token in self._additional_special_tokens]
@additional_special_tokens.setter
def additional_special_tokens(self, value):
self._additional_special_tokens = value
class _GPT2BPETokenizer(AbstractTokenizer):
"""Original GPT2 BPE tokenizer."""
def __init__(self, vocab_file, merge_file):
name = 'GPT2 BPE'
super().__init__(name)
self.tokenizer = GPT2Tokenizer(vocab_file, merge_file, errors='replace',
special_tokens=[], max_len=None)
self.eod_id = self.tokenizer.encoder['<|endoftext|>']
@property
def vocab_size(self):
return len(self.tokenizer.encoder)
@property
def vocab(self):
return self.tokenizer.encoder
@property
def inv_vocab(self):
return self.tokenizer.decoder
def tokenize(self, text):
return self.tokenizer.encode(text)
def detokenize(self, token_ids):
return self.tokenizer.decode(token_ids)
@property
def eod(self):
return self.eod_id
class _AutoTokenizer(AbstractTokenizer):
"""AutoTokenizer for Hf Pretrained model loading."""
def __init__(self, tokenizer_name_or_path, vocab_extra_ids):
name = tokenizer_name_or_path
super().__init__(name)
hf_tokenizer_kwargs = {}
if vocab_extra_ids > 0:
# TODO @thomasw21 we might need to concatenate to a pre-existing list?
hf_tokenizer_kwargs["additional_special_tokens"] = [f"<extra_id_{_id}>" for _id in range(vocab_extra_ids)]
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, **hf_tokenizer_kwargs)
self.encoder = self.tokenizer.get_vocab()
self.decoder = {v: k for k, v in self.encoder.items()}
@property
def vocab_size(self):
return len(self.tokenizer) # vocab_size doesn't contain additional tokens
@property
def vocab(self):
# TODO @thomasw21 make sure that special tokens don't collapse with vocab tokens.
return {
**{special_token: self.tokenizer.convert_tokens_to_ids(special_token) for special_token in self.tokenizer.additional_special_tokens},
**self.tokenizer.vocab,
}
@property
def inv_vocab(self):
return {v: k for k, v in self.vocab.items()}
def tokenize(self, text):
return self.tokenizer.encode(text)
def detokenize(self, token_ids):
return self.tokenizer.decode(token_ids)
@property
def eod(self):
# TODO @thomasw21 might conflict with <eos>
return self.eos
@property
def cls(self):
candidate = self.tokenizer.cls_token_id
return self._check_token_candidate(candidate)
@property
def sep(self):
candidate = self.tokenizer.sep_token_id
return self._check_token_candidate(candidate)
@property
def pad(self):
candidate = self.tokenizer.pad_token_id
return self._check_token_candidate(candidate)
@property
def mask(self):
candidate = self.tokenizer.mask_token_id
return self._check_token_candidate(candidate)
@property
def bos(self):
raise NotImplementedError("Missing <bos>")
@property
def eos(self):
# TODO @thomasw21 might conflict with the notion of <eod>
candidate = self.tokenizer.eos_token_id
return self._check_token_candidate(candidate)
@property
def additional_special_tokens_ids(self):
""" All the additional special tokens you may want to use (list of strings)."""
return self.tokenizer.additional_special_tokens_ids
@staticmethod
def _check_token_candidate(candidate):
if candidate is None:
raise AttributeError("Token doesn't exist")
return candidate
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pretrain utilities."""
from datetime import datetime
import bisect
import math
import sys
import time
import json
from traceback import print_tb
# The earliest we can measure the start time.
_TRAIN_START_TIME = time.time()
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron import get_args
from megatron import get_timers
from megatron import get_tensorboard_writer
from megatron import get_current_global_batch_size
from megatron import get_num_microbatches
from megatron import is_last_rank
from megatron import update_num_microbatches
from megatron import mpu
from megatron import print_rank_0
from megatron import print_rank_last
from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint
from megatron.model.module import Float16Module
from megatron.optimizer import get_megatron_optimizer
from megatron.initialize import initialize_megatron
from megatron.initialize import write_args_to_tensorboard, log_restart_to_tensorboard
from megatron.learning_rates import AnnealingLR
from megatron.model.distributed import DistributedDataParallel as LocalDDP
from megatron.utils import check_adlr_autoresume_termination, get_parameters_in_billions
from megatron.utils import unwrap_model, found_kill_switch
from megatron.data.data_samplers import build_pretraining_data_loader
from megatron.utils import calc_params_l2_norm
from megatron.schedules import forward_backward_no_pipelining
from megatron.schedules import forward_backward_pipelining_without_interleaving
from megatron.schedules import forward_backward_pipelining_with_interleaving
from megatron.utils import report_memory, flops_calculator
from megatron.global_vars import codecarbon_tracker_start, codecarbon_tracker_stop
from megatron.data.dataset_utils import analyze_data_prefix
import os, glob
from PIL import Image
import numpy as np
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader
import deepspeed
def print_datetime(string):
"""Note that this call will sync across all ranks."""
torch.distributed.barrier()
time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
print_rank_0('[' + string + '] datetime: {} '.format(time_str))
def pretrain(train_valid_test_dataset_provider,
model_provider,
forward_step_func,
extra_args_provider=None,
args_defaults={}):
"""Main training program.
This function will run the followings in the order provided:
1) initialize Megatron.
2) setup model, optimizer and lr schedule using the model_provider.
3) call train_val_test_data_provider to get train/val/test datasets.
4) train the modle using the forward_step_func.
Arguments:
train_valid_test_dataset_provider: a function that takes the size of
train/valid/test dataset and returns `train, valid, test` datasets.
model_provider: a function that returns a vanilla version of the
model. By vanilla we mean a simple model on cpu with no fp16 or ddp.
forward_step_func: a function that takes a `data iterator` and `model`,
and returns a `loss` scalar with a dictionary with key:values being
the info we would like to monitor during training, for example
`lm-loss: value`. We also require that this function add
`batch generator` to the timers class.
extra_args_provider: a function that takes a parser and adds arguments
to it. It is used for programs to add their own arguments.
args_defaults: a dictionary from argument-name to argument-value. It
to set already parse arguments.
"""
# Initalize and get arguments, timers, and Tensorboard writer.
initialize_megatron(extra_args_provider=extra_args_provider,
args_defaults=args_defaults)
args = get_args()
if found_kill_switch():
print_datetime(f"Detected kill switch at {args.kill_switch_path}. Exiting")
sys.exit()
codecarbon_tracker_start()
# Adjust the startup time so it reflects the largest value.
# This will be closer to what scheduler will see (outside of
# image ... launches.
global _TRAIN_START_TIME
start_time_tensor = torch.cuda.FloatTensor([_TRAIN_START_TIME])
torch.distributed.all_reduce(start_time_tensor,
op=torch.distributed.ReduceOp.MIN)
_TRAIN_START_TIME = start_time_tensor.item()
print_rank_0('time to initialize megatron (seconds): {:.3f}'.format(
time.time() - _TRAIN_START_TIME))
print_datetime('after megatron is initialized')
timers = get_timers()
if args.deepspeed:
args.deepspeed_configuration = json.load(
open(args.deepspeed_config, 'r', encoding='utf-8'))
if "curriculum_learning" in args.deepspeed_configuration and \
"enabled" in args.deepspeed_configuration["curriculum_learning"]:
args.curriculum_learning = args.deepspeed_configuration[ \
"curriculum_learning"]["enabled"]
if args.curriculum_learning and \
args.pipeline_model_parallel_size >= 1:
from deepspeed.runtime.data_pipeline.curriculum_scheduler \
import CurriculumScheduler
args.curriculum_scheduler = CurriculumScheduler( \
args.deepspeed_configuration["curriculum_learning"])
# Model, optimizer, and learning rate.
timers('model-and-optimizer-setup').start()
model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider)
args.parameters_in_billions_no_embedding = get_parameters_in_billions(model, exclude_embeddings=True)
print_rank_0(f'estimated model parameters: {get_parameters_in_billions(model)}')
print_rank_0(f'estimated model parameters without embeddings: {get_parameters_in_billions(model, exclude_embeddings=True)}')
timers('model-and-optimizer-setup').stop()
print_datetime('after model, optimizer, and learning rate '
'scheduler are built')
# Data stuff.
timers('train/valid/test-data-iterators-setup').start()
if args.virtual_pipeline_model_parallel_size is not None:
all_data_iterators = [
build_train_valid_test_data_iterators(train_valid_test_dataset_provider)
for _ in range(len(model))
]
train_data_iterator = [data_iterators[0] for data_iterators in all_data_iterators]
valid_data_iterator = [data_iterators[1] for data_iterators in all_data_iterators]
test_data_iterator = [data_iterators[2] for data_iterators in all_data_iterators]
else:
train_data_iterator, valid_data_iterator, test_data_iterator = build_train_valid_test_data_iterators(
train_valid_test_dataset_provider)
if args.data_path is not None and len(args.data_path) > 1:
prefixes, weights = analyze_data_prefix(args.data_path)
setattr(args, "data_prefixes", prefixes)
setattr(args, "data_weights", weights)
elif args.train_weighted_split_paths is not None and len(args.train_weighted_split_paths[0]) > 1:
paths = args.train_weighted_split_paths[0]
weights = args.train_weighted_split_weights[0]
data_prefix = [j for i in [[w,p] for w,p in zip(weights, paths)] for j in i]
prefixes, weights = analyze_data_prefix(data_prefix)
setattr(args, "data_prefixes", prefixes)
setattr(args, "data_weights", weights)
else:
setattr(args, "data_prefixes", None)
setattr(args, "data_weights", None)
timers('train/valid/test-data-iterators-setup').stop()
print_datetime('after dataloaders are built')
# Print setup timing.
print_rank_0('done with setup ...')
timers.log(['model-and-optimizer-setup', 'train/valid/test-data-iterators-setup'])
print_rank_0('training ...')
iteration = 0
if args.do_train and args.train_iters > 0:
iteration = train(forward_step_func,
model, optimizer, lr_scheduler,
train_data_iterator, valid_data_iterator)
print_datetime('after training is done')
if args.do_valid and not args.eval_only:
names = args.valid_weighted_split_names
names = names if names is not None else ['valid'] * len(valid_data_iterator)
for iterator, name in zip(valid_data_iterator, names):
prefix = 'the end of training for val data'
evaluate_and_print_results(prefix, forward_step_func,
iterator, model,
iteration, False, data_group_name=name)
if args.save and iteration != 0:
save_checkpoint(iteration, model, optimizer, lr_scheduler)
if args.do_test:
# Run on test data.
names = args.test_weighted_split_names
names = names if names is not None else ['test'] * len(test_data_iterator)
for iterator, name in zip(test_data_iterator, names):
test(forward_step_func, iterator, model, verbose=False)
codecarbon_tracker_stop()
def update_train_iters(args):
# For iteration-based training, we don't need to do anything
if args.train_iters:
return
# Constant batch size with sample-based training.
if args.rampup_batch_size is None:
args.train_iters = args.train_samples // args.global_batch_size
else:
# Sample based training with rampup batch size.
iterations = 0
consumed_samples = 0
# Rampup phase.
while consumed_samples <= int(args.rampup_batch_size[2]):
update_num_microbatches(consumed_samples, consistency_check=False)
consumed_samples += get_current_global_batch_size()
iterations += 1
# Reset
update_num_microbatches(0, consistency_check=False)
# Constant phase
# Note that we throw away any partial last batch.
iterations += (args.train_samples - consumed_samples) // \
args.global_batch_size
args.train_iters = iterations
print_rank_0('setting training iterations to {}'.format(args.train_iters))
def get_model(model_provider_func):
"""Build the model."""
args = get_args()
# Build model.
if mpu.get_pipeline_model_parallel_world_size() > 1 and \
args.virtual_pipeline_model_parallel_size is not None:
model = []
for i in range(args.virtual_pipeline_model_parallel_size):
mpu.set_virtual_pipeline_model_parallel_rank(i)
# Set pre_process and post_process only after virtual rank is set.
pre_process = mpu.is_pipeline_first_stage()
post_process = mpu.is_pipeline_last_stage()
this_model = model_provider_func(
pre_process=pre_process,
post_process=post_process
)
model.append(this_model)
else:
pre_process = mpu.is_pipeline_first_stage()
post_process = mpu.is_pipeline_last_stage()
model = model_provider_func(
pre_process=pre_process,
post_process=post_process
)
if not isinstance(model, list):
model = [model]
# Set tensor model parallel attributes if not set.
# Only parameters that are already tensor model parallel have these
# attributes set for them. We should make sure the default attributes
# are set for all params so the optimizer can use them.
for model_module in model:
for param in model_module.parameters():
mpu.set_defaults_if_not_set_tensor_model_parallel_attributes(param)
# # Print number of parameters.
# Moved to `train` with extras
# if mpu.get_data_parallel_rank() == 0:
# print('Number of parameters on tensor={}, pipeline={}: {}'.format(
# mpu.get_tensor_model_parallel_rank(),
# mpu.get_pipeline_model_parallel_rank(),
# sum([sum([p.ds_numel if hasattr(p,'ds_id') else p.nelement() for p in model_module.parameters()])
# for model_module in model])), flush=True)
# torch.distributed.barrier()
# else:
# torch.distributed.barrier()
if args.deepspeed:
return model
# GPU allocation.
for model_module in model:
model_module.cuda(torch.cuda.current_device())
# Fp16 conversion.
if args.fp16 or args.bf16:
model = [Float16Module(model_module, args) for model_module in model]
if args.DDP_impl == 'torch':
i = torch.cuda.current_device()
model = [torchDDP(model_module, device_ids=[i], output_device=i,
process_group=mpu.get_data_parallel_group())
for model_module in model]
return model
if args.DDP_impl == 'local':
model = [LocalDDP(model_module,
args.accumulate_allreduce_grads_in_fp32,
args.use_contiguous_buffers_in_ddp)
for model_module in model]
return model
raise NotImplementedError('Unknown DDP implementation specified: {}. '
'Exiting.'.format(args.DDP_impl))
def get_learning_rate_scheduler(optimizer):
"""Build the learning rate scheduler."""
args = get_args()
# Iteration-based training.
if args.train_iters:
if args.lr_decay_iters is None:
args.lr_decay_iters = args.train_iters
decay_steps = args.lr_decay_iters * args.global_batch_size
if args.lr_warmup_fraction is not None:
warmup_steps = args.lr_warmup_fraction * decay_steps
else:
warmup_steps = args.lr_warmup_iters * args.global_batch_size
# Sample-based training.
elif args.train_samples:
# We need to set training iters for later use. Technically
# we need to adjust the training samples too (due to last
# batch being incomplete) but we leave it as is for now.
update_train_iters(args)
if args.lr_decay_samples is None:
args.lr_decay_samples = args.train_samples
decay_steps = args.lr_decay_samples
if args.lr_warmup_fraction is not None:
warmup_steps = args.lr_warmup_fraction * decay_steps
else:
warmup_steps = args.lr_warmup_samples
else:
raise Exception(
'either train-iters or train-samples should be provided.')
lr_scheduler = AnnealingLR(
optimizer,
max_lr=args.lr,
min_lr=args.min_lr,
warmup_steps=warmup_steps,
decay_steps=decay_steps,
decay_style=args.lr_decay_style,
use_checkpoint_lr_scheduler=args.use_checkpoint_lr_scheduler,
override_lr_scheduler=args.override_lr_scheduler)
return lr_scheduler
def sync_hp_to_lp(optimizer):
optimizer.update_lp_params()
# for n,p in model.named_parameters():
# print(n)
# if p._hp_mapping is not None:
# #print(f'rank {rank} fixing hp for input_layernorm')
# #p._hp_mapping.update_hp()
# hp = p._hp_mapping.hp_fragment
# torch.distributed.all_reduce(hp, op=torch.distributed.ReduceOp.AVG, group=mpu.get_tensor_model_parallel_group())
# # 3. optim states
# for key in ['exp_avg', 'exp_avg_sq']:
# optim_state_fragment = p._hp_mapping.get_optim_state_fragment(key)
# #print(f'rank {rank} before reduce optim state fragment {key} = {optim_state_fragment}')
# torch.distributed.all_reduce(optim_state_fragment, op=torch.distributed.ReduceOp.AVG, group=mpu.get_tensor_model_parallel_group())
# #print(f'rank {rank} after reduce optim state fragment {key} = {optim_state_fragment}')
def setup_model_and_optimizer(model_provider_func):
"""Setup model and optimizer."""
args = get_args()
model = get_model(model_provider_func)
unwrapped_model = unwrap_model(model,
(torchDDP, LocalDDP, Float16Module))
if args.inference:
optimizer = None
lr_scheduler = None
else:
optimizer = get_megatron_optimizer(unwrapped_model)
lr_scheduler = get_learning_rate_scheduler(optimizer)
if args.deepspeed:
print_rank_0("DeepSpeed is enabled.")
#pp = mpu.get_pipeline_model_parallel_world_size()
import json
import io
with io.open(args.deepspeed_config, "r", encoding="utf-8") as f:
config = json.load(f)
if args.universal_checkpoint:
config["checkpoint"] = {"load_universal": True}
model, optimizer, _, lr_scheduler = deepspeed.initialize(
model=model[0],
optimizer=optimizer,
lr_scheduler=lr_scheduler,
# config=config,
args=args,
)
assert model.fp16_enabled() == args.fp16, "megatron fp16 config does not match deepspeed"
assert model.bfloat16_enabled() == args.bf16, "megatron bf16 config does not match deepspeed"
if isinstance(model, deepspeed.PipelineEngine):
# hack to get batch_fn from pretrain_gpt.py
model.set_batch_fn(model.module._megatron_batch_fn)
assert model.grid.get_pipe_parallel_rank() == mpu.get_pipeline_model_parallel_rank()
assert model.grid.get_slice_parallel_rank() == mpu.get_tensor_model_parallel_rank()
assert model.grid.get_data_parallel_rank() == mpu.get_data_parallel_rank()
model = [model]
if args.load is not None:
timers = get_timers()
# Extra barrier is added to make sure all ranks report the
# max time.
torch.distributed.barrier()
timers('load-checkpoint').start()
args.iteration = load_checkpoint(model, optimizer, lr_scheduler)
torch.distributed.barrier()
timers('load-checkpoint').stop()
timers.log(['load-checkpoint'])
# hp -> lp
if args.deepspeed and args.universal_checkpoint:
sync_hp_to_lp(optimizer)
else:
args.iteration = 0
# tp_rank = mpu.get_tensor_model_parallel_rank()
# pp_rank = mpu.get_pipeline_model_parallel_rank()
# dp_rank = mpu.get_data_parallel_rank()
# for n,p in model[0].named_parameters():
# if 'word_embeddings.weight' not in n:
# continue
# if tp_rank == 0 and pp_rank == 0:
# print(f"{tp_rank=}{pp_rank=}{dp_rank=} bf16 {n=} {p[:10]=}")
# if p._hp_mapping is not None:
# hp = p._hp_mapping.hp_fragment
# print(f'{tp_rank=}{pp_rank=}{dp_rank=} fp32 {n=} {hp[:10]=}')
# if tp_rank == 0 and pp_rank == mpu.get_pipeline_model_parallel_world_size() - 1:
# print(f"{tp_rank=}{pp_rank=}{dp_rank=} bf16 {n=} {p[:10]=}")
# if p._hp_mapping is not None:
# hp = p._hp_mapping.hp_fragment
# print(f'{tp_rank=}{pp_rank=}{dp_rank=} fp32 {n=} {hp[:10]=}')
# We only support local DDP with multiple micro-batches.
if len(model) > 1 or mpu.get_pipeline_model_parallel_world_size() > 1:
assert args.DDP_impl == 'local'
# get model without FP16 and/or TorchDDP wrappers
if args.iteration == 0 and len(unwrapped_model) == 1 \
and hasattr(unwrapped_model[0], 'init_state_dict_from_bert'):
print_rank_0("Initializing ICT from pretrained BERT model")
unwrapped_model[0].init_state_dict_from_bert()
if args.fp16:
optimizer.reload_model_params()
return model, optimizer, lr_scheduler
def train_step(forward_step_func, data_iterator,
model, optimizer, lr_scheduler):
"""Single training step."""
args = get_args()
timers = get_timers()
if args.deepspeed:
assert isinstance(model[0], deepspeed.PipelineEngine), model
loss = model[0].train_batch(data_iter=data_iterator)
skipped_iter = 0
grad_norm = model[0].get_global_grad_norm()
num_zeros_in_grad = 0
return {'lm loss' : loss}, skipped_iter, grad_norm, num_zeros_in_grad
# Set grad to zero.
if not args.deepspeed:
if args.DDP_impl == 'local' and args.use_contiguous_buffers_in_ddp:
for partition in model:
partition.zero_grad_buffer()
else:
optimizer.zero_grad()
if mpu.get_pipeline_model_parallel_world_size() > 1:
if args.virtual_pipeline_model_parallel_size is not None:
forward_backward_func = forward_backward_pipelining_with_interleaving
assert get_num_microbatches() % args.pipeline_model_parallel_size == 0, \
'number of microbatches is not divisible by pipeline-parallel ' \
'size when using interleaved schedule'
else:
forward_backward_func = forward_backward_pipelining_without_interleaving
else:
forward_backward_func = forward_backward_no_pipelining
losses_reduced = forward_backward_func(
forward_step_func, data_iterator, model,
optimizer, timers, forward_only=False)
# All-reduce if needed.
if not args.deepspeed and args.DDP_impl == 'local':
timers('backward-params-all-reduce').start()
for model_module in model:
model_module.allreduce_gradients()
timers('backward-params-all-reduce').stop()
# All-reduce word_embeddings' grad across first and last stages to ensure
# that word_embeddings parameters stay in sync.
# This should only run for models that support pipelined model parallelism
# (BERT and GPT-2).
timers('backward-embedding-all-reduce').start()
if not args.deepspeed:
if (mpu.is_pipeline_first_stage(ignore_virtual=True) or
mpu.is_pipeline_last_stage(ignore_virtual=True)) and \
mpu.get_pipeline_model_parallel_world_size() > 1:
if mpu.is_pipeline_first_stage(ignore_virtual=True):
unwrapped_model = model[0]
elif mpu.is_pipeline_last_stage(ignore_virtual=True):
unwrapped_model = model[-1]
unwrapped_model = unwrap_model(
unwrapped_model, (torchDDP, LocalDDP, Float16Module))
if unwrapped_model.share_word_embeddings:
word_embeddings_weight = unwrapped_model.word_embeddings_weight()
if args.DDP_impl == 'local':
grad = word_embeddings_weight.main_grad
else:
grad = word_embeddings_weight.grad
torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())
timers('backward-embedding-all-reduce').stop()
# Update parameters.
timers('optimizer').start()
if args.deepspeed:
increment = get_num_microbatches() * \
args.micro_batch_size * \
args.data_parallel_size
model[0].step(lr_kwargs={'increment': increment})
update_successful = model[0].was_step_applied()
else:
update_successful, grad_norm, num_zeros_in_grad = optimizer.step()
timers('optimizer').stop()
# Update learning rate.
if args.deepspeed:
skipped_iter = 0
grad_norm = None
num_zeros_in_grad = None
else:
if update_successful:
increment = get_num_microbatches() * \
args.micro_batch_size * \
args.data_parallel_size
lr_scheduler.step(increment=increment)
skipped_iter = 0
else:
skipped_iter = 1
if mpu.is_pipeline_last_stage(ignore_virtual=True):
# Average loss across microbatches.
loss_reduced = {}
for key in losses_reduced[0]:
losses_reduced_for_key = [x[key] for x in losses_reduced]
loss_reduced[key] = sum(losses_reduced_for_key) / len(losses_reduced_for_key)
return loss_reduced, skipped_iter, grad_norm, num_zeros_in_grad
return {}, skipped_iter, grad_norm, num_zeros_in_grad
def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
loss_scale, report_memory_flag, skipped_iter,
grad_norm, params_norm, num_zeros_in_grad,
model=None):
"""Log training information such as losses, timing, ...."""
args = get_args()
timers = get_timers()
writer = get_tensorboard_writer()
# Advanced, skipped, and Nan iterations.
advanced_iters_key = 'advanced iterations'
skipped_iters_key = 'skipped iterations'
nan_iters_key = 'nan iterations'
# Advanced iterations.
if not skipped_iter:
total_loss_dict[advanced_iters_key] = total_loss_dict.get(
advanced_iters_key, 0) + 1
else:
if advanced_iters_key not in total_loss_dict:
total_loss_dict[advanced_iters_key] = 0
# Skipped iterations.
total_loss_dict[skipped_iters_key] = total_loss_dict.get(
skipped_iters_key, 0) + skipped_iter
# Update losses and set nan iterations
got_nan = False
for key in loss_dict:
if not skipped_iter:
total_loss_dict[key] = total_loss_dict.get(
key, torch.cuda.FloatTensor([0.0])) + loss_dict[key]
else:
value = loss_dict[key].float().sum().item()
is_nan = value == float('inf') or \
value == -float('inf') or \
value != value
got_nan = got_nan or is_nan
total_loss_dict[nan_iters_key] = total_loss_dict.get(
nan_iters_key, 0) + int(got_nan)
# Logging.
timers_to_log = []
def add_to_logging(name):
if name in timers.timers:
timers_to_log.append(name)
add_to_logging('forward-compute')
add_to_logging('forward-recv')
add_to_logging('forward-send')
add_to_logging('forward-backward-send-forward-backward-recv')
add_to_logging('backward-compute')
add_to_logging('backward-recv')
add_to_logging('backward-send')
add_to_logging('backward-send-forward-recv')
add_to_logging('backward-send-backward-recv')
add_to_logging('backward-params-all-reduce')
add_to_logging('backward-embedding-all-reduce')
add_to_logging('optimizer-copy-to-main-grad')
add_to_logging('optimizer-unscale-and-check-inf')
add_to_logging('optimizer-clip-main-grad')
add_to_logging('optimizer-copy-main-to-model-params')
add_to_logging('optimizer')
add_to_logging('batch-generator')
# Calculate batch size.
batch_size = args.micro_batch_size * args.data_parallel_size * \
get_num_microbatches()
total_iterations = total_loss_dict[advanced_iters_key] + \
total_loss_dict[skipped_iters_key]
# Tensorboard values.
if writer and (iteration % args.tensorboard_log_interval == 0) and \
is_last_rank():
writer.add_scalar('steps-vs-samples/y=steps,x=samples', iteration, args.consumed_train_samples)
writer.add_scalar('steps-vs-samples/y=samples,x=steps', args.consumed_train_samples, iteration)
writer.add_scalar('steps-vs-tokens/y=steps,x=tokens', iteration, args.consumed_train_tokens)
writer.add_scalar('steps-vs-tokens/y=tokens,x=steps', args.consumed_train_tokens, iteration)
if args.log_learning_rate_to_tensorboard:
writer.add_scalar('learning-rate/learning-rate', learning_rate, iteration)
writer.add_scalar('learning-rate/learning-rate vs samples', learning_rate,
args.consumed_train_samples)
writer.add_scalar('learning-rate/learning-rate vs tokens', learning_rate,
args.consumed_train_tokens)
if args.log_batch_size_to_tensorboard:
writer.add_scalar('batch-size/batch-size', batch_size, iteration)
writer.add_scalar('batch-size/batch-size vs samples', batch_size,
args.consumed_train_samples)
for key in loss_dict:
writer.add_scalar(f"lm-loss-training/{key}", loss_dict[key], iteration)
writer.add_scalar(f"lm-loss-training/{key}" + ' vs samples', loss_dict[key],
args.consumed_train_samples)
writer.add_scalar(f"lm-loss-training/{key}" + ' vs tokens', loss_dict[key],
args.consumed_train_tokens)
writer.add_scalar(f"lm-loss-training/{key}" + ' vs gigaflos (without embeddings)', loss_dict[key],
args.gigaflos_no_embeds)
if args.log_loss_scale_to_tensorboard and args.fp16:
writer.add_scalar('loss-scale/loss-scale', loss_scale, iteration)
writer.add_scalar('loss-scale/loss-scale vs samples', loss_scale,
args.consumed_train_samples)
writer.add_scalar('loss-scale/loss-scale vs tokens', loss_scale,
args.consumed_train_tokens)
if grad_norm is not None:
writer.add_scalar('grad-norm/grad-norm', grad_norm, iteration)
writer.add_scalar('grad-norm/grad-norm vs samples', grad_norm,
args.consumed_train_samples)
writer.add_scalar('grad-norm/grad-norm vs tokens', grad_norm,
args.consumed_train_tokens)
if num_zeros_in_grad is not None:
writer.add_scalar('num-zeros/num-zeros', num_zeros_in_grad, iteration)
writer.add_scalar('num-zeros/num-zeros vs samples', num_zeros_in_grad,
args.consumed_train_samples)
writer.add_scalar('num-zeros/num-zeros vs tokens', num_zeros_in_grad,
args.consumed_train_tokens)
if params_norm is not None:
writer.add_scalar('params-norm/params-norm', params_norm, iteration)
writer.add_scalar('params-norm/params-norm vs samples', params_norm,
args.consumed_train_samples)
writer.add_scalar('params-norm/params-norm vs tokens', params_norm,
args.consumed_train_tokens)
if args.curriculum_learning:
writer.add_scalar('curriculum_seqlen', args.curriculum_seqlen,
iteration)
# It's very questionable what this data contributes, other than huge unstripped file paths
# as keys and hundreds of TB boards that make the TB files very bloated. So disabling for now.
#
# if args.data_weights is not None:
# for prefix, weight in zip(args.data_prefixes, args.data_weights):
# name = prefix.split(",")[-1]
# writer.add_scalar(f'samples-per-dataset/{name}', args.consumed_train_samples * weight, args.consumed_train_samples)
# writer.add_scalar(f'steps-per-dataset/{name}', iteration * weight, iteration)
# writer.add_scalar(f'tokens-per-dataset/{name}', args.consumed_train_tokens * weight, args.consumed_train_tokens)
if args.log_timers_to_tensorboard:
timers.write(timers_to_log, writer, iteration,
normalizer=total_iterations)
if iteration % args.log_interval == 0:
elapsed_time = timers('interval-time').elapsed()
elapsed_time_per_iteration = elapsed_time / total_iterations
seq_len = args.curriculum_seqlen if args.curriculum_learning else args.seq_length
hidden_size = args.hidden_size
num_layers = args.num_layers
vocab_size = args.padded_vocab_size
# Compute throughput.
samples_per_sec = batch_size / elapsed_time_per_iteration
samples_per_sec_per_replica = samples_per_sec / args.data_parallel_size
tokens_per_sec = samples_per_sec * seq_len
tokens_per_sec_per_replica = tokens_per_sec / args.data_parallel_size
# General TFLOPs formula (borrowed from Equation 3 in Section 5.1 of
# https://arxiv.org/pdf/2104.04473.pdf).
# The factor of 4 is when used with activation check-pointing,
# otherwise it will be 3, but for 200B model, activation check-pointing will always be on.
checkpoint_activations_factor = 4 if args.checkpoint_activations else 3
# GLU activations double the hidden states in the upscaling feed-forward in each transformer layer
# This leads to 16bsh^2 instead of 8bsh^2 per first feed-forward layer in MLP, thus we increase the coefficient by 8.
# Refer to https://github.com/bigscience-workshop/Megatron-DeepSpeed/pull/283#issue-1260805063 for more details.
coefficient = 32 if args.glu_activation else 24
flops_per_iteration = (coefficient * checkpoint_activations_factor * batch_size * seq_len * num_layers * (hidden_size**2)) * (1. + (seq_len / (6. * hidden_size)) + (vocab_size / (16. * num_layers * hidden_size)))
tflops = flops_per_iteration / (elapsed_time_per_iteration * args.world_size * (10**12))
# only the last rank process has a non-None _GLOBAL_TENSORBOARD_WRITER
if writer and is_last_rank():
if args.log_timers_to_tensorboard:
writer.add_scalar('iteration-time/iteration-time',
elapsed_time_per_iteration, iteration)
writer.add_scalar('iteration-time/iteration-time vs samples',
elapsed_time_per_iteration, args.consumed_train_samples)
writer.add_scalar('iteration-time/iteration-time vs tokens',
elapsed_time_per_iteration, args.consumed_train_tokens)
writer.add_scalar('iteration-time/samples per second',
samples_per_sec, args.iteration)
writer.add_scalar('iteration-time/samples per second per replica',
samples_per_sec_per_replica, args.iteration)
writer.add_scalar('iteration-time/tokens per second',
tokens_per_sec, args.iteration)
writer.add_scalar('iteration-time/tokens per second per replica',
tokens_per_sec_per_replica, args.iteration)
writer.add_scalar('iteration-time/TFLOPs per gpu (estimated)',
tflops, args.iteration)
log_string = ' iteration {:8d}/{:8d} |'.format(
iteration, args.train_iters)
log_string += ' consumed samples: {:12d} |'.format(
args.consumed_train_samples)
log_string += ' consumed tokens: {:12d} |'.format(
args.consumed_train_tokens)
log_string += ' elapsed time per iteration (s): {:.2f} |'.format(
elapsed_time_per_iteration)
log_string += ' learning rate: {:.3E} |'.format(learning_rate)
log_string += ' global batch size: {:5d} |'.format(batch_size)
for key in total_loss_dict:
if key not in [advanced_iters_key, skipped_iters_key,
nan_iters_key]:
avg = total_loss_dict[key].item() / \
float(max(1, total_loss_dict[advanced_iters_key]))
if avg > 0.0:
log_string += ' {}: {:.6E} |'.format(key, avg)
total_loss_dict[key] = torch.cuda.FloatTensor([0.0])
if args.fp16:
log_string += ' loss scale: {:.1f} |'.format(loss_scale)
if grad_norm is not None:
log_string += ' grad norm: {:.3f} |'.format(grad_norm)
if num_zeros_in_grad is not None:
log_string += ' num zeros: {:.1f} |'.format(num_zeros_in_grad)
if params_norm is not None:
log_string += ' params norm: {:.3f} |'.format(params_norm)
if args.curriculum_learning:
log_string += ' curriculum seqlen: {:5d} |'.format(args.curriculum_seqlen)
log_string += ' number of skipped iterations: {:3d} |'.format(
total_loss_dict[skipped_iters_key])
log_string += ' number of nan iterations: {:3d} |'.format(
total_loss_dict[nan_iters_key])
log_string += ' samples per second: {:.3f} |'.format(samples_per_sec)
log_string += ' TFLOPs: {:.2f} |'.format(tflops)
total_loss_dict[advanced_iters_key] = 0
total_loss_dict[skipped_iters_key] = 0
total_loss_dict[nan_iters_key] = 0
print_rank_last(log_string)
if report_memory_flag and learning_rate > 0.:
# Report memory after optimizer state has been initialized.
report_memory('(after {} iterations)'.format(iteration))
report_memory_flag = False
timers.log(timers_to_log, normalizer=args.log_interval)
flops_calculator(model, args, elapsed_time)
return report_memory_flag
def save_checkpoint_and_time(iteration, model, optimizer, lr_scheduler):
timers = get_timers()
# Extra barrier is added to make sure
# all ranks report the max time.
torch.distributed.barrier()
timers('save-checkpoint').start()
save_checkpoint(iteration, model, optimizer, lr_scheduler)
torch.distributed.barrier()
timers('save-checkpoint').stop()
timers.log(['save-checkpoint'])
def train(forward_step_func, model, optimizer, lr_scheduler,
train_data_iterator, valid_data_iterator):
"""Train the model function."""
args = get_args()
timers = get_timers()
if args.rank == 0:
print("Number of parameters: [tensor rank - pipeline rank] w/ and w/o embeddings:")
torch.distributed.barrier()
if mpu.get_data_parallel_rank() == 0:
tp_rank = mpu.get_tensor_model_parallel_rank()
pp_rank = mpu.get_pipeline_model_parallel_rank()
preamble = f"[{tp_rank:0>3d}-{pp_rank:0>3d}]"
print(f"{preamble} {get_parameters_in_billions(model):.4f}B / {get_parameters_in_billions(model, exclude_embeddings=True):.4f}B", flush=True)
torch.distributed.barrier()
else:
torch.distributed.barrier()
# Write args to tensorboard
write_args_to_tensorboard()
log_restart_to_tensorboard()
# Turn on training mode which enables dropout.
for model_module in model:
model_module.train()
# Tracking loss.
total_loss_dict = {}
# Iterations.
iteration = args.iteration
timers('interval-time').start()
print_datetime('before the start of training step')
report_memory_flag = True
# flush intervals prior to current iteration
if args.skip_train_iteration_range is not None:
ends = [end for start, end in args.skip_train_iteration_range]
index = bisect.bisect_left(ends, iteration)
for _ in range(index):
args.skip_train_iteration_range.popleft()
while iteration < args.train_iters:
if (
# train_data_iterator is not None
args.skip_train_iteration_range is not None
and len(args.skip_train_iteration_range) > 0
and args.skip_train_iteration_range[0][0] <= iteration + 1 <= args.skip_train_iteration_range[0][1]
):
start, end = args.skip_train_iteration_range.popleft()
print_rank_0(f"Skipped iterations {start} to {end} due to --skip-train-iteration-range flag.")
iteration_for_skipping = args.iteration
while iteration_for_skipping + 1 <= end:
try:
_ = next(train_data_iterator)
except TypeError:
pass
iteration_for_skipping += 1
continue
if found_kill_switch():
save_checkpoint_and_time(iteration, model, optimizer, lr_scheduler)
print_datetime(f"Detected kill switch at {args.kill_switch_path}. Exiting")
sys.exit()
update_num_microbatches(args.consumed_train_samples)
if args.deepspeed:
# inform deepspeed of any batch size changes
global_batch_size = mpu.get_data_parallel_world_size() * \
args.micro_batch_size * \
get_num_microbatches()
model[0].set_train_batch_size(global_batch_size)
if args.curriculum_learning and \
args.pipeline_model_parallel_size >= 1:
args.curriculum_seqlen = args.curriculum_scheduler.update_difficulty( \
args.iteration + 1)
loss_dict, skipped_iter, grad_norm, num_zeros_in_grad = \
train_step(forward_step_func,
train_data_iterator,
model,
optimizer,
lr_scheduler)
iteration += 1
args.iteration = iteration
new_samples = mpu.get_data_parallel_world_size() * \
args.micro_batch_size * \
get_num_microbatches()
args.consumed_train_samples += new_samples
if args.curriculum_learning:
args.consumed_train_tokens += new_samples * args.curriculum_seqlen
else:
args.consumed_train_tokens += new_samples * args.seq_length
args.gigaflos_no_embeds += (6 * new_samples * args.seq_length * get_parameters_in_billions(model, exclude_embeddings=True))
# Logging.
loss_scale = None
if args.fp16:
if args.deepspeed:
loss_scale = model[0].optimizer.cur_scale
else:
loss_scale = optimizer.get_loss_scale().item()
params_norm = None
if args.log_params_norm:
params_norm = calc_params_l2_norm(model)
report_memory_flag = training_log(loss_dict, total_loss_dict,
optimizer.param_groups[0]['lr'],
iteration, loss_scale,
report_memory_flag, skipped_iter,
grad_norm, params_norm, num_zeros_in_grad,
model)
# Autoresume
if args.adlr_autoresume and \
(iteration % args.adlr_autoresume_interval == 0):
check_adlr_autoresume_termination(iteration, model, optimizer,
lr_scheduler)
# Evaluation
if args.eval_interval and iteration % args.eval_interval == 0 and \
args.do_valid:
prefix = 'iteration {}'.format(iteration)
names = args.valid_weighted_split_names
names = names if names is not None else ['valid'] * len(valid_data_iterator)
for iterator, name in zip(valid_data_iterator, names):
evaluate_and_print_results(prefix, forward_step_func,
iterator, model,
iteration, False, data_group_name=name)
# Checkpointing
saved_checkpoint = False
if args.save and args.save_interval and \
iteration % args.save_interval == 0:
save_checkpoint_and_time(iteration, model, optimizer,
lr_scheduler)
saved_checkpoint = True
# Exiting based on duration
if args.exit_duration_in_mins:
train_time = (time.time() - _TRAIN_START_TIME) / 60.0
done_cuda = torch.cuda.IntTensor(
[train_time > args.exit_duration_in_mins])
torch.distributed.all_reduce(
done_cuda, op=torch.distributed.ReduceOp.MAX)
done = done_cuda.item()
if done:
if not saved_checkpoint:
save_checkpoint_and_time(iteration, model, optimizer,
lr_scheduler)
print_datetime('exiting program after {} minutes'.format(train_time))
sys.exit()
# Exiting based on iterations
if args.exit_interval and iteration % args.exit_interval == 0:
if not saved_checkpoint:
save_checkpoint_and_time(iteration, model, optimizer,
lr_scheduler)
torch.distributed.barrier()
print_datetime('exiting program at iteration {}'.format(iteration))
sys.exit()
return iteration
def evaluate(forward_step_func, data_iterator, model, verbose=False):
"""Evaluation."""
args = get_args()
# Turn on evaluation mode which disables dropout.
for model_module in model:
model_module.eval()
if args.curriculum_learning and \
args.pipeline_model_parallel_size >= 1:
# When curriculum learning is used with pipeline parallelism, we need
# this logic to ensure that the eval data is not truncated. If there
# is a seqlen change due to that, we need to call
# reset_activation_shape() to reset some buffers in deepspeed pipeline
# engine.
if args.curriculum_seqlen < args.seq_length:
args.curriculum_seqlen = args.seq_length
model[0].reset_activation_shape()
total_loss_dict = {}
with torch.no_grad():
iteration = 0
while iteration < args.eval_iters:
iteration += 1
if verbose and iteration % args.log_interval == 0:
print_rank_0('Evaluating iter {}/{}'.format(iteration,
args.eval_iters))
if mpu.get_pipeline_model_parallel_world_size() > 1:
if args.virtual_pipeline_model_parallel_size is not None:
forward_backward_func = forward_backward_pipelining_with_interleaving
else:
forward_backward_func = forward_backward_pipelining_without_interleaving
else:
forward_backward_func = forward_backward_no_pipelining
if args.deepspeed:
# DeepSpeed uses eval_batch() and already aggregates losses.
assert isinstance(model, list) and len(model) == 1
loss = model[0].eval_batch(data_iterator)
loss_dicts = [{'lm loss' : loss}] * get_num_microbatches()
else:
loss_dicts = forward_backward_func(
forward_step_func, data_iterator, model, optimizer=None,
timers=None, forward_only=True)
if mpu.is_pipeline_last_stage(ignore_virtual=True):
# Reduce across processes.
for loss_dict in loss_dicts:
for key in loss_dict:
total_loss_dict[key] = total_loss_dict.get(
key, torch.cuda.FloatTensor([0.0])) + loss_dict[key]
args.consumed_valid_samples += mpu.get_data_parallel_world_size() \
* args.micro_batch_size \
* get_num_microbatches()
# Move model back to the train mode.
for model_module in model:
model_module.train()
for key in total_loss_dict:
total_loss_dict[key] /= args.eval_iters * get_num_microbatches()
if args.curriculum_learning and \
args.pipeline_model_parallel_size >= 1:
# roll back to actual curriculum seqlen at the end of eval.
args.curriculum_seqlen = args.curriculum_scheduler.update_difficulty( \
args.iteration + 1)
if args.curriculum_seqlen < args.seq_length:
model[0].reset_activation_shape()
return total_loss_dict
class Testdataset(Dataset):
def __init__(self, img_paths, crop_size=224, color_jitter=True):
self.imgs = img_paths
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
fp16_t = transforms.ConvertImageDtype(torch.half)
self.transforms = transforms.Compose(
[
transforms.Resize(crop_size),
transforms.CenterCrop(crop_size),
transforms.ToTensor(),
normalize,
fp16_t
]
)
def __getitem__(self, index):
img_path = self.imgs[index]
label = torch.from_numpy(np.array(0)).half()
data_path = img_path
pil_img = Image.open(img_path).convert('RGB')
data = self.transforms(pil_img)
return data, label, data_path
def __len__(self):
return len(self.imgs)
def test(forward_step_func, data_iterator, model, verbose=False):
"""Test."""
args = get_args()
test_data_path = os.path.join(args.data_path[0], "test/images")
img_paths = sorted(glob.glob(test_data_path + "/*"))
test_dataset = Testdataset(img_paths=img_paths)
data_iterator = iter(DataLoader(test_dataset, batch_size=1, shuffle=False))
# Turn on evaluation mode which disables dropout.
for model_module in model:
model_module.eval()
if args.curriculum_learning and \
args.pipeline_model_parallel_size >= 1:
# When curriculum learning is used with pipeline parallelism, we need
# this logic to ensure that the eval data is not truncated. If there
# is a seqlen change due to that, we need to call
# reset_activation_shape() to reset some buffers in deepspeed pipeline
# engine.
if args.curriculum_seqlen < args.seq_length:
args.curriculum_seqlen = args.seq_length
model[0].reset_activation_shape()
with torch.no_grad():
iteration = 0
while iteration < len(img_paths):#test images num:len(img_paths)
iteration += 1
if verbose and iteration % args.log_interval == 0:
print_rank_0('Evaluating iter {}/{}'.format(iteration, args.eval_iters))
if mpu.get_pipeline_model_parallel_world_size() > 1:
if args.virtual_pipeline_model_parallel_size is not None:
forward_backward_func = forward_backward_pipelining_with_interleaving
else:
forward_backward_func = forward_backward_pipelining_without_interleaving
else:
forward_backward_func = forward_backward_no_pipelining
if args.deepspeed:
# DeepSpeed uses eval_batch() and already aggregates losses.
assert isinstance(model, list) and len(model) == 1
data_path = next(data_iterator)[2][0]
logits = model[0].eval_batch(data_iterator, compute_loss = False, reduce_output = None)
logits = torch.cat(logits, 0)
outputs = torch.argmax(logits, -1)[0]
if args.rank == 0:
print(data_path,': ',outputs.cpu().numpy())
else:
data = next(data_iterator)
data_path = data[2][0]
images = data[0].cuda()
logits = model[0](images).contiguous().float()
outputs = torch.argmax(logits, -1)[0]
if args.rank == 0:
print(data_path,': ',outputs.cpu().numpy())
print('the end of training for test data')
def evaluate_and_print_results(prefix, forward_step_func,
data_iterator, model,
iteration, verbose=False, **kwargs):
"""Helper function to evaluate and dump results on screen."""
args = get_args()
writer = get_tensorboard_writer()
ds_name = kwargs.get("data_group_name", None)
# print corresponding dataset name (used for multiple validation datasets)
tf_plot_prefix = f"lm-loss-validation/{ds_name}" if ds_name else "lm-loss-validation"
data_iterator = iter(next(iter(data_iterator)))
total_loss_dict = evaluate(forward_step_func, data_iterator, model, verbose)
string = '{} loss at {} | '.format(ds_name, prefix) if ds_name is not None\
else 'validation loss at {} | '.format(prefix)
for key in total_loss_dict:
string += '{} value: {:.6E} | '.format(key, total_loss_dict[key].item())
ppl = math.exp(min(20, total_loss_dict[key].item()))
string += '{} PPL: {:.6E} | '.format(key, ppl)
if writer and is_last_rank():
writer.add_scalar(f'{tf_plot_prefix}/{key} validation',
total_loss_dict[key].item(),
iteration)
writer.add_scalar(f'{tf_plot_prefix}/{key} validation vs samples',
total_loss_dict[key].item(),
args.consumed_train_samples)
writer.add_scalar(f'{tf_plot_prefix}/{key} validation vs tokens',
total_loss_dict[key].item(),
args.consumed_train_tokens)
writer.add_scalar(f'{tf_plot_prefix}/{key} validation vs gigaflos (without embeddings)',
total_loss_dict[key].item(),
args.gigaflos_no_embeds)
if args.log_validation_ppl_to_tensorboard:
writer.add_scalar(f'{tf_plot_prefix}/{key} validation ppl', ppl,
iteration)
writer.add_scalar(f'{tf_plot_prefix}/{key} validation ppl vs samples',
ppl, args.consumed_train_samples)
writer.add_scalar(f'{tf_plot_prefix}/{key} validation ppl vs tokens',
ppl, args.consumed_train_tokens)
writer.add_scalar(f'{tf_plot_prefix}/{key} validation ppl vs gigaflos (without embeddings)',
ppl, args.gigaflos_no_embeds)
length = len(string) + 1
print_rank_last('-' * length)
print_rank_last(string)
print_rank_last('-' * length)
def cyclic_iter(iter):
while True:
for x in iter:
yield x
def build_train_valid_test_data_iterators(
build_train_valid_test_datasets_provider):
"""XXX"""
args = get_args()
(train_dataloader, valid_dataloaders, test_dataloaders) = (None, None, None)
print_rank_0('> building train, validation, and test datasets ...')
# Backward compatibility, assume fixed batch size.
if args.iteration > 0 and args.consumed_train_samples == 0:
assert args.train_samples is None, \
'only backward compatiblity support for iteration-based training'
args.consumed_train_samples = args.iteration * args.global_batch_size
# it's possible that train was run, but not eval and it's valid if
# args.consumed_valid_samples == 0
# TODO: eval_interval could have changed between runs, so this might still be wrong
if args.iteration // args.eval_interval > 0 and args.consumed_valid_samples == 0:
assert args.train_samples is None, \
'only backward compatiblity support for iteration-based training'
args.consumed_valid_samples = (args.iteration // args.eval_interval) * \
args.eval_iters * args.global_batch_size
# Data loader only on rank 0 of each model parallel group.
if mpu.get_tensor_model_parallel_rank() == 0:
# Number of train/valid/test samples.
if args.train_samples:
train_samples = args.train_samples
else:
train_samples = args.train_iters * args.global_batch_size
eval_iters = (args.train_iters // args.eval_interval + 1) * \
args.eval_iters
test_iters = args.eval_iters
train_val_test_num_samples = [train_samples,
eval_iters * args.global_batch_size,
test_iters * args.global_batch_size]
print_rank_0(' > datasets target sizes (minimum size):')
print_rank_0(' train: {}'.format(train_val_test_num_samples[0]))
print_rank_0(' validation: {}'.format(train_val_test_num_samples[1]))
print_rank_0(' test: {}'.format(train_val_test_num_samples[2]))
# Build the datasets.
train_ds, valid_ds, test_ds = build_train_valid_test_datasets_provider(train_val_test_num_samples)
# if dataloading option is not 2 convert to list to allow
# same interface for multiple data groups
# for validation and testing in option 2
if type(train_ds) != list and train_ds is not None:
train_ds = [train_ds]
if type(valid_ds) != list and valid_ds is not None:
valid_ds = [valid_ds]
if type(test_ds) != list and test_ds is not None:
test_ds = [test_ds]
# Build dataloders.
assert len(train_ds) == 1, "only one training dataset group is allowed"
# train_dataloader is a single item while valid_dataloaders
# and test_dataloaders are arrays
train_dataloader = build_pretraining_data_loader(
train_ds[0], args.consumed_train_samples)
# We collapse None and empty list as both should mean we don't run validation
# args.consumed_valid_samples accumulates the sum of valid steps for every dataset, which are all equal
#
# XXX: we get a deadlock in the dataloader on multi-dataset eval, after the first dataset,
# possibly due to this bug in pytorch https://github.com/pytorch/pytorch/pull/25158. Using
# num_workers=0 to work around it - the training can't use that since it impacts throughput
# by a few percent
valid_dataloaders = [build_pretraining_data_loader(d, args.consumed_valid_samples // len(valid_ds), num_workers=args.valid_num_workers)
for d in valid_ds] \
if valid_ds is not None else []
# We collapse None and empty list as both should mean we don't run test
test_dataloaders = [build_pretraining_data_loader(d, 0) for d in test_ds] \
if test_ds is not None else []
# Flags to know if we need to do training/validation/testing.
do_train = train_dataloader is not None and args.train_iters > 0 and not args.eval_only
# Need to broadcast num_tokens and num_type_tokens.
flags = torch.cuda.LongTensor([
int(do_train),
len(valid_dataloaders) if args.eval_iters > 0 else 0, # eval_iters == 0 is equivalent to having no validation
len(test_dataloaders) if args.eval_iters > 0 else 0, # eval_iters == 0 is equivalent to having no test
])
else:
flags = torch.cuda.LongTensor([0, 0, 0])
# Broadcast num tokens.
torch.distributed.broadcast(flags,
mpu.get_tensor_model_parallel_src_rank(),
group=mpu.get_tensor_model_parallel_group())
args.do_train = flags[0].item()
num_valid_ds = flags[1].item()
num_test_ds = flags[2].item()
assert num_test_ds >= 0
assert num_valid_ds >= 0
args.do_valid = num_valid_ds > 0
args.do_test = num_test_ds > 0
# Build iterators.
dl_type = args.dataloader_type
assert dl_type in ['single', 'cyclic']
if train_dataloader is not None:
train_data_iterator = iter(train_dataloader) if dl_type in ['single'] \
else iter(cyclic_iter(train_dataloader))
else:
train_data_iterator = None
if valid_dataloaders is not None:
valid_data_iterators = [iter(vdl) if dl_type in ['single'] \
else iter(cyclic_iter(valid_dataloaders))
for vdl in valid_dataloaders]
else:
valid_data_iterators = [None] * num_valid_ds
if test_dataloaders is not None:
test_data_iterators = [iter(tdl) if dl_type in ['single'] \
else iter(cyclic_iter(test_dataloaders))
for tdl in test_dataloaders]
else:
test_data_iterators = [None] * num_test_ds
return train_data_iterator, valid_data_iterators, test_data_iterators
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""General utilities."""
import os
import sys
import warnings
from random import randint
import torch
from torch import nn
from torch.nn.parallel import DistributedDataParallel as torchDDP
from apex.multi_tensor_apply import multi_tensor_applier
import amp_C
from megatron import get_args, logging
from megatron import print_rank_0
from megatron import get_adlr_autoresume
from megatron import mpu
from megatron.model.module import param_is_not_shared
from megatron.model.utils import log_debug_usage
from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate, VocabParallelEmbedding
from megatron import get_num_microbatches
logger = logging.get_logger(__name__)
def unwrap_model(model, module_instances=(torchDDP)):
return_list = True
if not isinstance(model, list):
model = [model]
return_list = False
unwrapped_model = []
for model_module in model:
while isinstance(model_module, module_instances):
model_module = model_module.module
unwrapped_model.append(model_module)
if not return_list:
return unwrapped_model[0]
return unwrapped_model
def calc_params_l2_norm(model):
"""Calculate l2 norm of parameters """
args = get_args()
if not isinstance(model, list):
model = [model]
# Remove duplicate params.
params_data = []
for model_ in model:
for param in model_.parameters():
is_not_shared = param_is_not_shared(param)
is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
if is_not_shared and is_not_tp_duplicate:
if args.bf16:
params_data.append(param.data.float())
else:
params_data.append(param.data)
# Calculate norm
dummy_overflow_buf = torch.cuda.IntTensor([0])
norm, _ = multi_tensor_applier(
amp_C.multi_tensor_l2norm,
dummy_overflow_buf,
[params_data],
False # no per-parameter norm
)
norm_2 = norm * norm
# Sum across all model-parallel GPUs.
torch.distributed.all_reduce(norm_2,
op=torch.distributed.ReduceOp.SUM,
group=mpu.get_model_parallel_group())
return norm_2.item() ** 0.5
def average_losses_across_data_parallel_group(losses):
"""Reduce a tensor of losses across all GPUs."""
averaged_losses = torch.cat(
[loss.clone().detach().view(1) for loss in losses])
torch.distributed.all_reduce(averaged_losses,
group=mpu.get_data_parallel_group())
averaged_losses = averaged_losses / \
torch.distributed.get_world_size(group=mpu.get_data_parallel_group())
return averaged_losses
def report_memory(name):
"""Simple GPU memory report."""
mega_bytes = 1024.0 * 1024.0
string = name + ' memory (MB)'
string += ' | allocated: {}'.format(
torch.cuda.memory_allocated() / mega_bytes)
string += ' | max allocated: {}'.format(
torch.cuda.max_memory_allocated() / mega_bytes)
string += ' | reserved: {}'.format(
torch.cuda.memory_reserved() / mega_bytes)
string += ' | max reserved: {}'.format(
torch.cuda.max_memory_reserved() / mega_bytes)
if mpu.get_data_parallel_rank() == 0:
print("[Rank {}] {}".format(torch.distributed.get_rank(), string),
flush=True)
def print_params_min_max_norm(optimizer, iteration):
"""Print min, max, and norm of all parameters."""
index = 0
rank = torch.distributed.get_rank()
string = 'iteration, rank, index, tensor-model-parallel, min, max, norm\n'
optimizer_ = optimizer.optimizer
for param_group in optimizer_.param_groups:
for param in param_group['params']:
index += 1
min_ = param.data.min()
max_ = param.data.max()
norm = torch.linalg.norm(param.data)
string += '{:7d}, {:4d}, {:4d}, {:2d}, '.format(
iteration, rank, index, int(param.tensor_model_parallel))
string += '{:.6E}, {:.6E}, {:.6E}\n'.format(min_, max_, norm)
print(string, flush=True)
def check_adlr_autoresume_termination(iteration, model,
optimizer, lr_scheduler):
"""Check for autoresume signal and exit if it is received."""
from megatron.checkpointing import save_checkpoint
args = get_args()
autoresume = get_adlr_autoresume()
# Add barrier to ensure consistnecy.
torch.distributed.barrier()
if autoresume.termination_requested():
if args.save:
save_checkpoint(iteration, model, optimizer, lr_scheduler)
print_rank_0(">>> autoresume termination request found!")
if torch.distributed.get_rank() == 0:
autoresume.request_resume()
print_rank_0(">>> training terminated. Returning")
sys.exit(0)
def get_ltor_masks_and_position_ids(
data,
eod_token,
reset_position_ids,
reset_attention_mask,
eod_mask_loss,
prefix_indices,
loss_on_targets_only,
):
"""
Build masks and position id for left to right model.
:param prefix_indices: argument can have multiple types:
- None signifies that the model is fully autoregressive.
- List[int] the argument holds all prefix indices that split a row into an input and a target
- List[List[int]] the argument holds all prefix indices that split documents between input and target.
:param loss_on_targets_only: bool to determine if we should mask loss on prefix.
"""
# Extract batch size and sequence length.
micro_batch_size, seq_length = data.size()
# Attention mask (lower triangular).
if reset_attention_mask or prefix_indices is not None:
att_mask_batch = micro_batch_size
else:
att_mask_batch = 1
attention_mask = torch.tril(torch.ones(
(att_mask_batch, seq_length, seq_length), device=data.device)).view(
att_mask_batch, 1, seq_length, seq_length)
# Loss mask.
loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device)
if eod_mask_loss:
loss_mask[data == eod_token] = 0.0
# Position ids.
position_ids = torch.arange(seq_length, dtype=torch.long,
device=data.device)
position_ids = position_ids.unsqueeze(0).expand_as(data)
# We need to clone as the ids will be modifed based on batch index.
if reset_position_ids:
position_ids = position_ids.clone()
if reset_position_ids or reset_attention_mask or prefix_indices is not None:
# Loop through the batches:
for b in range(micro_batch_size):
# Find indecies where EOD token is.
eod_index = position_ids[b, data[b] == eod_token]
# If the last eod token is not the last token of the sequence, we suppose that there is a partial document
# We treat this case as if we add an eod token at the end of the sequence.
if data[b][-1] != eod_token:
eod_index = torch.cat(
(eod_index, torch.tensor([len(data[b])], dtype=eod_index.dtype, device=eod_index.device))
)
# Detach indecies from positions if going to modify positions.
if reset_position_ids:
eod_index = eod_index.clone()
# Loop through EOD indecies:
prev_index = 0
for j in range(eod_index.size()[0]):
i = eod_index[j]
if reset_attention_mask:
# Prevent cross document interactions.
attention_mask[b, 0, (i + 1):, :(i + 1)] = 0
# Prefix lm per document.
if prefix_indices:
assert isinstance(prefix_indices[b], list), f"prefix for a row has to be document specific, and consequently return a list, got {prefix_indices[b]}"
attention_mask[b, 0, prev_index: prefix_indices[b][j], prev_index: prefix_indices[b][j]] = 1
if loss_on_targets_only:
# Last token of the prefix should predict the prefix_index id
loss_mask[b, prev_index: prefix_indices[b][j] - 1] = 0.0
# Reset positions.
if reset_position_ids:
position_ids[b, (i + 1):] -= (i + 1 - prev_index)
prev_index = i + 1
# Prefix lm per row.
if prefix_indices is not None and (reset_attention_mask is False):
assert isinstance(prefix_indices[b], int), \
f"prefix for a row has to be row specific, and consequently return an int, got {prefix_indices[b]}"
attention_mask[b, 0, :prefix_indices[b], :prefix_indices[b]] = 1
if loss_on_targets_only:
# Last token of the prefix should predict the prefix_index id
loss_mask[b, :prefix_indices[b] - 1] = 0.0
# Convert attention mask to binary:
attention_mask = (attention_mask < 0.5)
return attention_mask, loss_mask, position_ids
def get_packed_attention_mask(is_causal: bool, causal_mask: torch.Tensor, decoder_is_inputs: torch.Tensor, segment_ids: torch.Tensor):
"""
Inspired by https://github.com/google-research/t5x/blob/7193407f98a8b18100b71a04ff777238be1682ca/t5x/examples/decoder_only/layers.py#L978
Arguments:
- is_causal: determines if the masking should be causal in the `inputs` part
- causal_mask: torch.BoolTensor [batch_size, sequence_length, sequence_length]
- decoder_is_inputs: torch.BoolTensor [batch_size, sequence_length]
- segment_ids: torch.IntTensor [batch_size, sequence_length]
Returns:
- attention_mask: torch.BoolTensor [batch_size, 1, sequence_length, sequence_length]
"""
"""Causal Inputs Mask:
mask = [[[[1, 1, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 0, 0],
[1, 1, 1, 1, 1, 0, 0],
[1, 1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 1, 1]]]]
"""
assert causal_mask.dtype == torch.bool
assert segment_ids.dtype == torch.long
if is_causal:
causal_inputs_mask = causal_mask
else:
assert decoder_is_inputs.dtype == torch.bool
inputs_mask = decoder_is_inputs[:, None, :, None] * decoder_is_inputs[:, None, None, :]
causal_inputs_mask = causal_mask + inputs_mask
"""Padding Mask:
mask = [[[[1, 1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 1, 0],
[0, 0, 0, 0, 0, 0, 0]]]]
"""
padding_mask = (segment_ids != 0)[:, None, :, None] * (segment_ids != 0)[:, None, None, :]
"""Segment Mask:
mask = [[[[1, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 0],
[0, 0, 0, 1, 1, 1, 0],
[0, 0, 0, 1, 1, 1, 0],
[0, 0, 0, 1, 1, 1, 0],
[0, 0, 0, 0, 0, 0, 0]]]]
"""
segment_mask = segment_ids[:, None, :, None] == segment_ids[:, None, None, :]
"""Final Mask:
mask = [[[[1, 1, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 0],
[0, 0, 0, 1, 1, 0, 0],
[0, 0, 0, 1, 1, 0, 0],
[0, 0, 0, 1, 1, 1, 0],
[0, 0, 0, 0, 0, 0, 0]]]]
"""
attention_mask = causal_inputs_mask * padding_mask * segment_mask
# Convert attention mask to binary:
attention_mask = (attention_mask < 0.5)
return attention_mask
def param_size(parameter):
return parameter.ds_numel if hasattr(parameter, 'ds_id') else parameter.nelement()
def unique_param_count(param_list):
# not actually deduplicating tied variables for now (which causes the PP > 1 double-counting bug)
return sum(dict((p.data_ptr(), param_size(p)) for p in param_list).values())
def non_embedding_params(module):
embedding_param_names = [
f"{name}.weight" for name, module_type in module.named_modules() if isinstance(module_type, nn.Embedding) or isinstance(module_type, VocabParallelEmbedding)
]
non_embedding_parameters = [
parameter for name, parameter in module.named_parameters() if name not in embedding_param_names
]
return unique_param_count(non_embedding_parameters)
def get_parameters_in_billions(model, exclude_embeddings=False):
gpus_per_model = torch.distributed.get_world_size(group=mpu.get_model_parallel_group())
if exclude_embeddings:
approx_parameters_in_billions = sum([non_embedding_params(model_module) for model_module in model])
else:
args = get_args()
if args.rank == 0:
warnings.warn("Parameter count with the embeddings will be inaccurate with PP > 1, as the first and last stage hold several copies of the embeddings")
approx_parameters_in_billions = unique_param_count([p for model_module in model for p in model_module.parameters()])
return approx_parameters_in_billions*gpus_per_model/(1e9)
def flops_calculator(model, args, iteration_time):
return # currently broken
gpus_per_model = torch.distributed.get_world_size(group = mpu.get_model_parallel_group())
approx_parameters_in_billions = get_parameters_in_billions(model)
batch_size = args.micro_batch_size * get_num_microbatches()
giga_flops_per_model_per_train_step = approx_parameters_in_billions * batch_size * args.seq_length * 2.0 * 4.0
effective_tera_flops_per_gpu = giga_flops_per_model_per_train_step / (iteration_time * 1000.0 * gpus_per_model)
print_rank_0(f"Effective Tera Flops per GPU: {round(effective_tera_flops_per_gpu, 2)} and total parameters {round(approx_parameters_in_billions, 3)} B")
def get_prefix_indices(data, eod_token, partial_prefix_indices, reset_attention_mask):
"""
Helper function in order to:
- randomly choose prefix index when there's no constraint
- check that prefix are compatible with convention.
:param data: torch.Tensor
:param eod_token: int, token_id used to signal end of document
:param partial_prefix_indices: this agument can have multiple types:
- None, it signals that all prefix indices are randomly sampled.
- List[Optional[int]], its length has to be equal to mini batch size. It stores all the indices for per row prefix.
Optional means that if set to None, we allows ourselves to sample one randomly.
- List[List[Optional[int]]], it follows the following rules:
- The first dimension refers to that sample, ie len(partial_prefix_indices) == len(data)
- The second dimension refers to the number of document of that sample, ie
len(partial_prefix_indices[b]) == (data[b] == eod_token).sum() (+1 for the last partial document).
- partial_prefix_indices have to be interleaved with eod_indices, ie
eod_indices[b][d-1] < partial_prefix_indices[b][d] < eod_indices[b][d] + 1 or is None.
- Optional means that if set to None, we allows ourselves to sample one randomly.
:param reset_attention_mask: bool, determines if prefixes are to be per document or per row.
:return Depending if prefix is per document or per row, the method returns:
- List[List[int]]: prefix indices for each document in case of per document prefix
- List[int]: prefix indices for rows else.
"""
micro_batch_size, seq_length = data.size()
prefix_indices = []
assert partial_prefix_indices is None or len(partial_prefix_indices) == micro_batch_size, f"partial_prefix_indices has to be None or its length equal to {micro_batch_size}, got {len(partial_prefix_indices)}"
for batch_id in range(micro_batch_size):
# Prefix lm per document.
if reset_attention_mask:
prefix_indices.append([])
# Compute the index of all eod tokens in data.
eod_indices = (data[batch_id] == eod_token).nonzero().squeeze(-1)
# If the last eod token is not the last token of the sequence, we suppose that there is a partial document
# We treat this case as if we add an eod token at the end of the sequence.
if data[batch_id][-1] != eod_token:
eod_indices = torch.cat(
(eod_indices,
torch.tensor([len(data[batch_id])], dtype=eod_indices.dtype, device=eod_indices.device))
)
prev_index = 0
assert partial_prefix_indices is None or len(partial_prefix_indices[batch_id]) == len(eod_indices), f"The number of prefixes has to match the number of documents, complete or partial. Got {len(partial_prefix_indices[batch_id])} prefixes and {len(eod_indices)} documents"
for doc_id, eod_index in enumerate(eod_indices):
assert partial_prefix_indices is None or isinstance(partial_prefix_indices[batch_id], list), f"Per document prefix has to store a list on indices for each row, got {partial_prefix_indices[batch_id]}"
# Prefix index is defined as the first index that isn't attended by all tokens in a document
if partial_prefix_indices is None or partial_prefix_indices[batch_id][doc_id] is None:
# We need to randomly generate a prefix index that satisfies the interleave condition in the docstring
prefix_index = randint(prev_index + 1, eod_index)
else:
# We get value from partial_prefix_indices, and run validation on that value
prefix_index = partial_prefix_indices[batch_id][doc_id]
assert prev_index + 1 <= prefix_index <= eod_index, f"Prefix index needs to be between documents indices, {prev_index + 1} <= {prefix_index} <= {eod_index} should be True."
prefix_indices[batch_id].append(prefix_index)
prev_index = eod_index + 1
# Prefix lm per row.
else:
assert partial_prefix_indices is None or isinstance(partial_prefix_indices[batch_id], int), \
f"Per document prefix has to store an int for each row, got {partial_prefix_indices[batch_id]}"
# Prefix index is defined as the first index that isn't attended by all previous tokens in a document
prefix_index: int
if partial_prefix_indices is None or partial_prefix_indices[batch_id] is None:
# 0 being the first prefix index makes no sense since 0 always attends to itself, and there are no other tokens before.
prefix_index = randint(1, seq_length)
else:
# We get value from partial_prefix_indices, and run validation on that value
prefix_index = partial_prefix_indices[batch_id]
assert 1 <= prefix_index <= seq_length, f"Prefix index needs to be between documents indices, 1 <= {prefix_index} <= {seq_length} should be True."
prefix_indices.append(prefix_index)
return prefix_indices
@log_debug_usage(logger, "Using loss reweighting")
def reweight_loss_mask_(loss_mask: torch.Tensor, tokens: torch.Tensor):
"""Reweight loss mask in-place"""
_, seq_length = tokens.shape
weight_loss = torch.arange(seq_length, 0, -1, dtype=torch.float, device=loss_mask.device) / (seq_length + 1) * 2
# in-place operation
loss_mask *= weight_loss[None, :]
def found_kill_switch():
args = get_args()
if args.kill_switch_path is not None and os.path.exists(args.kill_switch_path):
return True
else:
return False
def get_fingerprint_header():
return f"{'min':^13} {'max':^13} {'mean':^13} {'l2 norm':^12} metadata"
def get_fingerprint(p):
return f"{p.min():13.6e} {p.max():13.6e} {p.mean():13.6e} {p.norm():12.6e}"
def dump_weights(preamble, iteration, model, optimizer, tensor=None):
tp_rank = mpu.get_tensor_model_parallel_rank()
pp_rank = mpu.get_pipeline_model_parallel_rank()
dp_rank = mpu.get_data_parallel_rank()
dp_size = mpu.get_data_parallel_world_size()
fn = f"debug-bf16-{iteration}-pp{pp_rank}-tp{tp_rank}-dp{dp_rank}-{preamble}.txt"
# only care for first and last pp stages and dp0 tp0
#if not (mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage()):
# return
#if not (tp_rank == 0 and dp_rank == 0):
# return
if tensor is not None:
orig_tensor = tensor
if hasattr(tensor, "_hp_param"):
numel = tensor._hp_param.numel() # // dp_size
tensor = tensor.flatten().narrow(0, 0, numel)
#print(fn)
with open(fn, "w") as fh:
fh.write(f"{get_fingerprint_header()}\n")
if tensor is not None:
fh.write(f"{get_fingerprint(tensor)} tensor {tensor.shape}\n")
else:
for n, p in model[0].named_parameters():
fh.write(f"{get_fingerprint(p)} {n} {p.shape}\n")
return
# until we figure out how to dump the actual fp32 values don't do this
fn = f"debug-fp32-{iteration}-pp{pp_rank}-tp{tp_rank}-dp{dp_rank}-{preamble}.txt"
with open(fn, "w") as fh:
fh.write(f"{get_fingerprint_header()}\n")
if tensor is not None:
tensor = orig_tensor
if hasattr(tensor, "_hp_param"):
fh.write(f"{get_fingerprint(tensor._hp_param)} tensor {tensor._hp_param.shape}\n")
#fh.write(f"{get_fingerprint(tensor._hp_grad)} tensor grad\n")
else:
fh.write(f"{get_fingerprint(tensor)} tensor {tensor.shape}\n")
#fh.write(f"{get_fingerprint(tensor.grad)} tensor grad\n")
else:
if hasattr(model[0].module.tied_modules, "embed"):
p = model[0].module.tied_modules.embed.word_embeddings.weight._hp_param
fh.write(f"{get_fingerprint(p)} module.tied_modules.embed.word_embeddings.weight._hp_param {p.shape}\n")
# for i, param_group in enumerate(optimizer.param_groups):
# fh.write(f"{get_fingerprint(optimizer.fp32_groups_flat_partition[i])} group={i}\n")
#fh.write(f"{i}={optimizer.fp32_groups_flat_partition[i]}\n")
# if mpu.is_pipeline_first_stage():
# x = optimizer.fp32_groups_flat_partition[0]
# fh.write(f"fp32={x[:402432]}\n")
# if mpu.is_pipeline_last_stage()):
# x = optimizer.fp32_groups_flat_partition[1]
# fh.write(f"fp32={x[-402432:]}\n")
# import os
# import socket
# hostname = socket.gethostname()
# pid = os.getpid()
# global_rank = torch.distributed.get_rank()
#fn = f"debug-{iteration}-pp{pp_rank}-tp{tp_rank}-dp{dp_rank}-global{global_rank}-{preamble}-{pid}.txt"
\ No newline at end of file
# 模型编码
modelCode=342
# 模型名称
modelName=megatron-deepspeed-vit_pytorch
# 模型描述
modelDescription=基于transformer的图像分类算法
# 应用场景
appScenario=推理,训练,图像分类
# 框架类型
frameType=PyTorch
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pretrain BERT"""
from functools import partial
import torch
import torch.nn.functional as F
from megatron import get_args
from megatron import print_rank_0
from megatron import get_timers
from megatron import mpu
from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.model.bert_model import BertModel
from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
print_rank_0('building BERT model ...')
args = get_args()
num_tokentypes = 2 if args.bert_binary_head else 0
model = BertModel(
num_tokentypes=num_tokentypes,
add_binary_head=args.bert_binary_head,
parallel_output=True,
pre_process=pre_process,
post_process=post_process)
return model
def get_batch(data_iterator):
"""Build the batch."""
# Items and their type.
keys = ['text', 'types', 'labels', 'is_random', 'loss_mask', 'padding_mask']
datatype = torch.int64
# Broadcast data.
if data_iterator is not None:
data = next(data_iterator)
else:
data = None
data_b = mpu.broadcast_data(keys, data, datatype)
# Unpack.
tokens = data_b['text'].long()
types = data_b['types'].long()
sentence_order = data_b['is_random'].long()
loss_mask = data_b['loss_mask'].float()
lm_labels = data_b['labels'].long()
padding_mask = data_b['padding_mask'].long()
return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask
def loss_func(loss_mask, sentence_order, output_tensor):
lm_loss_, sop_logits = output_tensor
lm_loss_ = lm_loss_.float()
loss_mask = loss_mask.float()
lm_loss = torch.sum(
lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
if sop_logits is not None:
sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(),
sentence_order.view(-1),
ignore_index=-1)
sop_loss = sop_loss.float()
loss = lm_loss + sop_loss
averaged_losses = average_losses_across_data_parallel_group(
[lm_loss, sop_loss])
return loss, {'lm loss': averaged_losses[0],
'sop loss': averaged_losses[1]}
else:
loss = lm_loss
averaged_losses = average_losses_across_data_parallel_group(
[lm_loss])
return loss, {'lm loss': averaged_losses[0]}
def forward_step(data_iterator, model):
"""Forward step."""
args = get_args()
timers = get_timers()
# Get the batch.
timers('batch-generator').start()
tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = get_batch(
data_iterator)
timers('batch-generator').stop()
if not args.bert_binary_head:
types = None
# Forward pass through the model.
output_tensor = model(tokens, padding_mask, tokentype_ids=types,
lm_labels=lm_labels)
return output_tensor, partial(loss_func, loss_mask, sentence_order)
def train_valid_test_datasets_provider(train_val_test_num_samples):
"""Build train, valid, and test datasets."""
args = get_args()
print_rank_0('> building train, validation, and test datasets '
'for BERT ...')
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
data_prefix=args.data_path,
data_impl=args.data_impl,
splits_string=args.split,
train_valid_test_num_samples=train_val_test_num_samples,
max_seq_length=args.seq_length,
masked_lm_prob=args.mask_prob,
short_seq_prob=args.short_seq_prob,
seed=args.seed,
skip_warmup=(not args.mmap_warmup),
binary_head=args.bert_binary_head)
print_rank_0("> finished creating BERT datasets ...")
return train_ds, valid_ds, test_ds
if __name__ == "__main__":
pretrain(train_valid_test_datasets_provider, model_provider, forward_step,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pretrain GPT"""
import torch
from functools import partial
from megatron import get_args
from megatron import print_rank_0
from megatron import get_timers
from megatron import get_tokenizer
from megatron import mpu
from megatron.data.gpt_dataset import build_train_valid_test_datasets, build_dataset_group
from megatron.enums import AttnMaskType
from megatron.model import GPTModel, GPTModelPipe
from megatron.training import pretrain
from megatron.utils import get_ltor_masks_and_position_ids, get_prefix_indices
from megatron.utils import average_losses_across_data_parallel_group
import deepspeed
from deepspeed.runtime.utils import see_memory_usage
import os
try:
from torch.distributed.elastic.multiprocessing.errors import record
except ImportError:
# noop
def record(fn):
return fn
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
print_rank_0('building GPT model ...')
see_memory_usage(f"Before Building Model", force=True)
args = get_args()
with deepspeed.zero.Init(data_parallel_group=mpu.get_data_parallel_group(),
remote_device=None if args.remote_device == 'none' else args.remote_device,
config_dict_or_path=args.deepspeed_config,
enabled=args.zero_stage == 3,
mpu=mpu):
if args.deepspeed:
args.pretrain_causal_attention = True
model = GPTModelPipe(
num_tokentypes=0,
parallel_output=True,
attn_mask_type=AttnMaskType.causal
)
# This is a hack to give us a reference to get_batch_pipe from within training.py
# We need to call model.set_batch_fn after deepspeed.initialize
model._megatron_batch_fn = get_batch_pipe
else:
model = GPTModel(
num_tokentypes=0,
parallel_output=True,
pre_process=pre_process,
post_process=post_process
)
see_memory_usage(f"After Building Model", force=True)
return model
def get_batch(data_iterator):
"""Generate a batch"""
args = get_args()
tokenizer = get_tokenizer()
# Items and their type.
keys = ['text']
datatype = torch.int64
# Broadcast data.
if data_iterator is not None:
data = next(data_iterator)
else:
data = None
data_b = mpu.broadcast_data(keys, data, datatype)
# Unpack.
tokens_ = data_b['text'].long()
labels = tokens_[:, 1:].contiguous()
tokens = tokens_[:, :-1].contiguous()
# Get the masks and postition ids.
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
tokens,
tokenizer.eod,
args.reset_position_ids,
args.reset_attention_mask,
args.eod_mask_loss,
prefix_indices=None,
loss_on_targets_only=args.loss_on_targets_only
)
return tokens, labels, loss_mask, attention_mask, position_ids
def get_batch_pipe(data):
"""Modification of `get_batch` to work on `next(data_iterator)` instead of `data_iterator`"""
args = get_args()
tokenizer = get_tokenizer()
# Items and their type.
keys = ['text']
datatype = torch.int64
# Broadcast data.
data_b = mpu.broadcast_data(keys, data, datatype)
# Unpack.
tokens_ = data_b['text'].long()
labels = tokens_[:, 1:].contiguous()
tokens = tokens_[:, :-1].contiguous()
# Get the masks and position ids.
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
tokens,
tokenizer.eod,
args.reset_position_ids,
args.reset_attention_mask,
args.eod_mask_loss,
prefix_indices=None,
loss_on_targets_only=args.loss_on_targets_only
)
if args.curriculum_learning and args.curriculum_seqlen < tokens.size()[1]:
# seqlen-based curriculum learning
# tokens, position_ids, labels, loss_mask have size [batch size, seqlen]
tokens = tokens[:, :args.curriculum_seqlen].contiguous()
position_ids = position_ids[:, :args.curriculum_seqlen].contiguous()
labels = labels[:, :args.curriculum_seqlen].contiguous()
loss_mask = loss_mask[:, :args.curriculum_seqlen].contiguous()
return (tokens, position_ids, attention_mask), (labels, loss_mask)
def loss_func(loss_mask, output_tensor):
losses = output_tensor.float()
loss_mask = loss_mask.view(-1).float()
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
# Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {'lm loss': averaged_loss[0]}
def forward_step(data_iterator, model):
"""Forward step."""
args = get_args()
timers = get_timers()
# Get the batch.
timers('batch-generator').start()
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
data_iterator)
timers('batch-generator').stop()
output_tensor = model(tokens, position_ids, attention_mask,
labels=labels)
if args.curriculum_learning and args.curriculum_seqlen < args.seq_length:
loss_mask = loss_mask[:, :args.curriculum_seqlen].contiguous()
return output_tensor, partial(loss_func, loss_mask)
def train_valid_test_datasets_provider(train_val_test_num_samples):
"""Build train, valid, and test datasets."""
args = get_args()
train_ds, valid_ds, test_ds = None, None, None
print_rank_0('> building train, validation, and test datasets for GPT ...')
# Option 1 of data loading using --data-path
if args.data_path:
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
data_prefix=args.data_path,
data_impl=args.data_impl,
splits_string=args.split,
train_valid_test_num_samples=train_val_test_num_samples,
seq_length=args.seq_length,
seed=args.seed,
skip_warmup=(not args.mmap_warmup))
# Option 2 of data loading using --(train|valid|test)-weighted-split-paths
elif args.train_weighted_split_paths:
assigned_train_valid_test = []
if args.train_weighted_split_paths is not None:
train_ds = []
assigned_train_valid_test.append("train")
if args.valid_weighted_split_paths is not None:
valid_ds = []
assigned_train_valid_test.append("valid")
if args.test_weighted_split_paths is not None:
test_ds = []
assigned_train_valid_test.append("test")
for s in assigned_train_valid_test:
data_groups = zip(eval(f"args.{s}_weighted_split_paths"),
eval(f"args.{s}_weighted_split_weights"),
eval(f"args.{s}_weighted_split_splits"),
eval(f"args.{s}_weighted_split_names"))
for paths, weights, splits, name in data_groups:
d = build_dataset_group(name, paths, weights, splits,
args.data_impl,
train_val_test_num_samples,
args.seq_length, args.seed,
(not args.mmap_warmup),
train_valid_test=s)
eval(f"{s}_ds").append(d)
else:
raise NotImplementedError("No dataloading argument passed")
print_rank_0("> finished creating GPT datasets ...")
return train_ds, valid_ds, test_ds
@record
def main():
pretrain(train_valid_test_datasets_provider, model_provider, forward_step,
args_defaults={'tokenizer_type': 'GPT2BPETokenizer'})
if __name__ == "__main__":
main()
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pretrain BERT for Inverse Cloze Task"""
import math
import torch
import torch.distributed as dist
import torch.nn.functional as F
from megatron import get_args
from megatron import print_rank_0
from megatron import get_timers
from megatron import mpu
from megatron.data.biencoder_dataset_utils import get_ict_batch
from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.model.biencoder_model import biencoder_model_provider
from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group
def pretrain_ict_model_provider():
args = get_args()
model = biencoder_model_provider(
only_context_model=False,
only_query_model=False,
biencoder_shared_query_context_model=\
args.biencoder_shared_query_context_model)
return model
def get_group_world_size_rank():
group = mpu.get_data_parallel_group()
rank = torch.distributed.get_rank(group=group)
world_size = torch.distributed.get_world_size(group=group)
return group, rank, world_size
class AllgatherFromDataParallelRegion(torch.autograd.Function):
@staticmethod
def forward(ctx, input_):
assert input_.dim() == 2
group, rank, world_size = get_group_world_size_rank()
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
tensor_list[rank] = input_
torch.distributed.all_gather(tensor_list, input_, group=group)
output = torch.cat(tensor_list, dim=0).contiguous()
return output
@staticmethod
def backward(ctx, grad_output):
group, rank, world_size = get_group_world_size_rank()
assert grad_output.shape[0] % world_size == 0
dim_size = grad_output.shape[0] // world_size
output_list = torch.split(grad_output, dim_size, dim=0)
# get chunk from this rank
output = output_list[rank].contiguous()
return output
def forward_step(data_iterator, model, input_tensor):
"""Forward step."""
args = get_args()
timers = get_timers()
# Get the batch.
timers('batch-generator').start()
query_tokens, query_mask, \
context_tokens, context_mask, context_indices = get_ict_batch(data_iterator)
timers('batch-generator').stop()
# Query and Context Types
query_types = torch.cuda.LongTensor(*query_tokens.shape).fill_(0)
context_types = torch.cuda.LongTensor(*context_tokens.shape).fill_(0)
# Forward model.
query_logits, context_logits = model(query_tokens, query_mask,
query_types, context_tokens,
context_mask, context_types)
micro_batch_size = query_logits.shape[0]
# recall we assert that tensor_model_parallel_size == 1
assert mpu.get_tensor_model_parallel_world_size() == 1, \
"Model parallel size > 1 not supported for ICT"
global_batch_size = dist.get_world_size() * micro_batch_size
all_query_logits = AllgatherFromDataParallelRegion.apply(query_logits)
all_context_logits = AllgatherFromDataParallelRegion.apply(context_logits)
# scores are inner products between query and context embeddings
retrieval_scores = torch.matmul(all_query_logits,
torch.transpose(all_context_logits, 0, 1))
# scaling the retriever scores
if args.retriever_score_scaling:
retrieval_scores = retrieval_scores / math.sqrt(args.hidden_size)
softmax_scores = F.log_softmax(retrieval_scores, dim=1)
sorted_vals, sorted_indices = torch.topk(softmax_scores,
k=softmax_scores.shape[1], sorted=True)
def topk_accuracy(k):
return torch.cuda.FloatTensor([sum([int(i in sorted_indices[i, :k]) \
for i in range(global_batch_size)]) / global_batch_size])
topk_accs = [topk_accuracy(int(k)) for k in args.retriever_report_topk_accuracies]
labels = torch.arange(global_batch_size).long().cuda()
loss = F.nll_loss(softmax_scores, labels, reduction='mean')
reduced_losses = average_losses_across_data_parallel_group([loss, *topk_accs])
# Scale the retrieval loss
loss = loss * mpu.get_data_parallel_world_size()
# create stats_dict with retrieval loss and all specified top-k accuracies
topk_acc_dict = {'top{}_acc'.format(k): v * 100 for k, v in \
zip(args.retriever_report_topk_accuracies, reduced_losses[1:])}
stats_dict = dict(loss=reduced_losses[0], **topk_acc_dict)
return loss, stats_dict
def train_valid_test_datasets_provider(train_val_test_num_samples):
"""Build train, valid and test datasets."""
args = get_args()
print_rank_0('> building train, validation, and test datasets '
'for BERT ICT...')
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
data_prefix=args.data_path,
data_impl=args.data_impl,
splits_string=args.split,
train_valid_test_num_samples=train_val_test_num_samples,
max_seq_length=args.seq_length,
masked_lm_prob=args.mask_prob,
short_seq_prob=args.short_seq_prob,
seed=args.seed,
skip_warmup=(not args.mmap_warmup),
binary_head=False,
dataset_type='ict')
print_rank_0("> finished creating BERT ICT datasets ...")
return train_ds, valid_ds, test_ds
if __name__ == "__main__":
pretrain(train_valid_test_datasets_provider,
pretrain_ict_model_provider,
forward_step,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pretrain GPT"""
import torch
from functools import partial
from megatron import get_args
from megatron import print_rank_0
from megatron import get_timers
from megatron import get_tokenizer
from megatron import mpu
from megatron.data.gpt_dataset import build_train_valid_test_datasets, build_dataset_group
from megatron.enums import AttnMaskType
from megatron.model import GPTModel, GPTModelPipe
from megatron.training import pretrain
from megatron.utils import get_ltor_masks_and_position_ids, get_prefix_indices, reweight_loss_mask_
from megatron.utils import average_losses_across_data_parallel_group
import deepspeed
from deepspeed.runtime.utils import see_memory_usage
import subprocess
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
print_rank_0('building GPT model ...')
see_memory_usage(f"Before Building Model", force=True)
args = get_args()
with deepspeed.zero.Init(data_parallel_group=mpu.get_data_parallel_group(),
remote_device=None if args.remote_device == 'none' else args.remote_device,
config_dict_or_path=args.deepspeed_config,
enabled=args.zero_stage == 3,
mpu=mpu):
if args.deepspeed:
model = GPTModelPipe(
num_tokentypes=0,
parallel_output=True,
attn_mask_type=AttnMaskType.prefix
)
# This is a hack to give us a reference to get_batch_pipe from within training.py
# We need to call model.set_batch_fn after deepspeed.initialize
model._megatron_batch_fn = get_batch_pipe
else:
model = GPTModel(
num_tokentypes=0,
parallel_output=True,
pre_process=pre_process,
post_process=post_process,
prefix_lm=True
)
see_memory_usage(f"After Building Model", force=True)
return model
def get_batch(data_iterator):
"""Generate a batch"""
args = get_args()
tokenizer = get_tokenizer()
# Items and their type.
keys = ['text']
datatype = torch.int64
# Broadcast data.
if data_iterator is not None:
data = next(data_iterator)
else:
data = None
data_b = mpu.broadcast_data(keys, data, datatype)
# Unpack.
tokens_ = data_b['text'].long()
labels = tokens_[:, 1:].contiguous()
tokens = tokens_[:, :-1].contiguous()
# Prefix
prefix_indices = get_prefix_indices(
tokens,
tokenizer.eod,
partial_prefix_indices=None,
reset_attention_mask=args.reset_attention_mask
)
# Get the masks and postition ids.
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
tokens,
tokenizer.eod,
args.reset_position_ids,
args.reset_attention_mask,
args.eod_mask_loss,
prefix_indices=prefix_indices,
loss_on_targets_only=args.loss_on_targets_only
)
# weight loss_mask
if args.reweight_loss_based_on_position_frequency:
reweight_loss_mask_(loss_mask, tokens)
return tokens, labels, loss_mask, attention_mask, position_ids
def get_batch_pipe(data):
"""Modification of `get_batch` to work on `next(data_iterator)` instead of `data_iterator`"""
args = get_args()
tokenizer = get_tokenizer()
# Items and their type.
keys = ['text']
datatype = torch.int64
# Broadcast data.
data_b = mpu.broadcast_data(keys, data, datatype)
# Unpack.
tokens_ = data_b['text'].long()
labels = tokens_[:, 1:].contiguous()
tokens = tokens_[:, :-1].contiguous()
# Prefix
prefix_indices = get_prefix_indices(
tokens,
tokenizer.eod,
partial_prefix_indices=None,
reset_attention_mask=args.reset_attention_mask
)
# Get the masks and position ids.
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
tokens,
tokenizer.eod,
args.reset_position_ids,
args.reset_attention_mask,
args.eod_mask_loss,
prefix_indices=prefix_indices,
loss_on_targets_only=args.loss_on_targets_only
)
# weight loss_mask
if args.reweight_loss_based_on_position_frequency:
reweight_loss_mask_(loss_mask, tokens)
return (tokens, position_ids, attention_mask), (labels, loss_mask), prefix_indices
def loss_func(loss_mask, output_tensor):
losses = output_tensor.float()
loss_mask = loss_mask.view(-1).float()
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
# Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {'lm loss': averaged_loss[0]}
def forward_step(data_iterator, model):
"""Forward step."""
args = get_args()
timers = get_timers()
# Get the batch.
timers('batch-generator').start()
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
data_iterator)
timers('batch-generator').stop()
output_tensor = model(tokens, position_ids, attention_mask,
labels=labels)
return output_tensor, partial(loss_func, loss_mask)
def train_valid_test_datasets_provider(train_val_test_num_samples):
"""Build train, valid, and test datasets."""
args = get_args()
train_ds, valid_ds, test_ds = None, None, None
print_rank_0('> building train, validation, and test datasets for GPT ...')
# Option 1 of data loading using --data-path
if args.data_path:
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
data_prefix=args.data_path,
data_impl=args.data_impl,
splits_string=args.split,
train_valid_test_num_samples=train_val_test_num_samples,
seq_length=args.seq_length,
seed=args.seed,
skip_warmup=(not args.mmap_warmup))
# Option 2 of data loading using --(train|valid|test)-weighted-split-paths
elif args.train_weighted_split_paths:
assigned_train_valid_test = []
if args.train_weighted_split_paths is not None:
train_ds = []
assigned_train_valid_test.append("train")
if args.valid_weighted_split_paths is not None:
valid_ds = []
assigned_train_valid_test.append("valid")
if args.test_weighted_split_paths is not None:
test_ds = []
assigned_train_valid_test.append("test")
for s in assigned_train_valid_test:
data_groups = zip(eval(f"args.{s}_weighted_split_paths"),
eval(f"args.{s}_weighted_split_weights"),
eval(f"args.{s}_weighted_split_splits"),
eval(f"args.{s}_weighted_split_names"))
for paths, weights, splits, name in data_groups:
d = build_dataset_group(name, paths, weights, splits,
args.data_impl,
train_val_test_num_samples,
args.seq_length, args.seed,
(not args.mmap_warmup),
train_valid_test=s)
eval(f"{s}_ds").append(d)
else:
raise NotImplementedError("No dataloading argument passed")
print_rank_0("> finished creating GPT datasets ...")
return train_ds, valid_ds, test_ds
def command_exists(cmd):
result = subprocess.Popen(f'type {cmd}', stdout=subprocess.PIPE, shell=True)
return result.wait() == 0
def git_ds_info():
from deepspeed.env_report import main as ds_report
ds_report()
# Write out version/git info
git_hash_cmd = "git rev-parse --short HEAD"
git_branch_cmd = "git rev-parse --abbrev-ref HEAD"
if command_exists('git'):
try:
result = subprocess.check_output(git_hash_cmd, shell=True)
git_hash = result.decode('utf-8').strip()
result = subprocess.check_output(git_branch_cmd, shell=True)
git_branch = result.decode('utf-8').strip()
except subprocess.CalledProcessError:
git_hash = "unknown"
git_branch = "unknown"
else:
git_hash = "unknown"
git_branch = "unknown"
print(f'**** Git info for Megatron: git_hash={git_hash} git_branch={git_branch} ****')
if __name__ == "__main__":
git_ds_info()
pretrain(train_valid_test_datasets_provider, model_provider, forward_step,
args_defaults={'tokenizer_type': 'GPT2BPETokenizer'})
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pretrain T5"""
from functools import partial
import torch
from megatron import (
get_args,
get_timers,
mpu,
print_rank_0
)
from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.model.t5_model import T5Model
from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
assert pre_process and post_process, "T5 doesn't yet support pipelining"
print_rank_0('building T5 model ...')
model = T5Model(num_tokentypes=0,
parallel_output=True)
return model
def get_batch(data_iterator):
"""Build the batch."""
keys = ['text_enc', 'text_dec', 'labels', 'loss_mask',
'enc_mask', 'dec_mask', 'enc_dec_mask']
datatype = torch.int64
# Broadcast data.
if data_iterator is not None:
data = next(data_iterator)
else:
data = None
data_b = mpu.broadcast_data(keys, data, datatype)
# Unpack.
tokens_enc = data_b['text_enc'].long()
tokens_dec = data_b['text_dec'].long()
labels = data_b['labels'].long()
loss_mask = data_b['loss_mask'].float()
enc_mask = (data_b['enc_mask'] < 0.5)
dec_mask = (data_b['dec_mask'] < 0.5)
enc_dec_mask = (data_b['enc_dec_mask'] < 0.5)
return tokens_enc, tokens_dec, loss_mask, labels, \
enc_mask, dec_mask, enc_dec_mask
def loss_func(loss_mask, output_tensor):
lm_loss_, _ = output_tensor
lm_loss_ = lm_loss_.float()
lm_loss = torch.sum(
lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
loss = lm_loss
averaged_losses = average_losses_across_data_parallel_group([lm_loss])
return loss, {'lm loss': averaged_losses[0]}
def forward_step(data_iterator, model):
"""Forward step."""
args = get_args()
timers = get_timers()
# Get the batch.
timers('batch generator').start()
tokens_enc, tokens_dec, loss_mask, lm_labels, enc_mask, dec_mask, enc_dec_mask \
= get_batch(data_iterator)
timers('batch generator').stop()
# Forward model lm_labels
output_tensor = model(tokens_enc,
tokens_dec,
enc_mask,
dec_mask,
enc_dec_mask,
tokentype_ids=None,
lm_labels=lm_labels)
return output_tensor, partial(loss_func, loss_mask)
def train_valid_test_datasets_provider(train_val_test_num_samples):
"""Build train, valid, and test datasets."""
args = get_args()
print_rank_0('> building train, validation, and test datasets '
'for T5 ...')
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
data_prefix=args.data_path,
data_impl=args.data_impl,
splits_string=args.split,
train_valid_test_num_samples=train_val_test_num_samples,
max_seq_length=args.encoder_seq_length,
max_seq_length_dec=args.decoder_seq_length,
masked_lm_prob=args.mask_prob,
short_seq_prob=args.short_seq_prob,
seed=args.seed,
skip_warmup=(not args.mmap_warmup),
dataset_type='t5')
print_rank_0("> finished creating T5 datasets ...")
return train_ds, valid_ds, test_ds
if __name__ == "__main__":
pretrain(train_valid_test_datasets_provider, model_provider, forward_step,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pretrain GPT"""
import torch
import torch.nn.functional as F
from functools import partial
from megatron import get_args
from megatron import print_rank_0
from megatron import get_timers
from megatron import get_tokenizer
from megatron import mpu
from megatron.data.vit_dataset import build_train_valid_datasets
from megatron.data.gpt_dataset import build_train_valid_test_datasets, build_dataset_group
from megatron.enums import AttnMaskType
from megatron.model import GPTModel, GPTModelPipe
from megatron.model.vit_model import VitModel, VitModelPipe
from megatron.training import pretrain
from megatron.utils import get_ltor_masks_and_position_ids, get_prefix_indices
from megatron.utils import average_losses_across_data_parallel_group
import deepspeed
from deepspeed.runtime.utils import see_memory_usage
import os
try:
from torch.distributed.elastic.multiprocessing.errors import record
except ImportError:
# noop
def record(fn):
return fn
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
print_rank_0('building GPT model ...')
see_memory_usage(f"Before Building Model", force=True)
args = get_args()
with deepspeed.zero.Init(data_parallel_group=mpu.get_data_parallel_group(),
remote_device=None if args.remote_device == 'none' else args.remote_device,
config_dict_or_path=args.deepspeed_config,
enabled=args.zero_stage == 3,
mpu=mpu):
if args.deepspeed:
args.pretrain_causal_attention = True
model = VitModelPipe(num_classes=args.num_classes, finetune=False, attn_mask_type=AttnMaskType.causal)
# This is a hack to give us a reference to get_batch_pipe from within training.py
# We need to call model.set_batch_fn after deepspeed.initialize
model._megatron_batch_fn = get_batch_pipe
else:
model = VitModel(num_classes=args.num_classes)
see_memory_usage(f"After Building Model", force=True)
return model
def get_batch(data_iterator):
"""Build the batch."""
if data_iterator is not None:
data = next(data_iterator)
# only data parallelism; no need for broadcast
images = data[0].cuda()
labels = data[1].cuda()
return images, labels
def get_batch_pipe(data):
images = data[0].cuda()
labels = data[1].cuda()
return (images), (labels)
def forward_step(data_iterator, model):
"""Forward step."""
timers = get_timers()
# Get the batch.
timers("batch-generator").start()
(
images,
labels,
) = get_batch(data_iterator)
timers("batch-generator").stop()
# Forward model. lm_labels
logits = model(images).contiguous().float()
loss = F.cross_entropy(logits, labels)
outputs = torch.argmax(logits, -1)
correct = (outputs == labels).float()
accuracy = torch.mean(correct)
averaged_loss = average_losses_across_data_parallel_group([loss, accuracy])
return loss, {"loss": averaged_loss[0], "accuracy": averaged_loss[1]}
def train_valid_test_datasets_provider(train_val_test_num_samples):
"""Build train, valid, and test datasets."""
args = get_args()
print_rank_0(
"> building train, validation, and test datasets " "for VIT ..."
)
train_ds, valid_ds, test_ds = build_train_valid_datasets(data_path=args.data_path)
print_rank_0("> finished creating VIT datasets ...")
return train_ds, valid_ds, test_ds
@record
def main():
pretrain(
train_valid_test_datasets_provider,
model_provider,
forward_step,
args_defaults={'dataloader_type': 'cyclic'}
)
if __name__ == "__main__":
main()
'''
def get_batch(data_iterator):
"""Generate a batch"""
args = get_args()
tokenizer = get_tokenizer()
# Items and their type.
keys = ['text']
datatype = torch.int64
# Broadcast data.
if data_iterator is not None:
data = next(data_iterator)
else:
data = None
data_b = mpu.broadcast_data(keys, data, datatype)
# Unpack.
tokens_ = data_b['text'].long()
labels = tokens_[:, 1:].contiguous()
tokens = tokens_[:, :-1].contiguous()
# Get the masks and postition ids.
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
tokens,
tokenizer.eod,
args.reset_position_ids,
args.reset_attention_mask,
args.eod_mask_loss,
prefix_indices=None,
loss_on_targets_only=args.loss_on_targets_only
)
return tokens, labels, loss_mask, attention_mask, position_ids
def get_batch_pipe(data):
"""Modification of `get_batch` to work on `next(data_iterator)` instead of `data_iterator`"""
args = get_args()
tokenizer = get_tokenizer()
# Items and their type.
keys = ['text']
datatype = torch.int64
# Broadcast data.
data_b = mpu.broadcast_data(keys, data, datatype)
# Unpack.
tokens_ = data_b['text'].long()
labels = tokens_[:, 1:].contiguous()
tokens = tokens_[:, :-1].contiguous()
# Get the masks and position ids.
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
tokens,
tokenizer.eod,
args.reset_position_ids,
args.reset_attention_mask,
args.eod_mask_loss,
prefix_indices=None,
loss_on_targets_only=args.loss_on_targets_only
)
if args.curriculum_learning and args.curriculum_seqlen < tokens.size()[1]:
# seqlen-based curriculum learning
# tokens, position_ids, labels, loss_mask have size [batch size, seqlen]
tokens = tokens[:, :args.curriculum_seqlen].contiguous()
position_ids = position_ids[:, :args.curriculum_seqlen].contiguous()
labels = labels[:, :args.curriculum_seqlen].contiguous()
loss_mask = loss_mask[:, :args.curriculum_seqlen].contiguous()
return (tokens, position_ids, attention_mask), (labels, loss_mask)
def loss_func(loss_mask, output_tensor):
losses = output_tensor.float()
loss_mask = loss_mask.view(-1).float()
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
# Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {'lm loss': averaged_loss[0]}
def forward_step(data_iterator, model):
"""Forward step."""
args = get_args()
timers = get_timers()
# Get the batch.
timers('batch-generator').start()
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
data_iterator)
timers('batch-generator').stop()
output_tensor = model(tokens, position_ids, attention_mask,
labels=labels)
if args.curriculum_learning and args.curriculum_seqlen < args.seq_length:
loss_mask = loss_mask[:, :args.curriculum_seqlen].contiguous()
return output_tensor, partial(loss_func, loss_mask)
def train_valid_test_datasets_provider(train_val_test_num_samples):
"""Build train, valid, and test datasets."""
args = get_args()
train_ds, valid_ds, test_ds = None, None, None
print_rank_0('> building train, validation, and test datasets for GPT ...')
# Option 1 of data loading using --data-path
if args.data_path:
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
data_prefix=args.data_path,
data_impl=args.data_impl,
splits_string=args.split,
train_valid_test_num_samples=train_val_test_num_samples,
seq_length=args.seq_length,
seed=args.seed,
skip_warmup=(not args.mmap_warmup))
# Option 2 of data loading using --(train|valid|test)-weighted-split-paths
elif args.train_weighted_split_paths:
assigned_train_valid_test = []
if args.train_weighted_split_paths is not None:
train_ds = []
assigned_train_valid_test.append("train")
if args.valid_weighted_split_paths is not None:
valid_ds = []
assigned_train_valid_test.append("valid")
if args.test_weighted_split_paths is not None:
test_ds = []
assigned_train_valid_test.append("test")
for s in assigned_train_valid_test:
data_groups = zip(eval(f"args.{s}_weighted_split_paths"),
eval(f"args.{s}_weighted_split_weights"),
eval(f"args.{s}_weighted_split_splits"),
eval(f"args.{s}_weighted_split_names"))
for paths, weights, splits, name in data_groups:
d = build_dataset_group(name, paths, weights, splits,
args.data_impl,
train_val_test_num_samples,
args.seq_length, args.seed,
(not args.mmap_warmup),
train_valid_test=s)
eval(f"{s}_ds").append(d)
else:
raise NotImplementedError("No dataloading argument passed")
print_rank_0("> finished creating GPT datasets ...")
return train_ds, valid_ds, test_ds
@record
def main():
pretrain(train_valid_test_datasets_provider, model_provider, forward_step,
args_defaults={'tokenizer_type': 'GPT2BPETokenizer'})
if __name__ == "__main__":
main()
'''
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment