Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
f1b938fd
Unverified
Commit
f1b938fd
authored
Apr 20, 2021
by
Sylvain Gugger
Committed by
GitHub
Apr 20, 2021
Browse files
Update to use datasets remove_cloumns method (#11343)
* Update to use datasets remove_cloumns method * Quality
parent
cfd2eaa8
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
21 additions
and
27 deletions
+21
-27
examples/question-answering/requirements.txt
examples/question-answering/requirements.txt
+1
-1
examples/question-answering/trainer_qa.py
examples/question-answering/trainer_qa.py
+1
-12
src/transformers/trainer.py
src/transformers/trainer.py
+19
-14
No files found.
examples/question-answering/requirements.txt
View file @
f1b938fd
datasets >= 1.
2.1
datasets >= 1.
4.0
examples/question-answering/trainer_qa.py
View file @
f1b938fd
...
...
@@ -16,13 +16,10 @@
A subclass of `Trainer` specific to Question-Answering tasks
"""
from
transformers
import
Trainer
,
is_datasets_available
,
is_torch_tpu_available
from
transformers
import
Trainer
,
is_torch_tpu_available
from
transformers.trainer_utils
import
PredictionOutput
if
is_datasets_available
():
import
datasets
if
is_torch_tpu_available
():
import
torch_xla.core.xla_model
as
xm
import
torch_xla.debug.metrics
as
met
...
...
@@ -54,10 +51,6 @@ class QuestionAnsweringTrainer(Trainer):
finally
:
self
.
compute_metrics
=
compute_metrics
# We might have removed columns from the dataset so we put them back.
if
isinstance
(
eval_dataset
,
datasets
.
Dataset
):
eval_dataset
.
set_format
(
type
=
eval_dataset
.
format
[
"type"
],
columns
=
list
(
eval_dataset
.
features
.
keys
()))
if
self
.
post_process_function
is
not
None
and
self
.
compute_metrics
is
not
None
:
eval_preds
=
self
.
post_process_function
(
eval_examples
,
eval_dataset
,
output
.
predictions
)
metrics
=
self
.
compute_metrics
(
eval_preds
)
...
...
@@ -94,10 +87,6 @@ class QuestionAnsweringTrainer(Trainer):
if
self
.
post_process_function
is
None
or
self
.
compute_metrics
is
None
:
return
output
# We might have removed columns from the dataset so we put them back.
if
isinstance
(
test_dataset
,
datasets
.
Dataset
):
test_dataset
.
set_format
(
type
=
test_dataset
.
format
[
"type"
],
columns
=
list
(
test_dataset
.
features
.
keys
()))
eval_preds
=
self
.
post_process_function
(
test_examples
,
test_dataset
,
output
.
predictions
,
"test"
)
metrics
=
self
.
compute_metrics
(
eval_preds
)
...
...
src/transformers/trainer.py
View file @
f1b938fd
...
...
@@ -394,11 +394,6 @@ class Trainer:
raise
ValueError
(
"train_dataset does not implement __len__, max_steps has to be specified"
)
self
.
_signature_columns
=
None
if
is_datasets_available
():
if
isinstance
(
train_dataset
,
datasets
.
Dataset
):
self
.
_remove_unused_columns
(
self
.
train_dataset
,
description
=
"training"
)
if
isinstance
(
eval_dataset
,
datasets
.
Dataset
):
self
.
_remove_unused_columns
(
self
.
eval_dataset
,
description
=
"evaluation"
)
# Mixed precision setup
self
.
use_apex
=
False
...
...
@@ -503,7 +498,13 @@ class Trainer:
f
"`
{
self
.
model
.
__class__
.
__name__
}
.forward` and have been ignored:
{
', '
.
join
(
ignored_columns
)
}
."
)
dataset
.
set_format
(
type
=
dataset
.
format
[
"type"
],
columns
=
columns
,
format_kwargs
=
dataset
.
format
[
"format_kwargs"
])
if
version
.
parse
(
datasets
.
__version__
)
<
version
.
parse
(
"1.4.0"
):
dataset
.
set_format
(
type
=
dataset
.
format
[
"type"
],
columns
=
columns
,
format_kwargs
=
dataset
.
format
[
"format_kwargs"
]
)
return
dataset
else
:
return
dataset
.
remove_columns
(
ignored_columns
)
def
_get_train_sampler
(
self
)
->
Optional
[
torch
.
utils
.
data
.
sampler
.
Sampler
]:
if
not
isinstance
(
self
.
train_dataset
,
collections
.
abc
.
Sized
):
...
...
@@ -565,17 +566,20 @@ class Trainer:
if
self
.
train_dataset
is
None
:
raise
ValueError
(
"Trainer: training requires a train_dataset."
)
if
isinstance
(
self
.
train_dataset
,
torch
.
utils
.
data
.
dataset
.
IterableDataset
):
train_dataset
=
self
.
train_dataset
if
is_datasets_available
()
and
isinstance
(
train_dataset
,
datasets
.
Dataset
):
train_dataset
=
self
.
_remove_unused_columns
(
train_dataset
,
description
=
"training"
)
if
isinstance
(
train_dataset
,
torch
.
utils
.
data
.
dataset
.
IterableDataset
):
if
self
.
args
.
world_size
>
1
:
train_dataset
=
IterableDatasetShard
(
self
.
train_dataset
,
train_dataset
,
batch_size
=
self
.
args
.
train_batch_size
,
drop_last
=
self
.
args
.
dataloader_drop_last
,
num_processes
=
self
.
args
.
world_size
,
process_index
=
self
.
args
.
process_index
,
)
else
:
train_dataset
=
self
.
train_dataset
return
DataLoader
(
train_dataset
,
batch_size
=
self
.
args
.
train_batch_size
,
...
...
@@ -587,7 +591,7 @@ class Trainer:
train_sampler
=
self
.
_get_train_sampler
()
return
DataLoader
(
self
.
train_dataset
,
train_dataset
,
batch_size
=
self
.
args
.
train_batch_size
,
sampler
=
train_sampler
,
collate_fn
=
self
.
data_collator
,
...
...
@@ -638,10 +642,11 @@ class Trainer:
"""
if
eval_dataset
is
None
and
self
.
eval_dataset
is
None
:
raise
ValueError
(
"Trainer: evaluation requires an eval_dataset."
)
elif
is_datasets_available
()
and
isinstance
(
eval_dataset
,
datasets
.
Dataset
):
self
.
_remove_unused_columns
(
eval_dataset
,
description
=
"evaluation"
)
eval_dataset
=
eval_dataset
if
eval_dataset
is
not
None
else
self
.
eval_dataset
if
is_datasets_available
()
and
isinstance
(
eval_dataset
,
datasets
.
Dataset
):
eval_dataset
=
self
.
_remove_unused_columns
(
eval_dataset
,
description
=
"evaluation"
)
if
isinstance
(
eval_dataset
,
torch
.
utils
.
data
.
dataset
.
IterableDataset
):
if
self
.
args
.
world_size
>
1
:
eval_dataset
=
IterableDatasetShard
(
...
...
@@ -683,7 +688,7 @@ class Trainer:
``model.forward()`` method are automatically removed. It must implement :obj:`__len__`.
"""
if
is_datasets_available
()
and
isinstance
(
test_dataset
,
datasets
.
Dataset
):
self
.
_remove_unused_columns
(
test_dataset
,
description
=
"test"
)
test_dataset
=
self
.
_remove_unused_columns
(
test_dataset
,
description
=
"test"
)
if
isinstance
(
test_dataset
,
torch
.
utils
.
data
.
dataset
.
IterableDataset
):
if
self
.
args
.
world_size
>
1
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment