Unverified Commit f4eb459e authored by Sourab Mangrulkar's avatar Sourab Mangrulkar Committed by GitHub
Browse files

fsdp fixes and enhancements (#24980)

* fix fsdp prepare to remove the warnings and fix excess memory usage

* Update training_args.py

* parity for FSDP+XLA

* Update trainer.py
parent ec3dfe5e
...@@ -441,7 +441,7 @@ as the model saving with FSDP activated is only available with recent fixes. ...@@ -441,7 +441,7 @@ as the model saving with FSDP activated is only available with recent fixes.
- Remaining FSDP config is passed via `--fsdp_config <path_to_fsdp_config.json>`. It is either a location of - Remaining FSDP config is passed via `--fsdp_config <path_to_fsdp_config.json>`. It is either a location of
FSDP json config file (e.g., `fsdp_config.json`) or an already loaded json file as `dict`. FSDP json config file (e.g., `fsdp_config.json`) or an already loaded json file as `dict`.
- If auto wrapping is enabled, you can either use transformer based auto wrap policy or size based auto wrap policy. - If auto wrapping is enabled, you can either use transformer based auto wrap policy or size based auto wrap policy.
- For transformer based auto wrap policy, please specify `fsdp_transformer_layer_cls_to_wrap` in the config file. - For transformer based auto wrap policy, it is recommended to specify `fsdp_transformer_layer_cls_to_wrap` in the config file. If not specified, the default value is `model._no_split_modules` when available.
This specifies the list of transformer layer class name (case-sensitive) to wrap ,e.g, [`BertLayer`], [`GPTJBlock`], [`T5Block`] .... This specifies the list of transformer layer class name (case-sensitive) to wrap ,e.g, [`BertLayer`], [`GPTJBlock`], [`T5Block`] ....
This is important because submodules that share weights (e.g., embedding layer) should not end up in different FSDP wrapped units. This is important because submodules that share weights (e.g., embedding layer) should not end up in different FSDP wrapped units.
Using this policy, wrapping happens for each block containing Multi-Head Attention followed by couple of MLP layers. Using this policy, wrapping happens for each block containing Multi-Head Attention followed by couple of MLP layers.
...@@ -482,7 +482,7 @@ Pass `--fsdp "full shard"` along with following changes to be made in `--fsdp_co ...@@ -482,7 +482,7 @@ Pass `--fsdp "full shard"` along with following changes to be made in `--fsdp_co
This setting can only be used when the xla flag is set to true, and an auto wrapping policy is specified through This setting can only be used when the xla flag is set to true, and an auto wrapping policy is specified through
`fsdp_min_num_params` or `fsdp_transformer_layer_cls_to_wrap`. `fsdp_min_num_params` or `fsdp_transformer_layer_cls_to_wrap`.
- You can either use transformer based auto wrap policy or size based auto wrap policy. - You can either use transformer based auto wrap policy or size based auto wrap policy.
- For transformer based auto wrap policy, please specify `fsdp_transformer_layer_cls_to_wrap` in the config file. - For transformer based auto wrap policy, it is recommended to specify `fsdp_transformer_layer_cls_to_wrap` in the config file. If not specified, the default value is `model._no_split_modules` when available.
This specifies the list of transformer layer class name (case-sensitive) to wrap ,e.g, [`BertLayer`], [`GPTJBlock`], [`T5Block`] .... This specifies the list of transformer layer class name (case-sensitive) to wrap ,e.g, [`BertLayer`], [`GPTJBlock`], [`T5Block`] ....
This is important because submodules that share weights (e.g., embedding layer) should not end up in different FSDP wrapped units. This is important because submodules that share weights (e.g., embedding layer) should not end up in different FSDP wrapped units.
Using this policy, wrapping happens for each block containing Multi-Head Attention followed by couple of MLP layers. Using this policy, wrapping happens for each block containing Multi-Head Attention followed by couple of MLP layers.
......
...@@ -1377,18 +1377,24 @@ class Trainer: ...@@ -1377,18 +1377,24 @@ class Trainer:
raise ImportError("Missing XLA FSDP related module; please make sure to use torch-xla >= 2.0.") raise ImportError("Missing XLA FSDP related module; please make sure to use torch-xla >= 2.0.")
auto_wrap_policy = None auto_wrap_policy = None
auto_wrapper_callable = None auto_wrapper_callable = None
default_transformer_cls_names_to_wrap = getattr(model, "_no_split_modules", None)
fsdp_transformer_layer_cls_to_wrap = self.args.fsdp_config.get(
"fsdp_transformer_layer_cls_to_wrap", default_transformer_cls_names_to_wrap
)
if self.args.fsdp_config["fsdp_min_num_params"] > 0: if self.args.fsdp_config["fsdp_min_num_params"] > 0:
auto_wrap_policy = functools.partial( auto_wrap_policy = functools.partial(
size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"] size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"]
) )
elif self.args.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None: elif fsdp_transformer_layer_cls_to_wrap is not None:
transformer_cls_to_wrap = set() transformer_cls_to_wrap = set()
for layer_class in self.args.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]: for layer_class in fsdp_transformer_layer_cls_to_wrap:
transformer_cls = get_module_class_from_name(model, layer_class) transformer_cls = get_module_class_from_name(model, layer_class)
if transformer_cls is None: if transformer_cls is None:
raise Exception("Could not find the transformer layer class to wrap in the model.") raise Exception("Could not find the transformer layer class to wrap in the model.")
else: else:
transformer_cls_to_wrap.add(transformer_cls) transformer_cls_to_wrap.add(transformer_cls)
auto_wrap_policy = functools.partial( auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy, transformer_auto_wrap_policy,
# Transformer layer class to wrap # Transformer layer class to wrap
...@@ -1600,6 +1606,7 @@ class Trainer: ...@@ -1600,6 +1606,7 @@ class Trainer:
and self.sharded_ddp != ShardedDDPOption.SIMPLE and self.sharded_ddp != ShardedDDPOption.SIMPLE
or is_sagemaker_mp_enabled() or is_sagemaker_mp_enabled()
or self.fsdp is not None or self.fsdp is not None
or self.is_fsdp_enabled
) )
# We need to reset the scheduler, as its parameters may be different on subsequent calls # We need to reset the scheduler, as its parameters may be different on subsequent calls
...@@ -1631,6 +1638,8 @@ class Trainer: ...@@ -1631,6 +1638,8 @@ class Trainer:
use_accelerator_prepare = True if model is self.model else False use_accelerator_prepare = True if model is self.model else False
if delay_optimizer_creation: if delay_optimizer_creation:
if use_accelerator_prepare:
self.model = self.accelerator.prepare(self.model)
self.create_optimizer_and_scheduler(num_training_steps=max_steps) self.create_optimizer_and_scheduler(num_training_steps=max_steps)
# prepare using `accelerator` prepare # prepare using `accelerator` prepare
......
...@@ -1567,6 +1567,7 @@ class TrainingArguments: ...@@ -1567,6 +1567,7 @@ class TrainingArguments:
elif fsdp_option == FSDPOption.OFFLOAD: elif fsdp_option == FSDPOption.OFFLOAD:
os.environ["FSDP_OFFLOAD_PARAMS"] = "true" os.environ["FSDP_OFFLOAD_PARAMS"] = "true"
elif fsdp_option == FSDPOption.AUTO_WRAP: elif fsdp_option == FSDPOption.AUTO_WRAP:
os.environ["FSDP_AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[0]
if self.fsdp_config["fsdp_min_num_params"] > 0: if self.fsdp_config["fsdp_min_num_params"] > 0:
os.environ["FSDP_MIN_NUM_PARAMS"] = str(self.fsdp_config["fsdp_min_num_params"]) os.environ["FSDP_MIN_NUM_PARAMS"] = str(self.fsdp_config["fsdp_min_num_params"])
os.environ["FSDP_AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[1] os.environ["FSDP_AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[1]
...@@ -1574,7 +1575,6 @@ class TrainingArguments: ...@@ -1574,7 +1575,6 @@ class TrainingArguments:
os.environ["FSDP_TRANSFORMER_CLS_TO_WRAP"] = ",".join( os.environ["FSDP_TRANSFORMER_CLS_TO_WRAP"] = ",".join(
self.fsdp_config["fsdp_transformer_layer_cls_to_wrap"] self.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]
) )
os.environ["FSDP_AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[0]
prefetch_policy = self.fsdp_config.get("fsdp_backward_prefetch", "NO_PREFETCH") prefetch_policy = self.fsdp_config.get("fsdp_backward_prefetch", "NO_PREFETCH")
os.environ["FSDP_BACKWARD_PREFETCH"] = prefetch_policy.upper() os.environ["FSDP_BACKWARD_PREFETCH"] = prefetch_policy.upper()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment