Unverified Commit aac7b0d2 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[Trainer] add error when passing `8bit`models (#20651)

* add error when passing `8bit`models

* fix

* improve message
parent d151a8c5
......@@ -357,6 +357,13 @@ class Trainer:
else:
self.is_model_parallel = False
# At this stage the model is already loaded
if getattr(model, "is_loaded_in_8bit", False):
raise ValueError(
"The model you want to train is loaded in 8-bit precision. "
"Training an 8-bit model is not supported yet. "
)
# Setup Sharded DDP training
self.sharded_ddp = None
if len(args.sharded_ddp) > 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment