Unverified Commit c84d1708 authored by Jianbin Chang's avatar Jianbin Chang Committed by GitHub
Browse files

Support FP8 primary weight in FSDP training (#1630)



Support fp8 primary weight in fsdp training
Signed-off-by: default avatarjianbinc <shjwudp@gmail.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent a3ba4dff
...@@ -21,7 +21,11 @@ from transformer_engine.common.recipe import ( ...@@ -21,7 +21,11 @@ from transformer_engine.common.recipe import (
) )
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
from transformer_engine.pytorch.tensor import QuantizedTensor, cast_master_weights_to_fp8 from transformer_engine.pytorch.tensor import QuantizedTensor, cast_master_weights_to_fp8
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor from transformer_engine.pytorch.tensor.float8_tensor import (
Float8Tensor,
Float8CurrentScalingQuantizer,
)
from transformer_engine.pytorch.tensor.utils import replace_raw_data
def _get_raw_data(quantized_tensor): def _get_raw_data(quantized_tensor):
...@@ -228,6 +232,279 @@ class MiniOptimizer: ...@@ -228,6 +232,279 @@ class MiniOptimizer:
weight.data.copy_(master_weight) weight.data.copy_(master_weight)
class MiniFSDP:
def __init__(self, weights, lr, dp_group):
rank = dist.get_rank(dp_group)
world_size = dist.get_world_size(dp_group)
self.weights = weights
self.lr = lr
self.dp_group = dp_group
# Flatten the weights and pad to align with world size
raw_data_list = [
_get_raw_data(w).view(-1) if isinstance(w, Float8Tensor) else w.view(-1)
for w in weights
]
if isinstance(weights[0], Float8Tensor):
raw_data_list = [_get_raw_data(w).view(-1) for w in weights]
else:
raw_data_list = [w.view(-1) for w in weights]
self.flatten_weight, original_length = self._flatten_tensors_with_pad(raw_data_list)
# Split flattened weights into shards
self.local_weight_shard = torch.chunk(self.flatten_weight, world_size)[rank]
self.local_main_grad_shard = torch.zeros_like(self.local_weight_shard)
shard_size = self.flatten_weight.size(0) // world_size
# Map original tensors to flattened indices
tensor_indices = []
cumulative_length = 0
for tensor in raw_data_list:
length = tensor.size(0)
tensor_indices.append((cumulative_length, cumulative_length + length))
cumulative_length += length
# Build shard index mappings
self.weight_indices = []
self.shard_indices = []
for idx, (start, end) in enumerate(tensor_indices):
shard_start = rank * shard_size
shard_end = shard_start + shard_size
adjusted_end = min(shard_end, original_length)
if start <= adjusted_end and end >= shard_start:
start_idx = max(start, shard_start)
end_idx = min(end, adjusted_end)
self.weight_indices.append((start_idx - start, end_idx - start))
self.shard_indices.append((start_idx - shard_start, end_idx - shard_start))
else:
self.weight_indices.append((None, None))
self.shard_indices.append((None, None))
if isinstance(weights[idx], Float8Tensor):
replace_raw_data(
weights[idx], self.flatten_weight[start:end].view(weights[idx].shape)
)
else:
weights[idx].data = self.flatten_weight[start:end].view(weights[idx].shape)
# Initialize local model weights and high-precision master weights
self.local_weights = []
self.master_weights = []
for i, weight in enumerate(self.weights):
weight_start, weight_end = self.weight_indices[i]
shard_start, shard_end = self.shard_indices[i]
if shard_start is not None and shard_end is not None:
local_weight_shard = self.local_weight_shard[shard_start:shard_end]
self.local_weights.append(local_weight_shard)
if isinstance(weight, QuantizedTensor):
high_precision_init_val = weight.get_high_precision_init_val().view(-1)
master_weight_shard = high_precision_init_val.to(weight.device).float()[
weight_start:weight_end
]
else:
master_weight_shard = weight.detach().view(-1).float()[weight_start:weight_end]
self.master_weights.append(master_weight_shard)
else:
self.local_weights.append(None)
self.master_weights.append(None)
setattr(
weight, "main_grad", torch.zeros_like(weight, dtype=torch.float32, device="cuda")
)
def _flatten_tensors_with_pad(self, tensors):
"""
Flatten the list of tensors and pad them to align with the world size.
Args:
tensors (list): List of tensors to flatten.
Returns:
tuple: Flattened tensor and its original length before padding.
"""
world_size = dist.get_world_size(self.dp_group)
flatten_tensor = torch.cat(tensors)
original_length = flatten_tensor.size(0)
padding_needed = (world_size - original_length % world_size) % world_size
if padding_needed > 0:
flatten_tensor = torch.cat(
[flatten_tensor, torch.zeros(padding_needed, dtype=flatten_tensor.dtype)]
)
return flatten_tensor, original_length
def zero_grad(self):
for weight in self.weights:
weight.grad = None
weight.main_grad.zero_()
def step(self):
"""
Perform an optimization step for the distributed sharded model.
This method includes:
1. Gradient reduce-scatter: Synchronize gradients across all processes.
2. Master weight update: Update high-precision master weights using local gradients.
3. Precision casting: Cast updated master weights to FP8 or BF16 precision.
4. Weight synchronization: All-gather updated weights across all processes.
Returns:
None
"""
# Step 1: Reduce-scatter the gradients
main_grad_buffer, _ = self._flatten_tensors_with_pad(
[weight.main_grad.view(-1) for weight in self.weights]
)
main_grad_buffer = main_grad_buffer.to(self.local_main_grad_shard.dtype)
dist.reduce_scatter_tensor(
self.local_main_grad_shard, main_grad_buffer, group=self.dp_group
)
# Step 2: Update the master weights
for weight, master_weight, (shard_start, shard_end) in zip(
self.weights, self.master_weights, self.shard_indices
):
if master_weight is None:
continue
# Extract the local gradient shard for this weight
grad = self.local_main_grad_shard[shard_start:shard_end]
# Update the master weight using gradient descent
master_weight -= grad * self.lr
# Step 3: Cast master weights to FP8 or BF16 precision
if isinstance(self.weights[0], Float8Tensor):
local_weights = []
for model_weight, local_weight in zip(self.weights, self.local_weights):
if local_weight is None:
local_weights.append(None)
continue
quantizer = model_weight._get_quantizer()
if isinstance(quantizer, Float8CurrentScalingQuantizer):
local_weight = quantizer.create_tensor_from_data(
local_weight.view(-1),
model_weight.dtype,
)
local_weights.append(local_weight)
cast_master_weights_to_fp8(
self.weights,
self.master_weights,
[idx[0] for idx in self.weight_indices],
self.dp_group,
local_weights,
)
else:
for weight, master_weight in zip(self.local_weights, self.master_weights):
if master_weight is None:
continue
# Copy updated master weights to local weights
weight.data.copy_(master_weight)
# Step 4: All-gather updated weights across processes
dist.all_gather_into_tensor(
self.flatten_weight, self.local_weight_shard, group=self.dp_group
)
def _test_fsdp_cast_master_weights_to_fp8(quantization, dp_group):
rank = dist.get_rank(dp_group)
world_size = dist.get_world_size(dp_group)
# Configuration constants
NUM_STEPS = 100
SEED = 12345
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
mock_groups = [dist.new_group(ranks=[i]) for i in range(world_size)]
mock_group = mock_groups[rank]
linear_kwargs = {
"params_dtype": torch.bfloat16,
"bias": False,
"fuse_wgrad_accumulation": False,
}
# Create model with FP8 weights
with te.fp8.fp8_model_init(
enabled=quantization is not None,
recipe=quantization_recipe(quantization),
preserve_high_precision_init_val=True,
):
model_fp8 = nn.Sequential(
te.Linear(128, 256, **linear_kwargs),
te.Linear(256, 256 * 3, **linear_kwargs),
te.Linear(256 * 3, 128, **linear_kwargs),
)
# Create model with BF16 weights
model = nn.Sequential(
te.Linear(128, 256, **linear_kwargs),
te.Linear(256, 256 * 3, **linear_kwargs),
te.Linear(256 * 3, 128, **linear_kwargs),
)
# Make sure the BF16 model and FP8 model have the same initial weights
for w_fp8, w in zip(model_fp8.parameters(), model.parameters()):
high_precision_init_val = w_fp8.get_high_precision_init_val()
w.data.copy_(high_precision_init_val)
optimizer_fp8 = MiniFSDP([w for w in model_fp8.parameters()], 10.0, dp_group)
optimizer = MiniFSDP([w for w in model.parameters()], 10.0, dp_group)
for _ in range(100):
optimizer_fp8.zero_grad()
optimizer.zero_grad()
inputs = [
torch.randn(16, 128, dtype=torch.bfloat16, device="cuda") for _ in range(world_size)
]
# Choose based on rank to make sure the inputs of different ranks are different.
x = inputs[rank]
with te.fp8.fp8_autocast(
enabled=quantization is not None,
fp8_recipe=quantization_recipe(quantization),
fp8_group=mock_group,
):
y_fp8 = model_fp8(x)
with te.fp8_autocast(
enabled=quantization is not None,
fp8_recipe=quantization_recipe(quantization),
fp8_group=mock_group,
):
y = model(x)
targets = [torch.randn_like(y) for _ in range(world_size)]
# Choose based on rank to make sure the targets of different ranks are different.
target = targets[rank]
loss_fp8 = nn.MSELoss()(y_fp8, target)
loss = nn.MSELoss()(y, target)
loss_fp8.backward()
loss.backward()
optimizer_fp8.step()
optimizer.step()
torch.testing.assert_close(loss_fp8, loss, atol=0, rtol=0)
print(
f"✅ Successfully validated FSDP {NUM_STEPS} training steps with"
f" {quantization} quantization"
)
def _test_zero_1(dp_group): def _test_zero_1(dp_group):
"""Make sure the implementation of zero-1 optimizer is correct""" """Make sure the implementation of zero-1 optimizer is correct"""
rank = dist.get_rank(dp_group) rank = dist.get_rank(dp_group)
...@@ -389,6 +666,7 @@ def main(argv=None, namespace=None): ...@@ -389,6 +666,7 @@ def main(argv=None, namespace=None):
dp_group = dist.new_group(backend="nccl") dp_group = dist.new_group(backend="nccl")
_test_zero_1(dp_group) _test_zero_1(dp_group)
_test_cast_master_weights_to_fp8(args.quantization, dp_group) _test_cast_master_weights_to_fp8(args.quantization, dp_group)
_test_fsdp_cast_master_weights_to_fp8(args.quantization, dp_group)
dist.destroy_process_group() dist.destroy_process_group()
return 0 return 0
......
...@@ -38,7 +38,9 @@ def replace_raw_data(tensor: QuantizedTensor, new_raw_data: torch.Tensor): ...@@ -38,7 +38,9 @@ def replace_raw_data(tensor: QuantizedTensor, new_raw_data: torch.Tensor):
raise ValueError(f"replace_raw_data for {type(tensor)} is not supported yet") raise ValueError(f"replace_raw_data for {type(tensor)} is not supported yet")
def cast_master_weights_to_fp8(model_weights, master_weights, start_offsets, group): def cast_master_weights_to_fp8(
model_weights, master_weights, start_offsets, group, fsdp_shard_model_weights=None
):
r"""Helper function to cast master weights to FP8 primary weights. r"""Helper function to cast master weights to FP8 primary weights.
This is intended for use with ZeRO/FSDP. Each rank has a shard of This is intended for use with ZeRO/FSDP. Each rank has a shard of
...@@ -55,14 +57,23 @@ def cast_master_weights_to_fp8(model_weights, master_weights, start_offsets, gro ...@@ -55,14 +57,23 @@ def cast_master_weights_to_fp8(model_weights, master_weights, start_offsets, gro
should be updated. should be updated.
group : The distributed group to do amax reduction. Typically it's the data parallel group : The distributed group to do amax reduction. Typically it's the data parallel
group. group.
fsdp_shard_model_weights : list of FSDP shard model weights. If None, it means that the model weights are
not sharded. Otherwise, it means that the model weights are sharded and we get
target model weights data storage using the FSDP shard model weights.
""" """
delayed_scaling_params = [] delayed_scaling_params = []
current_scaling_params = [] current_scaling_params = []
for model_weight, master_weight, start_offset in zip( if fsdp_shard_model_weights is None:
model_weights, master_weights, start_offsets use_fsdp_shard_model_weights = False
fsdp_shard_model_weights = [None] * len(model_weights)
else:
use_fsdp_shard_model_weights = True
for model_weight, master_weight, start_offset, fsdp_shard_model_weight in zip(
model_weights, master_weights, start_offsets, fsdp_shard_model_weights
): ):
# Clear `_high_precision_init_val` of model_weight automatically. # Clear `_high_precision_init_val` of model_weight automatically.
# - Master weights are initialized from model weights, if we use fp8 primary weights to # - Master weights are initialized from model weights, if we use fp8 primary weights to
...@@ -88,9 +99,13 @@ def cast_master_weights_to_fp8(model_weights, master_weights, start_offsets, gro ...@@ -88,9 +99,13 @@ def cast_master_weights_to_fp8(model_weights, master_weights, start_offsets, gro
quantizer = model_weight._get_quantizer() quantizer = model_weight._get_quantizer()
if isinstance(quantizer, Float8Quantizer): if isinstance(quantizer, Float8Quantizer):
delayed_scaling_params.append((model_weight, master_weight, start_offset)) delayed_scaling_params.append(
(model_weight, master_weight, start_offset, fsdp_shard_model_weight)
)
elif isinstance(quantizer, Float8CurrentScalingQuantizer): elif isinstance(quantizer, Float8CurrentScalingQuantizer):
current_scaling_params.append((model_weight, master_weight, start_offset)) current_scaling_params.append(
(model_weight, master_weight, start_offset, fsdp_shard_model_weight)
)
elif isinstance(quantizer, MXFP8Quantizer): elif isinstance(quantizer, MXFP8Quantizer):
raise NotImplementedError( raise NotImplementedError(
"cast_master_weights_to_fp8 for MXFP8BlockScaling is not supported yet" "cast_master_weights_to_fp8 for MXFP8BlockScaling is not supported yet"
...@@ -101,12 +116,16 @@ def cast_master_weights_to_fp8(model_weights, master_weights, start_offsets, gro ...@@ -101,12 +116,16 @@ def cast_master_weights_to_fp8(model_weights, master_weights, start_offsets, gro
) )
if len(delayed_scaling_params) > 0: if len(delayed_scaling_params) > 0:
_cast_master_weights_to_fp8_delayed_scaling(delayed_scaling_params, group) _cast_master_weights_to_fp8_delayed_scaling(
delayed_scaling_params, group, use_fsdp_shard_model_weights
)
if len(current_scaling_params) > 0: if len(current_scaling_params) > 0:
_cast_master_weights_to_fp8_current_scaling(current_scaling_params, group) _cast_master_weights_to_fp8_current_scaling(
current_scaling_params, group, use_fsdp_shard_model_weights
)
def _cast_master_weights_to_fp8_delayed_scaling(params, group): def _cast_master_weights_to_fp8_delayed_scaling(params, group, use_fsdp_shard_model_weights=False):
r"""Helper function to cast master weights to FP8 primary weights for delayed scaling. r"""Helper function to cast master weights to FP8 primary weights for delayed scaling.
Parameters Parameters
...@@ -115,13 +134,14 @@ def _cast_master_weights_to_fp8_delayed_scaling(params, group): ...@@ -115,13 +134,14 @@ def _cast_master_weights_to_fp8_delayed_scaling(params, group):
indicating the starting index of the master weight in the model weight. indicating the starting index of the master weight in the model weight.
group : The distributed group to do amax reduction. Typically it's the data parallel group : The distributed group to do amax reduction. Typically it's the data parallel
group. group.
use_fsdp_shard_model_weights : bool, if True, it means that the model weights are sharded.
""" """
# Collect amaxes to do reduce-max among dp group. # Collect amaxes to do reduce-max among dp group.
# Collect scales and scale_invs to update scale_invs of the fp8 weights. # Collect scales and scale_invs to update scale_invs of the fp8 weights.
amaxes, scales, scale_invs = [], [], [] amaxes, scales, scale_invs = [], [], []
for model_weight, master_weight, start_offset in params: for model_weight, master_weight, start_offset, shard_model_weight_raw in params:
# Reset transpose cache for all model weights. # Reset transpose cache for all model weights.
# We cannot create transpose cache here because users (like megatron) may want to overlap # We cannot create transpose cache here because users (like megatron) may want to overlap
# the all-gather of model weights and forward process, so the model weight is not updated # the all-gather of model weights and forward process, so the model weight is not updated
...@@ -147,6 +167,7 @@ def _cast_master_weights_to_fp8_delayed_scaling(params, group): ...@@ -147,6 +167,7 @@ def _cast_master_weights_to_fp8_delayed_scaling(params, group):
# master_weight may be smaller than model_weight because it could be distributed across # master_weight may be smaller than model_weight because it could be distributed across
# multiple ranks. So we need to create a dummy weight using the raw data from model_weight. # multiple ranks. So we need to create a dummy weight using the raw data from model_weight.
if not use_fsdp_shard_model_weights:
shard_model_weight_raw = model_weight._data.view(-1)[start_offset:end_offset] shard_model_weight_raw = model_weight._data.view(-1)[start_offset:end_offset]
shard_model_weight_fp8 = quantizer.create_tensor_from_data( shard_model_weight_fp8 = quantizer.create_tensor_from_data(
shard_model_weight_raw.view(1, -1), shard_model_weight_raw.view(1, -1),
...@@ -186,7 +207,7 @@ def _cast_master_weights_to_fp8_delayed_scaling(params, group): ...@@ -186,7 +207,7 @@ def _cast_master_weights_to_fp8_delayed_scaling(params, group):
) )
def _cast_master_weights_to_fp8_current_scaling(params, group): def _cast_master_weights_to_fp8_current_scaling(params, group, use_fsdp_shard_model_weights=False):
r"""Helper function to cast master weights to FP8 primary weights for current scaling. r"""Helper function to cast master weights to FP8 primary weights for current scaling.
Parameters Parameters
...@@ -195,6 +216,7 @@ def _cast_master_weights_to_fp8_current_scaling(params, group): ...@@ -195,6 +216,7 @@ def _cast_master_weights_to_fp8_current_scaling(params, group):
indicating the starting index of the master weight in the model weight. indicating the starting index of the master weight in the model weight.
group : The distributed group to do amax reduction. Typically it's the data parallel group : The distributed group to do amax reduction. Typically it's the data parallel
group. group.
use_fsdp_shard_model_weights : bool, if True, it means that the model weights are sharded.
""" """
# Parameter attributes # Parameter attributes
...@@ -219,7 +241,7 @@ def _cast_master_weights_to_fp8_current_scaling(params, group): ...@@ -219,7 +241,7 @@ def _cast_master_weights_to_fp8_current_scaling(params, group):
# amaxes in a contiguous buffer. If the master weight is None, the corresponding amax # amaxes in a contiguous buffer. If the master weight is None, the corresponding amax
# will be set to 0. # will be set to 0.
# --------------------------------------------------------------------------------------------- # ---------------------------------------------------------------------------------------------
for (model_weight, master_weight, _), amax in zip(params, amaxes): for (model_weight, master_weight, _, _), amax in zip(params, amaxes):
# Make sure all the model weights have the same numerical options. # Make sure all the model weights have the same numerical options.
quantizer = model_weight._get_quantizer() quantizer = model_weight._get_quantizer()
...@@ -260,7 +282,9 @@ def _cast_master_weights_to_fp8_current_scaling(params, group): ...@@ -260,7 +282,9 @@ def _cast_master_weights_to_fp8_current_scaling(params, group):
# --------------------------------------------------------------------------------------------- # ---------------------------------------------------------------------------------------------
# Step 4: Cast master weights to FP8. # Step 4: Cast master weights to FP8.
# --------------------------------------------------------------------------------------------- # ---------------------------------------------------------------------------------------------
for (model_weight, master_weight, start_offset), scale in zip(params, scales): for (model_weight, master_weight, start_offset, model_weight_fragment), scale in zip(
params, scales
):
# Reset transpose cache for all model weights. # Reset transpose cache for all model weights.
# We cannot create transpose cache here because users (like megatron) may want to overlap # We cannot create transpose cache here because users (like megatron) may want to overlap
# the all-gather of model weights and forward process, so the model weight is not updated # the all-gather of model weights and forward process, so the model weight is not updated
...@@ -274,6 +298,7 @@ def _cast_master_weights_to_fp8_current_scaling(params, group): ...@@ -274,6 +298,7 @@ def _cast_master_weights_to_fp8_current_scaling(params, group):
# Cast master weight to FP8 # Cast master weight to FP8
end_offset = start_offset + master_weight.numel() end_offset = start_offset + master_weight.numel()
if not use_fsdp_shard_model_weights:
model_weight_fragment = model_weight.reshape(-1)[start_offset:end_offset] model_weight_fragment = model_weight.reshape(-1)[start_offset:end_offset]
quantizer = Float8Quantizer( quantizer = Float8Quantizer(
scale=scale, scale=scale,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment