Unverified Commit d1c039e3 authored by Sourab Mangrulkar's avatar Sourab Mangrulkar Committed by GitHub
Browse files

fix accelerator prepare during eval only mode (#24014)

* fix mixed precision prep during eval only mode

* update to address comments

* update to reflect the changes in accelerate
parent 2c887cf8
...@@ -3141,14 +3141,30 @@ class Trainer: ...@@ -3141,14 +3141,30 @@ class Trainer:
prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only
# if eval is called w/o train init deepspeed here # if eval is called w/o train, handle model prep here
if self.is_deepspeed_enabled and self.model_wrapped is self.model: if self.is_deepspeed_enabled and self.model_wrapped is self.model:
_, _ = deepspeed_init(self, num_training_steps=0, inference=True) _, _ = deepspeed_init(self, num_training_steps=0, inference=True)
model = self.accelerator.prepare(self.model)
self.model_wrapped = self.deepspeed = model
model = self._wrap_model(self.model, training=False, dataloader=dataloader) model = self._wrap_model(self.model, training=False, dataloader=dataloader)
if len(self.accelerator._models) == 0 and model is self.model:
model = (
self.accelerator.prepare(model)
if self.is_deepspeed_enabled
else self.accelerator.prepare_model(model, evaluation_mode=True)
)
if self.is_fsdp_enabled:
self.model = model
# for the rest of this function `model` is the outside model, whether it was wrapped or not
if model is not self.model:
self.model_wrapped = model
# backward compatibility
if self.is_deepspeed_enabled:
self.deepspeed = self.model_wrapped
# if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
# while ``train`` is running, cast it to the right dtype first and then put on device # while ``train`` is running, cast it to the right dtype first and then put on device
if not self.is_in_train: if not self.is_in_train:
...@@ -3736,14 +3752,30 @@ class Trainer: ...@@ -3736,14 +3752,30 @@ class Trainer:
prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only
# if eval is called w/o train init deepspeed here # if eval is called w/o train, handle model prep here
if self.is_deepspeed_enabled and self.model_wrapped is self.model: if self.is_deepspeed_enabled and self.model_wrapped is self.model:
_, _ = deepspeed_init(self, num_training_steps=0, inference=True) _, _ = deepspeed_init(self, num_training_steps=0, inference=True)
model = self.accelerator.prepare(self.model)
self.model_wrapped = self.deepspeed = model
model = self._wrap_model(self.model, training=False, dataloader=dataloader) model = self._wrap_model(self.model, training=False, dataloader=dataloader)
if len(self.accelerator._models) == 0 and model is self.model:
model = (
self.accelerator.prepare(model)
if self.is_deepspeed_enabled
else self.accelerator.prepare_model(model, evaluation_mode=True)
)
if self.is_fsdp_enabled:
self.model = model
# for the rest of this function `model` is the outside model, whether it was wrapped or not
if model is not self.model:
self.model_wrapped = model
# backward compatibility
if self.is_deepspeed_enabled:
self.deepspeed = self.model_wrapped
# if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
# while ``train`` is running, cast it to the right dtype first and then put on device # while ``train`` is running, cast it to the right dtype first and then put on device
if not self.is_in_train: if not self.is_in_train:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment