Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
01fc0833
Commit
01fc0833
authored
Jun 01, 2021
by
Jared Casper
Browse files
Merge branch 'vit_pipeline_fixes' into 'main'
vit pipeline fixes See merge request ADLR/megatron-lm!276
parents
217f54b3
ccae9dbd
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
197 additions
and
111 deletions
+197
-111
megatron/checkpointing.py
megatron/checkpointing.py
+1
-1
megatron/model/vit_model.py
megatron/model/vit_model.py
+76
-52
pretrain_vit.py
pretrain_vit.py
+20
-14
tasks/vision/classification.py
tasks/vision/classification.py
+3
-2
tasks/vision/eval_utils.py
tasks/vision/eval_utils.py
+56
-20
tasks/vision/finetune_utils.py
tasks/vision/finetune_utils.py
+41
-22
No files found.
megatron/checkpointing.py
View file @
01fc0833
...
...
@@ -60,8 +60,8 @@ def check_checkpoint_args(checkpoint_args):
_compare
(
'num_layers'
)
_compare
(
'hidden_size'
)
_compare
(
'num_attention_heads'
)
_compare
(
'max_position_embeddings'
)
if
args
.
vocab_file
:
_compare
(
'max_position_embeddings'
)
_compare
(
'make_vocab_size_divisible_by'
)
_compare
(
'padded_vocab_size'
)
_compare
(
'tokenizer_type'
)
...
...
megatron/model/vit_model.py
View file @
01fc0833
...
...
@@ -50,11 +50,11 @@ class VitMlpHead(MegatronModule):
def
forward
(
self
,
hidden_states
,
sequence_index
=
0
):
# hidden_states: [b, s, h]
# sequence_index: index of the token to pool.
x
=
hidden_states
[:,
sequence_index
,
:]
x
=
self
.
dense_in
(
x
)
x
=
torch
.
tanh
(
x
)
x
=
self
.
dense_out
(
x
)
return
x
hidden_state
=
hidden_states
[:,
sequence_index
,
:]
dense_in_result
=
self
.
dense_in
(
hidden_state
)
tanh_result
=
torch
.
tanh
(
dense_in_result
)
dense_out_result
=
self
.
dense_out
(
tanh_result
)
return
dense_out_result
def
twod_interpolate_position_embeddings_hook
(
...
...
@@ -122,8 +122,12 @@ def twod_interpolate_position_embeddings_hook(
class
VitModel
(
MegatronModule
):
"""Vision Transformer Model."""
def
__init__
(
self
,
num_classes
,
finetune
=
False
):
super
(
VitModel
,
self
).
__init__
()
def
__init__
(
self
,
num_classes
,
finetune
=
False
,
pre_process
=
True
,
post_process
=
True
):
super
(
VitModel
,
self
).
__init__
(
share_word_embeddings
=
False
)
args
=
get_args
()
self
.
fp16_lm_cross_entropy
=
args
.
fp16_lm_cross_entropy
...
...
@@ -136,6 +140,8 @@ class VitModel(MegatronModule):
args
.
init_method_std
,
args
.
num_layers
)
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
hidden_size
=
args
.
hidden_size
self
.
num_classes
=
num_classes
self
.
patch_dim
=
args
.
patch_dim
...
...
@@ -148,63 +154,81 @@ class VitModel(MegatronModule):
self
.
seq_length
=
self
.
num_patches
+
1
self
.
flatten_dim
=
self
.
patch_dim
*
self
.
patch_dim
*
args
.
num_channels
# cls_token
self
.
cls_token
=
torch
.
nn
.
Parameter
(
torch
.
randn
(
1
,
1
,
self
.
hidden_size
))
torch
.
nn
.
init
.
zeros_
(
self
.
cls_token
)
if
self
.
pre_process
:
# cls_token
self
.
cls_token
=
torch
.
nn
.
Parameter
(
torch
.
randn
(
1
,
1
,
self
.
hidden_size
)
)
torch
.
nn
.
init
.
zeros_
(
self
.
cls_token
)
# Linear encoder
self
.
linear_encoder
=
torch
.
nn
.
Linear
(
self
.
flatten_dim
,
self
.
hidden_size
)
# Linear encoder
self
.
linear_encoder
=
torch
.
nn
.
Linear
(
self
.
flatten_dim
,
self
.
hidden_size
)
# embedding
self
.
position_embeddings
=
torch
.
nn
.
Embedding
(
self
.
seq_length
,
self
.
hidden_size
)
init_method_normal
(
args
.
init_method_std
)(
self
.
position_embeddings
.
weight
)
self
.
position_ids
=
torch
.
arange
(
self
.
seq_length
).
expand
(
1
,
-
1
).
cuda
()
# embedding
self
.
position_embeddings
=
torch
.
nn
.
Embedding
(
self
.
seq_length
,
self
.
hidden_size
)
init_method_normal
(
args
.
init_method_std
)(
self
.
position_embeddings
.
weight
)
self
.
position_ids
=
torch
.
arange
(
self
.
seq_length
).
expand
(
1
,
-
1
).
cuda
()
self
.
position_embeddings
.
_register_load_state_dict_pre_hook
(
twod_interpolate_position_embeddings_hook
)
self
.
position_embeddings
.
_register_load_state_dict_pre_hook
(
twod_interpolate_position_embeddings_hook
)
self
.
embedding_dropout
=
torch
.
nn
.
Dropout
(
args
.
hidden_dropout
)
self
.
embedding_dropout
=
torch
.
nn
.
Dropout
(
args
.
hidden_dropout
)
# Transformer
self
.
transformer
=
ParallelTransformer
(
self
.
init_method
,
self
.
scaled_init_method
self
.
init_method
,
self
.
scaled_init_method
,
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
)
# MLP head
if
not
self
.
finetune
:
self
.
mlp_head
=
VitMlpHead
(
self
.
hidden_size
,
self
.
num_classes
)
else
:
self
.
class_head
=
get_linear_layer
(
self
.
hidden_size
,
num_classes
,
torch
.
nn
.
init
.
zeros_
if
self
.
post_process
:
# MLP head
if
not
self
.
finetune
:
self
.
mlp_head
=
VitMlpHead
(
self
.
hidden_size
,
self
.
num_classes
)
else
:
self
.
class_head
=
get_linear_layer
(
self
.
hidden_size
,
num_classes
,
torch
.
nn
.
init
.
zeros_
)
def
set_input_tensor
(
self
,
input_tensor
):
"""See megatron.model.transformer.set_input_tensor()"""
self
.
transformer
.
set_input_tensor
(
input_tensor
)
def
forward
(
self
,
input
):
if
self
.
pre_process
:
rearranged_input
=
einops
.
rearrange
(
input
,
"b c (h p1) (w p2) -> b (h w) (p1 p2 c)"
,
p1
=
self
.
patch_dim
,
p2
=
self
.
patch_dim
,
)
def
forward
(
self
,
x
):
x
=
einops
.
rearrange
(
x
,
"b c (h p1) (w p2) -> b (h w) (p1 p2 c)"
,
p1
=
self
.
patch_dim
,
p2
=
self
.
patch_dim
,
)
assert
rearranged_input
.
dtype
==
torch
.
half
encoder_output
=
self
.
linear_encoder
(
rearranged_input
)
cls_tokens
=
self
.
cls_token
.
expand
(
encoder_output
.
shape
[
0
],
-
1
,
-
1
)
concatenated_tokens
=
torch
.
cat
((
cls_tokens
,
encoder_output
),
dim
=
1
)
assert
x
.
dtype
==
torch
.
half
x
=
self
.
linear_encoder
(
x
)
cls_tokens
=
self
.
cls_token
.
expand
(
x
.
shape
[
0
],
-
1
,
-
1
)
x
=
torch
.
cat
((
cls_tokens
,
x
),
dim
=
1
)
token_embeddings
=
concatenated_tokens
+
\
self
.
position_embeddings
(
self
.
position_ids
)
hidden_states
=
self
.
embedding_dropout
(
token_embeddings
)
else
:
hidden_states
=
input
x
=
x
+
self
.
position_embeddings
(
self
.
position_ids
)
x
=
self
.
embedding_dropout
(
x
)
x
=
self
.
transformer
(
x
,
None
)
hidden_states
=
self
.
transformer
(
hidden_states
,
None
)
if
not
self
.
finetune
:
x
=
self
.
mlp_head
(
x
)
else
:
x
=
self
.
class_head
(
x
[:,
0
,
:])
if
self
.
post_process
:
if
not
self
.
finetune
:
hidden_states
=
self
.
mlp_head
(
hidden_states
)
else
:
hidden_states
=
self
.
class_head
(
hidden_states
[:,
0
,
:])
return
x
return
hidden_states
pretrain_vit.py
View file @
01fc0833
...
...
@@ -17,19 +17,22 @@
import
torch
import
torch.nn.functional
as
F
from
functools
import
partial
from
megatron
import
get_args
,
get_timers
,
mpu
,
print_rank_0
from
megatron.data.vit_dataset
import
build_train_valid_datasets
from
megatron.model.vit_model
import
VitModel
from
megatron.training
import
pretrain
from
megatron.utils
import
average_losses_across_data_parallel_group
def
model_provider
():
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
"""Build the model."""
print_rank_0
(
"building VIT model ..."
)
args
=
get_args
()
model
=
VitModel
(
num_classes
=
args
.
num_classes
)
model
=
VitModel
(
num_classes
=
args
.
num_classes
,
pre_process
=
pre_process
,
post_process
=
post_process
)
return
model
def
get_batch
(
data_iterator
):
...
...
@@ -42,10 +45,21 @@ def get_batch(data_iterator):
return
images
,
labels
def
forward_step
(
data_iterator
,
model
,
input_tensor
):
def
loss_func
(
labels
,
output_tensor
):
logits
=
output_tensor
.
contiguous
().
float
()
loss
=
F
.
cross_entropy
(
logits
,
labels
)
outputs
=
torch
.
argmax
(
logits
,
-
1
)
correct
=
(
outputs
==
labels
).
float
()
accuracy
=
torch
.
mean
(
correct
)
averaged_loss
=
average_losses_across_data_parallel_group
([
loss
,
accuracy
])
return
loss
,
{
"loss"
:
averaged_loss
[
0
],
"accuracy"
:
averaged_loss
[
1
]}
def
forward_step
(
data_iterator
,
model
):
"""Forward step."""
timers
=
get_timers
()
assert
input_tensor
is
None
# Get the batch.
timers
(
"batch-generator"
).
start
()
...
...
@@ -56,17 +70,9 @@ def forward_step(data_iterator, model, input_tensor):
timers
(
"batch-generator"
).
stop
()
# Forward model. lm_labels
logits
=
model
(
images
).
contiguous
().
float
()
loss
=
F
.
cross_entropy
(
logits
,
labels
)
outputs
=
torch
.
argmax
(
logits
,
-
1
)
correct
=
(
outputs
==
labels
).
float
()
accuracy
=
torch
.
mean
(
correct
)
averaged_loss
=
average_losses_across_data_parallel_group
([
loss
,
accuracy
])
return
loss
,
{
"loss"
:
averaged_loss
[
0
],
"accuracy"
:
averaged_loss
[
1
]}
output_tensor
=
model
(
images
)
return
output_tensor
,
partial
(
loss_func
,
labels
)
def
train_valid_test_datasets_provider
(
train_val_test_num_samples
):
"""Build train, valid, and test datasets."""
...
...
tasks/vision/classification.py
View file @
01fc0833
...
...
@@ -34,13 +34,14 @@ def classification():
)
return
train_ds
,
valid_ds
def
model_provider
():
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
"""Build the model."""
args
=
get_args
()
print_rank_0
(
"building classification model for ImageNet ..."
)
return
VitModel
(
num_classes
=
args
.
num_classes
,
finetune
=
True
)
return
VitModel
(
num_classes
=
args
.
num_classes
,
finetune
=
True
,
pre_process
=
pre_process
,
post_process
=
post_process
)
"""Finetune/evaluate."""
finetune
(
...
...
tasks/vision/eval_utils.py
View file @
01fc0833
...
...
@@ -16,10 +16,14 @@
"""Evaluation utilities."""
import
os
from
functools
import
partial
import
torch
from
megatron
import
get_args
from
megatron
import
print_rank_0
from
megatron
import
print_rank_0
,
print_rank_last
from
megatron
import
mpu
from
megatron.schedules
import
get_forward_backward_func
from
tasks.vision.finetune_utils
import
build_data_loader
from
tasks.vision.finetune_utils
import
process_batch
from
torchvision
import
datasets
,
transforms
...
...
@@ -56,7 +60,7 @@ def accuracy_func_provider():
print_rank_0
(
"calculating metrics ..."
)
correct
,
total
=
calculate_correct_answers
(
model
,
dataloader
,
epoch
)
percent
=
float
(
correct
)
*
100.0
/
float
(
total
)
print_rank_
0
(
print_rank_
last
(
" >> |epoch: {}| overall: correct / total = {} / {} = "
"{:.4f} %"
.
format
(
epoch
,
correct
,
total
,
percent
)
)
...
...
@@ -67,29 +71,61 @@ def accuracy_func_provider():
def
calculate_correct_answers
(
model
,
dataloader
,
epoch
):
"""Calculate correct over total answers"""
model
.
eval
()
args
=
get_args
()
forward_backward_func
=
get_forward_backward_func
()
for
m
in
model
:
m
.
eval
()
def
loss_func
(
labels
,
output_tensor
):
logits
=
output_tensor
loss_dict
=
{}
# Compute the correct answers.
predicted
=
torch
.
argmax
(
logits
,
dim
=-
1
)
corrects
=
(
predicted
==
labels
).
float
()
# Add to the counters.
loss_dict
[
'total'
]
=
labels
.
size
(
0
)
loss_dict
[
'correct'
]
=
corrects
.
sum
().
item
()
return
0
,
loss_dict
#defined inside to capture output_predictions
def
correct_answers_forward_step
(
batch
,
model
):
try
:
batch_
=
next
(
batch
)
except
BaseException
:
batch_
=
batch
images
,
labels
=
process_batch
(
batch_
)
# Forward model.
args
=
get_args
()
output_tensor
=
model
(
images
)
return
output_tensor
,
partial
(
loss_func
,
labels
)
with
torch
.
no_grad
():
# For all the batches in the dataset.
total
=
0
correct
=
0
for
_
,
batch
in
enumerate
(
dataloader
):
# Run the model forward.
images
,
labels
=
process_batch
(
batch
)
logits
=
model
(
images
).
contiguous
().
float
()
# Add output predictions.
# Compute the correct answers.
predicted
=
torch
.
argmax
(
logits
,
dim
=-
1
)
corrects
=
(
predicted
==
labels
).
float
()
# Add to the counters.
total
+=
labels
.
size
(
0
)
correct
+=
corrects
.
sum
().
item
()
model
.
train
()
loss_dicts
=
forward_backward_func
(
correct_answers_forward_step
,
batch
,
model
,
optimizer
=
None
,
timers
=
None
,
forward_only
=
True
)
for
loss_dict
in
loss_dicts
:
total
+=
loss_dict
[
'total'
]
correct
+=
loss_dict
[
'correct'
]
for
m
in
model
:
m
.
train
()
# Reduce.
unreduced
=
torch
.
cuda
.
LongTensor
([
correct
,
total
])
torch
.
distributed
.
all_reduce
(
unreduced
,
group
=
mpu
.
get_data_parallel_group
())
if
mpu
.
is_pipeline_last_stage
():
unreduced
=
torch
.
cuda
.
LongTensor
([
correct
,
total
])
torch
.
distributed
.
all_reduce
(
unreduced
,
group
=
mpu
.
get_data_parallel_group
())
# Print on screen.
correct_ans
=
unreduced
[
0
].
item
()
total_count
=
unreduced
[
1
].
item
()
return
correct_ans
,
total_count
# Print on screen.
correct_ans
=
unreduced
[
0
].
item
()
total_count
=
unreduced
[
1
].
item
()
return
correct_ans
,
total_count
tasks/vision/finetune_utils.py
View file @
01fc0833
...
...
@@ -17,6 +17,7 @@
import
torch
import
torch.nn.functional
as
F
from
functools
import
partial
from
megatron
import
get_args
from
megatron
import
print_rank_0
from
megatron
import
get_timers
...
...
@@ -38,10 +39,21 @@ def process_batch(batch):
return
images
,
labels
def
_cross_entropy_forward_step
(
batch
,
model
,
input_tensor
):
def
cross_entropy_loss_func
(
labels
,
output_tensor
):
logits
=
output_tensor
# Cross-entropy loss.
loss
=
F
.
cross_entropy
(
logits
.
contiguous
().
float
(),
labels
)
# Reduce loss for logging.
averaged_loss
=
average_losses_across_data_parallel_group
([
loss
])
return
loss
,
{
'lm loss'
:
averaged_loss
[
0
]}
def
_cross_entropy_forward_step
(
batch
,
model
):
"""Simple forward step with cross-entropy loss."""
timers
=
get_timers
()
assert
input_tensor
is
None
# Get the batch.
timers
(
"batch generator"
).
start
()
...
...
@@ -52,16 +64,10 @@ def _cross_entropy_forward_step(batch, model, input_tensor):
images
,
labels
=
process_batch
(
batch_
)
timers
(
"batch generator"
).
stop
()
# Forward model.
logits
=
model
(
images
).
contiguous
().
float
()
# Cross-entropy loss.
loss
=
F
.
cross_entropy
(
logits
,
labels
)
# Reduce loss for logging.
average_loss
=
average_losses_across_data_parallel_group
([
loss
])
return
loss
,
{
"lm loss"
:
average_loss
[
0
]}
# Forward model.
output_tensor
=
model
(
images
)
return
output_tensor
,
partial
(
cross_entropy_loss_func
,
labels
)
def
build_data_loader
(
dataset
,
micro_batch_size
,
num_workers
,
drop_last
):
...
...
@@ -103,23 +109,28 @@ def _build_train_valid_dataloaders(train_dataset, valid_dataset):
"""Traing and validation dataloaders."""
args
=
get_args
()
print_rank_0
(
"
building train and validation dataloaders ...
"
)
print_rank_0
(
'
building train and validation dataloaders ...
'
)
# Training dataset.
train_dataloader
=
build_data_loader
(
train_dataset
,
args
.
micro_batch_size
,
args
.
num_workers
,
not
args
.
keep_last
)
train_dataloader
=
build_data_loader
(
train_dataset
,
args
.
micro_batch_size
,
args
.
num_workers
,
not
args
.
keep_last
)
# Set the training iterations.
args
.
train_iters_per_epoch
=
len
(
train_dataloader
)
args
.
train_iters
=
args
.
epochs
*
args
.
train_iters_per_epoch
# Validation dataset. For this dataset, we do not need to set up
# shuffling so we can just use a simple infinite loop.
valid_dataloader_
=
build_data_loader
(
valid_dataset
,
args
.
micro_batch_size
,
args
.
num_workers
,
not
args
.
keep_last
)
valid_dataloader_
=
build_data_loader
(
valid_dataset
,
args
.
micro_batch_size
,
args
.
num_workers
,
not
args
.
keep_last
)
valid_dataloader
=
_build_infinite_size_dataloader
(
valid_dataloader_
)
return
train_dataloader
,
valid_dataloader
# Now that we've built the data loaders, set batch_size arguments
# to the actual batch size the model will see for this dataset.
# This is necessary so pipeline transfers know what size they are
# and the LR schedule, which is based on samples seen, gets set
# correctly.
args
.
orig_micro_batch_size
=
args
.
micro_batch_size
args
.
orig_global_batch_size
=
args
.
global_batch_size
return
train_dataloader
,
valid_dataloader
def
_train
(
model
,
...
...
@@ -135,7 +146,8 @@ def _train(
timers
=
get_timers
()
# Turn on training mode which enables dropout.
model
.
train
()
for
m
in
model
:
m
.
train
()
# Tracking loss.
losses_dict_sum
=
{}
...
...
@@ -166,12 +178,16 @@ def _train(
start_iteration
=
0
# Train for one step.
losses_dict
,
skipped_iter
=
train_step
(
losses_dict
,
skipped_iter
,
grad_norm
,
num_zeros_in_grad
=
train_step
(
forward_step
,
batch
,
model
,
optimizer
,
lr_scheduler
)
iteration
+=
1
# Logging.
params_norm
=
None
if
args
.
log_params_norm
:
params_norm
=
calc_params_l2_norm
(
model
)
report_memory_flag
=
training_log
(
losses_dict
,
losses_dict_sum
,
...
...
@@ -180,6 +196,9 @@ def _train(
optimizer
.
get_loss_scale
().
item
(),
report_memory_flag
,
skipped_iter
,
grad_norm
,
params_norm
,
num_zeros_in_grad
)
# Autoresume
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment