Commit 41276b6c authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

Merge branch 'main' into nmt-main

parents a44360ed fc7f4f03
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Transformer based language model."""
......@@ -243,20 +230,20 @@ class Embedding(MegatronModule):
return embeddings
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
"""For easy load."""
state_dict_ = {}
state_dict_[self._word_embeddings_key] \
= self.word_embeddings.state_dict(destination, prefix, keep_vars)
= self.word_embeddings.state_dict(prefix=prefix,
keep_vars=keep_vars)
state_dict_[self._position_embeddings_key] \
= self.position_embeddings.state_dict(
destination, prefix, keep_vars)
= self.position_embeddings.state_dict(prefix=prefix,
keep_vars=keep_vars)
if self.num_tokentypes > 0:
state_dict_[self._tokentype_embeddings_key] \
= self.tokentype_embeddings.state_dict(
destination, prefix, keep_vars)
= self.tokentype_embeddings.state_dict(prefix=prefix,
keep_vars=keep_vars)
return state_dict_
......@@ -478,28 +465,27 @@ class TransformerLanguageModel(MegatronModule):
else:
return decoder_output, encoder_output
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
"""For easy load."""
state_dict_ = {}
if self.pre_process:
state_dict_[self._embedding_key] \
= self.embedding.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
= self.embedding.state_dict_for_save_checkpoint(prefix=prefix,
keep_vars=keep_vars)
if self.add_encoder:
state_dict_[self._encoder_key] \
= self.encoder.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
= self.encoder.state_dict_for_save_checkpoint(prefix=prefix,
keep_vars=keep_vars)
if self.post_process:
if self.add_pooler:
state_dict_[self._pooler_key] \
= self.pooler.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
= self.pooler.state_dict_for_save_checkpoint(prefix=prefix,
keep_vars=keep_vars)
if self.add_decoder:
state_dict_[self._decoder_key] \
= self.decoder.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
= self.decoder.state_dict_for_save_checkpoint(prefix=prefix,
keep_vars=keep_vars)
return state_dict_
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Megatron Module"""
......@@ -43,11 +30,10 @@ class MegatronModule(torch.nn.Module):
self.share_word_embeddings = share_word_embeddings
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
"""Use this function to override the state dict for
saving checkpoints."""
return self.state_dict(destination, prefix, keep_vars)
return self.state_dict(prefix=prefix, keep_vars=keep_vars)
def word_embeddings_weight(self):
......@@ -198,14 +184,13 @@ class Float16Module(MegatronModule):
return outputs
def state_dict(self, destination=None, prefix='', keep_vars=False):
return self.module.state_dict(destination, prefix, keep_vars)
def state_dict(self, prefix='', keep_vars=False):
return self.module.state_dict(prefix=prefix, keep_vars=keep_vars)
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
return self.module.state_dict_for_save_checkpoint(destination, prefix,
keep_vars)
def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
return self.module.state_dict_for_save_checkpoint(prefix=prefix,
keep_vars=keep_vars)
def load_state_dict(self, state_dict, strict=True):
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Multiple choice model."""
......@@ -100,19 +87,17 @@ class MultipleChoice(MegatronModule):
return multichoice_logits
return lm_output
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
"""For easy load when model is combined with other heads,
add an extra key."""
state_dict_ = {}
state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
= self.language_model.state_dict_for_save_checkpoint(prefix=prefix,
keep_vars=keep_vars)
if self.post_process:
state_dict_[self._multichoice_head_key] \
= self.multichoice_head.state_dict(
destination, prefix, keep_vars)
= self.multichoice_head.state_dict(prefix=prefix, keep_vars=keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
......
......@@ -87,18 +87,18 @@ class ICTBertModel(MegatronModule):
else:
raise ValueError("Cannot embed block without block model.")
def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False):
def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
"""Save dict with state dicts of each of the models."""
state_dict_ = {}
if self.use_query_model:
state_dict_[self._query_key] \
= self.query_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
prefix=prefix, keep_vars=keep_vars)
if self.use_block_model:
state_dict_[self._block_key] \
= self.block_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
prefix=prefix, keep_vars=keep_vars)
return state_dict_
......@@ -181,17 +181,17 @@ class IREncoderBertModel(MegatronModule):
ict_logits = self.ict_head(pooled_output)
return ict_logits, None
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
"""For easy load when model is combined with other heads,
add an extra key."""
state_dict_ = {}
state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
= self.language_model.state_dict_for_save_checkpoint(prefix=prefix,
keep_vars=keep_vars)
state_dict_[self._ict_head_key] \
= self.ict_head.state_dict(destination, prefix, keep_vars)
= self.ict_head.state_dict(prefix=prefix,
keep_vars=keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""T5 model."""
......@@ -178,23 +165,23 @@ class T5Model(MegatronModule):
encoder_output = lm_output
return encoder_output
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
"""For easy load when model is combined with other heads,
add an extra key."""
state_dict_ = {}
state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
= self.language_model.state_dict_for_save_checkpoint(prefix=prefix,
keep_vars=keep_vars)
if self.post_process and self.add_decoder:
state_dict_[self._lm_head_key] \
= self.lm_head.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
= self.lm_head.state_dict_for_save_checkpoint(prefix=prefix,
keep_vars=keep_vars)
# Save word_embeddings.
if self.post_process and not self.pre_process and self.add_decoder:
state_dict_[self._word_embeddings_for_head_key] \
= self.word_embeddings.state_dict(destination, prefix, keep_vars)
= self.word_embeddings.state_dict(prefix=prefix,
keep_vars=keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Transformer."""
import math
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Utilities for models."""
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Vision Transformer(VIT) model."""
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Vision Transformer(VIT) model."""
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Model parallel utility interface."""
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import torch
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import torch
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Model and data parallel groups."""
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Parts of the code here are adapted from PyTorch
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import torch
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Parts of the code here are adapted from PyTorch
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import argparse
import os
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
from commons import set_random_seed
from commons import IdentityLayer
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
from commons import print_separator
from commons import initialize_distributed
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment