Commit 447c1171 authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

addressed the comments given by Mohammad

parent 22a3d81a
......@@ -646,16 +646,12 @@ def _add_biencoder_args(parser):
group.add_argument('--ict-head-size', type=int, default=None,
help='Size of block embeddings to be used in ICT and '
'REALM (paper default: 128)')
group.add_argument('--projection-dim', type=int, default=0,
group.add_argument('--biencoder-projection-dim', type=int, default=0,
help='Size of projection head used in biencoder (paper'
' default: 128)')
group.add_argument('--shared-query-context-model', action='store_true',
group.add_argument('--biencoder-shared-query-context-model', action='store_true',
help='Whether to share the parameters of the query '
'and context models or not')
group.add_argument('--pool-type', type=str, default='cls-token',
choices=['avg', 'cls-token', 'max'],
help='different options are: avg | cls-token | max, '
'default=cls-token')
# checkpointing
group.add_argument('--ict-load', type=str, default=None,
......@@ -674,7 +670,7 @@ def _add_biencoder_args(parser):
help='Whether to use one sentence documents in ICT')
# training
group.add_argument('--report-topk-accuracies', nargs='+', type=int,
group.add_argument('--retriever-report-topk-accuracies', nargs='+', type=int,
default=[], help="Which top-k accuracies to report "
"(e.g. '1 5 20')")
group.add_argument('--retriever-score-scaling', action='store_true',
......
......@@ -17,7 +17,7 @@ from .module import MegatronModule
def biencoder_model_provider(only_query_model=False,
only_context_model=False,
shared_query_context_model=False):
biencoder_shared_query_context_model=False):
"""Build the model."""
args = get_args()
......@@ -31,10 +31,11 @@ def biencoder_model_provider(only_query_model=False,
# the LM we initialize with has 2 tokentypes
model = BiEncoderModel(
num_tokentypes=2,
parallel_output=True,
parallel_output=False,
only_query_model=only_query_model,
only_context_model=only_context_model,
shared_query_context_model=shared_query_context_model)
biencoder_shared_query_context_model=\
biencoder_shared_query_context_model)
return model
......@@ -47,7 +48,7 @@ class BiEncoderModel(MegatronModule):
parallel_output=True,
only_query_model=False,
only_context_model=False,
shared_query_context_model=False):
biencoder_shared_query_context_model=False):
super(BiEncoderModel, self).__init__()
args = get_args()
......@@ -55,13 +56,14 @@ class BiEncoderModel(MegatronModule):
num_tokentypes=num_tokentypes,
parallel_output=parallel_output)
self.shared_query_context_model = shared_query_context_model
self.biencoder_shared_query_context_model = \
biencoder_shared_query_context_model
assert not (only_context_model and only_query_model)
self.use_context_model = not only_query_model
self.use_query_model = not only_context_model
self.projection_dim = args.projection_dim
self.biencoder_projection_dim = args.biencoder_projection_dim
if self.shared_query_context_model:
if self.biencoder_shared_query_context_model:
self.model = PretrainedBertModel(**bert_kwargs)
self._model_key = 'shared_model'
self.query_model, self.context_model = self.model, self.model
......@@ -109,7 +111,7 @@ class BiEncoderModel(MegatronModule):
prefix='', keep_vars=False):
"""Save dict with state dicts of each of the models."""
state_dict_ = {}
if self.shared_query_context_model:
if self.biencoder_shared_query_context_model:
state_dict_[self._model_key] = \
self.model.state_dict_for_save_checkpoint(destination,
prefix,
......@@ -129,7 +131,7 @@ class BiEncoderModel(MegatronModule):
def load_state_dict(self, state_dict, strict=True):
"""Load the state dicts of each of the models"""
if self.shared_query_context_model:
if self.biencoder_shared_query_context_model:
print_rank_0("Loading shared query-context model")
self.model.load_state_dict(state_dict[self._model_key], \
strict=strict)
......@@ -188,14 +190,14 @@ class BiEncoderModel(MegatronModule):
# load the LM state dict into each model
model_dict = state_dict['model']['language_model']
if self.shared_query_context_model:
if self.biencoder_shared_query_context_model:
self.model.language_model.load_state_dict(model_dict)
fix_query_key_value_ordering(self.model, checkpoint_version)
else:
if self.use_query_model:
self.query_model.language_model.load_state_dict(model_dict)
# give each model the same ict_head to begin with as well
if self.projection_dim > 0:
if self.biencoder_projection_dim > 0:
query_proj_state_dict = \
self.state_dict_for_save_checkpoint()\
[self._query_key]['projection_enc']
......@@ -203,7 +205,8 @@ class BiEncoderModel(MegatronModule):
if self.use_context_model:
self.context_model.language_model.load_state_dict(model_dict)
if self.query_model is not None and self.projection_dim > 0:
if self.query_model is not None and \
self.biencoder_projection_dim > 0:
self.context_model.projection_enc.load_state_dict\
(query_proj_state_dict)
fix_query_key_value_ordering(self.context_model, checkpoint_version)
......@@ -220,8 +223,7 @@ class PretrainedBertModel(MegatronModule):
args = get_args()
tokenizer = get_tokenizer()
self.pad_id = tokenizer.pad
self.pool_type = args.pool_type
self.projection_dim = args.projection_dim
self.biencoder_projection_dim = args.biencoder_projection_dim
self.parallel_output = parallel_output
init_method = init_method_normal(args.init_method_std)
scaled_init_method = scaled_init_method_normal(
......@@ -234,9 +236,9 @@ class PretrainedBertModel(MegatronModule):
init_method=init_method,
scaled_init_method=scaled_init_method)
if args.projection_dim > 0:
if args.biencoder_projection_dim > 0:
self.projection_enc = get_linear_layer(args.hidden_size,
args.projection_dim,
args.biencoder_projection_dim,
init_method)
self._projection_enc_key = 'projection_enc'
......@@ -253,22 +255,14 @@ class PretrainedBertModel(MegatronModule):
# This mask will be used in average-pooling and max-pooling
pool_mask = (input_ids == self.pad_id).unsqueeze(2)
# Taking the representation of the [CLS] token of BERT
if self.pool_type == "cls-token":
pooled_output = lm_output[:, 0, :]
elif self.pool_type == "avg": # Average Pooling
pooled_output = lm_output.masked_fill(pool_mask, 0)
pooled_output = pooled_output.sum(1) / (pool_mask.size(1) \
- pool_mask.float().sum(1))
elif self.pool_type == "max": # Max-Pooling
pooled_output = lm_output.masked_fill(pool_mask, -1000)
pooled_output = torch.max(pooled_output, 1)[0]
# Taking the representation of the [CLS] token of BERT
pooled_output = lm_output[:, 0, :]
# Converting to float16 dtype
pooled_output = pooled_output.to(lm_output.dtype)
# Output.
if self.projection_dim:
if self.biencoder_projection_dim:
pooled_output = self.projection_enc(pooled_output)
return pooled_output
......@@ -283,7 +277,7 @@ class PretrainedBertModel(MegatronModule):
= self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
if self.projection_dim > 0:
if self.biencoder_projection_dim > 0:
state_dict_[self._projection_enc_key] = \
self.projection_enc.state_dict(destination, prefix, keep_vars)
......@@ -295,7 +289,7 @@ class PretrainedBertModel(MegatronModule):
self.language_model.load_state_dict(
state_dict[self._language_model_key], strict=strict)
if self.projection_dim > 0:
if self.biencoder_projection_dim > 0:
print_rank_0("loading projection head weights")
self.projection_enc.load_state_dict(
state_dict[self._projection_enc_key], strict=strict)
......@@ -36,7 +36,8 @@ def pretrain_ict_model_provider():
model = biencoder_model_provider(
only_context_model=False,
only_query_model=False,
shared_query_context_model=args.shared_query_context_model)
biencoder_shared_query_context_model=\
args.biencoder_shared_query_context_model)
return model
def get_group_world_size_rank():
......@@ -120,7 +121,7 @@ def forward_step(data_iterator, model, input_tensor):
return torch.cuda.FloatTensor([sum([int(i in sorted_indices[i, :k]) \
for i in range(global_batch_size)]) / global_batch_size])
topk_accs = [topk_accuracy(int(k)) for k in args.report_topk_accuracies]
topk_accs = [topk_accuracy(int(k)) for k in args.retriever_report_topk_accuracies]
labels = torch.arange(global_batch_size).long().cuda()
loss = F.nll_loss(softmax_scores, labels, reduction='mean')
......@@ -131,7 +132,7 @@ def forward_step(data_iterator, model, input_tensor):
# create stats_dict with retrieval loss and all specified top-k accuracies
topk_acc_dict = {'top{}_acc'.format(k): v * 100 for k, v in \
zip(args.report_topk_accuracies, reduced_losses[1:])}
zip(args.retriever_report_topk_accuracies, reduced_losses[1:])}
stats_dict = dict(loss=reduced_losses[0], **topk_acc_dict)
return loss, stats_dict
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment