Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
7b95825d
Unverified
Commit
7b95825d
authored
May 11, 2022
by
Antoni Baum
Committed by
GitHub
May 11, 2022
Browse files
Remove columns before passing to data collator (#17187)
parent
934e21cd
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
7 additions
and
5 deletions
+7
-5
src/transformers/trainer.py
src/transformers/trainer.py
+4
-3
src/transformers/trainer_utils.py
src/transformers/trainer_utils.py
+3
-2
No files found.
src/transformers/trainer.py
View file @
7b95825d
...
@@ -607,13 +607,14 @@ class Trainer:
...
@@ -607,13 +607,14 @@ class Trainer:
# Inspect model forward signature to keep only the arguments it accepts.
# Inspect model forward signature to keep only the arguments it accepts.
signature
=
inspect
.
signature
(
self
.
model
.
forward
)
signature
=
inspect
.
signature
(
self
.
model
.
forward
)
self
.
_signature_columns
=
list
(
signature
.
parameters
.
keys
())
self
.
_signature_columns
=
list
(
signature
.
parameters
.
keys
())
# Labels may be named label or label_ids, the default data collator handles that.
self
.
_signature_columns
+=
list
(
set
([
"label"
,
"label_ids"
]
+
self
.
label_names
))
def
_remove_unused_columns
(
self
,
dataset
:
"datasets.Dataset"
,
description
:
Optional
[
str
]
=
None
):
def
_remove_unused_columns
(
self
,
dataset
:
"datasets.Dataset"
,
description
:
Optional
[
str
]
=
None
):
if
not
self
.
args
.
remove_unused_columns
:
if
not
self
.
args
.
remove_unused_columns
:
return
dataset
return
dataset
self
.
_set_signature_columns_if_needed
()
self
.
_set_signature_columns_if_needed
()
# Labels may be named label or label_ids, the default data collator handles that.
signature_columns
=
self
.
_signature_columns
signature_columns
=
self
.
_signature_columns
+
[
"label"
,
"label_ids"
]
ignored_columns
=
list
(
set
(
dataset
.
column_names
)
-
set
(
signature_columns
))
ignored_columns
=
list
(
set
(
dataset
.
column_names
)
-
set
(
signature_columns
))
if
len
(
ignored_columns
)
>
0
:
if
len
(
ignored_columns
)
>
0
:
...
@@ -642,7 +643,7 @@ class Trainer:
...
@@ -642,7 +643,7 @@ class Trainer:
if
not
self
.
args
.
remove_unused_columns
:
if
not
self
.
args
.
remove_unused_columns
:
return
data_collator
return
data_collator
self
.
_set_signature_columns_if_needed
()
self
.
_set_signature_columns_if_needed
()
signature_columns
=
self
.
_signature_columns
+
self
.
label_names
signature_columns
=
self
.
_signature_columns
remove_columns_collator
=
RemoveColumnsCollator
(
remove_columns_collator
=
RemoveColumnsCollator
(
data_collator
=
data_collator
,
data_collator
=
data_collator
,
...
...
src/transformers/trainer_utils.py
View file @
7b95825d
...
@@ -658,7 +658,7 @@ class FSDPOption(ExplicitEnum):
...
@@ -658,7 +658,7 @@ class FSDPOption(ExplicitEnum):
class
RemoveColumnsCollator
:
class
RemoveColumnsCollator
:
"""Wrap the data collator to remove unused columns
from its output
."""
"""Wrap the data collator to remove unused columns
before they are passed to the collator
."""
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -690,4 +690,5 @@ class RemoveColumnsCollator:
...
@@ -690,4 +690,5 @@ class RemoveColumnsCollator:
return
{
k
:
v
for
k
,
v
in
feature
.
items
()
if
k
in
self
.
signature_columns
}
return
{
k
:
v
for
k
,
v
in
feature
.
items
()
if
k
in
self
.
signature_columns
}
def
__call__
(
self
,
features
:
List
[
dict
]):
def
__call__
(
self
,
features
:
List
[
dict
]):
return
self
.
_remove_columns
(
self
.
data_collator
(
features
))
features
=
[
self
.
_remove_columns
(
feature
)
for
feature
in
features
]
return
self
.
data_collator
(
features
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment