Commit 1afe3541 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'state_dict_fix' into 'main'

Update state_dict arguments for recent PyTorch versions.

See merge request ADLR/megatron-lm!432
parents 8b686288 928a200c
......@@ -208,26 +208,25 @@ class BertModel(MegatronModule):
return lm_output
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
"""For easy load when model is combined with other heads,
add an extra key."""
state_dict_ = {}
state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
= self.language_model.state_dict_for_save_checkpoint(prefix=prefix,
keep_vars=keep_vars)
if self.post_process:
state_dict_[self._lm_head_key] \
= self.lm_head.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
= self.lm_head.state_dict_for_save_checkpoint(prefix=prefix,
keep_vars=keep_vars)
if self.post_process and self.add_binary_head:
state_dict_[self._binary_head_key] \
= self.binary_head.state_dict(destination, prefix, keep_vars)
= self.binary_head.state_dict(prefix=prefix, keep_vars=keep_vars)
# Save word_embeddings.
if self.post_process and not self.pre_process:
state_dict_[self._word_embeddings_for_head_key] \
= self.word_embeddings.state_dict(destination, prefix, keep_vars)
= self.word_embeddings.state_dict(prefix=prefix, keep_vars=keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
......
......@@ -139,25 +139,23 @@ class BiEncoderModel(MegatronModule):
token_types)
return logits
def state_dict_for_save_checkpoint(self, destination=None, \
prefix='', keep_vars=False):
def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
"""Save dict with state dicts of each of the models."""
state_dict_ = {}
if self.biencoder_shared_query_context_model:
state_dict_[self._model_key] = \
self.model.state_dict_for_save_checkpoint(destination,
prefix,
keep_vars)
self.model.state_dict_for_save_checkpoint(
prefix=prefix, keep_vars=keep_vars)
else:
if self.use_query_model:
state_dict_[self._query_key] = \
self.query_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
prefix=prefix, keep_vars=keep_vars)
if self.use_context_model:
state_dict_[self._context_key] = \
self.context_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
prefix=prefix, keep_vars=keep_vars)
return state_dict_
......@@ -302,19 +300,19 @@ class PretrainedBertModel(MegatronModule):
return pooled_output
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
"""For easy load when model is combined with other heads,
add an extra key."""
state_dict_ = {}
state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
prefix=prefix, keep_vars=keep_vars)
if self.biencoder_projection_dim > 0:
state_dict_[self._projection_enc_key] = \
self.projection_enc.state_dict(destination, prefix, keep_vars)
self.projection_enc.state_dict(prefix=prefix,
keep_vars=keep_vars)
return state_dict_
......
......@@ -89,19 +89,17 @@ class Classification(MegatronModule):
return classification_logits
return lm_output
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
"""For easy load when model is combined with other heads,
add an extra key."""
state_dict_ = {}
state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
= self.language_model.state_dict_for_save_checkpoint(prefix=prefix,
keep_vars=keep_vars)
if self.post_process:
state_dict_[self._classification_head_key] \
= self.classification_head.state_dict(
destination, prefix, keep_vars)
= self.classification_head.state_dict(prefix=prefix, keep_vars=keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
......
......@@ -71,14 +71,13 @@ class DistributedDataParallelBase(MegatronModule, ABC):
return self.module(*inputs, **kwargs)
def state_dict(self, destination=None, prefix='', keep_vars=False):
return self.module.state_dict(destination, prefix, keep_vars)
def state_dict(self, prefix='', keep_vars=False):
return self.module.state_dict(prefix=prefix, keep_vars=keep_vars)
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
return self.module.state_dict_for_save_checkpoint(destination, prefix,
keep_vars)
def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
return self.module.state_dict_for_save_checkpoint(prefix=prefix,
keep_vars=keep_vars)
def load_state_dict(self, state_dict, strict=True):
......
......@@ -105,17 +105,17 @@ class GPTModel(MegatronModule):
else:
return lm_output
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
state_dict_ = {}
state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
prefix=prefix, keep_vars=keep_vars)
# Save word_embeddings.
if self.post_process and not self.pre_process:
state_dict_[self._word_embeddings_for_head_key] \
= self.word_embeddings.state_dict(destination, prefix, keep_vars)
= self.word_embeddings.state_dict(prefix=prefix,
keep_vars=keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
......
......@@ -243,20 +243,20 @@ class Embedding(MegatronModule):
return embeddings
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
"""For easy load."""
state_dict_ = {}
state_dict_[self._word_embeddings_key] \
= self.word_embeddings.state_dict(destination, prefix, keep_vars)
= self.word_embeddings.state_dict(prefix=prefix,
keep_vars=keep_vars)
state_dict_[self._position_embeddings_key] \
= self.position_embeddings.state_dict(
destination, prefix, keep_vars)
= self.position_embeddings.state_dict(prefix=prefix,
keep_vars=keep_vars)
if self.num_tokentypes > 0:
state_dict_[self._tokentype_embeddings_key] \
= self.tokentype_embeddings.state_dict(
destination, prefix, keep_vars)
= self.tokentype_embeddings.state_dict(prefix=prefix,
keep_vars=keep_vars)
return state_dict_
......@@ -478,28 +478,27 @@ class TransformerLanguageModel(MegatronModule):
else:
return decoder_output, encoder_output
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
"""For easy load."""
state_dict_ = {}
if self.pre_process:
state_dict_[self._embedding_key] \
= self.embedding.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
= self.embedding.state_dict_for_save_checkpoint(prefix=prefix,
keep_vars=keep_vars)
if self.add_encoder:
state_dict_[self._encoder_key] \
= self.encoder.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
= self.encoder.state_dict_for_save_checkpoint(prefix=prefix,
keep_vars=keep_vars)
if self.post_process:
if self.add_pooler:
state_dict_[self._pooler_key] \
= self.pooler.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
= self.pooler.state_dict_for_save_checkpoint(prefix=prefix,
keep_vars=keep_vars)
if self.add_decoder:
state_dict_[self._decoder_key] \
= self.decoder.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
= self.decoder.state_dict_for_save_checkpoint(prefix=prefix,
keep_vars=keep_vars)
return state_dict_
......
......@@ -43,11 +43,10 @@ class MegatronModule(torch.nn.Module):
self.share_word_embeddings = share_word_embeddings
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
"""Use this function to override the state dict for
saving checkpoints."""
return self.state_dict(destination, prefix, keep_vars)
return self.state_dict(prefix=prefix, keep_vars=keep_vars)
def word_embeddings_weight(self):
......@@ -198,14 +197,13 @@ class Float16Module(MegatronModule):
return outputs
def state_dict(self, destination=None, prefix='', keep_vars=False):
return self.module.state_dict(destination, prefix, keep_vars)
def state_dict(self, prefix='', keep_vars=False):
return self.module.state_dict(prefix=prefix, keep_vars=keep_vars)
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
return self.module.state_dict_for_save_checkpoint(destination, prefix,
keep_vars)
def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
return self.module.state_dict_for_save_checkpoint(prefix=prefix,
keep_vars=keep_vars)
def load_state_dict(self, state_dict, strict=True):
......
......@@ -100,19 +100,17 @@ class MultipleChoice(MegatronModule):
return multichoice_logits
return lm_output
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
"""For easy load when model is combined with other heads,
add an extra key."""
state_dict_ = {}
state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
= self.language_model.state_dict_for_save_checkpoint(prefix=prefix,
keep_vars=keep_vars)
if self.post_process:
state_dict_[self._multichoice_head_key] \
= self.multichoice_head.state_dict(
destination, prefix, keep_vars)
= self.multichoice_head.state_dict(prefix=prefix, keep_vars=keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
......
......@@ -87,18 +87,18 @@ class ICTBertModel(MegatronModule):
else:
raise ValueError("Cannot embed block without block model.")
def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False):
def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
"""Save dict with state dicts of each of the models."""
state_dict_ = {}
if self.use_query_model:
state_dict_[self._query_key] \
= self.query_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
prefix=prefix, keep_vars=keep_vars)
if self.use_block_model:
state_dict_[self._block_key] \
= self.block_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
prefix=prefix, keep_vars=keep_vars)
return state_dict_
......@@ -181,17 +181,17 @@ class IREncoderBertModel(MegatronModule):
ict_logits = self.ict_head(pooled_output)
return ict_logits, None
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
"""For easy load when model is combined with other heads,
add an extra key."""
state_dict_ = {}
state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
= self.language_model.state_dict_for_save_checkpoint(prefix=prefix,
keep_vars=keep_vars)
state_dict_[self._ict_head_key] \
= self.ict_head.state_dict(destination, prefix, keep_vars)
= self.ict_head.state_dict(prefix=prefix,
keep_vars=keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
......
......@@ -178,23 +178,23 @@ class T5Model(MegatronModule):
encoder_output = lm_output
return encoder_output
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
"""For easy load when model is combined with other heads,
add an extra key."""
state_dict_ = {}
state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
= self.language_model.state_dict_for_save_checkpoint(prefix=prefix,
keep_vars=keep_vars)
if self.post_process and self.add_decoder:
state_dict_[self._lm_head_key] \
= self.lm_head.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
= self.lm_head.state_dict_for_save_checkpoint(prefix=prefix,
keep_vars=keep_vars)
# Save word_embeddings.
if self.post_process and not self.pre_process and self.add_decoder:
state_dict_[self._word_embeddings_for_head_key] \
= self.word_embeddings.state_dict(destination, prefix, keep_vars)
= self.word_embeddings.state_dict(prefix=prefix,
keep_vars=keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment