Unverified Commit 0196ed44 authored by Youngeun Kwon's avatar Youngeun Kwon Committed by GitHub
Browse files

Enabling FP8 all-gather for TE Float8Tensor when using Torch FSDP2 (#1358)



* draft implementation of fsdp2 fp8 all gather
Signed-off-by: default avatarYoungeun Kwon <youngeunk@nvidia.com>

* fix the convergence issue
Signed-off-by: default avatarYoungeun Kwon <youngeunk@nvidia.com>

* Add warning
Signed-off-by: default avatarYoungeun Kwon <youngeunk@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* disable lint error
Signed-off-by: default avatarYoungeun Kwon <youngeunk@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix the lint error
Signed-off-by: default avatarYoungeun Kwon <youngeunk@nvidia.com>

* fix lint error
Signed-off-by: default avatarYoungeun Kwon <youngeunk@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix lint error
Signed-off-by: default avatarYoungeun Kwon <youngeunk@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix lint error
Signed-off-by: default avatarYoungeun Kwon <youngeunk@nvidia.com>

* add comments
Signed-off-by: default avatarYoungeun Kwon <youngeunk@nvidia.com>

* add ref
Signed-off-by: default avatarYoungeun Kwon <youngeunk@nvidia.com>

* add related tests
Signed-off-by: default avatarYoungeun Kwon <youngeunk@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarYoungeun Kwon <youngeunk@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 1975ace4
......@@ -11,4 +11,5 @@ pytest -v -s $TE_PATH/tests/pytorch/distributed/test_numerics.py
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py
pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py
#!/usr/bin/python3
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import os
import sys
import argparse
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling
import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch import nn, optim
from torch.distributed import DeviceMesh
from torch.distributed._composable.fsdp import fully_shard
from torch.distributed.device_mesh import init_device_mesh
from contextlib import nullcontext
class SimpleNet(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(SimpleNet, self).__init__()
self.fc1 = te.Linear(input_size, hidden_size)
self.fc2 = te.Linear(hidden_size, output_size)
def forward(self, x):
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
def save_custom_attrs(module):
custom_attrs = {}
for name, param in module.named_parameters():
attrs = vars(param)
custom_attrs[name] = {k: v for k, v in attrs.items()}
return custom_attrs
def restore_custom_attrs(module, custom_attrs):
for name, param in module.named_parameters():
if name in custom_attrs:
for attr_name, attr_value in custom_attrs[name].items():
setattr(param, attr_name, attr_value)
def _parse_args(argv=None, namespace=None):
parser = argparse.ArgumentParser(description="Toy example for debugging fully_shard()")
parser.add_argument("--input-size", type=int, default=2048, help="Input size for the model")
parser.add_argument("--hidden-size", type=int, default=2048, help="Hidden layer size")
parser.add_argument("--output-size", type=int, default=2048, help="Output size for the model")
parser.add_argument("--batch-size", type=int, default=2048, help="Output size for the model")
parser.add_argument(
"--fp8-init", action="store_true", default=False, help="Initialize primary weights in FP8."
)
parser.add_argument(
"--iter", type=int, default=10, help="Number of iterations for forward pass"
)
parser.add_argument("--seed", type=int, default=42, help="RNG seed.")
# Adding hsdp_dim as a list argument, comma-separated
parser.add_argument(
"--sharding-dims",
type=int,
nargs="+",
help='FSDP/HSDP sharding dimensions ("replicate", "shard")',
)
args = parser.parse_args(argv, namespace)
if args.sharding_dims:
assert len(args.sharding_dims) <= 2
return args
sub_modules_to_wrap = [te.Linear]
def _train(args):
assert "TORCHELASTIC_RUN_ID" in os.environ
WORLD_RANK = int(os.getenv("RANK", "0"))
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0"))
LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1"))
assert LOCAL_SIZE == WORLD_SIZE
# Set device and initialize RNG states
torch.cuda.set_device(WORLD_RANK)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
# Initialize torch.distributed global process group and get DP/TP groups
dist_init_kwargs = {
"backend": "nccl",
"rank": WORLD_RANK,
"world_size": WORLD_SIZE,
}
assert dist.is_nccl_available()
dist.init_process_group(**dist_init_kwargs)
nccl_world = dist.new_group(backend="nccl")
device = torch.device(f"cuda:{LOCAL_RANK}")
# FP8 Configuration
fp8_format = Format.HYBRID
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max")
if not args.fp8_init:
# Build model context (FP8 init)
build_model_context = nullcontext
build_model_context_args = {}
from transformer_engine.pytorch import fp8_model_init
build_model_context = fp8_model_init
build_model_context_args["enabled"] = True
# Build the model with the specified context
with build_model_context(**build_model_context_args):
model = SimpleNet(args.input_size, args.hidden_size, args.output_size)
else:
model = SimpleNet(args.input_size, args.hidden_size, args.output_size)
# Move the model to the correct device
model.to(device)
if LOCAL_RANK == 0:
print(f"Rank {LOCAL_RANK}: Applying FSDP fully_shard() to the model...")
# Creating a DeviceMesh for fully_shard
world_size = int(WORLD_SIZE)
device_ids = list(range(world_size))
if LOCAL_RANK == 0:
print(f"sharding-dims:{args.sharding_dims}")
# Setup the sharding mesh for FSDP/HSDP
if args.sharding_dims == None: # FSDP
mesh = DeviceMesh("cuda", device_ids)
elif len(args.sharding_dims) == 1:
assert args.sharding_dims[0] == device_ids[-1] + 1
mesh = DeviceMesh("cuda", device_ids)
elif len(args.sharding_dims) == 2: # HSDP
assert args.sharding_dims[0] * args.sharding_dims[1] == device_ids[-1] + 1
mesh = init_device_mesh(
"cuda",
(args.sharding_dims[0], args.sharding_dims[1]),
mesh_dim_names=("replicate", "shard"),
)
else:
assert False
# Apply FSDP/HSDP
custom_attrs = save_custom_attrs(model)
for sub_module in model.modules():
if any(
isinstance(sub_module, sub_module_to_wrap) for sub_module_to_wrap in sub_modules_to_wrap
):
fully_shard(sub_module, mesh=mesh)
fully_shard(model, mesh=mesh)
restore_custom_attrs(model, custom_attrs)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
for iteration in range(args.iter):
# Zero the parameter gradients
optimizer.zero_grad()
input_data = torch.randn(args.batch_size, args.input_size).to(device)
output = model(input_data)
target = torch.randn(args.batch_size, args.output_size).to(device)
loss = F.mse_loss(output, target)
loss.backward()
optimizer.step()
if LOCAL_RANK == 0:
print(f"Rank {LOCAL_RANK}: Iteration {iteration} completed.")
dist.destroy_process_group()
if LOCAL_RANK == 0:
print(f"Rank {LOCAL_RANK}: Done...")
return 0
if __name__ == "__main__":
sys.exit(_train(_parse_args()))
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import os
import pytest
import subprocess
from pathlib import Path
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
import torch
from packaging.version import Version as PkgVersion
def get_torch_version():
"""Get pytorch version from __version__"""
def get_torch_version_str():
import torch
return str(torch.__version__)
return PkgVersion(get_torch_version_str())
if torch.cuda.device_count() < 4:
pytest.skip("FSDP2 test requires at least 4 GPUs.")
if torch.cuda.device_count() % 2 != 0:
pytest.skip("Number of device should be divided by 2.")
if not get_torch_version() >= PkgVersion("2.4"):
pytest.skip("FSDP2 requires PyTorch >= 2.4.0 with FSDP 2 support.")
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
TEST_ROOT = Path(__file__).parent.resolve()
NUM_PROCS: int = torch.cuda.device_count()
LAUNCH_CMD = ["torchrun", f"--nproc_per_node={NUM_PROCS}"]
def _run_test(fp_init, sharding_dims):
test_path = TEST_ROOT / "run_fsdp2_model.py"
test_cmd = LAUNCH_CMD + [str(test_path)]
if fp_init:
test_cmd += ["--fp8-init"]
if len(sharding_dims) == 1:
test_cmd += ["--sharding-dims", str(sharding_dims[0])]
elif len(sharding_dims) == 2:
test_cmd += ["--sharding-dims", str(sharding_dims[0]), str(sharding_dims[1])]
else:
assert False
result = subprocess.run(test_cmd, env=os.environ, capture_output=True, check=False)
if result.returncode != 0:
raise AssertionError(result.stderr.decode())
all_boolean = [True, False]
sharding_dims = [[NUM_PROCS], [2, NUM_PROCS // 2]]
@pytest.mark.parametrize("sharding_dims", sharding_dims)
@pytest.mark.parametrize("fp8_init", all_boolean)
def test_distributed(fp8_init, sharding_dims):
if fp8_init and not fp8_available:
pytest.skip(reason_for_no_fp8)
_run_test(fp8_init, sharding_dims)
......@@ -24,6 +24,19 @@ from .quantized_tensor import QuantizedTensor
aten = torch.ops.aten
updated_fp8_params = {}
_ops_to_preserve_subclass_in_fsdp2 = {
torch.ops.aten.empty_like.default,
torch.ops.aten.new_zeros.default,
torch.ops.aten.slice.Tensor,
torch.ops.aten.copy_.default,
torch.ops.aten.view.default,
torch.ops.aten.as_strided.default,
torch.ops.aten._to_copy.default,
torch.ops.aten._pin_memory.default,
torch.ops.aten.split.Tensor,
torch.ops.aten.clone.default,
}
def _make_fp8_attr_property_funcs(name: str) -> Any:
"""Make accessors for an FP8 attribute
......@@ -430,6 +443,37 @@ class Float8Tensor(QuantizedTensor):
return self
def fsdp_pre_all_gather(self, mesh): # pylint: disable=unused-argument
"""
A hook function used in torch fsdp2, called before all-gather
return (all-gather input), (metadata)
Ref: https://github.com/pytorch/pytorch/pull/122908
"""
return (self._data,), (self,)
def fsdp_post_all_gather(
self,
all_gather_outputs: Tuple[torch.Tensor, ...],
metadata: Any,
param_dtype: torch.dtype, # pylint: disable=unused-argument
*,
out: Optional[torch.Tensor] = None,
):
"""
A hook function used in torch fsdp2, called after all-gather
return (Float8Tensor class instance of all-gathered input), (Things to free after forward)
Ref: https://github.com/pytorch/pytorch/pull/122908
"""
(data,) = all_gather_outputs
(sample,) = metadata
if out is not None:
assert isinstance(out, Float8Tensor), f"{type(out)}"
return None
return Float8Tensor.make_like(sample, data=data), all_gather_outputs
@classmethod
def make_like(
cls,
......@@ -902,7 +946,53 @@ class Float8Tensor(QuantizedTensor):
)
return Float8Tensor.make_like(tensor, data=data_view)
# Default case
# Related to FSDP2
if func == aten.split.Tensor:
tensor = args[0]
data = tensor._data
func_out = data.__torch_dispatch__(
func,
types,
[data] + list(args[1:]),
kwargs,
)
return [Float8Tensor.make_like(tensor, data=split_tensor) for split_tensor in func_out]
if func == aten.new_zeros.default:
tensor = args[0]
data = tensor._data
func_out = data.__torch_dispatch__(
func,
types,
[data] + list(args[1:]),
kwargs,
)
return Float8Tensor.make_like(tensor, data=func_out)
if func == torch.ops.aten.as_strided.default:
tensor = args[0]
data = tensor._data
func_out = data.__torch_dispatch__(
func,
types,
[data] + list(args[1:]),
kwargs,
)
return Float8Tensor.make_like(tensor, data=func_out)
if func == torch.ops.aten.detach.default:
return cls.detach(args[0])
if func == torch.ops.aten.clone.default:
return cls.clone(args[0])
if func == torch.ops.aten.copy_.default:
# Implementation in the superclass (QuantizedTensor) returns a proper output
pass
elif func in _ops_to_preserve_subclass_in_fsdp2:
# Ops in the _ops_to_preserve_subclass_in_fsdp2 are recommened to return the same class instance to work fine with the torch fsdp2
warnings.warn(
f"A function call({func}) in {cls} may not return {cls} tensor as an output. It"
" might cause an error in torch FSDP2!"
)
else:
pass
return super().__torch_dispatch__(func, types, args, kwargs)
@classmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment