Commit 41276b6c authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

Merge branch 'main' into nmt-main

parents a44360ed fc7f4f03
# coding=utf-8 # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transformer based language model.""" """Transformer based language model."""
...@@ -243,20 +230,20 @@ class Embedding(MegatronModule): ...@@ -243,20 +230,20 @@ class Embedding(MegatronModule):
return embeddings return embeddings
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
keep_vars=False):
"""For easy load.""" """For easy load."""
state_dict_ = {} state_dict_ = {}
state_dict_[self._word_embeddings_key] \ state_dict_[self._word_embeddings_key] \
= self.word_embeddings.state_dict(destination, prefix, keep_vars) = self.word_embeddings.state_dict(prefix=prefix,
keep_vars=keep_vars)
state_dict_[self._position_embeddings_key] \ state_dict_[self._position_embeddings_key] \
= self.position_embeddings.state_dict( = self.position_embeddings.state_dict(prefix=prefix,
destination, prefix, keep_vars) keep_vars=keep_vars)
if self.num_tokentypes > 0: if self.num_tokentypes > 0:
state_dict_[self._tokentype_embeddings_key] \ state_dict_[self._tokentype_embeddings_key] \
= self.tokentype_embeddings.state_dict( = self.tokentype_embeddings.state_dict(prefix=prefix,
destination, prefix, keep_vars) keep_vars=keep_vars)
return state_dict_ return state_dict_
...@@ -478,28 +465,27 @@ class TransformerLanguageModel(MegatronModule): ...@@ -478,28 +465,27 @@ class TransformerLanguageModel(MegatronModule):
else: else:
return decoder_output, encoder_output return decoder_output, encoder_output
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
keep_vars=False):
"""For easy load.""" """For easy load."""
state_dict_ = {} state_dict_ = {}
if self.pre_process: if self.pre_process:
state_dict_[self._embedding_key] \ state_dict_[self._embedding_key] \
= self.embedding.state_dict_for_save_checkpoint( = self.embedding.state_dict_for_save_checkpoint(prefix=prefix,
destination, prefix, keep_vars) keep_vars=keep_vars)
if self.add_encoder: if self.add_encoder:
state_dict_[self._encoder_key] \ state_dict_[self._encoder_key] \
= self.encoder.state_dict_for_save_checkpoint( = self.encoder.state_dict_for_save_checkpoint(prefix=prefix,
destination, prefix, keep_vars) keep_vars=keep_vars)
if self.post_process: if self.post_process:
if self.add_pooler: if self.add_pooler:
state_dict_[self._pooler_key] \ state_dict_[self._pooler_key] \
= self.pooler.state_dict_for_save_checkpoint( = self.pooler.state_dict_for_save_checkpoint(prefix=prefix,
destination, prefix, keep_vars) keep_vars=keep_vars)
if self.add_decoder: if self.add_decoder:
state_dict_[self._decoder_key] \ state_dict_[self._decoder_key] \
= self.decoder.state_dict_for_save_checkpoint( = self.decoder.state_dict_for_save_checkpoint(prefix=prefix,
destination, prefix, keep_vars) keep_vars=keep_vars)
return state_dict_ return state_dict_
......
# coding=utf-8 # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Megatron Module""" """Megatron Module"""
...@@ -43,11 +30,10 @@ class MegatronModule(torch.nn.Module): ...@@ -43,11 +30,10 @@ class MegatronModule(torch.nn.Module):
self.share_word_embeddings = share_word_embeddings self.share_word_embeddings = share_word_embeddings
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
keep_vars=False):
"""Use this function to override the state dict for """Use this function to override the state dict for
saving checkpoints.""" saving checkpoints."""
return self.state_dict(destination, prefix, keep_vars) return self.state_dict(prefix=prefix, keep_vars=keep_vars)
def word_embeddings_weight(self): def word_embeddings_weight(self):
...@@ -198,14 +184,13 @@ class Float16Module(MegatronModule): ...@@ -198,14 +184,13 @@ class Float16Module(MegatronModule):
return outputs return outputs
def state_dict(self, destination=None, prefix='', keep_vars=False): def state_dict(self, prefix='', keep_vars=False):
return self.module.state_dict(destination, prefix, keep_vars) return self.module.state_dict(prefix=prefix, keep_vars=keep_vars)
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
keep_vars=False): return self.module.state_dict_for_save_checkpoint(prefix=prefix,
return self.module.state_dict_for_save_checkpoint(destination, prefix, keep_vars=keep_vars)
keep_vars)
def load_state_dict(self, state_dict, strict=True): def load_state_dict(self, state_dict, strict=True):
......
# coding=utf-8 # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Multiple choice model.""" """Multiple choice model."""
...@@ -100,19 +87,17 @@ class MultipleChoice(MegatronModule): ...@@ -100,19 +87,17 @@ class MultipleChoice(MegatronModule):
return multichoice_logits return multichoice_logits
return lm_output return lm_output
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
keep_vars=False):
"""For easy load when model is combined with other heads, """For easy load when model is combined with other heads,
add an extra key.""" add an extra key."""
state_dict_ = {} state_dict_ = {}
state_dict_[self._language_model_key] \ state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint( = self.language_model.state_dict_for_save_checkpoint(prefix=prefix,
destination, prefix, keep_vars) keep_vars=keep_vars)
if self.post_process: if self.post_process:
state_dict_[self._multichoice_head_key] \ state_dict_[self._multichoice_head_key] \
= self.multichoice_head.state_dict( = self.multichoice_head.state_dict(prefix=prefix, keep_vars=keep_vars)
destination, prefix, keep_vars)
return state_dict_ return state_dict_
def load_state_dict(self, state_dict, strict=True): def load_state_dict(self, state_dict, strict=True):
......
...@@ -87,18 +87,18 @@ class ICTBertModel(MegatronModule): ...@@ -87,18 +87,18 @@ class ICTBertModel(MegatronModule):
else: else:
raise ValueError("Cannot embed block without block model.") raise ValueError("Cannot embed block without block model.")
def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False): def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
"""Save dict with state dicts of each of the models.""" """Save dict with state dicts of each of the models."""
state_dict_ = {} state_dict_ = {}
if self.use_query_model: if self.use_query_model:
state_dict_[self._query_key] \ state_dict_[self._query_key] \
= self.query_model.state_dict_for_save_checkpoint( = self.query_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars) prefix=prefix, keep_vars=keep_vars)
if self.use_block_model: if self.use_block_model:
state_dict_[self._block_key] \ state_dict_[self._block_key] \
= self.block_model.state_dict_for_save_checkpoint( = self.block_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars) prefix=prefix, keep_vars=keep_vars)
return state_dict_ return state_dict_
...@@ -181,17 +181,17 @@ class IREncoderBertModel(MegatronModule): ...@@ -181,17 +181,17 @@ class IREncoderBertModel(MegatronModule):
ict_logits = self.ict_head(pooled_output) ict_logits = self.ict_head(pooled_output)
return ict_logits, None return ict_logits, None
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
keep_vars=False):
"""For easy load when model is combined with other heads, """For easy load when model is combined with other heads,
add an extra key.""" add an extra key."""
state_dict_ = {} state_dict_ = {}
state_dict_[self._language_model_key] \ state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint( = self.language_model.state_dict_for_save_checkpoint(prefix=prefix,
destination, prefix, keep_vars) keep_vars=keep_vars)
state_dict_[self._ict_head_key] \ state_dict_[self._ict_head_key] \
= self.ict_head.state_dict(destination, prefix, keep_vars) = self.ict_head.state_dict(prefix=prefix,
keep_vars=keep_vars)
return state_dict_ return state_dict_
def load_state_dict(self, state_dict, strict=True): def load_state_dict(self, state_dict, strict=True):
......
# coding=utf-8 # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""T5 model.""" """T5 model."""
...@@ -178,23 +165,23 @@ class T5Model(MegatronModule): ...@@ -178,23 +165,23 @@ class T5Model(MegatronModule):
encoder_output = lm_output encoder_output = lm_output
return encoder_output return encoder_output
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
keep_vars=False):
"""For easy load when model is combined with other heads, """For easy load when model is combined with other heads,
add an extra key.""" add an extra key."""
state_dict_ = {} state_dict_ = {}
state_dict_[self._language_model_key] \ state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint( = self.language_model.state_dict_for_save_checkpoint(prefix=prefix,
destination, prefix, keep_vars) keep_vars=keep_vars)
if self.post_process and self.add_decoder: if self.post_process and self.add_decoder:
state_dict_[self._lm_head_key] \ state_dict_[self._lm_head_key] \
= self.lm_head.state_dict_for_save_checkpoint( = self.lm_head.state_dict_for_save_checkpoint(prefix=prefix,
destination, prefix, keep_vars) keep_vars=keep_vars)
# Save word_embeddings. # Save word_embeddings.
if self.post_process and not self.pre_process and self.add_decoder: if self.post_process and not self.pre_process and self.add_decoder:
state_dict_[self._word_embeddings_for_head_key] \ state_dict_[self._word_embeddings_for_head_key] \
= self.word_embeddings.state_dict(destination, prefix, keep_vars) = self.word_embeddings.state_dict(prefix=prefix,
keep_vars=keep_vars)
return state_dict_ return state_dict_
def load_state_dict(self, state_dict, strict=True): def load_state_dict(self, state_dict, strict=True):
......
# coding=utf-8 # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transformer.""" """Transformer."""
import math import math
......
# coding=utf-8 # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for models.""" """Utilities for models."""
......
# coding=utf-8 # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Vision Transformer(VIT) model.""" """Vision Transformer(VIT) model."""
......
# coding=utf-8 # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# This source code is licensed under the BSD license found in the # This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
......
# coding=utf-8 # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Vision Transformer(VIT) model.""" """Vision Transformer(VIT) model."""
......
# coding=utf-8 # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Model parallel utility interface.""" """Model parallel utility interface."""
......
# coding=utf-8 # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch import torch
......
# coding=utf-8 # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch import torch
......
# coding=utf-8 # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Model and data parallel groups.""" """Model and data parallel groups."""
......
# coding=utf-8 # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Parts of the code here are adapted from PyTorch # Parts of the code here are adapted from PyTorch
......
# coding=utf-8 # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch import torch
......
# coding=utf-8 # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Parts of the code here are adapted from PyTorch # Parts of the code here are adapted from PyTorch
......
# coding=utf-8 # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse import argparse
import os import os
......
# coding=utf-8 # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from commons import set_random_seed from commons import set_random_seed
from commons import IdentityLayer from commons import IdentityLayer
......
# coding=utf-8 # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from commons import print_separator from commons import print_separator
from commons import initialize_distributed from commons import initialize_distributed
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment