Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
f94f1c60
Commit
f94f1c60
authored
Aug 19, 2019
by
Lysandre
Browse files
Distributed training + tokenizer agnostic mask token
parent
5652f54a
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
29 additions
and
12 deletions
+29
-12
examples/run_generative_finetuning.py
examples/run_generative_finetuning.py
+3
-11
examples/utils_lm.py
examples/utils_lm.py
+26
-1
No files found.
examples/run_generative_finetuning.py
View file @
f94f1c60
...
...
@@ -39,12 +39,10 @@ from pytorch_transformers import (WEIGHTS_NAME, GPT2Config, GPT2LMHeadModel, GPT
BertConfig
,
BertForMaskedLM
,
BertTokenizer
,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
RobertaConfig
,
RobertaForMaskedLM
,
RobertaTokenizer
,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
)
from
pytorch_transformers
import
AdamW
,
WarmupLinearSchedule
logger
=
logging
.
getLogger
(
__name__
)
from
utils_lm
import
WikiTextDataset
logger
=
logging
.
getLogger
(
__name__
)
ALL_MODELS
=
sum
((
tuple
(
conf
.
pretrained_config_archive_map
.
keys
())
for
conf
in
(
GPT2Config
,)),
())
MODEL_CLASSES
=
{
'gpt2'
:
(
GPT2Config
,
GPT2LMHeadModel
,
GPT2Tokenizer
),
...
...
@@ -68,10 +66,7 @@ def mask_tokens(inputs, tokenizer, args):
labels
[
~
masked_indices
.
bool
()]
=
-
1
# We only compute loss on masked tokens
indices_replaced
=
torch
.
bernoulli
(
torch
.
full
(
labels
.
shape
,
0.8
)).
byte
()
&
masked_indices
if
args
.
model_name
==
"bert"
:
inputs
[
indices_replaced
.
bool
()]
=
tokenizer
.
vocab
[
"[MASK]"
]
# 80% of the time, replace masked input tokens with [MASK]
elif
args
.
model_name
==
"roberta"
:
inputs
[
indices_replaced
.
bool
()]
=
tokenizer
.
encoder
[
"<mask>"
]
# 80% of the time, replace masked input tokens with <mask>
inputs
[
indices_replaced
.
bool
()]
=
tokenizer
.
convert_tokens_to_ids
(
tokenizer
.
mask_token
)
# 80% of the time, replace masked input tokens with [MASK]
indices_random
=
(
torch
.
bernoulli
(
torch
.
full
(
labels
.
shape
,
0.5
)).
byte
()
&
masked_indices
&
~
indices_replaced
).
bool
()
random_words
=
torch
.
randint
(
args
.
num_embeddings
,
labels
.
shape
,
dtype
=
torch
.
long
)
inputs
[
indices_random
]
=
random_words
[
...
...
@@ -246,10 +241,7 @@ def evaluate(args, model, tokenizer, prefix=""):
def
load_and_cache_examples
(
args
,
tokenizer
,
evaluate
=
False
):
if
args
.
local_rank
not
in
[
-
1
,
0
]:
torch
.
distributed
.
barrier
()
# Make sure only the first process in distributed training process the dataset, and the others will use the cache
dataset
=
WikiTextDataset
(
tokenizer
,
file
=
"test"
if
evaluate
else
"train"
,
directory
=
args
.
data_dir
)
dataset
=
WikiTextDataset
(
args
,
tokenizer
,
file
=
"test"
if
evaluate
else
"train"
,
directory
=
args
.
data_dir
)
return
dataset
...
...
examples/utils_lm.py
View file @
f94f1c60
...
...
@@ -3,10 +3,27 @@ import os
import
random
import
torch
import
torch.nn.functional
as
F
import
logging
import
pickle
logger
=
logging
.
getLogger
(
__name__
)
class
WikiTextDataset
(
Dataset
):
def
__init__
(
self
,
tokenizer
,
file
=
'train'
,
directory
=
'wikitext'
,
max_context_length
=
1024
):
def
__init__
(
self
,
args
,
tokenizer
,
file
=
'train'
,
directory
=
'wikitext'
,
max_context_length
=
512
,
cache
=
None
):
if
args
.
local_rank
not
in
[
-
1
,
0
]:
torch
.
distributed
.
barrier
()
# Make sure only the first process in distributed training process the dataset, and the others will use the cache
cached_features_file
=
os
.
path
.
join
(
args
.
data_dir
,
f
'cached_lm_
{
file
}
_
{
args
.
max_seq_length
}
'
)
if
os
.
path
.
exists
(
cached_features_file
):
logger
.
info
(
"Loading features from cached file %s"
,
cached_features_file
)
with
open
(
cached_features_file
,
'rb'
)
as
handle
:
self
.
examples
=
pickle
.
load
(
handle
)
else
:
logger
.
info
(
"Creating features from dataset file at %s"
,
args
.
data_dir
)
self
.
max_context_length
=
max_context_length
self
.
examples
=
[]
...
...
@@ -19,6 +36,14 @@ class WikiTextDataset(Dataset):
self
.
examples
.
append
(
tokenized_text
[:
max_context_length
])
tokenized_text
=
tokenized_text
[
max_context_length
:]
if
args
.
local_rank
in
[
-
1
,
0
]:
logger
.
info
(
"Saving features into cached file %s"
,
cached_features_file
)
with
open
(
cached_features_file
,
'wb'
)
as
handle
:
pickle
.
dump
(
self
.
examples
,
handle
,
protocol
=
pickle
.
HIGHEST_PROTOCOL
)
if
args
.
local_rank
==
0
:
torch
.
distributed
.
barrier
()
def
__len__
(
self
):
return
len
(
self
.
examples
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment