Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
41276b6c
Commit
41276b6c
authored
Oct 03, 2022
by
Vijay Korthikanti
Browse files
Merge branch 'main' into nmt-main
parents
a44360ed
fc7f4f03
Changes
135
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
53 additions
and
392 deletions
+53
-392
megatron/fused_kernels/scaled_softmax.cpp
megatron/fused_kernels/scaled_softmax.cpp
+1
-15
megatron/fused_kernels/scaled_softmax_cuda.cu
megatron/fused_kernels/scaled_softmax_cuda.cu
+1
-15
megatron/fused_kernels/scaled_upper_triang_masked_softmax.cpp
...tron/fused_kernels/scaled_upper_triang_masked_softmax.cpp
+1
-15
megatron/fused_kernels/scaled_upper_triang_masked_softmax.h
megatron/fused_kernels/scaled_upper_triang_masked_softmax.h
+1
-15
megatron/fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu
.../fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu
+1
-15
megatron/fused_kernels/type_shim.h
megatron/fused_kernels/type_shim.h
+1
-15
megatron/global_vars.py
megatron/global_vars.py
+6
-99
megatron/initialize.py
megatron/initialize.py
+1
-14
megatron/memory.py
megatron/memory.py
+1
-14
megatron/microbatches.py
megatron/microbatches.py
+1
-14
megatron/model/__init__.py
megatron/model/__init__.py
+1
-14
megatron/model/bert_model.py
megatron/model/bert_model.py
+8
-22
megatron/model/biencoder_model.py
megatron/model/biencoder_model.py
+9
-11
megatron/model/classification.py
megatron/model/classification.py
+5
-20
megatron/model/distributed.py
megatron/model/distributed.py
+6
-20
megatron/model/enums.py
megatron/model/enums.py
+1
-14
megatron/model/fused_bias_gelu.py
megatron/model/fused_bias_gelu.py
+1
-14
megatron/model/fused_layer_norm.py
megatron/model/fused_layer_norm.py
+1
-14
megatron/model/fused_softmax.py
megatron/model/fused_softmax.py
+1
-14
megatron/model/gpt_model.py
megatron/model/gpt_model.py
+5
-18
No files found.
megatron/fused_kernels/scaled_softmax.cpp
View file @
41276b6c
/* coding=utf-8
/* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cuda_fp16.h>
#include <cuda_fp16.h>
#include <torch/extension.h>
#include <torch/extension.h>
...
...
megatron/fused_kernels/scaled_softmax_cuda.cu
View file @
41276b6c
/* coding=utf-8
/* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/ATen.h>
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda.h>
...
...
megatron/fused_kernels/scaled_upper_triang_masked_softmax.cpp
View file @
41276b6c
/* coding=utf-8
/* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cuda_fp16.h>
#include <cuda_fp16.h>
#include <torch/extension.h>
#include <torch/extension.h>
...
...
megatron/fused_kernels/scaled_upper_triang_masked_softmax.h
View file @
41276b6c
/* coding=utf-8
/* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#pragma once
...
...
megatron/fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu
View file @
41276b6c
/* coding=utf-8
/* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/ATen.h>
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda.h>
...
...
megatron/fused_kernels/type_shim.h
View file @
41276b6c
/* coding=utf-8
/* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/ATen.h>
#include <ATen/ATen.h>
...
...
megatron/global_vars.py
View file @
41276b6c
# coding=utf-8
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Megatron global variables."""
"""Megatron global variables."""
import
os
import
os
import
sys
import
sys
import
time
from
functools
import
reduce
from
functools
import
reduce
import
operator
import
operator
import
torch
import
torch
...
@@ -25,6 +11,7 @@ import torch
...
@@ -25,6 +11,7 @@ import torch
from
megatron
import
dist_signal_handler
from
megatron
import
dist_signal_handler
from
megatron.tokenizer
import
build_tokenizer
from
megatron.tokenizer
import
build_tokenizer
from
.microbatches
import
build_num_microbatches_calculator
from
.microbatches
import
build_num_microbatches_calculator
from
.timers
import
Timers
_GLOBAL_ARGS
=
None
_GLOBAL_ARGS
=
None
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
=
None
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
=
None
...
@@ -108,7 +95,7 @@ def set_global_variables(args):
...
@@ -108,7 +95,7 @@ def set_global_variables(args):
_
=
_build_tokenizer
(
args
)
_
=
_build_tokenizer
(
args
)
_set_tensorboard_writer
(
args
)
_set_tensorboard_writer
(
args
)
_set_adlr_autoresume
(
args
)
_set_adlr_autoresume
(
args
)
_set_timers
()
_set_timers
(
args
)
_set_global_memory_buffer
()
_set_global_memory_buffer
()
if
args
.
exit_signal_handler
:
if
args
.
exit_signal_handler
:
...
@@ -182,11 +169,12 @@ def _set_adlr_autoresume(args):
...
@@ -182,11 +169,12 @@ def _set_adlr_autoresume(args):
_GLOBAL_ADLR_AUTORESUME
=
AutoResume
_GLOBAL_ADLR_AUTORESUME
=
AutoResume
def
_set_timers
():
def
_set_timers
(
args
):
"""Initialize timers."""
"""Initialize timers."""
global
_GLOBAL_TIMERS
global
_GLOBAL_TIMERS
_ensure_var_is_not_initialized
(
_GLOBAL_TIMERS
,
'timers'
)
_ensure_var_is_not_initialized
(
_GLOBAL_TIMERS
,
'timers'
)
_GLOBAL_TIMERS
=
Timers
()
_GLOBAL_TIMERS
=
Timers
(
args
.
timing_log_level
,
args
.
timing_log_option
)
def
_set_global_memory_buffer
():
def
_set_global_memory_buffer
():
"""Initialize global buffer"""
"""Initialize global buffer"""
...
@@ -205,87 +193,6 @@ def _ensure_var_is_not_initialized(var, name):
...
@@ -205,87 +193,6 @@ def _ensure_var_is_not_initialized(var, name):
assert
var
is
None
,
'{} is already initialized.'
.
format
(
name
)
assert
var
is
None
,
'{} is already initialized.'
.
format
(
name
)
class
_Timer
:
"""Timer."""
def
__init__
(
self
,
name
):
self
.
name_
=
name
self
.
elapsed_
=
0.0
self
.
started_
=
False
self
.
start_time
=
time
.
time
()
def
start
(
self
):
"""Start the timer."""
assert
not
self
.
started_
,
'timer has already been started'
torch
.
cuda
.
synchronize
()
self
.
start_time
=
time
.
time
()
self
.
started_
=
True
def
stop
(
self
):
"""Stop the timer."""
assert
self
.
started_
,
'timer is not started'
torch
.
cuda
.
synchronize
()
self
.
elapsed_
+=
(
time
.
time
()
-
self
.
start_time
)
self
.
started_
=
False
def
reset
(
self
):
"""Reset timer."""
self
.
elapsed_
=
0.0
self
.
started_
=
False
def
elapsed
(
self
,
reset
=
True
):
"""Calculate the elapsed time."""
started_
=
self
.
started_
# If the timing in progress, end it first.
if
self
.
started_
:
self
.
stop
()
# Get the elapsed time.
elapsed_
=
self
.
elapsed_
# Reset the elapsed time
if
reset
:
self
.
reset
()
# If timing was in progress, set it back.
if
started_
:
self
.
start
()
return
elapsed_
class
Timers
:
"""Group of timers."""
def
__init__
(
self
):
self
.
timers
=
{}
def
__call__
(
self
,
name
):
if
name
not
in
self
.
timers
:
self
.
timers
[
name
]
=
_Timer
(
name
)
return
self
.
timers
[
name
]
def
write
(
self
,
names
,
writer
,
iteration
,
normalizer
=
1.0
,
reset
=
False
):
"""Write timers to a tensorboard writer"""
# currently when using add_scalars,
# torch.utils.add_scalars makes each timer its own run, which
# polutes the runs list, so we just add each as a scalar
assert
normalizer
>
0.0
for
name
in
names
:
value
=
self
.
timers
[
name
].
elapsed
(
reset
=
reset
)
/
normalizer
writer
.
add_scalar
(
name
+
'-time'
,
value
,
iteration
)
def
log
(
self
,
names
,
normalizer
=
1.0
,
reset
=
True
):
"""Log a group of timers."""
assert
normalizer
>
0.0
string
=
'time (ms)'
for
name
in
names
:
elapsed_time
=
self
.
timers
[
name
].
elapsed
(
reset
=
reset
)
*
1000.0
/
normalizer
string
+=
' | {}: {:.2f}'
.
format
(
name
,
elapsed_time
)
if
torch
.
distributed
.
is_initialized
():
if
torch
.
distributed
.
get_rank
()
==
(
torch
.
distributed
.
get_world_size
()
-
1
):
print
(
string
,
flush
=
True
)
else
:
print
(
string
,
flush
=
True
)
class
GlobalMemoryBuffer
:
class
GlobalMemoryBuffer
:
"""Global buffer to avoid dynamic memory allocations.
"""Global buffer to avoid dynamic memory allocations.
...
...
megatron/initialize.py
View file @
41276b6c
# coding=utf-8
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Megatron initialization."""
"""Megatron initialization."""
...
...
megatron/memory.py
View file @
41276b6c
# coding=utf-8
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
import
torch
...
...
megatron/microbatches.py
View file @
41276b6c
# coding=utf-8
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Megatron number of micro-batches calculators."""
"""Megatron number of micro-batches calculators."""
...
...
megatron/model/__init__.py
View file @
41276b6c
# coding=utf-8
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.fused_layer_norm
import
MixedFusedLayerNorm
as
LayerNorm
from
.fused_layer_norm
import
MixedFusedLayerNorm
as
LayerNorm
...
...
megatron/model/bert_model.py
View file @
41276b6c
# coding=utf-8
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""BERT model."""
"""BERT model."""
...
@@ -208,26 +195,25 @@ class BertModel(MegatronModule):
...
@@ -208,26 +195,25 @@ class BertModel(MegatronModule):
return
lm_output
return
lm_output
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
def
state_dict_for_save_checkpoint
(
self
,
prefix
=
''
,
keep_vars
=
False
):
keep_vars
=
False
):
"""For easy load when model is combined with other heads,
"""For easy load when model is combined with other heads,
add an extra key."""
add an extra key."""
state_dict_
=
{}
state_dict_
=
{}
state_dict_
[
self
.
_language_model_key
]
\
state_dict_
[
self
.
_language_model_key
]
\
=
self
.
language_model
.
state_dict_for_save_checkpoint
(
=
self
.
language_model
.
state_dict_for_save_checkpoint
(
prefix
=
prefix
,
destination
,
prefix
,
keep_vars
)
keep_vars
=
keep_vars
)
if
self
.
post_process
:
if
self
.
post_process
:
state_dict_
[
self
.
_lm_head_key
]
\
state_dict_
[
self
.
_lm_head_key
]
\
=
self
.
lm_head
.
state_dict_for_save_checkpoint
(
=
self
.
lm_head
.
state_dict_for_save_checkpoint
(
prefix
=
prefix
,
destination
,
prefix
,
keep_vars
)
keep_vars
=
keep_vars
)
if
self
.
post_process
and
self
.
add_binary_head
:
if
self
.
post_process
and
self
.
add_binary_head
:
state_dict_
[
self
.
_binary_head_key
]
\
state_dict_
[
self
.
_binary_head_key
]
\
=
self
.
binary_head
.
state_dict
(
destination
,
prefix
,
keep_vars
)
=
self
.
binary_head
.
state_dict
(
prefix
=
prefix
,
keep_vars
=
keep_vars
)
# Save word_embeddings.
# Save word_embeddings.
if
self
.
post_process
and
not
self
.
pre_process
:
if
self
.
post_process
and
not
self
.
pre_process
:
state_dict_
[
self
.
_word_embeddings_for_head_key
]
\
state_dict_
[
self
.
_word_embeddings_for_head_key
]
\
=
self
.
word_embeddings
.
state_dict
(
destination
,
prefix
,
keep_vars
)
=
self
.
word_embeddings
.
state_dict
(
prefix
=
prefix
,
keep_vars
=
keep_vars
)
return
state_dict_
return
state_dict_
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
...
...
megatron/model/biencoder_model.py
View file @
41276b6c
...
@@ -139,25 +139,23 @@ class BiEncoderModel(MegatronModule):
...
@@ -139,25 +139,23 @@ class BiEncoderModel(MegatronModule):
token_types
)
token_types
)
return
logits
return
logits
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
\
def
state_dict_for_save_checkpoint
(
self
,
prefix
=
''
,
keep_vars
=
False
):
prefix
=
''
,
keep_vars
=
False
):
"""Save dict with state dicts of each of the models."""
"""Save dict with state dicts of each of the models."""
state_dict_
=
{}
state_dict_
=
{}
if
self
.
biencoder_shared_query_context_model
:
if
self
.
biencoder_shared_query_context_model
:
state_dict_
[
self
.
_model_key
]
=
\
state_dict_
[
self
.
_model_key
]
=
\
self
.
model
.
state_dict_for_save_checkpoint
(
destination
,
self
.
model
.
state_dict_for_save_checkpoint
(
prefix
,
prefix
=
prefix
,
keep_vars
=
keep_vars
)
keep_vars
)
else
:
else
:
if
self
.
use_query_model
:
if
self
.
use_query_model
:
state_dict_
[
self
.
_query_key
]
=
\
state_dict_
[
self
.
_query_key
]
=
\
self
.
query_model
.
state_dict_for_save_checkpoint
(
self
.
query_model
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
prefix
=
prefix
,
keep_vars
=
keep_vars
)
if
self
.
use_context_model
:
if
self
.
use_context_model
:
state_dict_
[
self
.
_context_key
]
=
\
state_dict_
[
self
.
_context_key
]
=
\
self
.
context_model
.
state_dict_for_save_checkpoint
(
self
.
context_model
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
prefix
=
prefix
,
keep_vars
=
keep_vars
)
return
state_dict_
return
state_dict_
...
@@ -302,19 +300,19 @@ class PretrainedBertModel(MegatronModule):
...
@@ -302,19 +300,19 @@ class PretrainedBertModel(MegatronModule):
return
pooled_output
return
pooled_output
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
def
state_dict_for_save_checkpoint
(
self
,
prefix
=
''
,
keep_vars
=
False
):
keep_vars
=
False
):
"""For easy load when model is combined with other heads,
"""For easy load when model is combined with other heads,
add an extra key."""
add an extra key."""
state_dict_
=
{}
state_dict_
=
{}
state_dict_
[
self
.
_language_model_key
]
\
state_dict_
[
self
.
_language_model_key
]
\
=
self
.
language_model
.
state_dict_for_save_checkpoint
(
=
self
.
language_model
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
prefix
=
prefix
,
keep_vars
=
keep_vars
)
if
self
.
biencoder_projection_dim
>
0
:
if
self
.
biencoder_projection_dim
>
0
:
state_dict_
[
self
.
_projection_enc_key
]
=
\
state_dict_
[
self
.
_projection_enc_key
]
=
\
self
.
projection_enc
.
state_dict
(
destination
,
prefix
,
keep_vars
)
self
.
projection_enc
.
state_dict
(
prefix
=
prefix
,
keep_vars
=
keep_vars
)
return
state_dict_
return
state_dict_
...
...
megatron/model/classification.py
View file @
41276b6c
# coding=utf-8
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Classification model."""
"""Classification model."""
...
@@ -89,19 +76,17 @@ class Classification(MegatronModule):
...
@@ -89,19 +76,17 @@ class Classification(MegatronModule):
return
classification_logits
return
classification_logits
return
lm_output
return
lm_output
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
def
state_dict_for_save_checkpoint
(
self
,
prefix
=
''
,
keep_vars
=
False
):
keep_vars
=
False
):
"""For easy load when model is combined with other heads,
"""For easy load when model is combined with other heads,
add an extra key."""
add an extra key."""
state_dict_
=
{}
state_dict_
=
{}
state_dict_
[
self
.
_language_model_key
]
\
state_dict_
[
self
.
_language_model_key
]
\
=
self
.
language_model
.
state_dict_for_save_checkpoint
(
=
self
.
language_model
.
state_dict_for_save_checkpoint
(
prefix
=
prefix
,
destination
,
prefix
,
keep_vars
)
keep_vars
=
keep_vars
)
if
self
.
post_process
:
if
self
.
post_process
:
state_dict_
[
self
.
_classification_head_key
]
\
state_dict_
[
self
.
_classification_head_key
]
\
=
self
.
classification_head
.
state_dict
(
=
self
.
classification_head
.
state_dict
(
prefix
=
prefix
,
keep_vars
=
keep_vars
)
destination
,
prefix
,
keep_vars
)
return
state_dict_
return
state_dict_
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
...
...
megatron/model/distributed.py
View file @
41276b6c
# coding=utf-8
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
abc
import
ABC
from
abc
import
ABC
from
abc
import
abstractmethod
from
abc
import
abstractmethod
...
@@ -71,14 +58,13 @@ class DistributedDataParallelBase(MegatronModule, ABC):
...
@@ -71,14 +58,13 @@ class DistributedDataParallelBase(MegatronModule, ABC):
return
self
.
module
(
*
inputs
,
**
kwargs
)
return
self
.
module
(
*
inputs
,
**
kwargs
)
def
state_dict
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
def
state_dict
(
self
,
prefix
=
''
,
keep_vars
=
False
):
return
self
.
module
.
state_dict
(
destination
,
prefix
,
keep_vars
)
return
self
.
module
.
state_dict
(
prefix
=
prefix
,
keep_vars
=
keep_vars
)
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
def
state_dict_for_save_checkpoint
(
self
,
prefix
=
''
,
keep_vars
=
False
):
keep_vars
=
False
):
return
self
.
module
.
state_dict_for_save_checkpoint
(
prefix
=
prefix
,
return
self
.
module
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
=
keep_vars
)
keep_vars
)
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
...
...
megatron/model/enums.py
View file @
41276b6c
# coding=utf-8
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
enum
import
enum
...
...
megatron/model/fused_bias_gelu.py
View file @
41276b6c
# coding=utf-8
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
import
torch
...
...
megatron/model/fused_layer_norm.py
View file @
41276b6c
# coding=utf-8
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This code is copied fron NVIDIA apex:
"""This code is copied fron NVIDIA apex:
https://github.com/NVIDIA/apex
https://github.com/NVIDIA/apex
...
...
megatron/model/fused_softmax.py
View file @
41276b6c
# coding=utf-8
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
import
torch
...
...
megatron/model/gpt_model.py
View file @
41276b6c
# coding=utf-8
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""GPT-2 model."""
"""GPT-2 model."""
...
@@ -105,17 +92,17 @@ class GPTModel(MegatronModule):
...
@@ -105,17 +92,17 @@ class GPTModel(MegatronModule):
else
:
else
:
return
lm_output
return
lm_output
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
def
state_dict_for_save_checkpoint
(
self
,
prefix
=
''
,
keep_vars
=
False
):
keep_vars
=
False
):
state_dict_
=
{}
state_dict_
=
{}
state_dict_
[
self
.
_language_model_key
]
\
state_dict_
[
self
.
_language_model_key
]
\
=
self
.
language_model
.
state_dict_for_save_checkpoint
(
=
self
.
language_model
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
prefix
=
prefix
,
keep_vars
=
keep_vars
)
# Save word_embeddings.
# Save word_embeddings.
if
self
.
post_process
and
not
self
.
pre_process
:
if
self
.
post_process
and
not
self
.
pre_process
:
state_dict_
[
self
.
_word_embeddings_for_head_key
]
\
state_dict_
[
self
.
_word_embeddings_for_head_key
]
\
=
self
.
word_embeddings
.
state_dict
(
destination
,
prefix
,
keep_vars
)
=
self
.
word_embeddings
.
state_dict
(
prefix
=
prefix
,
keep_vars
=
keep_vars
)
return
state_dict_
return
state_dict_
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
...
...
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment