Unverified Commit 636acc75 authored by Sourab Mangrulkar's avatar Sourab Mangrulkar Committed by GitHub
Browse files

fix z3 init when using accelerate launcher (#25589)

parent 8d2f953f
...@@ -1467,6 +1467,15 @@ class TrainingArguments: ...@@ -1467,6 +1467,15 @@ class TrainingArguments:
torch.backends.cudnn.allow_tf32 = False torch.backends.cudnn.allow_tf32 = False
# no need to assert on else # no need to assert on else
# if training args is specified, it will override the one specified in the accelerate config
if self.half_precision_backend != "apex" and len(self.sharded_ddp) == 0:
mixed_precision_dtype = os.environ.get("ACCELERATE_MIXED_PRECISION", "no")
if self.fp16:
mixed_precision_dtype = "fp16"
elif self.bf16:
mixed_precision_dtype = "bf16"
os.environ["ACCELERATE_MIXED_PRECISION"] = mixed_precision_dtype
if self.report_to is None: if self.report_to is None:
logger.info( logger.info(
"The default value for the training argument `--report_to` will change in v5 (from all installed " "The default value for the training argument `--report_to` will change in v5 (from all installed "
...@@ -1655,6 +1664,8 @@ class TrainingArguments: ...@@ -1655,6 +1664,8 @@ class TrainingArguments:
from accelerate.utils import DeepSpeedPlugin from accelerate.utils import DeepSpeedPlugin
self.deepspeed_plugin = DeepSpeedPlugin() self.deepspeed_plugin = DeepSpeedPlugin()
mixed_precision = os.environ.get("ACCELERATE_MIXED_PRECISION", "no")
self.deepspeed_plugin.set_mixed_precision(mixed_precision)
self.deepspeed_plugin.set_deepspeed_weakref() self.deepspeed_plugin.set_deepspeed_weakref()
if self.push_to_hub_token is not None: if self.push_to_hub_token is not None:
...@@ -1692,15 +1703,6 @@ class TrainingArguments: ...@@ -1692,15 +1703,6 @@ class TrainingArguments:
FutureWarning, FutureWarning,
) )
# if training args is specified, it will override the one specified in the accelerate config
if self.half_precision_backend != "apex" and len(self.sharded_ddp) == 0:
mixed_precision_dtype = os.environ.get("ACCELERATE_MIXED_PRECISION", "no")
if self.fp16:
mixed_precision_dtype = "fp16"
elif self.bf16:
mixed_precision_dtype = "bf16"
os.environ["ACCELERATE_MIXED_PRECISION"] = mixed_precision_dtype
# Finally set the `TrainingArguments` to be immutable # Finally set the `TrainingArguments` to be immutable
self._frozen = True self._frozen = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment