Unverified Commit 09a9888f authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`bnb`] 8bit models should not be converted to `DDP` (#22628)

add safety checker
parent d0b83fe2
...@@ -1406,8 +1406,8 @@ class Trainer: ...@@ -1406,8 +1406,8 @@ class Trainer:
if self.use_apex and training: if self.use_apex and training:
model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level) model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)
# Multi-gpu training (should be after apex fp16 initialization) # Multi-gpu training (should be after apex fp16 initialization) / 8bit models does not support DDP
if self.args.n_gpu > 1: if self.args.n_gpu > 1 and not getattr(model, "is_loaded_in_8bit", False):
model = nn.DataParallel(model) model = nn.DataParallel(model)
if self.args.jit_mode_eval: if self.args.jit_mode_eval:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment