import os import torch import sys from megatron import get_args, print_rank_0 from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoint_name from megatron.module import MegatronModule from megatron import mpu, get_tokenizer from megatron.model.bert_model import bert_attention_mask_func from megatron.model.bert_model import bert_extended_attention_mask from megatron.model.bert_model import bert_position_ids from megatron.model.language_model import get_language_model from megatron.model.utils import get_linear_layer from megatron.model.utils import init_method_normal from megatron.model.utils import scaled_init_method_normal def biencoder_model_provider(only_query_model=False, only_context_model=False, shared_query_context_model=False): """Build the model.""" args = get_args() assert mpu.get_tensor_model_parallel_world_size() == 1 and \ mpu.get_pipeline_model_parallel_world_size() == 1, \ "Model parallel size > 1 not supported for ICT" print_rank_0('building BiEncoderModel...') # simpler to just keep using 2 tokentypes since # the LM we initialize with has 2 tokentypes model = BiEncoderModel( num_tokentypes=2, parallel_output=True, only_query_model=only_query_model, only_context_model=only_context_model, shared_query_context_model=shared_query_context_model) return model class BiEncoderModel(MegatronModule): """Bert-based module for Biencoder model.""" def __init__(self, num_tokentypes=1, parallel_output=True, only_query_model=False, only_context_model=False, shared_query_context_model=False): super(BiEncoderModel, self).__init__() args = get_args() bert_kwargs = dict( num_tokentypes=num_tokentypes, parallel_output=parallel_output) self.shared_query_context_model = shared_query_context_model assert not (only_context_model and only_query_model) self.use_context_model = not only_query_model self.use_query_model = not only_context_model self.projection_dim = args.projection_dim if self.shared_query_context_model: self.model = PretrainedBertModel(**bert_kwargs) self._model_key = 'shared_model' self.query_model, self.context_model = self.model, self.model else: if self.use_query_model: # this model embeds (pseudo-)queries - Embed_input in the paper self.query_model = PretrainedBertModel(**bert_kwargs) self._query_key = 'query_model' if self.use_context_model: # this model embeds evidence blocks - Embed_doc in the paper self.context_model = PretrainedBertModel(**bert_kwargs) self._context_key = 'context_model' def forward(self, query_tokens, query_attention_mask, query_types, context_tokens, context_attention_mask, context_types): """Run a forward pass for each of the models and return the respective embeddings.""" if self.use_query_model: query_logits = self.embed_text(self.query_model, query_tokens, query_attention_mask, query_types) else: raise ValueError("Cannot embed query without the query model.") if self.use_context_model: context_logits = self.embed_text(self.context_model, context_tokens, context_attention_mask, context_types) else: raise ValueError("Cannot embed block without the block model.") return query_logits, context_logits @staticmethod def embed_text(model, tokens, attention_mask, token_types): """Embed a batch of tokens using the model""" logits = model(tokens, attention_mask, token_types) return logits def state_dict_for_save_checkpoint(self, destination=None, \ prefix='', keep_vars=False): """Save dict with state dicts of each of the models.""" state_dict_ = {} if self.shared_query_context_model: state_dict_[self._model_key] = \ self.model.state_dict_for_save_checkpoint(destination, prefix, keep_vars) else: if self.use_query_model: state_dict_[self._query_key] = \ self.query_model.state_dict_for_save_checkpoint( destination, prefix, keep_vars) if self.use_context_model: state_dict_[self._context_key] = \ self.context_model.state_dict_for_save_checkpoint( destination, prefix, keep_vars) return state_dict_ def load_state_dict(self, state_dict, strict=True): """Load the state dicts of each of the models""" if self.shared_query_context_model: print_rank_0("Loading shared query-context model") self.model.load_state_dict(state_dict[self._model_key], \ strict=strict) else: if self.use_query_model: print_rank_0("Loading query model") self.query_model.load_state_dict( \ state_dict[self._query_key], strict=strict) if self.use_context_model: print_rank_0("Loading context model") self.context_model.load_state_dict( \ state_dict[self._context_key], strict=strict) def init_state_dict_from_bert(self): """Initialize the state from a pretrained BERT model on iteration zero of ICT pretraining""" args = get_args() if args.bert_load is None: print_rank_0("bert-load argument is None") return tracker_filename = get_checkpoint_tracker_filename(args.bert_load) if not os.path.isfile(tracker_filename): raise FileNotFoundError("Could not find BERT checkpoint") with open(tracker_filename, 'r') as f: iteration = int(f.read().strip()) assert iteration > 0 #for param in self.query_model.language_model.parameters(): # print(param.data) #break #sys.exit() checkpoint_name = get_checkpoint_name(args.bert_load, iteration, False) if mpu.get_data_parallel_rank() == 0: print('global rank {} is loading BERT checkpoint {}'.format( torch.distributed.get_rank(), checkpoint_name)) try: state_dict = torch.load(checkpoint_name, map_location='cpu') except BaseException: raise ValueError("Could not load BERT checkpoint") # load the LM state dict into each model model_dict = state_dict['model']['language_model'] if self.shared_query_context_model: self.model.language_model.load_state_dict(model_dict) else: if self.use_query_model: self.query_model.language_model.load_state_dict(model_dict) # give each model the same ict_head to begin with as well if self.projection_dim > 0: query_proj_state_dict = \ self.state_dict_for_save_checkpoint()\ [self._query_key]['projection_enc'] if self.use_context_model: self.context_model.language_model.load_state_dict(model_dict) if self.query_model is not None and self.projection_dim > 0: self.context_model.projection_enc.load_state_dict\ (query_proj_state_dict) #for param in self.query_model.language_model.parameters(): # print(param.data) # #sys.exit() class PretrainedBertModel(MegatronModule): """BERT-based encoder for queries or contexts used for learned information retrieval.""" def __init__(self, num_tokentypes=2, parallel_output=True): super(PretrainedBertModel, self).__init__() args = get_args() tokenizer = get_tokenizer() self.pad_id = tokenizer.pad self.pool_type = args.pool_type self.projection_dim = args.projection_dim self.parallel_output = parallel_output init_method = init_method_normal(args.init_method_std) scaled_init_method = scaled_init_method_normal( args.init_method_std, args.num_layers) self.language_model, self._language_model_key = get_language_model( attention_mask_func=bert_attention_mask_func, num_tokentypes=num_tokentypes, add_pooler=False, init_method=init_method, scaled_init_method=scaled_init_method) if args.projection_dim > 0: self.projection_enc = get_linear_layer(args.hidden_size, args.projection_dim, init_method) self._projection_enc_key = 'projection_enc' def forward(self, input_ids, attention_mask, tokentype_ids=None): extended_attention_mask = attention_mask.unsqueeze(1) #extended_attention_mask = bert_extended_attention_mask(attention_mask) position_ids = bert_position_ids(input_ids) lm_output = self.language_model(input_ids, position_ids, extended_attention_mask, tokentype_ids=tokentype_ids) # This mask will be used in average-pooling and max-pooling pool_mask = (input_ids == self.pad_id).unsqueeze(2) # Taking the representation of the [CLS] token of BERT if self.pool_type == "cls-token": pooled_output = lm_output[:, 0, :] elif self.pool_type == "avg": # Average Pooling pooled_output = lm_output.masked_fill(pool_mask, 0) pooled_output = pooled_output.sum(1) / (pool_mask.size(1) \ - pool_mask.float().sum(1)) elif self.pool_type == "max": # Max-Pooling pooled_output = lm_output.masked_fill(pool_mask, -1000) pooled_output = torch.max(pooled_output, 1)[0] # Converting to float16 dtype pooled_output = pooled_output.to(lm_output.dtype) # Output. if self.projection_dim: pooled_output = self.projection_enc(pooled_output) return pooled_output def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False): """For easy load when model is combined with other heads, add an extra key.""" state_dict_ = {} state_dict_[self._language_model_key] \ = self.language_model.state_dict_for_save_checkpoint( destination, prefix, keep_vars) if self.projection_dim > 0: state_dict_[self._projection_enc_key] = \ self.projection_enc.state_dict(destination, prefix, keep_vars) return state_dict_ def load_state_dict(self, state_dict, strict=True): """Customized load.""" print_rank_0("loading BERT weights") self.language_model.load_state_dict( state_dict[self._language_model_key], strict=strict) if self.projection_dim > 0: print_rank_0("loading projection head weights") self.projection_enc.load_state_dict( state_dict[self._projection_enc_key], strict=strict)