Unverified Commit 29537c96 authored by vthumbe1503's avatar vthumbe1503 Committed by GitHub
Browse files

[PyTorch] FSDP2 Support for TE (#2245)



* fix for float8 tensor fsdp2 training
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* zeros_like should return fp32 for fsdp2 to work
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* minor cleanup
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* fix unsharded weights not releasing memory
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* implement using fsdp preallgather and postallgather functions
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* FSDP2 works on Hopper/L40
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor comment
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* some fixes for fp8 + handwavy changes for mxfp8
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* only transpose saved for backward pass allgather in case of L40/Hoppergst
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* missed minor change to hopper use-case
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* communicate only required data in mxfp8, fix for updating weight usages when required instead of doing upfront in fwd pass
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* changes for meta Dtensors for weights and better all gather data handling in fsdp hook functions
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* better solution to figure out forward pass in FSDP2
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* adress review comments
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* Update transformer_engine/pytorch/tensor/mxfp8_tensor.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarvthumbe1503 <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* everything functioning except hack for transformerlayer
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* fix merge conflict
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* revert change of commit id for cudnnt-frontend
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* unnecessary change
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor issues with linting, add some comments
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor stuff
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* revert space removal

Add default usage handling for rowwise and columnwise data.
Signed-off-by: default avatarvthumbe1503 <vthumbe@nvidia.com>

* fix the fsdp state collection issue, and minor review comments addressing
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* revert change for dgrad redundant computation
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* bug: get fsdp param group's training state instead of root training state; address review comments
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* address review comments
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* address review comments
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* address coderabbit review comments
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* address review comments
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* address review comments
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* adress review comments; fix fp8 allgather test to do after fsdp lazy init
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* address review comments
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* remove detach
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* do what makes sense
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* Update transformer_engine/pytorch/tensor/float8_tensor.py
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: default avatarvthumbe1503 <vthumbe@nvidia.com>

* Update transformer_engine/pytorch/tensor/mxfp8_tensor.py
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: default avatarvthumbe1503 <vthumbe@nvidia.com>

* Update transformer_engine/pytorch/tensor/mxfp8_tensor.py
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: default avatarvthumbe1503 <vthumbe@nvidia.com>

* Update transformer_engine/pytorch/tensor/mxfp8_tensor.py
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: default avatarvthumbe1503 <vthumbe@nvidia.com>

* Update transformer_engine/pytorch/tensor/mxfp8_tensor.py
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: default avatarvthumbe1503 <vthumbe@nvidia.com>

* Update transformer_engine/pytorch/tensor/mxfp8_tensor.py
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: default avatarvthumbe1503 <vthumbe@nvidia.com>

* address review comments
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* address review comments
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* address review comments
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* adress review comments
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* address review comments
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* have better dtype for fsdp_post_all_gather arguments
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* minor comment
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* improve comment
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* fix the error in CI
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* minor comment add
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* accidentally removed view function
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* fix minor bug for h100
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* minor addition
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* implement padding removal/addition for allgather
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* address review comments
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* address review comments
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* Update transformer_engine/pytorch/tensor/mxfp8_tensor.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarvthumbe1503 <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix lint error
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* adress review comments
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* improve the reset parameter logic for dtensors
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* other cosmetic changes
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* cosmetic changes
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* cosmetic changes
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* Update transformer_engine/pytorch/module/layernorm_linear.py
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

---------
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>
Signed-off-by: default avatarvthumbe1503 <vthumbe@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
parent 7a585983
......@@ -9,57 +9,73 @@ import sys
import argparse
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling
from transformer_engine.common.recipe import (
Format,
DelayedScaling,
Float8CurrentScaling,
MXFP8BlockScaling,
)
import torch
import torch.distributed as dist
from torch.distributed.tensor import DTensor
import torch.nn.functional as F
from torch import nn, optim
from torch.distributed import DeviceMesh
from torch.distributed._composable.fsdp import fully_shard
from torch.distributed.device_mesh import init_device_mesh
from transformer_engine.pytorch import QuantizedTensor
from contextlib import nullcontext
LOCAL_RANK = None
class SimpleNet(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(SimpleNet, self).__init__()
self.fc1 = te.Linear(input_size, hidden_size)
self.fc2 = te.Linear(hidden_size, output_size)
def forward(self, x):
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
def save_custom_attrs(module):
custom_attrs = {}
for name, param in module.named_parameters():
attrs = vars(param)
custom_attrs[name] = {k: v for k, v in attrs.items()}
return custom_attrs
def restore_custom_attrs(module, custom_attrs):
for name, param in module.named_parameters():
if name in custom_attrs:
for attr_name, attr_value in custom_attrs[name].items():
setattr(param, attr_name, attr_value)
def dist_print(msg):
if LOCAL_RANK == 0:
print(msg)
def _parse_args(argv=None, namespace=None):
parser = argparse.ArgumentParser(description="Toy example for debugging fully_shard()")
parser.add_argument("--input-size", type=int, default=2048, help="Input size for the model")
parser.add_argument("--hidden-size", type=int, default=2048, help="Hidden layer size")
parser.add_argument("--output-size", type=int, default=2048, help="Output size for the model")
parser.add_argument("--batch-size", type=int, default=2048, help="Output size for the model")
parser.add_argument("--num-heads", type=int, default=8, help="Number of attn. heads")
parser.add_argument("--head-dim", type=int, default=64, help="Attention head size")
parser.add_argument("--batch-size", type=int, default=16, help="Batch size of input")
parser.add_argument("--seq-length", type=int, default=128, help="Sequence length of input")
parser.add_argument("--params-dtype", type=str, default="float32", help="Parameter dtype.")
parser.add_argument(
"--fp8-init", action="store_true", default=False, help="Initialize primary weights in FP8."
)
parser.add_argument(
"--recipe",
type=str,
default="mx_fp8_block_scaling",
help="Quantizer type.",
choices=["delayed_scaling", "current_scaling", "mx_fp8_block_scaling"],
)
parser.add_argument(
"--layer-type",
type=str,
default="TransformerLayer",
choices=[
"Linear",
"LayerNormLinear",
"LayerNormMLP",
"MultiheadAttention",
"TransformerLayer",
],
help="Transformer Engine layer type",
)
parser.add_argument("--num-layers", type=int, default=4, help="Number of layers in the model")
parser.add_argument(
"--iter", type=int, default=10, help="Number of iterations for forward pass"
)
parser.add_argument(
"--device",
type=str,
default="meta",
help="Device to run the model on.",
choices=["cuda", "meta"],
)
parser.add_argument("--seed", type=int, default=42, help="RNG seed.")
# Adding hsdp_dim as a list argument, comma-separated
parser.add_argument(
......@@ -74,10 +90,170 @@ def _parse_args(argv=None, namespace=None):
return args
sub_modules_to_wrap = [te.Linear]
## Methods to help initialize the TE model in an FSDP2 setting
## with required configurations based on command line args
def get_te_layer_from_string(layer_name):
te_layer_types = [
te.Linear,
te.LayerNormLinear,
te.LayerNormMLP,
te.MultiheadAttention,
te.TransformerLayer,
]
te_layer_names = [layer.__name__ for layer in te_layer_types]
te_layer_map = dict(zip([name.lower() for name in te_layer_names], te_layer_types))
if layer_name.lower() not in te_layer_map.keys():
raise argparse.ArgumentTypeError(
f'"{layer_name}" is not a valid Transformer Engine layer, '
f"please choose layer from {te_layer_names}."
)
return te_layer_map[layer_name.lower()]
def get_recipe_from_string(recipe, fp8_format=Format.HYBRID):
if recipe == "delayed_scaling":
return DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max")
elif recipe == "current_scaling":
return Float8CurrentScaling(fp8_format=fp8_format)
elif recipe == "mx_fp8_block_scaling":
return MXFP8BlockScaling(fp8_format=fp8_format)
else:
raise ValueError(f"Unknown quantizer type: {recipe}")
def init_te_model(config):
hidden_size = config.num_heads * config.head_dim
args = [hidden_size, hidden_size]
inp_shape = [config.seq_length, config.batch_size, hidden_size]
out_shape = [config.seq_length, config.batch_size, hidden_size]
if config.params_dtype == "float16":
params_dtype = torch.float16
elif config.params_dtype == "bfloat16":
params_dtype = torch.bfloat16
else:
params_dtype = torch.float32
kwargs = {
"params_dtype": params_dtype,
}
kwargs["device"] = config.device
layer_type = get_te_layer_from_string(config.layer_type)
# We are creating model in a way so that we can test both reshard_after_forward=True/False cases.
# more details below.
if layer_type in [te.MultiheadAttention, te.TransformerLayer]:
# For this case, we are creating a model that resemebles production use-cases
# wherein there are mltiple TransformerLayers in the model. And we would need
# to shard each transformer layer. Since each transformer layer is not a root module,
# FSDP2's fully_shard assigns reshard_after_forward=False for all parameters of the model.
args[1] *= 4 # FFN hidden size
args.append(config.num_heads)
kwargs["fuse_qkv_params"] = True
if layer_type is te.MultiheadAttention:
kwargs["input_layernorm"] = True
model = nn.Sequential(*[layer_type(*args, **kwargs) for _ in range(config.num_layers)])
elif layer_type == te.LayerNormLinear:
# For this case, we are creating a model with just one LayerNormLinear layer
# so that the model itself is a root module, and FSDP2's fully_shard assigns
# reshard_after_forward=True for the parameters of these model.
args[1] *= 3 # QKV projection
out_shape[-1] *= 3
model = layer_type(*args, **kwargs)
else:
model = layer_type(*args, **kwargs)
return model, inp_shape, out_shape
def get_device_mesh(world_size, sharding_dims):
dist_print(f"sharding-dims:{sharding_dims}")
device_ids = list(range(world_size))
if sharding_dims is None: # FSDP
mesh = DeviceMesh("cuda", device_ids)
elif len(sharding_dims) == 1:
assert sharding_dims[0] == world_size
mesh = DeviceMesh("cuda", device_ids)
elif len(sharding_dims) == 2: # HSDP
assert sharding_dims[0] * sharding_dims[1] == world_size
mesh = init_device_mesh(
"cuda",
(sharding_dims[0], sharding_dims[1]),
mesh_dim_names=("replicate", "shard"),
)
else:
assert False
return mesh
def shard_model_with_fsdp2(model, mesh):
for child in model.children():
fully_shard(child, mesh=mesh)
fully_shard(model, mesh=mesh)
return model
#### Methods to save the custom attributes of QuantizedTensors before sharding
#### them with FSDP2, and restore them after sharding.
def save_custom_attrs(module):
custom_attrs = {}
for name, param in module.named_parameters():
if isinstance(param, QuantizedTensor):
# Ignore FP8 metadata attributes. Otherwise we will save duplicate copies
# for data/transpose FP8 tensors on top of FP8 tensors that FSDP2 will save.
ignore_keys = [key for key in param.__dict__.keys() if key.startswith("_")]
else:
ignore_keys = []
attrs = vars(param)
custom_attrs[name] = {k: v for k, v in attrs.items() if k not in ignore_keys}
return custom_attrs
def restore_custom_attrs(module, custom_attrs):
for name, param in module.named_parameters():
if name in custom_attrs:
for attr_name, attr_value in custom_attrs[name].items():
setattr(param, attr_name, attr_value)
@torch.no_grad()
def test_fp8_fsdp2_allgather(model):
# Do manual allgather in fp32 and match against fp8 allgather done
# with fsdp2
# FP32 manual weight allgather
fp32_allgathered_params = {}
for name, param in model.named_parameters():
assert isinstance(param, DTensor)
local_tensor = param._local_tensor
device_mesh = param.device_mesh
dist_group = (
device_mesh.get_group(mesh_dim="shard")
if device_mesh.ndim > 1
else device_mesh.get_group()
)
# Perform manual allgather on local_tensor. zeros_like will create hp tensor since torch_dispatch
# for local_tensor will go down the dequantization route.
gathered_tensor = [
torch.zeros_like(local_tensor) for _ in range(dist.get_world_size(group=dist_group))
]
dist.all_gather(gathered_tensor, local_tensor.dequantize(), group=dist_group)
full_tensor = torch.cat(gathered_tensor, dim=0)
fp32_allgathered_params[name] = full_tensor
# FP8 allgather using FSDP2
for module in model.modules():
# Not all modules are wrapped/sharded with FSDP2.
if hasattr(module, "unshard"):
module.unshard()
# Make sure allgathered parameters match exactly
for name, param in model.named_parameters():
assert torch.allclose(param.dequantize(), fp32_allgathered_params[name])
# Revert model to original sharded state
for module in model.modules():
# Not all modules are wrapped/sharded with FSDP2.
if hasattr(module, "reshard"):
module.reshard()
def _train(args):
global LOCAL_RANK
assert "TORCHELASTIC_RUN_ID" in os.environ
WORLD_RANK = int(os.getenv("RANK", "0"))
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
......@@ -103,74 +279,69 @@ def _train(args):
# FP8 Configuration
fp8_format = Format.HYBRID
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max")
# Create build context manager
if args.fp8_init:
from transformer_engine.pytorch import quantized_model_init
fp8_recipe = get_recipe_from_string(args.recipe, fp8_format)
build_model_context = quantized_model_init()
build_model_context_args = {}
if not args.fp8_init:
# Build model context (FP8 init)
build_model_context = nullcontext
else:
build_model_context = nullcontext()
from transformer_engine.pytorch import fp8_model_init
# Build the model with the specified context
with build_model_context:
model = SimpleNet(args.input_size, args.hidden_size, args.output_size)
build_model_context = fp8_model_init
build_model_context_args["enabled"] = True
build_model_context_args["recipe"] = fp8_recipe
# Move the model to the correct device
model.to(device)
dist_print(f"Memory before model init: {torch.cuda.memory_allocated(device)/1e6} MB")
# Create the model on the meta/cuda device as per args
with build_model_context(**build_model_context_args):
model, inp_shape, out_shape = init_te_model(args)
dist_print(
f"Memory after model init on device {args.device}:"
f" {torch.cuda.memory_allocated(device)/1e6} MB"
)
if LOCAL_RANK == 0:
print(f"Rank {LOCAL_RANK}: Applying FSDP fully_shard() to the model...")
# Creating a DeviceMesh for fully_shard
world_size = int(WORLD_SIZE)
device_ids = list(range(world_size))
if LOCAL_RANK == 0:
print(f"sharding-dims:{args.sharding_dims}")
# Setup the sharding mesh for FSDP/HSDP
if args.sharding_dims == None: # FSDP
mesh = DeviceMesh("cuda", device_ids)
elif len(args.sharding_dims) == 1:
assert args.sharding_dims[0] == device_ids[-1] + 1
mesh = DeviceMesh("cuda", device_ids)
elif len(args.sharding_dims) == 2: # HSDP
assert args.sharding_dims[0] * args.sharding_dims[1] == device_ids[-1] + 1
mesh = init_device_mesh(
"cuda",
(args.sharding_dims[0], args.sharding_dims[1]),
mesh_dim_names=("replicate", "shard"),
)
else:
assert False
# Apply FSDP/HSDP
mesh = get_device_mesh(world_size, args.sharding_dims)
custom_attrs = save_custom_attrs(model)
for sub_module in model.modules():
if any(
isinstance(sub_module, sub_module_to_wrap) for sub_module_to_wrap in sub_modules_to_wrap
):
fully_shard(sub_module, mesh=mesh)
fully_shard(model, mesh=mesh)
model = shard_model_with_fsdp2(model, mesh)
restore_custom_attrs(model, custom_attrs)
# model now has DTensors as its parameters
if args.device == "meta":
# After FSDP2 has been applied, materialize and initialize the sharded parameters
# TE base.py's reset_parameters() handles DTensors with FP8 initialization
for module in model.modules():
if hasattr(module, "reset_parameters"):
module.reset_parameters()
dist_print(f" Sharded parameters materialized and initialized on cuda device.")
dist_print(
f"FSDP2 model in cuda, memory allocated: {torch.cuda.memory_allocated(device)/1e6} MB"
)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
for iteration in range(args.iter):
# Zero the parameter gradients
optimizer.zero_grad()
input_data = torch.randn(args.batch_size, args.input_size).to(device)
input_data = torch.randn(inp_shape).to(device)
with te.autocast(enabled=True, recipe=fp8_recipe):
output = model(input_data)
target = torch.randn(args.batch_size, args.output_size).to(device)
target = torch.randn(out_shape).to(device)
loss = F.mse_loss(output, target)
loss.backward()
optimizer.step()
if LOCAL_RANK == 0:
print(f"Rank {LOCAL_RANK}: Iteration {iteration} completed.")
dist_print(f"Iteration {iteration} completed with loss {loss.item()}")
# Some of the FSDP states are lazy initialized during FSDP forward pass
# so testing fp8 allgather at the end of the training loop.
if args.fp8_init:
test_fp8_fsdp2_allgather(model)
dist.destroy_process_group()
if LOCAL_RANK == 0:
print(f"Rank {LOCAL_RANK}: Done...")
return 0
......
......@@ -12,22 +12,26 @@ import torch
fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True)
NUM_PROCS: int = torch.cuda.device_count()
def _run_test(fp_init, sharding_dims):
def _run_test(fp_init, sharding_dims, recipe, layer_type):
test_path = Path(__file__).parent.resolve() / "run_fsdp2_model.py"
test_cmd = ["torchrun", f"--nproc_per_node={NUM_PROCS}", str(test_path)]
if fp_init:
test_cmd += ["--fp8-init"]
if len(sharding_dims) == 1:
test_cmd += ["--sharding-dims", str(sharding_dims[0])]
elif len(sharding_dims) == 2:
test_cmd += ["--sharding-dims", str(sharding_dims[0]), str(sharding_dims[1])]
else:
assert False
test_cmd += ["--recipe", recipe]
test_cmd += ["--layer-type", layer_type]
result = subprocess.run(test_cmd, env=os.environ, check=True)
......@@ -36,16 +40,20 @@ def _run_test(fp_init, sharding_dims):
@pytest.mark.skipif(not te.torch_version() >= (2, 4, 0), reason="Requires PyTorch 2.4.0+")
@pytest.mark.parametrize("sharding_dims", ([NUM_PROCS], [2, NUM_PROCS // 2]))
@pytest.mark.parametrize("fp8_init", (False, True))
def test_distributed(fp8_init, sharding_dims):
@pytest.mark.parametrize("recipe", ("delayed_scaling", "current_scaling", "mx_fp8_block_scaling"))
@pytest.mark.parametrize("layer_type", ("LayerNormLinear", "TransformerLayer"))
def test_distributed(fp8_init, sharding_dims, recipe, layer_type):
# Skip invalid configurations
if torch.cuda.device_count() < 4:
pytest.skip("FSDP2 test requires at least 4 GPUs")
if fp8_init and not fp8_available:
if recipe == "mx_fp8_block_scaling" and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
elif not fp8_available:
pytest.skip(reason_for_no_fp8)
_run_test(fp8_init, sharding_dims)
_run_test(fp8_init, sharding_dims, recipe, layer_type)
def test_dummy() -> None:
......
......@@ -1886,6 +1886,43 @@ def allreduce(
return inp, handle
def _get_module_fsdp_state(module):
"""
If module is an FSDP module, return its _FSDPState.
Otherwise, return the _FSDPState of the closest parent FSDP module
in the module hierarchy the module belongs to.
"""
if hasattr(module, "_get_fsdp_state"):
# this will return correct fsdp state if module itself is an fsdp module
fsdp_state = module._get_fsdp_state()
elif getattr(module, "_te_cached_parent_fsdp_state", None) is not None:
# See if we have cached the parent fsdp state of the module
fsdp_state = module._te_cached_parent_fsdp_state
else:
from torch.distributed._composable_state import _module_state_mapping
# Otherwise get the fsdp state of lca of module in the module hierarchy
min_nodes_in_parent = float("inf")
closest_parent_fsdp_mod = None
for fsdp_mod in _module_state_mapping.keys():
all_submodules = list(fsdp_mod.modules())
for submodule in all_submodules:
if submodule is module:
if min_nodes_in_parent > len(all_submodules):
closest_parent_fsdp_mod = fsdp_mod
min_nodes_in_parent = len(all_submodules)
if closest_parent_fsdp_mod is None:
raise RuntimeError(
"Module is not FSDP-wrapped and does not have any FSDP-wrapped parent modules."
)
fsdp_state = closest_parent_fsdp_mod._get_fsdp_state()
# Cache the parent fsdp state of the module to avoid recomputing
# the closest parent fsdp module.
module._te_cached_parent_fsdp_state = fsdp_state
return fsdp_state
def _fsdp_scatter_tensors(
fsdp_group: dist_group_type,
*tensors: torch.Tensor,
......
......@@ -17,6 +17,7 @@ from types import MethodType
import torch
import torch.nn.functional as F
from torch.distributed.tensor import DTensor
import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe
......@@ -1244,6 +1245,11 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
metedata used in deferred initialization.
"""
super().register_parameter(name, param)
# Initialize param_init_meta exactly once during the init. FSDP2 can call
# register parameter again to change parameters to DTensors. And it calls
# it without custom fp8 specific kwargs that we need. And so we dont want
# to reset/loose our fp8 init attributes.
if hasattr(self, "param_init_meta") and name not in self.param_init_meta:
self.param_init_meta[name] = _ParameterInitMeta(**kwargs)
def reset_parameters(self, defer_init: Optional[bool] = False) -> None:
......@@ -1256,10 +1262,14 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
return
for name, param in self.named_parameters(recurse=False):
# Check if parameter is a DTensor (FSDP2) or regular tensor
is_dtensor = isinstance(param, DTensor)
dtensor_param = param if is_dtensor else None
# Need to update/quantize local tensor in case of DTensor
param = param._local_tensor if is_dtensor else param
# Ensure parameter is on a real device
if param.device == torch.device("meta"):
param = torch.empty_like(param, device="cuda")
# Initialize the parameter values on device
init_fn = self.param_init_meta[name].init_fn
get_rng_state_tracker = self.param_init_meta[name].get_rng_state_tracker
......@@ -1288,7 +1298,15 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
raise RuntimeError("Weight quantizer has not been initialized")
quantizer.set_usage(rowwise=True, columnwise=torch.is_grad_enabled())
quantizer.internal = False
if is_dtensor and isinstance(quantizer, Float8CurrentScalingQuantizer):
device_mesh = dtensor_param.device_mesh
amax_reduction_group = (
device_mesh.get_group(mesh_dim="shard")
if device_mesh.ndim > 1
else device_mesh.get_group()
)
quantizer.amax_reduction_group = amax_reduction_group
quantizer.with_amax_reduction = True
# Quantize parameter
param = quantizer(param)
......@@ -1296,6 +1314,17 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# NOTE: Currently this can only be broken when primary weights are in Fp8 but
# re-applying the nn.Parameter() wrap is a no-op when the input is already
# a parameter so we always re-apply it just for extra safety.
if is_dtensor:
# recreate the DTensor from the parameter.
dtensor_param = DTensor.from_local(
param,
device_mesh=dtensor_param.device_mesh,
placements=dtensor_param.placements,
shape=dtensor_param.size(),
stride=dtensor_param.stride(),
)
dtensor_param = torch.nn.Parameter(dtensor_param)
else:
param = torch.nn.Parameter(param)
# Keep high-precision values on CPU if needed
......@@ -1324,8 +1353,12 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
param._high_precision_init_val = high_precision_init_val
param.get_high_precision_init_val = MethodType(get, param)
param.clear_high_precision_init_val = MethodType(clear, param)
# Update the parameter based on its type
if not is_dtensor:
setattr(self, name, param)
else:
setattr(self, name, dtensor_param)
@abstractmethod
def forward(self):
......
......@@ -108,9 +108,15 @@ class _GroupedLinear(torch.autograd.Function):
is_fp8_activation_recompute_enabled()
and not in_fp8_activation_recompute_phase()
)
if weight_quantizers[0] is not None:
# No need to set the quantizer states if weight is already quantized
if weight_quantizers[0] is not None and not isinstance(
weights[0], QuantizedTensorStorage
):
for weight_quantizer in weight_quantizers:
weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage)
elif isinstance(weights[0], QuantizedTensorStorage):
# If weights are already quantized, no need to set quantizer states
weight_quantizers = [weight._quantizer for weight in weights]
if output_quantizers[0] is not None:
for output_quantizer in output_quantizers:
output_quantizer.set_usage(rowwise=True, columnwise=False)
......@@ -205,10 +211,6 @@ class _GroupedLinear(torch.autograd.Function):
inputmat.update_usage(rowwise_usage=False, columnwise_usage=True)
else:
inputmats = [None] * num_gemms
if inp.requires_grad:
for weight in weights_fp8:
if isinstance(weight, QuantizedTensorStorage):
weight.update_usage(columnwise_usage=True)
if cpu_offloading:
ctx.grad_added_to_main_grad = hasattr(weights[0], "grad_added_to_main_grad")
......@@ -354,13 +356,11 @@ class _GroupedLinear(torch.autograd.Function):
dtype=ctx.activation_dtype,
device=ctx.device,
)
for weight, quantizer in zip(weights, ctx.weight_quantizers):
if quantizer is not None and isinstance(weight, QuantizedTensorStorage):
weight.update_usage(
rowwise_usage=quantizer.rowwise_usage,
columnwise_usage=quantizer.columnwise_usage,
)
# Make sure weights are available in column-wise format
# for dgrad computation.
for weight in weights:
if isinstance(weight, QuantizedTensorStorage):
weight.update_usage(columnwise_usage=True)
general_grouped_gemm(
weights,
grad_output,
......
......@@ -276,12 +276,15 @@ class _LayerNormLinear(torch.autograd.Function):
# Prepare weight tensor
# ------------------------------------------------------
weightmat = weight
quantized_weight = False
is_weight_param_quantized = False
if fp8 or debug:
quantized_weight = not isinstance(weight, QuantizedTensorStorage)
is_weight_param_quantized = isinstance(weight, QuantizedTensorStorage)
# Configure quantizer
if weight_quantizer is not None:
# If weight is already quantized, no need to set quantizer states
if is_weight_param_quantized:
weight_quantizer = weight._quantizer
elif weight_quantizer is not None:
weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled)
# Get quantized weight
......@@ -413,10 +416,6 @@ class _LayerNormLinear(torch.autograd.Function):
):
ln_out.update_usage(rowwise_usage=False)
# Weight with column-wise usage is needed for dgrad GEMM.
if isinstance(weightmat, QuantizedTensorStorage):
weightmat.update_usage(columnwise_usage=True)
if cpu_offloading:
mark_activation_offload(inputmat, mu, rsigma, ln_out)
......@@ -429,7 +428,7 @@ class _LayerNormLinear(torch.autograd.Function):
fsdp_group,
mu,
rsigma,
weightmat if quantized_weight else None,
weightmat if fp8 and not is_weight_param_quantized else None,
ln_out if weight.requires_grad else None,
)
nvtx_range_pop(f"{nvtx_label}.fsdp_scatter")
......@@ -459,7 +458,7 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.tensor_objects = tensor_objects
ctx.requires_dgrad = inp_requires_grad
ctx.requires_wgrad = weight.requires_grad
ctx.quantized_weight = quantized_weight
ctx.is_weight_param_quantized = is_weight_param_quantized
if fuse_wgrad_accumulation and weight.requires_grad:
# This check is needed to ensure that main_grad is not created
# during the forward pass when using MCore FSDP as it creates
......@@ -563,7 +562,7 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.fsdp_shapes,
mu,
rsigma,
weight if ctx.fp8 and ctx.quantized_weight else None,
weight if ctx.fp8 and not ctx.is_weight_param_quantized else None,
ln_out,
)
nvtx_range_pop(f"{nvtx_label}.fsdp_gather")
......
......@@ -351,8 +351,17 @@ class _LayerNormMLP(torch.autograd.Function):
# which handles weight caching etc.
# FP8 cast to workspace buffer
update_workspace = is_first_microbatch is None or is_first_microbatch
# No need to set the quantizer states if weights are already quantized
if isinstance(fc1_weight, QuantizedTensorStorage):
fc1_weight_quantizer = fc1_weight._quantizer
elif fc1_weight_quantizer is not None:
fc1_weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled)
if isinstance(fc2_weight, QuantizedTensorStorage):
fc2_weight_quantizer = fc2_weight._quantizer
elif fc2_weight_quantizer is not None:
fc2_weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled)
fc1_weight_final = module.get_weight_workspace(
tensor=fc1_weight,
quantizer=fc1_weight_quantizer,
......@@ -538,13 +547,6 @@ class _LayerNormMLP(torch.autograd.Function):
# Cache state for backward pass
if is_grad_enabled:
# Weight with column-wise usage is needed for dgrad GEMM.
if isinstance(fc1_weight_final, QuantizedTensorStorage):
fc1_weight_final.update_usage(columnwise_usage=True)
if isinstance(fc2_weight_final, QuantizedTensorStorage):
fc2_weight_final.update_usage(columnwise_usage=True)
if cpu_offloading:
mark_activation_offload(
inputmat, mu, rsigma, ln_out, fc1_out, fc1_out_without_bias, act_out
......
......@@ -240,7 +240,8 @@ class _Linear(torch.autograd.Function):
weightmat = weight
if fp8 or debug:
# Configure quantizer
if weight_quantizer is not None:
# No need to set the quantizer states if weight is already quantized
if weight_quantizer is not None and not isinstance(weight, QuantizedTensor):
columnwise_usage = is_grad_enabled and inp.requires_grad
if not columnwise_usage:
columnwise_usage = (
......@@ -248,7 +249,9 @@ class _Linear(torch.autograd.Function):
and not in_fp8_activation_recompute_phase()
)
weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage)
elif isinstance(weight, QuantizedTensor):
# If weight is already quantized, no need to set quantizer states
weight_quantizer = weight._quantizer
# Get quantized weight
update_workspace = is_first_microbatch is None or is_first_microbatch
weightmat = module.get_weight_workspace(
......@@ -389,11 +392,6 @@ class _Linear(torch.autograd.Function):
if backward_needs_input:
saved_inputmat = inputmat
# Weight with column-wise usage is needed for dgrad GEMM.
if inp.requires_grad:
if isinstance(weightmat, QuantizedTensorStorage):
weightmat.update_usage(columnwise_usage=True)
if cpu_offloading and saved_inputmat is not None:
mark_activation_offload(saved_inputmat)
......
......@@ -433,6 +433,10 @@ class QuantizedTensor(torch.Tensor):
and schema_arg.alias_info.is_write
):
arg.quantize_(new_arg)
elif isinstance(arg, list) and isinstance(new_arg, list):
# Recursively handle update for lists of tensors
for a, na in zip(arg, new_arg):
maybe_update_inplace(a, na, schema_arg)
# In-place op: dequantize, perform op, and quantize
if func._schema.is_mutable:
......@@ -489,20 +493,16 @@ class QuantizedTensor(torch.Tensor):
shape: Optional[Iterable[int]] = None,
dtype: Optional[torch.dtype] = None,
requires_grad: bool = False,
data: Optional[torch.Tensor] = None,
) -> QuantizedTensor:
"""Create new quantized tensor
By default, new tensor has the same attributes and underlying
data.
data. This function is intended to create view of tensors.
"""
if shape is None:
shape = data.shape if data is not None else tensor.shape
shape = shape if shape is not None else tensor.shape
dtype = dtype if dtype is not None else tensor.dtype
kwargs = tensor.get_metadata()
if data is not None:
kwargs["data"] = data
return cls(shape=shape, dtype=dtype, requires_grad=requires_grad, **kwargs)
def to_dtype(self, dtype: torch.dtype) -> QuantizedTensor:
......
......@@ -4,10 +4,10 @@
"""Tensor class with FP8 data"""
from __future__ import annotations
from typing import Optional, Tuple, Iterable, Union
from typing import Any, Optional, Tuple, Iterable, Union
import warnings
import torch
from torch.distributed.fsdp._fully_shard._fsdp_common import TrainingState
import transformer_engine_torch as tex
from transformer_engine_torch import DType as TE_DType
......@@ -299,14 +299,12 @@ class Float8CurrentScalingQuantizer(Quantizer):
# Allocate FP8 data transpose if needed
data_transpose = None
if self.columnwise_usage:
inner_dim = data.size(-1)
transpose_shape = [data.size(-1)] + list(data.shape[:-1])
data_transpose = torch.empty(
inner_dim,
data.numel() // inner_dim,
transpose_shape,
dtype=torch.uint8,
device=device,
)
# Construct FP8 tensor
return Float8Tensor(
shape=shape,
......@@ -534,9 +532,36 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor):
self._transpose = None
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None):
def make_like(
cls,
tensor: QuantizedTensor,
*,
shape: Optional[Iterable[int]] = None,
dtype: Optional[torch.dtype] = None,
requires_grad: bool = False,
data: Optional[torch.Tensor] = None,
data_transpose: Optional[torch.Tensor] = None,
) -> QuantizedTensor:
"""Create new quantized tensor
# View op
By default, new tensor has the same attributes and underlying
data.
"""
if shape is None and data is not None:
shape = data.shape
new_tensor = super().make_like(
tensor, shape=shape, dtype=dtype, requires_grad=requires_grad
)
if data is not None:
new_tensor._data = data
if data_transpose is not None:
new_tensor._transpose = data_transpose
new_tensor._transpose_invalid = False
return new_tensor
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None):
if func == aten.view.default:
tensor = args[0]
data = tensor._data
......@@ -555,6 +580,9 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor):
or out_transpose_shape[1:] != out_shape[:-1]
):
out_transpose = None
else:
view_shape_for_transpose = [out_shape[-1]] + list(out_shape[:-1])
out_transpose = out_transpose.view(*view_shape_for_transpose)
return Float8Tensor(
shape=out_shape,
dtype=tensor.dtype,
......@@ -587,11 +615,37 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor):
[data] + list(args[1:]),
kwargs,
)
return [
Float8Tensor.make_like(tensor, data=split_tensor, shape=split_tensor.shape)
for split_tensor in func_out
t_func_out = [None] * len(func_out)
# Compute corresponding split of the transpose cache if available
if tensor._transpose is not None and not tensor._transpose_invalid:
transpose = tensor._transpose
ndim = data.dim()
# Figure out the original split dim
if "dim" in kwargs:
dim_to_split = kwargs["dim"]
else:
dim_to_split = args[2] if len(args) > 2 else 0
# Dimension along which transpose needs to be split
t_dim = 0 if dim_to_split == ndim - 1 else dim_to_split + 1
t_func_out = transpose.__torch_dispatch__(
func,
types,
[transpose, args[1], t_dim],
kwargs,
)
outs = [
Float8Tensor.make_like(
tensor,
data=split_tensor,
data_transpose=split_transpose_tensor,
shape=split_tensor.shape,
)
for split_tensor, split_transpose_tensor in zip(func_out, t_func_out)
]
return outs
if func == aten.new_zeros.default:
# create fresh new tensor with zeros.
tensor = args[0]
data = tensor._data
func_out = data.__torch_dispatch__(
......@@ -600,17 +654,63 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor):
[data] + list(args[1:]),
kwargs,
)
return Float8Tensor.make_like(tensor, data=func_out, shape=func_out.shape)
func_transposed_out = None
if tensor._transpose is not None and not tensor._transpose_invalid:
transpose = tensor._transpose
size = args[1]
t_shape = [size[-1]] + list(size[:-1])
func_transposed_out = transpose.__torch_dispatch__(
func,
types,
[transpose, t_shape] + list(args[2:]),
kwargs,
)
# deep copy the scale inverse tensor and quantizer as well.
scale_inv = tensor._scale_inv.detach().clone()
quantizer = tensor._quantizer.copy()
out_tensor = Float8Tensor(
data=func_out,
shape=func_out.shape,
dtype=tensor.dtype,
fp8_dtype=tensor._fp8_dtype,
fp8_scale_inv=scale_inv,
data_transpose=func_transposed_out,
quantizer=quantizer,
)
return out_tensor
if func == torch.ops.aten.as_strided.default:
tensor = args[0]
data = tensor._data
# Apply as_strided to the primary uint8 data
func_out = data.__torch_dispatch__(
func,
types,
[data] + list(args[1:]),
kwargs,
)
return Float8Tensor.make_like(tensor, data=func_out, shape=func_out.shape)
func_transposed_out = None
if tensor._transpose is not None and not tensor._transpose_invalid:
transpose = tensor._transpose
size = args[1]
stride = args[2]
if "storage_offset" in kwargs:
storage_offset = kwargs["storage_offset"]
else:
storage_offset = args[3] if len(args) > 3 else 0
# Shape and strided needed for transpose matrix
t_size = [size[-1]] + list(size[:-1])
t_stride = [stride[-1]] + list(stride[:-1])
func_transposed_out = transpose.__torch_dispatch__(
func,
types,
[transpose, t_size, t_stride, storage_offset] + list(args[4:]),
kwargs,
)
return Float8Tensor.make_like(
tensor, data=func_out, data_transpose=func_transposed_out, shape=func_out.shape
)
if func == torch.ops.aten.detach.default:
return cls.detach(args[0])
if func == torch.ops.aten.clone.default:
......@@ -632,9 +732,105 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor):
)
else:
pass
return super().__torch_dispatch__(func, types, args, kwargs)
def fsdp_pre_all_gather(self, mesh, orig_size, contiguous_orig_stride, module, mp_policy):
"""Functions FSDP2 calls before all-gather of the
weights for both forward and backward passes.
Args:
mesh (torch.distributed.DeviceMesh): DeviceMesh used by FSDP2
to shard the weights.
orig_size (torch.Size): Original size of the weight tensor.(For us same as self.shape)
contiguous_orig_stride (Tuple[int]): Original stride of the weight tensor
(For us same as self.stride())
module (FSDPModule): FSDP module. FSDP wrapped module wrapped using fully_shard
that contains this FP8 tensor.
mp_policy (MixedPrecisionPolicy): Mixed precision policy used by FSDP2.
Returns:
shareded_tensors: Tuple[torch.Tensor, ...]: Tuple of tensors
that need to be all-gathered.(In this case uint8 data tensor)
metadata: Tuple[Any]: Metadata needed for reconstructing the
Float8Tensor after all-gather.
"""
# pylint: disable=unused-argument
# Importing here to avoid circular imports
from transformer_engine.pytorch.distributed import _get_module_fsdp_state
if isinstance(self._quantizer, Float8CurrentScalingQuantizer) and mesh is not None:
# When sharded weight is updated after reduce scattering the gradients in FSDP2,
# we need to do amax reduction across the mesh to make sure all weight shards are
# updated with same scale inverse. Setting the state below in the quantizer will make
# sure that updated Quantized weight tensor have same scale inverse across all shards.
self._quantizer.amax_reduction_group = mesh.get_group()
self._quantizer.with_amax_reduction = True
quantizer = self._quantizer.copy() # quantizer to be used for allgathered weights
fsdp_state = _get_module_fsdp_state(module)
reshard_after_forward = fsdp_state._fsdp_param_group._reshard_after_forward
# If weights are resharded after forward pass, then its enough to set the quantizer usages
# based on whether its forward or backward pass for the allgathered weights.
# If not resharded after forward pass, the same weights allgathered in forward
# are used again in backward and so we dont change the quantizer usages which might need
# both rowwise and columnwise usages.
if reshard_after_forward:
training_state = fsdp_state._fsdp_param_group._training_state
is_backward_pass = training_state == TrainingState.PRE_BACKWARD
# In case of hopper/L40, only one of data/transpose is needed
# based on forward or backward pass. So setting the quantizer usages appropriately.
quantizer.set_usage(rowwise=not is_backward_pass, columnwise=is_backward_pass)
sharded_tensors = (self._data,)
metadata = (self._scale_inv, self._fp8_dtype, quantizer)
return sharded_tensors, metadata
def fsdp_post_all_gather(
self,
all_gather_outputs: Tuple[torch.Tensor, ...],
metadata: Any,
param_dtype: torch.dtype,
*,
out: Optional[Float8Tensor] = None,
):
"""Functions FSDP2 calls after all-gather of the
weights for both forward and backward passes.
Args:
all_gather_outputs (Tuple[torch.Tensor, ...]): sharded_tensors sent out in fsdp_pre_all_gather from each rank
are all-gathered and received here as a tuple.
metadata (Any): metadata sent out in fsdp_pre_all_gather used for reconstructing the Float8Tensor.
param_dtype (torch.dtype): high precision dtype of the Float8Tensor.
out (Optional[torch.Tensor], optional): _description_. Defaults to None.
Returns:
Tuple[Float8Tensor, Tuple[torch.Tensor, ...]]: Allgathered Float8Tensor and tuple of internal tensors
used by the Float8Tensor that was being computed after allgather.
"""
(data,) = all_gather_outputs
(fp8_scale_inv, fp8_dtype, quantizer) = metadata
orig_shape = data.size()
# Quantizer has only columnwise usage set for backward pass
# In Blackwell+ architectures, transpose is not needed at all,
# even if columnwise usage is set. and is going to be handled
# internally in the update_usage method.
if out is not None:
out._data = data
else:
fp8_args = {
"shape": orig_shape,
"dtype": param_dtype,
"fp8_scale_inv": fp8_scale_inv,
"fp8_dtype": fp8_dtype,
"quantizer": quantizer,
"requires_grad": False,
"data": data,
}
out = Float8Tensor(**fp8_args)
out.update_usage(
rowwise_usage=quantizer.rowwise_usage,
columnwise_usage=quantizer.columnwise_usage,
)
return out, all_gather_outputs
@classmethod
def _make_in_reduce_ex(
cls,
......@@ -752,6 +948,9 @@ class _ViewFunc(torch.autograd.Function):
out_transpose_shape = out_transpose.size()
if out_transpose_shape[0] != out_shape[-1] or out_transpose_shape[1:] != out_shape[:-1]:
out_transpose = None
else:
view_shape_for_transpose = [shape[-1]] + list(shape[:-1])
out_transpose = out_transpose.view(*view_shape_for_transpose)
return Float8Tensor(
shape=out_shape,
dtype=tensor.dtype,
......@@ -796,6 +995,9 @@ class _ReshapeFunc(torch.autograd.Function):
out_transpose_shape = out_transpose.size()
if out_transpose_shape[0] != out_shape[-1] or out_transpose_shape[1:] != out_shape[:-1]:
out_transpose = None
else:
reshape_shape_for_transpose = [shape[-1]] + list(shape[:-1])
out_transpose = out_transpose.reshape(*reshape_shape_for_transpose)
return Float8Tensor(
shape=out_shape,
dtype=tensor.dtype,
......
......@@ -6,16 +6,17 @@
from __future__ import annotations
from collections.abc import Iterable
import math
from typing import Optional, Tuple, Union
from typing import Optional, Tuple, Union, Any
import warnings
import torch
from torch.distributed.fsdp._fully_shard._fsdp_common import TrainingState
import transformer_engine_torch as tex
from transformer_engine_torch import DType as TE_DType
from transformer_engine.common.recipe import MXFP8BlockScaling, Recipe
from ..constants import MXFP8_BLOCK_SCALING_SIZE
from ..utils import devices_match, round_up_to_nearest_multiple
from .storage.mxfp8_tensor_storage import MXFP8TensorStorage, _FromMXFP8Func
from ..quantized_tensor import QuantizedTensor, Quantizer
from ._quantization_helpers import _IdentityFunc
......@@ -298,7 +299,6 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
memory_format: torch.memory_format = torch.contiguous_format,
) -> MXFP8Tensor:
"""Returns tensor with data in provided memory format
Returns `self` if data is already in correct memory format.
"""
......@@ -314,7 +314,6 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None):
# View op
if func == aten.view.default:
tensor = args[0]
......@@ -338,9 +337,335 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
fp8_dtype=tensor._fp8_dtype,
)
if func == torch.ops.aten.copy_.default:
dst, src = args[0], args[1]
# Booleans to check if src has all the usages that dst needs to respect dst quantizer usages.
# If not, default to base class behavior.
rowwise_matches = src._rowwise_data is not None or dst._rowwise_data is None
columnwise_matches = src._columnwise_data is not None or dst._columnwise_data is None
if (
isinstance(src, MXFP8Tensor)
and isinstance(dst, MXFP8Tensor)
and rowwise_matches
and columnwise_matches
):
if dst._rowwise_data is not None:
dst._rowwise_data.copy_(src._rowwise_data.detach())
dst._rowwise_scale_inv.copy_(src._rowwise_scale_inv.detach())
if dst._columnwise_data is not None:
dst._columnwise_data.copy_(src._columnwise_data.detach())
dst._columnwise_scale_inv.copy_(src._columnwise_scale_inv.detach())
return dst
# FSDP2 related functions.
if func == aten.split.Tensor:
# This is called if entire model is initialized on CUDA device and
# then splitted. Finally the shard needed by the process is used
# and other splitted shards are discarded.
if "dim" in kwargs:
dim_to_split = kwargs["dim"]
else:
dim_to_split = args[2] if len(args) > 2 else 0
tensor = args[0]
split_size = args[1]
dim0_size = tensor.size(0)
dimlast_size = math.prod(tensor.shape[1:])
if (
dim0_size % split_size != 0
or dim_to_split != 0
or split_size % MXFP8_BLOCK_SCALING_SIZE != 0
or dimlast_size % MXFP8_BLOCK_SCALING_SIZE != 0
):
# Handle splitting by dequantizing and splitting the hp tensor
return super().__torch_dispatch__(func, types, args, kwargs)
out_data = []
for data in [tensor._rowwise_data, tensor._columnwise_data]:
func_out = (
data.__torch_dispatch__(
func,
types,
[data] + list(args[1:]),
kwargs,
)
if data is not None
else None
)
out_data.append(func_out)
scale_invs = [tensor._rowwise_scale_inv, tensor._columnwise_scale_inv]
split_sizes_for_scale = [split_size, split_size // MXFP8_BLOCK_SCALING_SIZE]
# Padding requirements: rowwise dim0 should be divisble by 128, columnwise dim0 should be divisble by 4
padding_multiples = [128, 4]
for scale_inv, scale_split_size, pad_multiple in zip(
scale_invs, split_sizes_for_scale, padding_multiples
):
scale_inv_out = (
scale_inv.__torch_dispatch__(
func,
types,
[scale_inv, scale_split_size] + list(args[2:]),
kwargs,
)
if scale_inv is not None
else None
)
# Pad scale_inv_out to be a multiple of pad_multiple
if scale_inv_out is not None:
current_shape = scale_inv_out.shape
pad_dim0 = (pad_multiple - current_shape[0] % pad_multiple) % pad_multiple
if pad_dim0 > 0:
scale_inv_out = torch.nn.functional.pad(scale_inv_out, (0, 0, 0, pad_dim0))
out_data.append(scale_inv_out)
return [
MXFP8Tensor(
shape=(
splitted_tensor_data[0].size()
if splitted_tensor_data[0] is not None
else splitted_tensor_data[1].size()
),
dtype=tensor.dtype,
rowwise_data=splitted_tensor_data[0],
rowwise_scale_inv=splitted_tensor_data[2],
columnwise_data=splitted_tensor_data[1],
columnwise_scale_inv=splitted_tensor_data[3],
quantizer=tensor._quantizer,
requires_grad=False,
fp8_dtype=tensor._fp8_dtype,
)
for splitted_tensor_data in zip(*out_data)
]
if func == torch.ops.aten.as_strided.default:
# Applied on unsharded param in FSDP2. In our case, this should be a no-op
# This is needed for the case where some MXFP8 shards need padding i.e dimension 0
# of the unsharded param is not a multiple of the world size. If that is the case,
# we down the dequantization route and weights are allgathered in high precision.
# If weight doesnt need padding, this is just a no-op.
shape = args[1]
strides = args[2]
tensor = args[0]
if (
len(shape) != 2
or len(strides) != 2
or strides[1] != 1
or shape[0] != tensor.shape[0]
or shape[1] != tensor.shape[1]
):
return super().__torch_dispatch__(func, types, args, kwargs)
return MXFP8Tensor.make_like(tensor)
if func == aten.slice.Tensor:
# FSDP2 needed function.
# We need slicing for the case where some MXFP8 weight shards need padding i.e dimension 0
# of the unsharded param is not a multiple of the world size. If that is the case,
# we down the dequantization route and weights are allgathered in high precision instead.
# If sharded weight doesnt have padding, this is just a no-op.
dim = args[1]
start = args[2]
length = args[3]
tensor = args[0]
if (
dim != 0
or length != tensor.shape[0]
or start != 0
or length % MXFP8_BLOCK_SCALING_SIZE != 0
or start % MXFP8_BLOCK_SCALING_SIZE != 0
):
return super().__torch_dispatch__(func, types, args, kwargs)
return MXFP8Tensor.make_like(tensor)
if func == aten.new_zeros.default:
rowwise_data = None
columnwise_data = None
rowwise_scale_inv = None
columnwise_scale_inv = None
tensor = args[0]
shape = args[1]
first_dim = math.prod(shape[:-1])
last_dim = shape[-1]
if (
first_dim % MXFP8_BLOCK_SCALING_SIZE != 0
or last_dim % MXFP8_BLOCK_SCALING_SIZE != 0
):
return super().__torch_dispatch__(func, types, args, kwargs)
rowwise_scale_inv_shape = [first_dim, last_dim // MXFP8_BLOCK_SCALING_SIZE]
columnwise_scale_inv_shape = [
first_dim // MXFP8_BLOCK_SCALING_SIZE,
last_dim,
]
if tensor._rowwise_data is not None:
rowwise_data = tensor._rowwise_data.__torch_dispatch__(
func,
types,
[tensor._rowwise_data] + list(args[1:]),
kwargs,
)
rowwise_scale_inv = tensor._rowwise_scale_inv.__torch_dispatch__(
func,
types,
[tensor._rowwise_scale_inv, rowwise_scale_inv_shape] + list(args[2:]),
kwargs,
)
if tensor._columnwise_data is not None:
columnwise_data = tensor._columnwise_data.__torch_dispatch__(
func,
types,
[tensor._columnwise_data] + list(args[1:]),
kwargs,
)
columnwise_scale_inv = tensor._columnwise_scale_inv.__torch_dispatch__(
func,
types,
[tensor._columnwise_scale_inv, columnwise_scale_inv_shape] + list(args[2:]),
kwargs,
)
return MXFP8Tensor(
shape=args[1],
dtype=tensor.dtype,
rowwise_data=rowwise_data,
rowwise_scale_inv=rowwise_scale_inv,
columnwise_data=columnwise_data,
columnwise_scale_inv=columnwise_scale_inv,
quantizer=tensor._quantizer.copy(),
requires_grad=False,
fp8_dtype=tensor._fp8_dtype,
)
# Default case
return super().__torch_dispatch__(func, types, args, kwargs)
def fsdp_pre_all_gather(self, mesh, orig_size, contiguous_orig_stride, module, mp_policy):
"""Functions FSDP2 calls before all-gather of the
weights for both forward and backward passes.
Args:
mesh (torch.distributed.DeviceMesh): DeviceMesh used by FSDP2
to shard the weights.
orig_size (torch.Size): Original size of the weight tensor.(For us same as self.shape)
contiguous_orig_stride (Tuple[int]): Original stride of the weight tensor
(For us same as self.stride()).
module (FSDPModule): FSDP module. FSDP wrapped module wrapped using fully_shard
that contains this MXFP8 tensor.
mp_policy (MixedPrecisionPolicy): Mixed precision policy used by FSDP2.
Returns:
sharded_tensors: Tuple[torch.Tensor, ...]: Tuple of tensors
that need to be all-gathered.
metadata: Tuple[Any]: Metadata needed for reconstructing the
MXFP8Tensor after all-gather.
"""
# pylint: disable=unused-argument
from transformer_engine.pytorch.distributed import _get_module_fsdp_state
fsdp_state = _get_module_fsdp_state(module)
reshard_after_forward = fsdp_state._fsdp_param_group._reshard_after_forward
quantizer = self._quantizer.copy()
# Remove padding from scale inverses before allgather
# Rowwise scale_inv should be divisible by [128,4], columnwise by [4, 128]
rowwise_scale_inv = self._rowwise_scale_inv
columnwise_scale_inv = self._columnwise_scale_inv
shape = self.shape
if rowwise_scale_inv is not None:
# Remove padding from rowwise scale_inv
flattened_in_shape0 = math.prod(shape[:-1])
if rowwise_scale_inv.size(0) != flattened_in_shape0:
rowwise_scale_inv = rowwise_scale_inv[:flattened_in_shape0]
if columnwise_scale_inv is not None:
# Remove padding from columnwise scale_inv
flattened_in_shape0 = math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE
if columnwise_scale_inv.size(0) != flattened_in_shape0:
columnwise_scale_inv = columnwise_scale_inv[:flattened_in_shape0]
sharded_tensors = (self._rowwise_data, rowwise_scale_inv)
# If weights are resharded after forward pass, then its enough to set the quantizer usages
# based on whether its forward or backward pass for the allgathered weights.
# If not resharded after forward pass, the same weights allgathered in forward
# are used again in backward. And hence if we need the columnwise data/scale_inv,
# we need to send them as well for allgather in forward pass itself.
if reshard_after_forward:
training_state = fsdp_state._fsdp_param_group._training_state
is_backward_pass = training_state == TrainingState.PRE_BACKWARD
# Allgather only the necessary tensors based on forward/backward pass
quantizer.set_usage(rowwise=not is_backward_pass, columnwise=is_backward_pass)
sharded_tensors = (
(self._columnwise_data, columnwise_scale_inv)
if is_backward_pass
else sharded_tensors
)
else:
if quantizer.columnwise_usage:
# If weights are not resharded after forward, then both
# rowwise and columnwise data/scale_inv need to be allgathered.
sharded_tensors += (self._columnwise_data, columnwise_scale_inv)
metadata = (self._fp8_dtype, quantizer)
return sharded_tensors, metadata
def fsdp_post_all_gather(
self,
all_gather_outputs: Tuple[torch.Tensor, ...],
metadata: Any,
param_dtype: torch.dtype,
*,
out: Optional[MXFP8Tensor] = None,
):
"""Functions FSDP2 calls after all-gather of the
weights for both forward and backward passes.
Args:
all_gather_outputs (Tuple[torch.Tensor, ...]): sharded_tensors sent out in fsdp_pre_all_gather from each rank
are all-gathered and received here as a tuple.
metadata (Any): metadata sent out in fsdp_pre_all_gather used for reconstructing the MXFP8Tensor.
param_dtype (torch.dtype): high precision dtype of the MXFP8Tensor.
out (Optional[torch.Tensor], optional): _description_. Defaults to None.
Returns:
Tuple[MXFP8Tensor, Tuple[torch.Tensor, ...]]: Allgathered MXFP8Tensor and tuple of internal tensors
used by the MXFP8Tensor that was being computed after allgather.
"""
fp8_dtype, quantizer = metadata
rowwise_data, rowwise_scale_inv = (
all_gather_outputs[:2] if quantizer.rowwise_usage else (None, None)
)
columnwise_data, columnwise_scale_inv = (
all_gather_outputs[-2:] if quantizer.columnwise_usage else (None, None)
)
# Add padding to scale_inv tensors to be multiples of [128, 4]for rowwise and [4, 128] for columnwise
if rowwise_scale_inv is not None:
# Pad rowwise_scale_inv to be a multiple of [128, 4]
current_shape = rowwise_scale_inv.shape
pad_dim0 = (128 - current_shape[0] % 128) % 128
if pad_dim0 > 0:
rowwise_scale_inv = torch.nn.functional.pad(rowwise_scale_inv, (0, 0, 0, pad_dim0))
if columnwise_scale_inv is not None:
# Pad columnwise_scale_inv to be a multiple of [4, 128]
current_shape = columnwise_scale_inv.shape
pad_dim0 = (4 - current_shape[0] % 4) % 4
if pad_dim0 > 0:
columnwise_scale_inv = torch.nn.functional.pad(
columnwise_scale_inv, (0, 0, 0, pad_dim0)
)
if out is not None:
out._rowwise_data = rowwise_data
out._rowwise_scale_inv = rowwise_scale_inv
out._columnwise_data = columnwise_data
out._columnwise_scale_inv = columnwise_scale_inv
out._quantizer = quantizer
else:
out = MXFP8Tensor(
rowwise_data=rowwise_data,
rowwise_scale_inv=rowwise_scale_inv,
columnwise_data=columnwise_data,
columnwise_scale_inv=columnwise_scale_inv,
fp8_dtype=fp8_dtype,
dtype=param_dtype,
shape=rowwise_data.shape if rowwise_data is not None else columnwise_data.shape,
quantizer=quantizer,
)
return out, all_gather_outputs
@classmethod
def _make_in_reduce_ex(
cls,
......@@ -478,10 +803,14 @@ class _ViewFunc(torch.autograd.Function):
shape[i] = d_inferred
break
if shape[-1] != ctx.shape[-1]:
raise RuntimeError(
"MXFP8Tensor does not support reshaping inner dimension "
warnings.warn(
"MXFP8Tensor does not support reshaping inner dimension. "
f"(attempted to reshape dims={tuple(tensor.shape)} to {tuple(shape)})"
"If you are using this for FSDP2 without compiled_autograd_enabled,"
"then ignore this warning. Since this view is not going to be used anywhere. ",
stacklevel=2,
)
return tensor.dequantize().view(*shape)
# Construct new tensor if shape is provided
new_rowwise_data = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment