Unverified Commit 2712bb95 authored by Kunlun Li's avatar Kunlun Li Committed by GitHub
Browse files

Add post-processing API for FP8 primary weights to support CUDA Graph (#2266)



* Add post-processing API for FP8 primary weights to support CUDA Graph
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>

* Add post-processing support for plain pytorch tensors
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>

* Update type hint
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

---------
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent ce2f9fa4
...@@ -27,7 +27,7 @@ from transformer_engine.pytorch import ( ...@@ -27,7 +27,7 @@ from transformer_engine.pytorch import (
Float8BlockwiseQTensor, Float8BlockwiseQTensor,
) )
from transformer_engine.pytorch.tensor import cast_master_weights_to_fp8 from transformer_engine.pytorch.tensor import cast_master_weights_to_fp8
from transformer_engine.pytorch.tensor.utils import replace_raw_data from transformer_engine.pytorch.tensor.utils import post_all_gather_processing, replace_raw_data
def _get_raw_data(quantized_tensor): def _get_raw_data(quantized_tensor):
...@@ -203,12 +203,15 @@ class MiniZero_1: ...@@ -203,12 +203,15 @@ class MiniZero_1:
# ----------------------------------------------------------------------------------------- # -----------------------------------------------------------------------------------------
# Step 7: Copy the gathered weights from weight buffer to the actual weights # Step 7: Copy the gathered weights from weight buffer to the actual weights
# ----------------------------------------------------------------------------------------- # -----------------------------------------------------------------------------------------
quantized_weights = []
for weight, offset in zip(self.weights, self.offsets[:-1]): for weight, offset in zip(self.weights, self.offsets[:-1]):
start = offset start = offset
end = offset + weight.numel() end = offset + weight.numel()
if isinstance(weight, QuantizedTensor): if isinstance(weight, QuantizedTensor):
quantized_weights.append(weight)
weight = _get_raw_data(weight) weight = _get_raw_data(weight)
weight.view(-1).data.copy_(self.weight_buffer[start:end]) weight.view(-1).data.copy_(self.weight_buffer[start:end])
post_all_gather_processing(quantized_weights)
class MiniOptimizer: class MiniOptimizer:
...@@ -252,10 +255,6 @@ class MiniFSDP: ...@@ -252,10 +255,6 @@ class MiniFSDP:
self.dp_group = dp_group self.dp_group = dp_group
# Flatten the weights and pad to align with world size # Flatten the weights and pad to align with world size
raw_data_list = [
_get_raw_data(w).view(-1) if isinstance(w, QuantizedTensor) else w.view(-1)
for w in weights
]
if isinstance(weights[0], QuantizedTensor): if isinstance(weights[0], QuantizedTensor):
raw_data_list = [_get_raw_data(w).view(-1) for w in weights] raw_data_list = [_get_raw_data(w).view(-1) for w in weights]
else: else:
...@@ -264,7 +263,9 @@ class MiniFSDP: ...@@ -264,7 +263,9 @@ class MiniFSDP:
# Split flattened weights into shards # Split flattened weights into shards
self.local_weight_shard = torch.chunk(self.flatten_weight, world_size)[rank] self.local_weight_shard = torch.chunk(self.flatten_weight, world_size)[rank]
self.local_main_grad_shard = torch.zeros_like(self.local_weight_shard) self.local_main_grad_shard = torch.zeros_like(
self.local_weight_shard, dtype=torch.float32, device="cuda"
)
shard_size = self.flatten_weight.size(0) // world_size shard_size = self.flatten_weight.size(0) // world_size
# Map original tensors to flattened indices # Map original tensors to flattened indices
...@@ -341,9 +342,8 @@ class MiniFSDP: ...@@ -341,9 +342,8 @@ class MiniFSDP:
padding_needed = (world_size - original_length % world_size) % world_size padding_needed = (world_size - original_length % world_size) % world_size
if padding_needed > 0: if padding_needed > 0:
flatten_tensor = torch.cat( zeros = torch.zeros(padding_needed, dtype=flatten_tensor.dtype, device="cuda")
[flatten_tensor, torch.zeros(padding_needed, dtype=flatten_tensor.dtype)] flatten_tensor = torch.cat([flatten_tensor, zeros])
)
return flatten_tensor, original_length return flatten_tensor, original_length
...@@ -369,10 +369,10 @@ class MiniFSDP: ...@@ -369,10 +369,10 @@ class MiniFSDP:
main_grad_buffer, _ = self._flatten_tensors_with_pad( main_grad_buffer, _ = self._flatten_tensors_with_pad(
[weight.main_grad.view(-1) for weight in self.weights] [weight.main_grad.view(-1) for weight in self.weights]
) )
main_grad_buffer = main_grad_buffer.to(self.local_main_grad_shard.dtype)
dist.reduce_scatter_tensor( dist.reduce_scatter_tensor(
self.local_main_grad_shard, main_grad_buffer, group=self.dp_group self.local_main_grad_shard, main_grad_buffer, group=self.dp_group
) )
self.local_main_grad_shard /= dist.get_world_size(self.dp_group)
# Step 2: Update the master weights # Step 2: Update the master weights
for weight, master_weight, (shard_start, shard_end) in zip( for weight, master_weight, (shard_start, shard_end) in zip(
...@@ -416,6 +416,11 @@ class MiniFSDP: ...@@ -416,6 +416,11 @@ class MiniFSDP:
dist.all_gather_into_tensor( dist.all_gather_into_tensor(
self.flatten_weight, self.local_weight_shard, group=self.dp_group self.flatten_weight, self.local_weight_shard, group=self.dp_group
) )
quantized_weights = []
for weight in self.weights:
if isinstance(weight, QuantizedTensor):
quantized_weights.append(weight)
post_all_gather_processing(quantized_weights)
def _test_fsdp_cast_master_weights_to_fp8(quantization, dp_group): def _test_fsdp_cast_master_weights_to_fp8(quantization, dp_group):
...@@ -435,7 +440,7 @@ def _test_fsdp_cast_master_weights_to_fp8(quantization, dp_group): ...@@ -435,7 +440,7 @@ def _test_fsdp_cast_master_weights_to_fp8(quantization, dp_group):
linear_kwargs = { linear_kwargs = {
"params_dtype": torch.bfloat16, "params_dtype": torch.bfloat16,
"bias": False, "bias": False,
"fuse_wgrad_accumulation": False, "fuse_wgrad_accumulation": True,
} }
# Create model with FP8 weights # Create model with FP8 weights
...@@ -503,14 +508,9 @@ def _test_fsdp_cast_master_weights_to_fp8(quantization, dp_group): ...@@ -503,14 +508,9 @@ def _test_fsdp_cast_master_weights_to_fp8(quantization, dp_group):
torch.testing.assert_close(loss_fp8, loss, atol=0, rtol=0) torch.testing.assert_close(loss_fp8, loss, atol=0, rtol=0)
print(
f"✅ Successfully validated FSDP {NUM_STEPS} training steps with"
f" {quantization} quantization"
)
def _test_zero_1(dp_group): def _test_mini_optimizer(dp_group):
"""Make sure the implementation of zero-1 optimizer is correct""" """Make sure the implementation of MiniZero_1 and MiniFSDP is correct"""
rank = dist.get_rank(dp_group) rank = dist.get_rank(dp_group)
world_size = dist.get_world_size(dp_group) world_size = dist.get_world_size(dp_group)
...@@ -525,13 +525,15 @@ def _test_zero_1(dp_group): ...@@ -525,13 +525,15 @@ def _test_zero_1(dp_group):
weights_1 = weights weights_1 = weights
weights_2 = [weight.clone() for weight in weights] weights_2 = [weight.clone() for weight in weights]
weights_3 = [weight.clone() for weight in weights]
lr = 1.0 lr = 1.0
optimizer_1 = MiniZero_1(weights_1, lr, dp_group) optimizer_1 = MiniZero_1(weights_1, lr, dp_group)
optimizer_2 = MiniOptimizer(weights_2, lr, dp_group) optimizer_2 = MiniOptimizer(weights_2, lr, dp_group)
optimizer_3 = MiniFSDP(weights_3, lr, dp_group)
for _ in range(100): for _ in range(100):
for w1, w2 in zip(weights_1, weights_2): for w1, w2, w3 in zip(weights_1, weights_2, weights_3):
main_grads = [ main_grads = [
torch.randn_like(w1, dtype=torch.float32, device="cuda") for _ in range(world_size) torch.randn_like(w1, dtype=torch.float32, device="cuda") for _ in range(world_size)
] ]
...@@ -539,12 +541,16 @@ def _test_zero_1(dp_group): ...@@ -539,12 +541,16 @@ def _test_zero_1(dp_group):
main_grad = main_grads[rank] main_grad = main_grads[rank]
w1.main_grad = main_grad w1.main_grad = main_grad
w2.main_grad = main_grad w2.main_grad = main_grad
w3.main_grad = main_grad
optimizer_1.step() optimizer_1.step()
optimizer_2.step() optimizer_2.step()
optimizer_3.step()
for w1, w2 in zip(weights_1, weights_2): for w1, w2 in zip(weights_1, weights_2):
torch.testing.assert_close(w1, w2, atol=0, rtol=0) torch.testing.assert_close(w1, w2, atol=0, rtol=0)
for w1, w3 in zip(weights_1, weights_3):
torch.testing.assert_close(w1, w3, atol=0, rtol=0)
def quantization_recipe(quantization) -> Recipe: def quantization_recipe(quantization) -> Recipe:
...@@ -671,7 +677,7 @@ def main(argv=None, namespace=None): ...@@ -671,7 +677,7 @@ def main(argv=None, namespace=None):
args = parser.parse_args(argv, namespace) args = parser.parse_args(argv, namespace)
dp_group = dist.new_group(backend="nccl") dp_group = dist.new_group(backend="nccl")
_test_zero_1(dp_group) _test_mini_optimizer(dp_group)
_test_cast_master_weights_to_fp8(args.quantization, dp_group) _test_cast_master_weights_to_fp8(args.quantization, dp_group)
_test_fsdp_cast_master_weights_to_fp8(args.quantization, dp_group) _test_fsdp_cast_master_weights_to_fp8(args.quantization, dp_group)
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
"""Helper functions for using fp8 tensors as weights""" """Helper functions for using fp8 tensors as weights"""
import os import os
from typing import Optional, Union from typing import Optional, List, Union
import torch import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine_torch import multi_tensor_scale, multi_tensor_compute_scale_and_scale_inv from transformer_engine_torch import multi_tensor_scale, multi_tensor_compute_scale_and_scale_inv
...@@ -15,6 +15,7 @@ from .float8_tensor import Float8Tensor, Float8Quantizer, Float8CurrentScalingQu ...@@ -15,6 +15,7 @@ from .float8_tensor import Float8Tensor, Float8Quantizer, Float8CurrentScalingQu
from .mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer from .mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer
from .float8_blockwise_tensor import Float8BlockwiseQTensor, Float8BlockQuantizer from .float8_blockwise_tensor import Float8BlockwiseQTensor, Float8BlockQuantizer
from ..optimizers.multi_tensor_apply import multi_tensor_applier from ..optimizers.multi_tensor_apply import multi_tensor_applier
from ..utils import is_non_tn_fp8_gemm_supported
def replace_raw_data(tensor: QuantizedTensor, new_raw_data: torch.Tensor): def replace_raw_data(tensor: QuantizedTensor, new_raw_data: torch.Tensor):
...@@ -159,12 +160,6 @@ def _cast_master_weights_to_fp8_delayed_scaling(params, group, use_fsdp_shard_mo ...@@ -159,12 +160,6 @@ def _cast_master_weights_to_fp8_delayed_scaling(params, group, use_fsdp_shard_mo
amaxes, scales, scale_invs = [], [], [] amaxes, scales, scale_invs = [], [], []
for model_weight, master_weight, start_offset, shard_model_weight_raw in params: for model_weight, master_weight, start_offset, shard_model_weight_raw in params:
# Reset transpose cache for all model weights.
# We cannot create transpose cache here because users (like megatron) may want to overlap
# the all-gather of model weights and forward process, so the model weight is not updated
# currently.
model_weight._reset_caches()
quantizer = model_weight._get_quantizer() quantizer = model_weight._get_quantizer()
amaxes.append(quantizer.amax.view(1)) amaxes.append(quantizer.amax.view(1))
...@@ -302,12 +297,6 @@ def _cast_master_weights_to_fp8_current_scaling(params, group, use_fsdp_shard_mo ...@@ -302,12 +297,6 @@ def _cast_master_weights_to_fp8_current_scaling(params, group, use_fsdp_shard_mo
for (model_weight, master_weight, start_offset, model_weight_fragment), scale in zip( for (model_weight, master_weight, start_offset, model_weight_fragment), scale in zip(
params, scales params, scales
): ):
# Reset transpose cache for all model weights.
# We cannot create transpose cache here because users (like megatron) may want to overlap
# the all-gather of model weights and forward process, so the model weight is not updated
# currently.
model_weight._reset_caches()
# If master weight is None, it means that the master weight of the current model weight # If master weight is None, it means that the master weight of the current model weight
# is in other DP ranks. # is in other DP ranks.
if master_weight is None: if master_weight is None:
...@@ -432,12 +421,6 @@ def _cast_master_weights_to_fp8_blockwise_scaling( ...@@ -432,12 +421,6 @@ def _cast_master_weights_to_fp8_blockwise_scaling(
for (model_weight, master_weight, start_offset, model_weight_fragment), scale in zip( for (model_weight, master_weight, start_offset, model_weight_fragment), scale in zip(
params, scales params, scales
): ):
# Clear columnwise data for all model weights.
# We cannot create columnwise data here because users (like megatron) may want to overlap
# the all-gather of model weights and forward process, so the model weight is not updated
# at this moment.
model_weight.update_usage(rowwise_usage=True, columnwise_usage=False)
# If master weight is None, it means that the master weight of the current model weight # If master weight is None, it means that the master weight of the current model weight
# is in other DP ranks. # is in other DP ranks.
if master_weight is None: if master_weight is None:
...@@ -454,6 +437,28 @@ def _cast_master_weights_to_fp8_blockwise_scaling( ...@@ -454,6 +437,28 @@ def _cast_master_weights_to_fp8_blockwise_scaling(
) )
def post_all_gather_processing(model_weights: Union[torch.Tensor, List[torch.Tensor]]):
"""
Post-processing after all-gather for weights in distributed optimizer.
- Float8Tensor: may need to create a transposed view to match backend GEMM.
- Float8BlockwiseQTensor: create column-wise storage.
- Plain pytorch tensor: noop.
"""
if not isinstance(model_weights, list):
model_weights = [model_weights]
for model_weight in model_weights:
if isinstance(model_weight, Float8Tensor):
# Delayed scaling and per-tensor current scaling: if backend does not support
# non-transposed FP8 GEMM, pre-create the transpose.
if not is_non_tn_fp8_gemm_supported():
model_weight._create_transpose()
elif isinstance(model_weight, Float8BlockwiseQTensor):
# Blockwise scaling: create column-wise storage.
model_weight._create_columnwise()
elif isinstance(model_weight, QuantizedTensor):
raise ValueError(f"post_processing for {type(model_weight)} is not supported")
def is_experimental(x: Optional[Union[Quantizer, QuantizedTensorStorage]] = None) -> bool: def is_experimental(x: Optional[Union[Quantizer, QuantizedTensorStorage]] = None) -> bool:
"""Check if an environment or object is using experimental Kitchen middleware. """Check if an environment or object is using experimental Kitchen middleware.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment