Commit 92561822 authored by Benjamin Fattori's avatar Benjamin Fattori
Browse files

fix FSDP error with .prepare_model()

parent 4cda3a1c
...@@ -20,7 +20,7 @@ from lm_eval.api.registry import register_model ...@@ -20,7 +20,7 @@ from lm_eval.api.registry import register_model
from lm_eval.utils import MultiTokenEOSCriteria, stop_sequences_criteria from lm_eval.utils import MultiTokenEOSCriteria, stop_sequences_criteria
from accelerate import Accelerator, find_executable_batch_size from accelerate import Accelerator, find_executable_batch_size, DistributedType
from typing import List, Optional, Union from typing import List, Optional, Union
...@@ -288,9 +288,15 @@ class HFLM(LM): ...@@ -288,9 +288,15 @@ class HFLM(LM):
eval_logger.info( eval_logger.info(
"Failed to place model onto specified device. This may be because the model is quantized via `bitsandbytes`. If the desired GPU is being used, this message is safe to ignore." "Failed to place model onto specified device. This may be because the model is quantized via `bitsandbytes`. If the desired GPU is being used, this message is safe to ignore."
) )
else:
assert accelerator.distributed_type in [DistributedType.FSDP, DistributedType.MULTI_GPU], "Unsupported distributed type provided. Only DDP and FSDP are supported."
if accelerator.distributed_type == DistributedType.FSDP:
self._model = accelerator.prepare(
self.model
)
else: else:
self._model = accelerator.prepare_model( self._model = accelerator.prepare_model(
self.model, evaluation_mode=True self.model, evaluation_mode = True
) )
self._device = torch.device(f"cuda:{accelerator.local_process_index}") self._device = torch.device(f"cuda:{accelerator.local_process_index}")
self.accelerator = accelerator self.accelerator = accelerator
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment