Commit 06fc51ce authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'main' into checkpoint_util

parents ec561daa 0ed2f6ac
...@@ -242,7 +242,7 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -242,7 +242,7 @@ class ColumnParallelLinear(torch.nn.Module):
input_size: first dimension of matrix A. input_size: first dimension of matrix A.
output_size: second dimension of matrix A. output_size: second dimension of matrix A.
bias: If true, add bias bias: If true, add bias
gather_output: If true, call all-gether on output and make Y avaiable gather_output: If true, call all-gather on output and make Y avaiable
to all GPUs, otherwise, every GPU will have its output to all GPUs, otherwise, every GPU will have its output
which is Y_i = XA_i which is Y_i = XA_i
init_method: method to initialize weights. Note that bias is always set init_method: method to initialize weights. Note that bias is always set
......
...@@ -24,7 +24,6 @@ from torch import _C ...@@ -24,7 +24,6 @@ from torch import _C
from torch.cuda import _lazy_call, device as device_ctx_manager from torch.cuda import _lazy_call, device as device_ctx_manager
from torch.utils.checkpoint import detach_variable from torch.utils.checkpoint import detach_variable
from megatron import get_args
from megatron.memory import allocate_mem_buff from megatron.memory import allocate_mem_buff
from .initialize import get_data_parallel_rank from .initialize import get_data_parallel_rank
...@@ -88,18 +87,100 @@ def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False): ...@@ -88,18 +87,100 @@ def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False):
def gather_split_1d_tensor(tensor): def gather_split_1d_tensor(tensor):
"""Opposite of above function, gather values from model parallel ranks.""" """Opposite of above function, gather values from model parallel ranks."""
world_size = get_tensor_model_parallel_world_size() numel_gathered = torch.numel(tensor) * \
numel = torch.numel(tensor) get_tensor_model_parallel_world_size()
numel_gathered = world_size * numel
gathered = torch.empty(numel_gathered, dtype=tensor.dtype, gathered = torch.empty(numel_gathered, dtype=tensor.dtype,
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
requires_grad=False) requires_grad=False)
chunks = [gathered[i*numel:(i+1)*numel] for i in range(world_size)] # TODO: This API is experimental in pytorch (as of Feb 2022) and
torch.distributed.all_gather(chunks, tensor, # this might break in future pytorch releases. We chose this API
group=get_tensor_model_parallel_group()) # as opposed to torch.distributed.all_gather for efficiency reasons.
# This API calls directly NCCL all-gather versus the former does
# internal copies and can potentially cause slow down.
torch.distributed._all_gather_base(gathered, tensor,
group=get_tensor_model_parallel_group())
return gathered return gathered
def _kernel_make_viewless_tensor(inp, requires_grad):
'''Make a viewless tensor.
View tensors have the undesirable side-affect of retaining a reference
to the originally-viewed tensor, even after manually setting the '.data'
field. This method creates a new tensor that links to the old tensor's
data, without linking the viewed tensor, referenced via the '._base'
field.
'''
out = torch.empty(
(1,),
dtype = inp.dtype,
device = inp.device,
requires_grad = requires_grad,
)
out.data = inp.data
return out
class MakeViewlessTensor(torch.autograd.Function):
'''
Autograd function to make a viewless tensor.
This function should be used in cases where the computation graph needs
to be propagated, but we only want a viewless tensor (e.g.,
ParallelTransformer's hidden_states). Call this function by passing
'keep_graph = True' to 'make_viewless_tensor()'.
'''
@staticmethod
def forward(ctx, inp, requires_grad):
return _kernel_make_viewless_tensor(inp, requires_grad)
@staticmethod
def backward(ctx, grad_output):
return grad_output, None
def make_viewless_tensor(inp, requires_grad, keep_graph):
'''
Entry-point for creating viewless tensors.
This method should be used, rather than calling 'MakeViewlessTensor'
or '_kernel_make_viewless_tensor' directly. This method acts as a
switch for determining if an autograd function or a regular method
should be used to create the tensor.
'''
# return tensor as-is, if not a 'view'
if inp._base is None:
return inp
# create viewless tensor
if keep_graph:
return MakeViewlessTensor.apply(inp, requires_grad)
else:
return _kernel_make_viewless_tensor(inp, requires_grad)
def assert_viewless_tensor(tensor, extra_msg = None):
'''Assert that a tensor is not a view (i.e., its '._base' field is
not set).'''
if isinstance(tensor, list):
[ assert_viewless_tensor(t) for t in tensor ]
return tensor
if not isinstance(tensor, torch.Tensor):
return tensor
assert tensor._base is None, (
"Ensure tensor._base is None before setting tensor.data or storing "
"tensor to memory buffer. Otherwise, a memory leak will occur (and "
"likely accumulate over iterations). %s"
) % extra_msg
return tensor
def safely_set_viewless_tensor_data(tensor, new_data_tensor):
'''Safely set tensor's '.data' field.
Check first that the tensor is viewless (i.e., '._base' not set). If not,
raise an exception.
'''
assert_viewless_tensor(tensor, extra_msg = "FYI, tensor._base has shape %s, and new_data_tensor has shape %s." % ("--" if tensor._base is None else tensor._base.shape, new_data_tensor.shape))
tensor.data = new_data_tensor
class CudaRNGStatesTracker: class CudaRNGStatesTracker:
"""Tracker for the cuda RNG states. """Tracker for the cuda RNG states.
...@@ -243,8 +324,9 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -243,8 +324,9 @@ class CheckpointFunction(torch.autograd.Function):
# the chunk corresponding to the current rank. # the chunk corresponding to the current rank.
if distribute_checkpointed_activations: if distribute_checkpointed_activations:
ctx.input_0_shape = args[0].data.shape ctx.input_0_shape = args[0].data.shape
args[0].data = split_tensor_into_1d_equal_chunks(args[0].data, safely_set_viewless_tensor_data(
new_buffer=True) args[0],
split_tensor_into_1d_equal_chunks(args[0].data, new_buffer=True))
# Store everything. # Store everything.
ctx.save_for_backward(*args) ctx.save_for_backward(*args)
...@@ -258,8 +340,9 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -258,8 +340,9 @@ class CheckpointFunction(torch.autograd.Function):
"please use .backward() if possible") "please use .backward() if possible")
inputs = ctx.saved_tensors inputs = ctx.saved_tensors
if ctx.distribute_checkpointed_activations: if ctx.distribute_checkpointed_activations:
inputs[0].data = gather_split_1d_tensor(inputs[0].data) safely_set_viewless_tensor_data(
inputs[0].data = inputs[0].data.view(ctx.input_0_shape) inputs[0],
gather_split_1d_tensor(inputs[0].data).view(ctx.input_0_shape))
# Store the current states. # Store the current states.
bwd_cpu_rng_state = torch.get_rng_state() bwd_cpu_rng_state = torch.get_rng_state()
......
...@@ -23,35 +23,68 @@ from .grad_scaler import ConstantGradScaler, DynamicGradScaler ...@@ -23,35 +23,68 @@ from .grad_scaler import ConstantGradScaler, DynamicGradScaler
from .optimizer import Float16OptimizerWithFloat16Params, FP32Optimizer from .optimizer import Float16OptimizerWithFloat16Params, FP32Optimizer
def _get_params_for_weight_decay_optimization(modules): def get_param_groups(modules,
"""Divide params into with-weight-decay and without-weight-decay groups. no_weight_decay_cond,
Layernorms and baises will have no weight decay but the rest will. scale_lr_cond,
lr_mult):
"""creates param groups based on weight decay condition (regularized vs non regularized)
and learning rate scale condition (args.lr vs lr_mult * args.lr)
scale_lr_cond is used during finetuning where head of the network requires a scaled
version of the base learning rate.
""" """
wd_no_scale_lr = []
weight_decay_params = {'params': []} wd_scale_lr = []
no_weight_decay_params = {'params': [], 'weight_decay': 0.0} no_wd_no_scale_lr = []
no_wd_scale_lr = []
for module in modules: for module in modules:
for module_ in module.modules(): for name, param in module.named_parameters():
if isinstance(module_, LayerNorm): if not param.requires_grad:
no_weight_decay_params['params'].extend( continue
[p for p in list(module_._parameters.values())
if p is not None]) if no_weight_decay_cond is not None:
no_wd = no_weight_decay_cond(name, param)
else: else:
weight_decay_params['params'].extend( # do not regularize biases nor Norm parameters
[p for n, p in list(module_._parameters.items()) no_wd = name.endswith(".bias") or len(param.shape) == 1
if p is not None and n != 'bias'])
no_weight_decay_params['params'].extend(
[p for n, p in list(module_._parameters.items())
if p is not None and n == 'bias'])
return weight_decay_params, no_weight_decay_params if scale_lr_cond is not None:
scale_lr = scale_lr_cond(name, param)
else:
scale_lr = False
if not no_wd and not scale_lr:
wd_no_scale_lr.append(param)
elif not no_wd and scale_lr:
wd_scale_lr.append(param)
elif no_wd and not scale_lr:
no_wd_no_scale_lr.append(param)
else:
no_wd_scale_lr.append(param)
def get_megatron_optimizer(model): param_groups = []
if len(wd_no_scale_lr):
param_groups.append({'params': wd_no_scale_lr, 'wd_mult': 1.0, 'lr_mult': 1.0})
if len(wd_scale_lr):
param_groups.append({'params': wd_scale_lr, 'wd_mult': 1.0, 'lr_mult': lr_mult})
if len(no_wd_no_scale_lr):
param_groups.append({'params': no_wd_no_scale_lr, 'wd_mult': 0.0, 'lr_mult': 1.0})
if len(no_wd_scale_lr):
param_groups.append({'params': no_wd_scale_lr, 'wd_mult': 0.0, 'lr_mult': lr_mult})
return param_groups
def get_megatron_optimizer(model,
no_weight_decay_cond=None,
scale_lr_cond=None,
lr_mult=1.0):
args = get_args() args = get_args()
# Base optimizer. # Base optimizer.
param_groups = _get_params_for_weight_decay_optimization(model) param_groups = get_param_groups(model,
no_weight_decay_cond,
scale_lr_cond,
lr_mult)
if args.optimizer == 'adam': if args.optimizer == 'adam':
optimizer = Adam(param_groups, optimizer = Adam(param_groups,
lr=args.lr, lr=args.lr,
......
...@@ -13,19 +13,20 @@ ...@@ -13,19 +13,20 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Learning rate decay functions.""" """Learning rate decay and weight decay incr functions."""
import math import math
from megatron import print_rank_0 from megatron import print_rank_0
class AnnealingLR(object): class OptimizerParamScheduler(object):
"""Anneals the learning rate.""" """Anneals learning rate and weight decay"""
def __init__(self, optimizer, max_lr, min_lr, def __init__(self, optimizer, max_lr, min_lr,
warmup_steps, decay_steps, decay_style, lr_warmup_steps, lr_decay_steps, lr_decay_style,
use_checkpoint_lr_scheduler=True, start_wd, end_wd, wd_incr_steps, wd_incr_style,
override_lr_scheduler=False): use_checkpoint_opt_param_scheduler=True,
override_opt_param_scheduler=False):
# Class values. # Class values.
self.optimizer = optimizer self.optimizer = optimizer
...@@ -35,24 +36,55 @@ class AnnealingLR(object): ...@@ -35,24 +36,55 @@ class AnnealingLR(object):
assert self.min_lr >= 0.0 assert self.min_lr >= 0.0
assert self.max_lr >= self.min_lr assert self.max_lr >= self.min_lr
self.warmup_steps = warmup_steps self.lr_warmup_steps = lr_warmup_steps
self.num_steps = 0 self.num_steps = 0
self.decay_steps = decay_steps self.lr_decay_steps = lr_decay_steps
assert self.decay_steps > 0 assert self.lr_decay_steps > 0
assert self.warmup_steps < self.decay_steps assert self.lr_warmup_steps < self.lr_decay_steps
self.decay_style = decay_style self.lr_decay_style = lr_decay_style
self.override_lr_scheduler = override_lr_scheduler self.start_wd = start_wd
self.use_checkpoint_lr_scheduler = use_checkpoint_lr_scheduler self.end_wd = end_wd
if self.override_lr_scheduler: assert self.start_wd >= 0.0
assert not self.use_checkpoint_lr_scheduler, 'both override and '\ assert self.end_wd >= self.start_wd
self.wd_incr_steps = wd_incr_steps
self.wd_incr_style = wd_incr_style
self.override_opt_param_scheduler = override_opt_param_scheduler
self.use_checkpoint_opt_param_scheduler = use_checkpoint_opt_param_scheduler
if self.override_opt_param_scheduler:
assert not self.use_checkpoint_opt_param_scheduler, 'both override and '\
'use-checkpoint are set.' 'use-checkpoint are set.'
# Set the learning rate # Set the learning rate
self.step(0) self.step(0)
print_rank_0('> learning rate decay style: {}'.format(self.lr_decay_style))
def get_wd(self):
""" Weight decay incr functions"""
if self.num_steps > self.wd_incr_steps:
return self.end_wd
if self.wd_incr_style == 'constant':
assert self.start_wd == self.end_wd
return self.end_wd
print_rank_0('> learning rate decay style: {}'.format(self.decay_style)) incr_ratio = float(self.num_steps) / float(self.wd_incr_steps)
assert incr_ratio >= 0.0
assert incr_ratio <= 1.0
delta_wd = self.end_wd - self.start_wd
if self.wd_incr_style == 'linear':
coeff = incr_ratio
elif self.wd_incr_style == 'cosine':
coeff = 0.5 * (math.cos(math.pi * (1 - incr_ratio)) + 1.0)
else:
raise Exception('{} weight decay increment style is not supported.'.format(
self.wd_incr_style))
return self.start_wd + coeff * delta_wd
def get_lr(self): def get_lr(self):
...@@ -60,33 +92,33 @@ class AnnealingLR(object): ...@@ -60,33 +92,33 @@ class AnnealingLR(object):
https://openreview.net/pdf?id=BJYwwY9ll pg. 4""" https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""
# Use linear warmup for the initial part. # Use linear warmup for the initial part.
if self.warmup_steps > 0 and self.num_steps <= self.warmup_steps: if self.lr_warmup_steps > 0 and self.num_steps <= self.lr_warmup_steps:
return self.max_lr * float(self.num_steps) / \ return self.max_lr * float(self.num_steps) / \
float(self.warmup_steps) float(self.lr_warmup_steps)
# If the learning rate is constant, just return the initial value. # If the learning rate is constant, just return the initial value.
if self.decay_style == 'constant': if self.lr_decay_style == 'constant':
return self.max_lr return self.max_lr
# For any steps larger than `self.decay_steps`, use `self.min_lr`. # For any steps larger than `self.lr_decay_steps`, use `self.min_lr`.
if self.num_steps > self.decay_steps: if self.num_steps > self.lr_decay_steps:
return self.min_lr return self.min_lr
# If we are done with the warmup period, use the decay style. # If we are done with the warmup period, use the decay style.
num_steps_ = self.num_steps - self.warmup_steps num_steps_ = self.num_steps - self.lr_warmup_steps
decay_steps_ = self.decay_steps - self.warmup_steps decay_steps_ = self.lr_decay_steps - self.lr_warmup_steps
decay_ratio = float(num_steps_) / float(decay_steps_) decay_ratio = float(num_steps_) / float(decay_steps_)
assert decay_ratio >= 0.0 assert decay_ratio >= 0.0
assert decay_ratio <= 1.0 assert decay_ratio <= 1.0
delta_lr = self.max_lr - self.min_lr delta_lr = self.max_lr - self.min_lr
if self.decay_style == 'linear': if self.lr_decay_style == 'linear':
coeff = (1.0 - decay_ratio) coeff = (1.0 - decay_ratio)
elif self.decay_style == 'cosine': elif self.lr_decay_style == 'cosine':
coeff = 0.5 * (math.cos(math.pi * decay_ratio) + 1.0) coeff = 0.5 * (math.cos(math.pi * decay_ratio) + 1.0)
else: else:
raise Exception('{} decay style is not supported.'.format( raise Exception('{} decay style is not supported.'.format(
self.decay_style)) self.lr_decay_style))
return self.min_lr + coeff * delta_lr return self.min_lr + coeff * delta_lr
...@@ -95,18 +127,24 @@ class AnnealingLR(object): ...@@ -95,18 +127,24 @@ class AnnealingLR(object):
"""Set lr for all parameters groups.""" """Set lr for all parameters groups."""
self.num_steps += increment self.num_steps += increment
new_lr = self.get_lr() new_lr = self.get_lr()
new_wd = self.get_wd()
for group in self.optimizer.param_groups: for group in self.optimizer.param_groups:
group['lr'] = new_lr group['lr'] = new_lr * group.get('lr_mult', 1.0)
group['weight_decay'] = new_wd * group.get('wd_mult', 1.0)
def state_dict(self): def state_dict(self):
state_dict = { state_dict = {
'max_lr': self.max_lr, 'max_lr': self.max_lr,
'warmup_steps': self.warmup_steps, 'lr_warmup_steps': self.lr_warmup_steps,
'num_steps': self.num_steps, 'num_steps': self.num_steps,
'decay_style': self.decay_style, 'lr_decay_style': self.lr_decay_style,
'decay_steps': self.decay_steps, 'lr_decay_steps': self.lr_decay_steps,
'min_lr': self.min_lr 'min_lr': self.min_lr,
'start_wd': self.start_wd,
'end_wd': self.end_wd,
'wd_incr_style': self.wd_incr_style,
'wd_incr_steps': self.wd_incr_steps
} }
return state_dict return state_dict
...@@ -114,13 +152,13 @@ class AnnealingLR(object): ...@@ -114,13 +152,13 @@ class AnnealingLR(object):
def _check_and_set(self, cls_value, sd_value, name): def _check_and_set(self, cls_value, sd_value, name):
"""Auxiliary function for checking the values in the checkpoint and """Auxiliary function for checking the values in the checkpoint and
setting them.""" setting them."""
if self.override_lr_scheduler: if self.override_opt_param_scheduler:
print_rank_0(' > overriding {} value to {}'.format(name, cls_value)) print_rank_0(' > overriding {} value to {}'.format(name, cls_value))
return cls_value return cls_value
if not self.use_checkpoint_lr_scheduler: if not self.use_checkpoint_opt_param_scheduler:
assert cls_value == sd_value, \ assert cls_value == sd_value, \
f'AnnealingLR: class input value {cls_value} and checkpoint' \ f'OptimizerParamScheduler: class input value {cls_value} and checkpoint' \
f'value {sd_value} for {name} do not match' f'value {sd_value} for {name} do not match'
print_rank_0(' > using checkpoint value {} for {}'.format(sd_value, print_rank_0(' > using checkpoint value {} for {}'.format(sd_value,
name)) name))
...@@ -140,25 +178,57 @@ class AnnealingLR(object): ...@@ -140,25 +178,57 @@ class AnnealingLR(object):
'minimum learning rate') 'minimum learning rate')
if 'warmup_iter' in sd: if 'warmup_iter' in sd:
warmup_steps_ = sd['warmup_iter'] lr_warmup_steps_ = sd['warmup_iter']
elif 'warmup_steps' in sd:
lr_warmup_steps_ = sd['warmup_steps']
else: else:
warmup_steps_ = sd['warmup_steps'] lr_warmup_steps_ = sd['lr_warmup_steps']
self.warmup_steps = self._check_and_set(self.warmup_steps, self.lr_warmup_steps = self._check_and_set(self.lr_warmup_steps,
warmup_steps_, lr_warmup_steps_,
'warmup iterations') 'warmup iterations')
if 'end_iter' in sd: if 'end_iter' in sd:
decay_steps_ = sd['end_iter'] lr_decay_steps_ = sd['end_iter']
elif 'decay_steps' in sd:
lr_decay_steps_ = sd['decay_steps']
else: else:
decay_steps_ = sd['decay_steps'] lr_decay_steps_ = sd['lr_decay_steps']
self.decay_steps = self._check_and_set(self.decay_steps, decay_steps_, self.lr_decay_steps = self._check_and_set(self.lr_decay_steps, lr_decay_steps_,
'total number of iterations') 'total number of iterations')
self.decay_style = self._check_and_set(self.decay_style,
sd['decay_style'], if 'decay_style' in sd:
'decay style') lr_decay_style_ = sd['decay_style']
else:
lr_decay_style_ = sd['lr_decay_style']
self.lr_decay_style = self._check_and_set(self.lr_decay_style,
lr_decay_style_,
'learning rate decay style')
if 'num_iters' in sd: if 'num_iters' in sd:
num_steps = sd['num_iters'] num_steps = sd['num_iters']
else: else:
num_steps = sd['num_steps'] num_steps = sd['num_steps']
self.step(increment=num_steps) self.step(increment=num_steps)
if 'start_wd' in sd:
self.start_wd = self._check_and_set(self.start_wd,
sd['start_wd'],
"start weight decay")
self.end_wd = self._check_and_set(self.end_wd,
sd['end_wd'],
"end weight decay")
self.wd_incr_steps = self._check_and_set(self.wd_incr_steps,
sd['wd_incr_steps'],
"total number of weight decay iterations")
self.wd_incr_style = self._check_and_set(self.wd_incr_style,
sd['wd_incr_style'],
"weight decay incr style")
...@@ -142,10 +142,16 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, ...@@ -142,10 +142,16 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
if recv_prev: if recv_prev:
tensor_recv_prev = mpu.gather_split_1d_tensor( tensor_recv_prev = mpu.gather_split_1d_tensor(
tensor_recv_prev).view(tensor_shape).requires_grad_() tensor_recv_prev).view(tensor_shape).requires_grad_()
tensor_recv_prev = mpu.make_viewless_tensor(tensor_recv_prev,
requires_grad = True,
keep_graph = False)
if recv_next: if recv_next:
tensor_recv_next = mpu.gather_split_1d_tensor( tensor_recv_next = mpu.gather_split_1d_tensor(
tensor_recv_next).view(tensor_shape).requires_grad_() tensor_recv_next).view(tensor_shape).requires_grad_()
tensor_recv_next = mpu.make_viewless_tensor(tensor_recv_next,
requires_grad = True,
keep_graph = False)
return tensor_recv_prev, tensor_recv_next return tensor_recv_prev, tensor_recv_next
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
from contextlib import contextmanager from contextlib import contextmanager
import torch import torch
from torch.autograd.variable import Variable
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron import get_args from megatron import get_args
...@@ -33,17 +34,80 @@ def get_forward_backward_func(): ...@@ -33,17 +34,80 @@ def get_forward_backward_func():
if mpu.get_pipeline_model_parallel_world_size() > 1: if mpu.get_pipeline_model_parallel_world_size() > 1:
if args.virtual_pipeline_model_parallel_size is not None: if args.virtual_pipeline_model_parallel_size is not None:
forward_backward_func = forward_backward_pipelining_with_interleaving forward_backward_func = forward_backward_pipelining_with_interleaving
assert get_num_microbatches() % args.pipeline_model_parallel_size == 0, \ assert get_num_microbatches() % \
'number of microbatches is not divisible by pipeline-parallel ' \ args.pipeline_model_parallel_size == 0, \
'size when using interleaved schedule' 'number of microbatches (%d) is not divisible by pipeline-' \
'model-parallel-size (%d) when using interleaved schedule' % (
get_num_microbatches(),
args.pipeline_model_parallel_size,
)
else: else:
forward_backward_func = forward_backward_pipelining_without_interleaving forward_backward_func = forward_backward_pipelining_without_interleaving
else: else:
forward_backward_func = forward_backward_no_pipelining forward_backward_func = forward_backward_no_pipelining
return forward_backward_func return forward_backward_func
def deallocate_output_tensor(out):
def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_reduced): '''Pseudo-deallocate (i.e., set to scalar) the output tensor's '.data' field.
This method should be called right after the output tensor has been
sent to the next pipeline stage. At this point, the output tensor is
only useful for its '.grad_fn' field, and not its '.data'.
'''
if out is None:
return
assert isinstance(out, torch.Tensor), \
"expected Tensor, found %s." % type(out).__name__
assert out._base is None, \
"counter-productive to free a view of another tensor."
out.data = torch.empty(
(1,),
device = out.device,
dtype = out.dtype,
)
def custom_backward(output, grad_output):
'''Directly call C++ autograd engine.
To make the 'deallocate_output_tensor' (above) optimization work, the C++
autograd engine must be called directly, bypassing Pytorch's
torch.autograd.backward. Pytorch's 'backward' checks that the output and
grad have the same shape, while C++'s 'backward' does not.
'''
assert output.numel() == 1, \
"output should be pseudo-'freed' in schedule, to optimize memory"
assert isinstance(output, torch.Tensor), \
"output == '%s'." % type(output).__name__
assert isinstance(grad_output, (torch.Tensor, type(None))), \
"grad_output == '%s'." % type(grad_output).__name__
# Handle scalar output
if grad_output is None:
assert output.numel() == 1, "implicit grad requires scalar output."
grad_output = torch.ones_like(
output,
memory_format = torch.preserve_format,
)
# Call c++ engine [ see torch/csrc/autograd/python_engine.cpp ]
Variable._execution_engine.run_backward(
tensors = (output,),
grad_tensors = (grad_output,),
keep_graph = False,
create_graph = False,
inputs = tuple(),
allow_unreachable=True,
accumulate_grad=True,
)
def forward_step(forward_step_func,
data_iterator,
model,
input_tensor,
forward_data_store,
collect_non_loss_data=False):
"""Forward step for passed-in model. """Forward step for passed-in model.
If first stage, input tensor is obtained from data_iterator, otherwise If first stage, input tensor is obtained from data_iterator, otherwise
...@@ -65,10 +129,15 @@ def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_r ...@@ -65,10 +129,15 @@ def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_r
unwrapped_model.set_input_tensor(input_tensor) unwrapped_model.set_input_tensor(input_tensor)
output_tensor, loss_func = forward_step_func(data_iterator, model) output_tensor, loss_func = forward_step_func(data_iterator, model)
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
output_tensor = loss_func(output_tensor) if not collect_non_loss_data:
loss, loss_reduced = output_tensor output_tensor = loss_func(output_tensor)
output_tensor = loss / get_num_microbatches() loss, loss_reduced = output_tensor
losses_reduced.append(loss_reduced) output_tensor = loss / get_num_microbatches()
forward_data_store.append(loss_reduced)
else:
data = loss_func(output_tensor, non_loss_data=True)
forward_data_store.append(data)
timers('forward-compute').stop() timers('forward-compute').stop()
# If T5 model (or other model with encoder and decoder) # If T5 model (or other model with encoder and decoder)
...@@ -116,7 +185,7 @@ def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad): ...@@ -116,7 +185,7 @@ def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad):
# Backward pass. # Backward pass.
if output_tensor_grad[0] is None: if output_tensor_grad[0] is None:
output_tensor = optimizer.scale_loss(output_tensor[0]) output_tensor = optimizer.scale_loss(output_tensor[0])
torch.autograd.backward(output_tensor[0], grad_tensors=output_tensor_grad[0]) custom_backward(output_tensor[0], output_tensor_grad[0])
# Collect the grad of the input_tensor. # Collect the grad of the input_tensor.
input_tensor_grad = [None] input_tensor_grad = [None]
...@@ -151,8 +220,12 @@ def dummy_handler(): ...@@ -151,8 +220,12 @@ def dummy_handler():
pass pass
def forward_backward_no_pipelining(forward_step_func, data_iterator, model, def forward_backward_no_pipelining(forward_step_func,
optimizer, timers, forward_only): data_iterator, model,
optimizer,
timers,
forward_only,
collect_non_loss_data=False):
"""Run forward and backward passes with no pipeline parallelism """Run forward and backward passes with no pipeline parallelism
(no inter-stage communication). (no inter-stage communication).
...@@ -164,35 +237,41 @@ def forward_backward_no_pipelining(forward_step_func, data_iterator, model, ...@@ -164,35 +237,41 @@ def forward_backward_no_pipelining(forward_step_func, data_iterator, model,
if isinstance(model, torchDDP): if isinstance(model, torchDDP):
context_handler = model.no_sync context_handler = model.no_sync
losses_reduced = [] forward_data_store = []
input_tensor, output_tensor_grad = None, None input_tensor, output_tensor_grad = None, None
with context_handler(): with context_handler():
for i in range(get_num_microbatches() - 1): for i in range(get_num_microbatches() - 1):
output_tensor = forward_step(forward_step_func, data_iterator, model, output_tensor = forward_step(forward_step_func, data_iterator,
input_tensor, losses_reduced) model, input_tensor, forward_data_store,
collect_non_loss_data)
if not forward_only: if not forward_only:
backward_step(optimizer, input_tensor, output_tensor, backward_step(optimizer, input_tensor, output_tensor,
output_tensor_grad) output_tensor_grad)
# Run computation for last microbatch out of context handler (want to # Run computation for last microbatch out of context handler (want to
# synchronize gradients). # synchronize gradients).
output_tensor = forward_step(forward_step_func, data_iterator, model, output_tensor = forward_step(forward_step_func, data_iterator,
input_tensor, losses_reduced) model, input_tensor, forward_data_store,
collect_non_loss_data)
if not forward_only: if not forward_only:
backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad) backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad)
return losses_reduced return forward_data_store
def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterator, model, def forward_backward_pipelining_with_interleaving(forward_step_func,
optimizer, timers, forward_only): data_iterator, model,
optimizer,
timers,
forward_only,
collect_non_loss_data=False):
"""Run interleaved 1F1B schedule (model split into model chunks), with """Run interleaved 1F1B schedule (model split into model chunks), with
communication between pipeline stages as needed. communication between pipeline stages as needed.
Returns dictionary with losses if the last stage, empty dict otherwise.""" Returns dictionary with losses if the last stage, empty dict otherwise."""
input_tensors = [[] for _ in range(len(model))] input_tensors = [[] for _ in range(len(model))]
output_tensors = [[] for _ in range(len(model))] output_tensors = [[] for _ in range(len(model))]
losses_reduced = [] forward_data_store = []
if not forward_only: if not forward_only:
output_tensor_grads = [[] for _ in range(len(model))] output_tensor_grads = [[] for _ in range(len(model))]
...@@ -252,7 +331,9 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat ...@@ -252,7 +331,9 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
output_tensor = forward_step(forward_step_func, output_tensor = forward_step(forward_step_func,
data_iterator[model_chunk_id], data_iterator[model_chunk_id],
model[model_chunk_id], model[model_chunk_id],
input_tensor, losses_reduced) input_tensor,
forward_data_store,
collect_non_loss_data)
output_tensors[model_chunk_id].append(output_tensor) output_tensors[model_chunk_id].append(output_tensor)
# if forward-only, no need to save tensors for a backward pass # if forward-only, no need to save tensors for a backward pass
...@@ -325,6 +406,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat ...@@ -325,6 +406,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
tensor_shape=tensor_shape, tensor_shape=tensor_shape,
timers=timers) timers=timers)
input_tensors[next_forward_model_chunk_id].append(input_tensor) input_tensors[next_forward_model_chunk_id].append(input_tensor)
deallocate_output_tensor(output_tensor)
# Run 1F1B in steady state. # Run 1F1B in steady state.
for k in range(num_microbatches_remaining): for k in range(num_microbatches_remaining):
...@@ -388,6 +470,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat ...@@ -388,6 +470,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
output_tensor, input_tensor_grad, output_tensor, input_tensor_grad,
recv_prev=recv_prev, recv_next=recv_next, recv_prev=recv_prev, recv_next=recv_next,
tensor_shape=tensor_shape, timers=timers) tensor_shape=tensor_shape, timers=timers)
deallocate_output_tensor(output_tensor)
# Put input_tensor and output_tensor_grad in data structures in the # Put input_tensor and output_tensor_grad in data structures in the
# right location. # right location.
...@@ -417,7 +500,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat ...@@ -417,7 +500,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
tensor_shape=tensor_shape, tensor_shape=tensor_shape,
timers=timers)) timers=timers))
return losses_reduced return forward_data_store
def get_tensor_shapes(rank, model_type): def get_tensor_shapes(rank, model_type):
...@@ -514,13 +597,18 @@ def send_backward_recv_forward(input_tensor_grads, tensor_shapes, timers): ...@@ -514,13 +597,18 @@ def send_backward_recv_forward(input_tensor_grads, tensor_shapes, timers):
return input_tensors return input_tensors
def forward_backward_pipelining_without_interleaving(forward_step_func, data_iterator, def forward_backward_pipelining_without_interleaving(forward_step_func,
model, optimizer, timers, data_iterator,
forward_only): model,
optimizer,
timers,
forward_only,
collect_non_loss_data=False):
"""Run non-interleaved 1F1B schedule, with communication between pipeline """Run non-interleaved 1F1B schedule, with communication between pipeline
stages. stages.
Returns dictionary with losses if the last stage, empty dict otherwise.""" Returns dictionary with losses if the last stage, empty dict otherwise."""
args = get_args()
timers = get_timers() timers = get_timers()
assert len(model) == 1 assert len(model) == 1
...@@ -550,18 +638,20 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite ...@@ -550,18 +638,20 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
if not forward_only: if not forward_only:
input_tensors = [] input_tensors = []
output_tensors = [] output_tensors = []
losses_reduced = [] forward_data_store = []
# Run warmup forward passes. # Run warmup forward passes.
for i in range(num_warmup_microbatches): for i in range(num_warmup_microbatches):
input_tensor = recv_forward(recv_tensor_shapes, timers=timers) input_tensor = recv_forward(recv_tensor_shapes, timers=timers)
output_tensor = forward_step(forward_step_func, data_iterator, model, output_tensor = forward_step(forward_step_func, data_iterator, model,
input_tensor, losses_reduced) input_tensor, forward_data_store,
collect_non_loss_data)
send_forward(output_tensor, send_tensor_shapes, timers=timers) send_forward(output_tensor, send_tensor_shapes, timers=timers)
if not forward_only: if not forward_only:
input_tensors.append(input_tensor) input_tensors.append(input_tensor)
output_tensors.append(output_tensor) output_tensors.append(output_tensor)
deallocate_output_tensor(output_tensor[0])
# Before running 1F1B, need to receive first forward tensor. # Before running 1F1B, need to receive first forward tensor.
# If all microbatches are run in warmup / cooldown phase, then no need to # If all microbatches are run in warmup / cooldown phase, then no need to
...@@ -574,7 +664,8 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite ...@@ -574,7 +664,8 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
last_iteration = (i == (num_microbatches_remaining - 1)) last_iteration = (i == (num_microbatches_remaining - 1))
output_tensor = forward_step(forward_step_func, data_iterator, model, output_tensor = forward_step(forward_step_func, data_iterator, model,
input_tensor, losses_reduced) input_tensor, forward_data_store,
collect_non_loss_data)
if forward_only: if forward_only:
send_forward(output_tensor, send_tensor_shapes, timers=timers) send_forward(output_tensor, send_tensor_shapes, timers=timers)
...@@ -590,6 +681,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite ...@@ -590,6 +681,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
# Add input_tensor and output_tensor to end of list. # Add input_tensor and output_tensor to end of list.
input_tensors.append(input_tensor) input_tensors.append(input_tensor)
output_tensors.append(output_tensor) output_tensors.append(output_tensor)
deallocate_output_tensor(output_tensor[0])
# Pop input_tensor and output_tensor from the start of the list for # Pop input_tensor and output_tensor from the start of the list for
# the backward pass. # the backward pass.
...@@ -622,4 +714,4 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite ...@@ -622,4 +714,4 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
send_backward(input_tensor_grad, recv_tensor_shapes, timers=timers) send_backward(input_tensor_grad, recv_tensor_shapes, timers=timers)
return losses_reduced return forward_data_store
<!-- coding=utf-8-->
<!-- Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.-->
<!---->
<!-- Licensed under the Apache License, Version 2.0 (the "License");-->
<!-- you may not use this file except in compliance with the License.-->
<!-- You may obtain a copy of the License at-->
<!---->
<!-- http://www.apache.org/licenses/LICENSE-2.0-->
<!---->
<!-- Unless required by applicable law or agreed to in writing, software-->
<!-- distributed under the License is distributed on an "AS IS" BASIS,-->
<!-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.-->
<!-- See the License for the specific language governing permissions and-->
<!-- limitations under the License.-->
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8" />
<title>Megatron</title>
<style>
.wrapper {
max-width: 75%;
margin: auto;
}
h1 {
margin: 3rem 0 1rem 0;
padding: 0;
font-size: 1.5rem;
}
textarea {
width: 100%;
min-height: 300px;
resize: none;
border-radius: 8px;
border: 1px solid #ddd;
padding: 0.5rem;
box-shadow: inset 0 0 0.25rem #ddd;
&:focus {
outline: none;
border: 1px solid darken(#ddd, 5%);
box-shadow: inset 0 0 0.5rem darken(#ddd, 5%);
}
}
#the-count {
float: right;
padding: 0.1rem 0 0 0;
font-size: 0.875rem;
}
/* Chat containers */
.container {
font-family: 'Arial', sans-serif;
font-size: 16px;
border: 2px solid #dedede;
background-color: #f1f1f1;
border-radius: 5px;
padding: 15px;
margin: 10px 0;
}
/* Clear floats */
.container::after {
content: "";
clear: both;
display: table;
}
/* Style images */
.container img {
float: left;
max-width: 60px;
width: 100%;
margin-right: 20px;
border-radius: 50%;
}
</style>
</head>
<body>
<div class="wrapper">
<h1>Prompt Megatron</h1>
<textarea name="prompt" id="prompt" maxlength="1024" placeholder="Add prompt"autofocus></textarea>
<label for="tokens_to_generate">Number tokens to generate (1-1024):</label>
<input type="number" id="tokens_to_generate" name="tokens_to_generate" min="10" max="256", value=32>
<button onclick="submit_query()">Submit</button>
<div id="the-count">
<span id="current">0</span>
<span id="maximum">/ 1000</span>
</div>
<textarea name="response" id="response" maxlength="2048" placeholder="Megatron response..."></textarea>
</div>
<script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.5.1/jquery.min.js"></script>
<script type="text/javascript">
function submit_query() {
$("#response").val("Waiting for Megatron response...");
$.ajax({
url:"api",
type:"PUT",
data:JSON.stringify({prompts: [$("#prompt").val()], tokens_to_generate: parseInt($("#tokens_to_generate").val(),10)}),
contentType:"application/json; charset=utf-8",
dataType:"json",
success: function(data){
data.max_len=35;
$("#response").val(data.text);
}
});
}
$('textarea').keyup(function() {
var characterCount = $(this).val().length,
current = $('#current'),
maximum = $('#maximum'),
theCount = $('#the-count');
current.text(characterCount);
if (characterCount >= 800) {
maximum.css('color', '#8f0001');
current.css('color', '#8f0001');
theCount.css('font-weight','bold');
} else {
maximum.css('color','#666');
theCount.css('font-weight','normal');
}
});
</script>
</body>
</html>
...@@ -35,7 +35,10 @@ def generate_and_post_process(model, ...@@ -35,7 +35,10 @@ def generate_and_post_process(model,
top_p_sampling=0.0, top_p_sampling=0.0,
temperature=1.0, temperature=1.0,
add_BOS=False, add_BOS=False,
use_eod_token_for_early_termination=True): use_eod_token_for_early_termination=True,
stop_on_double_eol=False,
stop_on_eol=False,
random_seed=-1):
"""Run inference and post-process outputs, i.e., detokenize, """Run inference and post-process outputs, i.e., detokenize,
move to cpu and convert to list.""" move to cpu and convert to list."""
...@@ -49,7 +52,10 @@ def generate_and_post_process(model, ...@@ -49,7 +52,10 @@ def generate_and_post_process(model,
top_p_sampling=top_p_sampling, top_p_sampling=top_p_sampling,
temperature=temperature, temperature=temperature,
add_BOS=add_BOS, add_BOS=add_BOS,
use_eod_token_for_early_termination=use_eod_token_for_early_termination) use_eod_token_for_early_termination=use_eod_token_for_early_termination,
stop_on_double_eol=stop_on_double_eol,
stop_on_eol=stop_on_eol,
random_seed=random_seed)
# Only post-process on first stage. # Only post-process on first stage.
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
...@@ -74,7 +80,10 @@ def generate(model, ...@@ -74,7 +80,10 @@ def generate(model,
top_p_sampling=0.0, top_p_sampling=0.0,
temperature=1.0, temperature=1.0,
add_BOS=False, add_BOS=False,
use_eod_token_for_early_termination=True): use_eod_token_for_early_termination=True,
stop_on_double_eol=False,
stop_on_eol=False,
random_seed=-1):
"""Given prompts and input parameters, run inference and return: """Given prompts and input parameters, run inference and return:
tokens: prompts plus the generated tokens. tokens: prompts plus the generated tokens.
lengths: length of the prompt + generations. Note that we can lengths: length of the prompt + generations. Note that we can
...@@ -87,8 +96,11 @@ def generate(model, ...@@ -87,8 +96,11 @@ def generate(model,
values = [tokens_to_generate, values = [tokens_to_generate,
return_output_log_probs, return_output_log_probs,
top_k_sampling, top_p_sampling, top_k_sampling, top_p_sampling,
temperature, add_BOS, use_eod_token_for_early_termination] temperature, add_BOS, use_eod_token_for_early_termination,
values_float_tensor = broadcast_float_list(7, float_list=values) stop_on_double_eol,
stop_on_eol,
random_seed]
values_float_tensor = broadcast_float_list(10, float_list=values)
tokens_to_generate = int(values_float_tensor[0].item()) tokens_to_generate = int(values_float_tensor[0].item())
return_output_log_probs = bool(values_float_tensor[1].item()) return_output_log_probs = bool(values_float_tensor[1].item())
top_k_sampling = int(values_float_tensor[2].item()) top_k_sampling = int(values_float_tensor[2].item())
...@@ -96,6 +108,12 @@ def generate(model, ...@@ -96,6 +108,12 @@ def generate(model,
temperature = values_float_tensor[4].item() temperature = values_float_tensor[4].item()
add_BOS = bool(values_float_tensor[5].item()) add_BOS = bool(values_float_tensor[5].item())
use_eod_token_for_early_termination = bool(values_float_tensor[6].item()) use_eod_token_for_early_termination = bool(values_float_tensor[6].item())
stop_on_double_eol = bool(values_float_tensor[7].item())
stop_on_eol = bool(values_float_tensor[8].item())
random_seed = int(values_float_tensor[9].item())
if random_seed != -1:
torch.random.manual_seed(random_seed)
# Tokenize prompts and get the batch. # Tokenize prompts and get the batch.
# Note that these tensors are broadcaseted to all ranks. # Note that these tensors are broadcaseted to all ranks.
...@@ -108,7 +126,7 @@ def generate(model, ...@@ -108,7 +126,7 @@ def generate(model,
if tokens_to_generate == 0: if tokens_to_generate == 0:
return score_and_return_on_first_stage( return score_and_return_on_first_stage(
model, context_tokens_tensor, context_length_tensor) model, context_tokens_tensor, context_length_tensor)
# Main inference function. # Main inference function.
# Note that the outputs are available on the first stage. # Note that the outputs are available on the first stage.
return generate_tokens_probs_and_return_on_first_stage( return generate_tokens_probs_and_return_on_first_stage(
...@@ -117,4 +135,6 @@ def generate(model, ...@@ -117,4 +135,6 @@ def generate(model,
top_k=top_k_sampling, top_k=top_k_sampling,
top_p=top_p_sampling, top_p=top_p_sampling,
temperature=temperature, temperature=temperature,
use_eod_token_for_early_termination=use_eod_token_for_early_termination) use_eod_token_for_early_termination=use_eod_token_for_early_termination,
stop_on_double_eol=stop_on_double_eol,
stop_on_eol=stop_on_eol)
...@@ -96,7 +96,10 @@ def generate_tokens_probs_and_return_on_first_stage( ...@@ -96,7 +96,10 @@ def generate_tokens_probs_and_return_on_first_stage(
return_output_log_probs=False, return_output_log_probs=False,
top_k=0, top_p=0.0, top_k=0, top_p=0.0,
temperature=1.0, temperature=1.0,
use_eod_token_for_early_termination=True): use_eod_token_for_early_termination=True,
stop_on_double_eol=False,
stop_on_eol=False
):
"""Main token generation function. """Main token generation function.
Arguments: Arguments:
model: no interleaving is supported. model: no interleaving is supported.
...@@ -130,6 +133,10 @@ def generate_tokens_probs_and_return_on_first_stage( ...@@ -130,6 +133,10 @@ def generate_tokens_probs_and_return_on_first_stage(
min_prompt_length = lengths.min().item() min_prompt_length = lengths.min().item()
max_sequence_length = tokens.size(1) max_sequence_length = tokens.size(1)
max_sequence_length = min(max_sequence_length, args.max_position_embeddings) max_sequence_length = min(max_sequence_length, args.max_position_embeddings)
# If the context is too big, this happens
if min_prompt_length >= max_sequence_length:
raise ValueError("context length + tokens_to_generate too large")
# forward step. # forward step.
forward_step = ForwardStep(model, batch_size, max_sequence_length) forward_step = ForwardStep(model, batch_size, max_sequence_length)
...@@ -227,8 +234,20 @@ def generate_tokens_probs_and_return_on_first_stage( ...@@ -227,8 +234,20 @@ def generate_tokens_probs_and_return_on_first_stage(
# Check if all the sequences have hit the termination_id. # Check if all the sequences have hit the termination_id.
done = None done = None
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
done_token = (new_sample == termination_id).byte() & \ # TODO(rprenger) These stopping methods are tokenizer dependent
started.byte() # instead tokenization should be in the inference loop so stop sequences can be used
if stop_on_double_eol:
hit_double_eol = (new_sample == 628).byte() & started.byte()
hit_two_eols = (new_sample == 198).byte() & (tokens[:, context_length-1] == 198).byte() & started.byte()
done_token = hit_double_eol | hit_two_eols
elif stop_on_eol:
hit_double_eol = (new_sample == 628).byte() & started.byte()
hit_eol = (new_sample == 198).byte() & started.byte()
done_token = hit_double_eol | hit_eol
else:
done_token = (new_sample == termination_id).byte() & \
started.byte()
just_finished = (done_token & ~is_generation_done).bool() just_finished = (done_token & ~is_generation_done).bool()
generated_sequence_lengths[just_finished.view(-1)] = \ generated_sequence_lengths[just_finished.view(-1)] = \
context_length + 1 context_length + 1
......
...@@ -36,9 +36,6 @@ class MegatronGenerate(Resource): ...@@ -36,9 +36,6 @@ class MegatronGenerate(Resource):
def put(self): def put(self):
args = get_args() args = get_args()
print("request IP: " + str(request.remote_addr))
print(json.dumps(request.get_json()),flush=True)
print("current time: ", datetime.datetime.now())
if not "prompts" in request.get_json(): if not "prompts" in request.get_json():
return "prompts argument required", 400 return "prompts argument required", 400
...@@ -101,20 +98,60 @@ class MegatronGenerate(Resource): ...@@ -101,20 +98,60 @@ class MegatronGenerate(Resource):
add_BOS = request.get_json()["add_BOS"] add_BOS = request.get_json()["add_BOS"]
if not isinstance(add_BOS, bool): if not isinstance(add_BOS, bool):
return "add_BOS must be a boolean value" return "add_BOS must be a boolean value"
if any([len(prompt) == 0 for prompt in prompts]) and not add_BOS:
return "Empty prompts require add_BOS=true"
stop_on_double_eol = False
if "stop_on_double_eol" in request.get_json():
stop_on_double_eol = request.get_json()["stop_on_double_eol"]
if not isinstance(stop_on_double_eol, bool):
return "stop_on_double_eol must be a boolean value"
stop_on_eol = False
if "stop_on_eol" in request.get_json():
stop_on_eol = request.get_json()["stop_on_eol"]
if not isinstance(stop_on_eol, bool):
return "stop_on_eol must be a boolean value"
random_seed = -1
if "random_seed" in request.get_json():
random_seed = request.get_json()["random_seed"]
if not isinstance(random_seed, int):
return "random_seed must be integer"
if random_seed < 0:
return "random_seed must be a positive integer"
no_log = False
if "no_log" in request.get_json():
no_log = request.get_json()["no_log"]
if not isinstance(no_log, bool):
return "no_log must be a boolean value"
with lock: # Need to get lock to keep multiple threads from hitting code with lock: # Need to get lock to keep multiple threads from hitting code
if not no_log:
print("request IP: " + str(request.remote_addr))
print(json.dumps(request.get_json()),flush=True)
print("start time: ", datetime.datetime.now())
MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate
response, response_seg, response_logprobs, _ = \ try:
generate_and_post_process( response, response_seg, response_logprobs, _ = \
self.model, generate_and_post_process(
prompts=prompts, self.model,
tokens_to_generate=tokens_to_generate, prompts=prompts,
return_output_log_probs=logprobs, tokens_to_generate=tokens_to_generate,
top_k_sampling=top_k, return_output_log_probs=logprobs,
top_p_sampling=top_p, top_k_sampling=top_k,
temperature=temperature, top_p_sampling=top_p,
add_BOS=add_BOS, temperature=temperature,
use_eod_token_for_early_termination=True) add_BOS=add_BOS,
use_eod_token_for_early_termination=True,
stop_on_double_eol=stop_on_double_eol,
stop_on_eol=stop_on_eol,
random_seed=random_seed)
except ValueError as ve:
return "Length of prompt + tokens_to_generate longer than allowed"
print("end time: ", datetime.datetime.now())
return jsonify({"text": response, return jsonify({"text": response,
"segments": response_seg, "segments": response_seg,
......
...@@ -21,11 +21,11 @@ import sys ...@@ -21,11 +21,11 @@ import sys
import time import time
# The earliest we can measure the start time. # The earliest we can measure the start time.
_TRAIN_START_TIME = time.time() _TRAIN_START_TIME = time.time()
import torch import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron import get_args from megatron import get_args
from megatron import get_signal_handler
from megatron import get_timers from megatron import get_timers
from megatron import get_tensorboard_writer from megatron import get_tensorboard_writer
from megatron import get_current_global_batch_size from megatron import get_current_global_batch_size
...@@ -42,7 +42,7 @@ from megatron.model import ModelType ...@@ -42,7 +42,7 @@ from megatron.model import ModelType
from megatron.optimizer import get_megatron_optimizer from megatron.optimizer import get_megatron_optimizer
from megatron.initialize import initialize_megatron from megatron.initialize import initialize_megatron
from megatron.initialize import write_args_to_tensorboard from megatron.initialize import write_args_to_tensorboard
from megatron.learning_rates import AnnealingLR from megatron.optimizer_param_scheduler import OptimizerParamScheduler
from megatron.model import DistributedDataParallel as LocalDDP from megatron.model import DistributedDataParallel as LocalDDP
from megatron.utils import check_adlr_autoresume_termination from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import unwrap_model from megatron.utils import unwrap_model
...@@ -50,7 +50,7 @@ from megatron.data.data_samplers import build_pretraining_data_loader ...@@ -50,7 +50,7 @@ from megatron.data.data_samplers import build_pretraining_data_loader
from megatron.utils import calc_params_l2_norm from megatron.utils import calc_params_l2_norm
from megatron.schedules import get_forward_backward_func from megatron.schedules import get_forward_backward_func
from megatron.utils import report_memory from megatron.utils import report_memory
from megatron.model.vision.knn_monitor import compute_feature_bank
def print_datetime(string): def print_datetime(string):
...@@ -64,6 +64,7 @@ def pretrain(train_valid_test_dataset_provider, ...@@ -64,6 +64,7 @@ def pretrain(train_valid_test_dataset_provider,
model_provider, model_provider,
model_type, model_type,
forward_step_func, forward_step_func,
process_non_loss_data_func=None,
extra_args_provider=None, extra_args_provider=None,
args_defaults={}): args_defaults={}):
"""Main training program. """Main training program.
...@@ -85,6 +86,10 @@ def pretrain(train_valid_test_dataset_provider, ...@@ -85,6 +86,10 @@ def pretrain(train_valid_test_dataset_provider,
the info we would like to monitor during training, for example the info we would like to monitor during training, for example
`lm-loss: value`. We also require that this function add `lm-loss: value`. We also require that this function add
`batch generator` to the timers class. `batch generator` to the timers class.
process_non_loss_data_func: a function to post process outputs of the
network. It can be used for dumping output tensors (e.g images) to
tensorboard. It takes `collected data`(list of tensors),
`current iteration index` and `tensorboard writer` as arguments.
extra_args_provider: a function that takes a parser and adds arguments extra_args_provider: a function that takes a parser and adds arguments
to it. It is used for programs to add their own arguments. to it. It is used for programs to add their own arguments.
args_defaults: a dictionary from argument-name to argument-value. It args_defaults: a dictionary from argument-name to argument-value. It
...@@ -112,7 +117,7 @@ def pretrain(train_valid_test_dataset_provider, ...@@ -112,7 +117,7 @@ def pretrain(train_valid_test_dataset_provider,
# Model, optimizer, and learning rate. # Model, optimizer, and learning rate.
timers('model-and-optimizer-setup').start() timers('model-and-optimizer-setup').start()
model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider, model, optimizer, opt_param_scheduler = setup_model_and_optimizer(model_provider,
model_type) model_type)
timers('model-and-optimizer-setup').stop() timers('model-and-optimizer-setup').stop()
print_datetime('after model, optimizer, and learning rate ' print_datetime('after model, optimizer, and learning rate '
...@@ -143,25 +148,28 @@ def pretrain(train_valid_test_dataset_provider, ...@@ -143,25 +148,28 @@ def pretrain(train_valid_test_dataset_provider,
iteration = 0 iteration = 0
if args.do_train and args.train_iters > 0: if args.do_train and args.train_iters > 0:
iteration = train(forward_step_func, iteration = train(forward_step_func,
model, optimizer, lr_scheduler, model, optimizer, opt_param_scheduler,
train_data_iterator, valid_data_iterator) train_data_iterator, valid_data_iterator,
process_non_loss_data_func)
print_datetime('after training is done') print_datetime('after training is done')
if args.do_valid: if args.do_valid:
prefix = 'the end of training for val data' prefix = 'the end of training for val data'
evaluate_and_print_results(prefix, forward_step_func, evaluate_and_print_results(prefix, forward_step_func,
valid_data_iterator, model, valid_data_iterator, model,
iteration, False) iteration, process_non_loss_data_func,
False)
if args.save and iteration != 0: if args.save and iteration != 0:
save_checkpoint(iteration, model, optimizer, lr_scheduler) save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
if args.do_test: if args.do_test:
# Run on test data. # Run on test data.
prefix = 'the end of training for test data' prefix = 'the end of training for test data'
evaluate_and_print_results(prefix, forward_step_func, evaluate_and_print_results(prefix, forward_step_func,
test_data_iterator, model, test_data_iterator, model,
0, True) 0, process_non_loss_data_func,
True)
def update_train_iters(args): def update_train_iters(args):
...@@ -284,7 +292,10 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap ...@@ -284,7 +292,10 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap
args.accumulate_allreduce_grads_in_fp32, args.accumulate_allreduce_grads_in_fp32,
args.use_contiguous_buffers_in_local_ddp) args.use_contiguous_buffers_in_local_ddp)
for model_module in model] for model_module in model]
# broad cast params from data parallel src rank to other data parallel ranks
if args.data_parallel_random_init:
for model_module in model:
model_module.broadcast_params()
else: else:
raise NotImplementedError('Unknown DDP implementation specified: ' raise NotImplementedError('Unknown DDP implementation specified: '
'{}. Exiting.'.format(args.DDP_impl)) '{}. Exiting.'.format(args.DDP_impl))
...@@ -292,7 +303,7 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap ...@@ -292,7 +303,7 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap
return model return model
def get_learning_rate_scheduler(optimizer): def get_optimizer_param_scheduler(optimizer):
"""Build the learning rate scheduler.""" """Build the learning rate scheduler."""
args = get_args() args = get_args()
...@@ -300,11 +311,12 @@ def get_learning_rate_scheduler(optimizer): ...@@ -300,11 +311,12 @@ def get_learning_rate_scheduler(optimizer):
if args.train_iters: if args.train_iters:
if args.lr_decay_iters is None: if args.lr_decay_iters is None:
args.lr_decay_iters = args.train_iters args.lr_decay_iters = args.train_iters
decay_steps = args.lr_decay_iters * args.global_batch_size lr_decay_steps = args.lr_decay_iters * args.global_batch_size
wd_incr_steps = args.train_iters * args.global_batch_size
if args.lr_warmup_fraction is not None: if args.lr_warmup_fraction is not None:
warmup_steps = args.lr_warmup_fraction * decay_steps lr_warmup_steps = args.lr_warmup_fraction * lr_decay_steps
else: else:
warmup_steps = args.lr_warmup_iters * args.global_batch_size lr_warmup_steps = args.lr_warmup_iters * args.global_batch_size
# Sample-based training. # Sample-based training.
elif args.train_samples: elif args.train_samples:
# We need to set training iters for later use. Technically # We need to set training iters for later use. Technically
...@@ -313,29 +325,38 @@ def get_learning_rate_scheduler(optimizer): ...@@ -313,29 +325,38 @@ def get_learning_rate_scheduler(optimizer):
update_train_iters(args) update_train_iters(args)
if args.lr_decay_samples is None: if args.lr_decay_samples is None:
args.lr_decay_samples = args.train_samples args.lr_decay_samples = args.train_samples
decay_steps = args.lr_decay_samples lr_decay_steps = args.lr_decay_samples
wd_incr_steps = args.train_samples
if args.lr_warmup_fraction is not None: if args.lr_warmup_fraction is not None:
warmup_steps = args.lr_warmup_fraction * decay_steps lr_warmup_steps = args.lr_warmup_fraction * lr_decay_steps
else: else:
warmup_steps = args.lr_warmup_samples lr_warmup_steps = args.lr_warmup_samples
else: else:
raise Exception( raise Exception(
'either train-iters or train-samples should be provided.') 'either train-iters or train-samples should be provided.')
lr_scheduler = AnnealingLR( opt_param_scheduler = OptimizerParamScheduler(
optimizer, optimizer,
max_lr=args.lr, max_lr=args.lr,
min_lr=args.min_lr, min_lr=args.min_lr,
warmup_steps=warmup_steps, lr_warmup_steps=lr_warmup_steps,
decay_steps=decay_steps, lr_decay_steps=lr_decay_steps,
decay_style=args.lr_decay_style, lr_decay_style=args.lr_decay_style,
use_checkpoint_lr_scheduler=args.use_checkpoint_lr_scheduler, start_wd=args.start_weight_decay,
override_lr_scheduler=args.override_lr_scheduler) end_wd=args.end_weight_decay,
wd_incr_steps=wd_incr_steps,
return lr_scheduler wd_incr_style=args.weight_decay_incr_style,
use_checkpoint_opt_param_scheduler=args.use_checkpoint_opt_param_scheduler,
override_opt_param_scheduler=args.override_opt_param_scheduler)
def setup_model_and_optimizer(model_provider_func, model_type):
return opt_param_scheduler
def setup_model_and_optimizer(model_provider_func,
model_type,
no_wd_decay_cond=None,
scale_lr_cond=None,
lr_mult=1.0):
"""Setup model and optimizer.""" """Setup model and optimizer."""
args = get_args() args = get_args()
...@@ -343,9 +364,10 @@ def setup_model_and_optimizer(model_provider_func, model_type): ...@@ -343,9 +364,10 @@ def setup_model_and_optimizer(model_provider_func, model_type):
unwrapped_model = unwrap_model(model, unwrapped_model = unwrap_model(model,
(torchDDP, LocalDDP, Float16Module)) (torchDDP, LocalDDP, Float16Module))
optimizer = get_megatron_optimizer(unwrapped_model) optimizer = get_megatron_optimizer(unwrapped_model, no_wd_decay_cond,
scale_lr_cond, lr_mult)
lr_scheduler = get_learning_rate_scheduler(optimizer) opt_param_scheduler = get_optimizer_param_scheduler(optimizer)
if args.load is not None: if args.load is not None:
timers = get_timers() timers = get_timers()
...@@ -353,7 +375,7 @@ def setup_model_and_optimizer(model_provider_func, model_type): ...@@ -353,7 +375,7 @@ def setup_model_and_optimizer(model_provider_func, model_type):
# max time. # max time.
torch.distributed.barrier() torch.distributed.barrier()
timers('load-checkpoint').start() timers('load-checkpoint').start()
args.iteration = load_checkpoint(model, optimizer, lr_scheduler) args.iteration = load_checkpoint(model, optimizer, opt_param_scheduler)
torch.distributed.barrier() torch.distributed.barrier()
timers('load-checkpoint').stop() timers('load-checkpoint').stop()
timers.log(['load-checkpoint']) timers.log(['load-checkpoint'])
...@@ -372,11 +394,11 @@ def setup_model_and_optimizer(model_provider_func, model_type): ...@@ -372,11 +394,11 @@ def setup_model_and_optimizer(model_provider_func, model_type):
if args.fp16: if args.fp16:
optimizer.reload_model_params() optimizer.reload_model_params()
return model, optimizer, lr_scheduler return model, optimizer, opt_param_scheduler
def train_step(forward_step_func, data_iterator, def train_step(forward_step_func, data_iterator,
model, optimizer, lr_scheduler): model, optimizer, opt_param_scheduler):
"""Single training step.""" """Single training step."""
args = get_args() args = get_args()
timers = get_timers() timers = get_timers()
...@@ -426,19 +448,45 @@ def train_step(forward_step_func, data_iterator, ...@@ -426,19 +448,45 @@ def train_step(forward_step_func, data_iterator,
else: else:
grad = word_embeddings_weight.grad grad = word_embeddings_weight.grad
torch.distributed.all_reduce(grad, group=mpu.get_embedding_group()) torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())
# All-reduce position_embeddings grad across first (encoder) and split (decoder)
# stages to ensure that position embeddings parameters stay in sync.
# This should only run for T5 models with pipeline parallelism
if mpu.is_rank_in_position_embedding_group() and \
mpu.get_pipeline_model_parallel_world_size() > 1 and \
args.pipeline_model_parallel_split_rank is not None:
unwrapped_model = model[0]
unwrapped_model = unwrap_model(
unwrapped_model, (torchDDP, LocalDDP, Float16Module))
assert args.DDP_impl == 'local', \
'T5 model is only supported with local DDP mode'
grad = unwrapped_model.language_model.embedding.position_embeddings.weight.main_grad
torch.distributed.all_reduce(grad, group=mpu.get_position_embedding_group())
timers('backward-embedding-all-reduce').stop() timers('backward-embedding-all-reduce').stop()
if args.vision_pretraining and args.vision_pretraining_type == "dino":
unwrapped_model = unwrap_model(model[0],
(torchDDP, LocalDDP, Float16Module))
unwrapped_model.cancel_gradients_last_layer(args.curr_iteration)
# Update parameters. # Update parameters.
timers('optimizer').start() timers('optimizer').start()
update_successful, grad_norm, num_zeros_in_grad = optimizer.step() update_successful, grad_norm, num_zeros_in_grad = optimizer.step()
timers('optimizer').stop() timers('optimizer').stop()
if args.vision_pretraining and args.vision_pretraining_type == "dino":
unwrapped_model = unwrap_model(model[0],
(torchDDP, LocalDDP, Float16Module))
unwrapped_model.update_momentum(args.curr_iteration)
# Update learning rate. # Update learning rate.
if update_successful: if update_successful:
increment = get_num_microbatches() * \ increment = get_num_microbatches() * \
args.micro_batch_size * \ args.micro_batch_size * \
args.data_parallel_size args.data_parallel_size
lr_scheduler.step(increment=increment) opt_param_scheduler.step(increment=increment)
skipped_iter = 0 skipped_iter = 0
else: else:
skipped_iter = 1 skipped_iter = 1
...@@ -544,6 +592,10 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, ...@@ -544,6 +592,10 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
writer.add_scalar('loss-scale', loss_scale, iteration) writer.add_scalar('loss-scale', loss_scale, iteration)
writer.add_scalar('loss-scale vs samples', loss_scale, writer.add_scalar('loss-scale vs samples', loss_scale,
args.consumed_train_samples) args.consumed_train_samples)
if args.log_world_size_to_tensorboard:
writer.add_scalar('world-size', args.world_size, iteration)
writer.add_scalar('world-size vs samples', args.world_size,
args.consumed_train_samples)
if grad_norm is not None: if grad_norm is not None:
writer.add_scalar('grad-norm', grad_norm, iteration) writer.add_scalar('grad-norm', grad_norm, iteration)
writer.add_scalar('grad-norm vs samples', grad_norm, writer.add_scalar('grad-norm vs samples', grad_norm,
...@@ -624,20 +676,21 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, ...@@ -624,20 +676,21 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
return report_memory_flag return report_memory_flag
def save_checkpoint_and_time(iteration, model, optimizer, lr_scheduler): def save_checkpoint_and_time(iteration, model, optimizer, opt_param_scheduler):
timers = get_timers() timers = get_timers()
# Extra barrier is added to make sure # Extra barrier is added to make sure
# all ranks report the max time. # all ranks report the max time.
torch.distributed.barrier() torch.distributed.barrier()
timers('save-checkpoint').start() timers('save-checkpoint').start()
save_checkpoint(iteration, model, optimizer, lr_scheduler) save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
torch.distributed.barrier() torch.distributed.barrier()
timers('save-checkpoint').stop() timers('save-checkpoint').stop()
timers.log(['save-checkpoint']) timers.log(['save-checkpoint'])
def train(forward_step_func, model, optimizer, lr_scheduler, def train(forward_step_func, model, optimizer, opt_param_scheduler,
train_data_iterator, valid_data_iterator): train_data_iterator, valid_data_iterator,
process_non_loss_data_func):
"""Train the model function.""" """Train the model function."""
args = get_args() args = get_args()
timers = get_timers() timers = get_timers()
...@@ -660,12 +713,13 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -660,12 +713,13 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
report_memory_flag = True report_memory_flag = True
while iteration < args.train_iters: while iteration < args.train_iters:
update_num_microbatches(args.consumed_train_samples) update_num_microbatches(args.consumed_train_samples)
args.curr_iteration = iteration
loss_dict, skipped_iter, grad_norm, num_zeros_in_grad = \ loss_dict, skipped_iter, grad_norm, num_zeros_in_grad = \
train_step(forward_step_func, train_step(forward_step_func,
train_data_iterator, train_data_iterator,
model, model,
optimizer, optimizer,
lr_scheduler) opt_param_scheduler)
iteration += 1 iteration += 1
args.consumed_train_samples += mpu.get_data_parallel_world_size() * \ args.consumed_train_samples += mpu.get_data_parallel_world_size() * \
args.micro_batch_size * \ args.micro_batch_size * \
...@@ -686,7 +740,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -686,7 +740,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
if args.adlr_autoresume and \ if args.adlr_autoresume and \
(iteration % args.adlr_autoresume_interval == 0): (iteration % args.adlr_autoresume_interval == 0):
check_adlr_autoresume_termination(iteration, model, optimizer, check_adlr_autoresume_termination(iteration, model, optimizer,
lr_scheduler) opt_param_scheduler)
# Evaluation # Evaluation
if args.eval_interval and iteration % args.eval_interval == 0 and \ if args.eval_interval and iteration % args.eval_interval == 0 and \
...@@ -694,14 +748,23 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -694,14 +748,23 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
prefix = 'iteration {}'.format(iteration) prefix = 'iteration {}'.format(iteration)
evaluate_and_print_results(prefix, forward_step_func, evaluate_and_print_results(prefix, forward_step_func,
valid_data_iterator, model, valid_data_iterator, model,
iteration, False) iteration, process_non_loss_data_func,
False)
# Checkpointing # Checkpointing
saved_checkpoint = False saved_checkpoint = False
if args.exit_signal_handler:
signal_handler = get_signal_handler()
if any(signal_handler.signals_received()):
save_checkpoint_and_time(iteration, model, optimizer,
opt_param_scheduler)
print_datetime('exiting program after receiving SIGTERM.')
sys.exit()
if args.save and args.save_interval and \ if args.save and args.save_interval and \
iteration % args.save_interval == 0: iteration % args.save_interval == 0:
save_checkpoint_and_time(iteration, model, optimizer, save_checkpoint_and_time(iteration, model, optimizer,
lr_scheduler) opt_param_scheduler)
saved_checkpoint = True saved_checkpoint = True
# Exiting based on duration # Exiting based on duration
...@@ -715,7 +778,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -715,7 +778,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
if done: if done:
if not saved_checkpoint: if not saved_checkpoint:
save_checkpoint_and_time(iteration, model, optimizer, save_checkpoint_and_time(iteration, model, optimizer,
lr_scheduler) opt_param_scheduler)
print_datetime('exiting program after {} minutes'.format(train_time)) print_datetime('exiting program after {} minutes'.format(train_time))
sys.exit() sys.exit()
...@@ -723,7 +786,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -723,7 +786,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
if args.exit_interval and iteration % args.exit_interval == 0: if args.exit_interval and iteration % args.exit_interval == 0:
if not saved_checkpoint: if not saved_checkpoint:
save_checkpoint_and_time(iteration, model, optimizer, save_checkpoint_and_time(iteration, model, optimizer,
lr_scheduler) opt_param_scheduler)
torch.distributed.barrier() torch.distributed.barrier()
print_datetime('exiting program at iteration {}'.format(iteration)) print_datetime('exiting program at iteration {}'.format(iteration))
sys.exit() sys.exit()
...@@ -732,10 +795,17 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -732,10 +795,17 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
return iteration return iteration
def evaluate(forward_step_func, data_iterator, model, verbose=False): def evaluate(forward_step_func,
data_iterator,
model,
process_non_loss_data_func,
verbose=False):
"""Evaluation.""" """Evaluation."""
args = get_args() args = get_args()
if args.vision_pretraining and args.vision_pretraining_type == "dino":
compute_feature_bank(model)
# Turn on evaluation mode which disables dropout. # Turn on evaluation mode which disables dropout.
for model_module in model: for model_module in model:
model_module.eval() model_module.eval()
...@@ -769,6 +839,12 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False): ...@@ -769,6 +839,12 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
args.consumed_valid_samples += mpu.get_data_parallel_world_size() \ args.consumed_valid_samples += mpu.get_data_parallel_world_size() \
* args.micro_batch_size \ * args.micro_batch_size \
* get_num_microbatches() * get_num_microbatches()
collected_non_loss_data = None
if process_non_loss_data_func is not None and is_last_rank():
collected_non_loss_data = forward_backward_func(
forward_step_func, data_iterator, model, optimizer=None,
timers=None, forward_only=True, collect_non_loss_data=True)
# Move model back to the train mode. # Move model back to the train mode.
for model_module in model: for model_module in model:
model_module.train() model_module.train()
...@@ -776,16 +852,19 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False): ...@@ -776,16 +852,19 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
for key in total_loss_dict: for key in total_loss_dict:
total_loss_dict[key] /= args.eval_iters * get_num_microbatches() total_loss_dict[key] /= args.eval_iters * get_num_microbatches()
return total_loss_dict return total_loss_dict, collected_non_loss_data
def evaluate_and_print_results(prefix, forward_step_func, def evaluate_and_print_results(prefix, forward_step_func,
data_iterator, model, data_iterator, model,
iteration, verbose=False): iteration, process_non_loss_data_func,
verbose=False):
"""Helper function to evaluate and dump results on screen.""" """Helper function to evaluate and dump results on screen."""
args = get_args() args = get_args()
writer = get_tensorboard_writer() writer = get_tensorboard_writer()
total_loss_dict = evaluate(forward_step_func, data_iterator, model, verbose) total_loss_dict, collected_non_loss_data = evaluate(
forward_step_func, data_iterator, model,
process_non_loss_data_func, verbose)
string = ' validation loss at {} | '.format(prefix) string = ' validation loss at {} | '.format(prefix)
for key in total_loss_dict: for key in total_loss_dict:
string += '{} value: {:.6E} | '.format(key, total_loss_dict[key].item()) string += '{} value: {:.6E} | '.format(key, total_loss_dict[key].item())
...@@ -804,6 +883,9 @@ def evaluate_and_print_results(prefix, forward_step_func, ...@@ -804,6 +883,9 @@ def evaluate_and_print_results(prefix, forward_step_func,
writer.add_scalar('{} validation ppl vs samples'.format(key), writer.add_scalar('{} validation ppl vs samples'.format(key),
ppl, args.consumed_train_samples) ppl, args.consumed_train_samples)
if process_non_loss_data_func is not None and writer and is_last_rank():
process_non_loss_data_func(collected_non_loss_data, iteration, writer)
length = len(string) + 1 length = len(string) + 1
print_rank_last('-' * length) print_rank_last('-' * length)
print_rank_last(string) print_rank_last(string)
...@@ -882,7 +964,6 @@ def build_train_valid_test_data_iterators( ...@@ -882,7 +964,6 @@ def build_train_valid_test_data_iterators(
args.do_valid = flags[1].item() args.do_valid = flags[1].item()
args.do_test = flags[2].item() args.do_test = flags[2].item()
# Build iterators. # Build iterators.
dl_type = args.dataloader_type dl_type = args.dataloader_type
assert dl_type in ['single', 'cyclic'] assert dl_type in ['single', 'cyclic']
......
...@@ -126,7 +126,7 @@ def print_params_min_max_norm(optimizer, iteration): ...@@ -126,7 +126,7 @@ def print_params_min_max_norm(optimizer, iteration):
def check_adlr_autoresume_termination(iteration, model, def check_adlr_autoresume_termination(iteration, model,
optimizer, lr_scheduler): optimizer, opt_param_scheduler):
"""Check for autoresume signal and exit if it is received.""" """Check for autoresume signal and exit if it is received."""
from megatron.checkpointing import save_checkpoint from megatron.checkpointing import save_checkpoint
...@@ -136,7 +136,7 @@ def check_adlr_autoresume_termination(iteration, model, ...@@ -136,7 +136,7 @@ def check_adlr_autoresume_termination(iteration, model,
torch.distributed.barrier() torch.distributed.barrier()
if autoresume.termination_requested(): if autoresume.termination_requested():
if args.save: if args.save:
save_checkpoint(iteration, model, optimizer, lr_scheduler) save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
print_rank_0(">>> autoresume termination request found!") print_rank_0(">>> autoresume termination request found!")
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
autoresume.request_resume() autoresume.request_resume()
......
...@@ -21,21 +21,33 @@ from functools import partial ...@@ -21,21 +21,33 @@ from functools import partial
from megatron import get_args, get_timers, mpu, print_rank_0 from megatron import get_args, get_timers, mpu, print_rank_0
from megatron.data.vit_dataset import build_train_valid_datasets from megatron.data.vit_dataset import build_train_valid_datasets
from megatron.model import ModelType from megatron.model import ModelType
from megatron.model.vit_model import VitModel from megatron.model.vision.classification import VitClassificationModel
from megatron.model.vision.classification import MitClassificationModel
from megatron.training import pretrain from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group from megatron.utils import average_losses_across_data_parallel_group
def model_provider(pre_process=True, post_process=True): def model_provider(pre_process=True, post_process=True):
"""Build the model.""" """Build the model."""
print_rank_0("building VIT model ...")
args = get_args() args = get_args()
model = VitModel(num_classes=args.num_classes, if args.vision_backbone_type == 'vit':
pre_process=pre_process, print_rank_0("building VIT model ...")
post_process=post_process) model = VitClassificationModel(num_classes=args.num_classes,
pre_process=pre_process,
post_process=post_process)
elif args.vision_backbone_type == 'mit':
print_rank_0("building MIT model ...")
model = MitClassificationModel(num_classes=args.num_classes,
pre_process=pre_process,
post_process=post_process)
else:
raise Exception('{} vision backbone is not supported.'.format(
args.vision_backbone_type))
return model return model
def get_batch(data_iterator): def get_batch(data_iterator):
"""Build the batch.""" """Build the batch."""
data = next(data_iterator) data = next(data_iterator)
...@@ -46,6 +58,7 @@ def get_batch(data_iterator): ...@@ -46,6 +58,7 @@ def get_batch(data_iterator):
return images, labels return images, labels
def loss_func(labels, output_tensor): def loss_func(labels, output_tensor):
logits = output_tensor.contiguous().float() logits = output_tensor.contiguous().float()
loss = F.cross_entropy(logits, labels) loss = F.cross_entropy(logits, labels)
...@@ -58,6 +71,7 @@ def loss_func(labels, output_tensor): ...@@ -58,6 +71,7 @@ def loss_func(labels, output_tensor):
return loss, {"loss": averaged_loss[0], "accuracy": averaged_loss[1]} return loss, {"loss": averaged_loss[0], "accuracy": averaged_loss[1]}
def forward_step(data_iterator, model): def forward_step(data_iterator, model):
"""Forward step.""" """Forward step."""
timers = get_timers() timers = get_timers()
...@@ -82,7 +96,10 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): ...@@ -82,7 +96,10 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
print_rank_0( print_rank_0(
"> building train, validation, and test datasets " "for VIT ..." "> building train, validation, and test datasets " "for VIT ..."
) )
train_ds, valid_ds = build_train_valid_datasets(data_path=args.data_path) train_ds, valid_ds = build_train_valid_datasets(
data_path=args.data_path,
image_size=(args.img_h, args.img_w)
)
print_rank_0("> finished creating VIT datasets ...") print_rank_0("> finished creating VIT datasets ...")
return train_ds, valid_ds, None return train_ds, valid_ds, None
...@@ -95,5 +112,5 @@ if __name__ == "__main__": ...@@ -95,5 +112,5 @@ if __name__ == "__main__":
model_provider, model_provider,
ModelType.encoder_or_decoder, ModelType.encoder_or_decoder,
forward_step, forward_step,
args_defaults={'dataloader_type': 'cyclic'} args_defaults={'dataloader_type': 'cyclic', 'vision_pretraining': True}
) )
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
import torch.distributed as dist
from functools import partial
from megatron import get_args, get_timers, mpu, print_rank_0
from megatron.data.vit_dataset import build_train_valid_datasets
from megatron.model.vision.dino import DINOPretrainModel
from megatron.model.vision.knn_monitor import knn_predict, get_feature_bank
from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group, unwrap_model
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
from megatron.model import ModelType
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
return DINOPretrainModel(pre_process=pre_process, post_process=post_process)
def get_batch(data_iterator):
"""Build the batch."""
data = next(data_iterator)
# only data parallelism; no need for broadcast
if isinstance(data[0], list):
images = [aug.cuda() for aug in data[0]]
else:
images = data[0].cuda()
labels = data[1].cuda()
return images, labels
def loss_func(model, labels, output_tensor, collect_data=False):
args = get_args()
model = unwrap_model(
model,
(torchDDP, LocalDDP, Float16Module)
)
if model.training:
student_output, teacher_output = output_tensor
loss = model.dino_loss(student_output, teacher_output, args.curr_iteration)
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {"loss": averaged_loss[0]}
else:
_, teacher_feature = output_tensor
feature_bank, feature_labels, classes = get_feature_bank()
feature = F.normalize(teacher_feature.float(), dim=1)
knn_accs = []
for k in [10, 20, 100, 200]:
pred_labels = knn_predict(feature, feature_bank,
feature_labels, classes, k, 0.07)
knn_acc = (pred_labels[:, 0] == labels).float().mean()
knn_accs.append(knn_acc)
averaged_loss = average_losses_across_data_parallel_group(knn_accs)
return 0, {"knn_acc_10": averaged_loss[0],
"knn_acc_20": averaged_loss[1],
"knn_acc_100": averaged_loss[2],
"knn_acc_200": averaged_loss[3]}
def forward_step(data_iterator, model):
"""Forward step."""
timers = get_timers()
# Get the batch.
timers("batch-generator").start()
(
images,
labels,
) = get_batch(data_iterator)
timers("batch-generator").stop()
return model(images), partial(loss_func, model, labels)
def train_valid_test_datasets_provider(train_val_test_num_samples):
"""Build train, valid, and test datasets."""
args = get_args()
print_rank_0(
"> building train, validation, and test datasets " "for VIT ..."
)
train_ds, valid_ds = build_train_valid_datasets(
data_path=args.data_path,
image_size=(args.img_h, args.img_w)
)
print_rank_0("> finished creating VIT datasets ...")
return train_ds, valid_ds, None
if __name__ == "__main__":
pretrain(
train_valid_test_datasets_provider,
model_provider,
ModelType.encoder_or_decoder,
forward_step,
args_defaults={'dataloader_type': 'cyclic', 'vision_pretraining': True}
)
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pretrain VIT"""
import torch
import torch.nn.functional as F
from functools import partial
from megatron import get_args, get_timers, mpu, print_rank_0, print_rank_last
from megatron.data.vit_dataset import build_train_valid_datasets
from megatron.model.vision.inpainting import VitInpaintingModel
from megatron.model.vision.inpainting import MitInpaintingModel
from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group
from tasks.vision.metrics import SSIM, PSNR
from megatron.model import ModelType
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
args = get_args()
if args.vision_backbone_type == 'vit':
model = VitInpaintingModel(pre_process=pre_process,
post_process=post_process)
elif args.vision_backbone_type == 'mit':
model = MitInpaintingModel(pre_process=pre_process,
post_process=post_process)
else:
raise Exception('{} vision backbone is not supported.'.format(
args.vision_backbone_type))
return model
def get_batch(data_iterator):
"""Build the batch."""
data = next(data_iterator)
# only data parallelism; no need for broadcast
images = data[0][0].cuda()
masks = data[0][1].cuda()
return images, masks
def loss_func(images, masks, masked_images, outputs, collect_data=False):
outputs = outputs.contiguous().float()
masks_flip = 1-masks
flip_masked_outputs = outputs.masked_fill(masks_flip.bool(), 0)
flip_masked_images = images.masked_fill(masks_flip.bool(), 0)
ssim_fun = SSIM()
psnr_fun = PSNR()
if not collect_data:
mask_count = torch.count_nonzero(masks)
loss = F.mse_loss(
flip_masked_outputs,
flip_masked_images.float(),
reduction="sum"
)
loss = loss/mask_count
ssim = ssim_fun(flip_masked_outputs, flip_masked_images.float())
psnr = psnr_fun(flip_masked_outputs, flip_masked_images.float())
averaged_loss = average_losses_across_data_parallel_group(
[loss, psnr, ssim]
)
return loss, {"loss": averaged_loss[0],
"psnr": averaged_loss[1],
'ssim': averaged_loss[2]}
else:
synth_images = masked_images.float() + flip_masked_outputs
ssim = ssim_fun(synth_images, images.float())
psnr = psnr_fun(synth_images, images.float())
return torch.cat((images, masked_images, synth_images), dim=2), ssim, psnr
def forward_step(data_iterator, model):
"""Forward step."""
timers = get_timers()
# Get the batch.
timers("batch-generator").start()
(
images,
masks,
) = get_batch(data_iterator)
timers("batch-generator").stop()
masked_images = images.masked_fill(masks.bool(), 0)
outputs = model(masked_images)
# Forward mode
return outputs, partial(loss_func, images, masks, masked_images)
def process_non_loss_data(data, iteration, writer):
psnr_sum = 0
ssim_sum = 0
for (output_tb, ssim, psnr) in data:
output_tb[output_tb < 0] = 0
output_tb[output_tb > 1] = 1
writer.add_images("gt-input-output-vald", output_tb,
global_step=iteration, walltime=None,
dataformats='NCHW')
psnr_sum = psnr_sum + psnr.item()
ssim_sum = ssim_sum + ssim.item()
psnr = psnr_sum/len(data)
ssim = ssim_sum/len(data)
writer.add_scalar('PSNR generate value-validation', psnr, iteration)
writer.add_scalar('SSIM generate value-validation', ssim, iteration)
def train_valid_test_datasets_provider(train_val_test_num_samples):
"""Build train, valid, and test datasets."""
args = get_args()
print_rank_0(
"> building train, validation, and test datasets " "for VIT ..."
)
train_ds, valid_ds = build_train_valid_datasets(
data_path=args.data_path,
image_size=(args.img_h, args.img_w)
)
print_rank_0("> finished creating VIT datasets ...")
return train_ds, valid_ds, None
if __name__ == "__main__":
pretrain(
train_valid_test_datasets_provider,
model_provider,
ModelType.encoder_or_decoder,
forward_step,
process_non_loss_data,
args_defaults={'dataloader_type': 'cyclic', 'vision_pretraining': True}
)
...@@ -25,6 +25,7 @@ from megatron import get_timers ...@@ -25,6 +25,7 @@ from megatron import get_timers
from megatron import mpu from megatron import mpu
from megatron.checkpointing import load_checkpoint from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint from megatron.checkpointing import save_checkpoint
from megatron.model import ModelType
from megatron.training import evaluate_and_print_results from megatron.training import evaluate_and_print_results
from megatron.training import setup_model_and_optimizer from megatron.training import setup_model_and_optimizer
from megatron.training import train_step from megatron.training import train_step
...@@ -153,7 +154,7 @@ def _build_train_valid_dataloaders(train_dataset, valid_dataset, ...@@ -153,7 +154,7 @@ def _build_train_valid_dataloaders(train_dataset, valid_dataset,
return train_dataloader, valid_dataloader return train_dataloader, valid_dataloader
def _train(model, optimizer, lr_scheduler, forward_step, def _train(model, optimizer, opt_param_scheduler, forward_step,
train_dataloader, valid_dataloader, end_of_epoch_callback): train_dataloader, valid_dataloader, end_of_epoch_callback):
"""Train the model.""" """Train the model."""
args = get_args() args = get_args()
...@@ -194,7 +195,7 @@ def _train(model, optimizer, lr_scheduler, forward_step, ...@@ -194,7 +195,7 @@ def _train(model, optimizer, lr_scheduler, forward_step,
start_iteration = 0 start_iteration = 0
# Train for one step. # Train for one step.
out = train_step(forward_step, batch, model, optimizer, lr_scheduler) out = train_step(forward_step, batch, model, optimizer, opt_param_scheduler)
losses_dict, skipped_iter, grad_norm, num_zeros_in_grad = out losses_dict, skipped_iter, grad_norm, num_zeros_in_grad = out
iteration += 1 iteration += 1
...@@ -214,13 +215,13 @@ def _train(model, optimizer, lr_scheduler, forward_step, ...@@ -214,13 +215,13 @@ def _train(model, optimizer, lr_scheduler, forward_step,
if args.adlr_autoresume and \ if args.adlr_autoresume and \
(iteration % args.adlr_autoresume_interval == 0): (iteration % args.adlr_autoresume_interval == 0):
check_adlr_autoresume_termination(iteration, model, check_adlr_autoresume_termination(iteration, model,
optimizer, lr_scheduler) optimizer, opt_param_scheduler)
# Checkpointing # Checkpointing
saved_checkpoint = False saved_checkpoint = False
if args.save and args.save_interval and \ if args.save and args.save_interval and \
iteration % args.save_interval == 0: iteration % args.save_interval == 0:
save_checkpoint(iteration, model, optimizer, lr_scheduler) save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
saved_checkpoint = True saved_checkpoint = True
# Evaluation # Evaluation
...@@ -233,14 +234,14 @@ def _train(model, optimizer, lr_scheduler, forward_step, ...@@ -233,14 +234,14 @@ def _train(model, optimizer, lr_scheduler, forward_step,
# Exiting based on iterations # Exiting based on iterations
if args.exit_interval and iteration % args.exit_interval == 0: if args.exit_interval and iteration % args.exit_interval == 0:
if not saved_checkpoint: if not saved_checkpoint:
save_checkpoint(iteration, model, optimizer, lr_scheduler) save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
torch.distributed.barrier() torch.distributed.barrier()
print_rank_0('exiting program at iteration {}'.format(iteration)) print_rank_0('exiting program at iteration {}'.format(iteration))
sys.exit() sys.exit()
# Checkpointing at the end of each epoch. # Checkpointing at the end of each epoch.
if args.save: if args.save:
save_checkpoint(iteration, model, optimizer, lr_scheduler) save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
# Callback at the end of each epoch. # Callback at the end of each epoch.
if end_of_epoch_callback is not None: if end_of_epoch_callback is not None:
...@@ -248,6 +249,7 @@ def _train(model, optimizer, lr_scheduler, forward_step, ...@@ -248,6 +249,7 @@ def _train(model, optimizer, lr_scheduler, forward_step,
def finetune(train_valid_datasets_provider, model_provider, def finetune(train_valid_datasets_provider, model_provider,
model_type=ModelType.encoder_or_decoder,
forward_step=_cross_entropy_forward_step, forward_step=_cross_entropy_forward_step,
end_of_epoch_callback_provider=None, end_of_epoch_callback_provider=None,
task_collate_fn=None): task_collate_fn=None):
...@@ -277,7 +279,7 @@ def finetune(train_valid_datasets_provider, model_provider, ...@@ -277,7 +279,7 @@ def finetune(train_valid_datasets_provider, model_provider,
# Build model, optimizer and learning rate scheduler. # Build model, optimizer and learning rate scheduler.
timers('model and optimizer').start() timers('model and optimizer').start()
model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider) model, optimizer, opt_param_scheduler = setup_model_and_optimizer(model_provider, model_type)
timers('model and optimizer').stop() timers('model and optimizer').stop()
# If pretrained checkpoint is provided and we have not trained for # If pretrained checkpoint is provided and we have not trained for
...@@ -305,7 +307,7 @@ def finetune(train_valid_datasets_provider, model_provider, ...@@ -305,7 +307,7 @@ def finetune(train_valid_datasets_provider, model_provider,
# Finetune the model. # Finetune the model.
if args.epochs > 0: if args.epochs > 0:
_train(model, optimizer, lr_scheduler, forward_step, _train(model, optimizer, opt_param_scheduler, forward_step,
train_dataloader, valid_dataloader, end_of_epoch_callback) train_dataloader, valid_dataloader, end_of_epoch_callback)
# Or just evaluate. # Or just evaluate.
else: else:
......
# Multi-Stage Prompting for Knowledgeable Dialogue Generation
Below we present the steps to run our multi-stage dialogue prompting (MSDP) framework.
## Multi-Stage Dialogue Prompting
### Data Preparation
1. Dataset Download: [Wizard of Wikipedia](https://parl.ai/projects/wizard_of_wikipedia/) and [Wizard of Internet](https://parl.ai/projects/sea/)
2. Data Processing: We provide the script to run the [`data processing`](../../examples/msdp/data_processing.sh) of the datatsets.
### Stage-1: Prompting for Knowledge Generation
1. We provide the script to perform the [`first-stage prompting`](../../examples/msdp/prompt_knwl_gen.sh) for the knowledge generation.
2. We provide the [`evaluation script`](../../examples/msdp/eval_knwl_generation.sh) for the automatic evaluation (i.e., F1, BLEU, METEOR, and ROUGE-L) of the knowledge generation.
### Stage-2: Prompting for Response Generation
1. We provide the script to [`prepare the input file`](../../examples/msdp/prep_resp_gen.sh) for the response generation (based on the previously generated knowledge file).
2. We provide the script to perform the [`second-stage prompting`](../../examples/msdp/prompt_resp_gen.sh) for the response generation.
3. We provide the [`evaluation script`](../../examples/msdp/eval_resp_generation.sh) for the automatic evaluation (i.e., F1, KF1, BLEU, METEOR, and ROUGE-L) of the response generation.
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Model evaluation"""
from megatron import get_args
from megatron import print_rank_0
from tasks.msdp.metrics import F1Metric
from tqdm import tqdm
def evaluate_f1(guess_file, answer_file):
"""Evaluating F1 Score"""
guess_list = []
print_rank_0('reading %s' % guess_file)
with open(guess_file, "r") as f:
for i, line in enumerate(tqdm(f)):
line = line.strip()
if "<|endoftext|>" in line:
line = line.replace("<|endoftext|>", "")
guess_list.append(line)
answer_list = []
print_rank_0('reading %s' % answer_file)
with open(answer_file, "r") as f:
for i, line in enumerate(tqdm(f)):
line = line.strip()
if line == "no_passages_used":
line = ""
answer_list.append(line)
assert len(guess_list) == len(answer_list), \
"lengths of guess and answer are different!"
precision, recall, f1 = F1Metric.compute_all_pairs(guess_list, answer_list)
print_rank_0('Precision: %.4f; recall: %.4f; f1: %.4f' % (precision, recall, f1))
print_rank_0('done :-)')
def main():
args = get_args()
evaluate_f1(args.guess_file, args.answer_file)
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Run multi-stage dialogue prompting (MSDP)."""
import os
import sys
sys.path.append(os.path.abspath(os.path.join(
os.path.join(os.path.dirname(__file__), os.path.pardir), os.path.pardir)))
from megatron import get_args
from megatron.initialize import initialize_megatron
def get_tasks_args(parser):
"""Provide extra arguments required for tasks."""
group = parser.add_argument_group(title='tasks')
# parameters for the knowledgeable dialogue generation
group.add_argument('--task', type=str, required=True,
help='Task name.')
group.add_argument("--sample-input-file", type=str, default=None,
help='Get input from file instead of interactive mode, '
'each line is an input.')
group.add_argument("--sample-output-file", type=str, default=None,
help='Output file got from --sample-input-file')
group.add_argument('--prompt-file', type=str, default=None,
help='prompting file')
group.add_argument('--prompt-type', type=str, default=None,
choices=['knowledge', 'response'],
help='prompt type (knowledge or response)')
group.add_argument('--num-prompt-examples', type=int, default=10,
help='number of prompt examples')
group.add_argument('--guess-file', type=str, default=None,
help='datapath for generated sentences')
group.add_argument('--answer-file', type=str, default=None,
help='datapath for golden sentences')
group.add_argument('--out-seq-length', type=int, default=100,
help='output sequence length')
group.add_argument('--api-prompt', default=False, action="store_true",
help='setup model api for prompting')
group.add_argument('--megatron-api-url', type=str, default=None,
help='url of the megatron api')
return parser
if __name__ == '__main__':
initialize_megatron(extra_args_provider=get_tasks_args)
args = get_args()
if args.num_layers_per_virtual_pipeline_stage is not None:
print("Interleaved pipeline schedule is not yet supported for downstream tasks.")
exit()
if args.task == 'MSDP-PROMPT':
from tasks.msdp.prompt import main
elif args.task == 'MSDP-EVAL-F1':
from tasks.msdp.evaluate import main
else:
raise NotImplementedError('Task {} is not implemented.'.format(
args.task))
main()
# The following code is adapted from
# https://github.com/facebookresearch/ParlAI/blob/master/parlai/core/metrics.py,
# which is licensed under the MIT license. More details on the license can be
# found at https://github.com/facebookresearch/ParlAI/blob/master/LICENSE.
"""Provides standard metric evaluations for dialog."""
from collections import Counter
from typing import List
import numpy as np
import re
re_art = re.compile(r'\b(a|an|the)\b')
re_punc = re.compile(r'[!"#$%&()*+,-./:;<=>?@\[\]\\^`{|}~_\']')
def normalize_answer(s):
"""
Lower text and remove punctuation, articles and extra whitespace.
"""
s = s.lower()
s = re_punc.sub(' ', s)
s = re_art.sub(' ', s)
s = ' '.join(s.split())
return s
class F1Metric:
"""
Helper class which computes token-level F1.
"""
@staticmethod
def _prec_recall_f1_score(pred_items, gold_items):
"""
Compute precision, recall and f1 given a set of gold and prediction items.
:param pred_items: iterable of predicted values
:param gold_items: iterable of gold values
:return: tuple (p, r, f1) for precision, recall, f1
"""
common = Counter(gold_items) & Counter(pred_items)
num_same = sum(common.values())
if num_same == 0:
return 0, 0, 0
precision = 1.0 * num_same / len(pred_items)
recall = 1.0 * num_same / len(gold_items)
f1 = (2 * precision * recall) / (precision + recall)
return precision, recall, f1
@staticmethod
def compute_each_pair(guess: str, answer: str):
if answer == "":
return None, None, None
if guess == "":
return 0, 0, 0
g_tokens = normalize_answer(guess).split()
a_tokens = normalize_answer(answer).split()
precision, recall, f1 = F1Metric._prec_recall_f1_score(g_tokens, a_tokens)
return precision, recall, f1
@staticmethod
def compute_all_pairs(guesses: List[str], answers: List[str]):
# additional augment:
assert len(guesses) == len(answers)
precision_list, recall_list, f1_list = [], [], []
for guess, answer in zip(guesses, answers):
precision, recall, f1 = F1Metric.compute_each_pair(guess, answer)
if precision is None or recall is None or f1 is None:
continue
precision_list.append(precision)
recall_list.append(recall)
f1_list.append(f1)
return np.mean(precision_list), np.mean(recall_list), np.mean(f1_list)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment