Unverified Commit 80126f98 authored by Howard Liberty's avatar Howard Liberty Committed by GitHub
Browse files

Allow boolean FSDP options in fsdp_config (#30439)

* Allow boolean FSDP options in fsdp_config

* Use lower() to be safe
parent 73014b56
...@@ -1840,12 +1840,12 @@ class TrainingArguments: ...@@ -1840,12 +1840,12 @@ class TrainingArguments:
) )
prefetch_policy = self.fsdp_config.get("backward_prefetch", "NO_PREFETCH") prefetch_policy = self.fsdp_config.get("backward_prefetch", "NO_PREFETCH")
os.environ[f"{prefix}BACKWARD_PREFETCH"] = prefetch_policy.upper() os.environ[f"{prefix}BACKWARD_PREFETCH"] = prefetch_policy.upper()
os.environ[f"{prefix}FORWARD_PREFETCH"] = self.fsdp_config.get("forward_prefetch", "false") os.environ[f"{prefix}FORWARD_PREFETCH"] = str(self.fsdp_config.get("forward_prefetch", "false")).lower()
sync_module_states = self.fsdp_config.get("sync_module_states", "true") sync_module_states = str(self.fsdp_config.get("sync_module_states", "true")).lower()
cpu_ram_efficient_loading = self.fsdp_config.get("cpu_ram_efficient_loading", "false") cpu_ram_efficient_loading = str(self.fsdp_config.get("cpu_ram_efficient_loading", "false")).lower()
if str(sync_module_states).lower() == "false" and str(cpu_ram_efficient_loading).lower() == "true": if sync_module_states == "false" and cpu_ram_efficient_loading == "true":
# In this case, all the processes except the main process would have random weights leading # In this case, all the processes except the main process would have random weights leading
# to unexpected behaviour during training, thus throwing error here to prevent it. # to unexpected behaviour during training, thus throwing error here to prevent it.
raise ValueError('`sync_module_states` must be `"True"` if `cpu_ram_efficient_loading` is `"True"`') raise ValueError('`sync_module_states` must be `"True"` if `cpu_ram_efficient_loading` is `"True"`')
...@@ -1853,7 +1853,7 @@ class TrainingArguments: ...@@ -1853,7 +1853,7 @@ class TrainingArguments:
os.environ[f"{prefix}SYNC_MODULE_STATES"] = sync_module_states os.environ[f"{prefix}SYNC_MODULE_STATES"] = sync_module_states
os.environ[f"{prefix}CPU_RAM_EFFICIENT_LOADING"] = cpu_ram_efficient_loading os.environ[f"{prefix}CPU_RAM_EFFICIENT_LOADING"] = cpu_ram_efficient_loading
os.environ[f"{prefix}USE_ORIG_PARAMS"] = self.fsdp_config.get("use_orig_params", "true") os.environ[f"{prefix}USE_ORIG_PARAMS"] = str(self.fsdp_config.get("use_orig_params", "true")).lower()
if is_accelerate_available(): if is_accelerate_available():
if not isinstance(self.accelerator_config, (AcceleratorConfig)): if not isinstance(self.accelerator_config, (AcceleratorConfig)):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment