Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
b938ec51
Commit
b938ec51
authored
Mar 23, 2021
by
Jared Casper
Browse files
Tasks seems to be working.
parent
3b91262e
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
154 additions
and
163 deletions
+154
-163
megatron/model/multiple_choice.py
megatron/model/multiple_choice.py
+2
-3
megatron/schedules.py
megatron/schedules.py
+13
-0
megatron/text_generation_utils.py
megatron/text_generation_utils.py
+11
-41
tasks/eval_utils.py
tasks/eval_utils.py
+69
-51
tasks/finetune_utils.py
tasks/finetune_utils.py
+25
-25
tasks/glue/finetune.py
tasks/glue/finetune.py
+4
-15
tasks/main.py
tasks/main.py
+5
-0
tasks/race/finetune.py
tasks/race/finetune.py
+5
-12
tasks/zeroshot_gpt/evaluate.py
tasks/zeroshot_gpt/evaluate.py
+11
-15
tools/generate_samples_gpt.py
tools/generate_samples_gpt.py
+9
-1
No files found.
megatron/model/multiple_choice.py
View file @
b938ec51
...
@@ -30,7 +30,7 @@ from .module import MegatronModule
...
@@ -30,7 +30,7 @@ from .module import MegatronModule
class
MultipleChoice
(
MegatronModule
):
class
MultipleChoice
(
MegatronModule
):
def
__init__
(
self
,
def
__init__
(
self
,
num_tokentypes
=
2
,
num_tokentypes
=
2
,
pre_process
=
True
,
pre_process
=
True
,
post_process
=
True
):
post_process
=
True
):
...
@@ -58,7 +58,7 @@ class MultipleChoice(MegatronModule):
...
@@ -58,7 +58,7 @@ class MultipleChoice(MegatronModule):
init_method
)
init_method
)
self
.
_multichoice_head_key
=
'multichoice_head'
self
.
_multichoice_head_key
=
'multichoice_head'
def
set_input_tensor
(
self
,
input_tensor
)
def
set_input_tensor
(
self
,
input_tensor
)
:
self
.
language_model
.
set_input_tensor
(
input_tensor
)
self
.
language_model
.
set_input_tensor
(
input_tensor
)
def
forward
(
self
,
model_input
,
attention_mask
,
tokentype_ids
=
None
):
def
forward
(
self
,
model_input
,
attention_mask
,
tokentype_ids
=
None
):
...
@@ -127,4 +127,3 @@ class MultipleChoice(MegatronModule):
...
@@ -127,4 +127,3 @@ class MultipleChoice(MegatronModule):
print_rank_last
(
'***WARNING*** could not find {} in the checkpoint, '
print_rank_last
(
'***WARNING*** could not find {} in the checkpoint, '
'initializing to random'
.
format
(
'initializing to random'
.
format
(
self
.
_multichoice_head_key
))
self
.
_multichoice_head_key
))
megatron/schedules.py
View file @
b938ec51
...
@@ -24,6 +24,18 @@ from megatron import mpu
...
@@ -24,6 +24,18 @@ from megatron import mpu
from
megatron
import
p2p_communication
from
megatron
import
p2p_communication
def
get_forward_backward_func
():
args
=
get_args
()
if
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
if
args
.
virtual_pipeline_model_parallel_size
is
not
None
:
forward_backward_func
=
forward_backward_pipelining_with_interleaving
else
:
forward_backward_func
=
forward_backward_pipelining_without_interleaving
else
:
forward_backward_func
=
forward_backward_no_pipelining
return
forward_backward_func
def
forward_step
(
forward_step_func
,
data_iterator
,
model
,
input_tensor
,
losses_reduced
):
def
forward_step
(
forward_step_func
,
data_iterator
,
model
,
input_tensor
,
losses_reduced
):
"""Forward step for passed-in model.
"""Forward step for passed-in model.
...
@@ -34,6 +46,7 @@ def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_r
...
@@ -34,6 +46,7 @@ def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_r
timers
=
get_timers
()
timers
=
get_timers
()
timers
(
'forward-compute'
).
start
()
timers
(
'forward-compute'
).
start
()
# TODO
model
.
module
.
module
.
set_input_tensor
(
input_tensor
)
model
.
module
.
module
.
set_input_tensor
(
input_tensor
)
output_tensor
,
loss_func
=
forward_step_func
(
data_iterator
,
model
)
output_tensor
,
loss_func
=
forward_step_func
(
data_iterator
,
model
)
if
mpu
.
is_pipeline_last_stage
():
if
mpu
.
is_pipeline_last_stage
():
...
...
megatron/text_generation_utils.py
View file @
b938ec51
...
@@ -26,9 +26,8 @@ import torch.nn.functional as F
...
@@ -26,9 +26,8 @@ import torch.nn.functional as F
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
get_tokenizer
from
megatron
import
get_tokenizer
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron.training
import
communicate
from
megatron.utils
import
get_ltor_masks_and_position_ids
from
megatron.utils
import
get_ltor_masks_and_position_ids
from
megatron.p2p_communication
import
recv_forward
,
send_forward
def
get_batch
(
context_tokens
):
def
get_batch
(
context_tokens
):
"""Generate batch from context tokens."""
"""Generate batch from context tokens."""
...
@@ -395,55 +394,26 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,
...
@@ -395,55 +394,26 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,
layer_past
=
None
,
get_key_value
=
None
,
layer_past
=
None
,
get_key_value
=
None
,
forward_method_parallel_output
=
None
):
forward_method_parallel_output
=
None
):
# Hidden size changes when not using recompute, need to tell communicate
()
# Hidden size changes when not using recompute, need to tell
p2p_
communicate
# the correct size
#
functions
the correct size
args
=
get_args
()
args
=
get_args
()
orig_seq_length
=
args
.
seq_length
orig_seq_length
=
args
.
seq_length
args
.
seq_length
=
tokens
.
shape
[
1
]
args
.
seq_length
=
tokens
.
shape
[
1
]
if
not
mpu
.
is_pipeline_first_stage
():
input_tensor
=
recv_forward
()
input_tensor
,
_
=
communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
None
,
recv_forward
=
True
,
recv_backward
=
False
)
else
:
input_tensor
=
None
# Forward pass through the model.
# Forward pass through the model.
if
mpu
.
is_pipeline_first_stage
():
model
.
set_input_tensor
(
input_tensor
)
assert
input_tensor
is
None
output_tensor
=
model
(
tokens
,
position_ids
,
attention_mask
,
if
mpu
.
is_pipeline_last_stage
():
tokentype_ids
=
tokentype_ids
,
output_tensor
=
model
(
tokens
,
position_ids
,
attention_mask
,
layer_past
=
layer_past
,
tokentype_ids
=
tokentype_ids
,
get_key_value
=
get_key_value
,
layer_past
=
layer_past
,
forward_method_parallel_output
=
forward_method_parallel_output
)
get_key_value
=
get_key_value
,
forward_method_parallel_output
=
forward_method_parallel_output
)
else
:
output_tensor
=
model
(
tokens
,
position_ids
,
attention_mask
,
tokentype_ids
=
tokentype_ids
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
)
elif
mpu
.
is_pipeline_last_stage
():
assert
input_tensor
is
not
None
output_tensor
=
model
(
input_tensor
,
attention_mask
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
,
forward_method_parallel_output
=
forward_method_parallel_output
)
else
:
assert
input_tensor
is
not
None
output_tensor
=
model
(
input_tensor
,
attention_mask
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
)
if
get_key_value
:
if
get_key_value
:
output_tensor
,
layer_past
=
output_tensor
output_tensor
,
layer_past
=
output_tensor
if
not
mpu
.
is_pipeline_last_stage
():
send_forward
(
output_tensor
)
communicate
(
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
None
,
recv_forward
=
False
,
recv_backward
=
False
)
args
.
seq_length
=
orig_seq_length
args
.
seq_length
=
orig_seq_length
if
get_key_value
:
if
get_key_value
:
...
...
tasks/eval_utils.py
View file @
b938ec51
...
@@ -17,13 +17,14 @@
...
@@ -17,13 +17,14 @@
import
os
import
os
import
time
import
time
from
functools
import
partial
import
torch
import
torch
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
print_rank_last
,
is_last_rank
from
megatron
import
print_rank_last
,
is_last_rank
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron.
training
import
communicate
from
megatron.
schedules
import
get_forward_backward_func
from
tasks.finetune_utils
import
build_data_loader
from
tasks.finetune_utils
import
build_data_loader
from
tasks.finetune_utils
import
process_batch
from
tasks.finetune_utils
import
process_batch
...
@@ -38,7 +39,7 @@ def accuracy_func_provider(single_dataset_provider):
...
@@ -38,7 +39,7 @@ def accuracy_func_provider(single_dataset_provider):
for
datapath
in
datapaths
:
for
datapath
in
datapaths
:
dataset
=
single_dataset_provider
(
datapath
)
dataset
=
single_dataset_provider
(
datapath
)
dataloader
=
build_data_loader
(
dataloader
=
build_data_loader
(
dataset
,
args
.
micro_batch_size
,
num_workers
=
args
.
num_workers
,
dataset
,
args
.
orig_
micro_batch_size
,
num_workers
=
args
.
num_workers
,
drop_last
=
(
mpu
.
get_data_parallel_world_size
()
>
1
))
drop_last
=
(
mpu
.
get_data_parallel_world_size
()
>
1
))
dataloaders
.
append
((
dataset
.
dataset_name
,
dataloader
))
dataloaders
.
append
((
dataset
.
dataset_name
,
dataloader
))
...
@@ -73,14 +74,61 @@ def accuracy_func_provider(single_dataset_provider):
...
@@ -73,14 +74,61 @@ def accuracy_func_provider(single_dataset_provider):
return
metrics_func
return
metrics_func
def
calculate_correct_answers
(
name
,
model
,
dataloader
,
def
calculate_correct_answers
(
name
,
model
,
dataloader
,
epoch
,
output_predictions
):
epoch
,
output_predictions
):
"""Calculate correct over total answers and return prediction if the
"""Calculate correct over total answers and return prediction if the
`output_predictions` is true."""
`output_predictions` is true."""
args
=
get_args
()
args
=
get_args
()
forward_backward_func
=
get_forward_backward_func
()
start_time
=
time
.
time
()
start_time
=
time
.
time
()
model
.
eval
()
for
m
in
model
:
saved_batch_size
=
args
.
micro_batch_size
m
.
eval
()
saved_micro_batch_size
=
args
.
micro_batch_size
saved_global_batch_size
=
args
.
global_batch_size
ds
=
dataloader
.
dataset
if
hasattr
(
ds
,
'sample_multiplier'
):
sample_multiplier
=
ds
.
sample_multiplier
else
:
sample_multiplier
=
1
micro_batch_size_times_data_parallel
=
args
.
orig_micro_batch_size
*
args
.
data_parallel_size
num_micro_batches
=
args
.
orig_global_batch_size
//
micro_batch_size_times_data_parallel
def
loss_func
(
output_predictions
,
labels
,
output_tensor
):
logits
=
output_tensor
loss_dict
=
{}
# Add output predictions.
if
output_predictions
:
assert
False
loss_dict
[
'softmaxes'
]
=
torch
.
nn
.
Softmax
(
dim
=-
1
)(
logits
.
float
()).
data
.
cpu
().
numpy
().
tolist
()
loss_dict
[
'labels'
]
=
labels
.
data
.
cpu
().
numpy
().
tolist
()
loss_dict
[
'ids'
]
=
batch
[
'uid'
].
cpu
().
numpy
().
tolist
()
# Compute the correct answers.
predicted
=
torch
.
argmax
(
logits
,
dim
=-
1
)
corrects
=
(
predicted
==
labels
)
# Add to the counters.
loss_dict
[
'total'
]
=
labels
.
size
(
0
)
loss_dict
[
'correct'
]
=
corrects
.
sum
().
item
()
return
0
,
loss_dict
# defined inside to capture output_predictions
def
correct_answers_forward_step
(
batch
,
model
):
try
:
batch_
=
next
(
batch
)
except
BaseException
:
batch_
=
batch
tokens
,
types
,
labels
,
attention_mask
=
process_batch
(
batch_
)
# Forward model.
args
=
get_args
()
output_tensor
=
model
(
tokens
,
attention_mask
,
tokentype_ids
=
types
)
return
output_tensor
,
partial
(
loss_func
,
output_predictions
,
labels
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
# For all the batches in the dataset.
# For all the batches in the dataset.
total
=
0
total
=
0
...
@@ -92,60 +140,30 @@ def calculate_correct_answers(name, model, dataloader,
...
@@ -92,60 +140,30 @@ def calculate_correct_answers(name, model, dataloader,
labels
=
[]
labels
=
[]
ids
=
[]
ids
=
[]
for
_
,
batch
in
enumerate
(
dataloader
):
for
_
,
batch
in
enumerate
(
dataloader
):
# Run the model forward.
tokens
,
types
,
labels_
,
attention_mask
=
process_batch
(
batch
)
# For evaluation only mode we use drop_last = False to get all the
# For evaluation only mode we use drop_last = False to get all the
# samples, which means we might not have a full batch, so we
# samples, which means we might not have a full batch, so we
# adjust batch_size here to actual batch size of data
# adjust batch_size here to actual batch size of data
actual_batch_size
=
len
(
label
s_
)
actual_batch_size
=
len
(
batch
[
'
label
'
]
)
# ... applying sample_multiplier if necessary
# ... applying sample_multiplier if necessary
ds
=
dataloader
.
dataset
args
.
micro_batch_size
=
actual_batch_size
*
sample_multiplier
if
hasattr
(
ds
,
'sample_multiplier'
):
args
.
global_batch_size
=
actual_batch_size
*
sample_multiplier
*
num_micro_batches
actual_batch_size
*=
ds
.
sample_multiplier
args
.
micro_batch_size
=
actual_batch_size
if
not
mpu
.
is_pipeline_first_stage
():
input_tensor
,
_
=
communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
None
,
recv_forward
=
True
,
recv_backward
=
False
)
else
:
input_tensor
=
None
# Forward model.
loss_dicts
=
forward_backward_func
(
correct_answers_forward_step
,
batch
,
model
,
if
mpu
.
is_pipeline_first_stage
():
optimizer
=
None
,
timers
=
None
,
forward_only
=
True
)
assert
input_tensor
is
None
output_tensor
=
model
(
tokens
,
attention_mask
,
tokentype_ids
=
types
)
else
:
assert
input_tensor
is
not
None
output_tensor
=
model
(
input_tensor
,
attention_mask
)
if
mpu
.
is_pipeline_last_stage
():
logits
=
output_tensor
# Add output predictions.
for
loss_dict
in
loss_dicts
:
if
output_predictions
:
if
output_predictions
:
softmaxes
.
extend
(
torch
.
nn
.
Softmax
(
dim
=-
1
)(
softmaxes
.
extend
(
loss_dict
[
'softmaxes'
])
logits
.
float
()).
data
.
cpu
().
numpy
().
tolist
())
labels
.
extend
(
loss_dict
[
'labels'
])
labels
.
extend
(
labels_
.
data
.
cpu
().
numpy
().
tolist
())
ids
.
extend
(
loss_dict
[
'ids'
])
ids
.
extend
(
batch
[
'uid'
].
cpu
().
numpy
().
tolist
())
total
+=
loss_dict
[
'total'
]
# Compute the correct answers.
correct
+=
loss_dict
[
'correct'
]
predicted
=
torch
.
argmax
(
logits
,
dim
=-
1
)
corrects
=
(
predicted
==
labels_
)
# Add to the counters.
for
m
in
model
:
total
+=
labels_
.
size
(
0
)
m
.
train
()
correct
+=
corrects
.
sum
().
item
()
args
.
micro_batch_size
=
saved_micro_batch_size
else
:
args
.
global_batch_size
=
saved_global_batch_size
communicate
(
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
None
,
recv_forward
=
False
,
recv_backward
=
False
)
model
.
train
()
args
.
micro_batch_size
=
saved_batch_size
# Reduce.
# Reduce.
if
mpu
.
is_pipeline_last_stage
():
if
mpu
.
is_pipeline_last_stage
():
...
...
tasks/finetune_utils.py
View file @
b938ec51
...
@@ -15,6 +15,8 @@
...
@@ -15,6 +15,8 @@
"""Finetune utilities."""
"""Finetune utilities."""
from
functools
import
partial
import
torch
import
torch
from
megatron
import
get_args
from
megatron
import
get_args
...
@@ -46,7 +48,20 @@ def process_batch(batch):
...
@@ -46,7 +48,20 @@ def process_batch(batch):
return
tokens
,
types
,
labels
,
attention_mask
return
tokens
,
types
,
labels
,
attention_mask
def
_cross_entropy_forward_step
(
batch
,
model
,
input_tensor
):
def
cross_entropy_loss_func
(
labels
,
output_tensor
):
logits
=
output_tensor
# Cross-entropy loss.
loss_func
=
torch
.
nn
.
CrossEntropyLoss
()
loss
=
loss_func
(
logits
.
contiguous
().
float
(),
labels
)
# Reduce loss for logging.
averaged_loss
=
average_losses_across_data_parallel_group
([
loss
])
return
loss
,
{
'lm loss'
:
averaged_loss
[
0
]}
def
_cross_entropy_forward_step
(
batch
,
model
):
"""Simple forward step with cross-entropy loss."""
"""Simple forward step with cross-entropy loss."""
timers
=
get_timers
()
timers
=
get_timers
()
...
@@ -60,25 +75,9 @@ def _cross_entropy_forward_step(batch, model, input_tensor):
...
@@ -60,25 +75,9 @@ def _cross_entropy_forward_step(batch, model, input_tensor):
timers
(
'batch-generator'
).
stop
()
timers
(
'batch-generator'
).
stop
()
# Forward model.
# Forward model.
if
mpu
.
is_pipeline_first_stage
():
output_tensor
=
model
(
tokens
,
attention_mask
,
tokentype_ids
=
types
)
assert
input_tensor
is
None
output_tensor
=
model
(
tokens
,
attention_mask
,
tokentype_ids
=
types
)
else
:
assert
input_tensor
is
not
None
output_tensor
=
model
(
input_tensor
,
attention_mask
)
if
mpu
.
is_pipeline_last_stage
():
logits
=
output_tensor
# Cross-entropy loss.
loss_func
=
torch
.
nn
.
CrossEntropyLoss
()
loss
=
loss_func
(
logits
.
contiguous
().
float
(),
labels
)
# Reduce loss for logging.
averaged_loss
=
average_losses_across_data_parallel_group
([
loss
])
return
loss
,
{
'lm loss'
:
averaged_loss
[
0
]}
return
output_tensor
,
partial
(
cross_entropy_loss_func
,
labels
)
return
output_tensor
def
build_data_loader
(
dataset
,
micro_batch_size
,
num_workers
,
drop_last
):
def
build_data_loader
(
dataset
,
micro_batch_size
,
num_workers
,
drop_last
):
...
@@ -135,6 +134,8 @@ def _build_train_valid_dataloaders(train_dataset, valid_dataset):
...
@@ -135,6 +134,8 @@ def _build_train_valid_dataloaders(train_dataset, valid_dataset):
# This is necessary so pipeline transfers know what size they are
# This is necessary so pipeline transfers know what size they are
# and the LR schedule, which is based on samples seen, gets set
# and the LR schedule, which is based on samples seen, gets set
# correctly.
# correctly.
args
.
orig_micro_batch_size
=
args
.
micro_batch_size
args
.
orig_global_batch_size
=
args
.
global_batch_size
if
hasattr
(
train_dataset
,
'sample_multiplier'
):
if
hasattr
(
train_dataset
,
'sample_multiplier'
):
args
.
micro_batch_size
*=
train_dataset
.
sample_multiplier
args
.
micro_batch_size
*=
train_dataset
.
sample_multiplier
args
.
global_batch_size
*=
train_dataset
.
sample_multiplier
args
.
global_batch_size
*=
train_dataset
.
sample_multiplier
...
@@ -149,7 +150,8 @@ def _train(model, optimizer, lr_scheduler, forward_step,
...
@@ -149,7 +150,8 @@ def _train(model, optimizer, lr_scheduler, forward_step,
timers
=
get_timers
()
timers
=
get_timers
()
# Turn on training mode which enables dropout.
# Turn on training mode which enables dropout.
model
.
train
()
for
m
in
model
:
m
.
train
()
# Tracking loss.
# Tracking loss.
losses_dict_sum
=
{}
losses_dict_sum
=
{}
...
@@ -180,10 +182,8 @@ def _train(model, optimizer, lr_scheduler, forward_step,
...
@@ -180,10 +182,8 @@ def _train(model, optimizer, lr_scheduler, forward_step,
start_iteration
=
0
start_iteration
=
0
# Train for one step.
# Train for one step.
losses_dict
,
skipped_iter
,
grad_norm
=
train_step
(
forward_step
,
out
=
train_step
(
forward_step
,
batch
,
model
,
optimizer
,
lr_scheduler
)
batch
,
model
,
losses_dict
,
skipped_iter
,
grad_norm
,
num_zeros_in_grad
=
out
optimizer
,
lr_scheduler
)
iteration
+=
1
iteration
+=
1
# Logging.
# Logging.
...
@@ -195,7 +195,7 @@ def _train(model, optimizer, lr_scheduler, forward_step,
...
@@ -195,7 +195,7 @@ def _train(model, optimizer, lr_scheduler, forward_step,
iteration
,
iteration
,
optimizer
.
get_loss_scale
().
item
(),
optimizer
.
get_loss_scale
().
item
(),
report_memory_flag
,
skipped_iter
,
report_memory_flag
,
skipped_iter
,
grad_norm
,
params_norm
)
grad_norm
,
params_norm
,
num_zeros_in_grad
)
# Autoresume
# Autoresume
if
args
.
adlr_autoresume
and
\
if
args
.
adlr_autoresume
and
\
...
...
tasks/glue/finetune.py
View file @
b938ec51
...
@@ -19,7 +19,7 @@ from megatron import get_args
...
@@ -19,7 +19,7 @@ from megatron import get_args
from
megatron
import
print_rank_0
from
megatron
import
print_rank_0
from
megatron
import
get_tokenizer
from
megatron
import
get_tokenizer
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron.model.classification
import
Classification
,
ClassificationFirstStage
,
ClassificationIntermediateStage
,
ClassificationLastStage
from
megatron.model.classification
import
Classification
from
tasks.eval_utils
import
accuracy_func_provider
from
tasks.eval_utils
import
accuracy_func_provider
from
tasks.finetune_utils
import
finetune
from
tasks.finetune_utils
import
finetune
...
@@ -39,25 +39,14 @@ def glue_classification(num_classes, Dataset,
...
@@ -39,25 +39,14 @@ def glue_classification(num_classes, Dataset,
return
train_dataset
,
valid_dataset
return
train_dataset
,
valid_dataset
def
model_provider
():
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
"""Build the model."""
"""Build the model."""
args
=
get_args
()
args
=
get_args
()
print_rank_0
(
'building classification model for {} ...'
.
format
(
print_rank_0
(
'building classification model for {} ...'
.
format
(
args
.
task
))
args
.
task
))
if
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
model
=
Classification
(
num_classes
=
num_classes
,
num_tokentypes
=
2
,
# Determine model based on position of stage in pipeline.
pre_process
=
pre_process
,
post_process
=
post_process
)
if
mpu
.
is_pipeline_first_stage
():
model
=
ClassificationFirstStage
(
num_classes
=
num_classes
,
num_tokentypes
=
2
)
elif
mpu
.
is_pipeline_last_stage
():
model
=
ClassificationLastStage
(
num_classes
=
num_classes
,
num_tokentypes
=
2
)
else
:
model
=
ClassificationIntermediateStage
(
num_classes
=
num_classes
,
num_tokentypes
=
2
)
else
:
model
=
Classification
(
num_classes
=
num_classes
,
num_tokentypes
=
2
)
return
model
return
model
...
...
tasks/main.py
View file @
b938ec51
...
@@ -70,6 +70,11 @@ if __name__ == '__main__':
...
@@ -70,6 +70,11 @@ if __name__ == '__main__':
initialize_megatron
(
extra_args_provider
=
get_tasks_args
)
initialize_megatron
(
extra_args_provider
=
get_tasks_args
)
args
=
get_args
()
args
=
get_args
()
if
args
.
num_layers_per_virtual_pipeline_stage
is
not
None
:
print
(
"Interleaved pipeline schedule is not yet supported for downstream tasks."
)
exit
()
if
args
.
task
==
'RACE'
:
if
args
.
task
==
'RACE'
:
from
race.finetune
import
main
from
race.finetune
import
main
elif
args
.
task
in
[
'MNLI'
,
'QQP'
]:
elif
args
.
task
in
[
'MNLI'
,
'QQP'
]:
...
...
tasks/race/finetune.py
View file @
b938ec51
...
@@ -19,7 +19,7 @@ from megatron import get_args
...
@@ -19,7 +19,7 @@ from megatron import get_args
from
megatron
import
print_rank_0
from
megatron
import
print_rank_0
from
megatron
import
get_tokenizer
from
megatron
import
get_tokenizer
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron.model.multiple_choice
import
MultipleChoice
,
MultipleChoiceFirstStage
,
MultipleChoiceIntermediateStage
,
MultipleChoiceLastStage
from
megatron.model.multiple_choice
import
MultipleChoice
from
tasks.eval_utils
import
accuracy_func_provider
from
tasks.eval_utils
import
accuracy_func_provider
from
tasks.finetune_utils
import
finetune
from
tasks.finetune_utils
import
finetune
from
tasks.race.data
import
RaceDataset
from
tasks.race.data
import
RaceDataset
...
@@ -38,20 +38,13 @@ def train_valid_datasets_provider():
...
@@ -38,20 +38,13 @@ def train_valid_datasets_provider():
return
train_dataset
,
valid_dataset
return
train_dataset
,
valid_dataset
def
model_provider
():
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
"""Build the model."""
"""Build the model."""
print_rank_0
(
'building multichoice model for RACE ...'
)
print_rank_0
(
'building multichoice model for RACE ...'
)
if
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
model
=
MultipleChoice
(
num_tokentypes
=
2
,
# Determine model based on position of stage in pipeline.
pre_process
=
pre_process
,
if
mpu
.
is_pipeline_first_stage
():
post_process
=
post_process
)
model
=
MultipleChoiceFirstStage
(
num_tokentypes
=
2
)
elif
mpu
.
is_pipeline_last_stage
():
model
=
MultipleChoiceLastStage
(
num_tokentypes
=
2
)
else
:
model
=
MultipleChoiceIntermediateStage
(
num_tokentypes
=
2
)
else
:
model
=
MultipleChoice
(
num_tokentypes
=
2
)
return
model
return
model
...
...
tasks/zeroshot_gpt/evaluate.py
View file @
b938ec51
...
@@ -25,8 +25,9 @@ from megatron import get_tokenizer
...
@@ -25,8 +25,9 @@ from megatron import get_tokenizer
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron.checkpointing
import
load_checkpoint
from
megatron.checkpointing
import
load_checkpoint
from
megatron.model
import
GPTModel
,
GPTModelFirstStage
,
GPTModelLastStage
,
GPTModelIntermediateStage
from
megatron.model
import
GPTModel
,
GPTModelFirstStage
,
GPTModelLastStage
,
GPTModelIntermediateStage
from
megatron.training
import
get_model
,
communicate
from
megatron.training
import
get_model
from
megatron.utils
import
get_ltor_masks_and_position_ids
from
megatron.utils
import
get_ltor_masks_and_position_ids
from
megatron.p2p_communication
import
recv_forward
,
send_forward
from
tasks.finetune_utils
import
build_data_loader
from
tasks.finetune_utils
import
build_data_loader
from
.datasets
import
build_dataset
from
.datasets
import
build_dataset
...
@@ -98,14 +99,7 @@ def forward_step(batch, model, eval_metric):
...
@@ -98,14 +99,7 @@ def forward_step(batch, model, eval_metric):
args
.
micro_batch_size
=
len
(
labels
)
args
.
micro_batch_size
=
len
(
labels
)
# Forward model.
# Forward model.
if
not
mpu
.
is_pipeline_first_stage
():
input_tensor
=
recv_forward
()
input_tensor
,
_
=
communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
None
,
recv_forward
=
True
,
recv_backward
=
False
)
else
:
input_tensor
=
None
# Forward pass through the model.
# Forward pass through the model.
if
mpu
.
is_pipeline_first_stage
():
if
mpu
.
is_pipeline_first_stage
():
...
@@ -118,12 +112,7 @@ def forward_step(batch, model, eval_metric):
...
@@ -118,12 +112,7 @@ def forward_step(batch, model, eval_metric):
assert
input_tensor
is
not
None
assert
input_tensor
is
not
None
output
=
model
(
input_tensor
,
attention_mask
)
output
=
model
(
input_tensor
,
attention_mask
)
if
not
mpu
.
is_pipeline_last_stage
():
send_forward
(
output
)
communicate
(
tensor_send_next
=
output
,
tensor_send_prev
=
None
,
recv_forward
=
False
,
recv_backward
=
False
)
return
None
if
mpu
.
is_pipeline_last_stage
():
if
mpu
.
is_pipeline_last_stage
():
# For loss, return the unreduced loss.
# For loss, return the unreduced loss.
...
@@ -214,6 +203,10 @@ def main():
...
@@ -214,6 +203,10 @@ def main():
"""Main program."""
"""Main program."""
args
=
get_args
()
args
=
get_args
()
if
args
.
num_layers_per_virtual_pipeline_stage
is
not
None
:
print
(
"Interleaved pipeline schedule is not yet supported for text generation."
)
exit
()
if
args
.
task
==
'LAMBADA'
:
if
args
.
task
==
'LAMBADA'
:
eval_metric
=
'accuracy'
eval_metric
=
'accuracy'
elif
args
.
task
==
'WIKITEXT103'
:
elif
args
.
task
==
'WIKITEXT103'
:
...
@@ -227,6 +220,9 @@ def main():
...
@@ -227,6 +220,9 @@ def main():
if
args
.
load
is
not
None
:
if
args
.
load
is
not
None
:
_
=
load_checkpoint
(
model
,
None
,
None
)
_
=
load_checkpoint
(
model
,
None
,
None
)
assert
len
(
model
)
==
1
,
"Above condition should have caught this"
model
=
model
[
0
]
# Data stuff.
# Data stuff.
dataset
=
build_dataset
(
args
.
task
)
dataset
=
build_dataset
(
args
.
task
)
dataloader
=
build_data_loader
(
dataset
,
args
.
micro_batch_size
,
dataloader
=
build_data_loader
(
dataset
,
args
.
micro_batch_size
,
...
...
tools/generate_samples_gpt.py
View file @
b938ec51
...
@@ -96,12 +96,20 @@ def main():
...
@@ -96,12 +96,20 @@ def main():
'no_load_rng'
:
True
,
'no_load_rng'
:
True
,
'no_load_optim'
:
True
})
'no_load_optim'
:
True
})
args
=
get_args
()
if
args
.
num_layers_per_virtual_pipeline_stage
is
not
None
:
print
(
"Interleaved pipeline schedule is not yet supported for text generation."
)
exit
()
# Set up model and load checkpoint.
# Set up model and load checkpoint.
model
=
get_model
(
model_provider
)
model
=
get_model
(
model_provider
)
args
=
get_args
()
if
args
.
load
is
not
None
:
if
args
.
load
is
not
None
:
_
=
load_checkpoint
(
model
,
None
,
None
)
_
=
load_checkpoint
(
model
,
None
,
None
)
assert
len
(
model
)
==
1
,
"Above condition should have caught this"
model
=
model
[
0
]
# Generate samples.
# Generate samples.
if
args
.
num_samples
==
0
:
if
args
.
num_samples
==
0
:
args
.
micro_batch_size
=
1
args
.
micro_batch_size
=
1
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment