Unverified Commit 5f06053d authored by Jiewen Tan's avatar Jiewen Tan Committed by GitHub
Browse files

[TPU] Support PyTorch/XLA FSDP via SPMD (#28949)

* Initial commit

* Add guards for the global mesh

* Address more comments

* Move the dataloader into integrations/tpu.py

* Fix linters

* Make karg more explicitly

* Remove the move device logic

* Fix the CI

* Fix linters

* Re-enable checkpointing
parent 0199a484
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torch.utils.data import DataLoader
from ..utils import is_torch_tpu_available
def tpu_spmd_dataloader(dataloader: DataLoader):
if is_torch_tpu_available():
import torch_xla.distributed.parallel_loader as pl
assert isinstance(
dataloader, pl.MpDeviceLoader
), "The dataloader must be a `torch_xla.distributed.parallel_loader.MpDeviceLoader`."
# This is to support PyTorch/XLA FSDP via SPMD.
# Here we shard the input data's 0th dim across the fsdp axis.
import torch_xla.distributed.spmd as xs
sharding_spec = xs.ShardingSpec(xs.get_global_mesh(), ("fsdp", None))
dataloader._parallel_loader_kwargs["input_sharding"] = sharding_spec
return dataloader
else:
return dataloader
......@@ -60,6 +60,7 @@ from .data.data_collator import DataCollator, DataCollatorWithPadding, default_d
from .debug_utils import DebugOption, DebugUnderflowOverflow
from .hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS, default_hp_search_backend
from .integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available
from .integrations.tpu import tpu_spmd_dataloader
from .modelcard import TrainingSummary
from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
from .models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES
......@@ -170,6 +171,8 @@ if is_datasets_available():
if is_torch_tpu_available(check_device=False):
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.distributed.spmd as xs
import torch_xla.runtime as xr
if is_sagemaker_mp_enabled():
......@@ -635,6 +638,13 @@ class Trainer:
if args.torch_compile and not is_torch_compile_available():
raise RuntimeError("Using torch.compile requires PyTorch 2.0 or higher.")
self.is_fsdp_xla_v2_enabled = args.fsdp_config["xla_fsdp_v2"]
if self.is_fsdp_xla_v2_enabled:
# Prepare the SPMD mesh that is going to be used by the data loader and the FSDPv2 wrapper.
# Tensor axis is just a placeholder where it will not be used in FSDPv2.
num_devices = xr.global_runtime_device_count()
xs.set_global_mesh(xs.Mesh(np.array(range(num_devices)), (num_devices, 1), axis_names=("fsdp", "tensor")))
def _activate_neftune(self, model):
r"""
Activates the neftune as presented in this code: https://github.com/neelsjain/NEFTune and paper:
......@@ -1385,6 +1395,11 @@ class Trainer:
size_based_auto_wrap_policy,
transformer_auto_wrap_policy,
)
if self.is_fsdp_xla_v2_enabled:
from torch_xla.experimental.spmd_fully_sharded_data_parallel import (
SpmdFullyShardedDataParallel as FSDPv2,
)
except ImportError:
raise ImportError("Missing XLA FSDP related module; please make sure to use torch-xla >= 2.0.")
auto_wrap_policy = None
......@@ -1416,15 +1431,40 @@ class Trainer:
if self.args.fsdp_config["xla_fsdp_grad_ckpt"]:
# Apply gradient checkpointing to auto-wrapped sub-modules if specified
def auto_wrapper_callable(m, *args, **kwargs):
return FSDP(checkpoint_module(m), *args, **kwargs)
target_cls = FSDP if not self.is_fsdp_xla_v2_enabled else FSDPv2
return target_cls(checkpoint_module(m), *args, **kwargs)
# Wrap the base model with an outer FSDP wrapper
self.model = model = FSDP(
model,
auto_wrap_policy=auto_wrap_policy,
auto_wrapper_callable=auto_wrapper_callable,
**fsdp_kwargs,
)
if self.is_fsdp_xla_v2_enabled:
def shard_output(output, mesh):
from .modeling_outputs import CausalLMOutputWithPast
real_output = None
if isinstance(output, torch.Tensor):
real_output = output
elif isinstance(output, tuple):
real_output = output[0]
elif isinstance(output, CausalLMOutputWithPast):
real_output = output.logits
if real_output is None:
raise ValueError("Something went wrong, the output of the model shouldn't be `None`")
xs.mark_sharding(real_output, mesh, ("fsdp", None, None))
self.model = model = FSDPv2(
model,
shard_output=shard_output,
auto_wrap_policy=auto_wrap_policy,
auto_wrapper_callable=auto_wrapper_callable,
)
else:
self.model = model = FSDP(
model,
auto_wrap_policy=auto_wrap_policy,
auto_wrapper_callable=auto_wrapper_callable,
**fsdp_kwargs,
)
# Patch `xm.optimizer_step` should not reduce gradients in this case,
# as FSDP does not need gradient reduction over sharded parameters.
......@@ -1593,6 +1633,8 @@ class Trainer:
logger.debug(f"Currently training with a batch size of: {self._train_batch_size}")
# Data loader and number of training steps
train_dataloader = self.get_train_dataloader()
if self.is_fsdp_xla_v2_enabled:
train_dataloader = tpu_spmd_dataloader(train_dataloader)
# Setting up training control variables:
# number of training epochs: num_train_epochs
......@@ -1962,6 +2004,11 @@ class Trainer:
self.control = self.callback_handler.on_substep_end(args, self.state, self.control)
if self.control.should_epoch_stop or self.control.should_training_stop:
# PyTorch/XLA relies on the data loader to insert the mark_step for
# each step. Since we are breaking the loop early, we need to manually
# insert the mark_step here.
if is_torch_tpu_available():
xm.mark_step()
break
if step < 0:
logger.warning(
......@@ -2945,6 +2992,7 @@ class Trainer:
def _save_tpu(self, output_dir: Optional[str] = None):
output_dir = output_dir if output_dir is not None else self.args.output_dir
logger.info(f"Saving model checkpoint to {output_dir}")
model = self.model
model.to("cpu")
......@@ -3143,6 +3191,9 @@ class Trainer:
self._memory_tracker.start()
eval_dataloader = self.get_eval_dataloader(eval_dataset)
if self.is_fsdp_xla_v2_enabled:
eval_dataloader = tpu_spmd_dataloader(eval_dataloader)
start_time = time.time()
eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
......
......@@ -1684,6 +1684,7 @@ class TrainingArguments:
):
raise ValueError("`min_num_params` and `transformer_layer_cls_to_wrap` are mutually exclusive.")
self.fsdp_config["xla"] = self.fsdp_config.get("xla", False)
self.fsdp_config["xla_fsdp_v2"] = self.fsdp_config.get("xla_fsdp_v2", False)
self.fsdp_config["xla_fsdp_grad_ckpt"] = self.fsdp_config.get("xla_fsdp_grad_ckpt", False)
if self.fsdp_config["xla"]:
if len(self.fsdp) > 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment