Unverified Commit 01b55779 authored by Sourab Mangrulkar's avatar Sourab Mangrulkar Committed by GitHub
Browse files

deepspeed init during eval fix (#24298)



* deepspeed init during eval fix

* commit suggestions
Co-Authored-By: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

---------
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 6a081c51
......@@ -340,6 +340,7 @@ class Trainer:
# Seed must be set before instantiating the model when using model
enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed)
self.hp_name = None
self.deepspeed = None
self.is_in_train = False
self.create_accelerator_and_postprocess()
......@@ -3041,7 +3042,7 @@ class Trainer:
prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only
# if eval is called w/o train, handle model prep here
if self.is_deepspeed_enabled and self.model_wrapped is self.model:
if self.is_deepspeed_enabled and self.deepspeed is None:
_, _ = deepspeed_init(self, num_training_steps=0, inference=True)
model = self._wrap_model(self.model, training=False, dataloader=dataloader)
......@@ -3634,7 +3635,7 @@ class Trainer:
prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only
# if eval is called w/o train, handle model prep here
if self.is_deepspeed_enabled and self.model_wrapped is self.model:
if self.is_deepspeed_enabled and self.deepspeed is None:
_, _ = deepspeed_init(self, num_training_steps=0, inference=True)
model = self._wrap_model(self.model, training=False, dataloader=dataloader)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment