Commit 9bca36a9 authored by lintangsutawika's avatar lintangsutawika
Browse files

pre-commit

parent a7286607
...@@ -296,14 +296,14 @@ class HFLM(LM): ...@@ -296,14 +296,14 @@ class HFLM(LM):
) )
else: else:
assert accelerator.distributed_type in [ assert accelerator.distributed_type in [
DistributedType.FSDP, DistributedType.FSDP,
DistributedType.MULTI_GPU DistributedType.MULTI_GPU,
], "Unsupported distributed type provided. Only DDP and FSDP are supported." ], "Unsupported distributed type provided. Only DDP and FSDP are supported."
if accelerator.distributed_type == DistributedType.FSDP: if accelerator.distributed_type == DistributedType.FSDP:
self._model = accelerator.prepare(self.model) self._model = accelerator.prepare(self.model)
else: else:
self._model = accelerator.prepare_model( self._model = accelerator.prepare_model(
self.model, evaluation_mode=True self.model, evaluation_mode=True
) )
self._device = torch.device(f"cuda:{accelerator.local_process_index}") self._device = torch.device(f"cuda:{accelerator.local_process_index}")
self.accelerator = accelerator self.accelerator = accelerator
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment