Commit 5fc301aa authored by Mohammad Shoeybi's avatar Mohammad Shoeybi
Browse files

Merge branch 'master' into mo_albert_data_loader

parents 8179ebd3 1b290993
...@@ -28,7 +28,7 @@ from arguments import get_args ...@@ -28,7 +28,7 @@ from arguments import get_args
from megatron.utils import Timers from megatron.utils import Timers
from megatron.utils import initialize_distributed from megatron.utils import initialize_distributed
from megatron.utils import set_random_seed from megatron.utils import set_random_seed
from pretrain_gpt2 import get_masks_and_position_ids from megatron.utils import get_ltor_masks_and_position_ids
from megatron.utils import load_checkpoint from megatron.utils import load_checkpoint
from megatron.data_utils import make_tokenizer from megatron.data_utils import make_tokenizer
from configure_data import configure_data from configure_data import configure_data
...@@ -91,7 +91,7 @@ def get_batch(context_tokens, args): ...@@ -91,7 +91,7 @@ def get_batch(context_tokens, args):
tokens = tokens.to(device) tokens = tokens.to(device)
# Get the masks and postition ids. # Get the masks and postition ids.
attention_mask, loss_mask, position_ids = get_masks_and_position_ids( attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
tokens, tokens,
args.eod_token, args.eod_token,
args.reset_position_ids, args.reset_position_ids,
......
...@@ -31,6 +31,63 @@ from megatron.model import DistributedDataParallel as LocalDDP ...@@ -31,6 +31,63 @@ from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import get_params_for_weight_decay_optimization from megatron.model import get_params_for_weight_decay_optimization
def get_ltor_masks_and_position_ids(data,
eod_token,
reset_position_ids,
reset_attention_mask,
eod_mask_loss):
"""Build masks and position id for left to right model."""
# Extract batch size and sequence length.
batch_size, seq_length = data.size()
# Attention mask (lower triangular).
if reset_attention_mask:
att_mask_batch = batch_size
else:
att_mask_batch = 1
attention_mask = torch.tril(torch.ones(
(att_mask_batch, seq_length, seq_length), device=data.device)).view(
att_mask_batch, 1, seq_length, seq_length)
# Loss mask.
loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device)
if eod_mask_loss:
loss_mask[data == eod_token] = 0.0
# Position ids.
position_ids = torch.arange(seq_length, dtype=torch.long,
device=data.device)
position_ids = position_ids.unsqueeze(0).expand_as(data)
# We need to clone as the ids will be modifed based on batch index.
if reset_position_ids:
position_ids = position_ids.clone()
if reset_position_ids or reset_attention_mask:
# Loop through the batches:
for b in range(batch_size):
# Find indecies where EOD token is.
eod_index = position_ids[b, data[b] == eod_token]
# Detach indecies from positions if going to modify positions.
if reset_position_ids:
eod_index = eod_index.clone()
# Loop through EOD indecies:
prev_index = 0
for j in range(eod_index.size()[0]):
i = eod_index[j]
# Mask attention loss.
if reset_attention_mask:
attention_mask[b, 0, (i+1):, :(i+1)] = 0
# Reset positions.
if reset_position_ids:
position_ids[b, (i+1):] -= (i + 1 - prev_index)
prev_index = i + 1
return attention_mask, loss_mask, position_ids
def reduce_losses(losses): def reduce_losses(losses):
reduced_losses = torch.cat( reduced_losses = torch.cat(
[loss.clone().detach().view(1) for loss in losses]) [loss.clone().detach().view(1) for loss in losses])
......
...@@ -21,6 +21,7 @@ from configure_data import configure_data ...@@ -21,6 +21,7 @@ from configure_data import configure_data
from gpt2_data_loader import make_gpt2_dataloaders from gpt2_data_loader import make_gpt2_dataloaders
from megatron import mpu from megatron import mpu
from megatron.model import GPT2Model from megatron.model import GPT2Model
from megatron.utils import get_ltor_masks_and_position_ids
from megatron.utils import print_rank_0 from megatron.utils import print_rank_0
from megatron.utils import reduce_losses from megatron.utils import reduce_losses
from megatron.utils import vocab_size_with_padding from megatron.utils import vocab_size_with_padding
...@@ -47,63 +48,6 @@ def model_provider(args): ...@@ -47,63 +48,6 @@ def model_provider(args):
return model return model
def get_masks_and_position_ids(data,
eod_token,
reset_position_ids,
reset_attention_mask,
eod_mask_loss):
"""Build masks and position id."""
# Extract batch size and sequence length.
batch_size, seq_length = data.size()
# Attention mask (lower triangular).
if reset_attention_mask:
att_mask_batch = batch_size
else:
att_mask_batch = 1
attention_mask = torch.tril(torch.ones(
(att_mask_batch, seq_length, seq_length), device=data.device)).view(
att_mask_batch, 1, seq_length, seq_length)
# Loss mask.
loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device)
if eod_mask_loss:
loss_mask[data == eod_token] = 0.0
# Position ids.
position_ids = torch.arange(seq_length, dtype=torch.long,
device=data.device)
position_ids = position_ids.unsqueeze(0).expand_as(data)
# We need to clone as the ids will be modifed based on batch index.
if reset_position_ids:
position_ids = position_ids.clone()
if reset_position_ids or reset_attention_mask:
# Loop through the batches:
for b in range(batch_size):
# Find indecies where EOD token is.
eod_index = position_ids[b, data[b] == eod_token]
# Detach indecies from positions if going to modify positions.
if reset_position_ids:
eod_index = eod_index.clone()
# Loop through EOD indecies:
prev_index = 0
for j in range(eod_index.size()[0]):
i = eod_index[j]
# Mask attention loss.
if reset_attention_mask:
attention_mask[b, 0, (i+1):, :(i+1)] = 0
# Reset positions.
if reset_position_ids:
position_ids[b, (i+1):] -= (i + 1 - prev_index)
prev_index = i + 1
return attention_mask, loss_mask, position_ids
def get_batch(data_iterator, args, timers): def get_batch(data_iterator, args, timers):
"""Generate a batch""" """Generate a batch"""
...@@ -126,7 +70,7 @@ def get_batch(data_iterator, args, timers): ...@@ -126,7 +70,7 @@ def get_batch(data_iterator, args, timers):
tokens = tokens_[:, :-1].contiguous() tokens = tokens_[:, :-1].contiguous()
# Get the masks and postition ids. # Get the masks and postition ids.
attention_mask, loss_mask, position_ids = get_masks_and_position_ids( attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
tokens, tokens,
args.eod_token, args.eod_token,
args.reset_position_ids, args.reset_position_ids,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment