Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
601b19b7
Commit
601b19b7
authored
Mar 30, 2020
by
Mohammad
Browse files
tasks tested
parent
259062c2
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
139 additions
and
85 deletions
+139
-85
tasks/eval_utils.py
tasks/eval_utils.py
+10
-8
tasks/finetune_utils.py
tasks/finetune_utils.py
+48
-40
tasks/glue/finetune.py
tasks/glue/finetune.py
+28
-15
tasks/main.py
tasks/main.py
+19
-10
tasks/race/finetune.py
tasks/race/finetune.py
+34
-12
No files found.
tasks/eval_utils.py
View file @
601b19b7
...
...
@@ -20,26 +20,28 @@ import time
import
torch
from
megatron
import
get_args
from
megatron
import
mpu
from
megatron
import
print_rank_0
from
.finetune_utils
import
build_data_loader
from
.finetune_utils
import
process_batch
from
tasks
.finetune_utils
import
build_data_loader
from
tasks
.finetune_utils
import
process_batch
def
accuracy_func_provider
(
args
,
single_dataset_provider
):
def
accuracy_func_provider
(
single_dataset_provider
):
"""Provide function that calculates accuracies."""
args
=
get_args
()
# Build dataloaders.
datapaths
=
args
.
valid_data
dataloaders
=
[]
for
datapath
in
datapaths
:
dataset
=
single_dataset_provider
(
datapath
,
args
)
dataset
=
single_dataset_provider
(
datapath
)
dataloader
=
build_data_loader
(
dataset
,
args
.
batch_size
,
num_workers
=
args
.
num_workers
,
drop_last
=
(
mpu
.
get_data_parallel_world_size
()
>
1
))
dataloaders
.
append
((
dataset
.
dataset_name
,
dataloader
))
def
metrics_func
(
model
,
args_
,
epoch
,
output_predictions
=
False
):
def
metrics_func
(
model
,
epoch
,
output_predictions
=
False
):
print_rank_0
(
'calculating metrics ...'
)
correct
=
0
total
=
0
...
...
@@ -48,7 +50,7 @@ def accuracy_func_provider(args, single_dataset_provider):
named_predictions
=
[]
names
=
'predictions'
for
name
,
dataloader
in
dataloaders
:
output
=
calculate_correct_answers
(
name
,
model
,
dataloader
,
args_
,
output
=
calculate_correct_answers
(
name
,
model
,
dataloader
,
epoch
,
output_predictions
)
if
not
output_predictions
:
correct_ans
,
total_count
=
output
...
...
@@ -70,7 +72,7 @@ def accuracy_func_provider(args, single_dataset_provider):
return
metrics_func
def
calculate_correct_answers
(
name
,
model
,
dataloader
,
args
,
def
calculate_correct_answers
(
name
,
model
,
dataloader
,
epoch
,
output_predictions
):
"""Calculate correct over total answers and return prediction if the
`output_predictions` is true."""
...
...
@@ -89,7 +91,7 @@ def calculate_correct_answers(name, model, dataloader, args,
ids
=
[]
for
_
,
batch
in
enumerate
(
dataloader
):
# Run the model forward.
tokens
,
types
,
labels_
,
attention_mask
=
process_batch
(
batch
,
args
)
tokens
,
types
,
labels_
,
attention_mask
=
process_batch
(
batch
)
logits
=
model
(
tokens
,
attention_mask
,
types
)
# Add output predictions.
if
output_predictions
:
...
...
tasks/finetune_utils.py
View file @
601b19b7
...
...
@@ -17,22 +17,23 @@
import
torch
from
megatron
import
get_args
from
megatron
import
get_timers
from
megatron
import
mpu
from
megatron.data.tokenizer
import
add_tokenizer_to_args
from
megatron
import
print_rank_0
from
megatron.checkpointing
import
load_checkpoint
from
megatron.checkpointing
import
save_checkpoint
from
megatron.training
import
evaluate_and_print_results
from
megatron.training
import
initialize_megatron
from
megatron.training
import
setup_model_and_optimizer
from
megatron.training
import
train_step
from
megatron.training
import
training_log
from
megatron.utils
import
check_adlr_autoresume_termination
from
megatron.utils
import
load_checkpoint
from
megatron
import
print_rank_0
from
megatron.utils
import
reduce_losses
from
megatron.utils
import
save_checkpoint
def
process_batch
(
batch
,
args
):
def
process_batch
(
batch
):
"""Process batch and produce inputs for the model."""
args
=
get_args
()
tokens
=
batch
[
'text'
].
long
().
cuda
().
contiguous
()
types
=
batch
[
'types'
].
long
().
cuda
().
contiguous
()
...
...
@@ -44,8 +45,9 @@ def process_batch(batch, args):
return
tokens
,
types
,
labels
,
attention_mask
def
_cross_entropy_forward_step
(
batch
,
model
,
args
,
timers
):
def
_cross_entropy_forward_step
(
batch
,
model
):
"""Simple forward step with cross-entropy loss."""
timers
=
get_timers
()
# Get the batch.
timers
(
'batch generator'
).
start
()
...
...
@@ -53,7 +55,7 @@ def _cross_entropy_forward_step(batch, model, args, timers):
batch_
=
next
(
batch
)
except
:
batch_
=
batch
tokens
,
types
,
labels
,
attention_mask
=
process_batch
(
batch_
,
args
)
tokens
,
types
,
labels
,
attention_mask
=
process_batch
(
batch_
)
timers
(
'batch generator'
).
stop
()
# Forward model.
...
...
@@ -101,8 +103,9 @@ def _build_infinite_size_dataloader(dataloader):
iterator
=
dataloader
.
__iter__
()
def
_build_train_valid_dataloaders
(
train_dataset
,
valid_dataset
,
args
):
def
_build_train_valid_dataloaders
(
train_dataset
,
valid_dataset
):
"""Traing and validation dataloaders."""
args
=
get_args
()
print_rank_0
(
'building train and validation dataloaders ...'
)
# Training dataset.
...
...
@@ -121,9 +124,10 @@ def _build_train_valid_dataloaders(train_dataset, valid_dataset, args):
def
_train
(
model
,
optimizer
,
lr_scheduler
,
forward_step
,
train_dataloader
,
valid_dataloader
,
end_of_epoch_callback
,
timers
,
args
,
writer
):
train_dataloader
,
valid_dataloader
,
end_of_epoch_callback
):
"""Train the model."""
args
=
get_args
()
timers
=
get_timers
()
# Turn on training mode which enables dropout.
model
.
train
()
...
...
@@ -157,95 +161,99 @@ def _train(model, optimizer, lr_scheduler, forward_step,
start_iteration
=
0
# Train for one step.
losses_dict
,
_
=
train_step
(
forward_step
,
batch
,
model
,
optimizer
,
lr_scheduler
,
args
,
timers
)
losses_dict
,
_
=
train_step
(
forward_step
,
batch
,
model
,
optimizer
,
lr_scheduler
)
iteration
+=
1
# Logging.
report_memory_flag
=
training_log
(
losses_dict
,
losses_dict_sum
,
optimizer
.
param_groups
[
0
][
'lr'
],
iteration
,
optimizer
.
loss_scale
,
report_memory_flag
,
writer
,
args
,
timers
)
report_memory_flag
)
# Autoresume
if
args
.
adlr_autoresume
and
\
(
iteration
%
args
.
adlr_autoresume_interval
==
0
):
check_adlr_autoresume_termination
(
iteration
,
model
,
optimizer
,
lr_scheduler
,
args
)
check_adlr_autoresume_termination
(
iteration
,
model
,
optimizer
,
lr_scheduler
)
# Checkpointing
if
args
.
save
and
args
.
save_interval
and
\
iteration
%
args
.
save_interval
==
0
:
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
,
args
)
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
)
# Evaluation
if
args
.
eval_interval
and
iteration
%
args
.
eval_interval
==
0
:
prefix
=
'iteration {}'
.
format
(
iteration
)
evaluate_and_print_results
(
prefix
,
forward_step
,
valid_dataloader
,
model
,
args
,
writer
,
iteration
,
timers
,
False
)
valid_dataloader
,
model
,
iteration
,
False
)
# Checkpointing at the end of each epoch.
if
args
.
save
:
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
,
args
)
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
)
# Callback at the end of each epoch.
if
end_of_epoch_callback
is
not
None
:
end_of_epoch_callback
(
model
,
args
,
epoch
)
end_of_epoch_callback
(
model
,
epoch
)
def
finetune
(
args
,
train_valid_datasets_provider
,
model_provider
,
def
finetune
(
train_valid_datasets_provider
,
model_provider
,
forward_step
=
_cross_entropy_forward_step
,
end_of_epoch_callback_provider
=
None
):
"""Main finetune function used across all tasks."""
# Initialize megatron and get args, timers, and Tensorboard writer.
timers
,
writer
=
initialize_megatron
(
'finetune model for {} ...'
.
format
(
args
.
task
),
args
)
# Add tokenizer to the args.
add_tokenizer_to_args
(
args
,
args
.
tokenizer_type
)
args
=
get_args
()
timers
=
get_timers
()
# Train and validation data loaders.
timers
(
'train/valid/test dataset/dataloder'
).
start
()
if
args
.
epochs
>
0
:
train_dataset
,
valid_dataset
=
train_valid_datasets_provider
(
args
)
train_dataset
,
valid_dataset
=
train_valid_datasets_provider
()
train_dataloader
,
valid_dataloader
=
_build_train_valid_dataloaders
(
train_dataset
,
valid_dataset
,
args
)
train_dataset
,
valid_dataset
)
timers
(
'train/valid/test dataset/dataloder'
).
stop
()
# Build calback function.
timers
(
'callback function'
).
start
()
end_of_epoch_callback
=
None
if
end_of_epoch_callback_provider
is
not
None
:
end_of_epoch_callback
=
end_of_epoch_callback_provider
(
args
)
end_of_epoch_callback
=
end_of_epoch_callback_provider
()
timers
(
'callback function'
).
stop
()
# Build model, optimizer and learning rate scheduler.
model
,
optimizer
,
lr_scheduler
=
setup_model_and_optimizer
(
model_provider
,
args
)
timers
(
'model and optimizer'
).
start
()
model
,
optimizer
,
lr_scheduler
=
setup_model_and_optimizer
(
model_provider
)
timers
(
'model and optimizer'
).
stop
()
# If pretrained checkpoint is provided and we have not trained for
# any iteration (i.e., iteration is zero), then load the pretrained
# checkpoint.
timers
(
'pretrained checkpoint'
).
start
()
if
args
.
iteration
==
0
and
args
.
pretrained_checkpoint
is
not
None
:
original_load
=
args
.
load
args
.
load
=
args
.
pretrained_checkpoint
_
=
load_checkpoint
(
model
,
None
,
None
,
args
)
_
=
load_checkpoint
(
model
,
None
,
None
)
args
.
load
=
original_load
# This is critical when only model is loaded. We should make sure
# master parameters are also updated.
if
args
.
fp16
:
optimizer
.
_model_params_to_master_params
()
timers
(
'pretrained checkpoint'
).
stop
()
# Print setup timing.
print_rank_0
(
'done with setups ...'
)
timers
.
log
([
'train/valid/test dataset/dataloder'
,
'callback function'
,
'model and optimizer'
,
'pretrained checkpoint'
])
print_rank_0
(
'training ...'
)
# Finetune the model.
if
args
.
epochs
>
0
:
_train
(
model
,
optimizer
,
lr_scheduler
,
forward_step
,
train_dataloader
,
valid_dataloader
,
end_of_epoch_callback
,
timers
,
args
,
writer
)
train_dataloader
,
valid_dataloader
,
end_of_epoch_callback
)
# Or just evaluate.
else
:
if
end_of_epoch_callback
is
not
None
:
print_rank_0
(
'evaluation only mode, setting epoch to -1'
)
end_of_epoch_callback
(
model
,
args
,
epoch
=-
1
,
output_predictions
=
True
)
end_of_epoch_callback
(
model
,
epoch
=-
1
,
output_predictions
=
True
)
print_rank_0
(
'done :-)'
)
tasks/glue/finetune.py
View file @
601b19b7
...
...
@@ -15,32 +15,41 @@
"""GLUE finetuning/evaluation."""
from
megatron
import
get_args
from
megatron
import
get_tokenizer
from
megatron
import
print_rank_0
from
megatron.model.classification
import
Classification
from
tasks.eval_utils
import
accuracy_func_provider
from
tasks.finetune_utils
import
finetune
def
glue_classification
(
args
,
num_classes
,
Dataset
,
def
glue_classification
(
num_classes
,
Dataset
,
name_from_datapath_func
):
def
train_valid_datasets_provider
(
args
):
def
train_valid_datasets_provider
():
"""Build train and validation dataset."""
args
=
get_args
()
tokenizer
=
get_tokenizer
()
train_dataset
=
Dataset
(
'training'
,
args
.
train_data
,
args
.
tokenizer
,
args
.
seq_length
)
tokenizer
,
args
.
seq_length
)
valid_dataset
=
Dataset
(
'validation'
,
args
.
valid_data
,
args
.
tokenizer
,
args
.
seq_length
)
tokenizer
,
args
.
seq_length
)
return
train_dataset
,
valid_dataset
def
model_provider
(
args
):
def
model_provider
():
"""Build the model."""
args
=
get_args
()
print_rank_0
(
'building classification model for {} ...'
.
format
(
args
.
task
))
return
Classification
(
num_classes
=
num_classes
,
num_layers
=
args
.
num_layers
,
vocab_size
=
args
.
vocab_size
,
vocab_size
=
args
.
padded_
vocab_size
,
hidden_size
=
args
.
hidden_size
,
num_attention_heads
=
args
.
num_attention_heads
,
embedding_dropout_prob
=
args
.
hidden_dropout
,
...
...
@@ -50,25 +59,29 @@ def glue_classification(args, num_classes, Dataset,
checkpoint_activations
=
args
.
checkpoint_activations
)
def
metrics_func_provider
(
args
):
def
metrics_func_provider
():
"""Privde metrics callback function."""
def
single_dataset_provider
(
datapath
,
args
):
def
single_dataset_provider
(
datapath
):
args
=
get_args
()
tokenizer
=
get_tokenizer
()
name
=
name_from_datapath_func
(
datapath
)
return
Dataset
(
name
,
[
datapath
],
args
.
tokenizer
,
args
.
seq_length
)
return
accuracy_func_provider
(
args
,
single_dataset_provider
)
return
Dataset
(
name
,
[
datapath
],
tokenizer
,
args
.
seq_length
)
return
accuracy_func_provider
(
single_dataset_provider
)
"""Finetune/evaluate."""
finetune
(
args
,
train_valid_datasets_provider
,
model_provider
,
finetune
(
train_valid_datasets_provider
,
model_provider
,
end_of_epoch_callback_provider
=
metrics_func_provider
)
def
main
(
args
):
def
main
():
args
=
get_args
()
if
args
.
task
==
'MNLI'
:
num_classes
=
3
from
.mnli
import
MNLIDataset
as
Dataset
from
tasks.glue
.mnli
import
MNLIDataset
as
Dataset
def
name_from_datapath
(
datapath
):
return
datapath
.
split
(
'MNLI'
)[
-
1
].
strip
(
'.tsv'
).
strip
(
'/'
).
replace
(
'_'
,
'-'
)
...
...
@@ -76,7 +89,7 @@ def main(args):
elif
args
.
task
==
'QQP'
:
num_classes
=
2
from
.qqp
import
QQPDataset
as
Dataset
from
tasks.glue
.qqp
import
QQPDataset
as
Dataset
def
name_from_datapath
(
datapath
):
return
datapath
.
split
(
'QQP'
)[
-
1
].
strip
(
'.tsv'
).
strip
(
'/'
).
replace
(
'_'
,
'-'
)
...
...
@@ -85,4 +98,4 @@ def main(args):
raise
NotImplementedError
(
'GLUE task {} is not implemented.'
.
format
(
args
.
task
))
glue_classification
(
args
,
num_classes
,
Dataset
,
name_from_datapath
)
glue_classification
(
num_classes
,
Dataset
,
name_from_datapath
)
tasks/main.py
View file @
601b19b7
...
...
@@ -20,29 +20,38 @@ import sys
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
os
.
path
.
pardir
)))
from
arguments
import
get_args
from
megatron
import
get_args
from
megatron.initialize
import
initialize_megatron
def
get_tasks_args
(
parser
):
"""Provide extra arguments required for tasks."""
group
=
parser
.
add_argument_group
(
'tasks'
,
'tasks configurations'
)
parser
.
add_argument
(
'--task'
,
type
=
str
,
required
=
True
,
help
=
'task name.'
)
group
=
parser
.
add_argument_group
(
title
=
'tasks'
)
group
.
add_argument
(
'--task'
,
type
=
str
,
required
=
True
,
help
=
'Task name.'
)
group
.
add_argument
(
'--epochs'
,
type
=
int
,
required
=
True
,
help
=
'
n
umber of finetunning epochs. Zero results in '
help
=
'
N
umber of finetunning epochs. Zero results in '
'evaluation only.'
)
parser
.
add_argument
(
'--pretrained-checkpoint'
,
type
=
str
,
default
=
None
,
help
=
'
p
retrained checkpoint used for finetunning.'
)
group
.
add_argument
(
'--pretrained-checkpoint'
,
type
=
str
,
default
=
None
,
help
=
'
P
retrained checkpoint used for finetunning.'
)
group
.
add_argument
(
'--keep-last'
,
action
=
'store_true'
,
help
=
'
k
eep the last batch (maybe incomplete) in'
help
=
'
K
eep the last batch (maybe incomplete) in'
'the data loader'
)
group
.
add_argument
(
'--train-data'
,
nargs
=
'+'
,
default
=
None
,
help
=
'Whitespace separated paths or corpora names '
'for training.'
)
group
.
add_argument
(
'--valid-data'
,
nargs
=
'*'
,
default
=
None
,
help
=
'path(s) to the validation data.'
)
return
parser
if
__name__
==
'__main__'
:
args
=
get_args
(
extra_args_provider
=
get_tasks_args
)
initialize_megatron
(
extra_args_provider
=
get_tasks_args
)
args
=
get_args
()
if
args
.
task
==
'RACE'
:
from
race.finetune
import
main
elif
args
.
task
in
[
'MNLI'
,
'QQP'
]:
...
...
@@ -51,4 +60,4 @@ if __name__ == '__main__':
raise
NotImplementedError
(
'Task {} is not implemented.'
.
format
(
args
.
task
))
main
(
args
)
main
()
tasks/race/finetune.py
View file @
601b19b7
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Race."""
from
megatron.model.multiple_choice
import
MultipleChoice
from
megatron
import
get_args
from
megatron
import
get_tokenizer
from
megatron
import
print_rank_0
from
megatron.model.multiple_choice
import
MultipleChoice
from
tasks.eval_utils
import
accuracy_func_provider
from
tasks.finetune_utils
import
finetune
from
tasks.race.data
import
RaceDataset
def
train_valid_datasets_provider
(
args
):
def
train_valid_datasets_provider
():
"""Provide train and validation datasets."""
args
=
get_args
()
tokenizer
=
get_tokenizer
()
train_dataset
=
RaceDataset
(
'training'
,
args
.
train_data
,
args
.
tokenizer
,
args
.
seq_length
)
tokenizer
,
args
.
seq_length
)
valid_dataset
=
RaceDataset
(
'validation'
,
args
.
valid_data
,
args
.
tokenizer
,
args
.
seq_length
)
tokenizer
,
args
.
seq_length
)
return
train_dataset
,
valid_dataset
def
model_provider
(
args
):
def
model_provider
():
"""Build the model."""
args
=
get_args
()
print_rank_0
(
'building multichoice model for RACE ...'
)
return
MultipleChoice
(
num_layers
=
args
.
num_layers
,
vocab_size
=
args
.
vocab_size
,
vocab_size
=
args
.
padded_
vocab_size
,
hidden_size
=
args
.
hidden_size
,
num_attention_heads
=
args
.
num_attention_heads
,
embedding_dropout_prob
=
args
.
hidden_dropout
,
...
...
@@ -35,17 +55,19 @@ def model_provider(args):
checkpoint_activations
=
args
.
checkpoint_activations
)
def
metrics_func_provider
(
args
):
def
metrics_func_provider
():
"""Privde metrics callback function."""
args
=
get_args
()
tokenizer
=
get_tokenizer
()
def
single_dataset_provider
(
datapath
,
args
):
def
single_dataset_provider
(
datapath
):
name
=
datapath
.
split
(
'RACE'
)[
-
1
].
strip
(
'/'
).
replace
(
'/'
,
'-'
)
return
RaceDataset
(
name
,
[
datapath
],
args
.
tokenizer
,
args
.
seq_length
)
return
RaceDataset
(
name
,
[
datapath
],
tokenizer
,
args
.
seq_length
)
return
accuracy_func_provider
(
args
,
single_dataset_provider
)
return
accuracy_func_provider
(
single_dataset_provider
)
def
main
(
args
):
def
main
():
finetune
(
args
,
train_valid_datasets_provider
,
model_provider
,
finetune
(
train_valid_datasets_provider
,
model_provider
,
end_of_epoch_callback_provider
=
metrics_func_provider
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment