Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
5e56e563
Commit
5e56e563
authored
Apr 28, 2020
by
Neel Kant
Browse files
Merge master into realm-mlm
parents
6c0a5bd8
569b3dab
Changes
107
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
121 additions
and
132 deletions
+121
-132
megatron/training.py
megatron/training.py
+8
-8
megatron/utils.py
megatron/utils.py
+7
-9
pretrain_bert.py
pretrain_bert.py
+1
-1
pretrain_gpt2.py
pretrain_gpt2.py
+2
-3
tasks/data_utils.py
tasks/data_utils.py
+8
-8
tasks/ensemble_classifier.py
tasks/ensemble_classifier.py
+17
-13
tasks/eval_utils.py
tasks/eval_utils.py
+1
-1
tasks/finetune_utils.py
tasks/finetune_utils.py
+5
-5
tasks/glue/data.py
tasks/glue/data.py
+1
-4
tasks/glue/finetune.py
tasks/glue/finetune.py
+3
-4
tasks/glue/mnli.py
tasks/glue/mnli.py
+1
-2
tasks/glue/qqp.py
tasks/glue/qqp.py
+1
-2
tasks/main.py
tasks/main.py
+2
-2
tasks/race/data.py
tasks/race/data.py
+0
-3
tasks/race/finetune.py
tasks/race/finetune.py
+1
-1
tasks/zeroshot_gpt2/datasets.py
tasks/zeroshot_gpt2/datasets.py
+13
-13
tasks/zeroshot_gpt2/detokenizer.py
tasks/zeroshot_gpt2/detokenizer.py
+46
-48
tasks/zeroshot_gpt2/evaluate.py
tasks/zeroshot_gpt2/evaluate.py
+2
-3
tools/generate_samples_gpt2.py
tools/generate_samples_gpt2.py
+1
-1
tools/merge_mp_partitions.py
tools/merge_mp_partitions.py
+1
-1
No files found.
megatron/training.py
View file @
5e56e563
# coding=utf-8
# Copyright (c) 20
19
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 20
20
, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -96,7 +96,6 @@ def pretrain(train_valid_test_dataset_provider, model_provider,
model
,
optimizer
,
lr_scheduler
,
train_data_iterator
,
valid_data_iterator
)
if
args
.
do_valid
:
prefix
=
'the end of training for val data'
evaluate_and_print_results
(
prefix
,
forward_step_func
,
...
...
@@ -173,7 +172,7 @@ def get_optimizer(model):
dynamic_loss_scale
=
args
.
dynamic_loss_scale
,
dynamic_loss_args
=
{
'scale_window'
:
args
.
loss_scale_window
,
'min_scale'
:
args
.
min_scale
,
'min_scale'
:
args
.
min_scale
,
'delayed_shift'
:
args
.
hysteresis
})
return
optimizer
...
...
@@ -228,7 +227,7 @@ def backward_step(optimizer, model, loss):
torch
.
cuda
.
synchronize
()
# Backward pass.
optimizer
.
zero_grad
()
optimizer
.
zero_grad
(
set_grads_to_None
=
True
)
if
args
.
fp16
:
optimizer
.
backward
(
loss
,
update_master_grads
=
False
)
else
:
...
...
@@ -297,6 +296,7 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
# Logging.
timers_to_log
=
[]
def
add_to_logging
(
name
):
if
name
in
timers
.
timers
:
timers_to_log
.
append
(
name
)
...
...
@@ -431,7 +431,7 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
# Reduce across processes.
for
key
in
loss_dict
:
total_loss_dict
[
key
]
=
total_loss_dict
.
get
(
key
,
0.
)
+
\
loss_dict
[
key
]
loss_dict
[
key
]
# Move model back to the train mode.
model
.
train
()
...
...
@@ -521,14 +521,14 @@ def build_train_valid_test_data_iterators(
# Shift the start iterations.
if
train_dataloader
is
not
None
:
train_dataloader
.
batch_sampler
.
start_iter
=
args
.
iteration
%
\
len
(
train_dataloader
)
len
(
train_dataloader
)
print_rank_0
(
'setting training data start iteration to {}'
.
format
(
train_dataloader
.
batch_sampler
.
start_iter
))
if
valid_dataloader
is
not
None
:
start_iter_val
=
(
args
.
iteration
//
args
.
eval_interval
)
*
\
args
.
eval_iters
args
.
eval_iters
valid_dataloader
.
batch_sampler
.
start_iter
=
start_iter_val
%
\
len
(
valid_dataloader
)
len
(
valid_dataloader
)
print_rank_0
(
'setting validation data start iteration to {}'
.
format
(
valid_dataloader
.
batch_sampler
.
start_iter
))
...
...
megatron/utils.py
View file @
5e56e563
# coding=utf-8
# Copyright (c) 20
19
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 20
20
, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -48,7 +48,7 @@ def report_memory(name):
torch
.
cuda
.
max_memory_allocated
()
/
mega_bytes
)
string
+=
' | cached: {}'
.
format
(
torch
.
cuda
.
memory_cached
()
/
mega_bytes
)
string
+=
' | max cached: {}'
.
format
(
torch
.
cuda
.
max_memory_cached
()
/
mega_bytes
)
torch
.
cuda
.
max_memory_cached
()
/
mega_bytes
)
print_rank_0
(
string
)
...
...
@@ -119,8 +119,7 @@ def get_ltor_masks_and_position_ids(data,
eod_token
,
reset_position_ids
,
reset_attention_mask
,
eod_mask_loss
,
fp16
):
eod_mask_loss
):
"""Build masks and position id for left to right model."""
# Extract batch size and sequence length.
...
...
@@ -164,14 +163,13 @@ def get_ltor_masks_and_position_ids(data,
i
=
eod_index
[
j
]
# Mask attention loss.
if
reset_attention_mask
:
attention_mask
[
b
,
0
,
(
i
+
1
):,
:(
i
+
1
)]
=
0
attention_mask
[
b
,
0
,
(
i
+
1
):,
:(
i
+
1
)]
=
0
# Reset positions.
if
reset_position_ids
:
position_ids
[
b
,
(
i
+
1
):]
-=
(
i
+
1
-
prev_index
)
position_ids
[
b
,
(
i
+
1
):]
-=
(
i
+
1
-
prev_index
)
prev_index
=
i
+
1
# Convert
if
fp16
:
attention_mask
=
attention_mask
.
half
()
# Convert attention mask to binary:
attention_mask
=
(
attention_mask
<
0.5
)
return
attention_mask
,
loss_mask
,
position_ids
pretrain_bert.py
View file @
5e56e563
# coding=utf-8
# Copyright (c) 20
19
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 20
20
, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
pretrain_gpt2.py
View file @
5e56e563
# coding=utf-8
# Copyright (c) 20
19
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 20
20
, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -65,8 +65,7 @@ def get_batch(data_iterator):
tokenizer
.
eod
,
args
.
reset_position_ids
,
args
.
reset_attention_mask
,
args
.
eod_mask_loss
,
args
.
fp16
)
args
.
eod_mask_loss
)
return
tokens
,
labels
,
loss_mask
,
attention_mask
,
position_ids
...
...
tasks/data_utils.py
View file @
5e56e563
# coding=utf-8
# Copyright (c) 20
19
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 20
20
, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -75,8 +75,8 @@ def build_tokens_types_paddings_from_ids(text_a_ids, text_b_ids, max_seq_length,
# A.
len_text_a
=
len
(
text_a_ids
)
ids
.
extend
(
text_a_ids
)
types
.
extend
([
0
]
*
len_text_a
)
paddings
.
extend
([
1
]
*
len_text_a
)
types
.
extend
([
0
]
*
len_text_a
)
paddings
.
extend
([
1
]
*
len_text_a
)
# [SEP].
ids
.
append
(
sep_id
)
...
...
@@ -87,8 +87,8 @@ def build_tokens_types_paddings_from_ids(text_a_ids, text_b_ids, max_seq_length,
if
text_b_ids
is
not
None
:
len_text_b
=
len
(
text_b_ids
)
ids
.
extend
(
text_b_ids
)
types
.
extend
([
1
]
*
len_text_b
)
paddings
.
extend
([
1
]
*
len_text_b
)
types
.
extend
([
1
]
*
len_text_b
)
paddings
.
extend
([
1
]
*
len_text_b
)
# Cap the size.
trimmed
=
False
...
...
@@ -111,8 +111,8 @@ def build_tokens_types_paddings_from_ids(text_a_ids, text_b_ids, max_seq_length,
# Padding.
padding_length
=
max_seq_length
-
len
(
ids
)
if
padding_length
>
0
:
ids
.
extend
([
pad_id
]
*
padding_length
)
types
.
extend
([
pad_id
]
*
padding_length
)
paddings
.
extend
([
0
]
*
padding_length
)
ids
.
extend
([
pad_id
]
*
padding_length
)
types
.
extend
([
pad_id
]
*
padding_length
)
paddings
.
extend
([
0
]
*
padding_length
)
return
ids
,
types
,
paddings
tasks/ensemble_classifier.py
View file @
5e56e563
...
...
@@ -5,6 +5,7 @@ import collections
import
numpy
as
np
import
torch
def
process_files
(
args
):
all_predictions
=
collections
.
OrderedDict
()
all_labels
=
collections
.
OrderedDict
()
...
...
@@ -40,12 +41,12 @@ def get_threshold(all_predictions, all_labels, one_threshold=False):
for
dataset
in
all_predictions
:
preds
=
all_predictions
[
dataset
]
labels
=
all_labels
[
dataset
]
out_thresh
.
append
(
calc_threshold
(
preds
,
labels
))
out_thresh
.
append
(
calc_threshold
(
preds
,
labels
))
return
out_thresh
def
calc_threshold
(
p
,
l
):
trials
=
[(
i
)
*
(
1.
/
100.
)
for
i
in
range
(
100
)]
trials
=
[(
i
)
*
(
1.
/
100.
)
for
i
in
range
(
100
)]
best_acc
=
float
(
'-inf'
)
best_thresh
=
0
for
t
in
trials
:
...
...
@@ -58,7 +59,7 @@ def calc_threshold(p, l):
def
apply_threshold
(
preds
,
t
):
assert
(
np
.
allclose
(
preds
.
sum
(
-
1
),
np
.
ones
(
preds
.
shape
[
0
])))
prob
=
preds
[:,
-
1
]
prob
=
preds
[:,
-
1
]
thresholded
=
(
prob
>=
t
).
astype
(
int
)
preds
=
np
.
zeros_like
(
preds
)
preds
[
np
.
arange
(
len
(
thresholded
)),
thresholded
.
reshape
(
-
1
)]
=
1
...
...
@@ -66,8 +67,8 @@ def apply_threshold(preds, t):
def
threshold_predictions
(
all_predictions
,
threshold
):
if
len
(
threshold
)
!=
len
(
all_predictions
):
threshold
=
[
threshold
[
-
1
]]
*
(
len
(
all_predictions
)
-
len
(
threshold
))
if
len
(
threshold
)
!=
len
(
all_predictions
):
threshold
=
[
threshold
[
-
1
]]
*
(
len
(
all_predictions
)
-
len
(
threshold
))
for
i
,
dataset
in
enumerate
(
all_predictions
):
thresh
=
threshold
[
i
]
preds
=
all_predictions
[
dataset
]
...
...
@@ -77,7 +78,7 @@ def threshold_predictions(all_predictions, threshold):
def
postprocess_predictions
(
all_predictions
,
all_labels
,
args
):
for
d
in
all_predictions
:
all_predictions
[
d
]
=
all_predictions
[
d
]
/
len
(
args
.
paths
)
all_predictions
[
d
]
=
all_predictions
[
d
]
/
len
(
args
.
paths
)
if
args
.
calc_threshold
:
args
.
threshold
=
get_threshold
(
all_predictions
,
all_labels
,
args
.
one_threshold
)
...
...
@@ -98,19 +99,22 @@ def write_predictions(all_predictions, all_labels, all_uid, args):
if
args
.
eval
:
correct
=
(
preds
==
all_labels
[
dataset
]).
sum
()
num
=
len
(
all_labels
[
dataset
])
accuracy
=
correct
/
num
accuracy
=
correct
/
num
count
+=
num
all_correct
+=
correct
accuracy
=
(
preds
==
all_labels
[
dataset
]).
mean
()
print
(
accuracy
)
if
not
os
.
path
.
exists
(
os
.
path
.
join
(
args
.
outdir
,
dataset
)):
os
.
makedirs
(
os
.
path
.
join
(
args
.
outdir
,
dataset
))
outpath
=
os
.
path
.
join
(
args
.
outdir
,
dataset
,
os
.
path
.
splitext
(
args
.
prediction_name
)[
0
]
+
'.tsv'
)
outpath
=
os
.
path
.
join
(
args
.
outdir
,
dataset
,
os
.
path
.
splitext
(
args
.
prediction_name
)[
0
]
+
'.tsv'
)
with
open
(
outpath
,
'w'
)
as
f
:
f
.
write
(
'id
\t
label
\n
'
)
f
.
write
(
'
\n
'
.
join
(
str
(
uid
)
+
'
\t
'
+
str
(
args
.
labels
[
p
])
for
uid
,
p
in
zip
(
all_uid
[
dataset
],
preds
.
tolist
())))
f
.
write
(
'
\n
'
.
join
(
str
(
uid
)
+
'
\t
'
+
str
(
args
.
labels
[
p
])
for
uid
,
p
in
zip
(
all_uid
[
dataset
],
preds
.
tolist
())))
if
args
.
eval
:
print
(
all_correct
/
count
)
print
(
all_correct
/
count
)
def
ensemble_predictions
(
args
):
...
...
@@ -119,7 +123,7 @@ def ensemble_predictions(args):
write_predictions
(
all_predictions
,
all_labels
,
all_uid
,
args
)
def
main
():
def
main
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--paths'
,
required
=
True
,
nargs
=
'+'
,
help
=
'paths to checkpoint directories used in ensemble'
)
...
...
@@ -135,11 +139,11 @@ def main():
help
=
'use on threshold for all subdatasets'
)
parser
.
add_argument
(
'--threshold'
,
nargs
=
'+'
,
default
=
None
,
type
=
float
,
help
=
'user supplied threshold for classification'
)
parser
.
add_argument
(
'--labels'
,
nargs
=
'+'
,
default
=
None
,
parser
.
add_argument
(
'--labels'
,
nargs
=
'+'
,
default
=
None
,
help
=
'whitespace separated list of label names'
)
args
=
parser
.
parse_args
()
ensemble_predictions
(
args
)
if
__name__
==
'__main__'
:
main
()
\ No newline at end of file
main
()
tasks/eval_utils.py
View file @
5e56e563
# coding=utf-8
# Copyright (c) 20
19
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 20
20
, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
tasks/finetune_utils.py
View file @
5e56e563
# coding=utf-8
# Copyright (c) 20
19
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 20
20
, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -21,7 +21,7 @@ from megatron import get_args
from
megatron
import
get_timers
from
megatron
import
mpu
from
megatron
import
print_rank_0
from
megatron.checkpointing
import
load_checkpoint
from
megatron.checkpointing
import
load_checkpoint
from
megatron.checkpointing
import
save_checkpoint
from
megatron.training
import
evaluate_and_print_results
from
megatron.training
import
setup_model_and_optimizer
...
...
@@ -53,7 +53,7 @@ def _cross_entropy_forward_step(batch, model):
timers
(
'batch generator'
).
start
()
try
:
batch_
=
next
(
batch
)
except
:
except
BaseException
:
batch_
=
batch
tokens
,
types
,
labels
,
attention_mask
=
process_batch
(
batch_
)
timers
(
'batch generator'
).
stop
()
...
...
@@ -146,7 +146,7 @@ def _train(model, optimizer, lr_scheduler, forward_step,
# For each remaining epoch
timers
(
'interval time'
).
start
()
for
epoch
in
range
(
start_epoch
,
args
.
epochs
):
print_rank_0
(
'working on epoch {} ...'
.
format
(
epoch
+
1
))
print_rank_0
(
'working on epoch {} ...'
.
format
(
epoch
+
1
))
# Set the data loader epoch to shuffle the index iterator.
train_dataloader
.
sampler
.
set_epoch
(
args
.
seed
+
epoch
)
...
...
@@ -172,7 +172,7 @@ def _train(model, optimizer, lr_scheduler, forward_step,
report_memory_flag
)
# Autoresume
if
args
.
adlr_autoresume
and
\
if
args
.
adlr_autoresume
and
\
(
iteration
%
args
.
adlr_autoresume_interval
==
0
):
check_adlr_autoresume_termination
(
iteration
,
model
,
optimizer
,
lr_scheduler
)
...
...
tasks/glue/data.py
View file @
5e56e563
# coding=utf-8
# Copyright (c) 20
19
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 20
20
, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -48,11 +48,9 @@ class GLUEAbstractDataset(ABC, Dataset):
print_rank_0
(
' >> total number of samples: {}'
.
format
(
len
(
self
.
samples
)))
def
__len__
(
self
):
return
len
(
self
.
samples
)
def
__getitem__
(
self
,
idx
):
raw_sample
=
self
.
samples
[
idx
]
ids
,
types
,
paddings
=
build_tokens_types_paddings_from_text
(
...
...
@@ -62,7 +60,6 @@ class GLUEAbstractDataset(ABC, Dataset):
raw_sample
[
'label'
],
raw_sample
[
'uid'
])
return
sample
@
abstractmethod
def
process_samples_from_single_path
(
self
,
datapath
):
"""Abstract method that takes a single path / filename and
...
...
tasks/glue/finetune.py
View file @
5e56e563
# coding=utf-8
# Copyright (c) 20
19
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 20
20
, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -38,7 +38,6 @@ def glue_classification(num_classes, Dataset,
return
train_dataset
,
valid_dataset
def
model_provider
():
"""Build the model."""
args
=
get_args
()
...
...
@@ -48,7 +47,6 @@ def glue_classification(num_classes, Dataset,
return
Classification
(
num_classes
=
num_classes
,
num_tokentypes
=
2
)
def
metrics_func_provider
():
"""Privde metrics callback function."""
def
single_dataset_provider
(
datapath
):
...
...
@@ -59,7 +57,6 @@ def glue_classification(num_classes, Dataset,
return
Dataset
(
name
,
[
datapath
],
tokenizer
,
args
.
seq_length
)
return
accuracy_func_provider
(
single_dataset_provider
)
"""Finetune/evaluate."""
finetune
(
train_valid_datasets_provider
,
model_provider
,
end_of_epoch_callback_provider
=
metrics_func_provider
)
...
...
@@ -72,6 +69,7 @@ def main():
num_classes
=
3
from
tasks.glue.mnli
import
MNLIDataset
as
Dataset
def
name_from_datapath
(
datapath
):
return
datapath
.
split
(
'MNLI'
)[
-
1
].
strip
(
'.tsv'
).
strip
(
'/'
).
replace
(
'_'
,
'-'
)
...
...
@@ -80,6 +78,7 @@ def main():
num_classes
=
2
from
tasks.glue.qqp
import
QQPDataset
as
Dataset
def
name_from_datapath
(
datapath
):
return
datapath
.
split
(
'QQP'
)[
-
1
].
strip
(
'.tsv'
).
strip
(
'/'
).
replace
(
'_'
,
'-'
)
...
...
tasks/glue/mnli.py
View file @
5e56e563
# coding=utf-8
# Copyright (c) 20
19
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 20
20
, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -31,7 +31,6 @@ class MNLIDataset(GLUEAbstractDataset):
super
().
__init__
(
'MNLI'
,
name
,
datapaths
,
tokenizer
,
max_seq_length
)
def
process_samples_from_single_path
(
self
,
filename
):
""""Implement abstract method."""
print_rank_0
(
' > Processing {} ...'
.
format
(
filename
))
...
...
tasks/glue/qqp.py
View file @
5e56e563
# coding=utf-8
# Copyright (c) 20
19
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 20
20
, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -31,7 +31,6 @@ class QQPDataset(GLUEAbstractDataset):
super
().
__init__
(
'QQP'
,
name
,
datapaths
,
tokenizer
,
max_seq_length
)
def
process_samples_from_single_path
(
self
,
filename
):
""""Implement abstract method."""
print_rank_0
(
' > Processing {} ...'
.
format
(
filename
))
...
...
tasks/main.py
View file @
5e56e563
# coding=utf-8
# Copyright (c) 20
19
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 20
20
, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -46,7 +46,7 @@ def get_tasks_args(parser):
group
.
add_argument
(
'--overlapping-eval'
,
type
=
int
,
default
=
32
,
help
=
'Sliding window for overlapping evaluation.'
)
group
.
add_argument
(
'--strict-lambada'
,
action
=
'store_true'
,
help
=
'Use more difficult formulation of lambada.'
)
help
=
'Use more difficult formulation of lambada.'
)
return
parser
...
...
tasks/race/data.py
View file @
5e56e563
...
...
@@ -39,16 +39,13 @@ class RaceDataset(Dataset):
print_rank_0
(
' >> total number of samples: {}'
.
format
(
len
(
self
.
samples
)))
def
__len__
(
self
):
return
len
(
self
.
samples
)
def
__getitem__
(
self
,
idx
):
return
self
.
samples
[
idx
]
def
process_single_datapath
(
datapath
,
tokenizer
,
max_qa_length
,
max_seq_length
):
"""Read in RACE files, combine, clean-up, tokenize, and convert to
samples."""
...
...
tasks/race/finetune.py
View file @
5e56e563
# coding=utf-8
# Copyright (c) 20
19
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 20
20
, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
tasks/zeroshot_gpt2/datasets.py
View file @
5e56e563
# coding=utf-8
# Copyright (c) 20
19
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 20
20
, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -64,12 +64,12 @@ class _LMDataset(torch.utils.data.Dataset):
def
__getitem__
(
self
,
idx
):
start_idx
=
idx
*
self
.
overalapping_eval
end_idx
=
start_idx
+
self
.
seq_len
tokens
=
self
.
tokens
[
start_idx
:
end_idx
+
1
]
tokens
=
self
.
tokens
[
start_idx
:
end_idx
+
1
]
num_tokens
=
len
(
tokens
)
pad_mask
=
[
1
]
*
num_tokens
if
num_tokens
<
self
.
seq_len
+
1
:
num_pad
=
(
self
.
seq_len
+
1
-
num_tokens
)
pad_mask
+=
[
0
]
*
(
num_pad
)
pad_mask
=
[
1
]
*
num_tokens
if
num_tokens
<
self
.
seq_len
+
1
:
num_pad
=
(
self
.
seq_len
+
1
-
num_tokens
)
pad_mask
+=
[
0
]
*
(
num_pad
)
tokens
+=
[
self
.
pad_idx
]
*
num_pad
pad_mask
=
np
.
array
(
pad_mask
[
1
:])
if
self
.
overalapping_eval
!=
self
.
seq_len
and
idx
!=
0
:
...
...
@@ -103,7 +103,7 @@ class _LambadaDataset(torch.utils.data.Dataset):
last_token
=
text
.
split
()[
-
1
]
start_idx
=
text
.
rfind
(
last_token
)
beginning_tokens
=
self
.
tokenizer
.
tokenize
(
text
[:
start_idx
].
strip
())
last_token
=
self
.
tokenizer
.
tokenize
(
' '
+
last_token
)
last_token
=
self
.
tokenizer
.
tokenize
(
' '
+
last_token
)
return
beginning_tokens
,
last_token
def
__len__
(
self
):
...
...
@@ -112,14 +112,14 @@ class _LambadaDataset(torch.utils.data.Dataset):
def
__getitem__
(
self
,
idx
):
tokens
=
self
.
tokens
[
idx
]
num_tokens
=
len
(
tokens
)
pad_mask
=
[
0
]
*
num_tokens
pad_mask
=
[
0
]
*
num_tokens
labels
=
self
.
labels
[
idx
]
pad_mask
+=
[
1
]
*
len
(
labels
)
tokens
=
tokens
+
labels
pad_mask
+=
[
1
]
*
len
(
labels
)
tokens
=
tokens
+
labels
num_tokens
=
len
(
tokens
)
if
num_tokens
<
self
.
seq_len
+
1
:
num_pad
=
(
self
.
seq_len
+
1
-
num_tokens
)
pad_mask
+=
[
0
]
*
(
num_pad
)
if
num_tokens
<
self
.
seq_len
+
1
:
num_pad
=
(
self
.
seq_len
+
1
-
num_tokens
)
pad_mask
+=
[
0
]
*
(
num_pad
)
tokens
+=
[
self
.
pad_idx
]
*
num_pad
pad_mask
=
np
.
array
(
pad_mask
[
1
:])
...
...
tasks/zeroshot_gpt2/detokenizer.py
View file @
5e56e563
# coding=utf-8
# Copyright (c) 20
19
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 20
20
, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -19,64 +19,62 @@ import re
def
ptb_detokenizer
(
string
):
string
=
string
.
replace
(
" '"
,
"'"
)
string
=
string
.
replace
(
"
\n
"
,
"
\n
"
)
string
=
string
.
replace
(
"
\n
"
,
"
\n
"
)
string
=
string
.
replace
(
" n't"
,
"n't"
)
string
=
string
.
replace
(
" N "
,
"1 "
)
string
=
string
.
replace
(
"$ 1"
,
"$1"
)
string
=
string
.
replace
(
"# 1"
,
"#1"
)
return
string
string
=
string
.
replace
(
" '"
,
"'"
)
string
=
string
.
replace
(
"
\n
"
,
"
\n
"
)
string
=
string
.
replace
(
"
\n
"
,
"
\n
"
)
string
=
string
.
replace
(
" n't"
,
"n't"
)
string
=
string
.
replace
(
" N "
,
"1 "
)
string
=
string
.
replace
(
"$ 1"
,
"$1"
)
string
=
string
.
replace
(
"# 1"
,
"#1"
)
return
string
def
wikitext_detokenizer
(
string
):
#
contractions
string
=
string
.
replace
(
"s '"
,
"s'"
)
string
=
re
.
sub
(
r
"/' [0-9]/"
,
r
"/'[0-9]/"
,
string
)
# number separators
string
=
string
.
replace
(
" @-@ "
,
"-"
)
string
=
string
.
replace
(
" @,@ "
,
","
)
string
=
string
.
replace
(
" @.@ "
,
"."
)
#
punctuation
string
=
string
.
replace
(
" : "
,
": "
)
string
=
string
.
replace
(
" ; "
,
"; "
)
string
=
string
.
replace
(
" . "
,
". "
)
string
=
string
.
replace
(
" ! "
,
"! "
)
string
=
string
.
replace
(
" ? "
,
"? "
)
string
=
string
.
replace
(
" , "
,
", "
)
# double brackets
string
=
re
.
sub
(
r
"\(\s*([^\)]*?)\s*\)"
,
r
"(\1)"
,
string
)
string
=
re
.
sub
(
r
"\[\s*([^\]]*?)\s*\]"
,
r
"[\1]"
,
string
)
string
=
re
.
sub
(
r
"{\s*([^}]*?)\s*}"
,
r
"{\1}"
,
string
)
string
=
re
.
sub
(
r
"\"\s*([^\"]*?)\s*\""
,
r
'"\1"'
,
string
)
string
=
re
.
sub
(
r
"'\s*([^']*?)\s*'"
,
r
"'\1'"
,
string
)
# miscellaneous
string
=
string
.
replace
(
"= = = ="
,
"===="
)
string
=
string
.
replace
(
"= = ="
,
"==="
)
string
=
string
.
replace
(
"= ="
,
"=="
)
string
=
string
.
replace
(
" "
+
chr
(
176
)
+
" "
,
chr
(
176
))
string
=
string
.
replace
(
"
\n
"
,
"
\n
"
)
string
=
string
.
replace
(
"
\n
"
,
"
\n
"
)
string
=
string
.
replace
(
" N "
,
" 1 "
)
string
=
string
.
replace
(
" 's"
,
"'s"
)
#
contractions
string
=
string
.
replace
(
"s '"
,
"s'"
)
string
=
re
.
sub
(
r
"/' [0-9]/"
,
r
"/'[0-9]/"
,
string
)
# number separators
string
=
string
.
replace
(
" @-@ "
,
"-"
)
string
=
string
.
replace
(
" @,@ "
,
","
)
string
=
string
.
replace
(
" @.@ "
,
"."
)
#
punctuation
string
=
string
.
replace
(
" : "
,
": "
)
string
=
string
.
replace
(
" ; "
,
"; "
)
string
=
string
.
replace
(
" . "
,
". "
)
string
=
string
.
replace
(
" ! "
,
"! "
)
string
=
string
.
replace
(
" ? "
,
"? "
)
string
=
string
.
replace
(
" , "
,
", "
)
# double brackets
string
=
re
.
sub
(
r
"\(\s*([^\)]*?)\s*\)"
,
r
"(\1)"
,
string
)
string
=
re
.
sub
(
r
"\[\s*([^\]]*?)\s*\]"
,
r
"[\1]"
,
string
)
string
=
re
.
sub
(
r
"{\s*([^}]*?)\s*}"
,
r
"{\1}"
,
string
)
string
=
re
.
sub
(
r
"\"\s*([^\"]*?)\s*\""
,
r
'"\1"'
,
string
)
string
=
re
.
sub
(
r
"'\s*([^']*?)\s*'"
,
r
"'\1'"
,
string
)
# miscellaneous
string
=
string
.
replace
(
"= = = ="
,
"===="
)
string
=
string
.
replace
(
"= = ="
,
"==="
)
string
=
string
.
replace
(
"= ="
,
"=="
)
string
=
string
.
replace
(
" "
+
chr
(
176
)
+
" "
,
chr
(
176
))
string
=
string
.
replace
(
"
\n
"
,
"
\n
"
)
string
=
string
.
replace
(
"
\n
"
,
"
\n
"
)
string
=
string
.
replace
(
" N "
,
" 1 "
)
string
=
string
.
replace
(
" 's"
,
"'s"
)
return
string
return
string
def
lambada_detokenizer
(
string
):
return
string
return
string
_DETOKENIZERS
=
{
'ptb'
:
ptb_detokenizer
,
'wikitext'
:
wikitext_detokenizer
,
'lambada'
:
lambada_detokenizer
,
'ptb'
:
ptb_detokenizer
,
'wikitext'
:
wikitext_detokenizer
,
'lambada'
:
lambada_detokenizer
,
}
def
get_detokenizer
(
path
):
for
key
in
_DETOKENIZERS
.
keys
():
if
key
in
path
:
return
_DETOKENIZERS
[
key
]
for
key
in
_DETOKENIZERS
.
keys
():
if
key
in
path
:
return
_DETOKENIZERS
[
key
]
tasks/zeroshot_gpt2/evaluate.py
View file @
5e56e563
# coding=utf-8
# Copyright (c) 20
19
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 20
20
, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -71,8 +71,7 @@ def process_batch(batch):
tokenizer
.
eod
,
args
.
reset_position_ids
,
args
.
reset_attention_mask
,
args
.
eod_mask_loss
,
args
.
fp16
)
args
.
eod_mask_loss
)
return
tokens
,
labels
,
attention_mask
,
position_ids
,
loss_mask
...
...
tools/generate_samples_gpt2.py
View file @
5e56e563
# coding=utf-8
# Copyright (c) 20
19
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 20
20
, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
tools/merge_mp_partitions.py
View file @
5e56e563
# coding=utf-8
# Copyright (c) 20
19
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 20
20
, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment