Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
124c3d6a
Unverified
Commit
124c3d6a
authored
Aug 25, 2020
by
Sylvain Gugger
Committed by
GitHub
Aug 25, 2020
Browse files
Add tokenizer to Trainer (#6689)
parent
abc02021
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
17 additions
and
4 deletions
+17
-4
src/transformers/trainer.py
src/transformers/trainer.py
+17
-4
No files found.
src/transformers/trainer.py
View file @
124c3d6a
...
@@ -19,7 +19,7 @@ from torch.utils.data.distributed import DistributedSampler
...
@@ -19,7 +19,7 @@ from torch.utils.data.distributed import DistributedSampler
from
torch.utils.data.sampler
import
RandomSampler
,
Sampler
,
SequentialSampler
from
torch.utils.data.sampler
import
RandomSampler
,
Sampler
,
SequentialSampler
from
tqdm.auto
import
tqdm
,
trange
from
tqdm.auto
import
tqdm
,
trange
from
.data.data_collator
import
DataCollator
,
default_data_collator
from
.data.data_collator
import
DataCollator
,
DataCollatorWithPadding
,
default_data_collator
from
.file_utils
import
is_nlp_available
,
is_torch_tpu_available
from
.file_utils
import
is_nlp_available
,
is_torch_tpu_available
from
.integrations
import
(
from
.integrations
import
(
default_hp_search_backend
,
default_hp_search_backend
,
...
@@ -31,6 +31,7 @@ from .integrations import (
...
@@ -31,6 +31,7 @@ from .integrations import (
)
)
from
.modeling_utils
import
PreTrainedModel
from
.modeling_utils
import
PreTrainedModel
from
.optimization
import
AdamW
,
get_linear_schedule_with_warmup
from
.optimization
import
AdamW
,
get_linear_schedule_with_warmup
from
.tokenization_utils_base
import
PreTrainedTokenizerBase
from
.trainer_utils
import
(
from
.trainer_utils
import
(
PREFIX_CHECKPOINT_DIR
,
PREFIX_CHECKPOINT_DIR
,
BestRun
,
BestRun
,
...
@@ -168,15 +169,20 @@ class Trainer:
...
@@ -168,15 +169,20 @@ class Trainer:
args (:class:`~transformers.TrainingArguments`, `optional`):
args (:class:`~transformers.TrainingArguments`, `optional`):
The arguments to tweak for training. Will default to a basic instance of :class:`~transformers.TrainingArguments`
The arguments to tweak for training. Will default to a basic instance of :class:`~transformers.TrainingArguments`
with the ``output_dir`` set to a directory named `tmp_trainer` in the current directory if not provided.
with the ``output_dir`` set to a directory named `tmp_trainer` in the current directory if not provided.
data_collator (:obj:`DataCollator`, `optional`
, defaults to :func:`~transformers.default_data_collator`
):
data_collator (:obj:`DataCollator`, `optional`):
The function to use to form a batch from a list of elements of :obj:`train_dataset` or
The function to use to form a batch from a list of elements of :obj:`train_dataset` or
:obj:`eval_dataset`.
:obj:`eval_dataset`. Will default to :func:`~transformers.default_data_collator` if no ``tokenizer`` is
provided, an instance of :func:`~transformers.DataCollatorWithPadding` otherwise.
train_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
train_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
The dataset to use for training. If it is an :obj:`nlp.Dataset`, columns not accepted by the
The dataset to use for training. If it is an :obj:`nlp.Dataset`, columns not accepted by the
``model.forward()`` method are automatically removed.
``model.forward()`` method are automatically removed.
eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
The dataset to use for evaluation. If it is an :obj:`nlp.Dataset`, columns not accepted by the
The dataset to use for evaluation. If it is an :obj:`nlp.Dataset`, columns not accepted by the
``model.forward()`` method are automatically removed.
``model.forward()`` method are automatically removed.
tokenizer (:class:`PreTrainedTokenizerBase`, `optional`):
The tokenizer used to preprocess the data. If provided, will be used to automatically pad the inputs the
maximum length when batching inputs, and it will be saved along the model to make it easier to rerun an
interrupted training or reuse the fine-tuned model.
model_init (:obj:`Callable[[], PreTrainedModel]`, `optional`):
model_init (:obj:`Callable[[], PreTrainedModel]`, `optional`):
A function that instantiates the model to be used. If provided, each call to
A function that instantiates the model to be used. If provided, each call to
:meth:`~transformers.Trainer.train` will start from a new instance of the model as given by this function.
:meth:`~transformers.Trainer.train` will start from a new instance of the model as given by this function.
...
@@ -200,6 +206,7 @@ class Trainer:
...
@@ -200,6 +206,7 @@ class Trainer:
data_collator
:
Optional
[
DataCollator
]
=
None
,
data_collator
:
Optional
[
DataCollator
]
=
None
,
train_dataset
:
Optional
[
Dataset
]
=
None
,
train_dataset
:
Optional
[
Dataset
]
=
None
,
eval_dataset
:
Optional
[
Dataset
]
=
None
,
eval_dataset
:
Optional
[
Dataset
]
=
None
,
tokenizer
:
Optional
[
"PreTrainedTokenizerBase"
]
=
None
,
model_init
:
Callable
[[],
PreTrainedModel
]
=
None
,
model_init
:
Callable
[[],
PreTrainedModel
]
=
None
,
compute_metrics
:
Optional
[
Callable
[[
EvalPrediction
],
Dict
]]
=
None
,
compute_metrics
:
Optional
[
Callable
[[
EvalPrediction
],
Dict
]]
=
None
,
tb_writer
:
Optional
[
"SummaryWriter"
]
=
None
,
tb_writer
:
Optional
[
"SummaryWriter"
]
=
None
,
...
@@ -218,9 +225,11 @@ class Trainer:
...
@@ -218,9 +225,11 @@ class Trainer:
if
model
is
None
and
model_init
is
not
None
:
if
model
is
None
and
model_init
is
not
None
:
model
=
model_init
()
model
=
model_init
()
self
.
model
=
model
.
to
(
args
.
device
)
if
model
is
not
None
else
None
self
.
model
=
model
.
to
(
args
.
device
)
if
model
is
not
None
else
None
self
.
data_collator
=
data_collator
if
data_collator
is
not
None
else
default_data_collator
default_collator
=
default_data_collator
if
tokenizer
is
None
else
DataCollatorWithPadding
(
tokenizer
)
self
.
data_collator
=
data_collator
if
data_collator
is
not
None
else
default_collator
self
.
train_dataset
=
train_dataset
self
.
train_dataset
=
train_dataset
self
.
eval_dataset
=
eval_dataset
self
.
eval_dataset
=
eval_dataset
self
.
tokenizer
=
tokenizer
self
.
model_init
=
model_init
self
.
model_init
=
model_init
self
.
compute_metrics
=
compute_metrics
self
.
compute_metrics
=
compute_metrics
self
.
optimizer
,
self
.
lr_scheduler
=
optimizers
self
.
optimizer
,
self
.
lr_scheduler
=
optimizers
...
@@ -1091,6 +1100,8 @@ class Trainer:
...
@@ -1091,6 +1100,8 @@ class Trainer:
xm
.
rendezvous
(
"saving_checkpoint"
)
xm
.
rendezvous
(
"saving_checkpoint"
)
self
.
model
.
save_pretrained
(
output_dir
)
self
.
model
.
save_pretrained
(
output_dir
)
if
self
.
tokenizer
is
not
None
:
self
.
tokenizer
.
save_pretrained
(
output_dir
)
def
_save
(
self
,
output_dir
:
Optional
[
str
]
=
None
):
def
_save
(
self
,
output_dir
:
Optional
[
str
]
=
None
):
output_dir
=
output_dir
if
output_dir
is
not
None
else
self
.
args
.
output_dir
output_dir
=
output_dir
if
output_dir
is
not
None
else
self
.
args
.
output_dir
...
@@ -1101,6 +1112,8 @@ class Trainer:
...
@@ -1101,6 +1112,8 @@ class Trainer:
if
not
isinstance
(
self
.
model
,
PreTrainedModel
):
if
not
isinstance
(
self
.
model
,
PreTrainedModel
):
raise
ValueError
(
"Trainer.model appears to not be a PreTrainedModel"
)
raise
ValueError
(
"Trainer.model appears to not be a PreTrainedModel"
)
self
.
model
.
save_pretrained
(
output_dir
)
self
.
model
.
save_pretrained
(
output_dir
)
if
self
.
tokenizer
is
not
None
:
self
.
tokenizer
.
save_pretrained
(
output_dir
)
# Good practice: save your training arguments together with the trained model
# Good practice: save your training arguments together with the trained model
torch
.
save
(
self
.
args
,
os
.
path
.
join
(
output_dir
,
"training_args.bin"
))
torch
.
save
(
self
.
args
,
os
.
path
.
join
(
output_dir
,
"training_args.bin"
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment