Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
5942af97
Commit
5942af97
authored
Sep 23, 2022
by
Jared Casper
Browse files
Alias core.parallel_state as mpu and use it throughout code. RIP mpu.
parent
c2ea914f
Changes
63
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
26 additions
and
26 deletions
+26
-26
pretrain_bert.py
pretrain_bert.py
+2
-2
pretrain_gpt.py
pretrain_gpt.py
+2
-3
pretrain_ict.py
pretrain_ict.py
+1
-1
pretrain_t5.py
pretrain_t5.py
+2
-2
pretrain_vision_classify.py
pretrain_vision_classify.py
+1
-1
pretrain_vision_dino.py
pretrain_vision_dino.py
+1
-1
pretrain_vision_inpaint.py
pretrain_vision_inpaint.py
+1
-1
tasks/eval_utils.py
tasks/eval_utils.py
+1
-1
tasks/finetune_utils.py
tasks/finetune_utils.py
+1
-1
tasks/glue/finetune.py
tasks/glue/finetune.py
+0
-1
tasks/msdp/prompt.py
tasks/msdp/prompt.py
+1
-1
tasks/orqa/supervised/eval_utils.py
tasks/orqa/supervised/eval_utils.py
+1
-1
tasks/orqa/supervised/finetune.py
tasks/orqa/supervised/finetune.py
+2
-2
tasks/orqa/unsupervised/nq.py
tasks/orqa/unsupervised/nq.py
+1
-1
tasks/race/finetune.py
tasks/race/finetune.py
+0
-1
tasks/vision/classification/eval_utils.py
tasks/vision/classification/eval_utils.py
+1
-1
tasks/vision/finetune_utils.py
tasks/vision/finetune_utils.py
+2
-1
tasks/vision/segmentation/finetune_segformer.py
tasks/vision/segmentation/finetune_segformer.py
+2
-1
tasks/vision/segmentation/finetune_setr.py
tasks/vision/segmentation/finetune_setr.py
+2
-1
tasks/zeroshot_gpt/evaluate.py
tasks/zeroshot_gpt/evaluate.py
+2
-2
No files found.
pretrain_bert.py
View file @
5942af97
...
@@ -10,7 +10,7 @@ import torch.nn.functional as F
...
@@ -10,7 +10,7 @@ import torch.nn.functional as F
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
print_rank_0
from
megatron
import
print_rank_0
from
megatron
import
get_timers
from
megatron
import
get_timers
from
megatron
import
mpu
from
megatron
.core
import
tensor_parallel
from
megatron.data.dataset_utils
import
build_train_valid_test_datasets
from
megatron.data.dataset_utils
import
build_train_valid_test_datasets
from
megatron.model
import
BertModel
,
ModelType
from
megatron.model
import
BertModel
,
ModelType
from
megatron.training
import
pretrain
from
megatron.training
import
pretrain
...
@@ -46,7 +46,7 @@ def get_batch(data_iterator):
...
@@ -46,7 +46,7 @@ def get_batch(data_iterator):
data
=
next
(
data_iterator
)
data
=
next
(
data_iterator
)
else
:
else
:
data
=
None
data
=
None
data_b
=
mpu
.
broadcast_data
(
keys
,
data
,
datatype
)
data_b
=
tensor_parallel
.
broadcast_data
(
keys
,
data
,
datatype
)
# Unpack.
# Unpack.
tokens
=
data_b
[
'text'
].
long
()
tokens
=
data_b
[
'text'
].
long
()
...
...
pretrain_gpt.py
View file @
5942af97
...
@@ -8,8 +8,7 @@ from megatron import get_args
...
@@ -8,8 +8,7 @@ from megatron import get_args
from
megatron
import
print_rank_0
from
megatron
import
print_rank_0
from
megatron
import
get_timers
from
megatron
import
get_timers
from
megatron
import
get_tokenizer
from
megatron
import
get_tokenizer
from
megatron
import
mpu
from
megatron.core
import
tensor_parallel
from
megatron
import
core
from
megatron.data.gpt_dataset
import
build_train_valid_test_datasets
from
megatron.data.gpt_dataset
import
build_train_valid_test_datasets
from
megatron.model
import
GPTModel
,
ModelType
from
megatron.model
import
GPTModel
,
ModelType
from
megatron.training
import
pretrain
from
megatron.training
import
pretrain
...
@@ -43,7 +42,7 @@ def get_batch(data_iterator):
...
@@ -43,7 +42,7 @@ def get_batch(data_iterator):
data
=
next
(
data_iterator
)
data
=
next
(
data_iterator
)
else
:
else
:
data
=
None
data
=
None
data_b
=
core
.
tensor_parallel
.
broadcast_data
(
keys
,
data
,
datatype
)
data_b
=
tensor_parallel
.
broadcast_data
(
keys
,
data
,
datatype
)
# Unpack.
# Unpack.
tokens_
=
data_b
[
'text'
].
long
()
tokens_
=
data_b
[
'text'
].
long
()
...
...
pretrain_ict.py
View file @
5942af97
...
@@ -12,7 +12,7 @@ import torch.nn.functional as F
...
@@ -12,7 +12,7 @@ import torch.nn.functional as F
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
print_rank_0
from
megatron
import
print_rank_0
from
megatron
import
get_timers
from
megatron
import
get_timers
from
megatron
import
mpu
from
megatron
.core
import
mpu
from
megatron.data.biencoder_dataset_utils
import
get_ict_batch
from
megatron.data.biencoder_dataset_utils
import
get_ict_batch
from
megatron.data.dataset_utils
import
build_train_valid_test_datasets
from
megatron.data.dataset_utils
import
build_train_valid_test_datasets
from
megatron.model
import
ModelType
from
megatron.model
import
ModelType
...
...
pretrain_t5.py
View file @
5942af97
...
@@ -9,9 +9,9 @@ import torch
...
@@ -9,9 +9,9 @@ import torch
from
megatron
import
(
from
megatron
import
(
get_args
,
get_args
,
get_timers
,
get_timers
,
mpu
,
print_rank_0
print_rank_0
)
)
from
megatron.core
import
tensor_parallel
from
megatron.data.dataset_utils
import
build_train_valid_test_datasets
from
megatron.data.dataset_utils
import
build_train_valid_test_datasets
from
megatron.model
import
T5Model
,
ModelType
from
megatron.model
import
T5Model
,
ModelType
from
megatron.training
import
pretrain
from
megatron.training
import
pretrain
...
@@ -80,7 +80,7 @@ def get_batch(data_iterator):
...
@@ -80,7 +80,7 @@ def get_batch(data_iterator):
data
=
next
(
data_iterator
)
data
=
next
(
data_iterator
)
else
:
else
:
data
=
None
data
=
None
data_b
=
mpu
.
broadcast_data
(
keys
,
data
,
datatype
)
data_b
=
tensor_parallel
.
broadcast_data
(
keys
,
data
,
datatype
)
# Unpack.
# Unpack.
tokens_enc
=
data_b
[
'text_enc'
].
long
()
tokens_enc
=
data_b
[
'text_enc'
].
long
()
...
...
pretrain_vision_classify.py
View file @
5942af97
...
@@ -5,7 +5,7 @@
...
@@ -5,7 +5,7 @@
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
functools
import
partial
from
functools
import
partial
from
megatron
import
get_args
,
get_timers
,
mpu
,
print_rank_0
from
megatron
import
get_args
,
get_timers
,
print_rank_0
from
megatron.data.vit_dataset
import
build_train_valid_datasets
from
megatron.data.vit_dataset
import
build_train_valid_datasets
from
megatron.model
import
ModelType
from
megatron.model
import
ModelType
from
megatron.model.vision.classification
import
VitClassificationModel
from
megatron.model.vision.classification
import
VitClassificationModel
...
...
pretrain_vision_dino.py
View file @
5942af97
...
@@ -6,7 +6,7 @@ import torch.nn as nn
...
@@ -6,7 +6,7 @@ import torch.nn as nn
import
numpy
as
np
import
numpy
as
np
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
functools
import
partial
from
functools
import
partial
from
megatron
import
get_args
,
get_timers
,
mpu
,
print_rank_0
from
megatron
import
get_args
,
get_timers
,
print_rank_0
from
megatron.data.vit_dataset
import
build_train_valid_datasets
from
megatron.data.vit_dataset
import
build_train_valid_datasets
from
megatron.model.vision.dino
import
DINOPretrainModel
from
megatron.model.vision.dino
import
DINOPretrainModel
from
megatron.model.vision.knn_monitor
import
knn_predict
,
get_feature_bank
from
megatron.model.vision.knn_monitor
import
knn_predict
,
get_feature_bank
...
...
pretrain_vision_inpaint.py
View file @
5942af97
...
@@ -5,7 +5,7 @@
...
@@ -5,7 +5,7 @@
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
functools
import
partial
from
functools
import
partial
from
megatron
import
get_args
,
get_timers
,
mpu
,
print_rank_0
,
print_rank_last
from
megatron
import
get_args
,
get_timers
,
print_rank_0
,
print_rank_last
from
megatron.data.vit_dataset
import
build_train_valid_datasets
from
megatron.data.vit_dataset
import
build_train_valid_datasets
from
megatron.model.vision.inpainting
import
VitInpaintingModel
from
megatron.model.vision.inpainting
import
VitInpaintingModel
from
megatron.model.vision.inpainting
import
MitInpaintingModel
from
megatron.model.vision.inpainting
import
MitInpaintingModel
...
...
tasks/eval_utils.py
View file @
5942af97
...
@@ -10,7 +10,7 @@ import torch
...
@@ -10,7 +10,7 @@ import torch
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
print_rank_last
,
is_last_rank
from
megatron
import
print_rank_last
,
is_last_rank
from
megatron
import
mpu
from
megatron
.core
import
mpu
from
megatron.schedules
import
get_forward_backward_func
from
megatron.schedules
import
get_forward_backward_func
from
tasks.finetune_utils
import
build_data_loader
from
tasks.finetune_utils
import
build_data_loader
from
tasks.finetune_utils
import
process_batch
from
tasks.finetune_utils
import
process_batch
...
...
tasks/finetune_utils.py
View file @
5942af97
...
@@ -9,7 +9,7 @@ import torch
...
@@ -9,7 +9,7 @@ import torch
from
megatron
import
get_args
,
get_num_microbatches
from
megatron
import
get_args
,
get_num_microbatches
from
megatron
import
print_rank_0
from
megatron
import
print_rank_0
from
megatron
import
get_timers
from
megatron
import
get_timers
from
megatron
import
mpu
from
megatron
.core
import
mpu
from
megatron.checkpointing
import
load_checkpoint
from
megatron.checkpointing
import
load_checkpoint
from
megatron.checkpointing
import
save_checkpoint
from
megatron.checkpointing
import
save_checkpoint
from
megatron.model
import
ModelType
from
megatron.model
import
ModelType
...
...
tasks/glue/finetune.py
View file @
5942af97
...
@@ -5,7 +5,6 @@
...
@@ -5,7 +5,6 @@
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
print_rank_0
from
megatron
import
print_rank_0
from
megatron
import
get_tokenizer
from
megatron
import
get_tokenizer
from
megatron
import
mpu
from
megatron.model.classification
import
Classification
from
megatron.model.classification
import
Classification
from
tasks.eval_utils
import
accuracy_func_provider
from
tasks.eval_utils
import
accuracy_func_provider
from
tasks.finetune_utils
import
finetune
from
tasks.finetune_utils
import
finetune
...
...
tasks/msdp/prompt.py
View file @
5942af97
...
@@ -6,10 +6,10 @@ import json
...
@@ -6,10 +6,10 @@ import json
import
torch
import
torch
import
requests
import
requests
from
nltk
import
word_tokenize
from
nltk
import
word_tokenize
from
megatron
import
mpu
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
print_rank_0
from
megatron
import
print_rank_0
from
megatron
import
get_tokenizer
from
megatron
import
get_tokenizer
from
megatron.core
import
mpu
from
megatron.model
import
GPTModel
from
megatron.model
import
GPTModel
from
megatron.training
import
get_model
from
megatron.training
import
get_model
from
megatron.checkpointing
import
load_checkpoint
from
megatron.checkpointing
import
load_checkpoint
...
...
tasks/orqa/supervised/eval_utils.py
View file @
5942af97
...
@@ -10,7 +10,7 @@ import torch.nn.functional as F
...
@@ -10,7 +10,7 @@ import torch.nn.functional as F
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
from
megatron
import
get_args
,
print_rank_0
from
megatron
import
get_args
,
print_rank_0
from
megatron
import
mpu
from
megatron
.core
import
mpu
from
megatron.utils
import
average_losses_across_data_parallel_group
from
megatron.utils
import
average_losses_across_data_parallel_group
from
tasks.finetune_utils
import
build_data_loader
from
tasks.finetune_utils
import
build_data_loader
...
...
tasks/orqa/supervised/finetune.py
View file @
5942af97
...
@@ -9,8 +9,8 @@ import math
...
@@ -9,8 +9,8 @@ import math
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
megatron
import
get_args
,
get_timers
,
get_tokenizer
from
megatron
import
get_args
,
get_timers
,
get_tokenizer
,
print_rank_0
from
megatron
import
mpu
,
print_rank_0
from
megatron
.core
import
mpu
from
megatron.indexer
import
IndexBuilder
from
megatron.indexer
import
IndexBuilder
from
megatron.model.biencoder_model
import
biencoder_model_provider
from
megatron.model.biencoder_model
import
biencoder_model_provider
from
megatron.utils
import
average_losses_across_data_parallel_group
from
megatron.utils
import
average_losses_across_data_parallel_group
...
...
tasks/orqa/unsupervised/nq.py
View file @
5942af97
...
@@ -13,7 +13,7 @@ import torch
...
@@ -13,7 +13,7 @@ import torch
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
Dataset
,
BatchSampler
from
torch.utils.data
import
Dataset
,
BatchSampler
from
megatron
import
print_rank_0
,
get_args
,
get_tokenizer
,
mpu
from
megatron
import
print_rank_0
,
get_args
,
get_tokenizer
from
megatron.data.biencoder_dataset_utils
import
make_attention_mask
from
megatron.data.biencoder_dataset_utils
import
make_attention_mask
def
get_nq_dataset
(
qa_data
,
split
):
def
get_nq_dataset
(
qa_data
,
split
):
...
...
tasks/race/finetune.py
View file @
5942af97
...
@@ -5,7 +5,6 @@
...
@@ -5,7 +5,6 @@
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
print_rank_0
from
megatron
import
print_rank_0
from
megatron
import
get_tokenizer
from
megatron
import
get_tokenizer
from
megatron
import
mpu
from
megatron.model.multiple_choice
import
MultipleChoice
from
megatron.model.multiple_choice
import
MultipleChoice
from
tasks.eval_utils
import
accuracy_func_provider
from
tasks.eval_utils
import
accuracy_func_provider
from
tasks.finetune_utils
import
finetune
from
tasks.finetune_utils
import
finetune
...
...
tasks/vision/classification/eval_utils.py
View file @
5942af97
...
@@ -9,7 +9,7 @@ import torch
...
@@ -9,7 +9,7 @@ import torch
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
print_rank_0
,
print_rank_last
from
megatron
import
print_rank_0
,
print_rank_last
from
megatron
import
mpu
from
megatron
.core
import
mpu
from
megatron.schedules
import
get_forward_backward_func
from
megatron.schedules
import
get_forward_backward_func
from
tasks.vision.finetune_utils
import
build_data_loader
from
tasks.vision.finetune_utils
import
build_data_loader
from
tasks.vision.finetune_utils
import
process_batch
from
tasks.vision.finetune_utils
import
process_batch
...
...
tasks/vision/finetune_utils.py
View file @
5942af97
...
@@ -7,7 +7,8 @@ import torch.nn.functional as F
...
@@ -7,7 +7,8 @@ import torch.nn.functional as F
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
print_rank_0
from
megatron
import
print_rank_0
from
megatron
import
get_timers
from
megatron
import
get_timers
from
megatron
import
mpu
,
utils
from
megatron
import
utils
from
megatron.core
import
mpu
from
megatron.checkpointing
import
load_checkpoint
from
megatron.checkpointing
import
load_checkpoint
from
megatron.checkpointing
import
save_checkpoint
from
megatron.checkpointing
import
save_checkpoint
from
megatron.training
import
evaluate_and_print_results
from
megatron.training
import
evaluate_and_print_results
...
...
tasks/vision/segmentation/finetune_segformer.py
View file @
5942af97
...
@@ -7,7 +7,8 @@ import torch
...
@@ -7,7 +7,8 @@ import torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
functools
import
partial
from
functools
import
partial
from
megatron
import
get_args
,
get_timers
from
megatron
import
get_args
,
get_timers
from
megatron
import
mpu
,
print_rank_0
,
print_rank_last
from
megatron
import
print_rank_0
,
print_rank_last
from
megatron.core
import
mpu
from
tasks.vision.finetune_utils
import
finetune
from
tasks.vision.finetune_utils
import
finetune
from
tasks.vision.finetune_utils
import
build_data_loader
from
tasks.vision.finetune_utils
import
build_data_loader
from
megatron.utils
import
average_losses_across_data_parallel_group
from
megatron.utils
import
average_losses_across_data_parallel_group
...
...
tasks/vision/segmentation/finetune_setr.py
View file @
5942af97
...
@@ -6,7 +6,8 @@ import torch
...
@@ -6,7 +6,8 @@ import torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
functools
import
partial
from
functools
import
partial
from
megatron
import
get_args
,
get_timers
from
megatron
import
get_args
,
get_timers
from
megatron
import
mpu
,
print_rank_0
,
print_rank_last
from
megatron
import
print_rank_0
,
print_rank_last
from
megatron.core
import
mpu
from
tasks.vision.finetune_utils
import
finetune
from
tasks.vision.finetune_utils
import
finetune
from
tasks.vision.finetune_utils
import
build_data_loader
from
tasks.vision.finetune_utils
import
build_data_loader
from
megatron.utils
import
average_losses_across_data_parallel_group
from
megatron.utils
import
average_losses_across_data_parallel_group
...
...
tasks/zeroshot_gpt/evaluate.py
View file @
5942af97
...
@@ -9,7 +9,7 @@ import torch
...
@@ -9,7 +9,7 @@ import torch
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
print_rank_0
,
is_last_rank
from
megatron
import
print_rank_0
,
is_last_rank
from
megatron
import
get_tokenizer
from
megatron
import
get_tokenizer
from
megatron
import
mpu
from
megatron
.core
import
mpu
from
megatron.checkpointing
import
load_checkpoint
from
megatron.checkpointing
import
load_checkpoint
from
megatron.model
import
GPTModel
from
megatron.model
import
GPTModel
from
megatron.training
import
get_model
from
megatron.training
import
get_model
...
@@ -93,7 +93,7 @@ def forward_step(batch, model, eval_metric):
...
@@ -93,7 +93,7 @@ def forward_step(batch, model, eval_metric):
if
mpu
.
is_pipeline_last_stage
():
if
mpu
.
is_pipeline_last_stage
():
# For loss, return the unreduced loss.
# For loss, return the unreduced loss.
if
eval_metric
==
'loss'
:
if
eval_metric
==
'loss'
:
losses
=
mpu
.
vocab_parallel_cross_entropy
(
losses
=
mpu
.
tensor_parallel
.
vocab_parallel_cross_entropy
(
output
.
contiguous
().
float
(),
labels
.
contiguous
())
output
.
contiguous
().
float
(),
labels
.
contiguous
())
loss
=
torch
.
sum
(
loss
=
torch
.
sum
(
losses
.
view
(
-
1
)
*
loss_mask
.
contiguous
().
view
(
-
1
).
float
())
losses
.
view
(
-
1
)
*
loss_mask
.
contiguous
().
view
(
-
1
).
float
())
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment