Commit ebc95c35 authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

Merge branch 'main_retriver_merge' into main_retriver_merge_ict_eval

parents 360fb6b7 a4b628ab
...@@ -27,7 +27,7 @@ python pretrain_ict.py \ ...@@ -27,7 +27,7 @@ python pretrain_ict.py \
--log-interval 100 \ --log-interval 100 \
--eval-interval 1000 \ --eval-interval 1000 \
--eval-iters 10 \ --eval-iters 10 \
--report-topk-accuracies 1 5 10 20 100 \ --retriever-report-topk-accuracies 1 5 10 20 100 \
--retriever-score-scaling \ --retriever-score-scaling \
--load $CHECKPOINT_PATH \ --load $CHECKPOINT_PATH \
--save $CHECKPOINT_PATH \ --save $CHECKPOINT_PATH \
......
...@@ -649,16 +649,12 @@ def _add_biencoder_args(parser): ...@@ -649,16 +649,12 @@ def _add_biencoder_args(parser):
group.add_argument('--ict-head-size', type=int, default=None, group.add_argument('--ict-head-size', type=int, default=None,
help='Size of block embeddings to be used in ICT and ' help='Size of block embeddings to be used in ICT and '
'REALM (paper default: 128)') 'REALM (paper default: 128)')
group.add_argument('--projection-dim', type=int, default=0, group.add_argument('--biencoder-projection-dim', type=int, default=0,
help='Size of projection head used in biencoder (paper' help='Size of projection head used in biencoder (paper'
' default: 128)') ' default: 128)')
group.add_argument('--shared-query-context-model', action='store_true', group.add_argument('--biencoder-shared-query-context-model', action='store_true',
help='Whether to share the parameters of the query ' help='Whether to share the parameters of the query '
'and context models or not') 'and context models or not')
group.add_argument('--pool-type', type=str, default='cls-token',
choices=['avg', 'cls-token', 'max'],
help='different options are: avg | cls-token | max, '
'default=cls-token')
# checkpointing # checkpointing
group.add_argument('--ict-load', type=str, default=None, group.add_argument('--ict-load', type=str, default=None,
...@@ -677,7 +673,7 @@ def _add_biencoder_args(parser): ...@@ -677,7 +673,7 @@ def _add_biencoder_args(parser):
help='Whether to use one sentence documents in ICT') help='Whether to use one sentence documents in ICT')
# training # training
group.add_argument('--report-topk-accuracies', nargs='+', type=int, group.add_argument('--retriever-report-topk-accuracies', nargs='+', type=int,
default=[], help="Which top-k accuracies to report " default=[], help="Which top-k accuracies to report "
"(e.g. '1 5 20')") "(e.g. '1 5 20')")
group.add_argument('--retriever-score-scaling', action='store_true', group.add_argument('--retriever-score-scaling', action='store_true',
......
...@@ -17,7 +17,7 @@ from .module import MegatronModule ...@@ -17,7 +17,7 @@ from .module import MegatronModule
def biencoder_model_provider(only_query_model=False, def biencoder_model_provider(only_query_model=False,
only_context_model=False, only_context_model=False,
shared_query_context_model=False): biencoder_shared_query_context_model=False):
"""Build the model.""" """Build the model."""
args = get_args() args = get_args()
...@@ -31,10 +31,11 @@ def biencoder_model_provider(only_query_model=False, ...@@ -31,10 +31,11 @@ def biencoder_model_provider(only_query_model=False,
# the LM we initialize with has 2 tokentypes # the LM we initialize with has 2 tokentypes
model = BiEncoderModel( model = BiEncoderModel(
num_tokentypes=2, num_tokentypes=2,
parallel_output=True, parallel_output=False,
only_query_model=only_query_model, only_query_model=only_query_model,
only_context_model=only_context_model, only_context_model=only_context_model,
shared_query_context_model=shared_query_context_model) biencoder_shared_query_context_model=\
biencoder_shared_query_context_model)
return model return model
...@@ -47,7 +48,7 @@ class BiEncoderModel(MegatronModule): ...@@ -47,7 +48,7 @@ class BiEncoderModel(MegatronModule):
parallel_output=True, parallel_output=True,
only_query_model=False, only_query_model=False,
only_context_model=False, only_context_model=False,
shared_query_context_model=False): biencoder_shared_query_context_model=False):
super(BiEncoderModel, self).__init__() super(BiEncoderModel, self).__init__()
args = get_args() args = get_args()
...@@ -55,13 +56,14 @@ class BiEncoderModel(MegatronModule): ...@@ -55,13 +56,14 @@ class BiEncoderModel(MegatronModule):
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
parallel_output=parallel_output) parallel_output=parallel_output)
self.shared_query_context_model = shared_query_context_model self.biencoder_shared_query_context_model = \
biencoder_shared_query_context_model
assert not (only_context_model and only_query_model) assert not (only_context_model and only_query_model)
self.use_context_model = not only_query_model self.use_context_model = not only_query_model
self.use_query_model = not only_context_model self.use_query_model = not only_context_model
self.projection_dim = args.projection_dim self.biencoder_projection_dim = args.biencoder_projection_dim
if self.shared_query_context_model: if self.biencoder_shared_query_context_model:
self.model = PretrainedBertModel(**bert_kwargs) self.model = PretrainedBertModel(**bert_kwargs)
self._model_key = 'shared_model' self._model_key = 'shared_model'
self.query_model, self.context_model = self.model, self.model self.query_model, self.context_model = self.model, self.model
...@@ -109,7 +111,7 @@ class BiEncoderModel(MegatronModule): ...@@ -109,7 +111,7 @@ class BiEncoderModel(MegatronModule):
prefix='', keep_vars=False): prefix='', keep_vars=False):
"""Save dict with state dicts of each of the models.""" """Save dict with state dicts of each of the models."""
state_dict_ = {} state_dict_ = {}
if self.shared_query_context_model: if self.biencoder_shared_query_context_model:
state_dict_[self._model_key] = \ state_dict_[self._model_key] = \
self.model.state_dict_for_save_checkpoint(destination, self.model.state_dict_for_save_checkpoint(destination,
prefix, prefix,
...@@ -129,7 +131,7 @@ class BiEncoderModel(MegatronModule): ...@@ -129,7 +131,7 @@ class BiEncoderModel(MegatronModule):
def load_state_dict(self, state_dict, strict=True): def load_state_dict(self, state_dict, strict=True):
"""Load the state dicts of each of the models""" """Load the state dicts of each of the models"""
if self.shared_query_context_model: if self.biencoder_shared_query_context_model:
print_rank_0("Loading shared query-context model") print_rank_0("Loading shared query-context model")
self.model.load_state_dict(state_dict[self._model_key], \ self.model.load_state_dict(state_dict[self._model_key], \
strict=strict) strict=strict)
...@@ -188,14 +190,14 @@ class BiEncoderModel(MegatronModule): ...@@ -188,14 +190,14 @@ class BiEncoderModel(MegatronModule):
# load the LM state dict into each model # load the LM state dict into each model
model_dict = state_dict['model']['language_model'] model_dict = state_dict['model']['language_model']
if self.shared_query_context_model: if self.biencoder_shared_query_context_model:
self.model.language_model.load_state_dict(model_dict) self.model.language_model.load_state_dict(model_dict)
fix_query_key_value_ordering(self.model, checkpoint_version) fix_query_key_value_ordering(self.model, checkpoint_version)
else: else:
if self.use_query_model: if self.use_query_model:
self.query_model.language_model.load_state_dict(model_dict) self.query_model.language_model.load_state_dict(model_dict)
# give each model the same ict_head to begin with as well # give each model the same ict_head to begin with as well
if self.projection_dim > 0: if self.biencoder_projection_dim > 0:
query_proj_state_dict = \ query_proj_state_dict = \
self.state_dict_for_save_checkpoint()\ self.state_dict_for_save_checkpoint()\
[self._query_key]['projection_enc'] [self._query_key]['projection_enc']
...@@ -203,7 +205,8 @@ class BiEncoderModel(MegatronModule): ...@@ -203,7 +205,8 @@ class BiEncoderModel(MegatronModule):
if self.use_context_model: if self.use_context_model:
self.context_model.language_model.load_state_dict(model_dict) self.context_model.language_model.load_state_dict(model_dict)
if self.query_model is not None and self.projection_dim > 0: if self.query_model is not None and \
self.biencoder_projection_dim > 0:
self.context_model.projection_enc.load_state_dict\ self.context_model.projection_enc.load_state_dict\
(query_proj_state_dict) (query_proj_state_dict)
fix_query_key_value_ordering(self.context_model, checkpoint_version) fix_query_key_value_ordering(self.context_model, checkpoint_version)
...@@ -220,8 +223,7 @@ class PretrainedBertModel(MegatronModule): ...@@ -220,8 +223,7 @@ class PretrainedBertModel(MegatronModule):
args = get_args() args = get_args()
tokenizer = get_tokenizer() tokenizer = get_tokenizer()
self.pad_id = tokenizer.pad self.pad_id = tokenizer.pad
self.pool_type = args.pool_type self.biencoder_projection_dim = args.biencoder_projection_dim
self.projection_dim = args.projection_dim
self.parallel_output = parallel_output self.parallel_output = parallel_output
init_method = init_method_normal(args.init_method_std) init_method = init_method_normal(args.init_method_std)
scaled_init_method = scaled_init_method_normal( scaled_init_method = scaled_init_method_normal(
...@@ -234,9 +236,9 @@ class PretrainedBertModel(MegatronModule): ...@@ -234,9 +236,9 @@ class PretrainedBertModel(MegatronModule):
init_method=init_method, init_method=init_method,
scaled_init_method=scaled_init_method) scaled_init_method=scaled_init_method)
if args.projection_dim > 0: if args.biencoder_projection_dim > 0:
self.projection_enc = get_linear_layer(args.hidden_size, self.projection_enc = get_linear_layer(args.hidden_size,
args.projection_dim, args.biencoder_projection_dim,
init_method) init_method)
self._projection_enc_key = 'projection_enc' self._projection_enc_key = 'projection_enc'
...@@ -253,22 +255,14 @@ class PretrainedBertModel(MegatronModule): ...@@ -253,22 +255,14 @@ class PretrainedBertModel(MegatronModule):
# This mask will be used in average-pooling and max-pooling # This mask will be used in average-pooling and max-pooling
pool_mask = (input_ids == self.pad_id).unsqueeze(2) pool_mask = (input_ids == self.pad_id).unsqueeze(2)
# Taking the representation of the [CLS] token of BERT # Taking the representation of the [CLS] token of BERT
if self.pool_type == "cls-token": pooled_output = lm_output[:, 0, :]
pooled_output = lm_output[:, 0, :]
elif self.pool_type == "avg": # Average Pooling
pooled_output = lm_output.masked_fill(pool_mask, 0)
pooled_output = pooled_output.sum(1) / (pool_mask.size(1) \
- pool_mask.float().sum(1))
elif self.pool_type == "max": # Max-Pooling
pooled_output = lm_output.masked_fill(pool_mask, -1000)
pooled_output = torch.max(pooled_output, 1)[0]
# Converting to float16 dtype # Converting to float16 dtype
pooled_output = pooled_output.to(lm_output.dtype) pooled_output = pooled_output.to(lm_output.dtype)
# Output. # Output.
if self.projection_dim: if self.biencoder_projection_dim:
pooled_output = self.projection_enc(pooled_output) pooled_output = self.projection_enc(pooled_output)
return pooled_output return pooled_output
...@@ -283,7 +277,7 @@ class PretrainedBertModel(MegatronModule): ...@@ -283,7 +277,7 @@ class PretrainedBertModel(MegatronModule):
= self.language_model.state_dict_for_save_checkpoint( = self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars) destination, prefix, keep_vars)
if self.projection_dim > 0: if self.biencoder_projection_dim > 0:
state_dict_[self._projection_enc_key] = \ state_dict_[self._projection_enc_key] = \
self.projection_enc.state_dict(destination, prefix, keep_vars) self.projection_enc.state_dict(destination, prefix, keep_vars)
...@@ -295,7 +289,7 @@ class PretrainedBertModel(MegatronModule): ...@@ -295,7 +289,7 @@ class PretrainedBertModel(MegatronModule):
self.language_model.load_state_dict( self.language_model.load_state_dict(
state_dict[self._language_model_key], strict=strict) state_dict[self._language_model_key], strict=strict)
if self.projection_dim > 0: if self.biencoder_projection_dim > 0:
print_rank_0("loading projection head weights") print_rank_0("loading projection head weights")
self.projection_enc.load_state_dict( self.projection_enc.load_state_dict(
state_dict[self._projection_enc_key], strict=strict) state_dict[self._projection_enc_key], strict=strict)
...@@ -36,7 +36,8 @@ def pretrain_ict_model_provider(): ...@@ -36,7 +36,8 @@ def pretrain_ict_model_provider():
model = biencoder_model_provider( model = biencoder_model_provider(
only_context_model=False, only_context_model=False,
only_query_model=False, only_query_model=False,
shared_query_context_model=args.shared_query_context_model) biencoder_shared_query_context_model=\
args.biencoder_shared_query_context_model)
return model return model
def get_group_world_size_rank(): def get_group_world_size_rank():
...@@ -120,7 +121,7 @@ def forward_step(data_iterator, model, input_tensor): ...@@ -120,7 +121,7 @@ def forward_step(data_iterator, model, input_tensor):
return torch.cuda.FloatTensor([sum([int(i in sorted_indices[i, :k]) \ return torch.cuda.FloatTensor([sum([int(i in sorted_indices[i, :k]) \
for i in range(global_batch_size)]) / global_batch_size]) for i in range(global_batch_size)]) / global_batch_size])
topk_accs = [topk_accuracy(int(k)) for k in args.report_topk_accuracies] topk_accs = [topk_accuracy(int(k)) for k in args.retriever_report_topk_accuracies]
labels = torch.arange(global_batch_size).long().cuda() labels = torch.arange(global_batch_size).long().cuda()
loss = F.nll_loss(softmax_scores, labels, reduction='mean') loss = F.nll_loss(softmax_scores, labels, reduction='mean')
...@@ -131,7 +132,7 @@ def forward_step(data_iterator, model, input_tensor): ...@@ -131,7 +132,7 @@ def forward_step(data_iterator, model, input_tensor):
# create stats_dict with retrieval loss and all specified top-k accuracies # create stats_dict with retrieval loss and all specified top-k accuracies
topk_acc_dict = {'top{}_acc'.format(k): v * 100 for k, v in \ topk_acc_dict = {'top{}_acc'.format(k): v * 100 for k, v in \
zip(args.report_topk_accuracies, reduced_losses[1:])} zip(args.retriever_report_topk_accuracies, reduced_losses[1:])}
stats_dict = dict(loss=reduced_losses[0], **topk_acc_dict) stats_dict = dict(loss=reduced_losses[0], **topk_acc_dict)
return loss, stats_dict return loss, stats_dict
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment