Commit 458d7785 authored by rprenger's avatar rprenger
Browse files

Fixing merge conflict

parents d7bf1ab5 9c5a830f
...@@ -87,18 +87,100 @@ def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False): ...@@ -87,18 +87,100 @@ def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False):
def gather_split_1d_tensor(tensor): def gather_split_1d_tensor(tensor):
"""Opposite of above function, gather values from model parallel ranks.""" """Opposite of above function, gather values from model parallel ranks."""
world_size = get_tensor_model_parallel_world_size() numel_gathered = torch.numel(tensor) * \
numel = torch.numel(tensor) get_tensor_model_parallel_world_size()
numel_gathered = world_size * numel
gathered = torch.empty(numel_gathered, dtype=tensor.dtype, gathered = torch.empty(numel_gathered, dtype=tensor.dtype,
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
requires_grad=False) requires_grad=False)
chunks = [gathered[i*numel:(i+1)*numel] for i in range(world_size)] # TODO: This API is experimental in pytorch (as of Feb 2022) and
torch.distributed.all_gather(chunks, tensor, # this might break in future pytorch releases. We chose this API
group=get_tensor_model_parallel_group()) # as opposed to torch.distributed.all_gather for efficiency reasons.
# This API calls directly NCCL all-gather versus the former does
# internal copies and can potentially cause slow down.
torch.distributed._all_gather_base(gathered, tensor,
group=get_tensor_model_parallel_group())
return gathered return gathered
def _kernel_make_viewless_tensor(inp, requires_grad):
'''Make a viewless tensor.
View tensors have the undesirable side-affect of retaining a reference
to the originally-viewed tensor, even after manually setting the '.data'
field. This method creates a new tensor that links to the old tensor's
data, without linking the viewed tensor, referenced via the '._base'
field.
'''
out = torch.empty(
(1,),
dtype = inp.dtype,
device = inp.device,
requires_grad = requires_grad,
)
out.data = inp.data
return out
class MakeViewlessTensor(torch.autograd.Function):
'''
Autograd function to make a viewless tensor.
This function should be used in cases where the computation graph needs
to be propagated, but we only want a viewless tensor (e.g.,
ParallelTransformer's hidden_states). Call this function by passing
'keep_graph = True' to 'make_viewless_tensor()'.
'''
@staticmethod
def forward(ctx, inp, requires_grad):
return _kernel_make_viewless_tensor(inp, requires_grad)
@staticmethod
def backward(ctx, grad_output):
return grad_output, None
def make_viewless_tensor(inp, requires_grad, keep_graph):
'''
Entry-point for creating viewless tensors.
This method should be used, rather than calling 'MakeViewlessTensor'
or '_kernel_make_viewless_tensor' directly. This method acts as a
switch for determining if an autograd function or a regular method
should be used to create the tensor.
'''
# return tensor as-is, if not a 'view'
if inp._base is None:
return inp
# create viewless tensor
if keep_graph:
return MakeViewlessTensor.apply(inp, requires_grad)
else:
return _kernel_make_viewless_tensor(inp, requires_grad)
def assert_viewless_tensor(tensor, extra_msg = None):
'''Assert that a tensor is not a view (i.e., its '._base' field is
not set).'''
if isinstance(tensor, list):
[ assert_viewless_tensor(t) for t in tensor ]
return tensor
if not isinstance(tensor, torch.Tensor):
return tensor
assert tensor._base is None, (
"Ensure tensor._base is None before setting tensor.data or storing "
"tensor to memory buffer. Otherwise, a memory leak will occur (and "
"likely accumulate over iterations). %s"
) % extra_msg
return tensor
def safely_set_viewless_tensor_data(tensor, new_data_tensor):
'''Safely set tensor's '.data' field.
Check first that the tensor is viewless (i.e., '._base' not set). If not,
raise an exception.
'''
assert_viewless_tensor(tensor, extra_msg = "FYI, tensor._base has shape %s, and new_data_tensor has shape %s." % ("--" if tensor._base is None else tensor._base.shape, new_data_tensor.shape))
tensor.data = new_data_tensor
class CudaRNGStatesTracker: class CudaRNGStatesTracker:
"""Tracker for the cuda RNG states. """Tracker for the cuda RNG states.
...@@ -242,8 +324,9 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -242,8 +324,9 @@ class CheckpointFunction(torch.autograd.Function):
# the chunk corresponding to the current rank. # the chunk corresponding to the current rank.
if distribute_checkpointed_activations: if distribute_checkpointed_activations:
ctx.input_0_shape = args[0].data.shape ctx.input_0_shape = args[0].data.shape
args[0].data = split_tensor_into_1d_equal_chunks(args[0].data, safely_set_viewless_tensor_data(
new_buffer=True) args[0],
split_tensor_into_1d_equal_chunks(args[0].data, new_buffer=True))
# Store everything. # Store everything.
ctx.save_for_backward(*args) ctx.save_for_backward(*args)
...@@ -257,8 +340,9 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -257,8 +340,9 @@ class CheckpointFunction(torch.autograd.Function):
"please use .backward() if possible") "please use .backward() if possible")
inputs = ctx.saved_tensors inputs = ctx.saved_tensors
if ctx.distribute_checkpointed_activations: if ctx.distribute_checkpointed_activations:
inputs[0].data = gather_split_1d_tensor(inputs[0].data) safely_set_viewless_tensor_data(
inputs[0].data = inputs[0].data.view(ctx.input_0_shape) inputs[0],
gather_split_1d_tensor(inputs[0].data).view(ctx.input_0_shape))
# Store the current states. # Store the current states.
bwd_cpu_rng_state = torch.get_rng_state() bwd_cpu_rng_state = torch.get_rng_state()
......
...@@ -23,35 +23,68 @@ from .grad_scaler import ConstantGradScaler, DynamicGradScaler ...@@ -23,35 +23,68 @@ from .grad_scaler import ConstantGradScaler, DynamicGradScaler
from .optimizer import Float16OptimizerWithFloat16Params, FP32Optimizer from .optimizer import Float16OptimizerWithFloat16Params, FP32Optimizer
def _get_params_for_weight_decay_optimization(modules): def get_param_groups(modules,
"""Divide params into with-weight-decay and without-weight-decay groups. no_weight_decay_cond,
Layernorms and baises will have no weight decay but the rest will. scale_lr_cond,
lr_mult):
"""creates param groups based on weight decay condition (regularized vs non regularized)
and learning rate scale condition (args.lr vs lr_mult * args.lr)
scale_lr_cond is used during finetuning where head of the network requires a scaled
version of the base learning rate.
""" """
wd_no_scale_lr = []
weight_decay_params = {'params': []} wd_scale_lr = []
no_weight_decay_params = {'params': [], 'weight_decay': 0.0} no_wd_no_scale_lr = []
no_wd_scale_lr = []
for module in modules: for module in modules:
for module_ in module.modules(): for name, param in module.named_parameters():
if isinstance(module_, LayerNorm): if not param.requires_grad:
no_weight_decay_params['params'].extend( continue
[p for p in list(module_._parameters.values())
if p is not None]) if no_weight_decay_cond is not None:
no_wd = no_weight_decay_cond(name, param)
else: else:
weight_decay_params['params'].extend( # do not regularize biases nor Norm parameters
[p for n, p in list(module_._parameters.items()) no_wd = name.endswith(".bias") or len(param.shape) == 1
if p is not None and n != 'bias'])
no_weight_decay_params['params'].extend(
[p for n, p in list(module_._parameters.items())
if p is not None and n == 'bias'])
return weight_decay_params, no_weight_decay_params if scale_lr_cond is not None:
scale_lr = scale_lr_cond(name, param)
else:
scale_lr = False
if not no_wd and not scale_lr:
wd_no_scale_lr.append(param)
elif not no_wd and scale_lr:
wd_scale_lr.append(param)
elif no_wd and not scale_lr:
no_wd_no_scale_lr.append(param)
else:
no_wd_scale_lr.append(param)
def get_megatron_optimizer(model): param_groups = []
if len(wd_no_scale_lr):
param_groups.append({'params': wd_no_scale_lr, 'wd_mult': 1.0, 'lr_mult': 1.0})
if len(wd_scale_lr):
param_groups.append({'params': wd_scale_lr, 'wd_mult': 1.0, 'lr_mult': lr_mult})
if len(no_wd_no_scale_lr):
param_groups.append({'params': no_wd_no_scale_lr, 'wd_mult': 0.0, 'lr_mult': 1.0})
if len(no_wd_scale_lr):
param_groups.append({'params': no_wd_scale_lr, 'wd_mult': 0.0, 'lr_mult': lr_mult})
return param_groups
def get_megatron_optimizer(model,
no_weight_decay_cond=None,
scale_lr_cond=None,
lr_mult=1.0):
args = get_args() args = get_args()
# Base optimizer. # Base optimizer.
param_groups = _get_params_for_weight_decay_optimization(model) param_groups = get_param_groups(model,
no_weight_decay_cond,
scale_lr_cond,
lr_mult)
if args.optimizer == 'adam': if args.optimizer == 'adam':
optimizer = Adam(param_groups, optimizer = Adam(param_groups,
lr=args.lr, lr=args.lr,
......
...@@ -13,19 +13,20 @@ ...@@ -13,19 +13,20 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Learning rate decay functions.""" """Learning rate decay and weight decay incr functions."""
import math import math
from megatron import print_rank_0 from megatron import print_rank_0
class AnnealingLR(object): class OptimizerParamScheduler(object):
"""Anneals the learning rate.""" """Anneals learning rate and weight decay"""
def __init__(self, optimizer, max_lr, min_lr, def __init__(self, optimizer, max_lr, min_lr,
warmup_steps, decay_steps, decay_style, lr_warmup_steps, lr_decay_steps, lr_decay_style,
use_checkpoint_lr_scheduler=True, start_wd, end_wd, wd_incr_steps, wd_incr_style,
override_lr_scheduler=False): use_checkpoint_opt_param_scheduler=True,
override_opt_param_scheduler=False):
# Class values. # Class values.
self.optimizer = optimizer self.optimizer = optimizer
...@@ -35,24 +36,55 @@ class AnnealingLR(object): ...@@ -35,24 +36,55 @@ class AnnealingLR(object):
assert self.min_lr >= 0.0 assert self.min_lr >= 0.0
assert self.max_lr >= self.min_lr assert self.max_lr >= self.min_lr
self.warmup_steps = warmup_steps self.lr_warmup_steps = lr_warmup_steps
self.num_steps = 0 self.num_steps = 0
self.decay_steps = decay_steps self.lr_decay_steps = lr_decay_steps
assert self.decay_steps > 0 assert self.lr_decay_steps > 0
assert self.warmup_steps < self.decay_steps assert self.lr_warmup_steps < self.lr_decay_steps
self.decay_style = decay_style self.lr_decay_style = lr_decay_style
self.override_lr_scheduler = override_lr_scheduler self.start_wd = start_wd
self.use_checkpoint_lr_scheduler = use_checkpoint_lr_scheduler self.end_wd = end_wd
if self.override_lr_scheduler: assert self.start_wd >= 0.0
assert not self.use_checkpoint_lr_scheduler, 'both override and '\ assert self.end_wd >= self.start_wd
self.wd_incr_steps = wd_incr_steps
self.wd_incr_style = wd_incr_style
self.override_opt_param_scheduler = override_opt_param_scheduler
self.use_checkpoint_opt_param_scheduler = use_checkpoint_opt_param_scheduler
if self.override_opt_param_scheduler:
assert not self.use_checkpoint_opt_param_scheduler, 'both override and '\
'use-checkpoint are set.' 'use-checkpoint are set.'
# Set the learning rate # Set the learning rate
self.step(0) self.step(0)
print_rank_0('> learning rate decay style: {}'.format(self.lr_decay_style))
def get_wd(self):
""" Weight decay incr functions"""
if self.num_steps > self.wd_incr_steps:
return self.end_wd
if self.wd_incr_style == 'constant':
assert self.start_wd == self.end_wd
return self.end_wd
print_rank_0('> learning rate decay style: {}'.format(self.decay_style)) incr_ratio = float(self.num_steps) / float(self.wd_incr_steps)
assert incr_ratio >= 0.0
assert incr_ratio <= 1.0
delta_wd = self.end_wd - self.start_wd
if self.wd_incr_style == 'linear':
coeff = incr_ratio
elif self.wd_incr_style == 'cosine':
coeff = 0.5 * (math.cos(math.pi * (1 - incr_ratio)) + 1.0)
else:
raise Exception('{} weight decay increment style is not supported.'.format(
self.wd_incr_style))
return self.start_wd + coeff * delta_wd
def get_lr(self): def get_lr(self):
...@@ -60,33 +92,33 @@ class AnnealingLR(object): ...@@ -60,33 +92,33 @@ class AnnealingLR(object):
https://openreview.net/pdf?id=BJYwwY9ll pg. 4""" https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""
# Use linear warmup for the initial part. # Use linear warmup for the initial part.
if self.warmup_steps > 0 and self.num_steps <= self.warmup_steps: if self.lr_warmup_steps > 0 and self.num_steps <= self.lr_warmup_steps:
return self.max_lr * float(self.num_steps) / \ return self.max_lr * float(self.num_steps) / \
float(self.warmup_steps) float(self.lr_warmup_steps)
# If the learning rate is constant, just return the initial value. # If the learning rate is constant, just return the initial value.
if self.decay_style == 'constant': if self.lr_decay_style == 'constant':
return self.max_lr return self.max_lr
# For any steps larger than `self.decay_steps`, use `self.min_lr`. # For any steps larger than `self.lr_decay_steps`, use `self.min_lr`.
if self.num_steps > self.decay_steps: if self.num_steps > self.lr_decay_steps:
return self.min_lr return self.min_lr
# If we are done with the warmup period, use the decay style. # If we are done with the warmup period, use the decay style.
num_steps_ = self.num_steps - self.warmup_steps num_steps_ = self.num_steps - self.lr_warmup_steps
decay_steps_ = self.decay_steps - self.warmup_steps decay_steps_ = self.lr_decay_steps - self.lr_warmup_steps
decay_ratio = float(num_steps_) / float(decay_steps_) decay_ratio = float(num_steps_) / float(decay_steps_)
assert decay_ratio >= 0.0 assert decay_ratio >= 0.0
assert decay_ratio <= 1.0 assert decay_ratio <= 1.0
delta_lr = self.max_lr - self.min_lr delta_lr = self.max_lr - self.min_lr
if self.decay_style == 'linear': if self.lr_decay_style == 'linear':
coeff = (1.0 - decay_ratio) coeff = (1.0 - decay_ratio)
elif self.decay_style == 'cosine': elif self.lr_decay_style == 'cosine':
coeff = 0.5 * (math.cos(math.pi * decay_ratio) + 1.0) coeff = 0.5 * (math.cos(math.pi * decay_ratio) + 1.0)
else: else:
raise Exception('{} decay style is not supported.'.format( raise Exception('{} decay style is not supported.'.format(
self.decay_style)) self.lr_decay_style))
return self.min_lr + coeff * delta_lr return self.min_lr + coeff * delta_lr
...@@ -95,18 +127,24 @@ class AnnealingLR(object): ...@@ -95,18 +127,24 @@ class AnnealingLR(object):
"""Set lr for all parameters groups.""" """Set lr for all parameters groups."""
self.num_steps += increment self.num_steps += increment
new_lr = self.get_lr() new_lr = self.get_lr()
new_wd = self.get_wd()
for group in self.optimizer.param_groups: for group in self.optimizer.param_groups:
group['lr'] = new_lr group['lr'] = new_lr * group.get('lr_mult', 1.0)
group['weight_decay'] = new_wd * group.get('wd_mult', 1.0)
def state_dict(self): def state_dict(self):
state_dict = { state_dict = {
'max_lr': self.max_lr, 'max_lr': self.max_lr,
'warmup_steps': self.warmup_steps, 'lr_warmup_steps': self.lr_warmup_steps,
'num_steps': self.num_steps, 'num_steps': self.num_steps,
'decay_style': self.decay_style, 'lr_decay_style': self.lr_decay_style,
'decay_steps': self.decay_steps, 'lr_decay_steps': self.lr_decay_steps,
'min_lr': self.min_lr 'min_lr': self.min_lr,
'start_wd': self.start_wd,
'end_wd': self.end_wd,
'wd_incr_style': self.wd_incr_style,
'wd_incr_steps': self.wd_incr_steps
} }
return state_dict return state_dict
...@@ -114,13 +152,13 @@ class AnnealingLR(object): ...@@ -114,13 +152,13 @@ class AnnealingLR(object):
def _check_and_set(self, cls_value, sd_value, name): def _check_and_set(self, cls_value, sd_value, name):
"""Auxiliary function for checking the values in the checkpoint and """Auxiliary function for checking the values in the checkpoint and
setting them.""" setting them."""
if self.override_lr_scheduler: if self.override_opt_param_scheduler:
print_rank_0(' > overriding {} value to {}'.format(name, cls_value)) print_rank_0(' > overriding {} value to {}'.format(name, cls_value))
return cls_value return cls_value
if not self.use_checkpoint_lr_scheduler: if not self.use_checkpoint_opt_param_scheduler:
assert cls_value == sd_value, \ assert cls_value == sd_value, \
f'AnnealingLR: class input value {cls_value} and checkpoint' \ f'OptimizerParamScheduler: class input value {cls_value} and checkpoint' \
f'value {sd_value} for {name} do not match' f'value {sd_value} for {name} do not match'
print_rank_0(' > using checkpoint value {} for {}'.format(sd_value, print_rank_0(' > using checkpoint value {} for {}'.format(sd_value,
name)) name))
...@@ -140,25 +178,57 @@ class AnnealingLR(object): ...@@ -140,25 +178,57 @@ class AnnealingLR(object):
'minimum learning rate') 'minimum learning rate')
if 'warmup_iter' in sd: if 'warmup_iter' in sd:
warmup_steps_ = sd['warmup_iter'] lr_warmup_steps_ = sd['warmup_iter']
elif 'warmup_steps' in sd:
lr_warmup_steps_ = sd['warmup_steps']
else: else:
warmup_steps_ = sd['warmup_steps'] lr_warmup_steps_ = sd['lr_warmup_steps']
self.warmup_steps = self._check_and_set(self.warmup_steps, self.lr_warmup_steps = self._check_and_set(self.lr_warmup_steps,
warmup_steps_, lr_warmup_steps_,
'warmup iterations') 'warmup iterations')
if 'end_iter' in sd: if 'end_iter' in sd:
decay_steps_ = sd['end_iter'] lr_decay_steps_ = sd['end_iter']
elif 'decay_steps' in sd:
lr_decay_steps_ = sd['decay_steps']
else: else:
decay_steps_ = sd['decay_steps'] lr_decay_steps_ = sd['lr_decay_steps']
self.decay_steps = self._check_and_set(self.decay_steps, decay_steps_, self.lr_decay_steps = self._check_and_set(self.lr_decay_steps, lr_decay_steps_,
'total number of iterations') 'total number of iterations')
self.decay_style = self._check_and_set(self.decay_style,
sd['decay_style'], if 'decay_style' in sd:
'decay style') lr_decay_style_ = sd['decay_style']
else:
lr_decay_style_ = sd['lr_decay_style']
self.lr_decay_style = self._check_and_set(self.lr_decay_style,
lr_decay_style_,
'learning rate decay style')
if 'num_iters' in sd: if 'num_iters' in sd:
num_steps = sd['num_iters'] num_steps = sd['num_iters']
else: else:
num_steps = sd['num_steps'] num_steps = sd['num_steps']
self.step(increment=num_steps) self.step(increment=num_steps)
if 'start_wd' in sd:
self.start_wd = self._check_and_set(self.start_wd,
sd['start_wd'],
"start weight decay")
self.end_wd = self._check_and_set(self.end_wd,
sd['end_wd'],
"end weight decay")
self.wd_incr_steps = self._check_and_set(self.wd_incr_steps,
sd['wd_incr_steps'],
"total number of weight decay iterations")
self.wd_incr_style = self._check_and_set(self.wd_incr_style,
sd['wd_incr_style'],
"weight decay incr style")
...@@ -142,10 +142,16 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, ...@@ -142,10 +142,16 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
if recv_prev: if recv_prev:
tensor_recv_prev = mpu.gather_split_1d_tensor( tensor_recv_prev = mpu.gather_split_1d_tensor(
tensor_recv_prev).view(tensor_shape).requires_grad_() tensor_recv_prev).view(tensor_shape).requires_grad_()
tensor_recv_prev = mpu.make_viewless_tensor(tensor_recv_prev,
requires_grad = True,
keep_graph = False)
if recv_next: if recv_next:
tensor_recv_next = mpu.gather_split_1d_tensor( tensor_recv_next = mpu.gather_split_1d_tensor(
tensor_recv_next).view(tensor_shape).requires_grad_() tensor_recv_next).view(tensor_shape).requires_grad_()
tensor_recv_next = mpu.make_viewless_tensor(tensor_recv_next,
requires_grad = True,
keep_graph = False)
return tensor_recv_prev, tensor_recv_next return tensor_recv_prev, tensor_recv_next
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
from contextlib import contextmanager from contextlib import contextmanager
import torch import torch
from torch.autograd.variable import Variable
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron import get_args from megatron import get_args
...@@ -33,17 +34,80 @@ def get_forward_backward_func(): ...@@ -33,17 +34,80 @@ def get_forward_backward_func():
if mpu.get_pipeline_model_parallel_world_size() > 1: if mpu.get_pipeline_model_parallel_world_size() > 1:
if args.virtual_pipeline_model_parallel_size is not None: if args.virtual_pipeline_model_parallel_size is not None:
forward_backward_func = forward_backward_pipelining_with_interleaving forward_backward_func = forward_backward_pipelining_with_interleaving
assert get_num_microbatches() % args.pipeline_model_parallel_size == 0, \ assert get_num_microbatches() % \
'number of microbatches is not divisible by pipeline-parallel ' \ args.pipeline_model_parallel_size == 0, \
'size when using interleaved schedule' 'number of microbatches (%d) is not divisible by pipeline-' \
'model-parallel-size (%d) when using interleaved schedule' % (
get_num_microbatches(),
args.pipeline_model_parallel_size,
)
else: else:
forward_backward_func = forward_backward_pipelining_without_interleaving forward_backward_func = forward_backward_pipelining_without_interleaving
else: else:
forward_backward_func = forward_backward_no_pipelining forward_backward_func = forward_backward_no_pipelining
return forward_backward_func return forward_backward_func
def deallocate_output_tensor(out):
def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_reduced): '''Pseudo-deallocate (i.e., set to scalar) the output tensor's '.data' field.
This method should be called right after the output tensor has been
sent to the next pipeline stage. At this point, the output tensor is
only useful for its '.grad_fn' field, and not its '.data'.
'''
if out is None:
return
assert isinstance(out, torch.Tensor), \
"expected Tensor, found %s." % type(out).__name__
assert out._base is None, \
"counter-productive to free a view of another tensor."
out.data = torch.empty(
(1,),
device = out.device,
dtype = out.dtype,
)
def custom_backward(output, grad_output):
'''Directly call C++ autograd engine.
To make the 'deallocate_output_tensor' (above) optimization work, the C++
autograd engine must be called directly, bypassing Pytorch's
torch.autograd.backward. Pytorch's 'backward' checks that the output and
grad have the same shape, while C++'s 'backward' does not.
'''
assert output.numel() == 1, \
"output should be pseudo-'freed' in schedule, to optimize memory"
assert isinstance(output, torch.Tensor), \
"output == '%s'." % type(output).__name__
assert isinstance(grad_output, (torch.Tensor, type(None))), \
"grad_output == '%s'." % type(grad_output).__name__
# Handle scalar output
if grad_output is None:
assert output.numel() == 1, "implicit grad requires scalar output."
grad_output = torch.ones_like(
output,
memory_format = torch.preserve_format,
)
# Call c++ engine [ see torch/csrc/autograd/python_engine.cpp ]
Variable._execution_engine.run_backward(
tensors = (output,),
grad_tensors = (grad_output,),
keep_graph = False,
create_graph = False,
inputs = tuple(),
allow_unreachable=True,
accumulate_grad=True,
)
def forward_step(forward_step_func,
data_iterator,
model,
input_tensor,
forward_data_store,
collect_non_loss_data=False):
"""Forward step for passed-in model. """Forward step for passed-in model.
If first stage, input tensor is obtained from data_iterator, otherwise If first stage, input tensor is obtained from data_iterator, otherwise
...@@ -65,10 +129,15 @@ def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_r ...@@ -65,10 +129,15 @@ def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_r
unwrapped_model.set_input_tensor(input_tensor) unwrapped_model.set_input_tensor(input_tensor)
output_tensor, loss_func = forward_step_func(data_iterator, model) output_tensor, loss_func = forward_step_func(data_iterator, model)
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
output_tensor = loss_func(output_tensor) if not collect_non_loss_data:
loss, loss_reduced = output_tensor output_tensor = loss_func(output_tensor)
output_tensor = loss / get_num_microbatches() loss, loss_reduced = output_tensor
losses_reduced.append(loss_reduced) output_tensor = loss / get_num_microbatches()
forward_data_store.append(loss_reduced)
else:
data = loss_func(output_tensor, non_loss_data=True)
forward_data_store.append(data)
timers('forward-compute').stop() timers('forward-compute').stop()
# If T5 model (or other model with encoder and decoder) # If T5 model (or other model with encoder and decoder)
...@@ -116,7 +185,7 @@ def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad): ...@@ -116,7 +185,7 @@ def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad):
# Backward pass. # Backward pass.
if output_tensor_grad[0] is None: if output_tensor_grad[0] is None:
output_tensor = optimizer.scale_loss(output_tensor[0]) output_tensor = optimizer.scale_loss(output_tensor[0])
torch.autograd.backward(output_tensor[0], grad_tensors=output_tensor_grad[0]) custom_backward(output_tensor[0], output_tensor_grad[0])
# Collect the grad of the input_tensor. # Collect the grad of the input_tensor.
input_tensor_grad = [None] input_tensor_grad = [None]
...@@ -151,8 +220,12 @@ def dummy_handler(): ...@@ -151,8 +220,12 @@ def dummy_handler():
pass pass
def forward_backward_no_pipelining(forward_step_func, data_iterator, model, def forward_backward_no_pipelining(forward_step_func,
optimizer, timers, forward_only): data_iterator, model,
optimizer,
timers,
forward_only,
collect_non_loss_data=False):
"""Run forward and backward passes with no pipeline parallelism """Run forward and backward passes with no pipeline parallelism
(no inter-stage communication). (no inter-stage communication).
...@@ -164,35 +237,41 @@ def forward_backward_no_pipelining(forward_step_func, data_iterator, model, ...@@ -164,35 +237,41 @@ def forward_backward_no_pipelining(forward_step_func, data_iterator, model,
if isinstance(model, torchDDP): if isinstance(model, torchDDP):
context_handler = model.no_sync context_handler = model.no_sync
losses_reduced = [] forward_data_store = []
input_tensor, output_tensor_grad = None, None input_tensor, output_tensor_grad = None, None
with context_handler(): with context_handler():
for i in range(get_num_microbatches() - 1): for i in range(get_num_microbatches() - 1):
output_tensor = forward_step(forward_step_func, data_iterator, model, output_tensor = forward_step(forward_step_func, data_iterator,
input_tensor, losses_reduced) model, input_tensor, forward_data_store,
collect_non_loss_data)
if not forward_only: if not forward_only:
backward_step(optimizer, input_tensor, output_tensor, backward_step(optimizer, input_tensor, output_tensor,
output_tensor_grad) output_tensor_grad)
# Run computation for last microbatch out of context handler (want to # Run computation for last microbatch out of context handler (want to
# synchronize gradients). # synchronize gradients).
output_tensor = forward_step(forward_step_func, data_iterator, model, output_tensor = forward_step(forward_step_func, data_iterator,
input_tensor, losses_reduced) model, input_tensor, forward_data_store,
collect_non_loss_data)
if not forward_only: if not forward_only:
backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad) backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad)
return losses_reduced return forward_data_store
def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterator, model, def forward_backward_pipelining_with_interleaving(forward_step_func,
optimizer, timers, forward_only): data_iterator, model,
optimizer,
timers,
forward_only,
collect_non_loss_data=False):
"""Run interleaved 1F1B schedule (model split into model chunks), with """Run interleaved 1F1B schedule (model split into model chunks), with
communication between pipeline stages as needed. communication between pipeline stages as needed.
Returns dictionary with losses if the last stage, empty dict otherwise.""" Returns dictionary with losses if the last stage, empty dict otherwise."""
input_tensors = [[] for _ in range(len(model))] input_tensors = [[] for _ in range(len(model))]
output_tensors = [[] for _ in range(len(model))] output_tensors = [[] for _ in range(len(model))]
losses_reduced = [] forward_data_store = []
if not forward_only: if not forward_only:
output_tensor_grads = [[] for _ in range(len(model))] output_tensor_grads = [[] for _ in range(len(model))]
...@@ -252,7 +331,9 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat ...@@ -252,7 +331,9 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
output_tensor = forward_step(forward_step_func, output_tensor = forward_step(forward_step_func,
data_iterator[model_chunk_id], data_iterator[model_chunk_id],
model[model_chunk_id], model[model_chunk_id],
input_tensor, losses_reduced) input_tensor,
forward_data_store,
collect_non_loss_data)
output_tensors[model_chunk_id].append(output_tensor) output_tensors[model_chunk_id].append(output_tensor)
# if forward-only, no need to save tensors for a backward pass # if forward-only, no need to save tensors for a backward pass
...@@ -325,6 +406,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat ...@@ -325,6 +406,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
tensor_shape=tensor_shape, tensor_shape=tensor_shape,
timers=timers) timers=timers)
input_tensors[next_forward_model_chunk_id].append(input_tensor) input_tensors[next_forward_model_chunk_id].append(input_tensor)
deallocate_output_tensor(output_tensor)
# Run 1F1B in steady state. # Run 1F1B in steady state.
for k in range(num_microbatches_remaining): for k in range(num_microbatches_remaining):
...@@ -388,6 +470,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat ...@@ -388,6 +470,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
output_tensor, input_tensor_grad, output_tensor, input_tensor_grad,
recv_prev=recv_prev, recv_next=recv_next, recv_prev=recv_prev, recv_next=recv_next,
tensor_shape=tensor_shape, timers=timers) tensor_shape=tensor_shape, timers=timers)
deallocate_output_tensor(output_tensor)
# Put input_tensor and output_tensor_grad in data structures in the # Put input_tensor and output_tensor_grad in data structures in the
# right location. # right location.
...@@ -417,7 +500,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat ...@@ -417,7 +500,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
tensor_shape=tensor_shape, tensor_shape=tensor_shape,
timers=timers)) timers=timers))
return losses_reduced return forward_data_store
def get_tensor_shapes(rank, model_type): def get_tensor_shapes(rank, model_type):
...@@ -514,13 +597,18 @@ def send_backward_recv_forward(input_tensor_grads, tensor_shapes, timers): ...@@ -514,13 +597,18 @@ def send_backward_recv_forward(input_tensor_grads, tensor_shapes, timers):
return input_tensors return input_tensors
def forward_backward_pipelining_without_interleaving(forward_step_func, data_iterator, def forward_backward_pipelining_without_interleaving(forward_step_func,
model, optimizer, timers, data_iterator,
forward_only): model,
optimizer,
timers,
forward_only,
collect_non_loss_data=False):
"""Run non-interleaved 1F1B schedule, with communication between pipeline """Run non-interleaved 1F1B schedule, with communication between pipeline
stages. stages.
Returns dictionary with losses if the last stage, empty dict otherwise.""" Returns dictionary with losses if the last stage, empty dict otherwise."""
args = get_args()
timers = get_timers() timers = get_timers()
assert len(model) == 1 assert len(model) == 1
...@@ -550,18 +638,20 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite ...@@ -550,18 +638,20 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
if not forward_only: if not forward_only:
input_tensors = [] input_tensors = []
output_tensors = [] output_tensors = []
losses_reduced = [] forward_data_store = []
# Run warmup forward passes. # Run warmup forward passes.
for i in range(num_warmup_microbatches): for i in range(num_warmup_microbatches):
input_tensor = recv_forward(recv_tensor_shapes, timers=timers) input_tensor = recv_forward(recv_tensor_shapes, timers=timers)
output_tensor = forward_step(forward_step_func, data_iterator, model, output_tensor = forward_step(forward_step_func, data_iterator, model,
input_tensor, losses_reduced) input_tensor, forward_data_store,
collect_non_loss_data)
send_forward(output_tensor, send_tensor_shapes, timers=timers) send_forward(output_tensor, send_tensor_shapes, timers=timers)
if not forward_only: if not forward_only:
input_tensors.append(input_tensor) input_tensors.append(input_tensor)
output_tensors.append(output_tensor) output_tensors.append(output_tensor)
deallocate_output_tensor(output_tensor[0])
# Before running 1F1B, need to receive first forward tensor. # Before running 1F1B, need to receive first forward tensor.
# If all microbatches are run in warmup / cooldown phase, then no need to # If all microbatches are run in warmup / cooldown phase, then no need to
...@@ -574,7 +664,8 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite ...@@ -574,7 +664,8 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
last_iteration = (i == (num_microbatches_remaining - 1)) last_iteration = (i == (num_microbatches_remaining - 1))
output_tensor = forward_step(forward_step_func, data_iterator, model, output_tensor = forward_step(forward_step_func, data_iterator, model,
input_tensor, losses_reduced) input_tensor, forward_data_store,
collect_non_loss_data)
if forward_only: if forward_only:
send_forward(output_tensor, send_tensor_shapes, timers=timers) send_forward(output_tensor, send_tensor_shapes, timers=timers)
...@@ -590,6 +681,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite ...@@ -590,6 +681,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
# Add input_tensor and output_tensor to end of list. # Add input_tensor and output_tensor to end of list.
input_tensors.append(input_tensor) input_tensors.append(input_tensor)
output_tensors.append(output_tensor) output_tensors.append(output_tensor)
deallocate_output_tensor(output_tensor[0])
# Pop input_tensor and output_tensor from the start of the list for # Pop input_tensor and output_tensor from the start of the list for
# the backward pass. # the backward pass.
...@@ -622,4 +714,4 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite ...@@ -622,4 +714,4 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
send_backward(input_tensor_grad, recv_tensor_shapes, timers=timers) send_backward(input_tensor_grad, recv_tensor_shapes, timers=timers)
return losses_reduced return forward_data_store
<!-- coding=utf-8-->
<!-- Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.-->
<!---->
<!-- Licensed under the Apache License, Version 2.0 (the "License");-->
<!-- you may not use this file except in compliance with the License.-->
<!-- You may obtain a copy of the License at-->
<!---->
<!-- http://www.apache.org/licenses/LICENSE-2.0-->
<!---->
<!-- Unless required by applicable law or agreed to in writing, software-->
<!-- distributed under the License is distributed on an "AS IS" BASIS,-->
<!-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.-->
<!-- See the License for the specific language governing permissions and-->
<!-- limitations under the License.-->
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8" />
<title>Megatron</title>
<style>
.wrapper {
max-width: 75%;
margin: auto;
}
h1 {
margin: 3rem 0 1rem 0;
padding: 0;
font-size: 1.5rem;
}
textarea {
width: 100%;
min-height: 300px;
resize: none;
border-radius: 8px;
border: 1px solid #ddd;
padding: 0.5rem;
box-shadow: inset 0 0 0.25rem #ddd;
&:focus {
outline: none;
border: 1px solid darken(#ddd, 5%);
box-shadow: inset 0 0 0.5rem darken(#ddd, 5%);
}
}
#the-count {
float: right;
padding: 0.1rem 0 0 0;
font-size: 0.875rem;
}
/* Chat containers */
.container {
font-family: 'Arial', sans-serif;
font-size: 16px;
border: 2px solid #dedede;
background-color: #f1f1f1;
border-radius: 5px;
padding: 15px;
margin: 10px 0;
}
/* Clear floats */
.container::after {
content: "";
clear: both;
display: table;
}
/* Style images */
.container img {
float: left;
max-width: 60px;
width: 100%;
margin-right: 20px;
border-radius: 50%;
}
</style>
</head>
<body>
<div class="wrapper">
<h1>Prompt Megatron</h1>
<textarea name="prompt" id="prompt" maxlength="1024" placeholder="Add prompt"autofocus></textarea>
<label for="tokens_to_generate">Number tokens to generate (1-1024):</label>
<input type="number" id="tokens_to_generate" name="tokens_to_generate" min="10" max="256", value=32>
<button onclick="submit_query()">Submit</button>
<div id="the-count">
<span id="current">0</span>
<span id="maximum">/ 1000</span>
</div>
<textarea name="response" id="response" maxlength="2048" placeholder="Megatron response..."></textarea>
</div>
<script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.5.1/jquery.min.js"></script>
<script type="text/javascript">
function submit_query() {
$("#response").val("Waiting for Megatron response...");
$.ajax({
url:"api",
type:"PUT",
data:JSON.stringify({prompts: [$("#prompt").val()], tokens_to_generate: parseInt($("#tokens_to_generate").val(),10)}),
contentType:"application/json; charset=utf-8",
dataType:"json",
success: function(data){
data.max_len=35;
$("#response").val(data.text);
}
});
}
$('textarea').keyup(function() {
var characterCount = $(this).val().length,
current = $('#current'),
maximum = $('#maximum'),
theCount = $('#the-count');
current.text(characterCount);
if (characterCount >= 800) {
maximum.css('color', '#8f0001');
current.css('color', '#8f0001');
theCount.css('font-weight','bold');
} else {
maximum.css('color','#666');
theCount.css('font-weight','normal');
}
});
</script>
</body>
</html>
...@@ -35,7 +35,10 @@ def generate_and_post_process(model, ...@@ -35,7 +35,10 @@ def generate_and_post_process(model,
top_p_sampling=0.0, top_p_sampling=0.0,
temperature=1.0, temperature=1.0,
add_BOS=False, add_BOS=False,
use_eod_token_for_early_termination=True): use_eod_token_for_early_termination=True,
stop_on_double_eol=False,
stop_on_eol=False,
random_seed=-1):
"""Run inference and post-process outputs, i.e., detokenize, """Run inference and post-process outputs, i.e., detokenize,
move to cpu and convert to list.""" move to cpu and convert to list."""
...@@ -49,7 +52,10 @@ def generate_and_post_process(model, ...@@ -49,7 +52,10 @@ def generate_and_post_process(model,
top_p_sampling=top_p_sampling, top_p_sampling=top_p_sampling,
temperature=temperature, temperature=temperature,
add_BOS=add_BOS, add_BOS=add_BOS,
use_eod_token_for_early_termination=use_eod_token_for_early_termination) use_eod_token_for_early_termination=use_eod_token_for_early_termination,
stop_on_double_eol=stop_on_double_eol,
stop_on_eol=stop_on_eol,
random_seed=random_seed)
# Only post-process on first stage. # Only post-process on first stage.
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
...@@ -74,7 +80,10 @@ def generate(model, ...@@ -74,7 +80,10 @@ def generate(model,
top_p_sampling=0.0, top_p_sampling=0.0,
temperature=1.0, temperature=1.0,
add_BOS=False, add_BOS=False,
use_eod_token_for_early_termination=True): use_eod_token_for_early_termination=True,
stop_on_double_eol=False,
stop_on_eol=False,
random_seed=-1):
"""Given prompts and input parameters, run inference and return: """Given prompts and input parameters, run inference and return:
tokens: prompts plus the generated tokens. tokens: prompts plus the generated tokens.
lengths: length of the prompt + generations. Note that we can lengths: length of the prompt + generations. Note that we can
...@@ -87,8 +96,11 @@ def generate(model, ...@@ -87,8 +96,11 @@ def generate(model,
values = [tokens_to_generate, values = [tokens_to_generate,
return_output_log_probs, return_output_log_probs,
top_k_sampling, top_p_sampling, top_k_sampling, top_p_sampling,
temperature, add_BOS, use_eod_token_for_early_termination] temperature, add_BOS, use_eod_token_for_early_termination,
values_float_tensor = broadcast_float_list(7, float_list=values) stop_on_double_eol,
stop_on_eol,
random_seed]
values_float_tensor = broadcast_float_list(10, float_list=values)
tokens_to_generate = int(values_float_tensor[0].item()) tokens_to_generate = int(values_float_tensor[0].item())
return_output_log_probs = bool(values_float_tensor[1].item()) return_output_log_probs = bool(values_float_tensor[1].item())
top_k_sampling = int(values_float_tensor[2].item()) top_k_sampling = int(values_float_tensor[2].item())
...@@ -96,6 +108,12 @@ def generate(model, ...@@ -96,6 +108,12 @@ def generate(model,
temperature = values_float_tensor[4].item() temperature = values_float_tensor[4].item()
add_BOS = bool(values_float_tensor[5].item()) add_BOS = bool(values_float_tensor[5].item())
use_eod_token_for_early_termination = bool(values_float_tensor[6].item()) use_eod_token_for_early_termination = bool(values_float_tensor[6].item())
stop_on_double_eol = bool(values_float_tensor[7].item())
stop_on_eol = bool(values_float_tensor[8].item())
random_seed = int(values_float_tensor[9].item())
if random_seed != -1:
torch.random.manual_seed(random_seed)
# Tokenize prompts and get the batch. # Tokenize prompts and get the batch.
# Note that these tensors are broadcaseted to all ranks. # Note that these tensors are broadcaseted to all ranks.
...@@ -108,7 +126,7 @@ def generate(model, ...@@ -108,7 +126,7 @@ def generate(model,
if tokens_to_generate == 0: if tokens_to_generate == 0:
return score_and_return_on_first_stage( return score_and_return_on_first_stage(
model, context_tokens_tensor, context_length_tensor) model, context_tokens_tensor, context_length_tensor)
# Main inference function. # Main inference function.
# Note that the outputs are available on the first stage. # Note that the outputs are available on the first stage.
return generate_tokens_probs_and_return_on_first_stage( return generate_tokens_probs_and_return_on_first_stage(
...@@ -117,4 +135,6 @@ def generate(model, ...@@ -117,4 +135,6 @@ def generate(model,
top_k=top_k_sampling, top_k=top_k_sampling,
top_p=top_p_sampling, top_p=top_p_sampling,
temperature=temperature, temperature=temperature,
use_eod_token_for_early_termination=use_eod_token_for_early_termination) use_eod_token_for_early_termination=use_eod_token_for_early_termination,
stop_on_double_eol=stop_on_double_eol,
stop_on_eol=stop_on_eol)
...@@ -96,7 +96,10 @@ def generate_tokens_probs_and_return_on_first_stage( ...@@ -96,7 +96,10 @@ def generate_tokens_probs_and_return_on_first_stage(
return_output_log_probs=False, return_output_log_probs=False,
top_k=0, top_p=0.0, top_k=0, top_p=0.0,
temperature=1.0, temperature=1.0,
use_eod_token_for_early_termination=True): use_eod_token_for_early_termination=True,
stop_on_double_eol=False,
stop_on_eol=False
):
"""Main token generation function. """Main token generation function.
Arguments: Arguments:
model: no interleaving is supported. model: no interleaving is supported.
...@@ -130,6 +133,10 @@ def generate_tokens_probs_and_return_on_first_stage( ...@@ -130,6 +133,10 @@ def generate_tokens_probs_and_return_on_first_stage(
min_prompt_length = lengths.min().item() min_prompt_length = lengths.min().item()
max_sequence_length = tokens.size(1) max_sequence_length = tokens.size(1)
max_sequence_length = min(max_sequence_length, args.max_position_embeddings) max_sequence_length = min(max_sequence_length, args.max_position_embeddings)
# If the context is too big, this happens
if min_prompt_length >= max_sequence_length:
raise ValueError("context length + tokens_to_generate too large")
# forward step. # forward step.
forward_step = ForwardStep(model, batch_size, max_sequence_length) forward_step = ForwardStep(model, batch_size, max_sequence_length)
...@@ -227,8 +234,20 @@ def generate_tokens_probs_and_return_on_first_stage( ...@@ -227,8 +234,20 @@ def generate_tokens_probs_and_return_on_first_stage(
# Check if all the sequences have hit the termination_id. # Check if all the sequences have hit the termination_id.
done = None done = None
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
done_token = (new_sample == termination_id).byte() & \ # TODO(rprenger) These stopping methods are tokenizer dependent
started.byte() # instead tokenization should be in the inference loop so stop sequences can be used
if stop_on_double_eol:
hit_double_eol = (new_sample == 628).byte() & started.byte()
hit_two_eols = (new_sample == 198).byte() & (tokens[:, context_length-1] == 198).byte() & started.byte()
done_token = hit_double_eol | hit_two_eols
elif stop_on_eol:
hit_double_eol = (new_sample == 628).byte() & started.byte()
hit_eol = (new_sample == 198).byte() & started.byte()
done_token = hit_double_eol | hit_eol
else:
done_token = (new_sample == termination_id).byte() & \
started.byte()
just_finished = (done_token & ~is_generation_done).bool() just_finished = (done_token & ~is_generation_done).bool()
generated_sequence_lengths[just_finished.view(-1)] = \ generated_sequence_lengths[just_finished.view(-1)] = \
context_length + 1 context_length + 1
......
...@@ -36,9 +36,6 @@ class MegatronGenerate(Resource): ...@@ -36,9 +36,6 @@ class MegatronGenerate(Resource):
def put(self): def put(self):
args = get_args() args = get_args()
print("request IP: " + str(request.remote_addr))
print(json.dumps(request.get_json()),flush=True)
print("current time: ", datetime.datetime.now())
if not "prompts" in request.get_json(): if not "prompts" in request.get_json():
return "prompts argument required", 400 return "prompts argument required", 400
...@@ -101,20 +98,60 @@ class MegatronGenerate(Resource): ...@@ -101,20 +98,60 @@ class MegatronGenerate(Resource):
add_BOS = request.get_json()["add_BOS"] add_BOS = request.get_json()["add_BOS"]
if not isinstance(add_BOS, bool): if not isinstance(add_BOS, bool):
return "add_BOS must be a boolean value" return "add_BOS must be a boolean value"
if any([len(prompt) == 0 for prompt in prompts]) and not add_BOS:
return "Empty prompts require add_BOS=true"
stop_on_double_eol = False
if "stop_on_double_eol" in request.get_json():
stop_on_double_eol = request.get_json()["stop_on_double_eol"]
if not isinstance(stop_on_double_eol, bool):
return "stop_on_double_eol must be a boolean value"
stop_on_eol = False
if "stop_on_eol" in request.get_json():
stop_on_eol = request.get_json()["stop_on_eol"]
if not isinstance(stop_on_eol, bool):
return "stop_on_eol must be a boolean value"
random_seed = -1
if "random_seed" in request.get_json():
random_seed = request.get_json()["random_seed"]
if not isinstance(random_seed, int):
return "random_seed must be integer"
if random_seed < 0:
return "random_seed must be a positive integer"
no_log = False
if "no_log" in request.get_json():
no_log = request.get_json()["no_log"]
if not isinstance(no_log, bool):
return "no_log must be a boolean value"
with lock: # Need to get lock to keep multiple threads from hitting code with lock: # Need to get lock to keep multiple threads from hitting code
if not no_log:
print("request IP: " + str(request.remote_addr))
print(json.dumps(request.get_json()),flush=True)
print("start time: ", datetime.datetime.now())
MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate
response, response_seg, response_logprobs, _ = \ try:
generate_and_post_process( response, response_seg, response_logprobs, _ = \
self.model, generate_and_post_process(
prompts=prompts, self.model,
tokens_to_generate=tokens_to_generate, prompts=prompts,
return_output_log_probs=logprobs, tokens_to_generate=tokens_to_generate,
top_k_sampling=top_k, return_output_log_probs=logprobs,
top_p_sampling=top_p, top_k_sampling=top_k,
temperature=temperature, top_p_sampling=top_p,
add_BOS=add_BOS, temperature=temperature,
use_eod_token_for_early_termination=True) add_BOS=add_BOS,
use_eod_token_for_early_termination=True,
stop_on_double_eol=stop_on_double_eol,
stop_on_eol=stop_on_eol,
random_seed=random_seed)
except ValueError as ve:
return "Length of prompt + tokens_to_generate longer than allowed"
print("end time: ", datetime.datetime.now())
return jsonify({"text": response, return jsonify({"text": response,
"segments": response_seg, "segments": response_seg,
......
...@@ -21,7 +21,6 @@ import sys ...@@ -21,7 +21,6 @@ import sys
import time import time
# The earliest we can measure the start time. # The earliest we can measure the start time.
_TRAIN_START_TIME = time.time() _TRAIN_START_TIME = time.time()
import torch import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
...@@ -43,7 +42,7 @@ from megatron.model import ModelType ...@@ -43,7 +42,7 @@ from megatron.model import ModelType
from megatron.optimizer import get_megatron_optimizer from megatron.optimizer import get_megatron_optimizer
from megatron.initialize import initialize_megatron from megatron.initialize import initialize_megatron
from megatron.initialize import write_args_to_tensorboard from megatron.initialize import write_args_to_tensorboard
from megatron.learning_rates import AnnealingLR from megatron.optimizer_param_scheduler import OptimizerParamScheduler
from megatron.model import DistributedDataParallel as LocalDDP from megatron.model import DistributedDataParallel as LocalDDP
from megatron.utils import check_adlr_autoresume_termination from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import unwrap_model from megatron.utils import unwrap_model
...@@ -51,7 +50,7 @@ from megatron.data.data_samplers import build_pretraining_data_loader ...@@ -51,7 +50,7 @@ from megatron.data.data_samplers import build_pretraining_data_loader
from megatron.utils import calc_params_l2_norm from megatron.utils import calc_params_l2_norm
from megatron.schedules import get_forward_backward_func from megatron.schedules import get_forward_backward_func
from megatron.utils import report_memory from megatron.utils import report_memory
from megatron.model.vision.knn_monitor import compute_feature_bank
def print_datetime(string): def print_datetime(string):
...@@ -65,6 +64,7 @@ def pretrain(train_valid_test_dataset_provider, ...@@ -65,6 +64,7 @@ def pretrain(train_valid_test_dataset_provider,
model_provider, model_provider,
model_type, model_type,
forward_step_func, forward_step_func,
process_non_loss_data_func=None,
extra_args_provider=None, extra_args_provider=None,
args_defaults={}): args_defaults={}):
"""Main training program. """Main training program.
...@@ -86,6 +86,10 @@ def pretrain(train_valid_test_dataset_provider, ...@@ -86,6 +86,10 @@ def pretrain(train_valid_test_dataset_provider,
the info we would like to monitor during training, for example the info we would like to monitor during training, for example
`lm-loss: value`. We also require that this function add `lm-loss: value`. We also require that this function add
`batch generator` to the timers class. `batch generator` to the timers class.
process_non_loss_data_func: a function to post process outputs of the
network. It can be used for dumping output tensors (e.g images) to
tensorboard. It takes `collected data`(list of tensors),
`current iteration index` and `tensorboard writer` as arguments.
extra_args_provider: a function that takes a parser and adds arguments extra_args_provider: a function that takes a parser and adds arguments
to it. It is used for programs to add their own arguments. to it. It is used for programs to add their own arguments.
args_defaults: a dictionary from argument-name to argument-value. It args_defaults: a dictionary from argument-name to argument-value. It
...@@ -113,7 +117,7 @@ def pretrain(train_valid_test_dataset_provider, ...@@ -113,7 +117,7 @@ def pretrain(train_valid_test_dataset_provider,
# Model, optimizer, and learning rate. # Model, optimizer, and learning rate.
timers('model-and-optimizer-setup').start() timers('model-and-optimizer-setup').start()
model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider, model, optimizer, opt_param_scheduler = setup_model_and_optimizer(model_provider,
model_type) model_type)
timers('model-and-optimizer-setup').stop() timers('model-and-optimizer-setup').stop()
print_datetime('after model, optimizer, and learning rate ' print_datetime('after model, optimizer, and learning rate '
...@@ -144,25 +148,28 @@ def pretrain(train_valid_test_dataset_provider, ...@@ -144,25 +148,28 @@ def pretrain(train_valid_test_dataset_provider,
iteration = 0 iteration = 0
if args.do_train and args.train_iters > 0: if args.do_train and args.train_iters > 0:
iteration = train(forward_step_func, iteration = train(forward_step_func,
model, optimizer, lr_scheduler, model, optimizer, opt_param_scheduler,
train_data_iterator, valid_data_iterator) train_data_iterator, valid_data_iterator,
process_non_loss_data_func)
print_datetime('after training is done') print_datetime('after training is done')
if args.do_valid: if args.do_valid:
prefix = 'the end of training for val data' prefix = 'the end of training for val data'
evaluate_and_print_results(prefix, forward_step_func, evaluate_and_print_results(prefix, forward_step_func,
valid_data_iterator, model, valid_data_iterator, model,
iteration, False) iteration, process_non_loss_data_func,
False)
if args.save and iteration != 0: if args.save and iteration != 0:
save_checkpoint(iteration, model, optimizer, lr_scheduler) save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
if args.do_test: if args.do_test:
# Run on test data. # Run on test data.
prefix = 'the end of training for test data' prefix = 'the end of training for test data'
evaluate_and_print_results(prefix, forward_step_func, evaluate_and_print_results(prefix, forward_step_func,
test_data_iterator, model, test_data_iterator, model,
0, True) 0, process_non_loss_data_func,
True)
def update_train_iters(args): def update_train_iters(args):
...@@ -285,7 +292,10 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap ...@@ -285,7 +292,10 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap
args.accumulate_allreduce_grads_in_fp32, args.accumulate_allreduce_grads_in_fp32,
args.use_contiguous_buffers_in_local_ddp) args.use_contiguous_buffers_in_local_ddp)
for model_module in model] for model_module in model]
# broad cast params from data parallel src rank to other data parallel ranks
if args.data_parallel_random_init:
for model_module in model:
model_module.broadcast_params()
else: else:
raise NotImplementedError('Unknown DDP implementation specified: ' raise NotImplementedError('Unknown DDP implementation specified: '
'{}. Exiting.'.format(args.DDP_impl)) '{}. Exiting.'.format(args.DDP_impl))
...@@ -293,7 +303,7 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap ...@@ -293,7 +303,7 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap
return model return model
def get_learning_rate_scheduler(optimizer): def get_optimizer_param_scheduler(optimizer):
"""Build the learning rate scheduler.""" """Build the learning rate scheduler."""
args = get_args() args = get_args()
...@@ -301,11 +311,12 @@ def get_learning_rate_scheduler(optimizer): ...@@ -301,11 +311,12 @@ def get_learning_rate_scheduler(optimizer):
if args.train_iters: if args.train_iters:
if args.lr_decay_iters is None: if args.lr_decay_iters is None:
args.lr_decay_iters = args.train_iters args.lr_decay_iters = args.train_iters
decay_steps = args.lr_decay_iters * args.global_batch_size lr_decay_steps = args.lr_decay_iters * args.global_batch_size
wd_incr_steps = args.train_iters * args.global_batch_size
if args.lr_warmup_fraction is not None: if args.lr_warmup_fraction is not None:
warmup_steps = args.lr_warmup_fraction * decay_steps lr_warmup_steps = args.lr_warmup_fraction * lr_decay_steps
else: else:
warmup_steps = args.lr_warmup_iters * args.global_batch_size lr_warmup_steps = args.lr_warmup_iters * args.global_batch_size
# Sample-based training. # Sample-based training.
elif args.train_samples: elif args.train_samples:
# We need to set training iters for later use. Technically # We need to set training iters for later use. Technically
...@@ -314,29 +325,38 @@ def get_learning_rate_scheduler(optimizer): ...@@ -314,29 +325,38 @@ def get_learning_rate_scheduler(optimizer):
update_train_iters(args) update_train_iters(args)
if args.lr_decay_samples is None: if args.lr_decay_samples is None:
args.lr_decay_samples = args.train_samples args.lr_decay_samples = args.train_samples
decay_steps = args.lr_decay_samples lr_decay_steps = args.lr_decay_samples
wd_incr_steps = args.train_samples
if args.lr_warmup_fraction is not None: if args.lr_warmup_fraction is not None:
warmup_steps = args.lr_warmup_fraction * decay_steps lr_warmup_steps = args.lr_warmup_fraction * lr_decay_steps
else: else:
warmup_steps = args.lr_warmup_samples lr_warmup_steps = args.lr_warmup_samples
else: else:
raise Exception( raise Exception(
'either train-iters or train-samples should be provided.') 'either train-iters or train-samples should be provided.')
lr_scheduler = AnnealingLR( opt_param_scheduler = OptimizerParamScheduler(
optimizer, optimizer,
max_lr=args.lr, max_lr=args.lr,
min_lr=args.min_lr, min_lr=args.min_lr,
warmup_steps=warmup_steps, lr_warmup_steps=lr_warmup_steps,
decay_steps=decay_steps, lr_decay_steps=lr_decay_steps,
decay_style=args.lr_decay_style, lr_decay_style=args.lr_decay_style,
use_checkpoint_lr_scheduler=args.use_checkpoint_lr_scheduler, start_wd=args.start_weight_decay,
override_lr_scheduler=args.override_lr_scheduler) end_wd=args.end_weight_decay,
wd_incr_steps=wd_incr_steps,
return lr_scheduler wd_incr_style=args.weight_decay_incr_style,
use_checkpoint_opt_param_scheduler=args.use_checkpoint_opt_param_scheduler,
override_opt_param_scheduler=args.override_opt_param_scheduler)
def setup_model_and_optimizer(model_provider_func, model_type):
return opt_param_scheduler
def setup_model_and_optimizer(model_provider_func,
model_type,
no_wd_decay_cond=None,
scale_lr_cond=None,
lr_mult=1.0):
"""Setup model and optimizer.""" """Setup model and optimizer."""
args = get_args() args = get_args()
...@@ -344,9 +364,10 @@ def setup_model_and_optimizer(model_provider_func, model_type): ...@@ -344,9 +364,10 @@ def setup_model_and_optimizer(model_provider_func, model_type):
unwrapped_model = unwrap_model(model, unwrapped_model = unwrap_model(model,
(torchDDP, LocalDDP, Float16Module)) (torchDDP, LocalDDP, Float16Module))
optimizer = get_megatron_optimizer(unwrapped_model) optimizer = get_megatron_optimizer(unwrapped_model, no_wd_decay_cond,
scale_lr_cond, lr_mult)
lr_scheduler = get_learning_rate_scheduler(optimizer) opt_param_scheduler = get_optimizer_param_scheduler(optimizer)
if args.load is not None: if args.load is not None:
timers = get_timers() timers = get_timers()
...@@ -354,7 +375,7 @@ def setup_model_and_optimizer(model_provider_func, model_type): ...@@ -354,7 +375,7 @@ def setup_model_and_optimizer(model_provider_func, model_type):
# max time. # max time.
torch.distributed.barrier() torch.distributed.barrier()
timers('load-checkpoint').start() timers('load-checkpoint').start()
args.iteration = load_checkpoint(model, optimizer, lr_scheduler) args.iteration = load_checkpoint(model, optimizer, opt_param_scheduler)
torch.distributed.barrier() torch.distributed.barrier()
timers('load-checkpoint').stop() timers('load-checkpoint').stop()
timers.log(['load-checkpoint']) timers.log(['load-checkpoint'])
...@@ -373,11 +394,11 @@ def setup_model_and_optimizer(model_provider_func, model_type): ...@@ -373,11 +394,11 @@ def setup_model_and_optimizer(model_provider_func, model_type):
if args.fp16: if args.fp16:
optimizer.reload_model_params() optimizer.reload_model_params()
return model, optimizer, lr_scheduler return model, optimizer, opt_param_scheduler
def train_step(forward_step_func, data_iterator, def train_step(forward_step_func, data_iterator,
model, optimizer, lr_scheduler): model, optimizer, opt_param_scheduler):
"""Single training step.""" """Single training step."""
args = get_args() args = get_args()
timers = get_timers() timers = get_timers()
...@@ -427,19 +448,45 @@ def train_step(forward_step_func, data_iterator, ...@@ -427,19 +448,45 @@ def train_step(forward_step_func, data_iterator,
else: else:
grad = word_embeddings_weight.grad grad = word_embeddings_weight.grad
torch.distributed.all_reduce(grad, group=mpu.get_embedding_group()) torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())
# All-reduce position_embeddings grad across first (encoder) and split (decoder)
# stages to ensure that position embeddings parameters stay in sync.
# This should only run for T5 models with pipeline parallelism
if mpu.is_rank_in_position_embedding_group() and \
mpu.get_pipeline_model_parallel_world_size() > 1 and \
args.pipeline_model_parallel_split_rank is not None:
unwrapped_model = model[0]
unwrapped_model = unwrap_model(
unwrapped_model, (torchDDP, LocalDDP, Float16Module))
assert args.DDP_impl == 'local', \
'T5 model is only supported with local DDP mode'
grad = unwrapped_model.language_model.embedding.position_embeddings.weight.main_grad
torch.distributed.all_reduce(grad, group=mpu.get_position_embedding_group())
timers('backward-embedding-all-reduce').stop() timers('backward-embedding-all-reduce').stop()
if args.vision_pretraining and args.vision_pretraining_type == "dino":
unwrapped_model = unwrap_model(model[0],
(torchDDP, LocalDDP, Float16Module))
unwrapped_model.cancel_gradients_last_layer(args.curr_iteration)
# Update parameters. # Update parameters.
timers('optimizer').start() timers('optimizer').start()
update_successful, grad_norm, num_zeros_in_grad = optimizer.step() update_successful, grad_norm, num_zeros_in_grad = optimizer.step()
timers('optimizer').stop() timers('optimizer').stop()
if args.vision_pretraining and args.vision_pretraining_type == "dino":
unwrapped_model = unwrap_model(model[0],
(torchDDP, LocalDDP, Float16Module))
unwrapped_model.update_momentum(args.curr_iteration)
# Update learning rate. # Update learning rate.
if update_successful: if update_successful:
increment = get_num_microbatches() * \ increment = get_num_microbatches() * \
args.micro_batch_size * \ args.micro_batch_size * \
args.data_parallel_size args.data_parallel_size
lr_scheduler.step(increment=increment) opt_param_scheduler.step(increment=increment)
skipped_iter = 0 skipped_iter = 0
else: else:
skipped_iter = 1 skipped_iter = 1
...@@ -629,20 +676,21 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, ...@@ -629,20 +676,21 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
return report_memory_flag return report_memory_flag
def save_checkpoint_and_time(iteration, model, optimizer, lr_scheduler): def save_checkpoint_and_time(iteration, model, optimizer, opt_param_scheduler):
timers = get_timers() timers = get_timers()
# Extra barrier is added to make sure # Extra barrier is added to make sure
# all ranks report the max time. # all ranks report the max time.
torch.distributed.barrier() torch.distributed.barrier()
timers('save-checkpoint').start() timers('save-checkpoint').start()
save_checkpoint(iteration, model, optimizer, lr_scheduler) save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
torch.distributed.barrier() torch.distributed.barrier()
timers('save-checkpoint').stop() timers('save-checkpoint').stop()
timers.log(['save-checkpoint']) timers.log(['save-checkpoint'])
def train(forward_step_func, model, optimizer, lr_scheduler, def train(forward_step_func, model, optimizer, opt_param_scheduler,
train_data_iterator, valid_data_iterator): train_data_iterator, valid_data_iterator,
process_non_loss_data_func):
"""Train the model function.""" """Train the model function."""
args = get_args() args = get_args()
timers = get_timers() timers = get_timers()
...@@ -665,12 +713,13 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -665,12 +713,13 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
report_memory_flag = True report_memory_flag = True
while iteration < args.train_iters: while iteration < args.train_iters:
update_num_microbatches(args.consumed_train_samples) update_num_microbatches(args.consumed_train_samples)
args.curr_iteration = iteration
loss_dict, skipped_iter, grad_norm, num_zeros_in_grad = \ loss_dict, skipped_iter, grad_norm, num_zeros_in_grad = \
train_step(forward_step_func, train_step(forward_step_func,
train_data_iterator, train_data_iterator,
model, model,
optimizer, optimizer,
lr_scheduler) opt_param_scheduler)
iteration += 1 iteration += 1
args.consumed_train_samples += mpu.get_data_parallel_world_size() * \ args.consumed_train_samples += mpu.get_data_parallel_world_size() * \
args.micro_batch_size * \ args.micro_batch_size * \
...@@ -691,7 +740,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -691,7 +740,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
if args.adlr_autoresume and \ if args.adlr_autoresume and \
(iteration % args.adlr_autoresume_interval == 0): (iteration % args.adlr_autoresume_interval == 0):
check_adlr_autoresume_termination(iteration, model, optimizer, check_adlr_autoresume_termination(iteration, model, optimizer,
lr_scheduler) opt_param_scheduler)
# Evaluation # Evaluation
if args.eval_interval and iteration % args.eval_interval == 0 and \ if args.eval_interval and iteration % args.eval_interval == 0 and \
...@@ -699,7 +748,8 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -699,7 +748,8 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
prefix = 'iteration {}'.format(iteration) prefix = 'iteration {}'.format(iteration)
evaluate_and_print_results(prefix, forward_step_func, evaluate_and_print_results(prefix, forward_step_func,
valid_data_iterator, model, valid_data_iterator, model,
iteration, False) iteration, process_non_loss_data_func,
False)
# Checkpointing # Checkpointing
saved_checkpoint = False saved_checkpoint = False
...@@ -707,14 +757,14 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -707,14 +757,14 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
signal_handler = get_signal_handler() signal_handler = get_signal_handler()
if any(signal_handler.signals_received()): if any(signal_handler.signals_received()):
save_checkpoint_and_time(iteration, model, optimizer, save_checkpoint_and_time(iteration, model, optimizer,
lr_scheduler) opt_param_scheduler)
print_datetime('exiting program after receiving SIGTERM.') print_datetime('exiting program after receiving SIGTERM.')
sys.exit() sys.exit()
if args.save and args.save_interval and \ if args.save and args.save_interval and \
iteration % args.save_interval == 0: iteration % args.save_interval == 0:
save_checkpoint_and_time(iteration, model, optimizer, save_checkpoint_and_time(iteration, model, optimizer,
lr_scheduler) opt_param_scheduler)
saved_checkpoint = True saved_checkpoint = True
# Exiting based on duration # Exiting based on duration
...@@ -728,7 +778,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -728,7 +778,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
if done: if done:
if not saved_checkpoint: if not saved_checkpoint:
save_checkpoint_and_time(iteration, model, optimizer, save_checkpoint_and_time(iteration, model, optimizer,
lr_scheduler) opt_param_scheduler)
print_datetime('exiting program after {} minutes'.format(train_time)) print_datetime('exiting program after {} minutes'.format(train_time))
sys.exit() sys.exit()
...@@ -736,7 +786,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -736,7 +786,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
if args.exit_interval and iteration % args.exit_interval == 0: if args.exit_interval and iteration % args.exit_interval == 0:
if not saved_checkpoint: if not saved_checkpoint:
save_checkpoint_and_time(iteration, model, optimizer, save_checkpoint_and_time(iteration, model, optimizer,
lr_scheduler) opt_param_scheduler)
torch.distributed.barrier() torch.distributed.barrier()
print_datetime('exiting program at iteration {}'.format(iteration)) print_datetime('exiting program at iteration {}'.format(iteration))
sys.exit() sys.exit()
...@@ -745,10 +795,17 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -745,10 +795,17 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
return iteration return iteration
def evaluate(forward_step_func, data_iterator, model, verbose=False): def evaluate(forward_step_func,
data_iterator,
model,
process_non_loss_data_func,
verbose=False):
"""Evaluation.""" """Evaluation."""
args = get_args() args = get_args()
if args.vision_pretraining and args.vision_pretraining_type == "dino":
compute_feature_bank(model)
# Turn on evaluation mode which disables dropout. # Turn on evaluation mode which disables dropout.
for model_module in model: for model_module in model:
model_module.eval() model_module.eval()
...@@ -782,6 +839,12 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False): ...@@ -782,6 +839,12 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
args.consumed_valid_samples += mpu.get_data_parallel_world_size() \ args.consumed_valid_samples += mpu.get_data_parallel_world_size() \
* args.micro_batch_size \ * args.micro_batch_size \
* get_num_microbatches() * get_num_microbatches()
collected_non_loss_data = None
if process_non_loss_data_func is not None and is_last_rank():
collected_non_loss_data = forward_backward_func(
forward_step_func, data_iterator, model, optimizer=None,
timers=None, forward_only=True, collect_non_loss_data=True)
# Move model back to the train mode. # Move model back to the train mode.
for model_module in model: for model_module in model:
model_module.train() model_module.train()
...@@ -789,16 +852,19 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False): ...@@ -789,16 +852,19 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
for key in total_loss_dict: for key in total_loss_dict:
total_loss_dict[key] /= args.eval_iters * get_num_microbatches() total_loss_dict[key] /= args.eval_iters * get_num_microbatches()
return total_loss_dict return total_loss_dict, collected_non_loss_data
def evaluate_and_print_results(prefix, forward_step_func, def evaluate_and_print_results(prefix, forward_step_func,
data_iterator, model, data_iterator, model,
iteration, verbose=False): iteration, process_non_loss_data_func,
verbose=False):
"""Helper function to evaluate and dump results on screen.""" """Helper function to evaluate and dump results on screen."""
args = get_args() args = get_args()
writer = get_tensorboard_writer() writer = get_tensorboard_writer()
total_loss_dict = evaluate(forward_step_func, data_iterator, model, verbose) total_loss_dict, collected_non_loss_data = evaluate(
forward_step_func, data_iterator, model,
process_non_loss_data_func, verbose)
string = ' validation loss at {} | '.format(prefix) string = ' validation loss at {} | '.format(prefix)
for key in total_loss_dict: for key in total_loss_dict:
string += '{} value: {:.6E} | '.format(key, total_loss_dict[key].item()) string += '{} value: {:.6E} | '.format(key, total_loss_dict[key].item())
...@@ -817,6 +883,9 @@ def evaluate_and_print_results(prefix, forward_step_func, ...@@ -817,6 +883,9 @@ def evaluate_and_print_results(prefix, forward_step_func,
writer.add_scalar('{} validation ppl vs samples'.format(key), writer.add_scalar('{} validation ppl vs samples'.format(key),
ppl, args.consumed_train_samples) ppl, args.consumed_train_samples)
if process_non_loss_data_func is not None and writer and is_last_rank():
process_non_loss_data_func(collected_non_loss_data, iteration, writer)
length = len(string) + 1 length = len(string) + 1
print_rank_last('-' * length) print_rank_last('-' * length)
print_rank_last(string) print_rank_last(string)
...@@ -895,7 +964,6 @@ def build_train_valid_test_data_iterators( ...@@ -895,7 +964,6 @@ def build_train_valid_test_data_iterators(
args.do_valid = flags[1].item() args.do_valid = flags[1].item()
args.do_test = flags[2].item() args.do_test = flags[2].item()
# Build iterators. # Build iterators.
dl_type = args.dataloader_type dl_type = args.dataloader_type
assert dl_type in ['single', 'cyclic'] assert dl_type in ['single', 'cyclic']
......
...@@ -126,7 +126,7 @@ def print_params_min_max_norm(optimizer, iteration): ...@@ -126,7 +126,7 @@ def print_params_min_max_norm(optimizer, iteration):
def check_adlr_autoresume_termination(iteration, model, def check_adlr_autoresume_termination(iteration, model,
optimizer, lr_scheduler): optimizer, opt_param_scheduler):
"""Check for autoresume signal and exit if it is received.""" """Check for autoresume signal and exit if it is received."""
from megatron.checkpointing import save_checkpoint from megatron.checkpointing import save_checkpoint
...@@ -136,7 +136,7 @@ def check_adlr_autoresume_termination(iteration, model, ...@@ -136,7 +136,7 @@ def check_adlr_autoresume_termination(iteration, model,
torch.distributed.barrier() torch.distributed.barrier()
if autoresume.termination_requested(): if autoresume.termination_requested():
if args.save: if args.save:
save_checkpoint(iteration, model, optimizer, lr_scheduler) save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
print_rank_0(">>> autoresume termination request found!") print_rank_0(">>> autoresume termination request found!")
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
autoresume.request_resume() autoresume.request_resume()
......
...@@ -21,21 +21,33 @@ from functools import partial ...@@ -21,21 +21,33 @@ from functools import partial
from megatron import get_args, get_timers, mpu, print_rank_0 from megatron import get_args, get_timers, mpu, print_rank_0
from megatron.data.vit_dataset import build_train_valid_datasets from megatron.data.vit_dataset import build_train_valid_datasets
from megatron.model import ModelType from megatron.model import ModelType
from megatron.model.vit_model import VitModel from megatron.model.vision.classification import VitClassificationModel
from megatron.model.vision.classification import MitClassificationModel
from megatron.training import pretrain from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group from megatron.utils import average_losses_across_data_parallel_group
def model_provider(pre_process=True, post_process=True): def model_provider(pre_process=True, post_process=True):
"""Build the model.""" """Build the model."""
print_rank_0("building VIT model ...")
args = get_args() args = get_args()
model = VitModel(num_classes=args.num_classes, if args.vision_backbone_type == 'vit':
pre_process=pre_process, print_rank_0("building VIT model ...")
post_process=post_process) model = VitClassificationModel(num_classes=args.num_classes,
pre_process=pre_process,
post_process=post_process)
elif args.vision_backbone_type == 'mit':
print_rank_0("building MIT model ...")
model = MitClassificationModel(num_classes=args.num_classes,
pre_process=pre_process,
post_process=post_process)
else:
raise Exception('{} vision backbone is not supported.'.format(
args.vision_backbone_type))
return model return model
def get_batch(data_iterator): def get_batch(data_iterator):
"""Build the batch.""" """Build the batch."""
data = next(data_iterator) data = next(data_iterator)
...@@ -46,6 +58,7 @@ def get_batch(data_iterator): ...@@ -46,6 +58,7 @@ def get_batch(data_iterator):
return images, labels return images, labels
def loss_func(labels, output_tensor): def loss_func(labels, output_tensor):
logits = output_tensor.contiguous().float() logits = output_tensor.contiguous().float()
loss = F.cross_entropy(logits, labels) loss = F.cross_entropy(logits, labels)
...@@ -58,6 +71,7 @@ def loss_func(labels, output_tensor): ...@@ -58,6 +71,7 @@ def loss_func(labels, output_tensor):
return loss, {"loss": averaged_loss[0], "accuracy": averaged_loss[1]} return loss, {"loss": averaged_loss[0], "accuracy": averaged_loss[1]}
def forward_step(data_iterator, model): def forward_step(data_iterator, model):
"""Forward step.""" """Forward step."""
timers = get_timers() timers = get_timers()
...@@ -82,7 +96,10 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): ...@@ -82,7 +96,10 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
print_rank_0( print_rank_0(
"> building train, validation, and test datasets " "for VIT ..." "> building train, validation, and test datasets " "for VIT ..."
) )
train_ds, valid_ds = build_train_valid_datasets(data_path=args.data_path) train_ds, valid_ds = build_train_valid_datasets(
data_path=args.data_path,
image_size=(args.img_h, args.img_w)
)
print_rank_0("> finished creating VIT datasets ...") print_rank_0("> finished creating VIT datasets ...")
return train_ds, valid_ds, None return train_ds, valid_ds, None
...@@ -95,5 +112,5 @@ if __name__ == "__main__": ...@@ -95,5 +112,5 @@ if __name__ == "__main__":
model_provider, model_provider,
ModelType.encoder_or_decoder, ModelType.encoder_or_decoder,
forward_step, forward_step,
args_defaults={'dataloader_type': 'cyclic'} args_defaults={'dataloader_type': 'cyclic', 'vision_pretraining': True}
) )
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
import torch.distributed as dist
from functools import partial
from megatron import get_args, get_timers, mpu, print_rank_0
from megatron.data.vit_dataset import build_train_valid_datasets
from megatron.model.vision.dino import DINOPretrainModel
from megatron.model.vision.knn_monitor import knn_predict, get_feature_bank
from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group, unwrap_model
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
from megatron.model import ModelType
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
return DINOPretrainModel(pre_process=pre_process, post_process=post_process)
def get_batch(data_iterator):
"""Build the batch."""
data = next(data_iterator)
# only data parallelism; no need for broadcast
if isinstance(data[0], list):
images = [aug.cuda() for aug in data[0]]
else:
images = data[0].cuda()
labels = data[1].cuda()
return images, labels
def loss_func(model, labels, output_tensor, collect_data=False):
args = get_args()
model = unwrap_model(
model,
(torchDDP, LocalDDP, Float16Module)
)
if model.training:
student_output, teacher_output = output_tensor
loss = model.dino_loss(student_output, teacher_output, args.curr_iteration)
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {"loss": averaged_loss[0]}
else:
_, teacher_feature = output_tensor
feature_bank, feature_labels, classes = get_feature_bank()
feature = F.normalize(teacher_feature.float(), dim=1)
knn_accs = []
for k in [10, 20, 100, 200]:
pred_labels = knn_predict(feature, feature_bank,
feature_labels, classes, k, 0.07)
knn_acc = (pred_labels[:, 0] == labels).float().mean()
knn_accs.append(knn_acc)
averaged_loss = average_losses_across_data_parallel_group(knn_accs)
return 0, {"knn_acc_10": averaged_loss[0],
"knn_acc_20": averaged_loss[1],
"knn_acc_100": averaged_loss[2],
"knn_acc_200": averaged_loss[3]}
def forward_step(data_iterator, model):
"""Forward step."""
timers = get_timers()
# Get the batch.
timers("batch-generator").start()
(
images,
labels,
) = get_batch(data_iterator)
timers("batch-generator").stop()
return model(images), partial(loss_func, model, labels)
def train_valid_test_datasets_provider(train_val_test_num_samples):
"""Build train, valid, and test datasets."""
args = get_args()
print_rank_0(
"> building train, validation, and test datasets " "for VIT ..."
)
train_ds, valid_ds = build_train_valid_datasets(
data_path=args.data_path,
image_size=(args.img_h, args.img_w)
)
print_rank_0("> finished creating VIT datasets ...")
return train_ds, valid_ds, None
if __name__ == "__main__":
pretrain(
train_valid_test_datasets_provider,
model_provider,
ModelType.encoder_or_decoder,
forward_step,
args_defaults={'dataloader_type': 'cyclic', 'vision_pretraining': True}
)
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pretrain VIT"""
import torch
import torch.nn.functional as F
from functools import partial
from megatron import get_args, get_timers, mpu, print_rank_0, print_rank_last
from megatron.data.vit_dataset import build_train_valid_datasets
from megatron.model.vision.inpainting import VitInpaintingModel
from megatron.model.vision.inpainting import MitInpaintingModel
from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group
from tasks.vision.metrics import SSIM, PSNR
from megatron.model import ModelType
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
args = get_args()
if args.vision_backbone_type == 'vit':
model = VitInpaintingModel(pre_process=pre_process,
post_process=post_process)
elif args.vision_backbone_type == 'mit':
model = MitInpaintingModel(pre_process=pre_process,
post_process=post_process)
else:
raise Exception('{} vision backbone is not supported.'.format(
args.vision_backbone_type))
return model
def get_batch(data_iterator):
"""Build the batch."""
data = next(data_iterator)
# only data parallelism; no need for broadcast
images = data[0][0].cuda()
masks = data[0][1].cuda()
return images, masks
def loss_func(images, masks, masked_images, outputs, collect_data=False):
outputs = outputs.contiguous().float()
masks_flip = 1-masks
flip_masked_outputs = outputs.masked_fill(masks_flip.bool(), 0)
flip_masked_images = images.masked_fill(masks_flip.bool(), 0)
ssim_fun = SSIM()
psnr_fun = PSNR()
if not collect_data:
mask_count = torch.count_nonzero(masks)
loss = F.mse_loss(
flip_masked_outputs,
flip_masked_images.float(),
reduction="sum"
)
loss = loss/mask_count
ssim = ssim_fun(flip_masked_outputs, flip_masked_images.float())
psnr = psnr_fun(flip_masked_outputs, flip_masked_images.float())
averaged_loss = average_losses_across_data_parallel_group(
[loss, psnr, ssim]
)
return loss, {"loss": averaged_loss[0],
"psnr": averaged_loss[1],
'ssim': averaged_loss[2]}
else:
synth_images = masked_images.float() + flip_masked_outputs
ssim = ssim_fun(synth_images, images.float())
psnr = psnr_fun(synth_images, images.float())
return torch.cat((images, masked_images, synth_images), dim=2), ssim, psnr
def forward_step(data_iterator, model):
"""Forward step."""
timers = get_timers()
# Get the batch.
timers("batch-generator").start()
(
images,
masks,
) = get_batch(data_iterator)
timers("batch-generator").stop()
masked_images = images.masked_fill(masks.bool(), 0)
outputs = model(masked_images)
# Forward mode
return outputs, partial(loss_func, images, masks, masked_images)
def process_non_loss_data(data, iteration, writer):
psnr_sum = 0
ssim_sum = 0
for (output_tb, ssim, psnr) in data:
output_tb[output_tb < 0] = 0
output_tb[output_tb > 1] = 1
writer.add_images("gt-input-output-vald", output_tb,
global_step=iteration, walltime=None,
dataformats='NCHW')
psnr_sum = psnr_sum + psnr.item()
ssim_sum = ssim_sum + ssim.item()
psnr = psnr_sum/len(data)
ssim = ssim_sum/len(data)
writer.add_scalar('PSNR generate value-validation', psnr, iteration)
writer.add_scalar('SSIM generate value-validation', ssim, iteration)
def train_valid_test_datasets_provider(train_val_test_num_samples):
"""Build train, valid, and test datasets."""
args = get_args()
print_rank_0(
"> building train, validation, and test datasets " "for VIT ..."
)
train_ds, valid_ds = build_train_valid_datasets(
data_path=args.data_path,
image_size=(args.img_h, args.img_w)
)
print_rank_0("> finished creating VIT datasets ...")
return train_ds, valid_ds, None
if __name__ == "__main__":
pretrain(
train_valid_test_datasets_provider,
model_provider,
ModelType.encoder_or_decoder,
forward_step,
process_non_loss_data,
args_defaults={'dataloader_type': 'cyclic', 'vision_pretraining': True}
)
...@@ -154,7 +154,7 @@ def _build_train_valid_dataloaders(train_dataset, valid_dataset, ...@@ -154,7 +154,7 @@ def _build_train_valid_dataloaders(train_dataset, valid_dataset,
return train_dataloader, valid_dataloader return train_dataloader, valid_dataloader
def _train(model, optimizer, lr_scheduler, forward_step, def _train(model, optimizer, opt_param_scheduler, forward_step,
train_dataloader, valid_dataloader, end_of_epoch_callback): train_dataloader, valid_dataloader, end_of_epoch_callback):
"""Train the model.""" """Train the model."""
args = get_args() args = get_args()
...@@ -195,7 +195,7 @@ def _train(model, optimizer, lr_scheduler, forward_step, ...@@ -195,7 +195,7 @@ def _train(model, optimizer, lr_scheduler, forward_step,
start_iteration = 0 start_iteration = 0
# Train for one step. # Train for one step.
out = train_step(forward_step, batch, model, optimizer, lr_scheduler) out = train_step(forward_step, batch, model, optimizer, opt_param_scheduler)
losses_dict, skipped_iter, grad_norm, num_zeros_in_grad = out losses_dict, skipped_iter, grad_norm, num_zeros_in_grad = out
iteration += 1 iteration += 1
...@@ -215,13 +215,13 @@ def _train(model, optimizer, lr_scheduler, forward_step, ...@@ -215,13 +215,13 @@ def _train(model, optimizer, lr_scheduler, forward_step,
if args.adlr_autoresume and \ if args.adlr_autoresume and \
(iteration % args.adlr_autoresume_interval == 0): (iteration % args.adlr_autoresume_interval == 0):
check_adlr_autoresume_termination(iteration, model, check_adlr_autoresume_termination(iteration, model,
optimizer, lr_scheduler) optimizer, opt_param_scheduler)
# Checkpointing # Checkpointing
saved_checkpoint = False saved_checkpoint = False
if args.save and args.save_interval and \ if args.save and args.save_interval and \
iteration % args.save_interval == 0: iteration % args.save_interval == 0:
save_checkpoint(iteration, model, optimizer, lr_scheduler) save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
saved_checkpoint = True saved_checkpoint = True
# Evaluation # Evaluation
...@@ -234,14 +234,14 @@ def _train(model, optimizer, lr_scheduler, forward_step, ...@@ -234,14 +234,14 @@ def _train(model, optimizer, lr_scheduler, forward_step,
# Exiting based on iterations # Exiting based on iterations
if args.exit_interval and iteration % args.exit_interval == 0: if args.exit_interval and iteration % args.exit_interval == 0:
if not saved_checkpoint: if not saved_checkpoint:
save_checkpoint(iteration, model, optimizer, lr_scheduler) save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
torch.distributed.barrier() torch.distributed.barrier()
print_rank_0('exiting program at iteration {}'.format(iteration)) print_rank_0('exiting program at iteration {}'.format(iteration))
sys.exit() sys.exit()
# Checkpointing at the end of each epoch. # Checkpointing at the end of each epoch.
if args.save: if args.save:
save_checkpoint(iteration, model, optimizer, lr_scheduler) save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
# Callback at the end of each epoch. # Callback at the end of each epoch.
if end_of_epoch_callback is not None: if end_of_epoch_callback is not None:
...@@ -279,7 +279,7 @@ def finetune(train_valid_datasets_provider, model_provider, ...@@ -279,7 +279,7 @@ def finetune(train_valid_datasets_provider, model_provider,
# Build model, optimizer and learning rate scheduler. # Build model, optimizer and learning rate scheduler.
timers('model and optimizer').start() timers('model and optimizer').start()
model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider, model_type) model, optimizer, opt_param_scheduler = setup_model_and_optimizer(model_provider, model_type)
timers('model and optimizer').stop() timers('model and optimizer').stop()
# If pretrained checkpoint is provided and we have not trained for # If pretrained checkpoint is provided and we have not trained for
...@@ -307,7 +307,7 @@ def finetune(train_valid_datasets_provider, model_provider, ...@@ -307,7 +307,7 @@ def finetune(train_valid_datasets_provider, model_provider,
# Finetune the model. # Finetune the model.
if args.epochs > 0: if args.epochs > 0:
_train(model, optimizer, lr_scheduler, forward_step, _train(model, optimizer, opt_param_scheduler, forward_step,
train_dataloader, valid_dataloader, end_of_epoch_callback) train_dataloader, valid_dataloader, end_of_epoch_callback)
# Or just evaluate. # Or just evaluate.
else: else:
......
# Multi-Stage Prompting for Knowledgeable Dialogue Generation
Below we present the steps to run our multi-stage dialogue prompting (MSDP) framework.
## Multi-Stage Dialogue Prompting
### Data Preparation
1. Dataset Download: [Wizard of Wikipedia](https://parl.ai/projects/wizard_of_wikipedia/) and [Wizard of Internet](https://parl.ai/projects/sea/)
2. Data Processing: We provide the script to run the [`data processing`](../../examples/msdp/data_processing.sh) of the datatsets.
### Stage-1: Prompting for Knowledge Generation
1. We provide the script to perform the [`first-stage prompting`](../../examples/msdp/prompt_knwl_gen.sh) for the knowledge generation.
2. We provide the [`evaluation script`](../../examples/msdp/eval_knwl_generation.sh) for the automatic evaluation (i.e., F1, BLEU, METEOR, and ROUGE-L) of the knowledge generation.
### Stage-2: Prompting for Response Generation
1. We provide the script to [`prepare the input file`](../../examples/msdp/prep_resp_gen.sh) for the response generation (based on the previously generated knowledge file).
2. We provide the script to perform the [`second-stage prompting`](../../examples/msdp/prompt_resp_gen.sh) for the response generation.
3. We provide the [`evaluation script`](../../examples/msdp/eval_resp_generation.sh) for the automatic evaluation (i.e., F1, KF1, BLEU, METEOR, and ROUGE-L) of the response generation.
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Model evaluation"""
from megatron import get_args
from megatron import print_rank_0
from tasks.msdp.metrics import F1Metric
from tqdm import tqdm
def evaluate_f1(guess_file, answer_file):
"""Evaluating F1 Score"""
guess_list = []
print_rank_0('reading %s' % guess_file)
with open(guess_file, "r") as f:
for i, line in enumerate(tqdm(f)):
line = line.strip()
if "<|endoftext|>" in line:
line = line.replace("<|endoftext|>", "")
guess_list.append(line)
answer_list = []
print_rank_0('reading %s' % answer_file)
with open(answer_file, "r") as f:
for i, line in enumerate(tqdm(f)):
line = line.strip()
if line == "no_passages_used":
line = ""
answer_list.append(line)
assert len(guess_list) == len(answer_list), \
"lengths of guess and answer are different!"
precision, recall, f1 = F1Metric.compute_all_pairs(guess_list, answer_list)
print_rank_0('Precision: %.4f; recall: %.4f; f1: %.4f' % (precision, recall, f1))
print_rank_0('done :-)')
def main():
args = get_args()
evaluate_f1(args.guess_file, args.answer_file)
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Run multi-stage dialogue prompting (MSDP)."""
import os
import sys
sys.path.append(os.path.abspath(os.path.join(
os.path.join(os.path.dirname(__file__), os.path.pardir), os.path.pardir)))
from megatron import get_args
from megatron.initialize import initialize_megatron
def get_tasks_args(parser):
"""Provide extra arguments required for tasks."""
group = parser.add_argument_group(title='tasks')
# parameters for the knowledgeable dialogue generation
group.add_argument('--task', type=str, required=True,
help='Task name.')
group.add_argument("--sample-input-file", type=str, default=None,
help='Get input from file instead of interactive mode, '
'each line is an input.')
group.add_argument("--sample-output-file", type=str, default=None,
help='Output file got from --sample-input-file')
group.add_argument('--prompt-file', type=str, default=None,
help='prompting file')
group.add_argument('--prompt-type', type=str, default=None,
choices=['knowledge', 'response'],
help='prompt type (knowledge or response)')
group.add_argument('--num-prompt-examples', type=int, default=10,
help='number of prompt examples')
group.add_argument('--guess-file', type=str, default=None,
help='datapath for generated sentences')
group.add_argument('--answer-file', type=str, default=None,
help='datapath for golden sentences')
group.add_argument('--out-seq-length', type=int, default=100,
help='output sequence length')
group.add_argument('--api-prompt', default=False, action="store_true",
help='setup model api for prompting')
group.add_argument('--megatron-api-url', type=str, default=None,
help='url of the megatron api')
return parser
if __name__ == '__main__':
initialize_megatron(extra_args_provider=get_tasks_args)
args = get_args()
if args.num_layers_per_virtual_pipeline_stage is not None:
print("Interleaved pipeline schedule is not yet supported for downstream tasks.")
exit()
if args.task == 'MSDP-PROMPT':
from tasks.msdp.prompt import main
elif args.task == 'MSDP-EVAL-F1':
from tasks.msdp.evaluate import main
else:
raise NotImplementedError('Task {} is not implemented.'.format(
args.task))
main()
# The following code is adapted from
# https://github.com/facebookresearch/ParlAI/blob/master/parlai/core/metrics.py,
# which is licensed under the MIT license. More details on the license can be
# found at https://github.com/facebookresearch/ParlAI/blob/master/LICENSE.
"""Provides standard metric evaluations for dialog."""
from collections import Counter
from typing import List
import numpy as np
import re
re_art = re.compile(r'\b(a|an|the)\b')
re_punc = re.compile(r'[!"#$%&()*+,-./:;<=>?@\[\]\\^`{|}~_\']')
def normalize_answer(s):
"""
Lower text and remove punctuation, articles and extra whitespace.
"""
s = s.lower()
s = re_punc.sub(' ', s)
s = re_art.sub(' ', s)
s = ' '.join(s.split())
return s
class F1Metric:
"""
Helper class which computes token-level F1.
"""
@staticmethod
def _prec_recall_f1_score(pred_items, gold_items):
"""
Compute precision, recall and f1 given a set of gold and prediction items.
:param pred_items: iterable of predicted values
:param gold_items: iterable of gold values
:return: tuple (p, r, f1) for precision, recall, f1
"""
common = Counter(gold_items) & Counter(pred_items)
num_same = sum(common.values())
if num_same == 0:
return 0, 0, 0
precision = 1.0 * num_same / len(pred_items)
recall = 1.0 * num_same / len(gold_items)
f1 = (2 * precision * recall) / (precision + recall)
return precision, recall, f1
@staticmethod
def compute_each_pair(guess: str, answer: str):
if answer == "":
return None, None, None
if guess == "":
return 0, 0, 0
g_tokens = normalize_answer(guess).split()
a_tokens = normalize_answer(answer).split()
precision, recall, f1 = F1Metric._prec_recall_f1_score(g_tokens, a_tokens)
return precision, recall, f1
@staticmethod
def compute_all_pairs(guesses: List[str], answers: List[str]):
# additional augment:
assert len(guesses) == len(answers)
precision_list, recall_list, f1_list = [], [], []
for guess, answer in zip(guesses, answers):
precision, recall, f1 = F1Metric.compute_each_pair(guess, answer)
if precision is None or recall is None or f1 is None:
continue
precision_list.append(precision)
recall_list.append(recall)
f1_list.append(f1)
return np.mean(precision_list), np.mean(recall_list), np.mean(f1_list)
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Preprocessing for Wizard of Wikipedia and Wizard of Internet datasets"""
import torch
import argparse
from nltk import word_tokenize
from tqdm import tqdm
import numpy as np
import json
def get_args():
parser = argparse.ArgumentParser(description="Preprocessing")
parser.add_argument("--func", type=str, default=None,
help="choose to run which function")
parser.add_argument("--raw_file", type=str, default=None,
help="path of the input file")
parser.add_argument("--processed_file", type=str, default=None,
help="path of the output file")
parser.add_argument("--knwl_ref_file", type=str, default=None,
help="path of the knowledge reference file")
parser.add_argument("--resp_ref_file", type=str, default=None,
help="path of the knowledge reference file")
parser.add_argument("--knwl_gen_file", type=str, default=None,
help="path of the generated knowledge file")
parser.add_argument("--test_file", type=str, default=None,
help="path of the test file")
parser.add_argument("--train_file", type=str, default=None,
help="path of the train file")
parser.add_argument("--model_file", type=str, default=None,
help="path of the model file")
parser.add_argument("--data_type", type=str, default=None,
help="data types, choose one out of three types: \
wow_seen, wow_unseen, and woi")
parser.add_argument("--seed", type=int, default=1234,
help="random seed")
args = parser.parse_args()
return args
def process_wow_dataset(raw_file, processed_file, knwl_ref_file, resp_ref_file):
"""
This is a function used for processing the wizard of wikipedia (wow) dataset
Expected processed format:
topic \t dialogue context \t golden knowledge \t golden response
"""
# loading the raw data
print("> Loading data from %s" % raw_file)
with open(raw_file, "r") as fr:
dialog_data = json.load(fr)
print("> Processing data ...")
fproc = open(processed_file, "w")
fknwl = open(knwl_ref_file, "w") if knwl_ref_file else None
fresp = open(resp_ref_file, "w") if resp_ref_file else None
for i, sample in enumerate(tqdm(dialog_data)):
# get all the dialog data for a single dialog sample
dialog = sample["dialog"]
turn_list = [] # collect the dialog history
# processing for each single dialog sample
for j, turn in enumerate(dialog):
# text of each turn
text = turn["text"]
if not (text.endswith("?") or text.endswith(".") or text.endswith("!")):
text = text + "."
if j == 0:
# first turn
turn_list.append(text)
continue
speaker = turn["speaker"].lower()
if "wizard" in speaker:
checked_sentence = list(turn["checked_sentence"].values()) # knowledge
checked_passage = list(turn["checked_passage"].values()) # topic
assert len(checked_sentence) <= 1
# get the ground truth knowledge
if len(checked_sentence) > 0:
checked_sentence = checked_sentence[0]
else:
checked_sentence = "no_passages_used"
if len(checked_passage) == 1:
checked_passage = checked_passage[0]
else:
checked_passage = "no_passages_used"
# get the topic
if checked_passage != "no_passages_used":
topic = checked_passage
else:
topic = sample["chosen_topic"]
dialog_context = " [SEP] ".join(turn_list)
knowledge = checked_sentence
response = text
# add the response into the dialog history
turn_list.append(response)
# write to the output files
fproc.write(topic + "\t" + dialog_context + "\t" + \
knowledge + "\t" + response + "\n")
if fknwl:
fknwl.write(knowledge + "\n")
if fresp:
# tokenize for evaluation
response = " ".join(word_tokenize(response))
fresp.write(response + "\n")
else:
assert "apprentice" in speaker
turn_list.append(text)
fproc.close()
if fknwl:
fknwl.close()
if fresp:
fresp.close()
def process_woi_dataset(raw_file, processed_file, knwl_ref_file, resp_ref_file):
"""
This is a function used for processing the wizard of internet (woi) dataset
Expected processed format:
topic \t dialogue context \t golden knowledge \t golden response
"""
print("> Processing %s" % raw_file)
fproc = open(processed_file, "w")
fknwl = open(knwl_ref_file, "w") if knwl_ref_file else None
fresp = open(resp_ref_file, "w") if resp_ref_file else None
with open(raw_file, "r") as fr:
for i, line in tqdm(enumerate(fr)):
# read line by line, each line uses json format
line = line.strip()
item_dict = json.loads(line)
# item_dict is a dictionary
# its key is the data id, and its value contains all the data content
item_dict = item_dict.values()
item_dict = list(item_dict)[0] # len(item_dict) == 1
# get the whole dialog data for a single dialog sample
dialog_data = item_dict['dialog_history']
length = len(dialog_data)
turn_list = [] # collect the dialog history
search_text = ""
for i in range(length):
item = dialog_data[i]
action = item['action']
if action == "Wizard => SearchAgent":
search_text = item['text']
elif action == "Wizard => Apprentice":
if len(turn_list) == 0:
# first turn
turn = item['text']
turn_list.append(turn)
continue
# get the relevant content
contents = item["context"]["contents"]
selects = item["context"]["selected_contents"]
flag = selects[0][0]
selects = selects[1:]
assert len(selects) == len(contents)
# get the topic
if flag:
# no knowledge sentence is used for the response
topic = "no_topic"
knwl_sent = "no_passages_used"
else:
# we consider the search text as the topic
topic = search_text
# get the knowledge sentence
knwl_sent = ""
for content, select in zip(contents, selects):
content = content['content']
assert len(content) == len(select)
for c, s in zip(content, select):
if s:
knwl_sent = c
break
if knwl_sent == "":
# no knowledge is used for the response
topic = "no_topic"
knwl_sent = "no_passages_used"
# get dialogue context, knowledge, and response
dialog_context = " [SEP] ".join(turn_list)
response = item['text']
# processing
topic = topic.replace("\n", "").replace("\r", \
"").replace("\t", "")
dialog_context = dialog_context.replace("\n", "").replace("\r", \
"").replace("\t", "")
knwl_sent = knwl_sent.replace("\n", "").replace("\r", \
"").replace("\t", "")
response = response.replace("\n", "").replace("\r", \
"").replace("\t", "")
if topic != "no_topic":
# write to the ouput files
fproc.write(topic + "\t" + dialog_context + "\t" + \
knwl_sent + "\t" + response + "\n")
if fknwl:
fknwl.write(knwl_sent + "\n")
if fresp:
# tokenize for evaluation
response = " ".join(word_tokenize(response))
fresp.write(response + "\n")
turn_list.append(response)
elif action == "Apprentice => Wizard":
turn = item['text']
turn_list.append(turn)
else:
assert action == "SearchAgent => Wizard", \
"Please check whether you have used the correct data!"
fproc.close()
if fknwl:
fknwl.close()
if fresp:
fresp.close()
def get_database(test_datapath, train_datapath, data_type):
"""Get the database by topics"""
assert data_type in ["wow_seen", "wow_unseen", "woi"], \
"Please input a correct data type!!"
# get test data topic dictionary
print("> reading test data from %s" % test_datapath)
test_topics = {}
with open(test_datapath, "r") as f:
for i, line in enumerate(f):
line = line.strip()
splits = line.split("\t")
topic = splits[0]
test_topics[topic] = True
print("> reading data from %s" % train_datapath)
train_data_by_topic = {}
dialog_data_by_topic = {}
dialog_examples = []
with open(train_datapath, "r") as f:
for i, line in enumerate(f):
line = line.strip()
splits = line.split("\t")
topic = splits[0]
turns = splits[1].split(" [SEP] ")[-3:]
knowledge = splits[2]
response = splits[3]
# filtering data samples
if knowledge == "no_passages_used":
# when no knowledge is used
continue
if data_type != "wow_seen" and ("(" in knowledge or ")" in knowledge):
# when bracket exists in the knowledge
continue
if data_type != "wow_seen" and topic not in knowledge:
# when topic does not exist in the knowledge
continue
# get the instance
last_turn = turns[-1]
instance = "( " + last_turn + " ) " + topic + " => " + knowledge
# construct dialog example
dialog_example = ""
if data_type != "wow_seen":
dialog_example += "( " + topic + " ) "
for i, turn in enumerate(turns):
if i != 0:
dialog_example += " "
dialog_example += turn
# check overlaps
if topic in test_topics:
if topic not in train_data_by_topic:
train_data_by_topic[topic] = [instance]
else:
train_data_by_topic[topic].append(instance)
if topic not in dialog_data_by_topic:
dialog_data_by_topic[topic] = [dialog_example]
else:
dialog_data_by_topic[topic].append(dialog_example)
else:
# filtering data samples
if len(knowledge.split()) > 20:
# knowledge is too long
continue
if knowledge.startswith("It") or knowledge.startswith("it") or \
knowledge.startswith("This") or knowledge.startswith("this"):
continue
# append all the data into dialogue examples list
dialog_examples.append((topic, dialog_example, instance))
return train_data_by_topic, dialog_data_by_topic, dialog_examples
emb_dict = {}
def select_prompts_based_on_similarity(
query, dialog_list, prompt_list, topic, tokenizer, encoder, topk):
"""Select samples based on the similarity"""
with torch.no_grad():
# get the query embeddings
query_ids = tokenizer.encode(query)
query_ids = torch.LongTensor([query_ids]).cuda()
query_emb = encoder(input_ids=query_ids).pooler_output
query_emb = query_emb[0]
# calculate embeddings for the samples in the database
if topic in emb_dict:
example_embeddings = emb_dict[topic]
example_embeddings = example_embeddings.cuda()
else:
for idx, example in enumerate(dialog_list):
example_ids = tokenizer.encode(example)
example_ids = torch.LongTensor([example_ids]).cuda()
example_emb = encoder(input_ids=example_ids).pooler_output
if idx == 0:
example_embeddings = example_emb
else:
example_embeddings = torch.cat(
(example_embeddings, example_emb), dim=0)
emb_dict[topic] = example_embeddings.cpu()
# compare the similarity and select the topk samples
similarity_list = example_embeddings.matmul(query_emb)
_, indices = torch.topk(similarity_list, k=topk)
indices = indices.tolist()
indices = indices[::-1] # reverse the order
selected_prompts = []
for index in indices:
# index = index.item()
selected_prompts.append(prompt_list[index])
return selected_prompts
def prompt_selection_for_knowledge_generation(
test_datapath, train_datapath, model_path, output_prompt_path, data_type):
"""Selecting prompts for the knowledge generation"""
print("> Selecting prompts for the knowledge generation")
train_data_by_topic, dialog_data_by_topic, dialog_examples = \
get_database(test_datapath, train_datapath, data_type)
from transformers import DPRQuestionEncoderTokenizer
print("> loading tokenizer and encoder")
tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
'facebook/dpr-question_encoder-single-nq-base')
encoder = torch.load(model_path).cuda()
print("> getting dialog embeddings")
with torch.no_grad():
for idx, example in tqdm(enumerate(dialog_examples)):
dialog = example[1]
dialog_ids = tokenizer.encode(dialog)
dialog_ids = torch.LongTensor([dialog_ids]).cuda()
dialog_emb = encoder(input_ids=dialog_ids).pooler_output
if idx == 0:
dialog_embeddings = dialog_emb
else:
dialog_embeddings = torch.cat((dialog_embeddings, dialog_emb), dim=0)
print("> reading test data from %s" % test_datapath)
prompt_list_for_each_sample = []
with open(test_datapath, "r") as f:
for i, line in tqdm(enumerate(f)):
line = line.strip()
splits = line.split("\t")
topic = splits[0]
turns = splits[1].split(" [SEP] ")[-3:]
# get the query sentence
query_sent = ""
if data_type != "seen":
query_sent += "( " + topic + " ) "
for i, turn in enumerate(turns):
if i != 0:
query_sent += " "
query_sent += turn
if topic not in train_data_by_topic:
# get the query embedding
query_ids = tokenizer.encode(query_sent)
query_ids = torch.LongTensor([query_ids]).cuda()
query_emb = encoder(input_ids=query_ids).pooler_output
query_emb = query_emb[0]
# calculate the similarity
similarity_list = dialog_embeddings.matmul(query_emb)
_, indices = torch.sort(similarity_list)
indices = indices.tolist()
selected_topics = {}
selected_prompts = []
num_prompt = 0
for index in indices:
example = dialog_examples[index]
topic_temp = example[0]
if topic_temp not in selected_topics:
selected_topics[topic_temp] = True
selected_prompts.append(example[2])
num_prompt += 1
if num_prompt == 10:
break
# get the selected samples
example_list = selected_prompts[::-1]
key = topic + " " + turns[-1]
prompt_list_for_each_sample.append({key: example_list})
else:
num_data_sample = min(len(train_data_by_topic[topic]), 10)
total_example_list = train_data_by_topic[topic]
dialog_list = dialog_data_by_topic[topic]
assert len(dialog_list) == len(train_data_by_topic[topic])
# calculate the similarity
example_list = select_prompts_based_on_similarity(
query_sent, dialog_list, total_example_list,
topic, tokenizer, encoder, topk=num_data_sample)
key = topic + " " + turns[-1]
prompt_list_for_each_sample.append({key: example_list})
print("writing to %s" % output_prompt_path)
with open(output_prompt_path, "w") as f:
for instance in tqdm(prompt_list_for_each_sample):
json.dump(instance, f)
f.write("\n")
def prompt_selection_for_response_generation(input_path, output_path, seed):
"""Selecting prompts for the response generation"""
print("> Selecting prompts for the response generation")
print("> set random seed")
np.random.seed(seed)
prompt_example_list = []
print("> reading data from %s" % input_path)
with open(input_path, "r") as f:
for i, line in tqdm(enumerate(f)):
line = line.strip()
splits = line.split("\t")
# get the topic, context, knowledge and response
topic = splits[0]
dialog_context = splits[1]
knowledge = splits[2]
response = splits[3]
turns = dialog_context.split(" [SEP] ")[-3:]
if knowledge == "no_passages_used":
continue
# calculate the overlap ratio
from nltk import word_tokenize
knowledge_sent_token_list = word_tokenize(knowledge)
knowledge_sent_token_dict = {token: True for token in knowledge_sent_token_list}
knowledge_len = len(knowledge_sent_token_list)
response_token_list = word_tokenize(response)
response_len = len(response_token_list)
num_overlap_token = 0
accumulator = 0
for token in response_token_list:
if token in knowledge_sent_token_dict:
accumulator += 1
else:
if accumulator >= 10:
num_overlap_token += accumulator
accumulator = 0
if accumulator >= 10:
num_overlap_token += accumulator
# filtering the data based on the ratio
if num_overlap_token > response_len * 0.9 or num_overlap_token < response_len * 0.6:
continue
if num_overlap_token < knowledge_len * 0.8:
continue
last_turn = " ".join(word_tokenize(turns[-1]))
knowledge = " ".join(word_tokenize(knowledge))
response = " ".join(word_tokenize(response))
prompt_example = ""
# add dialog context
prompt_example += "Topic: " + topic + ". "
prompt_example += "User says: " + last_turn + " "
prompt_example += "We know that: " + knowledge + " "
prompt_example += "System replies: " + response
prompt_example_list.append(prompt_example)
# shuffle the prompt examples
np.random.shuffle(prompt_example_list)
print("> writing to %s" % output_path)
with open(output_path, "w") as f:
# f.write("Generate the System's response based on the knowledge sentence:\n")
for i in tqdm(range(20)):
example = prompt_example_list[i]
f.write(example + "\n")
def prepare_input_for_response_generation(test_file, knwl_gen_file, processed_file):
"""Preparing inputs for the response generation"""
print("> Reading knowledge file from %s" % knwl_gen_file)
# get the knowledge list
with open(knwl_gen_file, "r") as f:
knowledge_list = f.readlines()
print("> Processing ...")
with open(test_file, "r") as fr:
with open(processed_file, "w") as fw:
for line_num, line in enumerate(tqdm(fr)):
line = line.strip()
splits = line.split("\t")
# prepare topic, context, knowledge and response
topic = splits[0]
dialog_context = splits[1]
response = splits[3]
knowledge = knowledge_list[line_num]
knowledge = knowledge.strip()
if "<|endoftext|>" in knowledge:
knowledge = knowledge.replace("<|endoftext|>", "")
# write to the output file
fw.write(topic + "\t" + dialog_context + "\t" \
+ knowledge + "\t" + response + "\n")
if __name__ == "__main__":
args = get_args()
if args.func == "process_wow_dataset":
process_wow_dataset(args.raw_file, args.processed_file, args.knwl_ref_file, args.resp_ref_file)
elif args.func == "process_woi_dataset":
process_woi_dataset(args.raw_file, args.processed_file, args.knwl_ref_file, args.resp_ref_file)
elif args.func == "get_knwl_gen_prompts":
prompt_selection_for_knowledge_generation(
args.test_file, args.train_file, args.model_file,
args.processed_file, args.data_type)
elif args.func == "get_resp_gen_prompts":
prompt_selection_for_response_generation(
args.train_file, args.processed_file, args.seed)
elif args.func == "prepare_input":
prepare_input_for_response_generation(
args.test_file, args.knwl_gen_file, args.processed_file)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment