Commit 1afe3541 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'state_dict_fix' into 'main'

Update state_dict arguments for recent PyTorch versions.

See merge request ADLR/megatron-lm!432
parents 8b686288 928a200c
...@@ -208,26 +208,25 @@ class BertModel(MegatronModule): ...@@ -208,26 +208,25 @@ class BertModel(MegatronModule):
return lm_output return lm_output
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
keep_vars=False):
"""For easy load when model is combined with other heads, """For easy load when model is combined with other heads,
add an extra key.""" add an extra key."""
state_dict_ = {} state_dict_ = {}
state_dict_[self._language_model_key] \ state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint( = self.language_model.state_dict_for_save_checkpoint(prefix=prefix,
destination, prefix, keep_vars) keep_vars=keep_vars)
if self.post_process: if self.post_process:
state_dict_[self._lm_head_key] \ state_dict_[self._lm_head_key] \
= self.lm_head.state_dict_for_save_checkpoint( = self.lm_head.state_dict_for_save_checkpoint(prefix=prefix,
destination, prefix, keep_vars) keep_vars=keep_vars)
if self.post_process and self.add_binary_head: if self.post_process and self.add_binary_head:
state_dict_[self._binary_head_key] \ state_dict_[self._binary_head_key] \
= self.binary_head.state_dict(destination, prefix, keep_vars) = self.binary_head.state_dict(prefix=prefix, keep_vars=keep_vars)
# Save word_embeddings. # Save word_embeddings.
if self.post_process and not self.pre_process: if self.post_process and not self.pre_process:
state_dict_[self._word_embeddings_for_head_key] \ state_dict_[self._word_embeddings_for_head_key] \
= self.word_embeddings.state_dict(destination, prefix, keep_vars) = self.word_embeddings.state_dict(prefix=prefix, keep_vars=keep_vars)
return state_dict_ return state_dict_
def load_state_dict(self, state_dict, strict=True): def load_state_dict(self, state_dict, strict=True):
......
...@@ -139,25 +139,23 @@ class BiEncoderModel(MegatronModule): ...@@ -139,25 +139,23 @@ class BiEncoderModel(MegatronModule):
token_types) token_types)
return logits return logits
def state_dict_for_save_checkpoint(self, destination=None, \ def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
prefix='', keep_vars=False):
"""Save dict with state dicts of each of the models.""" """Save dict with state dicts of each of the models."""
state_dict_ = {} state_dict_ = {}
if self.biencoder_shared_query_context_model: if self.biencoder_shared_query_context_model:
state_dict_[self._model_key] = \ state_dict_[self._model_key] = \
self.model.state_dict_for_save_checkpoint(destination, self.model.state_dict_for_save_checkpoint(
prefix, prefix=prefix, keep_vars=keep_vars)
keep_vars)
else: else:
if self.use_query_model: if self.use_query_model:
state_dict_[self._query_key] = \ state_dict_[self._query_key] = \
self.query_model.state_dict_for_save_checkpoint( self.query_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars) prefix=prefix, keep_vars=keep_vars)
if self.use_context_model: if self.use_context_model:
state_dict_[self._context_key] = \ state_dict_[self._context_key] = \
self.context_model.state_dict_for_save_checkpoint( self.context_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars) prefix=prefix, keep_vars=keep_vars)
return state_dict_ return state_dict_
...@@ -302,19 +300,19 @@ class PretrainedBertModel(MegatronModule): ...@@ -302,19 +300,19 @@ class PretrainedBertModel(MegatronModule):
return pooled_output return pooled_output
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
keep_vars=False):
"""For easy load when model is combined with other heads, """For easy load when model is combined with other heads,
add an extra key.""" add an extra key."""
state_dict_ = {} state_dict_ = {}
state_dict_[self._language_model_key] \ state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint( = self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars) prefix=prefix, keep_vars=keep_vars)
if self.biencoder_projection_dim > 0: if self.biencoder_projection_dim > 0:
state_dict_[self._projection_enc_key] = \ state_dict_[self._projection_enc_key] = \
self.projection_enc.state_dict(destination, prefix, keep_vars) self.projection_enc.state_dict(prefix=prefix,
keep_vars=keep_vars)
return state_dict_ return state_dict_
......
...@@ -89,19 +89,17 @@ class Classification(MegatronModule): ...@@ -89,19 +89,17 @@ class Classification(MegatronModule):
return classification_logits return classification_logits
return lm_output return lm_output
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
keep_vars=False):
"""For easy load when model is combined with other heads, """For easy load when model is combined with other heads,
add an extra key.""" add an extra key."""
state_dict_ = {} state_dict_ = {}
state_dict_[self._language_model_key] \ state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint( = self.language_model.state_dict_for_save_checkpoint(prefix=prefix,
destination, prefix, keep_vars) keep_vars=keep_vars)
if self.post_process: if self.post_process:
state_dict_[self._classification_head_key] \ state_dict_[self._classification_head_key] \
= self.classification_head.state_dict( = self.classification_head.state_dict(prefix=prefix, keep_vars=keep_vars)
destination, prefix, keep_vars)
return state_dict_ return state_dict_
def load_state_dict(self, state_dict, strict=True): def load_state_dict(self, state_dict, strict=True):
......
...@@ -71,14 +71,13 @@ class DistributedDataParallelBase(MegatronModule, ABC): ...@@ -71,14 +71,13 @@ class DistributedDataParallelBase(MegatronModule, ABC):
return self.module(*inputs, **kwargs) return self.module(*inputs, **kwargs)
def state_dict(self, destination=None, prefix='', keep_vars=False): def state_dict(self, prefix='', keep_vars=False):
return self.module.state_dict(destination, prefix, keep_vars) return self.module.state_dict(prefix=prefix, keep_vars=keep_vars)
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
keep_vars=False): return self.module.state_dict_for_save_checkpoint(prefix=prefix,
return self.module.state_dict_for_save_checkpoint(destination, prefix, keep_vars=keep_vars)
keep_vars)
def load_state_dict(self, state_dict, strict=True): def load_state_dict(self, state_dict, strict=True):
......
...@@ -105,17 +105,17 @@ class GPTModel(MegatronModule): ...@@ -105,17 +105,17 @@ class GPTModel(MegatronModule):
else: else:
return lm_output return lm_output
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
keep_vars=False):
state_dict_ = {} state_dict_ = {}
state_dict_[self._language_model_key] \ state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint( = self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars) prefix=prefix, keep_vars=keep_vars)
# Save word_embeddings. # Save word_embeddings.
if self.post_process and not self.pre_process: if self.post_process and not self.pre_process:
state_dict_[self._word_embeddings_for_head_key] \ state_dict_[self._word_embeddings_for_head_key] \
= self.word_embeddings.state_dict(destination, prefix, keep_vars) = self.word_embeddings.state_dict(prefix=prefix,
keep_vars=keep_vars)
return state_dict_ return state_dict_
def load_state_dict(self, state_dict, strict=True): def load_state_dict(self, state_dict, strict=True):
......
...@@ -243,20 +243,20 @@ class Embedding(MegatronModule): ...@@ -243,20 +243,20 @@ class Embedding(MegatronModule):
return embeddings return embeddings
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
keep_vars=False):
"""For easy load.""" """For easy load."""
state_dict_ = {} state_dict_ = {}
state_dict_[self._word_embeddings_key] \ state_dict_[self._word_embeddings_key] \
= self.word_embeddings.state_dict(destination, prefix, keep_vars) = self.word_embeddings.state_dict(prefix=prefix,
keep_vars=keep_vars)
state_dict_[self._position_embeddings_key] \ state_dict_[self._position_embeddings_key] \
= self.position_embeddings.state_dict( = self.position_embeddings.state_dict(prefix=prefix,
destination, prefix, keep_vars) keep_vars=keep_vars)
if self.num_tokentypes > 0: if self.num_tokentypes > 0:
state_dict_[self._tokentype_embeddings_key] \ state_dict_[self._tokentype_embeddings_key] \
= self.tokentype_embeddings.state_dict( = self.tokentype_embeddings.state_dict(prefix=prefix,
destination, prefix, keep_vars) keep_vars=keep_vars)
return state_dict_ return state_dict_
...@@ -478,28 +478,27 @@ class TransformerLanguageModel(MegatronModule): ...@@ -478,28 +478,27 @@ class TransformerLanguageModel(MegatronModule):
else: else:
return decoder_output, encoder_output return decoder_output, encoder_output
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
keep_vars=False):
"""For easy load.""" """For easy load."""
state_dict_ = {} state_dict_ = {}
if self.pre_process: if self.pre_process:
state_dict_[self._embedding_key] \ state_dict_[self._embedding_key] \
= self.embedding.state_dict_for_save_checkpoint( = self.embedding.state_dict_for_save_checkpoint(prefix=prefix,
destination, prefix, keep_vars) keep_vars=keep_vars)
if self.add_encoder: if self.add_encoder:
state_dict_[self._encoder_key] \ state_dict_[self._encoder_key] \
= self.encoder.state_dict_for_save_checkpoint( = self.encoder.state_dict_for_save_checkpoint(prefix=prefix,
destination, prefix, keep_vars) keep_vars=keep_vars)
if self.post_process: if self.post_process:
if self.add_pooler: if self.add_pooler:
state_dict_[self._pooler_key] \ state_dict_[self._pooler_key] \
= self.pooler.state_dict_for_save_checkpoint( = self.pooler.state_dict_for_save_checkpoint(prefix=prefix,
destination, prefix, keep_vars) keep_vars=keep_vars)
if self.add_decoder: if self.add_decoder:
state_dict_[self._decoder_key] \ state_dict_[self._decoder_key] \
= self.decoder.state_dict_for_save_checkpoint( = self.decoder.state_dict_for_save_checkpoint(prefix=prefix,
destination, prefix, keep_vars) keep_vars=keep_vars)
return state_dict_ return state_dict_
......
...@@ -43,11 +43,10 @@ class MegatronModule(torch.nn.Module): ...@@ -43,11 +43,10 @@ class MegatronModule(torch.nn.Module):
self.share_word_embeddings = share_word_embeddings self.share_word_embeddings = share_word_embeddings
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
keep_vars=False):
"""Use this function to override the state dict for """Use this function to override the state dict for
saving checkpoints.""" saving checkpoints."""
return self.state_dict(destination, prefix, keep_vars) return self.state_dict(prefix=prefix, keep_vars=keep_vars)
def word_embeddings_weight(self): def word_embeddings_weight(self):
...@@ -198,14 +197,13 @@ class Float16Module(MegatronModule): ...@@ -198,14 +197,13 @@ class Float16Module(MegatronModule):
return outputs return outputs
def state_dict(self, destination=None, prefix='', keep_vars=False): def state_dict(self, prefix='', keep_vars=False):
return self.module.state_dict(destination, prefix, keep_vars) return self.module.state_dict(prefix=prefix, keep_vars=keep_vars)
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
keep_vars=False): return self.module.state_dict_for_save_checkpoint(prefix=prefix,
return self.module.state_dict_for_save_checkpoint(destination, prefix, keep_vars=keep_vars)
keep_vars)
def load_state_dict(self, state_dict, strict=True): def load_state_dict(self, state_dict, strict=True):
......
...@@ -100,19 +100,17 @@ class MultipleChoice(MegatronModule): ...@@ -100,19 +100,17 @@ class MultipleChoice(MegatronModule):
return multichoice_logits return multichoice_logits
return lm_output return lm_output
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
keep_vars=False):
"""For easy load when model is combined with other heads, """For easy load when model is combined with other heads,
add an extra key.""" add an extra key."""
state_dict_ = {} state_dict_ = {}
state_dict_[self._language_model_key] \ state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint( = self.language_model.state_dict_for_save_checkpoint(prefix=prefix,
destination, prefix, keep_vars) keep_vars=keep_vars)
if self.post_process: if self.post_process:
state_dict_[self._multichoice_head_key] \ state_dict_[self._multichoice_head_key] \
= self.multichoice_head.state_dict( = self.multichoice_head.state_dict(prefix=prefix, keep_vars=keep_vars)
destination, prefix, keep_vars)
return state_dict_ return state_dict_
def load_state_dict(self, state_dict, strict=True): def load_state_dict(self, state_dict, strict=True):
......
...@@ -87,18 +87,18 @@ class ICTBertModel(MegatronModule): ...@@ -87,18 +87,18 @@ class ICTBertModel(MegatronModule):
else: else:
raise ValueError("Cannot embed block without block model.") raise ValueError("Cannot embed block without block model.")
def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False): def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
"""Save dict with state dicts of each of the models.""" """Save dict with state dicts of each of the models."""
state_dict_ = {} state_dict_ = {}
if self.use_query_model: if self.use_query_model:
state_dict_[self._query_key] \ state_dict_[self._query_key] \
= self.query_model.state_dict_for_save_checkpoint( = self.query_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars) prefix=prefix, keep_vars=keep_vars)
if self.use_block_model: if self.use_block_model:
state_dict_[self._block_key] \ state_dict_[self._block_key] \
= self.block_model.state_dict_for_save_checkpoint( = self.block_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars) prefix=prefix, keep_vars=keep_vars)
return state_dict_ return state_dict_
...@@ -181,17 +181,17 @@ class IREncoderBertModel(MegatronModule): ...@@ -181,17 +181,17 @@ class IREncoderBertModel(MegatronModule):
ict_logits = self.ict_head(pooled_output) ict_logits = self.ict_head(pooled_output)
return ict_logits, None return ict_logits, None
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
keep_vars=False):
"""For easy load when model is combined with other heads, """For easy load when model is combined with other heads,
add an extra key.""" add an extra key."""
state_dict_ = {} state_dict_ = {}
state_dict_[self._language_model_key] \ state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint( = self.language_model.state_dict_for_save_checkpoint(prefix=prefix,
destination, prefix, keep_vars) keep_vars=keep_vars)
state_dict_[self._ict_head_key] \ state_dict_[self._ict_head_key] \
= self.ict_head.state_dict(destination, prefix, keep_vars) = self.ict_head.state_dict(prefix=prefix,
keep_vars=keep_vars)
return state_dict_ return state_dict_
def load_state_dict(self, state_dict, strict=True): def load_state_dict(self, state_dict, strict=True):
......
...@@ -178,23 +178,23 @@ class T5Model(MegatronModule): ...@@ -178,23 +178,23 @@ class T5Model(MegatronModule):
encoder_output = lm_output encoder_output = lm_output
return encoder_output return encoder_output
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
keep_vars=False):
"""For easy load when model is combined with other heads, """For easy load when model is combined with other heads,
add an extra key.""" add an extra key."""
state_dict_ = {} state_dict_ = {}
state_dict_[self._language_model_key] \ state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint( = self.language_model.state_dict_for_save_checkpoint(prefix=prefix,
destination, prefix, keep_vars) keep_vars=keep_vars)
if self.post_process and self.add_decoder: if self.post_process and self.add_decoder:
state_dict_[self._lm_head_key] \ state_dict_[self._lm_head_key] \
= self.lm_head.state_dict_for_save_checkpoint( = self.lm_head.state_dict_for_save_checkpoint(prefix=prefix,
destination, prefix, keep_vars) keep_vars=keep_vars)
# Save word_embeddings. # Save word_embeddings.
if self.post_process and not self.pre_process and self.add_decoder: if self.post_process and not self.pre_process and self.add_decoder:
state_dict_[self._word_embeddings_for_head_key] \ state_dict_[self._word_embeddings_for_head_key] \
= self.word_embeddings.state_dict(destination, prefix, keep_vars) = self.word_embeddings.state_dict(prefix=prefix,
keep_vars=keep_vars)
return state_dict_ return state_dict_
def load_state_dict(self, state_dict, strict=True): def load_state_dict(self, state_dict, strict=True):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment