Commit 41276b6c authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

Merge branch 'main' into nmt-main

parents a44360ed fc7f4f03
/* coding=utf-8 /* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <torch/extension.h> #include <torch/extension.h>
......
/* coding=utf-8 /* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <cuda.h> #include <cuda.h>
......
/* coding=utf-8 /* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <torch/extension.h> #include <torch/extension.h>
......
/* coding=utf-8 /* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once #pragma once
......
/* coding=utf-8 /* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <cuda.h> #include <cuda.h>
......
/* coding=utf-8 /* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/ATen.h> #include <ATen/ATen.h>
......
# coding=utf-8 # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Megatron global variables.""" """Megatron global variables."""
import os import os
import sys import sys
import time
from functools import reduce from functools import reduce
import operator import operator
import torch import torch
...@@ -25,6 +11,7 @@ import torch ...@@ -25,6 +11,7 @@ import torch
from megatron import dist_signal_handler from megatron import dist_signal_handler
from megatron.tokenizer import build_tokenizer from megatron.tokenizer import build_tokenizer
from .microbatches import build_num_microbatches_calculator from .microbatches import build_num_microbatches_calculator
from .timers import Timers
_GLOBAL_ARGS = None _GLOBAL_ARGS = None
_GLOBAL_NUM_MICROBATCHES_CALCULATOR = None _GLOBAL_NUM_MICROBATCHES_CALCULATOR = None
...@@ -108,7 +95,7 @@ def set_global_variables(args): ...@@ -108,7 +95,7 @@ def set_global_variables(args):
_ = _build_tokenizer(args) _ = _build_tokenizer(args)
_set_tensorboard_writer(args) _set_tensorboard_writer(args)
_set_adlr_autoresume(args) _set_adlr_autoresume(args)
_set_timers() _set_timers(args)
_set_global_memory_buffer() _set_global_memory_buffer()
if args.exit_signal_handler: if args.exit_signal_handler:
...@@ -182,11 +169,12 @@ def _set_adlr_autoresume(args): ...@@ -182,11 +169,12 @@ def _set_adlr_autoresume(args):
_GLOBAL_ADLR_AUTORESUME = AutoResume _GLOBAL_ADLR_AUTORESUME = AutoResume
def _set_timers(): def _set_timers(args):
"""Initialize timers.""" """Initialize timers."""
global _GLOBAL_TIMERS global _GLOBAL_TIMERS
_ensure_var_is_not_initialized(_GLOBAL_TIMERS, 'timers') _ensure_var_is_not_initialized(_GLOBAL_TIMERS, 'timers')
_GLOBAL_TIMERS = Timers() _GLOBAL_TIMERS = Timers(args.timing_log_level, args.timing_log_option)
def _set_global_memory_buffer(): def _set_global_memory_buffer():
"""Initialize global buffer""" """Initialize global buffer"""
...@@ -205,87 +193,6 @@ def _ensure_var_is_not_initialized(var, name): ...@@ -205,87 +193,6 @@ def _ensure_var_is_not_initialized(var, name):
assert var is None, '{} is already initialized.'.format(name) assert var is None, '{} is already initialized.'.format(name)
class _Timer:
"""Timer."""
def __init__(self, name):
self.name_ = name
self.elapsed_ = 0.0
self.started_ = False
self.start_time = time.time()
def start(self):
"""Start the timer."""
assert not self.started_, 'timer has already been started'
torch.cuda.synchronize()
self.start_time = time.time()
self.started_ = True
def stop(self):
"""Stop the timer."""
assert self.started_, 'timer is not started'
torch.cuda.synchronize()
self.elapsed_ += (time.time() - self.start_time)
self.started_ = False
def reset(self):
"""Reset timer."""
self.elapsed_ = 0.0
self.started_ = False
def elapsed(self, reset=True):
"""Calculate the elapsed time."""
started_ = self.started_
# If the timing in progress, end it first.
if self.started_:
self.stop()
# Get the elapsed time.
elapsed_ = self.elapsed_
# Reset the elapsed time
if reset:
self.reset()
# If timing was in progress, set it back.
if started_:
self.start()
return elapsed_
class Timers:
"""Group of timers."""
def __init__(self):
self.timers = {}
def __call__(self, name):
if name not in self.timers:
self.timers[name] = _Timer(name)
return self.timers[name]
def write(self, names, writer, iteration, normalizer=1.0, reset=False):
"""Write timers to a tensorboard writer"""
# currently when using add_scalars,
# torch.utils.add_scalars makes each timer its own run, which
# polutes the runs list, so we just add each as a scalar
assert normalizer > 0.0
for name in names:
value = self.timers[name].elapsed(reset=reset) / normalizer
writer.add_scalar(name + '-time', value, iteration)
def log(self, names, normalizer=1.0, reset=True):
"""Log a group of timers."""
assert normalizer > 0.0
string = 'time (ms)'
for name in names:
elapsed_time = self.timers[name].elapsed(
reset=reset) * 1000.0 / normalizer
string += ' | {}: {:.2f}'.format(name, elapsed_time)
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == (
torch.distributed.get_world_size() - 1):
print(string, flush=True)
else:
print(string, flush=True)
class GlobalMemoryBuffer: class GlobalMemoryBuffer:
"""Global buffer to avoid dynamic memory allocations. """Global buffer to avoid dynamic memory allocations.
......
# coding=utf-8 # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Megatron initialization.""" """Megatron initialization."""
......
# coding=utf-8 # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch import torch
......
# coding=utf-8 # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Megatron number of micro-batches calculators.""" """Megatron number of micro-batches calculators."""
......
# coding=utf-8 # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .fused_layer_norm import MixedFusedLayerNorm as LayerNorm from .fused_layer_norm import MixedFusedLayerNorm as LayerNorm
......
# coding=utf-8 # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""BERT model.""" """BERT model."""
...@@ -208,26 +195,25 @@ class BertModel(MegatronModule): ...@@ -208,26 +195,25 @@ class BertModel(MegatronModule):
return lm_output return lm_output
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
keep_vars=False):
"""For easy load when model is combined with other heads, """For easy load when model is combined with other heads,
add an extra key.""" add an extra key."""
state_dict_ = {} state_dict_ = {}
state_dict_[self._language_model_key] \ state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint( = self.language_model.state_dict_for_save_checkpoint(prefix=prefix,
destination, prefix, keep_vars) keep_vars=keep_vars)
if self.post_process: if self.post_process:
state_dict_[self._lm_head_key] \ state_dict_[self._lm_head_key] \
= self.lm_head.state_dict_for_save_checkpoint( = self.lm_head.state_dict_for_save_checkpoint(prefix=prefix,
destination, prefix, keep_vars) keep_vars=keep_vars)
if self.post_process and self.add_binary_head: if self.post_process and self.add_binary_head:
state_dict_[self._binary_head_key] \ state_dict_[self._binary_head_key] \
= self.binary_head.state_dict(destination, prefix, keep_vars) = self.binary_head.state_dict(prefix=prefix, keep_vars=keep_vars)
# Save word_embeddings. # Save word_embeddings.
if self.post_process and not self.pre_process: if self.post_process and not self.pre_process:
state_dict_[self._word_embeddings_for_head_key] \ state_dict_[self._word_embeddings_for_head_key] \
= self.word_embeddings.state_dict(destination, prefix, keep_vars) = self.word_embeddings.state_dict(prefix=prefix, keep_vars=keep_vars)
return state_dict_ return state_dict_
def load_state_dict(self, state_dict, strict=True): def load_state_dict(self, state_dict, strict=True):
......
...@@ -139,25 +139,23 @@ class BiEncoderModel(MegatronModule): ...@@ -139,25 +139,23 @@ class BiEncoderModel(MegatronModule):
token_types) token_types)
return logits return logits
def state_dict_for_save_checkpoint(self, destination=None, \ def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
prefix='', keep_vars=False):
"""Save dict with state dicts of each of the models.""" """Save dict with state dicts of each of the models."""
state_dict_ = {} state_dict_ = {}
if self.biencoder_shared_query_context_model: if self.biencoder_shared_query_context_model:
state_dict_[self._model_key] = \ state_dict_[self._model_key] = \
self.model.state_dict_for_save_checkpoint(destination, self.model.state_dict_for_save_checkpoint(
prefix, prefix=prefix, keep_vars=keep_vars)
keep_vars)
else: else:
if self.use_query_model: if self.use_query_model:
state_dict_[self._query_key] = \ state_dict_[self._query_key] = \
self.query_model.state_dict_for_save_checkpoint( self.query_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars) prefix=prefix, keep_vars=keep_vars)
if self.use_context_model: if self.use_context_model:
state_dict_[self._context_key] = \ state_dict_[self._context_key] = \
self.context_model.state_dict_for_save_checkpoint( self.context_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars) prefix=prefix, keep_vars=keep_vars)
return state_dict_ return state_dict_
...@@ -302,19 +300,19 @@ class PretrainedBertModel(MegatronModule): ...@@ -302,19 +300,19 @@ class PretrainedBertModel(MegatronModule):
return pooled_output return pooled_output
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
keep_vars=False):
"""For easy load when model is combined with other heads, """For easy load when model is combined with other heads,
add an extra key.""" add an extra key."""
state_dict_ = {} state_dict_ = {}
state_dict_[self._language_model_key] \ state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint( = self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars) prefix=prefix, keep_vars=keep_vars)
if self.biencoder_projection_dim > 0: if self.biencoder_projection_dim > 0:
state_dict_[self._projection_enc_key] = \ state_dict_[self._projection_enc_key] = \
self.projection_enc.state_dict(destination, prefix, keep_vars) self.projection_enc.state_dict(prefix=prefix,
keep_vars=keep_vars)
return state_dict_ return state_dict_
......
# coding=utf-8 # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Classification model.""" """Classification model."""
...@@ -89,19 +76,17 @@ class Classification(MegatronModule): ...@@ -89,19 +76,17 @@ class Classification(MegatronModule):
return classification_logits return classification_logits
return lm_output return lm_output
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
keep_vars=False):
"""For easy load when model is combined with other heads, """For easy load when model is combined with other heads,
add an extra key.""" add an extra key."""
state_dict_ = {} state_dict_ = {}
state_dict_[self._language_model_key] \ state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint( = self.language_model.state_dict_for_save_checkpoint(prefix=prefix,
destination, prefix, keep_vars) keep_vars=keep_vars)
if self.post_process: if self.post_process:
state_dict_[self._classification_head_key] \ state_dict_[self._classification_head_key] \
= self.classification_head.state_dict( = self.classification_head.state_dict(prefix=prefix, keep_vars=keep_vars)
destination, prefix, keep_vars)
return state_dict_ return state_dict_
def load_state_dict(self, state_dict, strict=True): def load_state_dict(self, state_dict, strict=True):
......
# coding=utf-8 # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC from abc import ABC
from abc import abstractmethod from abc import abstractmethod
...@@ -71,14 +58,13 @@ class DistributedDataParallelBase(MegatronModule, ABC): ...@@ -71,14 +58,13 @@ class DistributedDataParallelBase(MegatronModule, ABC):
return self.module(*inputs, **kwargs) return self.module(*inputs, **kwargs)
def state_dict(self, destination=None, prefix='', keep_vars=False): def state_dict(self, prefix='', keep_vars=False):
return self.module.state_dict(destination, prefix, keep_vars) return self.module.state_dict(prefix=prefix, keep_vars=keep_vars)
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
keep_vars=False): return self.module.state_dict_for_save_checkpoint(prefix=prefix,
return self.module.state_dict_for_save_checkpoint(destination, prefix, keep_vars=keep_vars)
keep_vars)
def load_state_dict(self, state_dict, strict=True): def load_state_dict(self, state_dict, strict=True):
......
# coding=utf-8 # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import enum import enum
......
# coding=utf-8 # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch import torch
......
# coding=utf-8 # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This code is copied fron NVIDIA apex: """This code is copied fron NVIDIA apex:
https://github.com/NVIDIA/apex https://github.com/NVIDIA/apex
......
# coding=utf-8 # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch import torch
......
# coding=utf-8 # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""GPT-2 model.""" """GPT-2 model."""
...@@ -105,17 +92,17 @@ class GPTModel(MegatronModule): ...@@ -105,17 +92,17 @@ class GPTModel(MegatronModule):
else: else:
return lm_output return lm_output
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
keep_vars=False):
state_dict_ = {} state_dict_ = {}
state_dict_[self._language_model_key] \ state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint( = self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars) prefix=prefix, keep_vars=keep_vars)
# Save word_embeddings. # Save word_embeddings.
if self.post_process and not self.pre_process: if self.post_process and not self.pre_process:
state_dict_[self._word_embeddings_for_head_key] \ state_dict_[self._word_embeddings_for_head_key] \
= self.word_embeddings.state_dict(destination, prefix, keep_vars) = self.word_embeddings.state_dict(prefix=prefix,
keep_vars=keep_vars)
return state_dict_ return state_dict_
def load_state_dict(self, state_dict, strict=True): def load_state_dict(self, state_dict, strict=True):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment