Unverified Commit 1ab71364 authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

[Trainer] Allow passing image processor (#29896)

* Add image processor to trainer

* Replace tokenizer=image_processor everywhere
parent d704c0b6
...@@ -189,6 +189,8 @@ class TrainerCallback: ...@@ -189,6 +189,8 @@ class TrainerCallback:
The model being trained. The model being trained.
tokenizer ([`PreTrainedTokenizer`]): tokenizer ([`PreTrainedTokenizer`]):
The tokenizer used for encoding the data. The tokenizer used for encoding the data.
image_processor ([`BaseImageProcessor`]):
The image processor used for encoding the images.
optimizer (`torch.optim.Optimizer`): optimizer (`torch.optim.Optimizer`):
The optimizer used for the training steps. The optimizer used for the training steps.
lr_scheduler (`torch.optim.lr_scheduler.LambdaLR`): lr_scheduler (`torch.optim.lr_scheduler.LambdaLR`):
...@@ -307,12 +309,13 @@ class TrainerCallback: ...@@ -307,12 +309,13 @@ class TrainerCallback:
class CallbackHandler(TrainerCallback): class CallbackHandler(TrainerCallback):
"""Internal class that just calls the list of callbacks in order.""" """Internal class that just calls the list of callbacks in order."""
def __init__(self, callbacks, model, tokenizer, optimizer, lr_scheduler): def __init__(self, callbacks, model, tokenizer, image_processor, optimizer, lr_scheduler):
self.callbacks = [] self.callbacks = []
for cb in callbacks: for cb in callbacks:
self.add_callback(cb) self.add_callback(cb)
self.model = model self.model = model
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.image_processor = image_processor
self.optimizer = optimizer self.optimizer = optimizer
self.lr_scheduler = lr_scheduler self.lr_scheduler = lr_scheduler
self.train_dataloader = None self.train_dataloader = None
...@@ -417,6 +420,7 @@ class CallbackHandler(TrainerCallback): ...@@ -417,6 +420,7 @@ class CallbackHandler(TrainerCallback):
control, control,
model=self.model, model=self.model,
tokenizer=self.tokenizer, tokenizer=self.tokenizer,
image_processor=self.image_processor,
optimizer=self.optimizer, optimizer=self.optimizer,
lr_scheduler=self.lr_scheduler, lr_scheduler=self.lr_scheduler,
train_dataloader=self.train_dataloader, train_dataloader=self.train_dataloader,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment