Unverified Commit 29537c96 authored by vthumbe1503's avatar vthumbe1503 Committed by GitHub
Browse files

[PyTorch] FSDP2 Support for TE (#2245)



* fix for float8 tensor fsdp2 training
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* zeros_like should return fp32 for fsdp2 to work
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* minor cleanup
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* fix unsharded weights not releasing memory
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* implement using fsdp preallgather and postallgather functions
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* FSDP2 works on Hopper/L40
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor comment
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* some fixes for fp8 + handwavy changes for mxfp8
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* only transpose saved for backward pass allgather in case of L40/Hoppergst
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* missed minor change to hopper use-case
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* communicate only required data in mxfp8, fix for updating weight usages when required instead of doing upfront in fwd pass
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* changes for meta Dtensors for weights and better all gather data handling in fsdp hook functions
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* better solution to figure out forward pass in FSDP2
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* adress review comments
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* Update transformer_engine/pytorch/tensor/mxfp8_tensor.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarvthumbe1503 <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* everything functioning except hack for transformerlayer
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* fix merge conflict
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* revert change of commit id for cudnnt-frontend
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* unnecessary change
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor issues with linting, add some comments
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor stuff
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* revert space removal

Add default usage handling for rowwise and columnwise data.
Signed-off-by: default avatarvthumbe1503 <vthumbe@nvidia.com>

* fix the fsdp state collection issue, and minor review comments addressing
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* revert change for dgrad redundant computation
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* bug: get fsdp param group's training state instead of root training state; address review comments
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* address review comments
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* address review comments
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* address coderabbit review comments
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* address review comments
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* address review comments
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* adress review comments; fix fp8 allgather test to do after fsdp lazy init
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* address review comments
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* remove detach
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* do what makes sense
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* Update transformer_engine/pytorch/tensor/float8_tensor.py
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: default avatarvthumbe1503 <vthumbe@nvidia.com>

* Update transformer_engine/pytorch/tensor/mxfp8_tensor.py
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: default avatarvthumbe1503 <vthumbe@nvidia.com>

* Update transformer_engine/pytorch/tensor/mxfp8_tensor.py
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: default avatarvthumbe1503 <vthumbe@nvidia.com>

* Update transformer_engine/pytorch/tensor/mxfp8_tensor.py
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: default avatarvthumbe1503 <vthumbe@nvidia.com>

* Update transformer_engine/pytorch/tensor/mxfp8_tensor.py
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: default avatarvthumbe1503 <vthumbe@nvidia.com>

* Update transformer_engine/pytorch/tensor/mxfp8_tensor.py
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: default avatarvthumbe1503 <vthumbe@nvidia.com>

* address review comments
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* address review comments
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* address review comments
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* adress review comments
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* address review comments
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* have better dtype for fsdp_post_all_gather arguments
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* minor comment
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* improve comment
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* fix the error in CI
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* minor comment add
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* accidentally removed view function
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* fix minor bug for h100
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* minor addition
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* implement padding removal/addition for allgather
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* address review comments
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* address review comments
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* Update transformer_engine/pytorch/tensor/mxfp8_tensor.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarvthumbe1503 <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix lint error
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* adress review comments
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* improve the reset parameter logic for dtensors
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* other cosmetic changes
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* cosmetic changes
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* cosmetic changes
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* Update transformer_engine/pytorch/module/layernorm_linear.py
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

---------
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>
Signed-off-by: default avatarvthumbe1503 <vthumbe@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
parent 7a585983
...@@ -9,57 +9,73 @@ import sys ...@@ -9,57 +9,73 @@ import sys
import argparse import argparse
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling from transformer_engine.common.recipe import (
Format,
DelayedScaling,
Float8CurrentScaling,
MXFP8BlockScaling,
)
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.distributed.tensor import DTensor
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn, optim from torch import nn, optim
from torch.distributed import DeviceMesh from torch.distributed import DeviceMesh
from torch.distributed._composable.fsdp import fully_shard from torch.distributed._composable.fsdp import fully_shard
from torch.distributed.device_mesh import init_device_mesh from torch.distributed.device_mesh import init_device_mesh
from transformer_engine.pytorch import QuantizedTensor
from contextlib import nullcontext from contextlib import nullcontext
LOCAL_RANK = None
class SimpleNet(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(SimpleNet, self).__init__()
self.fc1 = te.Linear(input_size, hidden_size)
self.fc2 = te.Linear(hidden_size, output_size)
def forward(self, x): def dist_print(msg):
x = F.relu(self.fc1(x)) if LOCAL_RANK == 0:
x = self.fc2(x) print(msg)
return x
def save_custom_attrs(module):
custom_attrs = {}
for name, param in module.named_parameters():
attrs = vars(param)
custom_attrs[name] = {k: v for k, v in attrs.items()}
return custom_attrs
def restore_custom_attrs(module, custom_attrs):
for name, param in module.named_parameters():
if name in custom_attrs:
for attr_name, attr_value in custom_attrs[name].items():
setattr(param, attr_name, attr_value)
def _parse_args(argv=None, namespace=None): def _parse_args(argv=None, namespace=None):
parser = argparse.ArgumentParser(description="Toy example for debugging fully_shard()") parser = argparse.ArgumentParser(description="Toy example for debugging fully_shard()")
parser.add_argument("--input-size", type=int, default=2048, help="Input size for the model") parser.add_argument("--num-heads", type=int, default=8, help="Number of attn. heads")
parser.add_argument("--hidden-size", type=int, default=2048, help="Hidden layer size") parser.add_argument("--head-dim", type=int, default=64, help="Attention head size")
parser.add_argument("--output-size", type=int, default=2048, help="Output size for the model") parser.add_argument("--batch-size", type=int, default=16, help="Batch size of input")
parser.add_argument("--batch-size", type=int, default=2048, help="Output size for the model") parser.add_argument("--seq-length", type=int, default=128, help="Sequence length of input")
parser.add_argument("--params-dtype", type=str, default="float32", help="Parameter dtype.")
parser.add_argument( parser.add_argument(
"--fp8-init", action="store_true", default=False, help="Initialize primary weights in FP8." "--fp8-init", action="store_true", default=False, help="Initialize primary weights in FP8."
) )
parser.add_argument(
"--recipe",
type=str,
default="mx_fp8_block_scaling",
help="Quantizer type.",
choices=["delayed_scaling", "current_scaling", "mx_fp8_block_scaling"],
)
parser.add_argument(
"--layer-type",
type=str,
default="TransformerLayer",
choices=[
"Linear",
"LayerNormLinear",
"LayerNormMLP",
"MultiheadAttention",
"TransformerLayer",
],
help="Transformer Engine layer type",
)
parser.add_argument("--num-layers", type=int, default=4, help="Number of layers in the model")
parser.add_argument( parser.add_argument(
"--iter", type=int, default=10, help="Number of iterations for forward pass" "--iter", type=int, default=10, help="Number of iterations for forward pass"
) )
parser.add_argument(
"--device",
type=str,
default="meta",
help="Device to run the model on.",
choices=["cuda", "meta"],
)
parser.add_argument("--seed", type=int, default=42, help="RNG seed.") parser.add_argument("--seed", type=int, default=42, help="RNG seed.")
# Adding hsdp_dim as a list argument, comma-separated # Adding hsdp_dim as a list argument, comma-separated
parser.add_argument( parser.add_argument(
...@@ -74,10 +90,170 @@ def _parse_args(argv=None, namespace=None): ...@@ -74,10 +90,170 @@ def _parse_args(argv=None, namespace=None):
return args return args
sub_modules_to_wrap = [te.Linear] ## Methods to help initialize the TE model in an FSDP2 setting
## with required configurations based on command line args
def get_te_layer_from_string(layer_name):
te_layer_types = [
te.Linear,
te.LayerNormLinear,
te.LayerNormMLP,
te.MultiheadAttention,
te.TransformerLayer,
]
te_layer_names = [layer.__name__ for layer in te_layer_types]
te_layer_map = dict(zip([name.lower() for name in te_layer_names], te_layer_types))
if layer_name.lower() not in te_layer_map.keys():
raise argparse.ArgumentTypeError(
f'"{layer_name}" is not a valid Transformer Engine layer, '
f"please choose layer from {te_layer_names}."
)
return te_layer_map[layer_name.lower()]
def get_recipe_from_string(recipe, fp8_format=Format.HYBRID):
if recipe == "delayed_scaling":
return DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max")
elif recipe == "current_scaling":
return Float8CurrentScaling(fp8_format=fp8_format)
elif recipe == "mx_fp8_block_scaling":
return MXFP8BlockScaling(fp8_format=fp8_format)
else:
raise ValueError(f"Unknown quantizer type: {recipe}")
def init_te_model(config):
hidden_size = config.num_heads * config.head_dim
args = [hidden_size, hidden_size]
inp_shape = [config.seq_length, config.batch_size, hidden_size]
out_shape = [config.seq_length, config.batch_size, hidden_size]
if config.params_dtype == "float16":
params_dtype = torch.float16
elif config.params_dtype == "bfloat16":
params_dtype = torch.bfloat16
else:
params_dtype = torch.float32
kwargs = {
"params_dtype": params_dtype,
}
kwargs["device"] = config.device
layer_type = get_te_layer_from_string(config.layer_type)
# We are creating model in a way so that we can test both reshard_after_forward=True/False cases.
# more details below.
if layer_type in [te.MultiheadAttention, te.TransformerLayer]:
# For this case, we are creating a model that resemebles production use-cases
# wherein there are mltiple TransformerLayers in the model. And we would need
# to shard each transformer layer. Since each transformer layer is not a root module,
# FSDP2's fully_shard assigns reshard_after_forward=False for all parameters of the model.
args[1] *= 4 # FFN hidden size
args.append(config.num_heads)
kwargs["fuse_qkv_params"] = True
if layer_type is te.MultiheadAttention:
kwargs["input_layernorm"] = True
model = nn.Sequential(*[layer_type(*args, **kwargs) for _ in range(config.num_layers)])
elif layer_type == te.LayerNormLinear:
# For this case, we are creating a model with just one LayerNormLinear layer
# so that the model itself is a root module, and FSDP2's fully_shard assigns
# reshard_after_forward=True for the parameters of these model.
args[1] *= 3 # QKV projection
out_shape[-1] *= 3
model = layer_type(*args, **kwargs)
else:
model = layer_type(*args, **kwargs)
return model, inp_shape, out_shape
def get_device_mesh(world_size, sharding_dims):
dist_print(f"sharding-dims:{sharding_dims}")
device_ids = list(range(world_size))
if sharding_dims is None: # FSDP
mesh = DeviceMesh("cuda", device_ids)
elif len(sharding_dims) == 1:
assert sharding_dims[0] == world_size
mesh = DeviceMesh("cuda", device_ids)
elif len(sharding_dims) == 2: # HSDP
assert sharding_dims[0] * sharding_dims[1] == world_size
mesh = init_device_mesh(
"cuda",
(sharding_dims[0], sharding_dims[1]),
mesh_dim_names=("replicate", "shard"),
)
else:
assert False
return mesh
def shard_model_with_fsdp2(model, mesh):
for child in model.children():
fully_shard(child, mesh=mesh)
fully_shard(model, mesh=mesh)
return model
#### Methods to save the custom attributes of QuantizedTensors before sharding
#### them with FSDP2, and restore them after sharding.
def save_custom_attrs(module):
custom_attrs = {}
for name, param in module.named_parameters():
if isinstance(param, QuantizedTensor):
# Ignore FP8 metadata attributes. Otherwise we will save duplicate copies
# for data/transpose FP8 tensors on top of FP8 tensors that FSDP2 will save.
ignore_keys = [key for key in param.__dict__.keys() if key.startswith("_")]
else:
ignore_keys = []
attrs = vars(param)
custom_attrs[name] = {k: v for k, v in attrs.items() if k not in ignore_keys}
return custom_attrs
def restore_custom_attrs(module, custom_attrs):
for name, param in module.named_parameters():
if name in custom_attrs:
for attr_name, attr_value in custom_attrs[name].items():
setattr(param, attr_name, attr_value)
@torch.no_grad()
def test_fp8_fsdp2_allgather(model):
# Do manual allgather in fp32 and match against fp8 allgather done
# with fsdp2
# FP32 manual weight allgather
fp32_allgathered_params = {}
for name, param in model.named_parameters():
assert isinstance(param, DTensor)
local_tensor = param._local_tensor
device_mesh = param.device_mesh
dist_group = (
device_mesh.get_group(mesh_dim="shard")
if device_mesh.ndim > 1
else device_mesh.get_group()
)
# Perform manual allgather on local_tensor. zeros_like will create hp tensor since torch_dispatch
# for local_tensor will go down the dequantization route.
gathered_tensor = [
torch.zeros_like(local_tensor) for _ in range(dist.get_world_size(group=dist_group))
]
dist.all_gather(gathered_tensor, local_tensor.dequantize(), group=dist_group)
full_tensor = torch.cat(gathered_tensor, dim=0)
fp32_allgathered_params[name] = full_tensor
# FP8 allgather using FSDP2
for module in model.modules():
# Not all modules are wrapped/sharded with FSDP2.
if hasattr(module, "unshard"):
module.unshard()
# Make sure allgathered parameters match exactly
for name, param in model.named_parameters():
assert torch.allclose(param.dequantize(), fp32_allgathered_params[name])
# Revert model to original sharded state
for module in model.modules():
# Not all modules are wrapped/sharded with FSDP2.
if hasattr(module, "reshard"):
module.reshard()
def _train(args): def _train(args):
global LOCAL_RANK
assert "TORCHELASTIC_RUN_ID" in os.environ assert "TORCHELASTIC_RUN_ID" in os.environ
WORLD_RANK = int(os.getenv("RANK", "0")) WORLD_RANK = int(os.getenv("RANK", "0"))
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
...@@ -103,74 +279,69 @@ def _train(args): ...@@ -103,74 +279,69 @@ def _train(args):
# FP8 Configuration # FP8 Configuration
fp8_format = Format.HYBRID fp8_format = Format.HYBRID
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max") fp8_recipe = get_recipe_from_string(args.recipe, fp8_format)
# Create build context manager
if args.fp8_init:
from transformer_engine.pytorch import quantized_model_init
build_model_context = quantized_model_init() build_model_context_args = {}
if not args.fp8_init:
# Build model context (FP8 init)
build_model_context = nullcontext
else: else:
build_model_context = nullcontext() from transformer_engine.pytorch import fp8_model_init
# Build the model with the specified context build_model_context = fp8_model_init
with build_model_context: build_model_context_args["enabled"] = True
model = SimpleNet(args.input_size, args.hidden_size, args.output_size) build_model_context_args["recipe"] = fp8_recipe
# Move the model to the correct device dist_print(f"Memory before model init: {torch.cuda.memory_allocated(device)/1e6} MB")
model.to(device) # Create the model on the meta/cuda device as per args
with build_model_context(**build_model_context_args):
model, inp_shape, out_shape = init_te_model(args)
dist_print(
f"Memory after model init on device {args.device}:"
f" {torch.cuda.memory_allocated(device)/1e6} MB"
)
if LOCAL_RANK == 0:
print(f"Rank {LOCAL_RANK}: Applying FSDP fully_shard() to the model...")
# Creating a DeviceMesh for fully_shard # Creating a DeviceMesh for fully_shard
world_size = int(WORLD_SIZE) world_size = int(WORLD_SIZE)
device_ids = list(range(world_size))
if LOCAL_RANK == 0:
print(f"sharding-dims:{args.sharding_dims}")
# Setup the sharding mesh for FSDP/HSDP # Setup the sharding mesh for FSDP/HSDP
if args.sharding_dims == None: # FSDP mesh = get_device_mesh(world_size, args.sharding_dims)
mesh = DeviceMesh("cuda", device_ids)
elif len(args.sharding_dims) == 1:
assert args.sharding_dims[0] == device_ids[-1] + 1
mesh = DeviceMesh("cuda", device_ids)
elif len(args.sharding_dims) == 2: # HSDP
assert args.sharding_dims[0] * args.sharding_dims[1] == device_ids[-1] + 1
mesh = init_device_mesh(
"cuda",
(args.sharding_dims[0], args.sharding_dims[1]),
mesh_dim_names=("replicate", "shard"),
)
else:
assert False
# Apply FSDP/HSDP
custom_attrs = save_custom_attrs(model) custom_attrs = save_custom_attrs(model)
for sub_module in model.modules(): model = shard_model_with_fsdp2(model, mesh)
if any(
isinstance(sub_module, sub_module_to_wrap) for sub_module_to_wrap in sub_modules_to_wrap
):
fully_shard(sub_module, mesh=mesh)
fully_shard(model, mesh=mesh)
restore_custom_attrs(model, custom_attrs) restore_custom_attrs(model, custom_attrs)
# model now has DTensors as its parameters
if args.device == "meta":
# After FSDP2 has been applied, materialize and initialize the sharded parameters
# TE base.py's reset_parameters() handles DTensors with FP8 initialization
for module in model.modules():
if hasattr(module, "reset_parameters"):
module.reset_parameters()
dist_print(f" Sharded parameters materialized and initialized on cuda device.")
dist_print(
f"FSDP2 model in cuda, memory allocated: {torch.cuda.memory_allocated(device)/1e6} MB"
)
optimizer = optim.Adam(model.parameters(), lr=1e-3) optimizer = optim.Adam(model.parameters(), lr=1e-3)
for iteration in range(args.iter): for iteration in range(args.iter):
# Zero the parameter gradients # Zero the parameter gradients
optimizer.zero_grad() optimizer.zero_grad()
input_data = torch.randn(args.batch_size, args.input_size).to(device) input_data = torch.randn(inp_shape).to(device)
with te.autocast(enabled=True, recipe=fp8_recipe): with te.autocast(enabled=True, recipe=fp8_recipe):
output = model(input_data) output = model(input_data)
target = torch.randn(args.batch_size, args.output_size).to(device) target = torch.randn(out_shape).to(device)
loss = F.mse_loss(output, target) loss = F.mse_loss(output, target)
loss.backward() loss.backward()
optimizer.step() optimizer.step()
if LOCAL_RANK == 0: dist_print(f"Iteration {iteration} completed with loss {loss.item()}")
print(f"Rank {LOCAL_RANK}: Iteration {iteration} completed.")
# Some of the FSDP states are lazy initialized during FSDP forward pass
# so testing fp8 allgather at the end of the training loop.
if args.fp8_init:
test_fp8_fsdp2_allgather(model)
dist.destroy_process_group() dist.destroy_process_group()
if LOCAL_RANK == 0:
print(f"Rank {LOCAL_RANK}: Done...")
return 0 return 0
......
...@@ -12,22 +12,26 @@ import torch ...@@ -12,22 +12,26 @@ import torch
fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True)
NUM_PROCS: int = torch.cuda.device_count() NUM_PROCS: int = torch.cuda.device_count()
def _run_test(fp_init, sharding_dims): def _run_test(fp_init, sharding_dims, recipe, layer_type):
test_path = Path(__file__).parent.resolve() / "run_fsdp2_model.py" test_path = Path(__file__).parent.resolve() / "run_fsdp2_model.py"
test_cmd = ["torchrun", f"--nproc_per_node={NUM_PROCS}", str(test_path)] test_cmd = ["torchrun", f"--nproc_per_node={NUM_PROCS}", str(test_path)]
if fp_init: if fp_init:
test_cmd += ["--fp8-init"] test_cmd += ["--fp8-init"]
if len(sharding_dims) == 1: if len(sharding_dims) == 1:
test_cmd += ["--sharding-dims", str(sharding_dims[0])] test_cmd += ["--sharding-dims", str(sharding_dims[0])]
elif len(sharding_dims) == 2: elif len(sharding_dims) == 2:
test_cmd += ["--sharding-dims", str(sharding_dims[0]), str(sharding_dims[1])] test_cmd += ["--sharding-dims", str(sharding_dims[0]), str(sharding_dims[1])]
else: else:
assert False assert False
test_cmd += ["--recipe", recipe]
test_cmd += ["--layer-type", layer_type]
result = subprocess.run(test_cmd, env=os.environ, check=True) result = subprocess.run(test_cmd, env=os.environ, check=True)
...@@ -36,16 +40,20 @@ def _run_test(fp_init, sharding_dims): ...@@ -36,16 +40,20 @@ def _run_test(fp_init, sharding_dims):
@pytest.mark.skipif(not te.torch_version() >= (2, 4, 0), reason="Requires PyTorch 2.4.0+") @pytest.mark.skipif(not te.torch_version() >= (2, 4, 0), reason="Requires PyTorch 2.4.0+")
@pytest.mark.parametrize("sharding_dims", ([NUM_PROCS], [2, NUM_PROCS // 2])) @pytest.mark.parametrize("sharding_dims", ([NUM_PROCS], [2, NUM_PROCS // 2]))
@pytest.mark.parametrize("fp8_init", (False, True)) @pytest.mark.parametrize("fp8_init", (False, True))
def test_distributed(fp8_init, sharding_dims): @pytest.mark.parametrize("recipe", ("delayed_scaling", "current_scaling", "mx_fp8_block_scaling"))
@pytest.mark.parametrize("layer_type", ("LayerNormLinear", "TransformerLayer"))
def test_distributed(fp8_init, sharding_dims, recipe, layer_type):
# Skip invalid configurations # Skip invalid configurations
if torch.cuda.device_count() < 4: if torch.cuda.device_count() < 4:
pytest.skip("FSDP2 test requires at least 4 GPUs") pytest.skip("FSDP2 test requires at least 4 GPUs")
if fp8_init and not fp8_available: if recipe == "mx_fp8_block_scaling" and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
elif not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
_run_test(fp8_init, sharding_dims) _run_test(fp8_init, sharding_dims, recipe, layer_type)
def test_dummy() -> None: def test_dummy() -> None:
......
...@@ -1886,6 +1886,43 @@ def allreduce( ...@@ -1886,6 +1886,43 @@ def allreduce(
return inp, handle return inp, handle
def _get_module_fsdp_state(module):
"""
If module is an FSDP module, return its _FSDPState.
Otherwise, return the _FSDPState of the closest parent FSDP module
in the module hierarchy the module belongs to.
"""
if hasattr(module, "_get_fsdp_state"):
# this will return correct fsdp state if module itself is an fsdp module
fsdp_state = module._get_fsdp_state()
elif getattr(module, "_te_cached_parent_fsdp_state", None) is not None:
# See if we have cached the parent fsdp state of the module
fsdp_state = module._te_cached_parent_fsdp_state
else:
from torch.distributed._composable_state import _module_state_mapping
# Otherwise get the fsdp state of lca of module in the module hierarchy
min_nodes_in_parent = float("inf")
closest_parent_fsdp_mod = None
for fsdp_mod in _module_state_mapping.keys():
all_submodules = list(fsdp_mod.modules())
for submodule in all_submodules:
if submodule is module:
if min_nodes_in_parent > len(all_submodules):
closest_parent_fsdp_mod = fsdp_mod
min_nodes_in_parent = len(all_submodules)
if closest_parent_fsdp_mod is None:
raise RuntimeError(
"Module is not FSDP-wrapped and does not have any FSDP-wrapped parent modules."
)
fsdp_state = closest_parent_fsdp_mod._get_fsdp_state()
# Cache the parent fsdp state of the module to avoid recomputing
# the closest parent fsdp module.
module._te_cached_parent_fsdp_state = fsdp_state
return fsdp_state
def _fsdp_scatter_tensors( def _fsdp_scatter_tensors(
fsdp_group: dist_group_type, fsdp_group: dist_group_type,
*tensors: torch.Tensor, *tensors: torch.Tensor,
......
...@@ -17,6 +17,7 @@ from types import MethodType ...@@ -17,6 +17,7 @@ from types import MethodType
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch.distributed.tensor import DTensor
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe from transformer_engine.common.recipe import Recipe
...@@ -1244,6 +1245,11 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1244,6 +1245,11 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
metedata used in deferred initialization. metedata used in deferred initialization.
""" """
super().register_parameter(name, param) super().register_parameter(name, param)
# Initialize param_init_meta exactly once during the init. FSDP2 can call
# register parameter again to change parameters to DTensors. And it calls
# it without custom fp8 specific kwargs that we need. And so we dont want
# to reset/loose our fp8 init attributes.
if hasattr(self, "param_init_meta") and name not in self.param_init_meta:
self.param_init_meta[name] = _ParameterInitMeta(**kwargs) self.param_init_meta[name] = _ParameterInitMeta(**kwargs)
def reset_parameters(self, defer_init: Optional[bool] = False) -> None: def reset_parameters(self, defer_init: Optional[bool] = False) -> None:
...@@ -1256,10 +1262,14 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1256,10 +1262,14 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
return return
for name, param in self.named_parameters(recurse=False): for name, param in self.named_parameters(recurse=False):
# Check if parameter is a DTensor (FSDP2) or regular tensor
is_dtensor = isinstance(param, DTensor)
dtensor_param = param if is_dtensor else None
# Need to update/quantize local tensor in case of DTensor
param = param._local_tensor if is_dtensor else param
# Ensure parameter is on a real device # Ensure parameter is on a real device
if param.device == torch.device("meta"): if param.device == torch.device("meta"):
param = torch.empty_like(param, device="cuda") param = torch.empty_like(param, device="cuda")
# Initialize the parameter values on device # Initialize the parameter values on device
init_fn = self.param_init_meta[name].init_fn init_fn = self.param_init_meta[name].init_fn
get_rng_state_tracker = self.param_init_meta[name].get_rng_state_tracker get_rng_state_tracker = self.param_init_meta[name].get_rng_state_tracker
...@@ -1288,7 +1298,15 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1288,7 +1298,15 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
raise RuntimeError("Weight quantizer has not been initialized") raise RuntimeError("Weight quantizer has not been initialized")
quantizer.set_usage(rowwise=True, columnwise=torch.is_grad_enabled()) quantizer.set_usage(rowwise=True, columnwise=torch.is_grad_enabled())
quantizer.internal = False quantizer.internal = False
if is_dtensor and isinstance(quantizer, Float8CurrentScalingQuantizer):
device_mesh = dtensor_param.device_mesh
amax_reduction_group = (
device_mesh.get_group(mesh_dim="shard")
if device_mesh.ndim > 1
else device_mesh.get_group()
)
quantizer.amax_reduction_group = amax_reduction_group
quantizer.with_amax_reduction = True
# Quantize parameter # Quantize parameter
param = quantizer(param) param = quantizer(param)
...@@ -1296,6 +1314,17 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1296,6 +1314,17 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# NOTE: Currently this can only be broken when primary weights are in Fp8 but # NOTE: Currently this can only be broken when primary weights are in Fp8 but
# re-applying the nn.Parameter() wrap is a no-op when the input is already # re-applying the nn.Parameter() wrap is a no-op when the input is already
# a parameter so we always re-apply it just for extra safety. # a parameter so we always re-apply it just for extra safety.
if is_dtensor:
# recreate the DTensor from the parameter.
dtensor_param = DTensor.from_local(
param,
device_mesh=dtensor_param.device_mesh,
placements=dtensor_param.placements,
shape=dtensor_param.size(),
stride=dtensor_param.stride(),
)
dtensor_param = torch.nn.Parameter(dtensor_param)
else:
param = torch.nn.Parameter(param) param = torch.nn.Parameter(param)
# Keep high-precision values on CPU if needed # Keep high-precision values on CPU if needed
...@@ -1324,8 +1353,12 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1324,8 +1353,12 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
param._high_precision_init_val = high_precision_init_val param._high_precision_init_val = high_precision_init_val
param.get_high_precision_init_val = MethodType(get, param) param.get_high_precision_init_val = MethodType(get, param)
param.clear_high_precision_init_val = MethodType(clear, param) param.clear_high_precision_init_val = MethodType(clear, param)
# Update the parameter based on its type
if not is_dtensor:
setattr(self, name, param) setattr(self, name, param)
else:
setattr(self, name, dtensor_param)
@abstractmethod @abstractmethod
def forward(self): def forward(self):
......
...@@ -108,9 +108,15 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -108,9 +108,15 @@ class _GroupedLinear(torch.autograd.Function):
is_fp8_activation_recompute_enabled() is_fp8_activation_recompute_enabled()
and not in_fp8_activation_recompute_phase() and not in_fp8_activation_recompute_phase()
) )
if weight_quantizers[0] is not None: # No need to set the quantizer states if weight is already quantized
if weight_quantizers[0] is not None and not isinstance(
weights[0], QuantizedTensorStorage
):
for weight_quantizer in weight_quantizers: for weight_quantizer in weight_quantizers:
weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage)
elif isinstance(weights[0], QuantizedTensorStorage):
# If weights are already quantized, no need to set quantizer states
weight_quantizers = [weight._quantizer for weight in weights]
if output_quantizers[0] is not None: if output_quantizers[0] is not None:
for output_quantizer in output_quantizers: for output_quantizer in output_quantizers:
output_quantizer.set_usage(rowwise=True, columnwise=False) output_quantizer.set_usage(rowwise=True, columnwise=False)
...@@ -205,10 +211,6 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -205,10 +211,6 @@ class _GroupedLinear(torch.autograd.Function):
inputmat.update_usage(rowwise_usage=False, columnwise_usage=True) inputmat.update_usage(rowwise_usage=False, columnwise_usage=True)
else: else:
inputmats = [None] * num_gemms inputmats = [None] * num_gemms
if inp.requires_grad:
for weight in weights_fp8:
if isinstance(weight, QuantizedTensorStorage):
weight.update_usage(columnwise_usage=True)
if cpu_offloading: if cpu_offloading:
ctx.grad_added_to_main_grad = hasattr(weights[0], "grad_added_to_main_grad") ctx.grad_added_to_main_grad = hasattr(weights[0], "grad_added_to_main_grad")
...@@ -354,13 +356,11 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -354,13 +356,11 @@ class _GroupedLinear(torch.autograd.Function):
dtype=ctx.activation_dtype, dtype=ctx.activation_dtype,
device=ctx.device, device=ctx.device,
) )
# Make sure weights are available in column-wise format
for weight, quantizer in zip(weights, ctx.weight_quantizers): # for dgrad computation.
if quantizer is not None and isinstance(weight, QuantizedTensorStorage): for weight in weights:
weight.update_usage( if isinstance(weight, QuantizedTensorStorage):
rowwise_usage=quantizer.rowwise_usage, weight.update_usage(columnwise_usage=True)
columnwise_usage=quantizer.columnwise_usage,
)
general_grouped_gemm( general_grouped_gemm(
weights, weights,
grad_output, grad_output,
......
...@@ -276,12 +276,15 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -276,12 +276,15 @@ class _LayerNormLinear(torch.autograd.Function):
# Prepare weight tensor # Prepare weight tensor
# ------------------------------------------------------ # ------------------------------------------------------
weightmat = weight weightmat = weight
quantized_weight = False is_weight_param_quantized = False
if fp8 or debug: if fp8 or debug:
quantized_weight = not isinstance(weight, QuantizedTensorStorage) is_weight_param_quantized = isinstance(weight, QuantizedTensorStorage)
# Configure quantizer # Configure quantizer
if weight_quantizer is not None: # If weight is already quantized, no need to set quantizer states
if is_weight_param_quantized:
weight_quantizer = weight._quantizer
elif weight_quantizer is not None:
weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled) weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled)
# Get quantized weight # Get quantized weight
...@@ -413,10 +416,6 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -413,10 +416,6 @@ class _LayerNormLinear(torch.autograd.Function):
): ):
ln_out.update_usage(rowwise_usage=False) ln_out.update_usage(rowwise_usage=False)
# Weight with column-wise usage is needed for dgrad GEMM.
if isinstance(weightmat, QuantizedTensorStorage):
weightmat.update_usage(columnwise_usage=True)
if cpu_offloading: if cpu_offloading:
mark_activation_offload(inputmat, mu, rsigma, ln_out) mark_activation_offload(inputmat, mu, rsigma, ln_out)
...@@ -429,7 +428,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -429,7 +428,7 @@ class _LayerNormLinear(torch.autograd.Function):
fsdp_group, fsdp_group,
mu, mu,
rsigma, rsigma,
weightmat if quantized_weight else None, weightmat if fp8 and not is_weight_param_quantized else None,
ln_out if weight.requires_grad else None, ln_out if weight.requires_grad else None,
) )
nvtx_range_pop(f"{nvtx_label}.fsdp_scatter") nvtx_range_pop(f"{nvtx_label}.fsdp_scatter")
...@@ -459,7 +458,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -459,7 +458,7 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.tensor_objects = tensor_objects ctx.tensor_objects = tensor_objects
ctx.requires_dgrad = inp_requires_grad ctx.requires_dgrad = inp_requires_grad
ctx.requires_wgrad = weight.requires_grad ctx.requires_wgrad = weight.requires_grad
ctx.quantized_weight = quantized_weight ctx.is_weight_param_quantized = is_weight_param_quantized
if fuse_wgrad_accumulation and weight.requires_grad: if fuse_wgrad_accumulation and weight.requires_grad:
# This check is needed to ensure that main_grad is not created # This check is needed to ensure that main_grad is not created
# during the forward pass when using MCore FSDP as it creates # during the forward pass when using MCore FSDP as it creates
...@@ -563,7 +562,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -563,7 +562,7 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.fsdp_shapes, ctx.fsdp_shapes,
mu, mu,
rsigma, rsigma,
weight if ctx.fp8 and ctx.quantized_weight else None, weight if ctx.fp8 and not ctx.is_weight_param_quantized else None,
ln_out, ln_out,
) )
nvtx_range_pop(f"{nvtx_label}.fsdp_gather") nvtx_range_pop(f"{nvtx_label}.fsdp_gather")
......
...@@ -351,8 +351,17 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -351,8 +351,17 @@ class _LayerNormMLP(torch.autograd.Function):
# which handles weight caching etc. # which handles weight caching etc.
# FP8 cast to workspace buffer # FP8 cast to workspace buffer
update_workspace = is_first_microbatch is None or is_first_microbatch update_workspace = is_first_microbatch is None or is_first_microbatch
# No need to set the quantizer states if weights are already quantized
if isinstance(fc1_weight, QuantizedTensorStorage):
fc1_weight_quantizer = fc1_weight._quantizer
elif fc1_weight_quantizer is not None:
fc1_weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled) fc1_weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled)
if isinstance(fc2_weight, QuantizedTensorStorage):
fc2_weight_quantizer = fc2_weight._quantizer
elif fc2_weight_quantizer is not None:
fc2_weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled) fc2_weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled)
fc1_weight_final = module.get_weight_workspace( fc1_weight_final = module.get_weight_workspace(
tensor=fc1_weight, tensor=fc1_weight,
quantizer=fc1_weight_quantizer, quantizer=fc1_weight_quantizer,
...@@ -538,13 +547,6 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -538,13 +547,6 @@ class _LayerNormMLP(torch.autograd.Function):
# Cache state for backward pass # Cache state for backward pass
if is_grad_enabled: if is_grad_enabled:
# Weight with column-wise usage is needed for dgrad GEMM.
if isinstance(fc1_weight_final, QuantizedTensorStorage):
fc1_weight_final.update_usage(columnwise_usage=True)
if isinstance(fc2_weight_final, QuantizedTensorStorage):
fc2_weight_final.update_usage(columnwise_usage=True)
if cpu_offloading: if cpu_offloading:
mark_activation_offload( mark_activation_offload(
inputmat, mu, rsigma, ln_out, fc1_out, fc1_out_without_bias, act_out inputmat, mu, rsigma, ln_out, fc1_out, fc1_out_without_bias, act_out
......
...@@ -240,7 +240,8 @@ class _Linear(torch.autograd.Function): ...@@ -240,7 +240,8 @@ class _Linear(torch.autograd.Function):
weightmat = weight weightmat = weight
if fp8 or debug: if fp8 or debug:
# Configure quantizer # Configure quantizer
if weight_quantizer is not None: # No need to set the quantizer states if weight is already quantized
if weight_quantizer is not None and not isinstance(weight, QuantizedTensor):
columnwise_usage = is_grad_enabled and inp.requires_grad columnwise_usage = is_grad_enabled and inp.requires_grad
if not columnwise_usage: if not columnwise_usage:
columnwise_usage = ( columnwise_usage = (
...@@ -248,7 +249,9 @@ class _Linear(torch.autograd.Function): ...@@ -248,7 +249,9 @@ class _Linear(torch.autograd.Function):
and not in_fp8_activation_recompute_phase() and not in_fp8_activation_recompute_phase()
) )
weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage)
elif isinstance(weight, QuantizedTensor):
# If weight is already quantized, no need to set quantizer states
weight_quantizer = weight._quantizer
# Get quantized weight # Get quantized weight
update_workspace = is_first_microbatch is None or is_first_microbatch update_workspace = is_first_microbatch is None or is_first_microbatch
weightmat = module.get_weight_workspace( weightmat = module.get_weight_workspace(
...@@ -389,11 +392,6 @@ class _Linear(torch.autograd.Function): ...@@ -389,11 +392,6 @@ class _Linear(torch.autograd.Function):
if backward_needs_input: if backward_needs_input:
saved_inputmat = inputmat saved_inputmat = inputmat
# Weight with column-wise usage is needed for dgrad GEMM.
if inp.requires_grad:
if isinstance(weightmat, QuantizedTensorStorage):
weightmat.update_usage(columnwise_usage=True)
if cpu_offloading and saved_inputmat is not None: if cpu_offloading and saved_inputmat is not None:
mark_activation_offload(saved_inputmat) mark_activation_offload(saved_inputmat)
......
...@@ -433,6 +433,10 @@ class QuantizedTensor(torch.Tensor): ...@@ -433,6 +433,10 @@ class QuantizedTensor(torch.Tensor):
and schema_arg.alias_info.is_write and schema_arg.alias_info.is_write
): ):
arg.quantize_(new_arg) arg.quantize_(new_arg)
elif isinstance(arg, list) and isinstance(new_arg, list):
# Recursively handle update for lists of tensors
for a, na in zip(arg, new_arg):
maybe_update_inplace(a, na, schema_arg)
# In-place op: dequantize, perform op, and quantize # In-place op: dequantize, perform op, and quantize
if func._schema.is_mutable: if func._schema.is_mutable:
...@@ -489,20 +493,16 @@ class QuantizedTensor(torch.Tensor): ...@@ -489,20 +493,16 @@ class QuantizedTensor(torch.Tensor):
shape: Optional[Iterable[int]] = None, shape: Optional[Iterable[int]] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
requires_grad: bool = False, requires_grad: bool = False,
data: Optional[torch.Tensor] = None,
) -> QuantizedTensor: ) -> QuantizedTensor:
"""Create new quantized tensor """Create new quantized tensor
By default, new tensor has the same attributes and underlying By default, new tensor has the same attributes and underlying
data. data. This function is intended to create view of tensors.
""" """
if shape is None: shape = shape if shape is not None else tensor.shape
shape = data.shape if data is not None else tensor.shape
dtype = dtype if dtype is not None else tensor.dtype dtype = dtype if dtype is not None else tensor.dtype
kwargs = tensor.get_metadata() kwargs = tensor.get_metadata()
if data is not None:
kwargs["data"] = data
return cls(shape=shape, dtype=dtype, requires_grad=requires_grad, **kwargs) return cls(shape=shape, dtype=dtype, requires_grad=requires_grad, **kwargs)
def to_dtype(self, dtype: torch.dtype) -> QuantizedTensor: def to_dtype(self, dtype: torch.dtype) -> QuantizedTensor:
......
...@@ -4,10 +4,10 @@ ...@@ -4,10 +4,10 @@
"""Tensor class with FP8 data""" """Tensor class with FP8 data"""
from __future__ import annotations from __future__ import annotations
from typing import Optional, Tuple, Iterable, Union from typing import Any, Optional, Tuple, Iterable, Union
import warnings import warnings
import torch import torch
from torch.distributed.fsdp._fully_shard._fsdp_common import TrainingState
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine_torch import DType as TE_DType from transformer_engine_torch import DType as TE_DType
...@@ -299,14 +299,12 @@ class Float8CurrentScalingQuantizer(Quantizer): ...@@ -299,14 +299,12 @@ class Float8CurrentScalingQuantizer(Quantizer):
# Allocate FP8 data transpose if needed # Allocate FP8 data transpose if needed
data_transpose = None data_transpose = None
if self.columnwise_usage: if self.columnwise_usage:
inner_dim = data.size(-1) transpose_shape = [data.size(-1)] + list(data.shape[:-1])
data_transpose = torch.empty( data_transpose = torch.empty(
inner_dim, transpose_shape,
data.numel() // inner_dim,
dtype=torch.uint8, dtype=torch.uint8,
device=device, device=device,
) )
# Construct FP8 tensor # Construct FP8 tensor
return Float8Tensor( return Float8Tensor(
shape=shape, shape=shape,
...@@ -534,9 +532,36 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor): ...@@ -534,9 +532,36 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor):
self._transpose = None self._transpose = None
@classmethod @classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None): def make_like(
cls,
tensor: QuantizedTensor,
*,
shape: Optional[Iterable[int]] = None,
dtype: Optional[torch.dtype] = None,
requires_grad: bool = False,
data: Optional[torch.Tensor] = None,
data_transpose: Optional[torch.Tensor] = None,
) -> QuantizedTensor:
"""Create new quantized tensor
# View op By default, new tensor has the same attributes and underlying
data.
"""
if shape is None and data is not None:
shape = data.shape
new_tensor = super().make_like(
tensor, shape=shape, dtype=dtype, requires_grad=requires_grad
)
if data is not None:
new_tensor._data = data
if data_transpose is not None:
new_tensor._transpose = data_transpose
new_tensor._transpose_invalid = False
return new_tensor
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None):
if func == aten.view.default: if func == aten.view.default:
tensor = args[0] tensor = args[0]
data = tensor._data data = tensor._data
...@@ -555,6 +580,9 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor): ...@@ -555,6 +580,9 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor):
or out_transpose_shape[1:] != out_shape[:-1] or out_transpose_shape[1:] != out_shape[:-1]
): ):
out_transpose = None out_transpose = None
else:
view_shape_for_transpose = [out_shape[-1]] + list(out_shape[:-1])
out_transpose = out_transpose.view(*view_shape_for_transpose)
return Float8Tensor( return Float8Tensor(
shape=out_shape, shape=out_shape,
dtype=tensor.dtype, dtype=tensor.dtype,
...@@ -587,11 +615,37 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor): ...@@ -587,11 +615,37 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor):
[data] + list(args[1:]), [data] + list(args[1:]),
kwargs, kwargs,
) )
return [ t_func_out = [None] * len(func_out)
Float8Tensor.make_like(tensor, data=split_tensor, shape=split_tensor.shape) # Compute corresponding split of the transpose cache if available
for split_tensor in func_out if tensor._transpose is not None and not tensor._transpose_invalid:
transpose = tensor._transpose
ndim = data.dim()
# Figure out the original split dim
if "dim" in kwargs:
dim_to_split = kwargs["dim"]
else:
dim_to_split = args[2] if len(args) > 2 else 0
# Dimension along which transpose needs to be split
t_dim = 0 if dim_to_split == ndim - 1 else dim_to_split + 1
t_func_out = transpose.__torch_dispatch__(
func,
types,
[transpose, args[1], t_dim],
kwargs,
)
outs = [
Float8Tensor.make_like(
tensor,
data=split_tensor,
data_transpose=split_transpose_tensor,
shape=split_tensor.shape,
)
for split_tensor, split_transpose_tensor in zip(func_out, t_func_out)
] ]
return outs
if func == aten.new_zeros.default: if func == aten.new_zeros.default:
# create fresh new tensor with zeros.
tensor = args[0] tensor = args[0]
data = tensor._data data = tensor._data
func_out = data.__torch_dispatch__( func_out = data.__torch_dispatch__(
...@@ -600,17 +654,63 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor): ...@@ -600,17 +654,63 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor):
[data] + list(args[1:]), [data] + list(args[1:]),
kwargs, kwargs,
) )
return Float8Tensor.make_like(tensor, data=func_out, shape=func_out.shape) func_transposed_out = None
if tensor._transpose is not None and not tensor._transpose_invalid:
transpose = tensor._transpose
size = args[1]
t_shape = [size[-1]] + list(size[:-1])
func_transposed_out = transpose.__torch_dispatch__(
func,
types,
[transpose, t_shape] + list(args[2:]),
kwargs,
)
# deep copy the scale inverse tensor and quantizer as well.
scale_inv = tensor._scale_inv.detach().clone()
quantizer = tensor._quantizer.copy()
out_tensor = Float8Tensor(
data=func_out,
shape=func_out.shape,
dtype=tensor.dtype,
fp8_dtype=tensor._fp8_dtype,
fp8_scale_inv=scale_inv,
data_transpose=func_transposed_out,
quantizer=quantizer,
)
return out_tensor
if func == torch.ops.aten.as_strided.default: if func == torch.ops.aten.as_strided.default:
tensor = args[0] tensor = args[0]
data = tensor._data data = tensor._data
# Apply as_strided to the primary uint8 data
func_out = data.__torch_dispatch__( func_out = data.__torch_dispatch__(
func, func,
types, types,
[data] + list(args[1:]), [data] + list(args[1:]),
kwargs, kwargs,
) )
return Float8Tensor.make_like(tensor, data=func_out, shape=func_out.shape) func_transposed_out = None
if tensor._transpose is not None and not tensor._transpose_invalid:
transpose = tensor._transpose
size = args[1]
stride = args[2]
if "storage_offset" in kwargs:
storage_offset = kwargs["storage_offset"]
else:
storage_offset = args[3] if len(args) > 3 else 0
# Shape and strided needed for transpose matrix
t_size = [size[-1]] + list(size[:-1])
t_stride = [stride[-1]] + list(stride[:-1])
func_transposed_out = transpose.__torch_dispatch__(
func,
types,
[transpose, t_size, t_stride, storage_offset] + list(args[4:]),
kwargs,
)
return Float8Tensor.make_like(
tensor, data=func_out, data_transpose=func_transposed_out, shape=func_out.shape
)
if func == torch.ops.aten.detach.default: if func == torch.ops.aten.detach.default:
return cls.detach(args[0]) return cls.detach(args[0])
if func == torch.ops.aten.clone.default: if func == torch.ops.aten.clone.default:
...@@ -632,9 +732,105 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor): ...@@ -632,9 +732,105 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor):
) )
else: else:
pass pass
return super().__torch_dispatch__(func, types, args, kwargs) return super().__torch_dispatch__(func, types, args, kwargs)
def fsdp_pre_all_gather(self, mesh, orig_size, contiguous_orig_stride, module, mp_policy):
"""Functions FSDP2 calls before all-gather of the
weights for both forward and backward passes.
Args:
mesh (torch.distributed.DeviceMesh): DeviceMesh used by FSDP2
to shard the weights.
orig_size (torch.Size): Original size of the weight tensor.(For us same as self.shape)
contiguous_orig_stride (Tuple[int]): Original stride of the weight tensor
(For us same as self.stride())
module (FSDPModule): FSDP module. FSDP wrapped module wrapped using fully_shard
that contains this FP8 tensor.
mp_policy (MixedPrecisionPolicy): Mixed precision policy used by FSDP2.
Returns:
shareded_tensors: Tuple[torch.Tensor, ...]: Tuple of tensors
that need to be all-gathered.(In this case uint8 data tensor)
metadata: Tuple[Any]: Metadata needed for reconstructing the
Float8Tensor after all-gather.
"""
# pylint: disable=unused-argument
# Importing here to avoid circular imports
from transformer_engine.pytorch.distributed import _get_module_fsdp_state
if isinstance(self._quantizer, Float8CurrentScalingQuantizer) and mesh is not None:
# When sharded weight is updated after reduce scattering the gradients in FSDP2,
# we need to do amax reduction across the mesh to make sure all weight shards are
# updated with same scale inverse. Setting the state below in the quantizer will make
# sure that updated Quantized weight tensor have same scale inverse across all shards.
self._quantizer.amax_reduction_group = mesh.get_group()
self._quantizer.with_amax_reduction = True
quantizer = self._quantizer.copy() # quantizer to be used for allgathered weights
fsdp_state = _get_module_fsdp_state(module)
reshard_after_forward = fsdp_state._fsdp_param_group._reshard_after_forward
# If weights are resharded after forward pass, then its enough to set the quantizer usages
# based on whether its forward or backward pass for the allgathered weights.
# If not resharded after forward pass, the same weights allgathered in forward
# are used again in backward and so we dont change the quantizer usages which might need
# both rowwise and columnwise usages.
if reshard_after_forward:
training_state = fsdp_state._fsdp_param_group._training_state
is_backward_pass = training_state == TrainingState.PRE_BACKWARD
# In case of hopper/L40, only one of data/transpose is needed
# based on forward or backward pass. So setting the quantizer usages appropriately.
quantizer.set_usage(rowwise=not is_backward_pass, columnwise=is_backward_pass)
sharded_tensors = (self._data,)
metadata = (self._scale_inv, self._fp8_dtype, quantizer)
return sharded_tensors, metadata
def fsdp_post_all_gather(
self,
all_gather_outputs: Tuple[torch.Tensor, ...],
metadata: Any,
param_dtype: torch.dtype,
*,
out: Optional[Float8Tensor] = None,
):
"""Functions FSDP2 calls after all-gather of the
weights for both forward and backward passes.
Args:
all_gather_outputs (Tuple[torch.Tensor, ...]): sharded_tensors sent out in fsdp_pre_all_gather from each rank
are all-gathered and received here as a tuple.
metadata (Any): metadata sent out in fsdp_pre_all_gather used for reconstructing the Float8Tensor.
param_dtype (torch.dtype): high precision dtype of the Float8Tensor.
out (Optional[torch.Tensor], optional): _description_. Defaults to None.
Returns:
Tuple[Float8Tensor, Tuple[torch.Tensor, ...]]: Allgathered Float8Tensor and tuple of internal tensors
used by the Float8Tensor that was being computed after allgather.
"""
(data,) = all_gather_outputs
(fp8_scale_inv, fp8_dtype, quantizer) = metadata
orig_shape = data.size()
# Quantizer has only columnwise usage set for backward pass
# In Blackwell+ architectures, transpose is not needed at all,
# even if columnwise usage is set. and is going to be handled
# internally in the update_usage method.
if out is not None:
out._data = data
else:
fp8_args = {
"shape": orig_shape,
"dtype": param_dtype,
"fp8_scale_inv": fp8_scale_inv,
"fp8_dtype": fp8_dtype,
"quantizer": quantizer,
"requires_grad": False,
"data": data,
}
out = Float8Tensor(**fp8_args)
out.update_usage(
rowwise_usage=quantizer.rowwise_usage,
columnwise_usage=quantizer.columnwise_usage,
)
return out, all_gather_outputs
@classmethod @classmethod
def _make_in_reduce_ex( def _make_in_reduce_ex(
cls, cls,
...@@ -752,6 +948,9 @@ class _ViewFunc(torch.autograd.Function): ...@@ -752,6 +948,9 @@ class _ViewFunc(torch.autograd.Function):
out_transpose_shape = out_transpose.size() out_transpose_shape = out_transpose.size()
if out_transpose_shape[0] != out_shape[-1] or out_transpose_shape[1:] != out_shape[:-1]: if out_transpose_shape[0] != out_shape[-1] or out_transpose_shape[1:] != out_shape[:-1]:
out_transpose = None out_transpose = None
else:
view_shape_for_transpose = [shape[-1]] + list(shape[:-1])
out_transpose = out_transpose.view(*view_shape_for_transpose)
return Float8Tensor( return Float8Tensor(
shape=out_shape, shape=out_shape,
dtype=tensor.dtype, dtype=tensor.dtype,
...@@ -796,6 +995,9 @@ class _ReshapeFunc(torch.autograd.Function): ...@@ -796,6 +995,9 @@ class _ReshapeFunc(torch.autograd.Function):
out_transpose_shape = out_transpose.size() out_transpose_shape = out_transpose.size()
if out_transpose_shape[0] != out_shape[-1] or out_transpose_shape[1:] != out_shape[:-1]: if out_transpose_shape[0] != out_shape[-1] or out_transpose_shape[1:] != out_shape[:-1]:
out_transpose = None out_transpose = None
else:
reshape_shape_for_transpose = [shape[-1]] + list(shape[:-1])
out_transpose = out_transpose.reshape(*reshape_shape_for_transpose)
return Float8Tensor( return Float8Tensor(
shape=out_shape, shape=out_shape,
dtype=tensor.dtype, dtype=tensor.dtype,
......
...@@ -6,16 +6,17 @@ ...@@ -6,16 +6,17 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Iterable from collections.abc import Iterable
import math import math
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union, Any
import warnings
import torch import torch
from torch.distributed.fsdp._fully_shard._fsdp_common import TrainingState
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine_torch import DType as TE_DType from transformer_engine_torch import DType as TE_DType
from transformer_engine.common.recipe import MXFP8BlockScaling, Recipe from transformer_engine.common.recipe import MXFP8BlockScaling, Recipe
from ..constants import MXFP8_BLOCK_SCALING_SIZE from ..constants import MXFP8_BLOCK_SCALING_SIZE
from ..utils import devices_match, round_up_to_nearest_multiple from ..utils import devices_match, round_up_to_nearest_multiple
from .storage.mxfp8_tensor_storage import MXFP8TensorStorage, _FromMXFP8Func from .storage.mxfp8_tensor_storage import MXFP8TensorStorage, _FromMXFP8Func
from ..quantized_tensor import QuantizedTensor, Quantizer from ..quantized_tensor import QuantizedTensor, Quantizer
from ._quantization_helpers import _IdentityFunc from ._quantization_helpers import _IdentityFunc
...@@ -298,7 +299,6 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor): ...@@ -298,7 +299,6 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
memory_format: torch.memory_format = torch.contiguous_format, memory_format: torch.memory_format = torch.contiguous_format,
) -> MXFP8Tensor: ) -> MXFP8Tensor:
"""Returns tensor with data in provided memory format """Returns tensor with data in provided memory format
Returns `self` if data is already in correct memory format. Returns `self` if data is already in correct memory format.
""" """
...@@ -314,7 +314,6 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor): ...@@ -314,7 +314,6 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
@classmethod @classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None): def __torch_dispatch__(cls, func, types, args, kwargs=None):
# View op # View op
if func == aten.view.default: if func == aten.view.default:
tensor = args[0] tensor = args[0]
...@@ -338,9 +337,335 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor): ...@@ -338,9 +337,335 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
fp8_dtype=tensor._fp8_dtype, fp8_dtype=tensor._fp8_dtype,
) )
if func == torch.ops.aten.copy_.default:
dst, src = args[0], args[1]
# Booleans to check if src has all the usages that dst needs to respect dst quantizer usages.
# If not, default to base class behavior.
rowwise_matches = src._rowwise_data is not None or dst._rowwise_data is None
columnwise_matches = src._columnwise_data is not None or dst._columnwise_data is None
if (
isinstance(src, MXFP8Tensor)
and isinstance(dst, MXFP8Tensor)
and rowwise_matches
and columnwise_matches
):
if dst._rowwise_data is not None:
dst._rowwise_data.copy_(src._rowwise_data.detach())
dst._rowwise_scale_inv.copy_(src._rowwise_scale_inv.detach())
if dst._columnwise_data is not None:
dst._columnwise_data.copy_(src._columnwise_data.detach())
dst._columnwise_scale_inv.copy_(src._columnwise_scale_inv.detach())
return dst
# FSDP2 related functions.
if func == aten.split.Tensor:
# This is called if entire model is initialized on CUDA device and
# then splitted. Finally the shard needed by the process is used
# and other splitted shards are discarded.
if "dim" in kwargs:
dim_to_split = kwargs["dim"]
else:
dim_to_split = args[2] if len(args) > 2 else 0
tensor = args[0]
split_size = args[1]
dim0_size = tensor.size(0)
dimlast_size = math.prod(tensor.shape[1:])
if (
dim0_size % split_size != 0
or dim_to_split != 0
or split_size % MXFP8_BLOCK_SCALING_SIZE != 0
or dimlast_size % MXFP8_BLOCK_SCALING_SIZE != 0
):
# Handle splitting by dequantizing and splitting the hp tensor
return super().__torch_dispatch__(func, types, args, kwargs)
out_data = []
for data in [tensor._rowwise_data, tensor._columnwise_data]:
func_out = (
data.__torch_dispatch__(
func,
types,
[data] + list(args[1:]),
kwargs,
)
if data is not None
else None
)
out_data.append(func_out)
scale_invs = [tensor._rowwise_scale_inv, tensor._columnwise_scale_inv]
split_sizes_for_scale = [split_size, split_size // MXFP8_BLOCK_SCALING_SIZE]
# Padding requirements: rowwise dim0 should be divisble by 128, columnwise dim0 should be divisble by 4
padding_multiples = [128, 4]
for scale_inv, scale_split_size, pad_multiple in zip(
scale_invs, split_sizes_for_scale, padding_multiples
):
scale_inv_out = (
scale_inv.__torch_dispatch__(
func,
types,
[scale_inv, scale_split_size] + list(args[2:]),
kwargs,
)
if scale_inv is not None
else None
)
# Pad scale_inv_out to be a multiple of pad_multiple
if scale_inv_out is not None:
current_shape = scale_inv_out.shape
pad_dim0 = (pad_multiple - current_shape[0] % pad_multiple) % pad_multiple
if pad_dim0 > 0:
scale_inv_out = torch.nn.functional.pad(scale_inv_out, (0, 0, 0, pad_dim0))
out_data.append(scale_inv_out)
return [
MXFP8Tensor(
shape=(
splitted_tensor_data[0].size()
if splitted_tensor_data[0] is not None
else splitted_tensor_data[1].size()
),
dtype=tensor.dtype,
rowwise_data=splitted_tensor_data[0],
rowwise_scale_inv=splitted_tensor_data[2],
columnwise_data=splitted_tensor_data[1],
columnwise_scale_inv=splitted_tensor_data[3],
quantizer=tensor._quantizer,
requires_grad=False,
fp8_dtype=tensor._fp8_dtype,
)
for splitted_tensor_data in zip(*out_data)
]
if func == torch.ops.aten.as_strided.default:
# Applied on unsharded param in FSDP2. In our case, this should be a no-op
# This is needed for the case where some MXFP8 shards need padding i.e dimension 0
# of the unsharded param is not a multiple of the world size. If that is the case,
# we down the dequantization route and weights are allgathered in high precision.
# If weight doesnt need padding, this is just a no-op.
shape = args[1]
strides = args[2]
tensor = args[0]
if (
len(shape) != 2
or len(strides) != 2
or strides[1] != 1
or shape[0] != tensor.shape[0]
or shape[1] != tensor.shape[1]
):
return super().__torch_dispatch__(func, types, args, kwargs)
return MXFP8Tensor.make_like(tensor)
if func == aten.slice.Tensor:
# FSDP2 needed function.
# We need slicing for the case where some MXFP8 weight shards need padding i.e dimension 0
# of the unsharded param is not a multiple of the world size. If that is the case,
# we down the dequantization route and weights are allgathered in high precision instead.
# If sharded weight doesnt have padding, this is just a no-op.
dim = args[1]
start = args[2]
length = args[3]
tensor = args[0]
if (
dim != 0
or length != tensor.shape[0]
or start != 0
or length % MXFP8_BLOCK_SCALING_SIZE != 0
or start % MXFP8_BLOCK_SCALING_SIZE != 0
):
return super().__torch_dispatch__(func, types, args, kwargs)
return MXFP8Tensor.make_like(tensor)
if func == aten.new_zeros.default:
rowwise_data = None
columnwise_data = None
rowwise_scale_inv = None
columnwise_scale_inv = None
tensor = args[0]
shape = args[1]
first_dim = math.prod(shape[:-1])
last_dim = shape[-1]
if (
first_dim % MXFP8_BLOCK_SCALING_SIZE != 0
or last_dim % MXFP8_BLOCK_SCALING_SIZE != 0
):
return super().__torch_dispatch__(func, types, args, kwargs)
rowwise_scale_inv_shape = [first_dim, last_dim // MXFP8_BLOCK_SCALING_SIZE]
columnwise_scale_inv_shape = [
first_dim // MXFP8_BLOCK_SCALING_SIZE,
last_dim,
]
if tensor._rowwise_data is not None:
rowwise_data = tensor._rowwise_data.__torch_dispatch__(
func,
types,
[tensor._rowwise_data] + list(args[1:]),
kwargs,
)
rowwise_scale_inv = tensor._rowwise_scale_inv.__torch_dispatch__(
func,
types,
[tensor._rowwise_scale_inv, rowwise_scale_inv_shape] + list(args[2:]),
kwargs,
)
if tensor._columnwise_data is not None:
columnwise_data = tensor._columnwise_data.__torch_dispatch__(
func,
types,
[tensor._columnwise_data] + list(args[1:]),
kwargs,
)
columnwise_scale_inv = tensor._columnwise_scale_inv.__torch_dispatch__(
func,
types,
[tensor._columnwise_scale_inv, columnwise_scale_inv_shape] + list(args[2:]),
kwargs,
)
return MXFP8Tensor(
shape=args[1],
dtype=tensor.dtype,
rowwise_data=rowwise_data,
rowwise_scale_inv=rowwise_scale_inv,
columnwise_data=columnwise_data,
columnwise_scale_inv=columnwise_scale_inv,
quantizer=tensor._quantizer.copy(),
requires_grad=False,
fp8_dtype=tensor._fp8_dtype,
)
# Default case # Default case
return super().__torch_dispatch__(func, types, args, kwargs) return super().__torch_dispatch__(func, types, args, kwargs)
def fsdp_pre_all_gather(self, mesh, orig_size, contiguous_orig_stride, module, mp_policy):
"""Functions FSDP2 calls before all-gather of the
weights for both forward and backward passes.
Args:
mesh (torch.distributed.DeviceMesh): DeviceMesh used by FSDP2
to shard the weights.
orig_size (torch.Size): Original size of the weight tensor.(For us same as self.shape)
contiguous_orig_stride (Tuple[int]): Original stride of the weight tensor
(For us same as self.stride()).
module (FSDPModule): FSDP module. FSDP wrapped module wrapped using fully_shard
that contains this MXFP8 tensor.
mp_policy (MixedPrecisionPolicy): Mixed precision policy used by FSDP2.
Returns:
sharded_tensors: Tuple[torch.Tensor, ...]: Tuple of tensors
that need to be all-gathered.
metadata: Tuple[Any]: Metadata needed for reconstructing the
MXFP8Tensor after all-gather.
"""
# pylint: disable=unused-argument
from transformer_engine.pytorch.distributed import _get_module_fsdp_state
fsdp_state = _get_module_fsdp_state(module)
reshard_after_forward = fsdp_state._fsdp_param_group._reshard_after_forward
quantizer = self._quantizer.copy()
# Remove padding from scale inverses before allgather
# Rowwise scale_inv should be divisible by [128,4], columnwise by [4, 128]
rowwise_scale_inv = self._rowwise_scale_inv
columnwise_scale_inv = self._columnwise_scale_inv
shape = self.shape
if rowwise_scale_inv is not None:
# Remove padding from rowwise scale_inv
flattened_in_shape0 = math.prod(shape[:-1])
if rowwise_scale_inv.size(0) != flattened_in_shape0:
rowwise_scale_inv = rowwise_scale_inv[:flattened_in_shape0]
if columnwise_scale_inv is not None:
# Remove padding from columnwise scale_inv
flattened_in_shape0 = math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE
if columnwise_scale_inv.size(0) != flattened_in_shape0:
columnwise_scale_inv = columnwise_scale_inv[:flattened_in_shape0]
sharded_tensors = (self._rowwise_data, rowwise_scale_inv)
# If weights are resharded after forward pass, then its enough to set the quantizer usages
# based on whether its forward or backward pass for the allgathered weights.
# If not resharded after forward pass, the same weights allgathered in forward
# are used again in backward. And hence if we need the columnwise data/scale_inv,
# we need to send them as well for allgather in forward pass itself.
if reshard_after_forward:
training_state = fsdp_state._fsdp_param_group._training_state
is_backward_pass = training_state == TrainingState.PRE_BACKWARD
# Allgather only the necessary tensors based on forward/backward pass
quantizer.set_usage(rowwise=not is_backward_pass, columnwise=is_backward_pass)
sharded_tensors = (
(self._columnwise_data, columnwise_scale_inv)
if is_backward_pass
else sharded_tensors
)
else:
if quantizer.columnwise_usage:
# If weights are not resharded after forward, then both
# rowwise and columnwise data/scale_inv need to be allgathered.
sharded_tensors += (self._columnwise_data, columnwise_scale_inv)
metadata = (self._fp8_dtype, quantizer)
return sharded_tensors, metadata
def fsdp_post_all_gather(
self,
all_gather_outputs: Tuple[torch.Tensor, ...],
metadata: Any,
param_dtype: torch.dtype,
*,
out: Optional[MXFP8Tensor] = None,
):
"""Functions FSDP2 calls after all-gather of the
weights for both forward and backward passes.
Args:
all_gather_outputs (Tuple[torch.Tensor, ...]): sharded_tensors sent out in fsdp_pre_all_gather from each rank
are all-gathered and received here as a tuple.
metadata (Any): metadata sent out in fsdp_pre_all_gather used for reconstructing the MXFP8Tensor.
param_dtype (torch.dtype): high precision dtype of the MXFP8Tensor.
out (Optional[torch.Tensor], optional): _description_. Defaults to None.
Returns:
Tuple[MXFP8Tensor, Tuple[torch.Tensor, ...]]: Allgathered MXFP8Tensor and tuple of internal tensors
used by the MXFP8Tensor that was being computed after allgather.
"""
fp8_dtype, quantizer = metadata
rowwise_data, rowwise_scale_inv = (
all_gather_outputs[:2] if quantizer.rowwise_usage else (None, None)
)
columnwise_data, columnwise_scale_inv = (
all_gather_outputs[-2:] if quantizer.columnwise_usage else (None, None)
)
# Add padding to scale_inv tensors to be multiples of [128, 4]for rowwise and [4, 128] for columnwise
if rowwise_scale_inv is not None:
# Pad rowwise_scale_inv to be a multiple of [128, 4]
current_shape = rowwise_scale_inv.shape
pad_dim0 = (128 - current_shape[0] % 128) % 128
if pad_dim0 > 0:
rowwise_scale_inv = torch.nn.functional.pad(rowwise_scale_inv, (0, 0, 0, pad_dim0))
if columnwise_scale_inv is not None:
# Pad columnwise_scale_inv to be a multiple of [4, 128]
current_shape = columnwise_scale_inv.shape
pad_dim0 = (4 - current_shape[0] % 4) % 4
if pad_dim0 > 0:
columnwise_scale_inv = torch.nn.functional.pad(
columnwise_scale_inv, (0, 0, 0, pad_dim0)
)
if out is not None:
out._rowwise_data = rowwise_data
out._rowwise_scale_inv = rowwise_scale_inv
out._columnwise_data = columnwise_data
out._columnwise_scale_inv = columnwise_scale_inv
out._quantizer = quantizer
else:
out = MXFP8Tensor(
rowwise_data=rowwise_data,
rowwise_scale_inv=rowwise_scale_inv,
columnwise_data=columnwise_data,
columnwise_scale_inv=columnwise_scale_inv,
fp8_dtype=fp8_dtype,
dtype=param_dtype,
shape=rowwise_data.shape if rowwise_data is not None else columnwise_data.shape,
quantizer=quantizer,
)
return out, all_gather_outputs
@classmethod @classmethod
def _make_in_reduce_ex( def _make_in_reduce_ex(
cls, cls,
...@@ -478,10 +803,14 @@ class _ViewFunc(torch.autograd.Function): ...@@ -478,10 +803,14 @@ class _ViewFunc(torch.autograd.Function):
shape[i] = d_inferred shape[i] = d_inferred
break break
if shape[-1] != ctx.shape[-1]: if shape[-1] != ctx.shape[-1]:
raise RuntimeError( warnings.warn(
"MXFP8Tensor does not support reshaping inner dimension " "MXFP8Tensor does not support reshaping inner dimension. "
f"(attempted to reshape dims={tuple(tensor.shape)} to {tuple(shape)})" f"(attempted to reshape dims={tuple(tensor.shape)} to {tuple(shape)})"
"If you are using this for FSDP2 without compiled_autograd_enabled,"
"then ignore this warning. Since this view is not going to be used anywhere. ",
stacklevel=2,
) )
return tensor.dequantize().view(*shape)
# Construct new tensor if shape is provided # Construct new tensor if shape is provided
new_rowwise_data = None new_rowwise_data = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment