Unverified Commit fb0a38b4 authored by Antoni Viros's avatar Antoni Viros Committed by GitHub
Browse files

Move torch.compile() wrapping after DDP/FSDP wrapping to ensure correct graph...

Move torch.compile() wrapping after DDP/FSDP wrapping to ensure correct graph breaks during training (#22279)
parent 8ac29fe0
......@@ -1361,9 +1361,6 @@ class Trainer:
return model
def _wrap_model(self, model, training=True, dataloader=None):
if self.args.torch_compile:
model = torch.compile(model, backend=self.args.torch_compile_backend, mode=self.args.torch_compile_mode)
if self.args.use_ipex:
dtype = torch.bfloat16 if self.use_cpu_amp else torch.float32
model = self.ipex_optimize_model(model, training, dtype=dtype)
......@@ -1550,6 +1547,11 @@ class Trainer:
**kwargs,
)
# torch.compile() needs to be called after wrapping the model with FSDP or DDP
# to ensure that it accounts for the graph breaks required by those wrappers
if self.args.torch_compile:
model = torch.compile(model, backend=self.args.torch_compile_backend, mode=self.args.torch_compile_mode)
return model
def train(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment