Unverified Commit ba1b24e0 authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

[Trainer] Fix default data collator (#30142)

* Fix data collator

* Support feature extractors as well
parent ec59a421
...@@ -58,6 +58,7 @@ from . import __version__ ...@@ -58,6 +58,7 @@ from . import __version__
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
from .debug_utils import DebugOption, DebugUnderflowOverflow from .debug_utils import DebugOption, DebugUnderflowOverflow
from .feature_extraction_sequence_utils import SequenceFeatureExtractor
from .hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS, default_hp_search_backend from .hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS, default_hp_search_backend
from .image_processing_utils import BaseImageProcessor from .image_processing_utils import BaseImageProcessor
from .integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available from .integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available
...@@ -492,7 +493,11 @@ class Trainer: ...@@ -492,7 +493,11 @@ class Trainer:
): ):
self.place_model_on_device = False self.place_model_on_device = False
default_collator = DataCollatorWithPadding(tokenizer) if tokenizer is not None else default_data_collator default_collator = (
DataCollatorWithPadding(tokenizer)
if tokenizer is not None and isinstance(tokenizer, (PreTrainedTokenizerBase, SequenceFeatureExtractor))
else default_data_collator
)
self.data_collator = data_collator if data_collator is not None else default_collator self.data_collator = data_collator if data_collator is not None else default_collator
self.train_dataset = train_dataset self.train_dataset = train_dataset
self.eval_dataset = eval_dataset self.eval_dataset = eval_dataset
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment