Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
3f2e6368
Unverified
Commit
3f2e6368
authored
Mar 01, 2022
by
Joao Gante
Committed by
GitHub
Mar 01, 2022
Browse files
Update TF LM examples (#15855)
parent
54f0db40
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
46 additions
and
132 deletions
+46
-132
examples/tensorflow/language-modeling/run_clm.py
examples/tensorflow/language-modeling/run_clm.py
+22
-49
examples/tensorflow/language-modeling/run_mlm.py
examples/tensorflow/language-modeling/run_mlm.py
+24
-83
No files found.
examples/tensorflow/language-modeling/run_clm.py
View file @
3f2e6368
...
...
@@ -29,13 +29,11 @@ import os
import
random
import
sys
from
dataclasses
import
dataclass
,
field
from
functools
import
partial
from
itertools
import
chain
from
pathlib
import
Path
from
typing
import
Optional
import
datasets
import
numpy
as
np
import
tensorflow
as
tf
from
datasets
import
load_dataset
from
sklearn.model_selection
import
train_test_split
...
...
@@ -48,6 +46,7 @@ from transformers import (
TF_MODEL_FOR_CAUSAL_LM_MAPPING
,
AutoConfig
,
AutoTokenizer
,
DefaultDataCollator
,
HfArgumentParser
,
TFAutoModelForCausalLM
,
TFTrainingArguments
,
...
...
@@ -160,9 +159,6 @@ class DataTrainingArguments:
default
=
None
,
metadata
=
{
"help"
:
"The number of processes to use for the preprocessing."
},
)
mlm_probability
:
float
=
field
(
default
=
0.15
,
metadata
=
{
"help"
:
"Ratio of tokens to mask for masked language modeling loss"
}
)
line_by_line
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether distinct lines of text in the dataset are to be handled as distinct sequences."
},
...
...
@@ -212,20 +208,6 @@ class SavePretrainedCallback(tf.keras.callbacks.Callback):
self
.
model
.
save_pretrained
(
self
.
output_dir
)
# endregion
# region Data generator
def
sample_generator
(
dataset
,
tokenizer
):
# Trim off the last partial batch if present
sample_ordering
=
np
.
random
.
permutation
(
len
(
dataset
))
for
sample_idx
in
sample_ordering
:
example
=
dataset
[
int
(
sample_idx
)]
# Handle dicts with proper padding and conversion to tensor.
example
=
{
key
:
tf
.
convert_to_tensor
(
arr
,
dtype_hint
=
tf
.
int64
)
for
key
,
arr
in
example
.
items
()}
yield
example
,
example
[
"labels"
]
# TF needs some kind of labels, even if we don't use them
return
# endregion
...
...
@@ -457,34 +439,27 @@ def main():
# region TF Dataset preparation
num_replicas
=
training_args
.
strategy
.
num_replicas_in_sync
train_generator
=
partial
(
sample_generator
,
train_dataset
,
tokenizer
)
train_signature
=
{
feature
:
tf
.
TensorSpec
(
shape
=
(
None
,),
dtype
=
tf
.
int64
)
for
feature
in
train_dataset
.
features
if
feature
!=
"special_tokens_mask"
}
train_sig
=
(
train_signature
,
train_signature
[
"labels"
])
data_collator
=
DefaultDataCollator
(
return_tensors
=
"tf"
)
options
=
tf
.
data
.
Options
()
options
.
experimental_distribute
.
auto_shard_policy
=
tf
.
data
.
experimental
.
AutoShardPolicy
.
OFF
tf_train_dataset
=
(
tf
.
data
.
Dataset
.
from_generator
(
train_generator
,
output_signature
=
train_sig
)
.
with_options
(
options
)
.
batch
(
batch_size
=
num_replicas
*
training_args
.
per_device_train_batch_size
,
drop_remainder
=
True
)
.
repeat
(
int
(
training_args
.
num_train_epochs
))
)
eval_generator
=
partial
(
sample_generator
,
eval_dataset
,
tokenizer
)
eval_signature
=
{
feature
:
tf
.
TensorSpec
(
shape
=
(
None
,),
dtype
=
tf
.
int64
)
for
feature
in
eval_dataset
.
features
if
feature
!=
"special_tokens_mask"
}
eval_sig
=
(
eval_signature
,
eval_signature
[
"labels"
])
tf_eval_dataset
=
(
tf
.
data
.
Dataset
.
from_generator
(
eval_generator
,
output_signature
=
eval_sig
)
.
with_options
(
options
)
.
batch
(
batch_size
=
num_replicas
*
training_args
.
per_device_eval_batch_size
,
drop_remainder
=
True
)
.
repeat
(
int
(
training_args
.
num_train_epochs
))
)
tf_train_dataset
=
train_dataset
.
to_tf_dataset
(
# labels are passed as input, as we will use the model's internal loss
columns
=
[
col
for
col
in
train_dataset
.
features
if
col
!=
"special_tokens_mask"
],
shuffle
=
True
,
batch_size
=
num_replicas
*
training_args
.
per_device_train_batch_size
,
collate_fn
=
data_collator
,
drop_remainder
=
True
,
).
with_options
(
options
)
tf_eval_dataset
=
eval_dataset
.
to_tf_dataset
(
# labels are passed as input, as we will use the model's internal loss
columns
=
[
col
for
col
in
eval_dataset
.
features
if
col
!=
"special_tokens_mask"
],
shuffle
=
False
,
batch_size
=
num_replicas
*
training_args
.
per_device_train_batch_size
,
collate_fn
=
data_collator
,
drop_remainder
=
True
,
).
with_options
(
options
)
# endregion
# region Optimizer and loss
...
...
@@ -500,10 +475,8 @@ def main():
weight_decay_rate
=
training_args
.
weight_decay
,
)
def
dummy_loss
(
y_true
,
y_pred
):
return
tf
.
reduce_mean
(
y_pred
)
model
.
compile
(
optimizer
=
optimizer
,
loss
=
{
"loss"
:
dummy_loss
})
# no user-specified loss = will use the model internal loss
model
.
compile
(
optimizer
=
optimizer
)
# endregion
# region Training and validation
...
...
examples/tensorflow/language-modeling/run_mlm.py
View file @
3f2e6368
...
...
@@ -31,13 +31,11 @@ import os
import
random
import
sys
from
dataclasses
import
dataclass
,
field
from
functools
import
partial
from
itertools
import
chain
from
pathlib
import
Path
from
typing
import
Optional
import
datasets
import
numpy
as
np
import
tensorflow
as
tf
from
datasets
import
load_dataset
from
sklearn.model_selection
import
train_test_split
...
...
@@ -50,6 +48,7 @@ from transformers import (
TF_MODEL_FOR_MASKED_LM_MAPPING
,
AutoConfig
,
AutoTokenizer
,
DataCollatorForLanguageModeling
,
HfArgumentParser
,
TFAutoModelForMaskedLM
,
TFTrainingArguments
,
...
...
@@ -217,56 +216,6 @@ class SavePretrainedCallback(tf.keras.callbacks.Callback):
self
.
model
.
save_pretrained
(
self
.
output_dir
)
# endregion
# region Data generator
def
sample_generator
(
dataset
,
tokenizer
,
mlm_probability
=
0.15
,
pad_to_multiple_of
=
None
):
if
tokenizer
.
mask_token
is
None
:
raise
ValueError
(
"This tokenizer does not have a mask token which is necessary for masked language modeling. "
)
# Trim off the last partial batch if present
sample_ordering
=
np
.
random
.
permutation
(
len
(
dataset
))
for
sample_idx
in
sample_ordering
:
example
=
dataset
[
int
(
sample_idx
)]
# Handle dicts with proper padding and conversion to tensor.
example
=
tokenizer
.
pad
(
example
,
return_tensors
=
"np"
,
pad_to_multiple_of
=
pad_to_multiple_of
)
special_tokens_mask
=
example
.
pop
(
"special_tokens_mask"
,
None
)
example
[
"input_ids"
],
example
[
"labels"
]
=
mask_tokens
(
example
[
"input_ids"
],
mlm_probability
,
tokenizer
,
special_tokens_mask
=
special_tokens_mask
)
if
tokenizer
.
pad_token_id
is
not
None
:
example
[
"labels"
][
example
[
"labels"
]
==
tokenizer
.
pad_token_id
]
=
-
100
example
=
{
key
:
tf
.
convert_to_tensor
(
arr
)
for
key
,
arr
in
example
.
items
()}
yield
example
,
example
[
"labels"
]
# TF needs some kind of labels, even if we don't use them
return
def
mask_tokens
(
inputs
,
mlm_probability
,
tokenizer
,
special_tokens_mask
):
"""
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
"""
labels
=
np
.
copy
(
inputs
)
# We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
probability_matrix
=
np
.
random
.
random_sample
(
labels
.
shape
)
special_tokens_mask
=
special_tokens_mask
.
astype
(
np
.
bool_
)
probability_matrix
[
special_tokens_mask
]
=
0.0
masked_indices
=
probability_matrix
>
(
1
-
mlm_probability
)
labels
[
~
masked_indices
]
=
-
100
# We only compute loss on masked tokens
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced
=
(
np
.
random
.
random_sample
(
labels
.
shape
)
<
0.8
)
&
masked_indices
inputs
[
indices_replaced
]
=
tokenizer
.
convert_tokens_to_ids
(
tokenizer
.
mask_token
)
# 10% of the time, we replace masked input tokens with random word
indices_random
=
(
np
.
random
.
random_sample
(
labels
.
shape
)
<
0.5
)
&
masked_indices
&
~
indices_replaced
random_words
=
np
.
random
.
randint
(
low
=
0
,
high
=
len
(
tokenizer
),
size
=
np
.
count_nonzero
(
indices_random
),
dtype
=
np
.
int64
)
inputs
[
indices_random
]
=
random_words
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
return
inputs
,
labels
# endregion
...
...
@@ -531,35 +480,29 @@ def main():
# region TF Dataset preparation
num_replicas
=
training_args
.
strategy
.
num_replicas_in_sync
train_generator
=
partial
(
sample_generator
,
train_dataset
,
tokenizer
)
train_signature
=
{
feature
:
tf
.
TensorSpec
(
shape
=
(
None
,),
dtype
=
tf
.
int64
)
for
feature
in
train_dataset
.
features
if
feature
!=
"special_tokens_mask"
}
train_signature
[
"labels"
]
=
train_signature
[
"input_ids"
]
train_signature
=
(
train_signature
,
train_signature
[
"labels"
])
data_collator
=
DataCollatorForLanguageModeling
(
tokenizer
=
tokenizer
,
mlm_probability
=
data_args
.
mlm_probability
,
return_tensors
=
"tf"
)
options
=
tf
.
data
.
Options
()
options
.
experimental_distribute
.
auto_shard_policy
=
tf
.
data
.
experimental
.
AutoShardPolicy
.
OFF
tf_train_dataset
=
(
tf
.
data
.
Dataset
.
from_generator
(
train_generator
,
output_signature
=
train_signature
)
.
with_options
(
options
)
.
batch
(
batch_size
=
num_replicas
*
training_args
.
per_device_train_batch_size
,
drop_remainder
=
True
)
.
repeat
(
int
(
training_args
.
num_train_epochs
))
)
eval_generator
=
partial
(
sample_generator
,
eval_dataset
,
tokenizer
)
eval_signature
=
{
feature
:
tf
.
TensorSpec
(
shape
=
(
None
,),
dtype
=
tf
.
int64
)
for
feature
in
eval_dataset
.
features
if
feature
!=
"special_tokens_mask"
}
eval_signature
[
"labels"
]
=
eval_signature
[
"input_ids"
]
eval_signature
=
(
eval_signature
,
eval_signature
[
"labels"
])
tf_eval_dataset
=
(
tf
.
data
.
Dataset
.
from_generator
(
eval_generator
,
output_signature
=
eval_signature
)
.
with_options
(
options
)
.
batch
(
batch_size
=
num_replicas
*
training_args
.
per_device_eval_batch_size
,
drop_remainder
=
True
)
)
tf_train_dataset
=
train_dataset
.
to_tf_dataset
(
# labels are passed as input, as we will use the model's internal loss
columns
=
[
col
for
col
in
train_dataset
.
features
if
col
!=
"special_tokens_mask"
]
+
[
"labels"
],
shuffle
=
True
,
batch_size
=
num_replicas
*
training_args
.
per_device_train_batch_size
,
collate_fn
=
data_collator
,
drop_remainder
=
True
,
).
with_options
(
options
)
tf_eval_dataset
=
eval_dataset
.
to_tf_dataset
(
# labels are passed as input, as we will use the model's internal loss
columns
=
[
col
for
col
in
eval_dataset
.
features
if
col
!=
"special_tokens_mask"
]
+
[
"labels"
],
shuffle
=
False
,
batch_size
=
num_replicas
*
training_args
.
per_device_train_batch_size
,
collate_fn
=
data_collator
,
drop_remainder
=
True
,
).
with_options
(
options
)
# endregion
# region Optimizer and loss
...
...
@@ -575,10 +518,8 @@ def main():
weight_decay_rate
=
training_args
.
weight_decay
,
)
def
dummy_loss
(
y_true
,
y_pred
):
return
tf
.
reduce_mean
(
y_pred
)
model
.
compile
(
optimizer
=
optimizer
,
loss
=
{
"loss"
:
dummy_loss
})
# no user-specified loss = will use the model internal loss
model
.
compile
(
optimizer
=
optimizer
)
# endregion
# region Training and validation
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment