Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
1a8e87be
Commit
1a8e87be
authored
Jan 18, 2020
by
Julien Chaumond
Browse files
Line-by-line text dataset (including padding)
parent
b94cf7fa
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
48 additions
and
11 deletions
+48
-11
examples/run_lm_finetuning.py
examples/run_lm_finetuning.py
+48
-11
No files found.
examples/run_lm_finetuning.py
View file @
1a8e87be
...
...
@@ -32,6 +32,7 @@ from typing import Dict, List, Tuple
import
numpy
as
np
import
torch
from
torch.nn.utils.rnn
import
pad_sequence
from
torch.utils.data
import
DataLoader
,
Dataset
,
RandomSampler
,
SequentialSampler
from
torch.utils.data.distributed
import
DistributedSampler
from
tqdm
import
tqdm
,
trange
...
...
@@ -83,7 +84,7 @@ MODEL_CLASSES = {
class
TextDataset
(
Dataset
):
def
__init__
(
self
,
tokenizer
:
PreTrainedTokenizer
,
args
,
file_path
=
"train"
,
block_size
=
512
):
def
__init__
(
self
,
tokenizer
:
PreTrainedTokenizer
,
args
,
file_path
:
str
,
block_size
=
512
):
assert
os
.
path
.
isfile
(
file_path
)
directory
,
filename
=
os
.
path
.
split
(
file_path
)
cached_features_file
=
os
.
path
.
join
(
...
...
@@ -120,13 +121,32 @@ class TextDataset(Dataset):
return
torch
.
tensor
(
self
.
examples
[
item
])
class
LineByLineTextDataset
(
Dataset
):
def
__init__
(
self
,
tokenizer
:
PreTrainedTokenizer
,
args
,
file_path
:
str
,
block_size
=
512
):
assert
os
.
path
.
isfile
(
file_path
)
# Here, we do not cache the features, operating under the assumption
# that we will soon use fast multithreaded tokenizers from the
# `tokenizers` repo everywhere =)
logger
.
info
(
"Creating features from dataset file at %s"
,
file_path
)
with
open
(
file_path
,
encoding
=
"utf-8"
)
as
f
:
lines
=
[
line
for
line
in
f
.
read
().
splitlines
()
if
len
(
line
)
>
0
]
self
.
examples
=
tokenizer
.
batch_encode_plus
(
lines
,
max_length
=
block_size
)[
"input_ids"
]
def
__len__
(
self
):
return
len
(
self
.
examples
)
def
__getitem__
(
self
,
i
):
return
torch
.
tensor
(
self
.
examples
[
i
])
def
load_and_cache_examples
(
args
,
tokenizer
,
evaluate
=
False
):
return
TextDataset
(
tokenizer
,
args
,
file_path
=
args
.
eval_data_file
if
evaluate
else
args
.
train_data_file
,
block_size
=
args
.
block_size
,
)
file_path
=
args
.
eval_data_file
if
evaluate
else
args
.
train_data_file
if
args
.
line_by_line
:
return
LineByLineTextDataset
(
tokenizer
,
args
,
file_path
=
file_path
,
block_size
=
args
.
block_size
)
else
:
return
TextDataset
(
tokenizer
,
args
,
file_path
=
file_path
,
block_size
=
args
.
block_size
)
def
set_seed
(
args
):
...
...
@@ -182,6 +202,8 @@ def mask_tokens(inputs: torch.Tensor, tokenizer: PreTrainedTokenizer, args) -> T
tokenizer
.
get_special_tokens_mask
(
val
,
already_has_special_tokens
=
True
)
for
val
in
labels
.
tolist
()
]
probability_matrix
.
masked_fill_
(
torch
.
tensor
(
special_tokens_mask
,
dtype
=
torch
.
bool
),
value
=
0.0
)
padding_mask
=
labels
.
eq
(
tokenizer
.
pad_token_id
)
probability_matrix
.
masked_fill_
(
padding_mask
,
value
=
0.0
)
masked_indices
=
torch
.
bernoulli
(
probability_matrix
).
bool
()
labels
[
~
masked_indices
]
=
-
100
# We only compute loss on masked tokens
...
...
@@ -204,8 +226,14 @@ def train(args, train_dataset, model: PreTrainedModel, tokenizer: PreTrainedToke
tb_writer
=
SummaryWriter
()
args
.
train_batch_size
=
args
.
per_gpu_train_batch_size
*
max
(
1
,
args
.
n_gpu
)
def
collate
(
examples
:
List
[
torch
.
Tensor
]):
return
pad_sequence
(
examples
,
batch_first
=
True
,
padding_value
=
tokenizer
.
pad_token_id
)
train_sampler
=
RandomSampler
(
train_dataset
)
if
args
.
local_rank
==
-
1
else
DistributedSampler
(
train_dataset
)
train_dataloader
=
DataLoader
(
train_dataset
,
sampler
=
train_sampler
,
batch_size
=
args
.
train_batch_size
)
train_dataloader
=
DataLoader
(
train_dataset
,
sampler
=
train_sampler
,
batch_size
=
args
.
train_batch_size
,
collate_fn
=
collate
)
if
args
.
max_steps
>
0
:
t_total
=
args
.
max_steps
...
...
@@ -391,8 +419,14 @@ def evaluate(args, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, prefi
args
.
eval_batch_size
=
args
.
per_gpu_eval_batch_size
*
max
(
1
,
args
.
n_gpu
)
# Note that DistributedSampler samples randomly
def
collate
(
examples
:
List
[
torch
.
Tensor
]):
return
pad_sequence
(
examples
,
batch_first
=
True
,
padding_value
=
tokenizer
.
pad_token_id
)
eval_sampler
=
SequentialSampler
(
eval_dataset
)
eval_dataloader
=
DataLoader
(
eval_dataset
,
sampler
=
eval_sampler
,
batch_size
=
args
.
eval_batch_size
)
eval_dataloader
=
DataLoader
(
eval_dataset
,
sampler
=
eval_sampler
,
batch_size
=
args
.
eval_batch_size
,
collate_fn
=
collate
)
# multi-gpu evaluate
if
args
.
n_gpu
>
1
:
...
...
@@ -456,11 +490,14 @@ def main():
type
=
str
,
help
=
"An optional input evaluation data file to evaluate the perplexity on (a text file)."
,
)
parser
.
add_argument
(
"--line_by_line"
,
action
=
"store_true"
,
help
=
"Whether distinct lines of text in the dataset are to be handled as distinct sequences."
,
)
parser
.
add_argument
(
"--should_continue"
,
action
=
"store_true"
,
help
=
"Whether to continue from latest checkpoint in output_dir"
)
parser
.
add_argument
(
"--model_name_or_path"
,
default
=
None
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment