Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
5652f54a
Commit
5652f54a
authored
Aug 16, 2019
by
Lysandre
Browse files
Simplified data generator + better perplexity calculator
GPT-2 now obtains ~20 perplexity on WikiText-2
parent
71553480
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
10 additions
and
22 deletions
+10
-22
examples/run_generative_finetuning.py
examples/run_generative_finetuning.py
+5
-4
examples/utils_lm.py
examples/utils_lm.py
+5
-18
No files found.
examples/run_generative_finetuning.py
View file @
5652f54a
...
...
@@ -85,7 +85,7 @@ def train(args, train_dataset, model, tokenizer):
args
.
train_batch_size
=
args
.
per_gpu_train_batch_size
*
max
(
1
,
args
.
n_gpu
)
train_sampler
=
SequentialSampler
(
train_dataset
)
if
args
.
local_rank
==
-
1
else
DistributedSampler
(
train_dataset
)
train_dataloader
=
DataLoader
(
train_dataset
,
sampler
=
train_sampler
,
batch_size
=
args
.
train_batch_size
,
collate_fn
=
WikiTextDataset
.
collate
)
train_dataloader
=
DataLoader
(
train_dataset
,
sampler
=
train_sampler
,
batch_size
=
args
.
train_batch_size
)
if
args
.
max_steps
>
0
:
t_total
=
args
.
max_steps
...
...
@@ -209,7 +209,7 @@ def evaluate(args, model, tokenizer, prefix=""):
args
.
eval_batch_size
=
args
.
per_gpu_eval_batch_size
*
max
(
1
,
args
.
n_gpu
)
# Note that DistributedSampler samples randomly
eval_sampler
=
SequentialSampler
(
eval_dataset
)
if
args
.
local_rank
==
-
1
else
DistributedSampler
(
eval_dataset
)
eval_dataloader
=
DataLoader
(
eval_dataset
,
sampler
=
eval_sampler
,
batch_size
=
args
.
eval_batch_size
,
collate_fn
=
WikiTextDataset
.
collate
)
eval_dataloader
=
DataLoader
(
eval_dataset
,
sampler
=
eval_sampler
,
batch_size
=
args
.
eval_batch_size
)
# Eval!
logger
.
info
(
"***** Running evaluation {} *****"
.
format
(
prefix
))
...
...
@@ -217,12 +217,13 @@ def evaluate(args, model, tokenizer, prefix=""):
logger
.
info
(
" Batch size = %d"
,
args
.
eval_batch_size
)
eval_loss
=
0.0
nb_eval_steps
=
0
for
batch
in
tqdm
(
eval_dataloader
,
desc
=
"Evaluating"
):
model
.
eval
()
for
batch
in
tqdm
(
eval_dataloader
,
desc
=
"Evaluating"
):
batch
=
batch
.
to
(
args
.
device
)
with
torch
.
no_grad
():
outputs
=
model
(
batch
)
outputs
=
model
(
batch
,
masked_lm_labels
=
batch
)
if
args
.
mlm
else
model
(
batch
,
labels
=
batch
)
lm_loss
=
outputs
[
0
]
eval_loss
+=
lm_loss
.
mean
().
item
()
nb_eval_steps
+=
1
...
...
examples/utils_lm.py
View file @
5652f54a
...
...
@@ -6,34 +6,21 @@ import torch.nn.functional as F
class
WikiTextDataset
(
Dataset
):
def
__init__
(
self
,
tokenizer
,
file
=
'train'
,
directory
=
'wikitext'
,
max_context_length
=
512
):
def
__init__
(
self
,
tokenizer
,
file
=
'train'
,
directory
=
'wikitext'
,
max_context_length
=
1024
):
self
.
max_context_length
=
max_context_length
self
.
examples
=
[]
with
open
(
os
.
path
.
join
(
directory
,
f
"wiki.
{
file
}
.raw"
),
encoding
=
"utf-8"
)
as
f
:
text
=
f
.
read
()
spans
=
list
(
filter
(
lambda
item
:
len
(
item
)
>
120
,
text
.
split
(
"
\n
"
)
))
tokenized_text
=
tokenizer
.
convert_tokens_to_ids
(
tokenizer
.
tokenize
(
text
))
for
span
in
spans
:
span
=
tokenizer
.
encode
(
span
)
while
len
(
span
)
>
0
:
self
.
examples
.
append
(
span
[:
max_context_length
])
span
=
span
[
max_context_length
:]
# Randomly shuffle the examples array
random
.
shuffle
(
self
.
examples
)
# Sort the array by example length.
self
.
examples
.
sort
(
key
=
len
)
while
len
(
tokenized_text
)
>
max_context_length
:
self
.
examples
.
append
(
tokenized_text
[:
max_context_length
])
tokenized_text
=
tokenized_text
[
max_context_length
:]
def
__len__
(
self
):
return
len
(
self
.
examples
)
def
__getitem__
(
self
,
item
):
return
torch
.
tensor
(
self
.
examples
[
item
])
@
staticmethod
def
collate
(
values
):
stack
=
torch
.
stack
([
F
.
pad
(
value
,
(
len
(
values
[
-
1
])
-
value
.
size
(
0
),
0
),
"constant"
,
0
)
for
value
in
values
])
return
stack
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment