Unverified Commit 2712bb95 authored by Kunlun Li's avatar Kunlun Li Committed by GitHub
Browse files

Add post-processing API for FP8 primary weights to support CUDA Graph (#2266)



* Add post-processing API for FP8 primary weights to support CUDA Graph
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>

* Add post-processing support for plain pytorch tensors
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>

* Update type hint
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

---------
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent ce2f9fa4
......@@ -27,7 +27,7 @@ from transformer_engine.pytorch import (
Float8BlockwiseQTensor,
)
from transformer_engine.pytorch.tensor import cast_master_weights_to_fp8
from transformer_engine.pytorch.tensor.utils import replace_raw_data
from transformer_engine.pytorch.tensor.utils import post_all_gather_processing, replace_raw_data
def _get_raw_data(quantized_tensor):
......@@ -203,12 +203,15 @@ class MiniZero_1:
# -----------------------------------------------------------------------------------------
# Step 7: Copy the gathered weights from weight buffer to the actual weights
# -----------------------------------------------------------------------------------------
quantized_weights = []
for weight, offset in zip(self.weights, self.offsets[:-1]):
start = offset
end = offset + weight.numel()
if isinstance(weight, QuantizedTensor):
quantized_weights.append(weight)
weight = _get_raw_data(weight)
weight.view(-1).data.copy_(self.weight_buffer[start:end])
post_all_gather_processing(quantized_weights)
class MiniOptimizer:
......@@ -252,10 +255,6 @@ class MiniFSDP:
self.dp_group = dp_group
# Flatten the weights and pad to align with world size
raw_data_list = [
_get_raw_data(w).view(-1) if isinstance(w, QuantizedTensor) else w.view(-1)
for w in weights
]
if isinstance(weights[0], QuantizedTensor):
raw_data_list = [_get_raw_data(w).view(-1) for w in weights]
else:
......@@ -264,7 +263,9 @@ class MiniFSDP:
# Split flattened weights into shards
self.local_weight_shard = torch.chunk(self.flatten_weight, world_size)[rank]
self.local_main_grad_shard = torch.zeros_like(self.local_weight_shard)
self.local_main_grad_shard = torch.zeros_like(
self.local_weight_shard, dtype=torch.float32, device="cuda"
)
shard_size = self.flatten_weight.size(0) // world_size
# Map original tensors to flattened indices
......@@ -341,9 +342,8 @@ class MiniFSDP:
padding_needed = (world_size - original_length % world_size) % world_size
if padding_needed > 0:
flatten_tensor = torch.cat(
[flatten_tensor, torch.zeros(padding_needed, dtype=flatten_tensor.dtype)]
)
zeros = torch.zeros(padding_needed, dtype=flatten_tensor.dtype, device="cuda")
flatten_tensor = torch.cat([flatten_tensor, zeros])
return flatten_tensor, original_length
......@@ -369,10 +369,10 @@ class MiniFSDP:
main_grad_buffer, _ = self._flatten_tensors_with_pad(
[weight.main_grad.view(-1) for weight in self.weights]
)
main_grad_buffer = main_grad_buffer.to(self.local_main_grad_shard.dtype)
dist.reduce_scatter_tensor(
self.local_main_grad_shard, main_grad_buffer, group=self.dp_group
)
self.local_main_grad_shard /= dist.get_world_size(self.dp_group)
# Step 2: Update the master weights
for weight, master_weight, (shard_start, shard_end) in zip(
......@@ -416,6 +416,11 @@ class MiniFSDP:
dist.all_gather_into_tensor(
self.flatten_weight, self.local_weight_shard, group=self.dp_group
)
quantized_weights = []
for weight in self.weights:
if isinstance(weight, QuantizedTensor):
quantized_weights.append(weight)
post_all_gather_processing(quantized_weights)
def _test_fsdp_cast_master_weights_to_fp8(quantization, dp_group):
......@@ -435,7 +440,7 @@ def _test_fsdp_cast_master_weights_to_fp8(quantization, dp_group):
linear_kwargs = {
"params_dtype": torch.bfloat16,
"bias": False,
"fuse_wgrad_accumulation": False,
"fuse_wgrad_accumulation": True,
}
# Create model with FP8 weights
......@@ -503,14 +508,9 @@ def _test_fsdp_cast_master_weights_to_fp8(quantization, dp_group):
torch.testing.assert_close(loss_fp8, loss, atol=0, rtol=0)
print(
f"✅ Successfully validated FSDP {NUM_STEPS} training steps with"
f" {quantization} quantization"
)
def _test_zero_1(dp_group):
"""Make sure the implementation of zero-1 optimizer is correct"""
def _test_mini_optimizer(dp_group):
"""Make sure the implementation of MiniZero_1 and MiniFSDP is correct"""
rank = dist.get_rank(dp_group)
world_size = dist.get_world_size(dp_group)
......@@ -525,13 +525,15 @@ def _test_zero_1(dp_group):
weights_1 = weights
weights_2 = [weight.clone() for weight in weights]
weights_3 = [weight.clone() for weight in weights]
lr = 1.0
optimizer_1 = MiniZero_1(weights_1, lr, dp_group)
optimizer_2 = MiniOptimizer(weights_2, lr, dp_group)
optimizer_3 = MiniFSDP(weights_3, lr, dp_group)
for _ in range(100):
for w1, w2 in zip(weights_1, weights_2):
for w1, w2, w3 in zip(weights_1, weights_2, weights_3):
main_grads = [
torch.randn_like(w1, dtype=torch.float32, device="cuda") for _ in range(world_size)
]
......@@ -539,12 +541,16 @@ def _test_zero_1(dp_group):
main_grad = main_grads[rank]
w1.main_grad = main_grad
w2.main_grad = main_grad
w3.main_grad = main_grad
optimizer_1.step()
optimizer_2.step()
optimizer_3.step()
for w1, w2 in zip(weights_1, weights_2):
torch.testing.assert_close(w1, w2, atol=0, rtol=0)
for w1, w3 in zip(weights_1, weights_3):
torch.testing.assert_close(w1, w3, atol=0, rtol=0)
def quantization_recipe(quantization) -> Recipe:
......@@ -671,7 +677,7 @@ def main(argv=None, namespace=None):
args = parser.parse_args(argv, namespace)
dp_group = dist.new_group(backend="nccl")
_test_zero_1(dp_group)
_test_mini_optimizer(dp_group)
_test_cast_master_weights_to_fp8(args.quantization, dp_group)
_test_fsdp_cast_master_weights_to_fp8(args.quantization, dp_group)
......
......@@ -5,7 +5,7 @@
"""Helper functions for using fp8 tensors as weights"""
import os
from typing import Optional, Union
from typing import Optional, List, Union
import torch
import transformer_engine_torch as tex
from transformer_engine_torch import multi_tensor_scale, multi_tensor_compute_scale_and_scale_inv
......@@ -15,6 +15,7 @@ from .float8_tensor import Float8Tensor, Float8Quantizer, Float8CurrentScalingQu
from .mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer
from .float8_blockwise_tensor import Float8BlockwiseQTensor, Float8BlockQuantizer
from ..optimizers.multi_tensor_apply import multi_tensor_applier
from ..utils import is_non_tn_fp8_gemm_supported
def replace_raw_data(tensor: QuantizedTensor, new_raw_data: torch.Tensor):
......@@ -159,12 +160,6 @@ def _cast_master_weights_to_fp8_delayed_scaling(params, group, use_fsdp_shard_mo
amaxes, scales, scale_invs = [], [], []
for model_weight, master_weight, start_offset, shard_model_weight_raw in params:
# Reset transpose cache for all model weights.
# We cannot create transpose cache here because users (like megatron) may want to overlap
# the all-gather of model weights and forward process, so the model weight is not updated
# currently.
model_weight._reset_caches()
quantizer = model_weight._get_quantizer()
amaxes.append(quantizer.amax.view(1))
......@@ -302,12 +297,6 @@ def _cast_master_weights_to_fp8_current_scaling(params, group, use_fsdp_shard_mo
for (model_weight, master_weight, start_offset, model_weight_fragment), scale in zip(
params, scales
):
# Reset transpose cache for all model weights.
# We cannot create transpose cache here because users (like megatron) may want to overlap
# the all-gather of model weights and forward process, so the model weight is not updated
# currently.
model_weight._reset_caches()
# If master weight is None, it means that the master weight of the current model weight
# is in other DP ranks.
if master_weight is None:
......@@ -432,12 +421,6 @@ def _cast_master_weights_to_fp8_blockwise_scaling(
for (model_weight, master_weight, start_offset, model_weight_fragment), scale in zip(
params, scales
):
# Clear columnwise data for all model weights.
# We cannot create columnwise data here because users (like megatron) may want to overlap
# the all-gather of model weights and forward process, so the model weight is not updated
# at this moment.
model_weight.update_usage(rowwise_usage=True, columnwise_usage=False)
# If master weight is None, it means that the master weight of the current model weight
# is in other DP ranks.
if master_weight is None:
......@@ -454,6 +437,28 @@ def _cast_master_weights_to_fp8_blockwise_scaling(
)
def post_all_gather_processing(model_weights: Union[torch.Tensor, List[torch.Tensor]]):
"""
Post-processing after all-gather for weights in distributed optimizer.
- Float8Tensor: may need to create a transposed view to match backend GEMM.
- Float8BlockwiseQTensor: create column-wise storage.
- Plain pytorch tensor: noop.
"""
if not isinstance(model_weights, list):
model_weights = [model_weights]
for model_weight in model_weights:
if isinstance(model_weight, Float8Tensor):
# Delayed scaling and per-tensor current scaling: if backend does not support
# non-transposed FP8 GEMM, pre-create the transpose.
if not is_non_tn_fp8_gemm_supported():
model_weight._create_transpose()
elif isinstance(model_weight, Float8BlockwiseQTensor):
# Blockwise scaling: create column-wise storage.
model_weight._create_columnwise()
elif isinstance(model_weight, QuantizedTensor):
raise ValueError(f"post_processing for {type(model_weight)} is not supported")
def is_experimental(x: Optional[Union[Quantizer, QuantizedTensorStorage]] = None) -> bool:
"""Check if an environment or object is using experimental Kitchen middleware.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment