"vscode:/vscode.git/clone" did not exist on "5282a4735f5b0bc99303bc57838fa926aa437066"
Commit aed2f75e authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'main' into github-main

parents 8aa4619f f32a638d
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import reduce
import operator
import torch
from megatron import get_args
from megatron import mpu
def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
use_ring_exchange=False):
"""Communicate tensors between stages. Used as helper method in other
communication methods that are used in megatron/schedules.py.
Takes the following arguments:
tensor_send_next: tensor to send to next rank (no tensor sent if
set to None).
tensor_send_prev: tensor to send to prev rank (no tensor sent if
set to None).
recv_prev: boolean for whether tensor should be received from
previous rank.
recv_next: boolean for whether tensor should be received from
next rank.
use_ring_exchange: boolean for whether torch.distributed.ring_exchange()
API should be used.
Returns:
(tensor_recv_prev, tensor_recv_next)
"""
args = get_args()
# Create placeholder tensors for receive in forward and backward directions
# if needed.
tensor_recv_prev = None
tensor_recv_next = None
tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
if args.scatter_gather_tensors_in_pipeline:
tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1) // \
mpu.get_tensor_model_parallel_world_size()
else:
tensor_chunk_shape = tensor_shape
dtype = args.params_dtype
if args.fp32_residual_connection:
dtype = torch.float
if recv_prev:
tensor_recv_prev = torch.empty(tensor_chunk_shape,
requires_grad=True,
device=torch.cuda.current_device(),
dtype=dtype)
if recv_next:
tensor_recv_next = torch.empty(tensor_chunk_shape,
requires_grad=True,
device=torch.cuda.current_device(),
dtype=dtype)
# Split tensor into smaller chunks if using scatter-gather optimization.
if args.scatter_gather_tensors_in_pipeline:
if tensor_send_next is not None:
tensor_send_next = mpu.split_tensor_into_1d_equal_chunks(tensor_send_next)
if tensor_send_prev is not None:
tensor_send_prev = mpu.split_tensor_into_1d_equal_chunks(tensor_send_prev)
# Send tensors in both the forward and backward directions as appropriate.
if use_ring_exchange:
torch.distributed.ring_exchange(tensor_send_prev=tensor_send_prev,
tensor_recv_prev=tensor_recv_prev,
tensor_send_next=tensor_send_next,
tensor_recv_next=tensor_recv_next,
group=mpu.get_pipeline_model_parallel_group())
else:
ops = []
if tensor_send_prev is not None:
send_prev_op = torch.distributed.P2POp(
torch.distributed.isend, tensor_send_prev,
mpu.get_pipeline_model_parallel_prev_rank())
ops.append(send_prev_op)
if tensor_recv_prev is not None:
recv_prev_op = torch.distributed.P2POp(
torch.distributed.irecv, tensor_recv_prev,
mpu.get_pipeline_model_parallel_prev_rank())
ops.append(recv_prev_op)
if tensor_send_next is not None:
send_next_op = torch.distributed.P2POp(
torch.distributed.isend, tensor_send_next,
mpu.get_pipeline_model_parallel_next_rank())
ops.append(send_next_op)
if tensor_recv_next is not None:
recv_next_op = torch.distributed.P2POp(
torch.distributed.irecv, tensor_recv_next,
mpu.get_pipeline_model_parallel_next_rank())
ops.append(recv_next_op)
if len(ops) > 0:
reqs = torch.distributed.batch_isend_irecv(ops)
for req in reqs:
req.wait()
# To protect against race condition when using batch_isend_irecv().
torch.cuda.synchronize()
# If using scatter-gather optimization, gather smaller chunks.
if args.scatter_gather_tensors_in_pipeline:
if recv_prev:
tensor_recv_prev = mpu.gather_split_1d_tensor(
tensor_recv_prev).view(tensor_shape).requires_grad_()
if recv_next:
tensor_recv_next = mpu.gather_split_1d_tensor(
tensor_recv_next).view(tensor_shape).requires_grad_()
return tensor_recv_prev, tensor_recv_next
def recv_forward(timers=None):
"""Receive tensor from previous rank in pipeline (forward receive)."""
if mpu.is_pipeline_first_stage():
input_tensor = None
else:
if timers is not None:
timers('forward-recv').start()
input_tensor, _ = _communicate(
tensor_send_next=None,
tensor_send_prev=None,
recv_prev=True,
recv_next=False)
if timers is not None:
timers('forward-recv').stop()
return input_tensor
def recv_backward(timers=None):
"""Receive tensor from next rank in pipeline (backward receive)."""
if mpu.is_pipeline_last_stage():
output_tensor_grad = None
else:
if timers is not None:
timers('backward-recv').start()
_, output_tensor_grad = _communicate(
tensor_send_next=None,
tensor_send_prev=None,
recv_prev=False,
recv_next=True)
if timers is not None:
timers('backward-recv').stop()
return output_tensor_grad
def send_forward(output_tensor, timers=None):
"""Send tensor to next rank in pipeline (forward send)."""
if not mpu.is_pipeline_last_stage():
if timers is not None:
timers('forward-send').start()
_communicate(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_prev=False,
recv_next=False)
if timers is not None:
timers('forward-send').stop()
def send_backward(input_tensor_grad, timers=None):
"""Send tensor to previous rank in pipeline (backward send)."""
if not mpu.is_pipeline_first_stage():
if timers is not None:
timers('backward-send').start()
_communicate(
tensor_send_next=None,
tensor_send_prev=input_tensor_grad,
recv_prev=False,
recv_next=False)
if timers is not None:
timers('backward-send').stop()
def send_forward_recv_backward(output_tensor, timers=None):
"""Batched send and recv with next rank in pipeline."""
if mpu.is_pipeline_last_stage():
output_tensor_grad = None
else:
if timers is not None:
timers('forward-send-backward-recv').start()
_, output_tensor_grad = _communicate(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_prev=False,
recv_next=True)
if timers is not None:
timers('forward-send-backward-recv').stop()
return output_tensor_grad
def send_backward_recv_forward(input_tensor_grad, timers=None):
"""Batched send and recv with previous rank in pipeline."""
if mpu.is_pipeline_first_stage():
input_tensor = None
else:
if timers is not None:
timers('backward-send-forward-recv').start()
input_tensor, _ = _communicate(
tensor_send_next=None,
tensor_send_prev=input_tensor_grad,
recv_prev=True,
recv_next=False)
if timers is not None:
timers('backward-send-forward-recv').stop()
return input_tensor
def send_forward_recv_forward(output_tensor, recv_prev, timers=None):
"""Batched recv from previous rank and send to next rank in pipeline."""
if timers is not None:
timers('forward-send-forward-recv').start()
input_tensor, _ = _communicate(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_prev=recv_prev,
recv_next=False)
if timers is not None:
timers('forward-send-forward-recv').stop()
return input_tensor
def send_backward_recv_backward(input_tensor_grad, recv_next, timers=None):
"""Batched recv from next rank and send to previous rank in pipeline."""
if timers is not None:
timers('backward-send-backward-recv').start()
_, output_tensor_grad = _communicate(
tensor_send_next=None,
tensor_send_prev=input_tensor_grad,
recv_prev=False,
recv_next=recv_next)
if timers is not None:
timers('backward-send-backward-recv').stop()
return output_tensor_grad
def send_forward_backward_recv_forward_backward(
output_tensor, input_tensor_grad, recv_prev,
recv_next, timers=None):
"""Batched send and recv with previous and next ranks in pipeline."""
if timers is not None:
timers('forward-backward-send-forward-backward-recv').start()
input_tensor, output_tensor_grad = _communicate(
tensor_send_next=output_tensor,
tensor_send_prev=input_tensor_grad,
recv_prev=recv_prev,
recv_next=recv_next)
if timers is not None:
timers('forward-backward-send-forward-backward-recv').stop()
return input_tensor, output_tensor_grad
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron import get_args
from megatron import get_num_microbatches
from megatron import get_timers
from megatron import mpu
from megatron import p2p_communication
from megatron.utils import unwrap_model
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
def get_forward_backward_func():
args = get_args()
if mpu.get_pipeline_model_parallel_world_size() > 1:
if args.virtual_pipeline_model_parallel_size is not None:
forward_backward_func = forward_backward_pipelining_with_interleaving
else:
forward_backward_func = forward_backward_pipelining_without_interleaving
else:
forward_backward_func = forward_backward_no_pipelining
return forward_backward_func
def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_reduced):
"""Forward step for passed-in model.
If first stage, input tensor is obtained from data_iterator, otherwise
passed-in input_tensor is used.
Returns output tensor."""
timers = get_timers()
timers('forward-compute').start()
unwrapped_model = unwrap_model(
model, (torchDDP, LocalDDP, Float16Module))
unwrapped_model.set_input_tensor(input_tensor)
output_tensor, loss_func = forward_step_func(data_iterator, model)
if mpu.is_pipeline_last_stage():
output_tensor = loss_func(output_tensor)
loss, loss_reduced = output_tensor
output_tensor = loss / get_num_microbatches()
losses_reduced.append(loss_reduced)
timers('forward-compute').stop()
return output_tensor
def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad):
"""Backward step through passed-in output tensor.
If last stage, output_tensor_grad is None, otherwise gradient of loss
with respect to stage's output tensor.
Returns gradient of loss with respect to input tensor (None if first
stage)."""
args = get_args()
timers = get_timers()
timers('backward-compute').start()
# Retain the grad on the input_tensor.
if input_tensor is not None:
input_tensor.retain_grad()
# Backward pass.
if output_tensor_grad is None:
output_tensor = optimizer.scale_loss(output_tensor)
torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad)
# Collect the grad of the input_tensor.
input_tensor_grad = None
if input_tensor is not None:
input_tensor_grad = input_tensor.grad
timers('backward-compute').stop()
return input_tensor_grad
@contextmanager
def dummy_handler():
try:
yield
finally:
pass
def forward_backward_no_pipelining(forward_step_func, data_iterator, model,
optimizer, timers, forward_only):
"""Run forward and backward passes with no pipeline parallelism
(no inter-stage communication).
Returns dictionary with losses."""
assert len(model) == 1
model = model[0]
context_handler = dummy_handler
if isinstance(model, torchDDP):
context_handler = model.no_sync
losses_reduced = []
input_tensor, output_tensor_grad = None, None
with context_handler():
for i in range(get_num_microbatches() - 1):
output_tensor = forward_step(forward_step_func, data_iterator, model,
input_tensor, losses_reduced)
if not forward_only:
backward_step(optimizer, input_tensor, output_tensor,
output_tensor_grad)
# Run computation for last microbatch out of context handler (want to
# synchronize gradients).
output_tensor = forward_step(forward_step_func, data_iterator, model,
input_tensor, losses_reduced)
if not forward_only:
backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad)
return losses_reduced
def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterator, model,
optimizer, timers, forward_only):
"""Run interleaved 1F1B schedule (model split into model chunks), with
communication between pipeline stages as needed.
Returns dictionary with losses if the last stage, empty dict otherwise."""
input_tensors = [[] for _ in range(len(model))]
output_tensors = [[] for _ in range(len(model))]
losses_reduced = []
if not forward_only:
output_tensor_grads = [[] for _ in range(len(model))]
pipeline_parallel_size = mpu.get_pipeline_model_parallel_world_size()
pipeline_parallel_rank = mpu.get_pipeline_model_parallel_rank()
# Compute number of warmup and remaining microbatches.
num_model_chunks = len(model)
num_microbatches = get_num_microbatches() * num_model_chunks
all_warmup_microbatches = False
if forward_only:
num_warmup_microbatches = num_microbatches
else:
# Run all forward passes and then all backward passes if number of
# microbatches is just the number of pipeline stages.
# Otherwise, perform (num_model_chunks-1)*pipeline_parallel_size on
# all workers, followed by more microbatches after depending on
# stage ID (more forward passes for earlier stages, later stages can
# immediately start with 1F1B).
if get_num_microbatches() == pipeline_parallel_size:
num_warmup_microbatches = num_microbatches
all_warmup_microbatches = True
else:
num_warmup_microbatches = \
(pipeline_parallel_size - pipeline_parallel_rank - 1) * 2
num_warmup_microbatches += (
num_model_chunks - 1) * pipeline_parallel_size
num_warmup_microbatches = min(num_warmup_microbatches,
num_microbatches)
num_microbatches_remaining = \
num_microbatches - num_warmup_microbatches
def get_model_chunk_id(microbatch_id, forward):
"""Helper method to get the model chunk ID given the iteration number."""
microbatch_id_in_group = microbatch_id % (pipeline_parallel_size * num_model_chunks)
model_chunk_id = microbatch_id_in_group // pipeline_parallel_size
if not forward:
model_chunk_id = (num_model_chunks - model_chunk_id - 1)
return model_chunk_id
def forward_step_helper(microbatch_id):
"""Helper method to run forward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
forward_step())."""
model_chunk_id = get_model_chunk_id(microbatch_id, forward=True)
mpu.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
if mpu.is_pipeline_first_stage():
if len(input_tensors[model_chunk_id]) == \
len(output_tensors[model_chunk_id]):
input_tensors[model_chunk_id].append(None)
input_tensor = input_tensors[model_chunk_id][-1]
output_tensor = forward_step(forward_step_func,
data_iterator[model_chunk_id],
model[model_chunk_id],
input_tensor, losses_reduced)
output_tensors[model_chunk_id].append(output_tensor)
return output_tensor
def backward_step_helper(microbatch_id):
"""Helper method to run backward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
backward_step())."""
model_chunk_id = get_model_chunk_id(microbatch_id, forward=False)
mpu.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
if mpu.is_pipeline_last_stage():
if len(output_tensor_grads[model_chunk_id]) == 0:
output_tensor_grads[model_chunk_id].append(None)
input_tensor = input_tensors[model_chunk_id].pop(0)
output_tensor = output_tensors[model_chunk_id].pop(0)
output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0)
input_tensor_grad = \
backward_step(optimizer,
input_tensor,
output_tensor,
output_tensor_grad)
return input_tensor_grad
# Run warmup forward passes.
mpu.set_virtual_pipeline_model_parallel_rank(0)
input_tensors[0].append(
p2p_communication.recv_forward(timers))
for k in range(num_warmup_microbatches):
output_tensor = forward_step_helper(k)
# Determine if tensor should be received from previous stage.
next_forward_model_chunk_id = get_model_chunk_id(k+1, forward=True)
recv_prev = True
if mpu.is_pipeline_first_stage(ignore_virtual=True):
if next_forward_model_chunk_id == 0:
recv_prev = False
if k == (num_microbatches - 1):
recv_prev = False
# Don't send tensor downstream if on last stage.
if mpu.is_pipeline_last_stage():
output_tensor = None
# Send and receive tensors as appropriate (send tensors computed
# in this iteration; receive tensors for next iteration).
if k == (num_warmup_microbatches - 1) and not forward_only and \
not all_warmup_microbatches:
input_tensor_grad = None
recv_next = True
if mpu.is_pipeline_last_stage(ignore_virtual=True):
recv_next = False
input_tensor, output_tensor_grad = \
p2p_communication.send_forward_backward_recv_forward_backward(
output_tensor, input_tensor_grad,
recv_prev=recv_prev, recv_next=recv_next,
timers=timers)
output_tensor_grads[num_model_chunks-1].append(output_tensor_grad)
else:
input_tensor = \
p2p_communication.send_forward_recv_forward(
output_tensor, recv_prev, timers)
input_tensors[next_forward_model_chunk_id].append(input_tensor)
# Run 1F1B in steady state.
for k in range(num_microbatches_remaining):
# Forward pass.
forward_k = k + num_warmup_microbatches
output_tensor = forward_step_helper(forward_k)
# Backward pass.
backward_k = k
input_tensor_grad = backward_step_helper(backward_k)
# Send output_tensor and input_tensor_grad, receive input_tensor
# and output_tensor_grad.
# Determine if current stage has anything to send in either direction,
# otherwise set tensor to None.
forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
mpu.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id)
if mpu.is_pipeline_last_stage():
output_tensor = None
backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
mpu.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id)
if mpu.is_pipeline_first_stage():
input_tensor_grad = None
# Determine if peers are sending, and where in data structure to put
# received tensors.
recv_prev = True
if mpu.is_pipeline_first_stage(ignore_virtual=True):
# First stage is ahead of last stage by (pipeline_parallel_size - 1).
next_forward_model_chunk_id = get_model_chunk_id(
forward_k - (pipeline_parallel_size - 1), forward=True)
if next_forward_model_chunk_id == (num_model_chunks - 1):
recv_prev = False
next_forward_model_chunk_id += 1
else:
next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1,
forward=True)
recv_next = True
if mpu.is_pipeline_last_stage(ignore_virtual=True):
# Last stage is ahead of first stage by (pipeline_parallel_size - 1).
next_backward_model_chunk_id = get_model_chunk_id(
backward_k - (pipeline_parallel_size - 1), forward=False)
if next_backward_model_chunk_id == 0:
recv_next = False
next_backward_model_chunk_id -= 1
else:
next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1,
forward=False)
# If last iteration, don't receive; we already received one extra
# before the start of the for loop.
if k == (num_microbatches_remaining - 1):
recv_prev = False
# Communicate tensors.
input_tensor, output_tensor_grad = \
p2p_communication.send_forward_backward_recv_forward_backward(
output_tensor, input_tensor_grad,
recv_prev=recv_prev, recv_next=recv_next,
timers=timers)
# Put input_tensor and output_tensor_grad in data structures in the
# right location.
if recv_prev:
input_tensors[next_forward_model_chunk_id].append(input_tensor)
if recv_next:
output_tensor_grads[next_backward_model_chunk_id].append(
output_tensor_grad)
# Run cooldown backward passes (flush out pipeline).
if not forward_only:
if all_warmup_microbatches:
output_tensor_grads[num_model_chunks-1].append(
p2p_communication.recv_backward(timers))
for k in range(num_microbatches_remaining, num_microbatches):
input_tensor_grad = backward_step_helper(k)
next_backward_model_chunk_id = get_model_chunk_id(k+1, forward=False)
recv_next = True
if mpu.is_pipeline_last_stage(ignore_virtual=True):
if next_backward_model_chunk_id == (num_model_chunks - 1):
recv_next = False
if k == (num_microbatches - 1):
recv_next = False
output_tensor_grads[next_backward_model_chunk_id].append(
p2p_communication.send_backward_recv_backward(
input_tensor_grad, recv_next, timers))
return losses_reduced
def forward_backward_pipelining_without_interleaving(forward_step_func, data_iterator,
model, optimizer, timers,
forward_only):
"""Run non-interleaved 1F1B schedule, with communication between pipeline
stages.
Returns dictionary with losses if the last stage, empty dict otherwise."""
timers = get_timers()
assert len(model) == 1
model = model[0]
# Compute number of warmup microbatches.
num_microbatches = get_num_microbatches()
num_warmup_microbatches = \
(mpu.get_pipeline_model_parallel_world_size() -
mpu.get_pipeline_model_parallel_rank() - 1)
num_warmup_microbatches = min(
num_warmup_microbatches,
num_microbatches)
num_microbatches_remaining = \
num_microbatches - num_warmup_microbatches
input_tensors = []
output_tensors = []
losses_reduced = []
# Run warmup forward passes.
for i in range(num_warmup_microbatches):
input_tensor = p2p_communication.recv_forward(timers)
output_tensor = forward_step(forward_step_func, data_iterator, model,
input_tensor, losses_reduced)
p2p_communication.send_forward(output_tensor, timers)
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
# Before running 1F1B, need to receive first forward tensor.
# If all microbatches are run in warmup / cooldown phase, then no need to
# receive this tensor here.
if num_microbatches_remaining > 0:
input_tensor = p2p_communication.recv_forward(timers)
# Run 1F1B in steady state.
for i in range(num_microbatches_remaining):
last_iteration = (i == (num_microbatches_remaining - 1))
output_tensor = forward_step(forward_step_func, data_iterator, model,
input_tensor, losses_reduced)
if forward_only:
p2p_communication.send_forward(output_tensor, timers)
else:
output_tensor_grad = \
p2p_communication.send_forward_recv_backward(output_tensor,
timers)
# Add input_tensor and output_tensor to end of list, then pop from the
# start of the list for backward pass.
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
if forward_only:
if not last_iteration:
input_tensor = p2p_communication.recv_forward(timers)
else:
input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0)
input_tensor_grad = \
backward_step(optimizer, input_tensor, output_tensor,
output_tensor_grad)
if last_iteration:
input_tensor = None
p2p_communication.send_backward(input_tensor_grad, timers)
else:
input_tensor = \
p2p_communication.send_backward_recv_forward(
input_tensor_grad, timers)
# Run cooldown backward passes.
if not forward_only:
for i in range(num_warmup_microbatches):
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)
output_tensor_grad = p2p_communication.recv_backward(timers)
input_tensor_grad = \
backward_step(optimizer, input_tensor, output_tensor,
output_tensor_grad)
p2p_communication.send_backward(input_tensor_grad, timers)
return losses_reduced
......@@ -26,9 +26,13 @@ import torch.nn.functional as F
from megatron import get_args
from megatron import get_tokenizer
from megatron import mpu
from megatron.training import communicate
from megatron.utils import get_ltor_masks_and_position_ids
from megatron.utils import get_ltor_masks_and_position_ids, unwrap_model
from megatron.p2p_communication import recv_forward, send_forward
# These are needed to unwrap the model, would be nice to put these in megatron.utils if possible?
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
def get_batch(context_tokens):
"""Generate batch from context tokens."""
......@@ -395,55 +399,28 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,
layer_past=None, get_key_value=None,
forward_method_parallel_output=None):
# Hidden size changes when not using recompute, need to tell communicate()
# the correct size
# Hidden size changes when not using recompute, need to tell p2p_communicate
# functions the correct size
args = get_args()
orig_seq_length = args.seq_length
args.seq_length = tokens.shape[1]
if not mpu.is_pipeline_first_stage():
input_tensor, _ = communicate(
tensor_send_next=None,
tensor_send_prev=None,
recv_forward=True,
recv_backward=False)
else:
input_tensor = None
input_tensor = recv_forward()
# Forward pass through the model.
if mpu.is_pipeline_first_stage():
assert input_tensor is None
if mpu.is_pipeline_last_stage():
output_tensor = model(tokens, position_ids, attention_mask,
tokentype_ids=tokentype_ids,
layer_past=layer_past,
get_key_value=get_key_value,
forward_method_parallel_output=forward_method_parallel_output)
else:
output_tensor = model(tokens, position_ids, attention_mask,
tokentype_ids=tokentype_ids,
layer_past=layer_past,
get_key_value=get_key_value)
elif mpu.is_pipeline_last_stage():
assert input_tensor is not None
output_tensor = model(input_tensor, attention_mask,
layer_past=layer_past,
get_key_value=get_key_value,
forward_method_parallel_output=forward_method_parallel_output)
else:
assert input_tensor is not None
output_tensor = model(input_tensor, attention_mask,
layer_past=layer_past,
get_key_value=get_key_value)
unwrapped_model = unwrap_model(
model, (torchDDP, LocalDDP, Float16Module))
unwrapped_model.set_input_tensor(input_tensor)
output_tensor = model(tokens, position_ids, attention_mask,
tokentype_ids=tokentype_ids,
layer_past=layer_past,
get_key_value=get_key_value,
forward_method_parallel_output=forward_method_parallel_output)
if get_key_value:
output_tensor, layer_past = output_tensor
if not mpu.is_pipeline_last_stage():
communicate(tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_forward=False,
recv_backward=False)
send_forward(output_tensor)
args.seq_length = orig_seq_length
if get_key_value:
......
......@@ -37,19 +37,23 @@ from megatron import print_rank_0
from megatron import print_rank_last
from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint
from megatron.model import FP16Module
from megatron.model import Float16Module
from megatron.optimizer import get_megatron_optimizer
from megatron.initialize import initialize_megatron
from megatron.initialize import write_args_to_tensorboard
from megatron.learning_rates import AnnealingLR
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model.realm_model import ICTBertModel
from megatron.utils import check_adlr_autoresume_termination
from megatron.data.data_loaders import build_pretraining_data_loader
from megatron.utils import unwrap_model
from megatron.data.data_samplers import build_pretraining_data_loader
from megatron.utils import calc_params_l2_norm
from megatron.schedules import forward_backward_no_pipelining
from megatron.schedules import forward_backward_pipelining_without_interleaving
from megatron.schedules import forward_backward_pipelining_with_interleaving
from megatron.utils import report_memory
def print_datetime(string):
"""Note that this call will sync across all ranks."""
torch.distributed.barrier()
......@@ -57,8 +61,11 @@ def print_datetime(string):
print_rank_0('[' + string + '] datetime: {} '.format(time_str))
def pretrain(train_valid_test_dataset_provider, model_provider,
forward_step_func, extra_args_provider=None, args_defaults={}):
def pretrain(train_valid_test_dataset_provider,
model_provider,
forward_step_func,
extra_args_provider=None,
args_defaults={}):
"""Main training program.
This function will run the followings in the order provided:
......@@ -103,23 +110,32 @@ def pretrain(train_valid_test_dataset_provider, model_provider,
timers = get_timers()
# Model, optimizer, and learning rate.
timers('model and optimizer').start()
timers('model-and-optimizer-setup').start()
model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider)
timers('model and optimizer').stop()
timers('model-and-optimizer-setup').stop()
print_datetime('after model, optimizer, and learning rate '
'scheduler are built')
# Data stuff.
timers('train/valid/test data iterators').start()
train_data_iterator, valid_data_iterator, test_data_iterator \
= build_train_valid_test_data_iterators(
train_valid_test_dataset_provider)
timers('train/valid/test data iterators').stop()
timers('train/valid/test-data-iterators-setup').start()
if args.virtual_pipeline_model_parallel_size is not None:
all_data_iterators = [
build_train_valid_test_data_iterators(train_valid_test_dataset_provider)
for _ in range(len(model))
]
train_data_iterator = [data_iterators[0] for data_iterators in all_data_iterators]
valid_data_iterator = [data_iterators[1] for data_iterators in all_data_iterators]
test_data_iterator = [data_iterators[2] for data_iterators in all_data_iterators]
else:
train_data_iterator, valid_data_iterator, test_data_iterator \
= build_train_valid_test_data_iterators(
train_valid_test_dataset_provider)
timers('train/valid/test-data-iterators-setup').stop()
print_datetime('after dataloaders are built')
# Print setup timing.
print_rank_0('done with setups ...')
timers.log(['model and optimizer', 'train/valid/test data iterators'])
print_rank_0('done with setup ...')
timers.log(['model-and-optimizer-setup', 'train/valid/test-data-iterators-setup'])
print_rank_0('training ...')
iteration = 0
......@@ -179,15 +195,38 @@ def get_model(model_provider_func):
"""Build the model."""
args = get_args()
# Build model on cpu.
model = model_provider_func()
# Build model.
if mpu.get_pipeline_model_parallel_world_size() > 1 and \
args.virtual_pipeline_model_parallel_size is not None:
model = []
for i in range(args.virtual_pipeline_model_parallel_size):
mpu.set_virtual_pipeline_model_parallel_rank(i)
# Set pre_process and post_process only after virtual rank is set.
pre_process = mpu.is_pipeline_first_stage()
post_process = mpu.is_pipeline_last_stage()
this_model = model_provider_func(
pre_process=pre_process,
post_process=post_process
)
model.append(this_model)
else:
pre_process = mpu.is_pipeline_first_stage()
post_process = mpu.is_pipeline_last_stage()
model = model_provider_func(
pre_process=pre_process,
post_process=post_process
)
if not isinstance(model, list):
model = [model]
# Set tensor model parallel attributes if not set.
# Only parameters that are already tensor model parallel have these
# attributes set for them. We should make sure the default attributes
# are set for all params so the optimizer can use them.
for param in model.parameters():
mpu.set_defaults_if_not_set_tensor_model_parallel_attributes(param)
for model_module in model:
for param in model_module.parameters():
mpu.set_defaults_if_not_set_tensor_model_parallel_attributes(param)
# Print number of parameters.
if mpu.get_data_parallel_rank() == 0:
......@@ -195,22 +234,29 @@ def get_model(model_provider_func):
'model parallel rank ({}, {}): {}'.format(
mpu.get_tensor_model_parallel_rank(),
mpu.get_pipeline_model_parallel_rank(),
sum([p.nelement() for p in model.parameters()])), flush=True)
sum([sum([p.nelement() for p in model_module.parameters()])
for model_module in model])), flush=True)
# GPU allocation.
model.cuda(torch.cuda.current_device())
for model_module in model:
model_module.cuda(torch.cuda.current_device())
# Fp16 conversion.
if args.fp16:
model = FP16Module(model)
if args.fp16 or args.bf16:
model = [Float16Module(model_module, args) for model_module in model]
if args.DDP_impl == 'torch':
i = torch.cuda.current_device()
model = torchDDP(model, device_ids=[i], output_device=i,
process_group=mpu.get_data_parallel_group())
model = [torchDDP(model_module, device_ids=[i], output_device=i,
process_group=mpu.get_data_parallel_group())
for model_module in model]
return model
if args.DDP_impl == 'local':
model = LocalDDP(model)
model = [LocalDDP(model_module,
args.accumulate_allreduce_grads_in_fp32,
args.use_contiguous_buffers_in_ddp)
for model_module in model]
return model
raise NotImplementedError('Unknown DDP implementation specified: {}. '
......@@ -266,9 +312,8 @@ def setup_model_and_optimizer(model_provider_func):
model = get_model(model_provider_func)
unwrapped_model = model
while isinstance(unwrapped_model, (torchDDP, LocalDDP, FP16Module)):
unwrapped_model = unwrapped_model.module
unwrapped_model = unwrap_model(model,
(torchDDP, LocalDDP, Float16Module))
optimizer = get_megatron_optimizer(unwrapped_model)
lr_scheduler = get_learning_rate_scheduler(optimizer)
......@@ -278,305 +323,29 @@ def setup_model_and_optimizer(model_provider_func):
# Extra barrier is added to make sure all ranks report the
# max time.
torch.distributed.barrier()
timers('load checkpoint').start()
timers('load-checkpoint').start()
args.iteration = load_checkpoint(model, optimizer, lr_scheduler)
torch.distributed.barrier()
timers('load checkpoint').stop()
timers.log(['load checkpoint'])
timers('load-checkpoint').stop()
timers.log(['load-checkpoint'])
else:
args.iteration = 0
# We only support local DDP with multiple micro-batches.
if get_num_microbatches() > 1:
if len(model) > 1 or mpu.get_pipeline_model_parallel_world_size() > 1:
assert args.DDP_impl == 'local'
# get model without FP16 and/or TorchDDP wrappers
unwrapped_model = model
while hasattr(unwrapped_model, 'module'):
unwrapped_model = unwrapped_model.module
if args.iteration == 0 and hasattr(unwrapped_model,
'init_state_dict_from_bert'):
print("Initializing ICT from pretrained BERT model", flush=True)
unwrapped_model.init_state_dict_from_bert()
if args.iteration == 0 and len(unwrapped_model) == 1 \
and hasattr(unwrapped_model[0], 'init_state_dict_from_bert'):
print_rank_0("Initializing ICT from pretrained BERT model")
unwrapped_model[0].init_state_dict_from_bert()
if args.fp16:
optimizer.reload_model_params()
return model, optimizer, lr_scheduler
def communicate(tensor_send_next, tensor_send_prev, recv_forward, recv_backward):
"""Communicate tensors between stages."""
args = get_args()
# Create placeholder tensors for receive in forward and backward directions
# if needed.
tensor_recv_prev = None
tensor_recv_next = None
tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
dtype = args.params_dtype
if args.fp32_residual_connection:
dtype = torch.float
if recv_forward:
tensor_recv_prev = torch.empty(tensor_shape,
requires_grad=True,
device=torch.cuda.current_device(),
dtype=dtype)
if recv_backward:
tensor_recv_next = torch.empty(tensor_shape,
requires_grad=True,
device=torch.cuda.current_device(),
dtype=dtype)
# Send tensors in both the forward and backward directions as appropriate.
ops = []
if tensor_send_prev is not None:
send_prev_op = torch.distributed.P2POp(torch.distributed.isend, tensor_send_prev,
mpu.get_pipeline_model_parallel_prev_rank())
ops.append(send_prev_op)
if tensor_recv_prev is not None:
recv_prev_op = torch.distributed.P2POp(torch.distributed.irecv, tensor_recv_prev,
mpu.get_pipeline_model_parallel_prev_rank())
ops.append(recv_prev_op)
if tensor_send_next is not None:
send_next_op = torch.distributed.P2POp(torch.distributed.isend, tensor_send_next,
mpu.get_pipeline_model_parallel_next_rank())
ops.append(send_next_op)
if tensor_recv_next is not None:
recv_next_op = torch.distributed.P2POp(torch.distributed.irecv, tensor_recv_next,
mpu.get_pipeline_model_parallel_next_rank())
ops.append(recv_next_op)
reqs = torch.distributed.batch_isend_irecv(ops)
for req in reqs:
req.wait()
# Temporary workaround for batch_isend_irecv() race condition.
torch.cuda.synchronize()
return tensor_recv_prev, tensor_recv_next
def backward_step(optimizer, model, input_tensor, output_tensor, output_tensor_grad):
"""Backward step."""
args = get_args()
timers = get_timers()
# Retain the grad on the input_tensor.
if input_tensor is not None:
input_tensor.retain_grad()
# Backward pass.
if output_tensor_grad is None:
output_tensor = optimizer.scale_loss(output_tensor)
torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad)
# Collect the grad of the input_tensor.
input_tensor_grad = None
if input_tensor is not None:
input_tensor_grad = input_tensor.grad
return input_tensor_grad
def forward_step_with_communication(forward_step_func, data_iterator, model,
input_tensors, output_tensors,
losses_reduced, timers):
args = get_args()
if not mpu.is_pipeline_first_stage():
timers('forward-recv').start()
input_tensor, _ = communicate(
tensor_send_next=None,
tensor_send_prev=None,
recv_forward=True,
recv_backward=False)
timers('forward-recv').stop()
else:
input_tensor = None
# Forward model for one step.
timers('forward-compute').start()
output_tensor = forward_step_func(data_iterator, model, input_tensor)
timers('forward-compute').stop()
if mpu.is_pipeline_last_stage():
loss, loss_reduced = output_tensor
output_tensor = loss / get_num_microbatches()
losses_reduced.append(loss_reduced)
else:
timers('forward-send').start()
communicate(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_forward=False,
recv_backward=False)
timers('forward-send').stop()
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
def backward_step_with_communication(optimizer, model, input_tensors, output_tensors, timers):
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)
if mpu.is_pipeline_last_stage():
output_tensor_grad = None
else:
timers('backward-recv').start()
_, output_tensor_grad = communicate(
tensor_send_next=None,
tensor_send_prev=None,
recv_forward=False,
recv_backward=True)
timers('backward-recv').stop()
# Backward pass for one step.
timers('backward-compute').start()
input_grad_tensor = \
backward_step(optimizer, model, input_tensor, output_tensor, output_tensor_grad)
timers('backward-compute').stop()
if not mpu.is_pipeline_first_stage():
timers('backward-send').start()
communicate(
tensor_send_next=None,
tensor_send_prev=input_grad_tensor,
recv_forward=False,
recv_backward=False)
timers('backward-send').stop()
def forward_and_backward_steps_with_communication(forward_step_func, data_iterator, model,
optimizer,
input_tensor, last_microbatch,
input_tensors, output_tensors,
losses_reduced, timers):
args = get_args()
# Forward model for one step.
timers('forward-compute').start()
output_tensor = forward_step_func(data_iterator, model, input_tensor)
timers('forward-compute').stop()
if mpu.is_pipeline_last_stage():
loss, loss_reduced = output_tensor
output_tensor = loss / get_num_microbatches()
output_tensor_grad = None
losses_reduced.append(loss_reduced)
else:
timers('forward-send-backward-recv').start()
_, output_tensor_grad = communicate(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_forward=False,
recv_backward=True)
timers('forward-send-backward-recv').stop()
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)
# Backward pass for one step.
timers('backward-compute').start()
input_grad_tensor = \
backward_step(optimizer, model, input_tensor, output_tensor, output_tensor_grad)
timers('backward-compute').stop()
if not mpu.is_pipeline_first_stage():
timers('backward-send-forward-recv').start()
input_tensor, _ = communicate(
tensor_send_next=None,
tensor_send_prev=input_grad_tensor,
recv_forward=(not last_microbatch),
recv_backward=False)
timers('backward-send-forward-recv').stop()
else:
input_tensor = None
return input_tensor
def forward_backward_no_pipelining(forward_step_func, data_iterator, model,
optimizer, timers):
"""Run forward and backward passes without inter-stage communication."""
args = get_args()
losses_reduced = []
for i in range(get_num_microbatches()):
timers('forward-compute').start()
loss, loss_reduced = forward_step_func(data_iterator, model, input_tensor=None)
output_tensor = loss / get_num_microbatches()
losses_reduced.append(loss_reduced)
timers('forward-compute').stop()
timers('backward-compute').start()
output_tensor_grad = None
backward_step(optimizer, model, input_tensor=None,
output_tensor=output_tensor, output_tensor_grad=None)
timers('backward-compute').stop()
return losses_reduced
def forward_backward_pipelining(forward_step_func, data_iterator, model,
optimizer, timers):
"""Run 1F1B schedule, with communication and warmup + cooldown microbatches as needed."""
args = get_args()
# Compute number of warmup microbatches.
num_microbatches = get_num_microbatches()
num_warmup_microbatches = \
(mpu.get_pipeline_model_parallel_world_size() -
mpu.get_pipeline_model_parallel_rank() - 1)
num_warmup_microbatches = min(
num_warmup_microbatches,
num_microbatches)
num_microbatches_remaining = \
num_microbatches - num_warmup_microbatches
input_tensors = []
output_tensors = []
losses_reduced = []
# Run warmup forward passes.
for i in range(num_warmup_microbatches):
forward_step_with_communication(
forward_step_func, data_iterator, model,
input_tensors, output_tensors,
losses_reduced, timers)
# Before running 1F1B, need to receive first forward tensor.
# If all microbatches are run in warmup / cooldown phase, then no need to
# receive this tensor here.
if num_microbatches_remaining > 0:
if mpu.is_pipeline_first_stage():
input_tensor = None
else:
timers('forward-recv').start()
input_tensor, _ = communicate(tensor_send_next=None,
tensor_send_prev=None,
recv_forward=True,
recv_backward=False)
timers('forward-recv').stop()
# Run 1F1B.
for i in range(num_microbatches_remaining):
last_iteration = (i == (num_microbatches_remaining - 1))
input_tensor = \
forward_and_backward_steps_with_communication(forward_step_func, data_iterator, model,
optimizer,
input_tensor, last_iteration,
input_tensors, output_tensors,
losses_reduced, timers)
# Run cooldown backward passes.
for i in range(num_warmup_microbatches):
backward_step_with_communication(
optimizer, model, input_tensors, output_tensors, timers)
return losses_reduced
def train_step(forward_step_func, data_iterator,
model, optimizer, lr_scheduler):
"""Single training step."""
......@@ -584,20 +353,31 @@ def train_step(forward_step_func, data_iterator,
timers = get_timers()
# Set grad to zero.
optimizer.zero_grad()
if args.DDP_impl == 'local' and args.use_contiguous_buffers_in_ddp:
for partition in model:
partition.zero_grad_buffer()
else:
optimizer.zero_grad()
if mpu.get_pipeline_model_parallel_world_size() > 1:
losses_reduced = forward_backward_pipelining(
forward_step_func, data_iterator, model, optimizer, timers)
if args.virtual_pipeline_model_parallel_size is not None:
forward_backward_func = forward_backward_pipelining_with_interleaving
assert get_num_microbatches() % args.pipeline_model_parallel_size == 0, \
'number of microbatches is not divisible by pipeline-parallel ' \
'size when using interleaved schedule'
else:
forward_backward_func = forward_backward_pipelining_without_interleaving
else:
losses_reduced = forward_backward_no_pipelining(
forward_step_func, data_iterator, model, optimizer, timers)
forward_backward_func = forward_backward_no_pipelining
losses_reduced = forward_backward_func(
forward_step_func, data_iterator, model,
optimizer, timers, forward_only=False)
# All-reduce if needed.
if args.DDP_impl == 'local':
timers('backward-params-all-reduce').start()
model.allreduce_params(reduce_after=False,
fp32_allreduce=args.fp32_allreduce)
for model_module in model:
model_module.allreduce_gradients()
timers('backward-params-all-reduce').stop()
# All-reduce word_embeddings' grad across first and last stages to ensure
......@@ -605,25 +385,32 @@ def train_step(forward_step_func, data_iterator,
# This should only run for models that support pipelined model parallelism
# (BERT and GPT-2).
timers('backward-embedding-all-reduce').start()
if (mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage()) and \
if (mpu.is_pipeline_first_stage(ignore_virtual=True) or
mpu.is_pipeline_last_stage(ignore_virtual=True)) and \
mpu.get_pipeline_model_parallel_world_size() > 1:
unwrapped_model = model
while isinstance(unwrapped_model, (torchDDP, LocalDDP, FP16Module)):
unwrapped_model = unwrapped_model.module
if mpu.is_pipeline_first_stage(ignore_virtual=True):
unwrapped_model = model[0]
elif mpu.is_pipeline_last_stage(ignore_virtual=True):
unwrapped_model = model[-1]
unwrapped_model = unwrap_model(
unwrapped_model, (torchDDP, LocalDDP, Float16Module))
if unwrapped_model.share_word_embeddings:
word_embeddings_weight = unwrapped_model.word_embeddings_weight()
torch.distributed.all_reduce(word_embeddings_weight.grad,
group=mpu.get_embedding_group())
if args.DDP_impl == 'local':
grad = word_embeddings_weight.main_grad
else:
grad = word_embeddings_weight.grad
torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())
timers('backward-embedding-all-reduce').stop()
# Update parameters.
timers('optimizer').start()
update_successfull = optimizer.step()
update_successful, grad_norm, num_zeros_in_grad = optimizer.step()
timers('optimizer').stop()
# Update learning rate.
if update_successfull:
if update_successful:
increment = get_num_microbatches() * \
args.micro_batch_size * \
args.data_parallel_size
......@@ -632,18 +419,19 @@ def train_step(forward_step_func, data_iterator,
else:
skipped_iter = 1
if mpu.is_pipeline_last_stage():
if mpu.is_pipeline_last_stage(ignore_virtual=True):
# Average loss across microbatches.
loss_reduced = {}
for key in losses_reduced[0]:
losses_reduced_for_key = [x[key] for x in losses_reduced]
loss_reduced[key] = sum(losses_reduced_for_key) / len(losses_reduced_for_key)
return loss_reduced, skipped_iter
return {}, skipped_iter
return loss_reduced, skipped_iter, grad_norm, num_zeros_in_grad
return {}, skipped_iter, grad_norm, num_zeros_in_grad
def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
loss_scale, report_memory_flag, skipped_iter):
loss_scale, report_memory_flag, skipped_iter,
grad_norm, params_norm, num_zeros_in_grad):
"""Log training information such as losses, timing, ...."""
args = get_args()
timers = get_timers()
......@@ -687,11 +475,12 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
add_to_logging('forward-compute')
add_to_logging('forward-recv')
add_to_logging('forward-send')
add_to_logging('forward-send-backward-recv')
add_to_logging('forward-backward-send-forward-backward-recv')
add_to_logging('backward-compute')
add_to_logging('backward-recv')
add_to_logging('backward-send')
add_to_logging('backward-send-forward-recv')
add_to_logging('backward-send-backward-recv')
add_to_logging('backward-params-all-reduce')
add_to_logging('backward-embedding-all-reduce')
add_to_logging('optimizer-copy-to-main-grad')
......@@ -709,29 +498,47 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
total_loss_dict[skipped_iters_key]
# Tensorboard values.
if writer and is_last_rank():
writer.add_scalar('learning-rate', learning_rate, iteration)
writer.add_scalar('learning-rate vs samples', learning_rate,
args.consumed_train_samples)
writer.add_scalar('batch-size', batch_size, iteration)
writer.add_scalar('batch-size vs samples', batch_size,
args.consumed_train_samples)
if writer and (iteration % args.tensorboard_log_interval == 0 ) and \
is_last_rank():
if args.log_learning_rate_to_tensorboard:
writer.add_scalar('learning-rate', learning_rate, iteration)
writer.add_scalar('learning-rate vs samples', learning_rate,
args.consumed_train_samples)
if args.log_batch_size_to_tensorboard:
writer.add_scalar('batch-size', batch_size, iteration)
writer.add_scalar('batch-size vs samples', batch_size,
args.consumed_train_samples)
for key in loss_dict:
writer.add_scalar(key , loss_dict[key], iteration)
writer.add_scalar(key + ' vs samples', loss_dict[key],
args.consumed_train_samples)
writer.add_scalar('loss-scale', loss_scale, iteration)
writer.add_scalar('loss-scale vs samples', loss_scale,
args.consumed_train_samples)
timers.write(timers_to_log, writer, iteration,
normalizer=total_iterations)
if args.log_loss_scale_to_tensorboard:
writer.add_scalar('loss-scale', loss_scale, iteration)
writer.add_scalar('loss-scale vs samples', loss_scale,
args.consumed_train_samples)
if grad_norm is not None:
writer.add_scalar('grad-norm', grad_norm, iteration)
writer.add_scalar('grad-norm vs samples', grad_norm,
args.consumed_train_samples)
if num_zeros_in_grad is not None:
writer.add_scalar('num-zeros', num_zeros_in_grad, iteration)
writer.add_scalar('num-zeros vs samples', num_zeros_in_grad,
args.consumed_train_samples)
if params_norm is not None:
writer.add_scalar('params-norm', params_norm, iteration)
writer.add_scalar('params-norm vs samples', params_norm,
args.consumed_train_samples)
if args.log_timers_to_tensorboard:
timers.write(timers_to_log, writer, iteration,
normalizer=total_iterations)
if iteration % args.log_interval == 0:
elapsed_time = timers('interval time').elapsed()
elapsed_time = timers('interval-time').elapsed()
elapsed_time_per_iteration = elapsed_time / total_iterations
if writer and torch.distributed.get_rank() == 0:
writer.add_scalar('iteration-time',
elapsed_time_per_iteration, iteration)
if args.log_timers_to_tensorboard:
writer.add_scalar('iteration-time',
elapsed_time_per_iteration, iteration)
log_string = ' iteration {:8d}/{:8d} |'.format(
iteration, args.train_iters)
log_string += ' consumed samples: {:12d} |'.format(
......@@ -749,6 +556,12 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
log_string += ' {}: {:.6E} |'.format(key, avg)
total_loss_dict[key] = torch.cuda.FloatTensor([0.0])
log_string += ' loss scale: {:.1f} |'.format(loss_scale)
if grad_norm is not None:
log_string += ' grad norm: {:.3f} |'.format(grad_norm)
if num_zeros_in_grad is not None:
log_string += ' num zeros: {:.1f} |'.format(num_zeros_in_grad)
if params_norm is not None:
log_string += ' params norm: {:.3f} |'.format(params_norm)
log_string += ' number of skipped iterations: {:3d} |'.format(
total_loss_dict[skipped_iters_key])
log_string += ' number of nan iterations: {:3d} |'.format(
......@@ -771,11 +584,11 @@ def save_checkpoint_and_time(iteration, model, optimizer, lr_scheduler):
# Extra barrier is added to make sure
# all ranks report the max time.
torch.distributed.barrier()
timers('save checkpoint').start()
timers('save-checkpoint').start()
save_checkpoint(iteration, model, optimizer, lr_scheduler)
torch.distributed.barrier()
timers('save checkpoint').stop()
timers.log(['save checkpoint'])
timers('save-checkpoint').stop()
timers.log(['save-checkpoint'])
def train(forward_step_func, model, optimizer, lr_scheduler,
......@@ -788,7 +601,8 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
write_args_to_tensorboard()
# Turn on training mode which enables dropout.
model.train()
for model_module in model:
model_module.train()
# Tracking loss.
total_loss_dict = {}
......@@ -796,16 +610,17 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
# Iterations.
iteration = args.iteration
timers('interval time').start()
timers('interval-time').start()
print_datetime('before the start of training step')
report_memory_flag = True
while iteration < args.train_iters:
update_num_microbatches(args.consumed_train_samples)
loss_dict, skipped_iter = train_step(forward_step_func,
train_data_iterator,
model,
optimizer,
lr_scheduler)
loss_dict, skipped_iter, grad_norm, num_zeros_in_grad = \
train_step(forward_step_func,
train_data_iterator,
model,
optimizer,
lr_scheduler)
iteration += 1
args.consumed_train_samples += mpu.get_data_parallel_world_size() * \
args.micro_batch_size * \
......@@ -813,10 +628,14 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
# Logging.
loss_scale = optimizer.get_loss_scale().item()
params_norm = None
if args.log_params_norm:
params_norm = calc_params_l2_norm(model)
report_memory_flag = training_log(loss_dict, total_loss_dict,
optimizer.param_groups[0]['lr'],
iteration, loss_scale,
report_memory_flag, skipped_iter)
report_memory_flag, skipped_iter,
grad_norm, params_norm, num_zeros_in_grad)
# Autoresume
if args.adlr_autoresume and \
......@@ -852,16 +671,16 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
if not saved_checkpoint:
save_checkpoint_and_time(iteration, model, optimizer,
lr_scheduler)
print_datetime('exiting program after {} minutes'.format(train_time))
print_datetime('exiting program after {} minutes'.format(train_time))
sys.exit()
# Exiting based on iterations
# Exiting based on iterations
if args.exit_interval and iteration % args.exit_interval == 0:
if not saved_checkpoint:
save_checkpoint_and_time(iteration, model, optimizer,
lr_scheduler)
torch.distributed.barrier()
print_datetime('exiting program at iteration {}'.format(iteration))
print_datetime('exiting program at iteration {}'.format(iteration))
sys.exit()
......@@ -873,7 +692,8 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
args = get_args()
# Turn on evaluation mode which disables dropout.
model.eval()
for model_module in model:
model_module.eval()
total_loss_dict = {}
......@@ -885,37 +705,30 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
print_rank_0('Evaluating iter {}/{}'.format(iteration,
args.eval_iters))
for _ in range(get_num_microbatches()):
if not mpu.is_pipeline_first_stage():
input_tensor, _ = communicate(
tensor_send_next=None,
tensor_send_prev=None,
recv_forward=True,
recv_backward=False)
if mpu.get_pipeline_model_parallel_world_size() > 1:
if args.virtual_pipeline_model_parallel_size is not None:
forward_backward_func = forward_backward_pipelining_with_interleaving
else:
input_tensor = None
# Forward evaluation.
output_tensor = forward_step_func(data_iterator, model, input_tensor)
if mpu.is_pipeline_last_stage():
_, loss_dict = output_tensor
# Reduce across processes.
forward_backward_func = forward_backward_pipelining_without_interleaving
else:
forward_backward_func = forward_backward_no_pipelining
loss_dicts = forward_backward_func(
forward_step_func, data_iterator, model, optimizer=None,
timers=None, forward_only=True)
if mpu.is_pipeline_last_stage(ignore_virtual=True):
# Reduce across processes.
for loss_dict in loss_dicts:
for key in loss_dict:
total_loss_dict[key] = total_loss_dict.get(key, torch.cuda.FloatTensor([0.0])) + \
loss_dict[key]
else:
communicate(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_forward=False,
recv_backward=False)
total_loss_dict[key] = total_loss_dict.get(
key, torch.cuda.FloatTensor([0.0])) + loss_dict[key]
args.consumed_valid_samples += mpu.get_data_parallel_world_size() \
* args.micro_batch_size \
* get_num_microbatches()
# Move model back to the train mode.
model.train()
for model_module in model:
model_module.train()
for key in total_loss_dict:
total_loss_dict[key] /= args.eval_iters * get_num_microbatches()
......@@ -936,15 +749,17 @@ def evaluate_and_print_results(prefix, forward_step_func,
ppl = math.exp(min(20, total_loss_dict[key].item()))
string += '{} PPL: {:.6E} | '.format(key, ppl)
if writer and is_last_rank():
writer.add_scalar('{} value-validation'.format(key),
writer.add_scalar('{} validation'.format(key),
total_loss_dict[key].item(),
iteration)
writer.add_scalar('{} ppl-validation'.format(key), ppl, iteration)
writer.add_scalar('{} value-validation vs samples'.format(key),
writer.add_scalar('{} validation vs samples'.format(key),
total_loss_dict[key].item(),
args.consumed_train_samples)
writer.add_scalar('{} ppl-validation vs samples'.format(key), ppl,
args.consumed_train_samples)
if args.log_validation_ppl_to_tensorboard:
writer.add_scalar('{} validation ppl'.format(key), ppl,
iteration)
writer.add_scalar('{} validation ppl vs samples'.format(key),
ppl, args.consumed_train_samples)
length = len(string) + 1
print_rank_last('-' * length)
......@@ -952,6 +767,11 @@ def evaluate_and_print_results(prefix, forward_step_func,
print_rank_last('-' * length)
def cyclic_iter(iter):
while True:
for x in iter:
yield x
def build_train_valid_test_data_iterators(
build_train_valid_test_datasets_provider):
"""XXX"""
......@@ -1020,19 +840,26 @@ def build_train_valid_test_data_iterators(
args.do_valid = flags[1].item()
args.do_test = flags[2].item()
# Build iterators.
dl_type = args.dataloader_type
assert dl_type in ['single', 'cyclic']
if train_dataloader is not None:
train_data_iterator = iter(train_dataloader)
train_data_iterator = iter(train_dataloader) if dl_type == 'single' \
else iter(cyclic_iter(train_dataloader))
else:
train_data_iterator = None
if valid_dataloader is not None:
valid_data_iterator = iter(valid_dataloader)
valid_data_iterator = iter(valid_dataloader) if dl_type == 'single' \
else iter(cyclic_iter(valid_dataloader))
else:
valid_data_iterator = None
if test_dataloader is not None:
test_data_iterator = iter(test_dataloader)
test_data_iterator = iter(test_dataloader) if dl_type == 'single' \
else iter(cyclic_iter(test_dataloader))
else:
test_data_iterator = None
......
......@@ -18,12 +18,64 @@
import sys
import torch
from torch.nn.parallel import DistributedDataParallel as torchDDP
from apex.multi_tensor_apply import multi_tensor_applier
import amp_C
from megatron import get_args
from megatron import print_rank_0
from megatron import get_adlr_autoresume
from megatron import mpu
from megatron.checkpointing import save_checkpoint
from megatron.model.module import param_is_not_shared
from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate
def unwrap_model(model, module_instances=(torchDDP)):
return_list = True
if not isinstance(model, list):
model = [model]
return_list = False
unwrapped_model = []
for model_module in model:
while isinstance(model_module, module_instances):
model_module = model_module.module
unwrapped_model.append(model_module)
if not return_list:
return unwrapped_model[0]
return unwrapped_model
def calc_params_l2_norm(model):
"""Calculate l2 norm of parameters """
args = get_args()
if not isinstance(model, list):
model = [model]
# Remove duplicate params.
params_data = []
for model_ in model:
for param in model_.parameters():
is_not_shared = param_is_not_shared(param)
is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
if is_not_shared and is_not_tp_duplicate:
if args.bf16:
params_data.append(param.data.float())
else:
params_data.append(param.data)
# Calculate norm
dummy_overflow_buf = torch.cuda.IntTensor([0])
norm, _ = multi_tensor_applier(
amp_C.multi_tensor_l2norm,
dummy_overflow_buf,
[params_data],
False # no per-parameter norm
)
norm_2 = norm * norm
# Sum across all model-parallel GPUs.
torch.distributed.all_reduce(norm_2,
op=torch.distributed.ReduceOp.SUM,
group=mpu.get_model_parallel_group())
return norm_2.item() ** 0.5
def average_losses_across_data_parallel_group(losses):
......@@ -76,6 +128,8 @@ def print_params_min_max_norm(optimizer, iteration):
def check_adlr_autoresume_termination(iteration, model,
optimizer, lr_scheduler):
"""Check for autoresume signal and exit if it is received."""
from megatron.checkpointing import save_checkpoint
args = get_args()
autoresume = get_adlr_autoresume()
# Add barrier to ensure consistnecy.
......
......@@ -17,41 +17,30 @@
import torch
import torch.nn.functional as F
from functools import partial
from megatron import get_args
from megatron import print_rank_0
from megatron import get_timers
from megatron import mpu
from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.model import BertModel, BertModelFirstStage, BertModelIntermediateStage, BertModelLastStage
from megatron.model import BertModel
from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group
def model_provider():
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
print_rank_0('building BERT model ...')
args = get_args()
if mpu.get_pipeline_model_parallel_world_size() > 1:
# Determine model based on position of stage in pipeline.
if mpu.is_pipeline_first_stage():
model = BertModelFirstStage(
num_tokentypes=2)
elif mpu.is_pipeline_last_stage():
model = BertModelLastStage(
num_tokentypes=2,
add_binary_head=True,
parallel_output=True)
else:
model = BertModelIntermediateStage(
num_tokentypes=2)
else:
model = BertModel(
num_tokentypes=2,
add_binary_head=True,
parallel_output=True)
num_tokentypes = 2 if args.bert_binary_head else 0
model = BertModel(
num_tokentypes=num_tokentypes,
add_binary_head=args.bert_binary_head,
parallel_output=True,
pre_process=pre_process,
post_process=post_process)
return model
......@@ -81,51 +70,51 @@ def get_batch(data_iterator):
return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask
def forward_step(data_iterator, model, input_tensor):
"""Forward step."""
args = get_args()
timers = get_timers()
def loss_func(loss_mask, sentence_order, output_tensor):
lm_loss_, sop_logits = output_tensor
# Get the batch.
timers('batch-generator').start()
tokens, types, sentence_order, loss_mask, lm_labels, padding_mask \
= get_batch(data_iterator)
timers('batch-generator').stop()
# Forward pass through the model.
if mpu.is_pipeline_first_stage():
assert input_tensor is None
if mpu.is_pipeline_last_stage():
output_tensor = model(tokens, padding_mask, tokentype_ids=types,
lm_labels=lm_labels)
else:
output_tensor = model(tokens, padding_mask, tokentype_ids=types)
elif mpu.is_pipeline_last_stage():
assert input_tensor is not None
output_tensor = model(input_tensor, padding_mask, lm_labels=lm_labels)
else:
assert input_tensor is not None
output_tensor = model(input_tensor, padding_mask)
if mpu.is_pipeline_last_stage():
lm_loss_, sop_logits = output_tensor
lm_loss_ = lm_loss_.float()
loss_mask = loss_mask.float()
lm_loss = torch.sum(
lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
if sop_logits is not None:
sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(),
sentence_order.view(-1),
ignore_index=-1)
sop_loss = sop_loss.float()
loss = lm_loss + sop_loss
averaged_losses = average_losses_across_data_parallel_group(
[lm_loss, sop_loss])
return loss, {'lm loss': averaged_losses[0],
'sop loss': averaged_losses[1]}
lm_loss_ = lm_loss_.float()
loss_mask = loss_mask.float()
lm_loss = torch.sum(
lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
else:
loss = lm_loss
averaged_losses = average_losses_across_data_parallel_group(
[lm_loss])
return loss, {'lm loss': averaged_losses[0]}
loss = lm_loss + sop_loss
averaged_losses = average_losses_across_data_parallel_group([lm_loss, sop_loss])
def forward_step(data_iterator, model):
"""Forward step."""
args = get_args()
timers = get_timers()
# Get the batch.
timers('batch-generator').start()
tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = get_batch(
data_iterator)
timers('batch-generator').stop()
if not args.bert_binary_head:
types = None
# Forward pass through the model.
output_tensor = model(tokens, padding_mask, tokentype_ids=types,
lm_labels=lm_labels)
return loss, {'lm loss': averaged_losses[0], 'sop loss': averaged_losses[1]}
return output_tensor
return output_tensor, partial(loss_func, loss_mask, sentence_order)
def train_valid_test_datasets_provider(train_val_test_num_samples):
......@@ -143,7 +132,8 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
masked_lm_prob=args.mask_prob,
short_seq_prob=args.short_seq_prob,
seed=args.seed,
skip_warmup=(not args.mmap_warmup))
skip_warmup=(not args.mmap_warmup),
binary_head=args.bert_binary_head)
print_rank_0("> finished creating BERT datasets ...")
return train_ds, valid_ds, test_ds
......
......@@ -16,39 +16,28 @@
"""Pretrain GPT"""
import torch
from functools import partial
from megatron import get_args
from megatron import print_rank_0
from megatron import get_timers
from megatron import get_tokenizer
from megatron import mpu
from megatron.data.gpt_dataset import build_train_valid_test_datasets
from megatron.model import (GPTModel,
GPTModelFirstStage,
GPTModelIntermediateStage,
GPTModelLastStage)
from megatron.model import GPTModel
from megatron.training import pretrain
from megatron.utils import get_ltor_masks_and_position_ids
from megatron.utils import average_losses_across_data_parallel_group
def model_provider():
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
print_rank_0('building GPT model ...')
args = get_args()
if mpu.get_pipeline_model_parallel_world_size() > 1:
# Determine model based on position of stage in pipeline.
if mpu.is_pipeline_first_stage():
model = GPTModelFirstStage(num_tokentypes=0)
elif mpu.is_pipeline_last_stage():
model = GPTModelLastStage(
num_tokentypes=0, parallel_output=True)
else:
model = GPTModelIntermediateStage(
num_tokentypes=0)
else:
model = GPTModel(num_tokentypes=0, parallel_output=True)
model = GPTModel(
num_tokentypes=0,
parallel_output=True,
pre_process=pre_process,
post_process=post_process
)
return model
......@@ -83,8 +72,18 @@ def get_batch(data_iterator):
return tokens, labels, loss_mask, attention_mask, position_ids
def loss_func(loss_mask, output_tensor):
losses = output_tensor.float()
loss_mask = loss_mask.view(-1).float()
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
# Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss])
def forward_step(data_iterator, model, input_tensor):
return loss, {'lm loss': averaged_loss[0]}
def forward_step(data_iterator, model):
"""Forward step."""
args = get_args()
timers = get_timers()
......@@ -95,31 +94,10 @@ def forward_step(data_iterator, model, input_tensor):
data_iterator)
timers('batch-generator').stop()
# Forward pass through the model.
if mpu.is_pipeline_first_stage():
assert input_tensor is None
if mpu.is_pipeline_last_stage():
output_tensor = model(tokens, position_ids, attention_mask,
labels=labels)
else:
output_tensor = model(tokens, position_ids, attention_mask)
elif mpu.is_pipeline_last_stage():
assert input_tensor is not None
output_tensor = model(input_tensor, attention_mask, labels=labels)
else:
assert input_tensor is not None
output_tensor = model(input_tensor, attention_mask)
if mpu.is_pipeline_last_stage():
losses = output_tensor.float()
loss_mask = loss_mask.view(-1).float()
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
# Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss])
output_tensor = model(tokens, position_ids, attention_mask,
labels=labels)
return loss, {'lm loss': averaged_loss[0]}
return output_tensor
return output_tensor, partial(loss_func, loss_mask)
def train_valid_test_datasets_provider(train_val_test_num_samples):
......@@ -144,5 +122,4 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
if __name__ == "__main__":
pretrain(train_valid_test_datasets_provider, model_provider, forward_step,
args_defaults={'tokenizer_type': 'GPT2BPETokenizer',
'scaled_upper_triang_masked_softmax_fusion': True})
args_defaults={'tokenizer_type': 'GPT2BPETokenizer'})
......@@ -14,6 +14,7 @@
# limitations under the License.
"""Pretrain BERT for Inverse Cloze Task"""
import math
import torch
import torch.distributed as dist
......@@ -23,17 +24,21 @@ from megatron import get_args
from megatron import print_rank_0
from megatron import get_timers
from megatron import mpu
from megatron.data.biencoder_dataset_utils import get_ict_batch
from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.model.biencoder_model import biencoder_model_provider
from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group
from megatron.model.realm_model import general_ict_model_provider
from megatron.data.realm_dataset_utils import get_ict_batch
def pretrain_ict_model_provider():
args = get_args()
return general_ict_model_provider(False, False)
model = biencoder_model_provider(
only_context_model=False,
only_query_model=False,
biencoder_shared_query_context_model=\
args.biencoder_shared_query_context_model)
return model
def get_group_world_size_rank():
......@@ -72,7 +77,6 @@ class AllgatherFromDataParallelRegion(torch.autograd.Function):
output = output_list[rank].contiguous()
return output
def forward_step(data_iterator, model, input_tensor):
"""Forward step."""
args = get_args()
......@@ -80,37 +84,57 @@ def forward_step(data_iterator, model, input_tensor):
# Get the batch.
timers('batch-generator').start()
query_tokens, query_pad_mask, \
block_tokens, block_pad_mask, block_indices = get_ict_batch(data_iterator)
query_tokens, query_mask, \
context_tokens, context_mask, context_indices = get_ict_batch(data_iterator)
timers('batch-generator').stop()
# Query and Context Types
query_types = torch.cuda.LongTensor(*query_tokens.shape).fill_(0)
context_types = torch.cuda.LongTensor(*context_tokens.shape).fill_(0)
# Forward model.
query_logits, block_logits = model(query_tokens, query_pad_mask, block_tokens, block_pad_mask)
query_logits, context_logits = model(query_tokens, query_mask,
query_types, context_tokens,
context_mask, context_types)
micro_batch_size = query_logits.shape[0]
global_batch_size = dist.get_world_size() * micro_batch_size # recall we assert that tensor_model_parallel_size == 1
# recall we assert that tensor_model_parallel_size == 1
assert mpu.get_tensor_model_parallel_world_size() == 1, \
"Model parallel size > 1 not supported for ICT"
global_batch_size = dist.get_world_size() * micro_batch_size
all_query_logits = AllgatherFromDataParallelRegion.apply(query_logits)
all_block_logits = AllgatherFromDataParallelRegion.apply(block_logits)
all_context_logits = AllgatherFromDataParallelRegion.apply(context_logits)
# scores are inner products between query and context embeddings
retrieval_scores = torch.matmul(all_query_logits,
torch.transpose(all_context_logits, 0, 1))
# scaling the retriever scores
if args.retriever_score_scaling:
retrieval_scores = retrieval_scores / math.sqrt(args.hidden_size)
# scores are inner products between query and block embeddings
retrieval_scores = all_query_logits.float().matmul(torch.transpose(all_block_logits, 0, 1).float())
softmaxed = F.softmax(retrieval_scores, dim=1)
sorted_vals, sorted_indices = torch.topk(softmaxed, k=softmaxed.shape[1], sorted=True)
softmax_scores = F.log_softmax(retrieval_scores, dim=1)
sorted_vals, sorted_indices = torch.topk(softmax_scores,
k=softmax_scores.shape[1], sorted=True)
def topk_accuracy(k):
return torch.cuda.FloatTensor([sum([int(i in sorted_indices[i, :k]) for i in range(global_batch_size)]) / global_batch_size])
return torch.cuda.FloatTensor([sum([int(i in sorted_indices[i, :k]) \
for i in range(global_batch_size)]) / global_batch_size])
topk_accs = [topk_accuracy(int(k)) for k in args.report_topk_accuracies]
retrieval_loss = torch.nn.CrossEntropyLoss()(retrieval_scores, torch.arange(global_batch_size).long().cuda())
retrieval_loss = retrieval_loss.float()
averaged_losses = average_losses_across_data_parallel_group([retrieval_loss, *topk_accs])
topk_accs = [topk_accuracy(int(k)) for k in args.retriever_report_topk_accuracies]
# create stats_dict with retrieval loss and all specified top-k accuracies
topk_acc_dict = {'top{}_acc'.format(k): v for k, v in zip(args.report_topk_accuracies, averaged_losses[1:])}
stats_dict = dict(retrieval_loss=averaged_losses[0], **topk_acc_dict)
labels = torch.arange(global_batch_size).long().cuda()
loss = F.nll_loss(softmax_scores, labels, reduction='mean')
reduced_losses = average_losses_across_data_parallel_group([loss, *topk_accs])
# Scale the retrieval loss
loss = loss * mpu.get_data_parallel_world_size()
return retrieval_loss, stats_dict
# create stats_dict with retrieval loss and all specified top-k accuracies
topk_acc_dict = {'top{}_acc'.format(k): v * 100 for k, v in \
zip(args.retriever_report_topk_accuracies, reduced_losses[1:])}
stats_dict = dict(loss=reduced_losses[0], **topk_acc_dict)
return loss, stats_dict
def train_valid_test_datasets_provider(train_val_test_num_samples):
......@@ -129,6 +153,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
short_seq_prob=args.short_seq_prob,
seed=args.seed,
skip_warmup=(not args.mmap_warmup),
binary_head=False,
dataset_type='ict')
print_rank_0("> finished creating BERT ICT datasets ...")
......@@ -136,5 +161,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
if __name__ == "__main__":
pretrain(train_valid_test_datasets_provider, pretrain_ict_model_provider, forward_step,
pretrain(train_valid_test_datasets_provider,
pretrain_ict_model_provider,
forward_step,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pretrain VIT"""
import torch
import torch.nn.functional as F
from megatron import get_args, get_timers, mpu, print_rank_0
from megatron.data.vit_dataset import build_train_valid_datasets
from megatron.model.vit_model import VitModel
from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group
def model_provider():
"""Build the model."""
print_rank_0("building VIT model ...")
args = get_args()
model = VitModel(num_classes=args.num_classes)
return model
def get_batch(data_iterator):
"""Build the batch."""
data = next(data_iterator)
# only data parallelism; no need for broadcast
images = data[0].cuda()
labels = data[1].cuda()
return images, labels
def forward_step(data_iterator, model, input_tensor):
"""Forward step."""
timers = get_timers()
assert input_tensor is None
# Get the batch.
timers("batch-generator").start()
(
images,
labels,
) = get_batch(data_iterator)
timers("batch-generator").stop()
# Forward model. lm_labels
logits = model(images).contiguous().float()
loss = F.cross_entropy(logits, labels)
outputs = torch.argmax(logits, -1)
correct = (outputs == labels).float()
accuracy = torch.mean(correct)
averaged_loss = average_losses_across_data_parallel_group([loss, accuracy])
return loss, {"loss": averaged_loss[0], "accuracy": averaged_loss[1]}
def train_valid_test_datasets_provider(train_val_test_num_samples):
"""Build train, valid, and test datasets."""
args = get_args()
print_rank_0(
"> building train, validation, and test datasets " "for VIT ..."
)
train_ds, valid_ds = build_train_valid_datasets(data_path=args.data_path)
print_rank_0("> finished creating VIT datasets ...")
return train_ds, valid_ds, None
if __name__ == "__main__":
pretrain(
train_valid_test_datasets_provider,
model_provider,
forward_step,
args_defaults={'dataloader_type': 'cyclic'}
)
......@@ -17,13 +17,14 @@
import os
import time
from functools import partial
import torch
from megatron import get_args
from megatron import print_rank_last, is_last_rank
from megatron import mpu
from megatron.training import communicate
from megatron.schedules import get_forward_backward_func
from tasks.finetune_utils import build_data_loader
from tasks.finetune_utils import process_batch
......@@ -38,7 +39,7 @@ def accuracy_func_provider(single_dataset_provider):
for datapath in datapaths:
dataset = single_dataset_provider(datapath)
dataloader = build_data_loader(
dataset, args.micro_batch_size, num_workers=args.num_workers,
dataset, args.orig_micro_batch_size, num_workers=args.num_workers,
drop_last=(mpu.get_data_parallel_world_size() > 1))
dataloaders.append((dataset.dataset_name, dataloader))
......@@ -73,14 +74,66 @@ def accuracy_func_provider(single_dataset_provider):
return metrics_func
def calculate_correct_answers(name, model, dataloader,
epoch, output_predictions):
"""Calculate correct over total answers and return prediction if the
`output_predictions` is true."""
args = get_args()
forward_backward_func = get_forward_backward_func()
start_time = time.time()
model.eval()
saved_batch_size = args.micro_batch_size
for m in model:
m.eval()
saved_micro_batch_size = args.micro_batch_size
saved_global_batch_size = args.global_batch_size
ds = dataloader.dataset
if hasattr(ds, 'sample_multiplier'):
# If our dataset as a sample_multiplier attribute that means
# each "sample" from the dataset actually has multiple samples
# that will collapse into the batch dimension (for example in
# the RACE dataset that has several options), we need to
# account for that when setting the micro batch size.
sample_multiplier = ds.sample_multiplier
else:
sample_multiplier = 1
micro_batch_size_times_data_parallel = args.orig_micro_batch_size * args.data_parallel_size
num_micro_batches = args.orig_global_batch_size // micro_batch_size_times_data_parallel
def loss_func(output_predictions, labels, output_tensor):
logits = output_tensor
loss_dict = {}
# Add output predictions.
if output_predictions:
assert False
loss_dict['softmaxes'] = torch.nn.Softmax(dim=-1)(
logits.float()).data.cpu().numpy().tolist()
loss_dict['labels'] = labels.data.cpu().numpy().tolist()
loss_dict['ids'] = batch['uid'].cpu().numpy().tolist()
# Compute the correct answers.
predicted = torch.argmax(logits, dim=-1)
corrects = (predicted == labels)
# Add to the counters.
loss_dict['total'] = labels.size(0)
loss_dict['correct'] = corrects.sum().item()
return 0, loss_dict
# defined inside to capture output_predictions
def correct_answers_forward_step(batch, model):
try:
batch_ = next(batch)
except BaseException:
batch_ = batch
tokens, types, labels, attention_mask = process_batch(batch_)
# Forward model.
args = get_args()
output_tensor = model(tokens, attention_mask, tokentype_ids=types)
return output_tensor, partial(loss_func, output_predictions, labels)
with torch.no_grad():
# For all the batches in the dataset.
total = 0
......@@ -92,60 +145,30 @@ def calculate_correct_answers(name, model, dataloader,
labels = []
ids = []
for _, batch in enumerate(dataloader):
# Run the model forward.
tokens, types, labels_, attention_mask = process_batch(batch)
# For evaluation only mode we use drop_last = False to get all the
# samples, which means we might not have a full batch, so we
# adjust batch_size here to actual batch size of data
actual_batch_size = len(labels_)
actual_batch_size = len(batch['label'])
# ... applying sample_multiplier if necessary
ds = dataloader.dataset
if hasattr(ds, 'sample_multiplier'):
actual_batch_size *= ds.sample_multiplier
args.micro_batch_size = actual_batch_size
if not mpu.is_pipeline_first_stage():
input_tensor, _ = communicate(
tensor_send_next=None,
tensor_send_prev=None,
recv_forward=True,
recv_backward=False)
else:
input_tensor = None
args.micro_batch_size = actual_batch_size * sample_multiplier
args.global_batch_size = actual_batch_size * sample_multiplier * num_micro_batches
# Forward model.
if mpu.is_pipeline_first_stage():
assert input_tensor is None
output_tensor = model(tokens, attention_mask, tokentype_ids=types)
else:
assert input_tensor is not None
output_tensor = model(input_tensor, attention_mask)
if mpu.is_pipeline_last_stage():
logits = output_tensor
loss_dicts = forward_backward_func(correct_answers_forward_step, batch, model,
optimizer=None, timers=None, forward_only=True)
# Add output predictions.
for loss_dict in loss_dicts:
if output_predictions:
softmaxes.extend(torch.nn.Softmax(dim=-1)(
logits.float()).data.cpu().numpy().tolist())
labels.extend(labels_.data.cpu().numpy().tolist())
ids.extend(batch['uid'].cpu().numpy().tolist())
# Compute the correct answers.
predicted = torch.argmax(logits, dim=-1)
corrects = (predicted == labels_)
# Add to the counters.
total += labels_.size(0)
correct += corrects.sum().item()
else:
communicate(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_forward=False,
recv_backward=False)
model.train()
args.micro_batch_size = saved_batch_size
softmaxes.extend(loss_dict['softmaxes'])
labels.extend(loss_dict['labels'])
ids.extend(loss_dict['ids'])
total += loss_dict['total']
correct += loss_dict['correct']
for m in model:
m.train()
args.micro_batch_size = saved_micro_batch_size
args.global_batch_size = saved_global_batch_size
# Reduce.
if mpu.is_pipeline_last_stage():
......
......@@ -15,6 +15,8 @@
"""Finetune utilities."""
from functools import partial
import torch
from megatron import get_args
......@@ -27,8 +29,9 @@ from megatron.training import evaluate_and_print_results
from megatron.training import setup_model_and_optimizer
from megatron.training import train_step
from megatron.training import training_log
from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import average_losses_across_data_parallel_group
from megatron.utils import calc_params_l2_norm
from megatron.utils import check_adlr_autoresume_termination
def process_batch(batch):
......@@ -45,7 +48,20 @@ def process_batch(batch):
return tokens, types, labels, attention_mask
def _cross_entropy_forward_step(batch, model, input_tensor):
def cross_entropy_loss_func(labels, output_tensor):
logits = output_tensor
# Cross-entropy loss.
loss_func = torch.nn.CrossEntropyLoss()
loss = loss_func(logits.contiguous().float(), labels)
# Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {'lm loss': averaged_loss[0]}
def _cross_entropy_forward_step(batch, model):
"""Simple forward step with cross-entropy loss."""
timers = get_timers()
......@@ -59,25 +75,9 @@ def _cross_entropy_forward_step(batch, model, input_tensor):
timers('batch-generator').stop()
# Forward model.
if mpu.is_pipeline_first_stage():
assert input_tensor is None
output_tensor = model(tokens, attention_mask, tokentype_ids=types)
else:
assert input_tensor is not None
output_tensor = model(input_tensor, attention_mask)
if mpu.is_pipeline_last_stage():
logits = output_tensor
output_tensor = model(tokens, attention_mask, tokentype_ids=types)
# Cross-entropy loss.
loss_func = torch.nn.CrossEntropyLoss()
loss = loss_func(logits.contiguous().float(), labels)
# Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {'lm loss': averaged_loss[0]}
return output_tensor
return output_tensor, partial(cross_entropy_loss_func, labels)
def build_data_loader(dataset, micro_batch_size, num_workers, drop_last):
......@@ -134,7 +134,14 @@ def _build_train_valid_dataloaders(train_dataset, valid_dataset):
# This is necessary so pipeline transfers know what size they are
# and the LR schedule, which is based on samples seen, gets set
# correctly.
args.orig_micro_batch_size = args.micro_batch_size
args.orig_global_batch_size = args.global_batch_size
if hasattr(train_dataset, 'sample_multiplier'):
# If our dataset as a sample_multiplier attribute that means
# each "sample" from the dataset actually has multiple samples
# that will collapse into the batch dimension (for example in
# the RACE dataset that has several options), we need to
# account for that when setting the micro batch size.
args.micro_batch_size *= train_dataset.sample_multiplier
args.global_batch_size *= train_dataset.sample_multiplier
......@@ -148,7 +155,8 @@ def _train(model, optimizer, lr_scheduler, forward_step,
timers = get_timers()
# Turn on training mode which enables dropout.
model.train()
for m in model:
m.train()
# Tracking loss.
losses_dict_sum = {}
......@@ -162,7 +170,7 @@ def _train(model, optimizer, lr_scheduler, forward_step,
report_memory_flag = True
# For each remaining epoch
timers('interval time').start()
timers('interval-time').start()
for epoch in range(start_epoch, args.epochs):
print_rank_0('working on epoch {} ...'.format(epoch + 1))
......@@ -179,16 +187,20 @@ def _train(model, optimizer, lr_scheduler, forward_step,
start_iteration = 0
# Train for one step.
losses_dict, skipped_iter = train_step(forward_step, batch, model,
optimizer, lr_scheduler)
out = train_step(forward_step, batch, model, optimizer, lr_scheduler)
losses_dict, skipped_iter, grad_norm, num_zeros_in_grad = out
iteration += 1
# Logging.
params_norm = None
if args.log_params_norm:
params_norm = calc_params_l2_norm(model)
report_memory_flag = training_log(losses_dict, losses_dict_sum,
optimizer.param_groups[0]['lr'],
iteration,
optimizer.get_loss_scale().item(),
report_memory_flag, skipped_iter)
report_memory_flag, skipped_iter,
grad_norm, params_norm, num_zeros_in_grad)
# Autoresume
if args.adlr_autoresume and \
......@@ -224,6 +236,9 @@ def finetune(train_valid_datasets_provider, model_provider,
args = get_args()
timers = get_timers()
assert args.rampup_batch_size is None, \
'batch size scaling is not supported for finetuning'
# Train and validation data loaders.
timers('train/valid/test dataset/dataloder').start()
if args.epochs > 0:
......
......@@ -19,7 +19,7 @@ from megatron import get_args
from megatron import print_rank_0
from megatron import get_tokenizer
from megatron import mpu
from megatron.model.classification import Classification, ClassificationFirstStage, ClassificationIntermediateStage, ClassificationLastStage
from megatron.model.classification import Classification
from tasks.eval_utils import accuracy_func_provider
from tasks.finetune_utils import finetune
......@@ -39,25 +39,14 @@ def glue_classification(num_classes, Dataset,
return train_dataset, valid_dataset
def model_provider():
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
args = get_args()
print_rank_0('building classification model for {} ...'.format(
args.task))
if mpu.get_pipeline_model_parallel_world_size() > 1:
# Determine model based on position of stage in pipeline.
if mpu.is_pipeline_first_stage():
model = ClassificationFirstStage(
num_classes=num_classes, num_tokentypes=2)
elif mpu.is_pipeline_last_stage():
model = ClassificationLastStage(
num_classes=num_classes, num_tokentypes=2)
else:
model = ClassificationIntermediateStage(
num_classes=num_classes, num_tokentypes=2)
else:
model = Classification(num_classes=num_classes, num_tokentypes=2)
model = Classification(num_classes=num_classes, num_tokentypes=2,
pre_process=pre_process, post_process=post_process)
return model
......
......@@ -47,6 +47,20 @@ def get_tasks_args(parser):
help='Sliding window for overlapping evaluation.')
group.add_argument('--strict-lambada', action='store_true',
help='Use more difficult formulation of lambada.')
# Retriever args
group.add_argument('--qa-data-dev', type=str, default=None,
help='Path to the QA dataset dev file.')
group.add_argument('--qa-data-test', type=str, default=None,
help='Path to the QA dataset test file.')
# Faiss arguments for retriever
group.add_argument('--faiss-use-gpu', action='store_true',
help='Whether create the FaissMIPSIndex on GPU')
group.add_argument('--faiss-match', type=str, default='string', \
choices=['regex', 'string'], help="Answer matching '\
'logic type")
group.add_argument('--faiss-topk-retrievals', type=int, default=100,
help='Number of blocks to use as top-k during retrieval')
return parser
......@@ -56,12 +70,19 @@ if __name__ == '__main__':
initialize_megatron(extra_args_provider=get_tasks_args)
args = get_args()
if args.num_layers_per_virtual_pipeline_stage is not None:
print("Interleaved pipeline schedule is not yet supported for downstream tasks.")
exit()
if args.task == 'RACE':
from race.finetune import main
elif args.task in ['MNLI', 'QQP']:
from glue.finetune import main
elif args.task in ['LAMBADA', 'WIKITEXT103']:
from zeroshot_gpt2.evaluate import main
from zeroshot_gpt.evaluate import main
elif args.task in ['ICT-ZEROSHOT-NQ']:
from orqa.evaluate_orqa import main
else:
raise NotImplementedError('Task {} is not implemented.'.format(
args.task))
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Main tasks functionality."""
import os
import sys
from megatron import get_args
from tasks.orqa.evaluate_utils import ORQAEvaluator
def main():
"""
Main program
"""
args = get_args()
# Set up the model and evaluator
evaluator = ORQAEvaluator()
# Run evaluation
if args.qa_data_dev is not None:
evaluator.evaluate(args.qa_data_dev, "DEV")
if args.qa_data_test is not None:
evaluator.evaluate(args.qa_data_test, "TEST")
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from megatron import get_args, print_rank_0
from megatron.checkpointing import load_biencoder_checkpoint
from megatron.data.orqa_wiki_dataset import get_open_retrieval_wiki_dataset
from tasks.orqa.natural_questions.nq import get_nq_dataset
from tasks.orqa.natural_questions.nq import get_one_epoch_nq_dataloader
from tasks.orqa.natural_questions.nq import process_nq_batch
from tasks.orqa.natural_questions.qa_utils import calculate_matches
from megatron.data.realm_index import OpenRetreivalDataStore, FaissMIPSIndex
from megatron.model.biencoder_model import biencoder_model_provider
from megatron.training import get_model
class ORQAEvaluator(object):
def __init__(self):
args = get_args()
self.embedding_size = args.hidden_size
self.faiss_use_gpu = args.faiss_use_gpu
self.evidence_embedder_obj = None
self.evidence_dataset = None
self.mips_index = None
self.eval_dataset = None
# Get Evidence (Wikipedia) dataset
self.get_evidence_dataset()
# Load query encoder checkpoint
only_query_model = True
if args.biencoder_shared_query_context_model:
only_query_model = False
model = get_model(lambda: biencoder_model_provider(only_query_model=\
only_query_model, biencoder_shared_query_context_model=\
args.biencoder_shared_query_context_model))
self.model = load_biencoder_checkpoint(model,
only_query_model=only_query_model)
assert len(self.model) == 1
self.model[0].eval()
# Load faiss indexer
self.faiss_wrapper()
def get_evidence_embedding(self):
# This will load the embedding from the embedding path
self.evidence_embedder_obj = OpenRetreivalDataStore(load_from_path=True)
def get_evidence_dataset(self):
self.evidence_dataset = get_open_retrieval_wiki_dataset()
def faiss_wrapper(self):
# Initialize FAISS wrapper on local rank = 0 as the evidence embeddings
# is distributed over all the GPUs in a node and FAISS is not
# thread-safe
args = get_args()
if args.local_rank == 0:
# Get evidence embeddings computed using context encoder
self.get_evidence_embedding()
assert self.evidence_embedder_obj is not None
self.mips_index = FaissMIPSIndex(embed_size=self.embedding_size,
embed_data=self.evidence_embedder_obj,
use_gpu=self.faiss_use_gpu)
# Wait for the FAISS index to be initialized in all the nodes
torch.distributed.barrier()
def generate_query_vectors(self, qa_data, split):
self.eval_dataset = get_nq_dataset(qa_data, split)
dataloader = get_one_epoch_nq_dataloader(self.eval_dataset)
query_vectors = []
reference_list = []
for batch in dataloader:
# batch also has query_tokens and query_pad_data
query_tokens, query_mask, query_types, \
query_len, reference = process_nq_batch(batch)
assert len(self.model) == 1
unwrapped_model = self.model[0]
while not hasattr(unwrapped_model, 'embed_text'):
unwrapped_model = unwrapped_model.module
with torch.no_grad():
query_logits = unwrapped_model.embed_text(
unwrapped_model.query_model, query_tokens,
query_mask, query_types)
reference_list.extend(reference)
query_vectors.extend(query_logits.split(1, dim=0))
if len(query_vectors) % 100 == 0:
print_rank_0('Encoded queries {}'.format(len(query_vectors)))
query_tensor = torch.cat(query_vectors, dim=0)
print_rank_0('Total encoded queries tensor {}'.format(query_tensor.size()))
assert query_tensor.size(0) == len(self.eval_dataset)
return query_tensor, reference_list
def evaluate(self, qa_data, split):
args = get_args()
query_tensor, reference_list = self.generate_query_vectors(qa_data, \
split)
local_rank = args.local_rank
rank = torch.distributed.get_rank()
device_count = torch.cuda.device_count()
num_nodes = torch.distributed.get_world_size() // device_count
node_id = rank // device_count
for node in range(num_nodes):
start_rank = node * device_count
end_rank = (node + 1) * device_count
ranks_list = list(range(start_rank, end_rank))
node_group = torch.distributed.new_group(ranks=ranks_list)
if node_id == node:
device_start_rank = start_rank
group = node_group
input_ = torch.empty_like(query_tensor).copy_(query_tensor).detach_()
tensor_list = [torch.empty_like(input_) for _ in range(device_count)]
torch.distributed.all_gather(tensor_list, query_tensor, group=group)
if local_rank == 0 and self.mips_index is not None:
all_query_tensor = torch.cat(tensor_list, dim=0).contiguous()
distance, topkindex = self.mips_index.search_mips_index(
all_query_tensor, top_k=args.faiss_topk_retrievals,
reconstruct=False)
distance = torch.from_numpy(distance).cuda()
topkindex = torch.LongTensor(topkindex).cuda()
if local_rank != 0:
distance = torch.empty(device_count * len(query_tensor), \
args.faiss_topk_retrievals, dtype=torch.float32).cuda()
topkindex = torch.empty(device_count * len(query_tensor), \
args.faiss_topk_retrievals, dtype=torch.int64).cuda()
torch.distributed.broadcast(distance, src=device_start_rank, \
group=group)
torch.distributed.broadcast(topkindex, src=device_start_rank, \
group=group)
distance = torch.split(distance, len(query_tensor), dim=0)\
[local_rank]
topkindex = torch.split(topkindex, len(query_tensor), dim=0)\
[local_rank]
top_ids_and_scores = []
for darray, topkarray in zip(distance, topkindex):
top_ids_and_scores.append((topkarray.tolist(), darray.tolist()))
passages = self.evidence_dataset.id2text
match_stats = calculate_matches(passages,
reference_list,
top_ids_and_scores,
workers_num=args.num_workers,
match_type=args.faiss_match)
top_k_hits = match_stats.top_k_hits
print_rank_0("{} SET RESULTS".format(split))
print_rank_0("topk-{} documents hits {}".format(
args.faiss_topk_retrievals, top_k_hits))
top_k_hits = [v / len(top_ids_and_scores) for v in top_k_hits]
print_rank_0("top-k documents hits accuracy {}".format(top_k_hits))
for i in args.retriever_report_topk_accuracies:
print_rank_0("top-{}: {:.2f}".format(i, top_k_hits[i-1] * 100))
return
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Data Loader for Google NQ dataset
"""
from abc import ABC
import csv
from collections import OrderedDict
import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset, BatchSampler
from megatron import print_rank_0, get_args, get_tokenizer, mpu
from megatron.data.biencoder_dataset_utils import make_attention_mask
def get_nq_dataset(qa_data, split):
args = get_args()
tokenizer = get_tokenizer()
dataset = NQDataset('Google NQ {} Split'.format(split),
'Google Natural Questions',
qa_data,
tokenizer,
args.retriever_seq_length)
return dataset
def process_nq_batch(batch):
query_tokens = batch['token_ids'].long().cuda()
query_mask = (batch['token_mask'] < 0.5).cuda()
query_types = batch['token_types'].long().cuda()
query_len = batch['seq_len'].long().cuda()
reference = batch['reference']
return query_tokens, query_mask, query_types, query_len, reference
class CustomDataLoader(DataLoader):
def __init__(self, dataset, eval=False, **kwargs):
if kwargs.get('collate_fn', None) is None:
kwargs['collate_fn'] = self._collate_fn
self.eval = eval
super().__init__(dataset, **kwargs)
def _collate_fn(self, batch_data):
# generate batch
batch_size = len(batch_data)
tensorized = OrderedDict()
for d in batch_data:
for k, v in d.items():
tensorized.setdefault(k, []).append(v)
assert len(tensorized) == 5
tensorized['token_ids'] = torch.LongTensor(tensorized['token_ids'])
tensorized['token_mask'] = torch.LongTensor(tensorized['token_mask'])
tensorized['token_types'] = torch.LongTensor(tensorized['token_types'])
tensorized['seq_len'] = torch.LongTensor(tensorized['seq_len'])
return tensorized
def get_one_epoch_nq_dataloader(dataset, micro_batch_size=None):
"""Data loader. Note that batch-size is the local (per GPU) batch-size.
NOTE: This dataloader is not distributed !!!
"""
args = get_args()
if micro_batch_size is None:
micro_batch_size = args.micro_batch_size
num_workers = args.num_workers
sampler = torch.utils.data.SequentialSampler(dataset)
# importantly, drop_last must be False to get all the data.
batch_sampler = BatchSampler(sampler,
batch_size=micro_batch_size,
drop_last=False)
# Data loader. Note that batch size is the per GPU batch size.
data_loader = CustomDataLoader(dataset,
batch_sampler=batch_sampler,
num_workers=num_workers,
pin_memory=True)
return data_loader
def build_tokens_types_paddings_from_text(src_text, tokenizer, max_seq_length):
"""Build token types and paddings, trim if needed, and pad if needed."""
src_text_ids = tokenizer.tokenize(src_text)
return build_tokens_types_paddings_from_ids(src_text_ids,
max_seq_length,
tokenizer.cls,
tokenizer.sep,
tokenizer.pad)
def build_tokens_types_paddings_from_ids(src_ids, max_seq_length, cls_id, \
sep_id, pad_id):
"""
Build token types and paddings, trim if needed, and pad if needed.
TODO: Design modular interface to reuse this function. This is getting
repeated multiple times in different tasks
"""
enc_ids = []
tokentypes_enc = []
# [CLS].
enc_ids.append(cls_id)
tokentypes_enc.append(0)
# A.
len_src = len(src_ids)
enc_ids.extend(src_ids)
tokentypes_enc.extend([0] * len_src)
# Cap the size.
if len(enc_ids) > max_seq_length - 1:
enc_ids = enc_ids[0: max_seq_length - 1]
tokentypes_enc = tokentypes_enc[0: max_seq_length - 1]
# [SEP].
enc_ids.append(sep_id)
tokentypes_enc.append(0)
num_tokens_enc = len(enc_ids)
# Padding.
padding_length = max_seq_length - len(enc_ids)
if padding_length > 0:
enc_ids.extend([pad_id] * padding_length)
tokentypes_enc.extend([pad_id] * padding_length)
return enc_ids, tokentypes_enc, num_tokens_enc
def build_sample(token_ids, token_types, num_tokens, reference):
"""
Convert to numpy and return a sample consumed by the
batch producer.
"""
token_ids = np.array(token_ids, dtype=np.int64)
token_types = np.array(token_types, dtype=np.int64)
token_mask = make_attention_mask(token_ids, token_ids)
sample = ({
'token_ids': token_ids,
'token_mask': token_mask,
'token_types': token_types,
'seq_len': num_tokens,
'reference': reference
})
return sample
class NQDataset(ABC, Dataset):
"""
Open Retrieval Question Answering evaluation using Google NQ dataset.
"""
def __init__(self, task_name, dataset_name, datapath,
tokenizer, max_seq_length):
# Store inputs.
self.task_name = task_name
self.dataset_name = dataset_name
self.tokenizer = tokenizer
self.max_seq_length = max_seq_length
print_rank_0(' > building {} dataset for {}:'.format(self.task_name,
self.dataset_name))
print_rank_0(datapath)
self.samples = self.process_samples_from_single_path(datapath)
print_rank_0(' >> total number of samples: {}'.format(\
len(self.samples)))
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
raw_sample = self.samples[idx]
ques_tokens, tokentypes_enc, num_tokens_ques = \
build_tokens_types_paddings_from_text(raw_sample['question'],
self.tokenizer, self.max_seq_length)
sample = build_sample(ques_tokens,
tokentypes_enc,
num_tokens_ques,
raw_sample['answers'])
return sample
@staticmethod
def process_samples_from_single_path(filename):
print_rank_0(' > Processing {} ...'.format(filename))
samples = []
total = 0
with open(filename, 'r') as ifile:
reader = csv.reader(ifile, delimiter='\t')
for row in reader:
question = row[0]
answers = eval(row[1])
sample = {'question': question, 'answers': answers}
total += 1
samples.append(sample)
if total % 1000 == 0:
print_rank_0(' > processed {} so far ...'.format(total))
print_rank_0(' >> processed {} samples.'.format(len(samples)))
return samples
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# The following code has been taken from
# https://github.com/facebookresearch/DPR, which is CC-BY-NC 4.0
# licensed as of now. More details on the license can be found
# at https://github.com/facebookresearch/DPR/blob/master/LICENSE
"""
Set of utilities for Q&A results validation tasks - Retriver passage
validation and Reader predicted answer validation
"""
import collections
import logging
import string
import unicodedata
from functools import partial
from multiprocessing import Pool as ProcessPool
from typing import Tuple, List, Dict
import regex as re
from tasks.orqa.natural_questions.tokenizers import SimpleTokenizer
logger = logging.getLogger(__name__)
QAMatchStats = collections.namedtuple('QAMatchStats', ['top_k_hits',\
'questions_doc_hits'])
def calculate_matches(all_docs: Dict[object, Tuple[str, str]],
answers: List[List[str]], closest_docs: List[Tuple[List[object],
List[float]]], workers_num: int, match_type: str) -> QAMatchStats:
"""
Evaluates answers presence in the set of documents. This function is
supposed to be used with a large collection of documents and results.
It internally forks multiple sub-processes for evaluation and then
merges results
:param all_docs: dictionary of the entire documents database.
doc_id -> (doc_text, title)
:param answers: list of answers's list. One list per question
:param closest_docs: document ids of the top results along with their
scores
:param workers_num: amount of parallel threads to process data
:param match_type: type of answer matching. Refer to has_answer code for
available options
:return: matching information tuple.
top_k_hits - a list where the index is the amount of top documents retrieved
and the value is the total amount of valid matches across an entire
dataset.
questions_doc_hits - more detailed info with answer matches for every
question and every retrieved document
"""
global dpr_all_documents
dpr_all_documents = all_docs
tok_opts = {}
tokenizer = SimpleTokenizer(**tok_opts)
processes = ProcessPool(
processes=workers_num,
)
logger.info('Matching answers in top docs...')
get_score_partial = partial(check_answer, match_type=match_type,
tokenizer=tokenizer)
questions_answers_docs = zip(answers, closest_docs)
scores = processes.map(get_score_partial, questions_answers_docs)
logger.info('Per question validation results len=%d', len(scores))
n_docs = len(closest_docs[0][0])
top_k_hits = [0] * n_docs
for question_hits in scores:
best_hit = next((i for i, x in enumerate(question_hits) if x), None)
if best_hit is not None:
top_k_hits[best_hit:] = [v + 1 for v in top_k_hits[best_hit:]]
return QAMatchStats(top_k_hits, scores)
def check_answer(questions_answers_docs, tokenizer, match_type) -> List[bool]:
"""
Search through all the top docs to see if they have any of the answers.
"""
answers, (doc_ids, doc_scores) = questions_answers_docs
global dpr_all_documents
hits = []
for i, doc_id in enumerate(doc_ids):
doc = dpr_all_documents[doc_id]
text = doc[0]
answer_found = False
if text is None: # cannot find the document for some reason
logger.warning("no doc in db")
hits.append(False)
continue
if has_answer(answers, text, tokenizer, match_type):
answer_found = True
hits.append(answer_found)
return hits
def has_answer(answers, text, tokenizer, match_type) -> bool:
"""
Check if a document contains an answer string.
If `match_type` is string, token matching is done between the text
and answer.
If `match_type` is regex, we search the whole text with the regex.
"""
text = _normalize(text)
if match_type == 'string':
# Answer is a list of possible strings
text = tokenizer.tokenize(text).words(uncased=True)
for single_answer in answers:
single_answer = _normalize(single_answer)
single_answer = tokenizer.tokenize(single_answer)
single_answer = single_answer.words(uncased=True)
for i in range(0, len(text) - len(single_answer) + 1):
if single_answer == text[i: i + len(single_answer)]:
return True
elif match_type == 'regex':
# Answer is a regex
for single_answer in answers:
single_answer = _normalize(single_answer)
if regex_match(text, single_answer):
return True
return False
def regex_match(text, pattern):
"""Test if a regex pattern is contained within a text."""
try:
pattern = re.compile(
pattern,
flags=re.IGNORECASE + re.UNICODE + re.MULTILINE,
)
except BaseException:
return False
return pattern.search(text) is not None
# function for the reader model answer validation
def exact_match_score(prediction, ground_truth):
return _normalize_answer(prediction) == _normalize_answer(ground_truth)
def _normalize_answer(s):
def remove_articles(text):
return re.sub(r'\b(a|an|the)\b', ' ', text)
def white_space_fix(text):
return ' '.join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return ''.join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
def _normalize(text):
return unicodedata.normalize('NFD', text)
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# The following code has been taken from
# https://github.com/facebookresearch/DPR, which is CC-BY-NC 4.0
# licensed as of now. More details on the license can be found
# at https://github.com/facebookresearch/DPR/blob/master/LICENSE
"""
Most of the tokenizers code here is copied from DrQA codebase to avoid adding extra dependency
"""
import copy
import logging
import regex
import spacy
logger = logging.getLogger(__name__)
class Tokens(object):
"""A class to represent a list of tokenized text."""
TEXT = 0
TEXT_WS = 1
SPAN = 2
POS = 3
LEMMA = 4
NER = 5
def __init__(self, data, annotators, opts=None):
self.data = data
self.annotators = annotators
self.opts = opts or {}
def __len__(self):
"""The number of tokens."""
return len(self.data)
def slice(self, i=None, j=None):
"""Return a view of the list of tokens from [i, j)."""
new_tokens = copy.copy(self)
new_tokens.data = self.data[i: j]
return new_tokens
def untokenize(self):
"""Returns the original text (with whitespace reinserted)."""
return ''.join([t[self.TEXT_WS] for t in self.data]).strip()
def words(self, uncased=False):
"""Returns a list of the text of each token
Args:
uncased: lower cases text
"""
if uncased:
return [t[self.TEXT].lower() for t in self.data]
else:
return [t[self.TEXT] for t in self.data]
def offsets(self):
"""Returns a list of [start, end) character offsets of each token."""
return [t[self.SPAN] for t in self.data]
def pos(self):
"""Returns a list of part-of-speech tags of each token.
Returns None if this annotation was not included.
"""
if 'pos' not in self.annotators:
return None
return [t[self.POS] for t in self.data]
def lemmas(self):
"""Returns a list of the lemmatized text of each token.
Returns None if this annotation was not included.
"""
if 'lemma' not in self.annotators:
return None
return [t[self.LEMMA] for t in self.data]
def entities(self):
"""Returns a list of named-entity-recognition tags of each token.
Returns None if this annotation was not included.
"""
if 'ner' not in self.annotators:
return None
return [t[self.NER] for t in self.data]
def ngrams(self, n=1, uncased=False, filter_fn=None, as_strings=True):
"""Returns a list of all ngrams from length 1 to n.
Args:
n: upper limit of ngram length
uncased: lower cases text
filter_fn: user function that takes in an ngram list and returns
True or False to keep or not keep the ngram
as_string: return the ngram as a string vs list
"""
def _skip(gram):
if not filter_fn:
return False
return filter_fn(gram)
words = self.words(uncased)
ngrams = [(s, e + 1)
for s in range(len(words))
for e in range(s, min(s + n, len(words)))
if not _skip(words[s:e + 1])]
# Concatenate into strings
if as_strings:
ngrams = ['{}'.format(' '.join(words[s:e])) for (s, e) in ngrams]
return ngrams
def entity_groups(self):
"""Group consecutive entity tokens with the same NER tag."""
entities = self.entities()
if not entities:
return None
non_ent = self.opts.get('non_ent', 'O')
groups = []
idx = 0
while idx < len(entities):
ner_tag = entities[idx]
# Check for entity tag
if ner_tag != non_ent:
# Chomp the sequence
start = idx
while (idx < len(entities) and entities[idx] == ner_tag):
idx += 1
groups.append((self.slice(start, idx).untokenize(), ner_tag))
else:
idx += 1
return groups
class Tokenizer(object):
"""Base tokenizer class.
Tokenizers implement tokenize, which should return a Tokens class.
"""
def tokenize(self, text):
raise NotImplementedError
def shutdown(self):
pass
def __del__(self):
self.shutdown()
class SimpleTokenizer(Tokenizer):
ALPHA_NUM = r'[\p{L}\p{N}\p{M}]+'
NON_WS = r'[^\p{Z}\p{C}]'
def __init__(self, **kwargs):
"""
Args:
annotators: None or empty set (only tokenizes).
"""
self._regexp = regex.compile(
'(%s)|(%s)' % (self.ALPHA_NUM, self.NON_WS),
flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE
)
if len(kwargs.get('annotators', {})) > 0:
logger.warning('%s only tokenizes! Skipping annotators: %s' %
(type(self).__name__, kwargs.get('annotators')))
self.annotators = set()
def tokenize(self, text):
data = []
matches = [m for m in self._regexp.finditer(text)]
for i in range(len(matches)):
# Get text
token = matches[i].group()
# Get whitespace
span = matches[i].span()
start_ws = span[0]
if i + 1 < len(matches):
end_ws = matches[i + 1].span()[0]
else:
end_ws = span[1]
# Format data
data.append((
token,
text[start_ws: end_ws],
span,
))
return Tokens(data, self.annotators)
class SpacyTokenizer(Tokenizer):
def __init__(self, **kwargs):
"""
Args:
annotators: set that can include pos, lemma, and ner.
model: spaCy model to use (either path, or keyword like 'en').
"""
model = kwargs.get('model', 'en')
self.annotators = copy.deepcopy(kwargs.get('annotators', set()))
nlp_kwargs = {'parser': False}
if not any([p in self.annotators for p in ['lemma', 'pos', 'ner']]):
nlp_kwargs['tagger'] = False
if 'ner' not in self.annotators:
nlp_kwargs['entity'] = False
self.nlp = spacy.load(model, **nlp_kwargs)
def tokenize(self, text):
# We don't treat new lines as tokens.
clean_text = text.replace('\n', ' ')
tokens = self.nlp.tokenizer(clean_text)
if any([p in self.annotators for p in ['lemma', 'pos', 'ner']]):
self.nlp.tagger(tokens)
if 'ner' in self.annotators:
self.nlp.entity(tokens)
data = []
for i in range(len(tokens)):
# Get whitespace
start_ws = tokens[i].idx
if i + 1 < len(tokens):
end_ws = tokens[i + 1].idx
else:
end_ws = tokens[i].idx + len(tokens[i].text)
data.append((
tokens[i].text,
text[start_ws: end_ws],
(tokens[i].idx, tokens[i].idx + len(tokens[i].text)),
tokens[i].tag_,
tokens[i].lemma_,
tokens[i].ent_type_,
))
# Set special option for non-entity tag: '' vs 'O' in spaCy
return Tokens(data, self.annotators, opts={'non_ent': ''})
......@@ -39,6 +39,8 @@ class RaceDataset(Dataset):
print_rank_0(' >> total number of samples: {}'.format(
len(self.samples)))
# This indicates that each "sample" has multiple samples that
# will collapse into batch dimension
self.sample_multiplier = NUM_CHOICES
def __len__(self):
......
......@@ -19,7 +19,7 @@ from megatron import get_args
from megatron import print_rank_0
from megatron import get_tokenizer
from megatron import mpu
from megatron.model.multiple_choice import MultipleChoice, MultipleChoiceFirstStage, MultipleChoiceIntermediateStage, MultipleChoiceLastStage
from megatron.model.multiple_choice import MultipleChoice
from tasks.eval_utils import accuracy_func_provider
from tasks.finetune_utils import finetune
from tasks.race.data import RaceDataset
......@@ -38,20 +38,13 @@ def train_valid_datasets_provider():
return train_dataset, valid_dataset
def model_provider():
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
print_rank_0('building multichoice model for RACE ...')
if mpu.get_pipeline_model_parallel_world_size() > 1:
# Determine model based on position of stage in pipeline.
if mpu.is_pipeline_first_stage():
model = MultipleChoiceFirstStage(num_tokentypes=2)
elif mpu.is_pipeline_last_stage():
model = MultipleChoiceLastStage(num_tokentypes=2)
else:
model = MultipleChoiceIntermediateStage(num_tokentypes=2)
else:
model = MultipleChoice(num_tokentypes=2)
model = MultipleChoice(num_tokentypes=2,
pre_process=pre_process,
post_process=post_process)
return model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment