Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
41276b6c
Commit
41276b6c
authored
Oct 03, 2022
by
Vijay Korthikanti
Browse files
Merge branch 'main' into nmt-main
parents
a44360ed
fc7f4f03
Changes
135
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
61 additions
and
301 deletions
+61
-301
megatron/model/language_model.py
megatron/model/language_model.py
+17
-31
megatron/model/module.py
megatron/model/module.py
+8
-23
megatron/model/multiple_choice.py
megatron/model/multiple_choice.py
+5
-20
megatron/model/realm_model.py
megatron/model/realm_model.py
+8
-8
megatron/model/t5_model.py
megatron/model/t5_model.py
+8
-21
megatron/model/transformer.py
megatron/model/transformer.py
+1
-14
megatron/model/utils.py
megatron/model/utils.py
+1
-14
megatron/model/vision/classification.py
megatron/model/vision/classification.py
+1
-14
megatron/model/vision/inpainting.py
megatron/model/vision/inpainting.py
+1
-2
megatron/model/vision/vit_backbone.py
megatron/model/vision/vit_backbone.py
+1
-14
megatron/mpu/__init__.py
megatron/mpu/__init__.py
+1
-14
megatron/mpu/cross_entropy.py
megatron/mpu/cross_entropy.py
+1
-14
megatron/mpu/data.py
megatron/mpu/data.py
+1
-14
megatron/mpu/initialize.py
megatron/mpu/initialize.py
+1
-14
megatron/mpu/layers.py
megatron/mpu/layers.py
+1
-14
megatron/mpu/mappings.py
megatron/mpu/mappings.py
+1
-14
megatron/mpu/random.py
megatron/mpu/random.py
+1
-14
megatron/mpu/tests/commons.py
megatron/mpu/tests/commons.py
+1
-14
megatron/mpu/tests/test_cross_entropy.py
megatron/mpu/tests/test_cross_entropy.py
+1
-14
megatron/mpu/tests/test_data.py
megatron/mpu/tests/test_data.py
+1
-14
No files found.
megatron/model/language_model.py
View file @
41276b6c
# coding=utf-8
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transformer based language model."""
"""Transformer based language model."""
...
@@ -243,20 +230,20 @@ class Embedding(MegatronModule):
...
@@ -243,20 +230,20 @@ class Embedding(MegatronModule):
return
embeddings
return
embeddings
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
def
state_dict_for_save_checkpoint
(
self
,
prefix
=
''
,
keep_vars
=
False
):
keep_vars
=
False
):
"""For easy load."""
"""For easy load."""
state_dict_
=
{}
state_dict_
=
{}
state_dict_
[
self
.
_word_embeddings_key
]
\
state_dict_
[
self
.
_word_embeddings_key
]
\
=
self
.
word_embeddings
.
state_dict
(
destination
,
prefix
,
keep_vars
)
=
self
.
word_embeddings
.
state_dict
(
prefix
=
prefix
,
keep_vars
=
keep_vars
)
state_dict_
[
self
.
_position_embeddings_key
]
\
state_dict_
[
self
.
_position_embeddings_key
]
\
=
self
.
position_embeddings
.
state_dict
(
=
self
.
position_embeddings
.
state_dict
(
prefix
=
prefix
,
destination
,
prefix
,
keep_vars
)
keep_vars
=
keep_vars
)
if
self
.
num_tokentypes
>
0
:
if
self
.
num_tokentypes
>
0
:
state_dict_
[
self
.
_tokentype_embeddings_key
]
\
state_dict_
[
self
.
_tokentype_embeddings_key
]
\
=
self
.
tokentype_embeddings
.
state_dict
(
=
self
.
tokentype_embeddings
.
state_dict
(
prefix
=
prefix
,
destination
,
prefix
,
keep_vars
)
keep_vars
=
keep_vars
)
return
state_dict_
return
state_dict_
...
@@ -478,28 +465,27 @@ class TransformerLanguageModel(MegatronModule):
...
@@ -478,28 +465,27 @@ class TransformerLanguageModel(MegatronModule):
else
:
else
:
return
decoder_output
,
encoder_output
return
decoder_output
,
encoder_output
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
def
state_dict_for_save_checkpoint
(
self
,
prefix
=
''
,
keep_vars
=
False
):
keep_vars
=
False
):
"""For easy load."""
"""For easy load."""
state_dict_
=
{}
state_dict_
=
{}
if
self
.
pre_process
:
if
self
.
pre_process
:
state_dict_
[
self
.
_embedding_key
]
\
state_dict_
[
self
.
_embedding_key
]
\
=
self
.
embedding
.
state_dict_for_save_checkpoint
(
=
self
.
embedding
.
state_dict_for_save_checkpoint
(
prefix
=
prefix
,
destination
,
prefix
,
keep_vars
)
keep_vars
=
keep_vars
)
if
self
.
add_encoder
:
if
self
.
add_encoder
:
state_dict_
[
self
.
_encoder_key
]
\
state_dict_
[
self
.
_encoder_key
]
\
=
self
.
encoder
.
state_dict_for_save_checkpoint
(
=
self
.
encoder
.
state_dict_for_save_checkpoint
(
prefix
=
prefix
,
destination
,
prefix
,
keep_vars
)
keep_vars
=
keep_vars
)
if
self
.
post_process
:
if
self
.
post_process
:
if
self
.
add_pooler
:
if
self
.
add_pooler
:
state_dict_
[
self
.
_pooler_key
]
\
state_dict_
[
self
.
_pooler_key
]
\
=
self
.
pooler
.
state_dict_for_save_checkpoint
(
=
self
.
pooler
.
state_dict_for_save_checkpoint
(
prefix
=
prefix
,
destination
,
prefix
,
keep_vars
)
keep_vars
=
keep_vars
)
if
self
.
add_decoder
:
if
self
.
add_decoder
:
state_dict_
[
self
.
_decoder_key
]
\
state_dict_
[
self
.
_decoder_key
]
\
=
self
.
decoder
.
state_dict_for_save_checkpoint
(
=
self
.
decoder
.
state_dict_for_save_checkpoint
(
prefix
=
prefix
,
destination
,
prefix
,
keep_vars
)
keep_vars
=
keep_vars
)
return
state_dict_
return
state_dict_
...
...
megatron/model/module.py
View file @
41276b6c
# coding=utf-8
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Megatron Module"""
"""Megatron Module"""
...
@@ -43,11 +30,10 @@ class MegatronModule(torch.nn.Module):
...
@@ -43,11 +30,10 @@ class MegatronModule(torch.nn.Module):
self
.
share_word_embeddings
=
share_word_embeddings
self
.
share_word_embeddings
=
share_word_embeddings
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
def
state_dict_for_save_checkpoint
(
self
,
prefix
=
''
,
keep_vars
=
False
):
keep_vars
=
False
):
"""Use this function to override the state dict for
"""Use this function to override the state dict for
saving checkpoints."""
saving checkpoints."""
return
self
.
state_dict
(
destination
,
prefix
,
keep_vars
)
return
self
.
state_dict
(
prefix
=
prefix
,
keep_vars
=
keep_vars
)
def
word_embeddings_weight
(
self
):
def
word_embeddings_weight
(
self
):
...
@@ -198,14 +184,13 @@ class Float16Module(MegatronModule):
...
@@ -198,14 +184,13 @@ class Float16Module(MegatronModule):
return
outputs
return
outputs
def
state_dict
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
def
state_dict
(
self
,
prefix
=
''
,
keep_vars
=
False
):
return
self
.
module
.
state_dict
(
destination
,
prefix
,
keep_vars
)
return
self
.
module
.
state_dict
(
prefix
=
prefix
,
keep_vars
=
keep_vars
)
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
def
state_dict_for_save_checkpoint
(
self
,
prefix
=
''
,
keep_vars
=
False
):
keep_vars
=
False
):
return
self
.
module
.
state_dict_for_save_checkpoint
(
prefix
=
prefix
,
return
self
.
module
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
=
keep_vars
)
keep_vars
)
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
...
...
megatron/model/multiple_choice.py
View file @
41276b6c
# coding=utf-8
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Multiple choice model."""
"""Multiple choice model."""
...
@@ -100,19 +87,17 @@ class MultipleChoice(MegatronModule):
...
@@ -100,19 +87,17 @@ class MultipleChoice(MegatronModule):
return
multichoice_logits
return
multichoice_logits
return
lm_output
return
lm_output
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
def
state_dict_for_save_checkpoint
(
self
,
prefix
=
''
,
keep_vars
=
False
):
keep_vars
=
False
):
"""For easy load when model is combined with other heads,
"""For easy load when model is combined with other heads,
add an extra key."""
add an extra key."""
state_dict_
=
{}
state_dict_
=
{}
state_dict_
[
self
.
_language_model_key
]
\
state_dict_
[
self
.
_language_model_key
]
\
=
self
.
language_model
.
state_dict_for_save_checkpoint
(
=
self
.
language_model
.
state_dict_for_save_checkpoint
(
prefix
=
prefix
,
destination
,
prefix
,
keep_vars
)
keep_vars
=
keep_vars
)
if
self
.
post_process
:
if
self
.
post_process
:
state_dict_
[
self
.
_multichoice_head_key
]
\
state_dict_
[
self
.
_multichoice_head_key
]
\
=
self
.
multichoice_head
.
state_dict
(
=
self
.
multichoice_head
.
state_dict
(
prefix
=
prefix
,
keep_vars
=
keep_vars
)
destination
,
prefix
,
keep_vars
)
return
state_dict_
return
state_dict_
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
...
...
megatron/model/realm_model.py
View file @
41276b6c
...
@@ -87,18 +87,18 @@ class ICTBertModel(MegatronModule):
...
@@ -87,18 +87,18 @@ class ICTBertModel(MegatronModule):
else
:
else
:
raise
ValueError
(
"Cannot embed block without block model."
)
raise
ValueError
(
"Cannot embed block without block model."
)
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
def
state_dict_for_save_checkpoint
(
self
,
prefix
=
''
,
keep_vars
=
False
):
"""Save dict with state dicts of each of the models."""
"""Save dict with state dicts of each of the models."""
state_dict_
=
{}
state_dict_
=
{}
if
self
.
use_query_model
:
if
self
.
use_query_model
:
state_dict_
[
self
.
_query_key
]
\
state_dict_
[
self
.
_query_key
]
\
=
self
.
query_model
.
state_dict_for_save_checkpoint
(
=
self
.
query_model
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
prefix
=
prefix
,
keep_vars
=
keep_vars
)
if
self
.
use_block_model
:
if
self
.
use_block_model
:
state_dict_
[
self
.
_block_key
]
\
state_dict_
[
self
.
_block_key
]
\
=
self
.
block_model
.
state_dict_for_save_checkpoint
(
=
self
.
block_model
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
prefix
=
prefix
,
keep_vars
=
keep_vars
)
return
state_dict_
return
state_dict_
...
@@ -181,17 +181,17 @@ class IREncoderBertModel(MegatronModule):
...
@@ -181,17 +181,17 @@ class IREncoderBertModel(MegatronModule):
ict_logits
=
self
.
ict_head
(
pooled_output
)
ict_logits
=
self
.
ict_head
(
pooled_output
)
return
ict_logits
,
None
return
ict_logits
,
None
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
def
state_dict_for_save_checkpoint
(
self
,
prefix
=
''
,
keep_vars
=
False
):
keep_vars
=
False
):
"""For easy load when model is combined with other heads,
"""For easy load when model is combined with other heads,
add an extra key."""
add an extra key."""
state_dict_
=
{}
state_dict_
=
{}
state_dict_
[
self
.
_language_model_key
]
\
state_dict_
[
self
.
_language_model_key
]
\
=
self
.
language_model
.
state_dict_for_save_checkpoint
(
=
self
.
language_model
.
state_dict_for_save_checkpoint
(
prefix
=
prefix
,
destination
,
prefix
,
keep_vars
)
keep_vars
=
keep_vars
)
state_dict_
[
self
.
_ict_head_key
]
\
state_dict_
[
self
.
_ict_head_key
]
\
=
self
.
ict_head
.
state_dict
(
destination
,
prefix
,
keep_vars
)
=
self
.
ict_head
.
state_dict
(
prefix
=
prefix
,
keep_vars
=
keep_vars
)
return
state_dict_
return
state_dict_
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
...
...
megatron/model/t5_model.py
View file @
41276b6c
# coding=utf-8
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""T5 model."""
"""T5 model."""
...
@@ -178,23 +165,23 @@ class T5Model(MegatronModule):
...
@@ -178,23 +165,23 @@ class T5Model(MegatronModule):
encoder_output
=
lm_output
encoder_output
=
lm_output
return
encoder_output
return
encoder_output
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
def
state_dict_for_save_checkpoint
(
self
,
prefix
=
''
,
keep_vars
=
False
):
keep_vars
=
False
):
"""For easy load when model is combined with other heads,
"""For easy load when model is combined with other heads,
add an extra key."""
add an extra key."""
state_dict_
=
{}
state_dict_
=
{}
state_dict_
[
self
.
_language_model_key
]
\
state_dict_
[
self
.
_language_model_key
]
\
=
self
.
language_model
.
state_dict_for_save_checkpoint
(
=
self
.
language_model
.
state_dict_for_save_checkpoint
(
prefix
=
prefix
,
destination
,
prefix
,
keep_vars
)
keep_vars
=
keep_vars
)
if
self
.
post_process
and
self
.
add_decoder
:
if
self
.
post_process
and
self
.
add_decoder
:
state_dict_
[
self
.
_lm_head_key
]
\
state_dict_
[
self
.
_lm_head_key
]
\
=
self
.
lm_head
.
state_dict_for_save_checkpoint
(
=
self
.
lm_head
.
state_dict_for_save_checkpoint
(
prefix
=
prefix
,
destination
,
prefix
,
keep_vars
)
keep_vars
=
keep_vars
)
# Save word_embeddings.
# Save word_embeddings.
if
self
.
post_process
and
not
self
.
pre_process
and
self
.
add_decoder
:
if
self
.
post_process
and
not
self
.
pre_process
and
self
.
add_decoder
:
state_dict_
[
self
.
_word_embeddings_for_head_key
]
\
state_dict_
[
self
.
_word_embeddings_for_head_key
]
\
=
self
.
word_embeddings
.
state_dict
(
destination
,
prefix
,
keep_vars
)
=
self
.
word_embeddings
.
state_dict
(
prefix
=
prefix
,
keep_vars
=
keep_vars
)
return
state_dict_
return
state_dict_
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
...
...
megatron/model/transformer.py
View file @
41276b6c
# coding=utf-8
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transformer."""
"""Transformer."""
import
math
import
math
...
...
megatron/model/utils.py
View file @
41276b6c
# coding=utf-8
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for models."""
"""Utilities for models."""
...
...
megatron/model/vision/classification.py
View file @
41276b6c
# coding=utf-8
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Vision Transformer(VIT) model."""
"""Vision Transformer(VIT) model."""
...
...
megatron/model/vision/inpainting.py
View file @
41276b6c
# coding=utf-8
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
#
# This source code is licensed under the BSD license found in the
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# LICENSE file in the root directory of this source tree.
...
...
megatron/model/vision/vit_backbone.py
View file @
41276b6c
# coding=utf-8
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Vision Transformer(VIT) model."""
"""Vision Transformer(VIT) model."""
...
...
megatron/mpu/__init__.py
View file @
41276b6c
# coding=utf-8
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Model parallel utility interface."""
"""Model parallel utility interface."""
...
...
megatron/mpu/cross_entropy.py
View file @
41276b6c
# coding=utf-8
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
import
torch
...
...
megatron/mpu/data.py
View file @
41276b6c
# coding=utf-8
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
import
torch
...
...
megatron/mpu/initialize.py
View file @
41276b6c
# coding=utf-8
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Model and data parallel groups."""
"""Model and data parallel groups."""
...
...
megatron/mpu/layers.py
View file @
41276b6c
# coding=utf-8
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Parts of the code here are adapted from PyTorch
# Parts of the code here are adapted from PyTorch
...
...
megatron/mpu/mappings.py
View file @
41276b6c
# coding=utf-8
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
import
torch
...
...
megatron/mpu/random.py
View file @
41276b6c
# coding=utf-8
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Parts of the code here are adapted from PyTorch
# Parts of the code here are adapted from PyTorch
...
...
megatron/mpu/tests/commons.py
View file @
41276b6c
# coding=utf-8
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
import
argparse
import
os
import
os
...
...
megatron/mpu/tests/test_cross_entropy.py
View file @
41276b6c
# coding=utf-8
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
commons
import
set_random_seed
from
commons
import
set_random_seed
from
commons
import
IdentityLayer
from
commons
import
IdentityLayer
...
...
megatron/mpu/tests/test_data.py
View file @
41276b6c
# coding=utf-8
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
commons
import
print_separator
from
commons
import
print_separator
from
commons
import
initialize_distributed
from
commons
import
initialize_distributed
...
...
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment