Commit 6495927e authored by Neel Kant's avatar Neel Kant
Browse files

Address most of comments from code review w/ Mohammad

parent fcc500d6
......@@ -136,6 +136,8 @@ def _add_network_size_args(parser):
' grouped: [1, 2, 1, 2] and spaced: [1, 1, 2, 2].')
group.add_argument('--hidden-size', type=int, default=None,
help='Tansformer hidden size.')
group.add_argument('--ict-head-size', type=int, default=None,
help='Size of block embeddings to be used in ICT and REALM (paper default: 128)')
group.add_argument('--num-attention-heads', type=int, default=None,
help='Number of transformer attention heads.')
group.add_argument('--max-position-embeddings', type=int, default=None,
......@@ -202,8 +204,6 @@ def _add_training_args(parser):
def _add_initialization_args(parser):
group = parser.add_argument_group(title='initialization')
group.add_argument('--debug', action='store_true',
help='Run things in debug mode')
group.add_argument('--seed', type=int, default=1234,
help='Random seed used for python, numpy, '
'pytorch, and cuda.')
......
......@@ -128,13 +128,10 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
torch.distributed.barrier()
def load_checkpoint(model, optimizer, lr_scheduler):
def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
"""Load a model checkpoint and return the iteration."""
args = get_args()
load_dir = args.load
from megatron.model.bert_model import BertModel
if isinstance(model, BertModel) and args.bert_load is not None:
load_dir = args.bert_load
load_dir = getattr(args, load_arg)
if isinstance(model, torchDDP):
model = model.module
......
......@@ -25,6 +25,11 @@ from torch.utils.data import Dataset
from megatron import get_tokenizer, get_args
from megatron import mpu
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
from megatron.data.dataset_utils import get_a_and_b_segments
from megatron.data.dataset_utils import truncate_segments
from megatron.data.dataset_utils import create_tokens_and_tokentypes
from megatron.data.dataset_utils import pad_and_convert_to_numpy
from megatron.data.dataset_utils import create_masked_lm_predictions
from megatron import print_rank_0
......@@ -61,8 +66,6 @@ class BertDataset(Dataset):
self.sep_id = tokenizer.sep
self.mask_id = tokenizer.mask
self.pad_id = tokenizer.pad
from megatron.data.dataset_utils import build_training_sample
self.build_sample_fn = build_training_sample
def __len__(self):
return self.samples_mapping.shape[0]
......@@ -73,7 +76,7 @@ class BertDataset(Dataset):
# Note that this rng state should be numpy and not python since
# python randint is inclusive whereas the numpy one is exclusive.
np_rng = np.random.RandomState(seed=(self.seed + idx))
return self.build_sample_fn(sample, seq_length,
return build_training_sample(sample, seq_length,
self.max_seq_length, # needed for padding
self.vocab_id_list,
self.vocab_id_to_token_dict,
......@@ -214,3 +217,66 @@ def get_samples_mapping_(indexed_dataset,
samples_mapping.shape[0]))
return samples_mapping
def build_training_sample(sample,
target_seq_length, max_seq_length,
vocab_id_list, vocab_id_to_token_dict,
cls_id, sep_id, mask_id, pad_id,
masked_lm_prob, np_rng):
"""Biuld training sample.
Arguments:
sample: A list of sentences in which each sentence is a list token ids.
target_seq_length: Desired sequence length.
max_seq_length: Maximum length of the sequence. All values are padded to
this length.
vocab_id_list: List of vocabulary ids. Used to pick a random id.
vocab_id_to_token_dict: A dictionary from vocab ids to text tokens.
cls_id: Start of example id.
sep_id: Separator id.
mask_id: Mask token id.
pad_id: Padding token id.
masked_lm_prob: Probability to mask tokens.
np_rng: Random number genenrator. Note that this rng state should be
numpy and not python since python randint is inclusive for
the opper bound whereas the numpy one is exclusive.
"""
# We assume that we have at least two sentences in the sample
assert len(sample) > 1
assert target_seq_length <= max_seq_length
# Divide sample into two segments (A and B).
tokens_a, tokens_b, is_next_random = get_a_and_b_segments(sample, np_rng)
# Truncate to `target_sequence_length`.
max_num_tokens = target_seq_length
truncated = truncate_segments(tokens_a, tokens_b, len(tokens_a),
len(tokens_b), max_num_tokens, np_rng)
# Build tokens and toketypes.
tokens, tokentypes = create_tokens_and_tokentypes(tokens_a, tokens_b,
cls_id, sep_id)
# Masking.
max_predictions_per_seq = masked_lm_prob * max_num_tokens
(tokens, masked_positions, masked_labels, _) = create_masked_lm_predictions(
tokens, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob,
cls_id, sep_id, mask_id, max_predictions_per_seq, np_rng)
# Padding.
tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np \
= pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
masked_labels, pad_id, max_seq_length)
train_sample = {
'text': tokens_np,
'types': tokentypes_np,
'labels': labels_np,
'is_random': int(is_next_random),
'loss_mask': loss_mask_np,
'padding_mask': padding_mask_np,
'truncated': int(truncated)}
return train_sample
......@@ -23,9 +23,11 @@ import itertools
import numpy as np
from megatron import print_rank_0, get_args
from megatron.data.bert_dataset import get_indexed_dataset_, get_train_valid_test_split_
DATASET_TYPES = ['standard_bert', 'ict']
DSET_TYPE_STD = 'standard_bert'
DSET_TYPE_ICT = 'ict'
DSET_TYPES = [DSET_TYPE_ICT, DSET_TYPE_STD]
def compile_helper():
"""Compile helper function ar runtime. Make sure this
......@@ -40,68 +42,6 @@ def compile_helper():
sys.exit(1)
def build_training_sample(sample,
target_seq_length, max_seq_length,
vocab_id_list, vocab_id_to_token_dict,
cls_id, sep_id, mask_id, pad_id,
masked_lm_prob, np_rng):
"""Biuld training sample.
Arguments:
sample: A list of sentences in which each sentence is a list token ids.
target_seq_length: Desired sequence length.
max_seq_length: Maximum length of the sequence. All values are padded to
this length.
vocab_id_list: List of vocabulary ids. Used to pick a random id.
vocab_id_to_token_dict: A dictionary from vocab ids to text tokens.
cls_id: Start of example id.
sep_id: Separator id.
mask_id: Mask token id.
pad_id: Padding token id.
masked_lm_prob: Probability to mask tokens.
np_rng: Random number genenrator. Note that this rng state should be
numpy and not python since python randint is inclusive for
the opper bound whereas the numpy one is exclusive.
"""
# We assume that we have at least two sentences in the sample
assert len(sample) > 1
assert target_seq_length <= max_seq_length
# Divide sample into two segments (A and B).
tokens_a, tokens_b, is_next_random = get_a_and_b_segments(sample, np_rng)
# Truncate to `target_sequence_length`.
max_num_tokens = target_seq_length
truncated = truncate_segments(tokens_a, tokens_b, len(tokens_a),
len(tokens_b), max_num_tokens, np_rng)
# Build tokens and toketypes.
tokens, tokentypes = create_tokens_and_tokentypes(tokens_a, tokens_b,
cls_id, sep_id)
# Masking.
max_predictions_per_seq = masked_lm_prob * max_num_tokens
(tokens, masked_positions, masked_labels, _) = create_masked_lm_predictions(
tokens, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob,
cls_id, sep_id, mask_id, max_predictions_per_seq, np_rng)
# Padding.
tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np \
= pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
masked_labels, pad_id, max_seq_length)
train_sample = {
'text': tokens_np,
'types': tokentypes_np,
'labels': labels_np,
'is_random': int(is_next_random),
'loss_mask': loss_mask_np,
'padding_mask': padding_mask_np,
'truncated': int(truncated)}
return train_sample
def get_a_and_b_segments(sample, np_rng):
"""Divide sample into a and b segments."""
......@@ -418,7 +358,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
short_seq_prob, seed, skip_warmup,
dataset_type='standard_bert'):
if dataset_type not in DATASET_TYPES:
if dataset_type not in DSET_TYPES:
raise ValueError("Invalid dataset_type: ", dataset_type)
# Indexed dataset.
......@@ -426,7 +366,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
data_impl,
skip_warmup)
if dataset_type in ['ict']:
if dataset_type == DSET_TYPE_ICT:
args = get_args()
title_dataset = get_indexed_dataset_(args.titles_data_path,
data_impl,
......@@ -479,7 +419,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
seed=seed
)
if dataset_type == 'ict':
if dataset_type == DSET_TYPE_ICT:
args = get_args()
dataset = ICTDataset(
block_dataset=indexed_dataset,
......
......@@ -452,10 +452,12 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
// Current map index.
uint64_t map_index = 0;
int32_t block_id = 0;
// For each epoch:
for (int32_t epoch=0; epoch<num_epochs; ++epoch) {
// assign every block a unique id
int32_t block_id = 0;
if (map_index >= max_num_samples) {
if (verbose && (!second)) {
cout << " reached " << max_num_samples << " samples after "
......@@ -516,6 +518,10 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
// Populate the map.
if (second) {
const auto map_index_0 = 4 * map_index;
// Each sample has 4 items: the starting sentence index, ending sentence index,
// the index of the document from which the block comes (used for fetching titles)
// and the unique id of the block (used for creating block indexes)
maps[map_index_0] = static_cast<DocIdx>(prev_start_index);
maps[map_index_0 + 1] = static_cast<DocIdx>(sent_index + 1);
maps[map_index_0 + 2] = static_cast<DocIdx>(doc);
......
......@@ -41,14 +41,15 @@ class ICTDataset(Dataset):
"""Get an ICT example of a pseudo-query and the block of text from which it was extracted"""
start_idx, end_idx, doc_idx, block_idx = self.samples_mapping[idx]
if self.use_titles:
title = list(self.title_dataset[int(doc_idx)])
title = self.title_dataset[int(doc_idx)]
title_pad_offset = 3 + len(title)
else:
title = None
title_pad_offset = 2
block = [list(self.block_dataset[i]) for i in range(start_idx, end_idx)]
block = [self.block_dataset[i] for i in range(start_idx, end_idx)]
assert len(block) > 1
# randint() is inclusive for Python rng
rand_sent_idx = self.rng.randint(0, len(block) - 1)
# keep the query in the context query_in_block_prob fraction of the time.
......@@ -64,53 +65,47 @@ class ICTDataset(Dataset):
query_tokens, query_pad_mask = self.concat_and_pad_tokens(query)
block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)
block_data = np.array([start_idx, end_idx, doc_idx, block_idx]).astype(np.int64)
sample = {
'query_tokens': np.array(query_tokens),
'query_pad_mask': np.array(query_pad_mask),
'block_tokens': np.array(block_tokens),
'block_pad_mask': np.array(block_pad_mask),
'block_data': np.array([start_idx, end_idx, doc_idx, block_idx]).astype(np.int64)
'query_tokens': query_tokens,
'query_pad_mask': query_pad_mask,
'block_tokens': block_tokens,
'block_pad_mask': block_pad_mask,
'block_data': block_data,
}
return sample
def encode_text(self, text):
return self.tokenizer.tokenize(text)
def decode_tokens(self, token_ids):
"""Utility function to help with debugging mostly"""
tokens = self.tokenizer.tokenizer.convert_ids_to_tokens(token_ids)
exclude_list = ['[PAD]', '[CLS]']
non_pads = [t for t in tokens if t not in exclude_list]
joined_strs = join_str_list(non_pads)
def get_block(self, start_idx, end_idx, doc_idx):
"""Get the IDs for an evidence block plus the title of the corresponding document"""
block = [list(self.block_dataset[i]) for i in range(start_idx, end_idx)]
title = list(self.title_dataset[int(doc_idx)])
block = [self.block_dataset[i] for i in range(start_idx, end_idx)]
title = self.title_dataset[int(doc_idx)]
block = list(itertools.chain(*block))[:self.max_seq_length - (3 + len(title))]
block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)
return (block_tokens, block_pad_mask)
return block_tokens, block_pad_mask
def get_null_block(self):
"""Get empty block and title - used in REALM pretraining"""
block, title = [], []
block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)
return (block_tokens, block_pad_mask)
return block_tokens, block_pad_mask
def concat_and_pad_tokens(self, tokens, title=None):
"""Concat with special tokens and pad sequence to self.max_seq_length"""
tokens = list(tokens)
if title is None:
tokens = [self.cls_id] + tokens + [self.sep_id]
else:
title = list(title)
tokens = [self.cls_id] + title + [self.sep_id] + tokens + [self.sep_id]
assert len(tokens) <= self.max_seq_length, len(tokens)
assert len(tokens) <= self.max_seq_length
num_pad = self.max_seq_length - len(tokens)
pad_mask = [1] * len(tokens) + [0] * num_pad
tokens += [self.pad_id] * num_pad
return tokens, pad_mask
return np.array(tokens), np.array(pad_mask)
......@@ -20,6 +20,8 @@ def join_str_list(str_list):
def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epochs,
max_num_samples, max_seq_length, seed, name):
"""Get samples mapping for a dataset over fixed size blocks. This function also requires
a dataset of the titles for the source documents since their lengths must be taken into account."""
if not num_epochs:
if not max_num_samples:
raise ValueError("Need to specify either max_num_samples "
......@@ -40,7 +42,7 @@ def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epo
indexmap_filename += '.npy'
# Build the indexed mapping if not exist.
if torch.distributed.get_rank() == 0 and \
if mpu.get_data_parallel_rank() == 0 and \
not os.path.isfile(indexmap_filename):
print(' > WARNING: could not find index map file {}, building '
'the indices on rank 0 ...'.format(indexmap_filename))
......
......@@ -25,46 +25,12 @@ from megatron.model.utils import openai_gelu
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal
from megatron.model.utils import bert_attention_mask_func
from megatron.model.utils import bert_extended_attention_mask
from megatron.model.utils import bert_position_ids
from megatron.module import MegatronModule
def bert_attention_mask_func(attention_scores, attention_mask):
attention_scores = attention_scores + attention_mask
return attention_scores
def bert_extended_attention_mask(attention_mask, dtype):
# We create a 3D attention mask from a 2D tensor mask.
# [b, 1, s]
attention_mask_b1s = attention_mask.unsqueeze(1)
# [b, s, 1]
attention_mask_bs1 = attention_mask.unsqueeze(2)
# [b, s, s]
attention_mask_bss = attention_mask_b1s * attention_mask_bs1
# [b, 1, s, s]
extended_attention_mask = attention_mask_bss.unsqueeze(1)
# Since attention_mask is 1.0 for positions we want to attend and 0.0
# for masked positions, this operation will create a tensor which is
# 0.0 for positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
# fp16 compatibility
extended_attention_mask = extended_attention_mask.to(dtype=dtype)
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
return extended_attention_mask
def bert_position_ids(token_ids):
# Create position ids
seq_length = token_ids.size(1)
position_ids = torch.arange(seq_length, dtype=torch.long,
device=token_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(token_ids)
return position_ids
class BertLMHead(MegatronModule):
"""Masked LM head for Bert
......@@ -110,29 +76,23 @@ class BertModel(MegatronModule):
"""Bert Language model."""
def __init__(self, num_tokentypes=2, add_binary_head=True,
ict_head_size=None, parallel_output=True):
parallel_output=True):
super(BertModel, self).__init__()
args = get_args()
self.add_binary_head = add_binary_head
self.ict_head_size = ict_head_size
self.add_ict_head = ict_head_size is not None
assert not (self.add_binary_head and self.add_ict_head)
self.parallel_output = parallel_output
init_method = init_method_normal(args.init_method_std)
add_pooler = self.add_binary_head or self.add_ict_head
scaled_init_method = scaled_init_method_normal(args.init_method_std,
args.num_layers)
self.language_model, self._language_model_key = get_language_model(
attention_mask_func=bert_attention_mask_func,
num_tokentypes=num_tokentypes,
add_pooler=add_pooler,
add_pooler=self.add_binary_head,
init_method=init_method,
scaled_init_method=scaled_init_method)
if not self.add_ict_head:
self.lm_head = BertLMHead(
self.language_model.embedding.word_embeddings.weight.size(0),
args.hidden_size, init_method, args.layernorm_epsilon, parallel_output)
......@@ -141,9 +101,6 @@ class BertModel(MegatronModule):
self.binary_head = get_linear_layer(args.hidden_size, 2,
init_method)
self._binary_head_key = 'binary_head'
elif self.add_ict_head:
self.ict_head = get_linear_layer(args.hidden_size, ict_head_size, init_method)
self._ict_head_key = 'ict_head'
def forward(self, input_ids, attention_mask, tokentype_ids=None):
......@@ -151,7 +108,7 @@ class BertModel(MegatronModule):
attention_mask, next(self.language_model.parameters()).dtype)
position_ids = bert_position_ids(input_ids)
if self.add_binary_head or self.add_ict_head:
if self.add_binary_head:
lm_output, pooled_output = self.language_model(
input_ids,
position_ids,
......@@ -165,12 +122,9 @@ class BertModel(MegatronModule):
tokentype_ids=tokentype_ids)
# Output.
if self.add_ict_head:
ict_logits = self.ict_head(pooled_output)
return ict_logits, None
lm_logits = self.lm_head(
lm_output, self.language_model.embedding.word_embeddings.weight)
if self.add_binary_head:
binary_logits = self.binary_head(pooled_output)
return lm_logits, binary_logits
......@@ -186,16 +140,12 @@ class BertModel(MegatronModule):
state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
if not self.add_ict_head:
state_dict_[self._lm_head_key] \
= self.lm_head.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
if self.add_binary_head:
state_dict_[self._binary_head_key] \
= self.binary_head.state_dict(destination, prefix, keep_vars)
elif self.add_ict_head:
state_dict_[self._ict_head_key] \
= self.ict_head.state_dict(destination, prefix, keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
......@@ -203,14 +153,10 @@ class BertModel(MegatronModule):
self.language_model.load_state_dict(
state_dict[self._language_model_key], strict=strict)
if not self.add_ict_head:
self.lm_head.load_state_dict(
state_dict[self._lm_head_key], strict=strict)
if self.add_binary_head:
self.binary_head.load_state_dict(
state_dict[self._binary_head_key], strict=strict)
elif self.add_ict_head:
self.ict_head.load_state_dict(
state_dict[self._ict_head_key], strict=strict)
......@@ -6,6 +6,13 @@ from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoi
from megatron.model import BertModel
from megatron.module import MegatronModule
from megatron import mpu
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal
from megatron.model.language_model import get_language_model
from megatron.model.utils import scaled_init_method_normal
from megatron.model.utils import bert_attention_mask_func
from megatron.model.utils import bert_extended_attention_mask
from megatron.model.utils import bert_position_ids
class ICTBertModel(MegatronModule):
......@@ -17,10 +24,9 @@ class ICTBertModel(MegatronModule):
only_query_model=False,
only_block_model=False):
super(ICTBertModel, self).__init__()
bert_args = dict(
num_tokentypes=num_tokentypes,
add_binary_head=False,
bert_kwargs = dict(
ict_head_size=ict_head_size,
num_tokentypes=num_tokentypes,
parallel_output=parallel_output
)
assert not (only_block_model and only_query_model)
......@@ -29,12 +35,12 @@ class ICTBertModel(MegatronModule):
if self.use_query_model:
# this model embeds (pseudo-)queries - Embed_input in the paper
self.query_model = BertModel(**bert_args)
self.query_model = IREncoderBertModel(**bert_kwargs)
self._query_key = 'question_model'
if self.use_block_model:
# this model embeds evidence blocks - Embed_doc in the paper
self.block_model = BertModel(**bert_args)
self.block_model = IREncoderBertModel(**bert_kwargs)
self._block_key = 'context_model'
def forward(self, query_tokens, query_attention_mask, block_tokens, block_attention_mask):
......@@ -116,3 +122,64 @@ class ICTBertModel(MegatronModule):
# give each model the same ict_head to begin with as well
query_ict_head_state_dict = self.state_dict_for_save_checkpoint()[self._query_key]['ict_head']
self.block_model.ict_head.load_state_dict(query_ict_head_state_dict)
class IREncoderBertModel(MegatronModule):
"""Bert Language model."""
def __init__(self, ict_head_size, num_tokentypes=2, parallel_output=True):
super(IREncoderBertModel, self).__init__()
args = get_args()
self.ict_head_size = ict_head_size
self.parallel_output = parallel_output
init_method = init_method_normal(args.init_method_std)
scaled_init_method = scaled_init_method_normal(args.init_method_std,
args.num_layers)
self.language_model, self._language_model_key = get_language_model(
attention_mask_func=bert_attention_mask_func,
num_tokentypes=num_tokentypes,
add_pooler=True,
init_method=init_method,
scaled_init_method=scaled_init_method)
self.ict_head = get_linear_layer(args.hidden_size, ict_head_size, init_method)
self._ict_head_key = 'ict_head'
def forward(self, input_ids, attention_mask, tokentype_ids=None):
extended_attention_mask = bert_extended_attention_mask(
attention_mask, next(self.language_model.parameters()).dtype)
position_ids = bert_position_ids(input_ids)
lm_output, pooled_output = self.language_model(
input_ids,
position_ids,
extended_attention_mask,
tokentype_ids=tokentype_ids)
# Output.
if self.add_ict_head:
ict_logits = self.ict_head(pooled_output)
return ict_logits, None
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
"""For easy load when model is combined with other heads,
add an extra key."""
state_dict_ = {}
state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
state_dict_[self._ict_head_key] \
= self.ict_head.state_dict(destination, prefix, keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
self.language_model.load_state_dict(
state_dict[self._language_model_key], strict=strict)
self.ict_head.load_state_dict(
state_dict[self._ict_head_key], strict=strict)
......@@ -78,3 +78,42 @@ def get_params_for_weight_decay_optimization(module):
if p is not None and n == 'bias'])
return weight_decay_params, no_weight_decay_params
def bert_attention_mask_func(attention_scores, attention_mask):
attention_scores = attention_scores + attention_mask
return attention_scores
def bert_extended_attention_mask(attention_mask, dtype):
# We create a 3D attention mask from a 2D tensor mask.
# [b, 1, s]
attention_mask_b1s = attention_mask.unsqueeze(1)
# [b, s, 1]
attention_mask_bs1 = attention_mask.unsqueeze(2)
# [b, s, s]
attention_mask_bss = attention_mask_b1s * attention_mask_bs1
# [b, 1, s, s]
extended_attention_mask = attention_mask_bss.unsqueeze(1)
# Since attention_mask is 1.0 for positions we want to attend and 0.0
# for masked positions, this operation will create a tensor which is
# 0.0 for positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
# fp16 compatibility
extended_attention_mask = extended_attention_mask.to(dtype=dtype)
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
return extended_attention_mask
def bert_position_ids(token_ids):
# Create position ids
seq_length = token_ids.size(1)
position_ids = torch.arange(seq_length, dtype=torch.long,
device=token_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(token_ids)
return position_ids
......@@ -20,6 +20,7 @@ from abc import abstractmethod
from .bert_tokenization import FullTokenizer as FullBertTokenizer
from .gpt2_tokenization import GPT2Tokenizer
from megatron.data.realm_dataset_utils import join_str_list
def build_tokenizer(args):
......@@ -155,6 +156,13 @@ class _BertWordPieceTokenizer(AbstractTokenizer):
text_tokens = self.tokenizer.tokenize(text)
return self.tokenizer.convert_tokens_to_ids(text_tokens)
def decode_token_ids(self, token_ids):
tokens = self.tokenizer.convert_ids_to_tokens(token_ids)
exclude_list = ['[PAD]', '[CLS]']
non_pads = [t for t in tokens if t not in exclude_list]
joined_strs = join_str_list(non_pads)
return joined_strs
@property
def cls(self):
return self.cls_id
......
......@@ -218,9 +218,10 @@ def setup_model_and_optimizer(model_provider_func):
else:
args.iteration = 0
if args.iteration == 0 and isinstance(model.module.module, ICTBertModel):
unwrapped_model = model.module.module
if args.iteration == 0 and hasattr(unwrapped_model, 'init_state_dict_from_bert'):
print("Initializing ICT from pretrained BERT model", flush=True)
model.module.module.init_state_dict_from_bert()
unwrapped_model.init_state_dict_from_bert()
return model, optimizer, lr_scheduler
......
......@@ -31,14 +31,17 @@ from megatron.utils import reduce_losses
num_batches = 0
def model_provider(only_query_model=False, only_block_model=False):
def general_model_provider(only_query_model=False, only_block_model=False):
"""Build the model."""
args = get_args()
print_rank_0('building BERT models ...')
if args.ict_head_size is None:
raise ValueError("Need to specify --ict-head-size to provide an ICTBertModel")
print_rank_0('building ICTBertModel...')
# simpler to just keep using 2 tokentypes since the LM we initialize with has 2 tokentypes
model = ICTBertModel(
ict_head_size=128,
ict_head_size=args.ict_head_size,
num_tokentypes=2,
parallel_output=True,
only_query_model=only_query_model,
......@@ -47,6 +50,10 @@ def model_provider(only_query_model=False, only_block_model=False):
return model
def model_provider():
return general_model_provider(False, False)
def get_batch(data_iterator):
# Items and their type.
keys = ['query_tokens', 'query_pad_mask',
......@@ -98,6 +105,7 @@ def forward_step(data_iterator, model):
all_query_logits[args.rank * batch_size:(args.rank + 1) * batch_size] = query_logits
all_block_logits[args.rank * batch_size:(args.rank + 1) * batch_size] = block_logits
# currently this assumes model parallel size == 1.
dist.all_reduce(all_query_logits)
dist.all_reduce(all_block_logits)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment