"tests/test_tokenization_mbart50.py" did not exist on "9336086ab5d232cccd9512333518cf4299528882"
Unverified Commit fb0a38b4 authored by Antoni Viros's avatar Antoni Viros Committed by GitHub
Browse files

Move torch.compile() wrapping after DDP/FSDP wrapping to ensure correct graph...

Move torch.compile() wrapping after DDP/FSDP wrapping to ensure correct graph breaks during training (#22279)
parent 8ac29fe0
...@@ -1361,9 +1361,6 @@ class Trainer: ...@@ -1361,9 +1361,6 @@ class Trainer:
return model return model
def _wrap_model(self, model, training=True, dataloader=None): def _wrap_model(self, model, training=True, dataloader=None):
if self.args.torch_compile:
model = torch.compile(model, backend=self.args.torch_compile_backend, mode=self.args.torch_compile_mode)
if self.args.use_ipex: if self.args.use_ipex:
dtype = torch.bfloat16 if self.use_cpu_amp else torch.float32 dtype = torch.bfloat16 if self.use_cpu_amp else torch.float32
model = self.ipex_optimize_model(model, training, dtype=dtype) model = self.ipex_optimize_model(model, training, dtype=dtype)
...@@ -1550,6 +1547,11 @@ class Trainer: ...@@ -1550,6 +1547,11 @@ class Trainer:
**kwargs, **kwargs,
) )
# torch.compile() needs to be called after wrapping the model with FSDP or DDP
# to ensure that it accounts for the graph breaks required by those wrappers
if self.args.torch_compile:
model = torch.compile(model, backend=self.args.torch_compile_backend, mode=self.args.torch_compile_mode)
return model return model
def train( def train(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment