diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 732f0a1..29f40bb 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -39,6 +39,7 @@ python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py" NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python3 -m pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || test_fail "test_fused_attn.py" NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python3 -m pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_paged_attn.py || test_fail "test_paged_attn.py" + if [ "$RET" -ne 0 ]; then echo "Error in the following test cases:$FAILED_CASES" exit 1 diff --git a/qa/L1_pytorch_distributed_unittest/test.sh b/qa/L1_pytorch_distributed_unittest/test.sh index 5776734..36d491e 100644 --- a/qa/L1_pytorch_distributed_unittest/test.sh +++ b/qa/L1_pytorch_distributed_unittest/test.sh @@ -26,6 +26,7 @@ python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py | python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py" # python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || test_fail "test_fusible_ops_with_userbuffers.py" ### TODO Debug UB support with te.Sequential python3 -m pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py || test_fail "test_fused_attn_with_cp.py" +python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py" if [ "$RET" -ne 0 ]; then echo "Error in the following test cases:$FAILED_CASES" diff --git a/tests/pytorch/distributed/run_cast_master_weights_to_fp8.py b/tests/pytorch/distributed/run_cast_master_weights_to_fp8.py new file mode 100644 index 0000000..1b38f72 --- /dev/null +++ b/tests/pytorch/distributed/run_cast_master_weights_to_fp8.py @@ -0,0 +1,671 @@ +#!/usr/bin/python3 + +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import argparse +import datetime +import os +import sys + +import torch +from torch import nn +import torch.distributed as dist + +from transformer_engine.common.recipe import ( + DelayedScaling, + Float8CurrentScaling, + Format, + Recipe, +) +import transformer_engine.pytorch as te +from transformer_engine.pytorch.tensor import QuantizedTensor, cast_master_weights_to_fp8 +from transformer_engine.pytorch.tensor.float8_tensor import ( + Float8Tensor, + Float8CurrentScalingQuantizer, +) +from transformer_engine.pytorch.tensor.utils import replace_raw_data + + +def _get_raw_data(quantized_tensor): + """Get the underlying data of a quantized tensor, used in zero-1 optimizer""" + if isinstance(quantized_tensor, Float8Tensor): + assert hasattr(quantized_tensor, "_data"), "Float8Tensor does not have _data attribute" + assert quantized_tensor._data.dtype == torch.uint8, "Float8Tensor _data must be uint8" + return quantized_tensor._data + else: + raise ValueError(f"Unsupported quantized tensor type: {type(quantized_tensor)}") + + +class MiniZero_1: + """A mini zero-1 optimizer implementation, just used for this test""" + + def __init__(self, weights, lr, dp_group): + self.rank = dist.get_rank(dp_group) + self.world_size = dist.get_world_size(dp_group) + + self.weights = weights + self.lr = lr + self.dp_group = dp_group + + # [self.offsets[i], self.offsets[i+1]) is the range of weights[i] in the global buffer + self.offsets = [0] + for weight in self.weights: + self.offsets.append(self.offsets[-1] + weight.numel()) + + # Padding to avoid global buffer cannot be divided by world size, so the offsets[-1] may + # not be the end range of the last weight. + if self.offsets[-1] % self.world_size != 0: + self.offsets[-1] += self.world_size - self.offsets[-1] % self.world_size + + self.master_weights = [] + # The start offset of the master weight in the weight + self.start_offsets = [] + # The overlapping area of the weight and this rank's local buffer + self.overlapping_areas = [] + + # The start and end of this rank's local buffer in the global buffer + rank_start = self.offsets[-1] // self.world_size * self.rank + rank_end = rank_start + self.offsets[-1] // self.world_size + + for weight, offset in zip(self.weights, self.offsets[:-1]): + if offset >= rank_end or (offset + weight.numel()) <= rank_start: + # This weight is not in this rank's local buffer + master_weight = None + start_offset = None + overlapping_area = None + else: + overlapping_start = max(rank_start, offset) + overlapping_end = min(rank_end, offset + weight.numel()) + length = overlapping_end - overlapping_start + start_offset = overlapping_start - offset + if isinstance(weight, QuantizedTensor): + # If weight is a FP8 tensor, we need to use the original high precision version + # to initialize the master weight. + high_precision_init_val = weight.get_high_precision_init_val().view(-1) + master_weight = high_precision_init_val.to(weight.device).float()[ + start_offset : start_offset + length + ] + else: + master_weight = ( + weight.detach().view(-1).float()[start_offset : start_offset + length] + ) + overlapping_area = (overlapping_start, overlapping_end) + self.master_weights.append(master_weight) + self.start_offsets.append(start_offset) + self.overlapping_areas.append(overlapping_area) + + # Create global buffer for grads reduce-scatter + self.grad_buffer = torch.empty( + [self.offsets[-1]], dtype=torch.float32, device=weights[0].device + ) + self.grad_buffer_slice = self.grad_buffer[rank_start:rank_end] + + # Create global buffer for weights all-gather + if isinstance(self.weights[0], QuantizedTensor): + weight_buffer_dtype = torch.uint8 + else: + weight_buffer_dtype = weights[0].dtype + self.weight_buffer = torch.empty( + [self.offsets[-1]], dtype=weight_buffer_dtype, device=weights[0].device + ) + self.weight_buffer_slice = self.weight_buffer[rank_start:rank_end] + + def step(self): + # ----------------------------------------------------------------------------------------- + # Step 1: Copy grads to the grad buffer + # ----------------------------------------------------------------------------------------- + for weight, offset in zip(self.weights, self.offsets[:-1]): + start = offset + end = offset + weight.numel() + self.grad_buffer[start:end].copy_(weight.main_grad.view(-1)) + + # ----------------------------------------------------------------------------------------- + # Step 2: Grads reduce-scatter + # ----------------------------------------------------------------------------------------- + # Don't use reduce_scatter directly to explicitly control the reduce order. + # dist.reduce_scatter_tensor(self.grad_buffer_slice, self.grad_buffer, op=dist.ReduceOp.AVG, + # group=self.dp_group) + buffers = [torch.empty_like(self.grad_buffer) for _ in range(self.world_size)] + dist.all_gather(buffers, self.grad_buffer, group=self.dp_group) + for i in range(1, self.world_size): + buffers[0] += buffers[i] + rank_start = self.offsets[-1] // self.world_size * self.rank + rank_end = rank_start + self.offsets[-1] // self.world_size + self.grad_buffer_slice.copy_(buffers[0][rank_start:rank_end]) + self.grad_buffer_slice /= self.world_size + + # ----------------------------------------------------------------------------------------- + # Step 3: Update master weights + # ----------------------------------------------------------------------------------------- + for master_weight, overlapping_area in zip(self.master_weights, self.overlapping_areas): + if master_weight is None: + # This weight's master weight is in other rank. + continue + grad = self.grad_buffer[overlapping_area[0] : overlapping_area[1]] + master_weight -= grad * self.lr + + # ----------------------------------------------------------------------------------------- + # Step 4: Cast master weights to BF16 or FP8, depending on the type of the weight + # ----------------------------------------------------------------------------------------- + if isinstance(self.weights[0], QuantizedTensor): + # FP8 weights case + for i in range(1, len(self.weights)): + assert isinstance(self.weights[i], QuantizedTensor) + cast_master_weights_to_fp8( + self.weights, self.master_weights, self.start_offsets, self.dp_group + ) + else: + # BF16 weights case + for weight, master_weight, start_offset in zip( + self.weights, self.master_weights, self.start_offsets + ): + if master_weight is None: + continue + start = start_offset + end = start_offset + master_weight.numel() + weight.data.view(-1)[start:end].copy_(master_weight) + + # ----------------------------------------------------------------------------------------- + # Step 5: Copy the updated weights (not all weights) to the weight buffer + # ----------------------------------------------------------------------------------------- + for i in range(len(self.weights)): + master_weight = self.master_weights[i] + if master_weight is None: + continue + start_offset = self.start_offsets[i] + if isinstance(self.weights[i], QuantizedTensor): + weight = _get_raw_data(self.weights[i]) + else: + weight = self.weights[i] + weight_slice = weight.view(-1)[start_offset : start_offset + master_weight.numel()] + overlapping_start, overlapping_end = self.overlapping_areas[i] + self.weight_buffer[overlapping_start:overlapping_end].copy_(weight_slice) + + # ----------------------------------------------------------------------------------------- + # Step 6: Weight all-gather (FP8 or BF16) + # ----------------------------------------------------------------------------------------- + dist.all_gather_into_tensor( + self.weight_buffer, self.weight_buffer_slice, group=self.dp_group + ) + + # ----------------------------------------------------------------------------------------- + # Step 7: Copy the gathered weights from weight buffer to the actual weights + # ----------------------------------------------------------------------------------------- + for weight, offset in zip(self.weights, self.offsets[:-1]): + start = offset + end = offset + weight.numel() + if isinstance(weight, QuantizedTensor): + weight = _get_raw_data(weight) + weight.view(-1).data.copy_(self.weight_buffer[start:end]) + + +class MiniOptimizer: + + def __init__(self, weights, lr, dp_group): + self.world_size = dist.get_world_size(dp_group) + + self.weights = weights + self.lr = lr + self.dp_group = dp_group + + master_weights = [] + for weight in self.weights: + master_weights.append(weight.detach().float()) + self.master_weights = master_weights + + def step(self): + for weight, master_weight in zip(self.weights, self.master_weights): + main_grad = weight.main_grad + + # Don't use all-reduce directly to explicitly control the reduce order. + # dist.all_reduce(main_grad, op=dist.ReduceOp.AVG, group=self.dp_group) + buffers = [torch.empty_like(main_grad) for _ in range(self.world_size)] + dist.all_gather(buffers, main_grad, group=self.dp_group) + for i in range(1, self.world_size): + buffers[0] += buffers[i] + main_grad.copy_(buffers[0]) + main_grad /= self.world_size + + master_weight -= main_grad * self.lr + weight.data.copy_(master_weight) + + +class MiniFSDP: + def __init__(self, weights, lr, dp_group): + rank = dist.get_rank(dp_group) + world_size = dist.get_world_size(dp_group) + + self.weights = weights + self.lr = lr + self.dp_group = dp_group + + # Flatten the weights and pad to align with world size + raw_data_list = [ + _get_raw_data(w).view(-1) if isinstance(w, QuantizedTensor) else w.view(-1) + for w in weights + ] + if isinstance(weights[0], QuantizedTensor): + raw_data_list = [_get_raw_data(w).view(-1) for w in weights] + else: + raw_data_list = [w.view(-1) for w in weights] + self.flatten_weight, original_length = self._flatten_tensors_with_pad(raw_data_list) + + # Split flattened weights into shards + self.local_weight_shard = torch.chunk(self.flatten_weight, world_size)[rank] + self.local_main_grad_shard = torch.zeros_like(self.local_weight_shard) + shard_size = self.flatten_weight.size(0) // world_size + + # Map original tensors to flattened indices + tensor_indices = [] + cumulative_length = 0 + for tensor in raw_data_list: + length = tensor.size(0) + tensor_indices.append((cumulative_length, cumulative_length + length)) + cumulative_length += length + + # Build shard index mappings + self.weight_indices = [] + self.shard_indices = [] + for idx, (start, end) in enumerate(tensor_indices): + shard_start = rank * shard_size + shard_end = shard_start + shard_size + adjusted_end = min(shard_end, original_length) + + if start <= adjusted_end and end >= shard_start: + start_idx = max(start, shard_start) + end_idx = min(end, adjusted_end) + self.weight_indices.append((start_idx - start, end_idx - start)) + self.shard_indices.append((start_idx - shard_start, end_idx - shard_start)) + else: + self.weight_indices.append((None, None)) + self.shard_indices.append((None, None)) + + if isinstance(weights[idx], QuantizedTensor): + replace_raw_data( + weights[idx], self.flatten_weight[start:end].view(weights[idx].shape) + ) + else: + weights[idx].data = self.flatten_weight[start:end].view(weights[idx].shape) + + # Initialize local model weights and high-precision master weights + self.local_weights = [] + self.master_weights = [] + for i, weight in enumerate(self.weights): + weight_start, weight_end = self.weight_indices[i] + shard_start, shard_end = self.shard_indices[i] + if shard_start is not None and shard_end is not None: + local_weight_shard = self.local_weight_shard[shard_start:shard_end] + self.local_weights.append(local_weight_shard) + + if isinstance(weight, QuantizedTensor): + high_precision_init_val = weight.get_high_precision_init_val().view(-1) + master_weight_shard = high_precision_init_val.to(weight.device).float()[ + weight_start:weight_end + ] + else: + master_weight_shard = weight.detach().view(-1).float()[weight_start:weight_end] + self.master_weights.append(master_weight_shard) + else: + self.local_weights.append(None) + self.master_weights.append(None) + setattr( + weight, "main_grad", torch.zeros_like(weight, dtype=torch.float32, device="cuda") + ) + + def _flatten_tensors_with_pad(self, tensors): + """ + Flatten the list of tensors and pad them to align with the world size. + + Args: + tensors (list): List of tensors to flatten. + + Returns: + tuple: Flattened tensor and its original length before padding. + """ + world_size = dist.get_world_size(self.dp_group) + + flatten_tensor = torch.cat(tensors) + original_length = flatten_tensor.size(0) + + padding_needed = (world_size - original_length % world_size) % world_size + if padding_needed > 0: + flatten_tensor = torch.cat( + [flatten_tensor, torch.zeros(padding_needed, dtype=flatten_tensor.dtype)] + ) + + return flatten_tensor, original_length + + def zero_grad(self): + for weight in self.weights: + weight.grad = None + weight.main_grad.zero_() + + def step(self): + """ + Perform an optimization step for the distributed sharded model. + + This method includes: + 1. Gradient reduce-scatter: Synchronize gradients across all processes. + 2. Master weight update: Update high-precision master weights using local gradients. + 3. Precision casting: Cast updated master weights to FP8 or BF16 precision. + 4. Weight synchronization: All-gather updated weights across all processes. + + Returns: + None + """ + # Step 1: Reduce-scatter the gradients + main_grad_buffer, _ = self._flatten_tensors_with_pad( + [weight.main_grad.view(-1) for weight in self.weights] + ) + main_grad_buffer = main_grad_buffer.to(self.local_main_grad_shard.dtype) + dist.reduce_scatter_tensor( + self.local_main_grad_shard, main_grad_buffer, group=self.dp_group + ) + + # Step 2: Update the master weights + for weight, master_weight, (shard_start, shard_end) in zip( + self.weights, self.master_weights, self.shard_indices + ): + if master_weight is None: + continue + + # Extract the local gradient shard for this weight + grad = self.local_main_grad_shard[shard_start:shard_end] + + # Update the master weight using gradient descent + master_weight -= grad * self.lr + + # Step 3: Cast master weights to FP8 or BF16 precision + if isinstance(self.weights[0], QuantizedTensor): + local_weights = [] + for local_weight in self.local_weights: + if local_weight is None: + local_weights.append(None) + continue + + local_weights.append(local_weight) + + cast_master_weights_to_fp8( + self.weights, + self.master_weights, + [idx[0] for idx in self.weight_indices], + self.dp_group, + local_weights, + ) + else: + for weight, master_weight in zip(self.local_weights, self.master_weights): + if master_weight is None: + continue + + # Copy updated master weights to local weights + weight.data.copy_(master_weight) + + # Step 4: All-gather updated weights across processes + dist.all_gather_into_tensor( + self.flatten_weight, self.local_weight_shard, group=self.dp_group + ) + + +def _test_fsdp_cast_master_weights_to_fp8(quantization, dp_group): + rank = dist.get_rank(dp_group) + world_size = dist.get_world_size(dp_group) + + # Configuration constants + NUM_STEPS = 100 + SEED = 12345 + + torch.manual_seed(SEED) + torch.cuda.manual_seed(SEED) + + mock_groups = [dist.new_group(ranks=[i]) for i in range(world_size)] + mock_group = mock_groups[rank] + + linear_kwargs = { + "params_dtype": torch.bfloat16, + "bias": False, + "fuse_wgrad_accumulation": False, + } + + # Create model with FP8 weights + with te.fp8.fp8_model_init( + enabled=quantization is not None, + recipe=quantization_recipe(quantization), + preserve_high_precision_init_val=True, + ): + model_fp8 = nn.Sequential( + te.Linear(128, 256, **linear_kwargs), + te.Linear(256, 256 * 3, **linear_kwargs), + te.Linear(256 * 3, 128, **linear_kwargs), + ) + + # Create model with BF16 weights + model = nn.Sequential( + te.Linear(128, 256, **linear_kwargs), + te.Linear(256, 256 * 3, **linear_kwargs), + te.Linear(256 * 3, 128, **linear_kwargs), + ) + + # Make sure the BF16 model and FP8 model have the same initial weights + for w_fp8, w in zip(model_fp8.parameters(), model.parameters()): + high_precision_init_val = w_fp8.get_high_precision_init_val() + w.data.copy_(high_precision_init_val) + + optimizer_fp8 = MiniFSDP([w for w in model_fp8.parameters()], 10.0, dp_group) + optimizer = MiniFSDP([w for w in model.parameters()], 10.0, dp_group) + + for _ in range(100): + optimizer_fp8.zero_grad() + optimizer.zero_grad() + + inputs = [ + torch.randn(16, 128, dtype=torch.bfloat16, device="cuda") for _ in range(world_size) + ] + # Choose based on rank to make sure the inputs of different ranks are different. + x = inputs[rank] + + with te.fp8.fp8_autocast( + enabled=quantization is not None, + fp8_recipe=quantization_recipe(quantization), + fp8_group=mock_group, + ): + y_fp8 = model_fp8(x) + + with te.fp8_autocast( + enabled=quantization is not None, + fp8_recipe=quantization_recipe(quantization), + fp8_group=mock_group, + ): + y = model(x) + + targets = [torch.randn_like(y) for _ in range(world_size)] + # Choose based on rank to make sure the targets of different ranks are different. + target = targets[rank] + loss_fp8 = nn.MSELoss()(y_fp8, target) + loss = nn.MSELoss()(y, target) + + loss_fp8.backward() + loss.backward() + + optimizer_fp8.step() + optimizer.step() + + torch.testing.assert_close(loss_fp8, loss, atol=0, rtol=0) + + print( + f"✅ Successfully validated FSDP {NUM_STEPS} training steps with" + f" {quantization} quantization" + ) + + +def _test_zero_1(dp_group): + """Make sure the implementation of zero-1 optimizer is correct""" + rank = dist.get_rank(dp_group) + world_size = dist.get_world_size(dp_group) + + torch.manual_seed(12345) + torch.cuda.manual_seed(12345) + + weights = [ + torch.randn(256 * 256, dtype=torch.bfloat16, device="cuda"), + torch.randn(256 * 256 * 3, dtype=torch.bfloat16, device="cuda"), + torch.randn(256 * 256 * 2 - 1, dtype=torch.bfloat16, device="cuda"), + ] + + weights_1 = weights + weights_2 = [weight.clone() for weight in weights] + + lr = 1.0 + optimizer_1 = MiniZero_1(weights_1, lr, dp_group) + optimizer_2 = MiniOptimizer(weights_2, lr, dp_group) + + for _ in range(100): + for w1, w2 in zip(weights_1, weights_2): + main_grads = [ + torch.randn_like(w1, dtype=torch.float32, device="cuda") for _ in range(world_size) + ] + # Choose based on rank to make sure the grads of different ranks are different. + main_grad = main_grads[rank] + w1.main_grad = main_grad + w2.main_grad = main_grad + + optimizer_1.step() + optimizer_2.step() + + for w1, w2 in zip(weights_1, weights_2): + torch.testing.assert_close(w1, w2, atol=0, rtol=0) + + +def quantization_recipe(quantization) -> Recipe: + """Quantization recipe setup""" + if quantization == "fp8": + return DelayedScaling( + fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo="max" + ) + elif quantization == "fp8_cs": + return Float8CurrentScaling() + else: + raise ValueError(f"Unsupported quantization: {quantization}") + + +def _test_cast_master_weights_to_fp8(quantization, dp_group): + rank = dist.get_rank(dp_group) + world_size = dist.get_world_size(dp_group) + + torch.manual_seed(12345) + torch.cuda.manual_seed(12345) + + mock_groups = [dist.new_group(ranks=[i]) for i in range(world_size)] + mock_group = mock_groups[rank] + + linear_kwargs = {"params_dtype": torch.bfloat16, "bias": False, "fuse_wgrad_accumulation": True} + + # Create model with FP8 weights + with te.fp8.fp8_model_init( + enabled=quantization is not None, + recipe=quantization_recipe(quantization), + preserve_high_precision_init_val=True, + ): + model_fp8 = nn.Sequential( + te.Linear(128, 256, **linear_kwargs), + te.Linear(256, 256 * 3, **linear_kwargs), + te.Linear(256 * 3, 128, **linear_kwargs), + ) + + # Create model with BF16 weights + model = nn.Sequential( + te.Linear(128, 256, **linear_kwargs), + te.Linear(256, 256 * 3, **linear_kwargs), + te.Linear(256 * 3, 128, **linear_kwargs), + ) + + # Make sure the BF16 model and FP8 model have the same initial weights + for w_fp8, w in zip(model_fp8.parameters(), model.parameters()): + high_precision_init_val = w_fp8.get_high_precision_init_val() + w.data.copy_(high_precision_init_val) + + # Allocate main_grads for each weight + for w_fp8, w in zip(model_fp8.parameters(), model.parameters()): + w_fp8.main_grad = torch.zeros_like(w_fp8, dtype=torch.float32, device="cuda") + w.main_grad = torch.zeros_like(w, dtype=torch.float32, device="cuda") + + optimizer_fp8 = MiniZero_1([w for w in model_fp8.parameters()], 10.0, dp_group) + optimizer = MiniZero_1([w for w in model.parameters()], 10.0, dp_group) + + for _ in range(100): + for w_fp8, w in zip(model_fp8.parameters(), model.parameters()): + w_fp8.main_grad.zero_() + w.main_grad.zero_() + + inputs = [ + torch.randn(16, 128, dtype=torch.bfloat16, device="cuda") for _ in range(world_size) + ] + # Choose based on rank to make sure the inputs of different ranks are different. + x = inputs[rank] + + with te.fp8.fp8_autocast( + enabled=quantization is not None, + fp8_recipe=quantization_recipe(quantization), + fp8_group=mock_group, + ): + y_fp8 = model_fp8(x) + + with te.fp8_autocast( + enabled=quantization is not None, + fp8_recipe=quantization_recipe(quantization), + fp8_group=mock_group, + ): + y = model(x) + + targets = [torch.randn_like(y) for _ in range(world_size)] + # Choose based on rank to make sure the targets of different ranks are different. + target = targets[rank] + loss_fp8 = nn.MSELoss()(y_fp8, target) + loss = nn.MSELoss()(y, target) + + loss_fp8.backward() + loss.backward() + + optimizer_fp8.step() + optimizer.step() + + torch.testing.assert_close(loss_fp8, loss, atol=0, rtol=0) + + +def main(argv=None, namespace=None): + WORLD_RANK = int(os.getenv("RANK", "0")) + WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) + LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0")) + LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1")) + + assert WORLD_SIZE == LOCAL_SIZE # this test supports only 1 node + assert LOCAL_SIZE <= torch.cuda.device_count() + dist_init_kwargs = { + "backend": "nccl", + "rank": WORLD_RANK, + "world_size": WORLD_SIZE, + "timeout": datetime.timedelta(seconds=30), + } + dist_init_kwargs["init_method"] = "env://" + dist_init_kwargs["device_id"] = torch.device(f"cuda:{LOCAL_RANK}") + assert dist.is_nccl_available() + torch.cuda.set_device(LOCAL_RANK) + dist.init_process_group(**dist_init_kwargs) + + parser = argparse.ArgumentParser() + parser.add_argument("--quantization", type=str, default=None, choices=["fp8", "fp8_cs"]) + args = parser.parse_args(argv, namespace) + + dp_group = dist.new_group(backend="nccl") + _test_zero_1(dp_group) + _test_cast_master_weights_to_fp8(args.quantization, dp_group) + _test_fsdp_cast_master_weights_to_fp8(args.quantization, dp_group) + + dist.destroy_process_group() + return 0 + + +if __name__ == "__main__": + + sys.exit(main()) diff --git a/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py b/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py new file mode 100644 index 0000000..8ebe86b --- /dev/null +++ b/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py @@ -0,0 +1,35 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import os +import subprocess +from pathlib import Path + +import pytest +import torch +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager + + +if torch.cuda.device_count() < 2: + pytest.skip("cast_master_weights_to_fp8 test needs at least 2 GPUs.") + +fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() + +TEST_ROOT = Path(__file__).parent.resolve() +NUM_PROCS: int = min(2, torch.cuda.device_count()) +LAUNCH_CMD = ["torchrun", f"--nproc_per_node={NUM_PROCS}"] + + +def _run_test(quantization): + test_path = TEST_ROOT / "run_cast_master_weights_to_fp8.py" + test_cmd = LAUNCH_CMD + [str(test_path)] + ["--quantization", quantization] + result = subprocess.run(test_cmd, env=os.environ, check=False) + assert result.returncode == 0 + + +@pytest.mark.parametrize("quantization", ["fp8", "fp8_cs"]) +def test_cast_master_weights_to_fp8(quantization): + if not fp8_available: + pytest.skip(reason_for_no_fp8) + _run_test(quantization) diff --git a/tests/pytorch/references/ref_per_tensor_cs.py b/tests/pytorch/references/ref_per_tensor_cs.py index 1895b31..dad0c42 100644 --- a/tests/pytorch/references/ref_per_tensor_cs.py +++ b/tests/pytorch/references/ref_per_tensor_cs.py @@ -8,12 +8,8 @@ import transformer_engine_torch as tex from transformer_engine.pytorch.constants import TE_DType_To_Torch -# compute amax and scale -def _ref_compute_amax_scale(x, quant_dtype, eps, pow_2_scales): - x_fp32 = x.to(torch.float32) - amax = torch.amax(torch.abs(x_fp32)).view(1) - assert amax.dtype == torch.float, "amax must be a float tensor." - fp8_max = torch.finfo(quant_dtype).max +# Compute scale and scale_inv from amax +def _ref_compute_scale_and_scale_inv_from_amax(amax, fp8_max, eps, pow_2_scales): # Clamping amax to avoid division by small numbers amax = torch.max(amax, torch.tensor(eps)) @@ -52,6 +48,20 @@ def _ref_compute_amax_scale(x, quant_dtype, eps, pow_2_scales): # Compute scale_inv scale_inv = torch.reciprocal(scale) + return scale, scale_inv + + +# compute amax and scale +def _ref_compute_amax_scale(x, quant_dtype, eps, pow_2_scales): + x_fp32 = x.to(torch.float32) + amax = torch.amax(torch.abs(x_fp32)).view(1) + assert amax.dtype == torch.float, "amax must be a float tensor." + fp8_max = torch.finfo(quant_dtype).max + + scale, scale_inv = _ref_compute_scale_and_scale_inv_from_amax(amax, fp8_max, eps, pow_2_scales) + # Clamping amax to avoid division by small numbers + amax = torch.max(amax, torch.tensor(eps)) + return scale, scale_inv, amax @@ -103,3 +113,7 @@ def ref_per_tensor_cs_cast( qx_t = _multi_dim_transpose(qx) sx_t = sx return qx, sx, qx_t, sx_t + + +def ref_compute_scale_and_scale_inv_from_amax(amax, fp8_max, eps, pow_2_scales): + return _ref_compute_scale_and_scale_inv_from_amax(amax, fp8_max, eps, pow_2_scales) diff --git a/tests/pytorch/test_multi_tensor.py b/tests/pytorch/test_multi_tensor.py index ecc06c3..4dc1ec0 100644 --- a/tests/pytorch/test_multi_tensor.py +++ b/tests/pytorch/test_multi_tensor.py @@ -9,6 +9,9 @@ import transformer_engine.pytorch as te import transformer_engine_torch as tex from transformer_engine.pytorch.optimizers import MultiTensorApply +from references.ref_per_tensor_cs import ref_compute_scale_and_scale_inv_from_amax + + input_size_pairs = [ (7777 * 77, 555 * 555), (777, 555), @@ -216,3 +219,42 @@ def test_multi_tensor_unscale_l2norm(input_size_pair, applier, repeat, in_type, if per_tensor: torch.testing.assert_close(norm_per_tensor, normab.broadcast_to(norm_per_tensor.shape)) assert overflow_buf.item() == 0 + + +@pytest.mark.parametrize("input_size_pair", input_size_pairs + [(1, 1)]) +@pytest.mark.parametrize("applier", appliers) +@pytest.mark.parametrize("repeat", [1, 55]) +@pytest.mark.parametrize("max_fp8", [448.0, 57344.0]) +@pytest.mark.parametrize("pow_2_scales", [False, True]) +@pytest.mark.parametrize("epsilon", [0.0, 100.0]) +def test_multi_tensor_compute_scale_and_scale_inv( + input_size_pair, applier, repeat, max_fp8, pow_2_scales, epsilon +): + sizea, sizeb = input_size_pair + device = torch.device("cuda") + overflow_buf = torch.zeros(1, dtype=torch.int32, device=device) + a = torch.randn([sizea], dtype=torch.float32, device=device).abs() + b = torch.randn([sizeb], dtype=torch.float32, device=device).abs() + + amax_list = [] + for i in range(repeat): + amax_list += [a.clone(), b.clone()] + + scale_list = [torch.empty_like(x) for x in amax_list] + scale_inv_list = [torch.empty_like(x) for x in amax_list] + + applier( + tex.multi_tensor_compute_scale_and_scale_inv, + overflow_buf, + [amax_list, scale_list, scale_inv_list], + max_fp8, + pow_2_scales, + epsilon, + ) + + for amax, scale, scale_inv in zip(amax_list, scale_list, scale_inv_list): + scale_ref, scale_inv_ref = ref_compute_scale_and_scale_inv_from_amax( + amax, max_fp8, epsilon, pow_2_scales + ) + torch.testing.assert_close(scale, scale_ref, rtol=0, atol=0) + torch.testing.assert_close(scale_inv, scale_inv_ref, rtol=0, atol=0) diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 1e6250f..980eeef 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -36,7 +36,12 @@ from transformer_engine.common import recipe import transformer_engine_torch as tex from transformer_engine.pytorch.cpp_extensions import general_gemm from transformer_engine.pytorch.module.base import get_workspace -from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer +from transformer_engine.pytorch.tensor import QuantizedTensor +from transformer_engine.pytorch.tensor.float8_tensor import ( + Float8Quantizer, + Float8CurrentScalingQuantizer, +) +from transformer_engine.pytorch.tensor.utils import replace_raw_data from test_numerics import reset_rng_states, dtype_tols # Only run FP8 tests on supported devices. @@ -1196,3 +1201,70 @@ def _run_attention_extra_state(dtype, config, checkpoint=False, mimic_v1_6=False outputs.append(p.grad) return outputs + + +@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) +def test_replace_raw_data_for_float8tensor(): + """Test the functionality of replace_raw_data""" + torch.manual_seed(12345) + torch.cuda.manual_seed(12345) + + fp8_quantizer = Float8CurrentScalingQuantizer(fp8_dtype=tex.DType.kFloat8E4M3, device="cuda") + fp8_tensor = fp8_quantizer.make_empty([128, 128], dtype=torch.bfloat16, device="cuda") + random_bf16_data = torch.randn(fp8_tensor.shape, dtype=torch.bfloat16, device="cuda") + fp8_quantizer.update_quantized(random_bf16_data, fp8_tensor) + + attrs_to_check = ["_quantizer", "_fp8_dtype", "_scale_inv", "_transpose", "_transpose_invalid"] + attrs = {} + for attr in attrs_to_check: + attrs[attr] = getattr(fp8_tensor, attr) + + old_data = fp8_tensor._data + new_data = torch.empty_like(old_data) + replace_raw_data(fp8_tensor, new_data) + + # Make sure the new_data is properly assigned. + assert fp8_tensor._data.data_ptr() != old_data.data_ptr() + assert fp8_tensor._data.data_ptr() == new_data.data_ptr() + # Make sure the values are not changed. + torch.testing.assert_close(old_data, fp8_tensor._data, atol=0, rtol=0) + # Make sure other attributes are not changed (totally identical) + for attr in attrs_to_check: + assert id(getattr(fp8_tensor, attr)) == id(attrs[attr]) + + +@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) +def test_fp8_model_init_high_precision_init_val(): + """Test fp8_model_init with preserve_high_precision_init_val=True""" + with fp8_model_init(preserve_high_precision_init_val=True): + model = Linear(768, 768) + + weight = model.weight + + assert isinstance(weight, QuantizedTensor), "Weight should be QuantizedTensor" + assert hasattr(weight, "_high_precision_init_val"), "_high_precision_init_val not found" + assert hasattr(weight, "get_high_precision_init_val"), "get_high_precision_init_val() not found" + assert hasattr( + weight, "clear_high_precision_init_val" + ), "clear_high_precision_init_val() not found" + + high_precision = weight.get_high_precision_init_val() + assert high_precision.device.type == "cpu", "high_precision_init_val is not on the CPU" + + new_weight = weight._get_quantizer().make_empty( + shape=weight.shape, dtype=weight.dtype, device=weight.device + ) + weight._get_quantizer().update_quantized(high_precision.to(weight.device), new_weight) + + torch.testing.assert_close( + new_weight.dequantize(dtype=weight.dtype), + weight.dequantize(dtype=weight.dtype), + rtol=0, + atol=0, + ) + + weight.clear_high_precision_init_val() + assert weight.get_high_precision_init_val() is None, "clear_high_precision_init_val() not work" + assert not hasattr( + weight, "._high_precision_init_val" + ), "clear_high_precision_init_val() not work" diff --git a/transformer_engine/common/recipe/current_scaling.cu b/transformer_engine/common/recipe/current_scaling.cu index 3a25d71..cf07d12 100644 --- a/transformer_engine/common/recipe/current_scaling.cu +++ b/transformer_engine/common/recipe/current_scaling.cu @@ -13,6 +13,7 @@ #include "../common.h" #include "../util/logging.h" #include "../util/vectorized_pointwise.h" +#include "recipe_common.cuh" namespace transformer_engine { namespace { @@ -135,7 +136,7 @@ void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaSt "Output tensor for amax computation has invalid amax tensor " "(expected FP32, got dtype=", to_string(output.amax.dtype), ")"); - CheckOutputTensor(output, "output_compute_amax"); + CheckOutputTensor(output, "output_compute_amax", true); // Compute amax TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( @@ -151,41 +152,7 @@ namespace { __global__ void compute_scale_from_amax_kernel(const float *amax_ptr, float *scale_ptr, const float max_fp8, const bool force_pow_2_scales, const float epsilon) { - float amax = *amax_ptr; - if (amax < epsilon) { - amax = epsilon; - } - - float scale = 1.f; - - if (isinf(amax) || amax == 0.f) { - *scale_ptr = scale; - return; - } - - scale = max_fp8 / amax; - - // The amax is too small that the scale becoming infinite in FP32. In other word, - // the scale is not representable in FP32. - if (isinf(scale)) { - // use fp32 max to represent the scale - scale = std::numeric_limits::max(); - } - - if (isnan(scale)) { - scale = 1.f; - } - - if (force_pow_2_scales) { - uint32_t scale_bits = *reinterpret_cast(&scale); - scale_bits &= 0xFF800000; - // If the exponent was zero, we have a logic error. - __builtin_assume(scale_bits != 0); - __builtin_assume(scale_bits != 0x80000000); - scale = *reinterpret_cast(&scale_bits); - } - - *scale_ptr = scale; + *scale_ptr = compute_scale_from_amax(*amax_ptr, max_fp8, force_pow_2_scales, epsilon); } } // namespace diff --git a/transformer_engine/common/recipe/recipe_common.cuh b/transformer_engine/common/recipe/recipe_common.cuh new file mode 100644 index 0000000..c789a9b --- /dev/null +++ b/transformer_engine/common/recipe/recipe_common.cuh @@ -0,0 +1,56 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_RECIPE_RECIPE_COMMON_CUH_ +#define TRANSFORMER_ENGINE_RECIPE_RECIPE_COMMON_CUH_ + +#include + +namespace transformer_engine { + +__device__ __forceinline__ float compute_scale_from_amax(float amax, float max_fp8, + bool force_pow_2_scales, float epsilon) { + if (amax < epsilon) { + amax = epsilon; + } + + float scale = 1.f; + + if (isinf(amax) || amax == 0.f) { + return scale; + } + + // Here we don't use "scale = max_fp8 / amax" because it has different results with/without + // "--use_fast_math". + // "__fdiv_rn" has the same behavior with "max_fp8 / amax" when not using fast math. + scale = __fdiv_rn(max_fp8, amax); + + // The amax is too small that the scale becoming infinite in FP32. In other word, + // the scale is not representable in FP32. + if (isinf(scale)) { + // use fp32 max to represent the scale + scale = std::numeric_limits::max(); + } + + if (isnan(scale)) { + scale = 1.f; + } + + if (force_pow_2_scales) { + uint32_t scale_bits = *reinterpret_cast(&scale); + scale_bits &= 0xFF800000; + // If the exponent was zero, we have a logic error. + __builtin_assume(scale_bits != 0); + __builtin_assume(scale_bits != 0x80000000); + scale = *reinterpret_cast(&scale_bits); + } + + return scale; +} + +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_RECIPE_RECIPE_COMMON_CUH_ diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index e430be0..9561fda 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -252,6 +252,8 @@ at::Tensor scaled_aligned_causal_masked_softmax_backward(at::Tensor output_grads * FP8 recipe **************************************************************************************************/ +void compute_amax(const at::Tensor &tensor, at::Tensor &amax); + void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reduction_buffer, std::vector amax_histories, std::vector scales, @@ -359,6 +361,10 @@ void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag, float momentum, float dampening, float lr, bool nesterov, bool first_run, bool wd_after_momentum, float scale); +void multi_tensor_compute_scale_and_scale_inv_cuda( + int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, + float max_fp8, bool force_pow_2_scales, float epsilon); + /*************************************************************************************************** * padding **************************************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_compute_scale.cu b/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_compute_scale.cu new file mode 100644 index 0000000..d262767 --- /dev/null +++ b/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_compute_scale.cu @@ -0,0 +1,66 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +// Another possibility: +// #include + +#include +// Stringstream is a big hammer, but I want to rely on operator<< for dtype. +#include + +#include "common/recipe/recipe_common.cuh" +#include "common/utils.cuh" +#include "multi_tensor_apply.cuh" +#include "type_shim.h" + +#define BLOCK_SIZE 256 + +struct ComputeScaleAndScaleInvFunctor { + __device__ __forceinline__ void operator()(int chunk_size, volatile int *noop_gmem, + TensorListMetadata<3> &tl, // NOLINT(*) + float max_fp8, bool force_pow_2_scales, + float epsilon) { + // I'd like this kernel to propagate infs/nans. + // if(*noop_gmem == 1) + // return; + + int tensor_loc = tl.block_to_tensor[blockIdx.x]; + int chunk_idx = tl.block_to_chunk[blockIdx.x]; + int n = tl.sizes[tensor_loc]; + + float *amax = reinterpret_cast(tl.addresses[0][tensor_loc]); + amax += chunk_idx * chunk_size; + + float *scale = reinterpret_cast(tl.addresses[1][tensor_loc]); + scale += chunk_idx * chunk_size; + + float *scale_inv = reinterpret_cast(tl.addresses[2][tensor_loc]); + scale_inv += chunk_idx * chunk_size; + + n -= chunk_idx * chunk_size; + + for (int i_start = threadIdx.x; i_start < n && i_start < chunk_size; i_start += blockDim.x) { + float scale_val = transformer_engine::compute_scale_from_amax(amax[i_start], max_fp8, + force_pow_2_scales, epsilon); + scale[i_start] = scale_val; + transformer_engine::reciprocal(scale_inv + i_start, scale_val); + } + } +}; + +void multi_tensor_compute_scale_and_scale_inv_cuda( + int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, + float max_fp8, bool force_pow_2_scales, float epsilon) { + using namespace at; + + multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + ComputeScaleAndScaleInvFunctor(), max_fp8, force_pow_2_scales, epsilon); + AT_CUDA_CHECK(cudaGetLastError()); +} diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index a58fd3a..097cf63 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -178,6 +178,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("dtype"), py::kw_only(), py::arg("out"), py::call_guard()); m.def("get_fused_attn_backend", &get_fused_attn_backend, "Get Fused Attention backend", py::call_guard()); + m.def("compute_amax", &compute_amax, "Compute amax", py::arg("input"), py::arg("amax")); m.def("fused_amax_and_scale_update_after_reduction", &fused_amax_and_scale_update_after_reduction, "Update amax history and FP8 scale/scale_inv after reduction", py::call_guard()); @@ -265,6 +266,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("multi_tensor_sgd", &multi_tensor_sgd_cuda, "Fused SGD optimizer for list of contiguous tensors", py::call_guard()); + m.def("multi_tensor_compute_scale_and_scale_inv", &multi_tensor_compute_scale_and_scale_inv_cuda, + "Fused compute scale and scale_inv from amax", py::call_guard()); // Data structures py::class_(m, "FP8TensorMeta") diff --git a/transformer_engine/pytorch/csrc/extensions/recipe.cpp b/transformer_engine/pytorch/csrc/extensions/recipe.cpp index e8a31da..2dc3b69 100644 --- a/transformer_engine/pytorch/csrc/extensions/recipe.cpp +++ b/transformer_engine/pytorch/csrc/extensions/recipe.cpp @@ -12,10 +12,27 @@ #include "common/common.h" #include "extensions.h" -void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reduction_buffer, +void compute_amax(const at::Tensor& tensor, at::Tensor& amax) { + using namespace transformer_engine; + using namespace transformer_engine::pytorch; + + auto input_tensor = tensor.contiguous(); + const TensorWrapper& te_input = makeTransformerEngineTensor(input_tensor); + + TORCH_CHECK(amax.scalar_type() == at::kFloat, "amax must be a float tensor"); + TORCH_CHECK(amax.numel() == 1, "amax must have exactly one element"); + TensorWrapper fake_te_output( + nullptr, te_input.shape(), + transformer_engine::DType::kFloat8E4M3, // It doesn't matter because we only compute amax. + amax.data_ptr()); + + nvte_compute_amax(te_input.data(), fake_te_output.data(), at::cuda::getCurrentCUDAStream()); +} + +void fused_amax_and_scale_update_after_reduction(const at::Tensor& amax_reduction_buffer, std::vector amax_histories, std::vector scales, - const std::string &amax_compute_algo, + const std::string& amax_compute_algo, transformer_engine::DType fp8_dtype, float margin) { using namespace transformer_engine; diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 87298c2..38f829c 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -93,6 +93,7 @@ class FP8GlobalStateManager: FP8_RECIPE = None FP8_DISTRIBUTED_GROUP = None FP8_PARAMETERS = False + HIGH_PRECISION_INIT_VAL = False IS_FIRST_FP8_MODULE = False FP8_GRAPH_CAPTURING = False FP8_AUTOCAST_DEPTH = 0 @@ -117,6 +118,7 @@ class FP8GlobalStateManager: cls.FP8_RECIPE = None cls.FP8_DISTRIBUTED_GROUP = None cls.FP8_PARAMETERS = False + cls.HIGH_PRECISION_INIT_VAL = False cls.IS_FIRST_FP8_MODULE = False cls.FP8_GRAPH_CAPTURING = False cls.FP8_AUTOCAST_DEPTH = 0 @@ -267,6 +269,11 @@ class FP8GlobalStateManager: """Should the parameters be stored as FP8""" return cls.FP8_PARAMETERS + @classmethod + def with_high_precision_init_val(cls) -> bool: + """Should the high precision initial values be stored with FP8 parameters""" + return cls.HIGH_PRECISION_INIT_VAL + @classmethod def fp8_graph_capturing(cls) -> bool: """Is CUDA graph capture under way?""" @@ -500,7 +507,11 @@ class FP8GlobalStateManager: @contextmanager -def fp8_model_init(enabled: bool = True, recipe: Optional[Recipe] = None) -> None: +def fp8_model_init( + enabled: bool = True, + recipe: Optional[Recipe] = None, + preserve_high_precision_init_val: bool = False, +) -> None: """ Context manager for FP8 initialization of parameters. @@ -511,6 +522,12 @@ def fp8_model_init(enabled: bool = True, recipe: Optional[Recipe] = None) -> Non with fp8_model_init(enabled=True): model = transformer_engine.pytorch.Linear(768, 768) + # Preserving high precision initial value to initialize master weight + with fp8_model_init(enabled=True, preserve_high_precision_init_val=True): + model = transformer_engine.pytorch.Linear(768, 768) + master_weight = model.weight.get_high_precision_init_val() + model.weight.clear_high_precision_init_val() + Parameters ---------- enabled: bool, default = `True` @@ -526,18 +543,29 @@ def fp8_model_init(enabled: bool = True, recipe: Optional[Recipe] = None) -> Non * LoRA-like fine-tuning, where the main parameters of the model do not change. recipe: transformer_engine.common.recipe.Recipe, default = `None` Recipe used to create the parameters. If left to None, it uses the default FP8 recipe. + preserve_high_precision_init_val: bool, default = `False` + when enabled, store the high precision tensor used to initialize FP8 parameters + in CPU memory, and add two function attributes named `get_high_precision_init_val()` + and `clear_high_precision_init_val()` to FP8 parameters to get/clear this high + precision tensor. The purpose is that users can use this high-precision copy + to initialize master weights, avoiding the loss of precision that can occur when + using FP8 parameters directly. Note that after the master weights are initialized, + users should call `clear_high_precision_init_val()` to release this CPU memory. This functionality is *EXPERIMENTAL*. """ _fp8_parameters = FP8GlobalStateManager.FP8_PARAMETERS _fp8_recipe = FP8GlobalStateManager.FP8_RECIPE + _high_precision_init_val = FP8GlobalStateManager.HIGH_PRECISION_INIT_VAL FP8GlobalStateManager.FP8_PARAMETERS = enabled FP8GlobalStateManager.FP8_RECIPE = get_default_fp8_recipe() if recipe is None else recipe + FP8GlobalStateManager.HIGH_PRECISION_INIT_VAL = preserve_high_precision_init_val try: yield finally: FP8GlobalStateManager.FP8_PARAMETERS = _fp8_parameters FP8GlobalStateManager.FP8_RECIPE = _fp8_recipe + FP8GlobalStateManager.HIGH_PRECISION_INIT_VAL = _high_precision_init_val @contextmanager diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 4b82054..cdb75aa 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -10,6 +10,7 @@ import warnings from abc import ABC, abstractmethod from typing import Any, Dict, Generator, List, Optional, Set, Tuple, Union from contextlib import contextmanager +from types import MethodType import torch import torch.nn.functional as F @@ -405,6 +406,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): self.sequence_parallel = False self.param_init_meta = {} self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters() + self.preserve_high_precision_init_val = FP8GlobalStateManager.with_high_precision_init_val() self.fsdp_wrapped = False self.fsdp_group = None self._fp8_workspaces: Dict[str, QuantizedTensor] = {} @@ -902,7 +904,11 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): # If primary weights are in fp8, wrap the parameter as FP8Tensor fp8_meta_index = self.param_init_meta[name].fp8_meta_index + high_precision_init_val = None if self.primary_weights_in_fp8 and fp8_meta_index is not None: + if self.preserve_high_precision_init_val: + high_precision_init_val = param.detach().cpu() + quantizer = self.quantizers["scaling_fwd"][fp8_meta_index] assert ( quantizer is not None @@ -914,7 +920,34 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): # NOTE: Currently this can only be broken when primary weights are in Fp8 but # re-applying the nn.Parameter() wrap is a no-op when the input is already # a parameter so we always re-apply it just for extra safety. - setattr(self, name, torch.nn.Parameter(param)) + param = torch.nn.Parameter(param) + if high_precision_init_val is not None: + + # - Master weights are initialized from model weights, if we use fp8 primary + # weights to initialize master weights, the numerical values of master weights + # are not consistent with the numerical values when we initialize them from + # bf16/fp16 weights. + # - So we add a `_high_precision_init_val` attribute to each model weight to store + # the original bf16/fp16 weight on cpu before casting it to fp8. And users can + # use `get_high_precision_init_val` to get this cpu tensor. + # - This cpu tensor is not needed once the master weight is initialized, so users + # should call `clear_high_precision_init_val` to remove it after master weight + # is initialized. + + def get(self): + if hasattr(self, "_high_precision_init_val"): + return self._high_precision_init_val + return None + + def clear(self): + if hasattr(self, "_high_precision_init_val"): + del self._high_precision_init_val + + param._high_precision_init_val = high_precision_init_val + param.get_high_precision_init_val = MethodType(get, param) + param.clear_high_precision_init_val = MethodType(clear, param) + + setattr(self, name, param) @abstractmethod def forward(self): @@ -953,10 +986,26 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): FSDP process group that the weights are distributed over. """ + # FP8 primary weights + if isinstance(tensor, QuantizedTensor): + if update_workspace and quantizer is not None: + tensor.update_usage( + rowwise_usage=quantizer.rowwise_usage, + columnwise_usage=quantizer.columnwise_usage, + ) + return tensor + # Try getting workspace from cache out = None if cache_name is not None: out = self._fp8_workspaces.get(cache_name, None) + if quantizer is not None and isinstance(out, MXFP8TensorBase): + if quantizer.rowwise_usage and out._rowwise_data is None: + out = None + del self._fp8_workspaces[cache_name] + elif quantizer.columnwise_usage and out._columnwise_data is None: + out = None + del self._fp8_workspaces[cache_name] # Gather cached Fp8 workspace if it's distributed # NOTE: FSDP sharding is supported only for Fp8 buffers and will not work diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 8bf420a..8963a61 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -130,20 +130,17 @@ class _GroupedLinear(torch.autograd.Function): ) weights_fp8 = [] bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype - if not isinstance(weights[0], QuantizedTensor): - # FP8 cast to workspace buffer - update_workspace = is_first_microbatch is None or is_first_microbatch - for i in range(num_gemms): - weight_fp8 = module.get_weight_workspace( - tensor=weights[i], - quantizer=weight_quantizers[i], - cache_name=(None if is_first_microbatch is None else f"weight{i}"), - update_workspace=update_workspace, - skip_update_flag=skip_fp8_weight_update, - ) - weights_fp8.append(weight_fp8) - else: - weights_fp8 = weights + # FP8 cast to workspace buffer + update_workspace = is_first_microbatch is None or is_first_microbatch + for i in range(num_gemms): + weight_fp8 = module.get_weight_workspace( + tensor=weights[i], + quantizer=weight_quantizers[i], + cache_name=(None if is_first_microbatch is None else f"weight{i}"), + update_workspace=update_workspace, + skip_update_flag=skip_fp8_weight_update, + ) + weights_fp8.append(weight_fp8) else: inputmats = inputmats_no_fp8 @@ -180,7 +177,7 @@ class _GroupedLinear(torch.autograd.Function): weight_quantizers[i].calibrate(weights[i]) if is_grad_enabled: - + ctx.weight_quantizers = weight_quantizers ctx.weights_shape_1 = weights[0].shape[1] tensors_to_save, tensor_objects = prepare_for_saving(*inputmats, *weights_fp8, *biases) @@ -270,6 +267,12 @@ class _GroupedLinear(torch.autograd.Function): device=ctx.device, ) + for weight, quantizer in zip(weights, ctx.weight_quantizers): + if quantizer is not None and isinstance(weight, QuantizedTensor): + weight.update_usage( + rowwise_usage=quantizer.rowwise_usage, + columnwise_usage=quantizer.columnwise_usage, + ) general_grouped_gemm( weights, grad_output, diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 4022924..fc316e3 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -262,28 +262,26 @@ class _LayerNormLinear(torch.autograd.Function): nvtx_range_pop(f"{nvtx_label}.gemm_input_cast_comm") # Cast weight to expected dtype - weightmat = weight - quantized_weight = False if not fp8: - weightmat = cast_if_needed(weightmat, activation_dtype) + quantized_weight = False + weightmat = cast_if_needed(weight, activation_dtype) else: - if not isinstance(weight, QuantizedTensor): - quantized_weight = True - - # Configure quantizer - if weight_quantizer is not None: - weight_quantizer.set_usage(rowwise=True, columnwise=True) - - # FP8 cast to workspace buffer - update_workspace = is_first_microbatch is None or is_first_microbatch - weightmat = module.get_weight_workspace( - tensor=weight, - quantizer=weight_quantizer, - cache_name=(None if is_first_microbatch is None else "weight"), - update_workspace=update_workspace, - skip_update_flag=skip_fp8_weight_update, - fsdp_group=fsdp_group, - ) + quantized_weight = not isinstance(weight, QuantizedTensor) + + # Configure quantizer + if weight_quantizer is not None: + weight_quantizer.set_usage(rowwise=True, columnwise=True) + + # FP8 cast to workspace buffer + update_workspace = is_first_microbatch is None or is_first_microbatch + weightmat = module.get_weight_workspace( + tensor=weight, + quantizer=weight_quantizer, + cache_name=(None if is_first_microbatch is None else "weight"), + update_workspace=update_workspace, + skip_update_flag=skip_fp8_weight_update, + fsdp_group=fsdp_group, + ) # Cast bias to expected dtype bias_dtype = activation_dtype @@ -345,11 +343,12 @@ class _LayerNormLinear(torch.autograd.Function): clear_tensor_data(ln_out, ln_out_total) if is_grad_enabled: + ctx.weight_quantizer = weight_quantizer ctx.ln_out_needs_gather = ( weight.requires_grad and parallel_mode == "column" and sequence_parallel ) - # Input with column-wise usage is needed for dgrad GEMM. + # Input with column-wise usage is needed for wgrad GEMM. if backward_needs_input: if isinstance(ln_out, QuantizedTensor): # For sequence parallel in vanilla FP8, rowwise data is @@ -358,6 +357,11 @@ class _LayerNormLinear(torch.autograd.Function): if isinstance(ln_out, MXFP8TensorBase) or not ctx.ln_out_needs_gather: ln_out.update_usage(rowwise_usage=False) + # Weight with column-wise usage is needed for dgrad GEMM. + if inp.requires_grad: + if isinstance(weightmat, QuantizedTensor): + weightmat.update_usage(columnwise_usage=True) + if cpu_offloading: if fp8 and weightmat is not None: set_offloading_param(weightmat, "weight_offloading", True) @@ -642,6 +646,11 @@ class _LayerNormLinear(torch.autograd.Function): if hasattr(recipe, "fp8_gemm_dgrad"): dgrad_gemm_use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator + if ctx.weight_quantizer is not None and isinstance(weight, QuantizedTensor): + weight.update_usage( + rowwise_usage=ctx.weight_quantizer.rowwise_usage, + columnwise_usage=ctx.weight_quantizer.columnwise_usage, + ) dgrad, *_ = general_gemm( weight, grad_output, @@ -1274,6 +1283,7 @@ class LayerNormLinear(TransformerEngineBaseModule): inp: torch.Tensor, is_first_microbatch: Optional[bool] = None, fp8_output: Optional[bool] = False, + fp8_grad: Optional[bool] = False, ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: """ Apply layer normalization to the input followed by a linear transformation. @@ -1304,6 +1314,13 @@ class LayerNormLinear(TransformerEngineBaseModule): if skip_fp8_weight_update is not None: is_first_microbatch = False + if self.ub_overlap_rs_fprop: + if get_ub(self.ub_name + "_fprop").is_fp8_ubuf(): + fp8_output = True + if self.ub_overlap_rs_dgrad: + if get_ub(self.ub_name + "_dgrad").is_fp8_ubuf(): + fp8_grad = True + with self.prepare_forward( inp, allow_non_contiguous=False # removed .contiguous from inside the layer ) as inp: @@ -1331,7 +1348,7 @@ class LayerNormLinear(TransformerEngineBaseModule): output_quantizer, grad_output_quantizer, grad_input_quantizer, - ) = self._get_quantizers(fp8_output) + ) = self._get_quantizers(fp8_output, fp8_grad) if torch.is_grad_enabled(): fwd_fn = _LayerNormLinear.apply @@ -1397,7 +1414,7 @@ class LayerNormLinear(TransformerEngineBaseModule): return out, ln_out return out - def _get_quantizers(self, fp8_output): + def _get_quantizers(self, fp8_output, fp8_grad): if not self.fp8: return [None] * 5 grad_input_quantizer = None @@ -1412,6 +1429,8 @@ class LayerNormLinear(TransformerEngineBaseModule): if torch.is_grad_enabled(): grad_output_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1] grad_output_quantizer.internal = True + if fp8_grad: + grad_input_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT1] return ( input_quantizer, diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 633690b..9cffc47 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -319,35 +319,31 @@ class _LayerNormMLP(torch.autograd.Function): ln_out_total = ln_out # Cast weights to expected dtype - fc1_weight_final = fc1_weight - fc2_weight_final = fc2_weight if not fp8: - fc1_weight_final = cast_if_needed(fc1_weight_final, activation_dtype) - fc2_weight_final = cast_if_needed(fc2_weight_final, activation_dtype) + fc1_weight_final = cast_if_needed(fc1_weight, activation_dtype) + fc2_weight_final = cast_if_needed(fc2_weight, activation_dtype) else: # If weights are not quantized, we call get_weight_workspace, # which handles weight caching etc. - if not isinstance(fc1_weight, QuantizedTensor): - # FP8 cast to workspace buffer - update_workspace = is_first_microbatch is None or is_first_microbatch - fc1_weight_final = module.get_weight_workspace( - tensor=fc1_weight, - quantizer=fc1_weight_quantizer, - cache_name=(None if is_first_microbatch is None else "fc1_weight"), - update_workspace=update_workspace, - skip_update_flag=skip_fp8_weight_update, - fsdp_group=fsdp_group, - ) - if not isinstance(fc2_weight, QuantizedTensor): - fc2_weight_quantizer.set_usage(rowwise=True, columnwise=True) - fc2_weight_final = module.get_weight_workspace( - tensor=fc2_weight, - quantizer=fc2_weight_quantizer, - cache_name=(None if is_first_microbatch is None else "fc2_weight"), - update_workspace=update_workspace, - skip_update_flag=skip_fp8_weight_update, - fsdp_group=fsdp_group, - ) + # FP8 cast to workspace buffer + update_workspace = is_first_microbatch is None or is_first_microbatch + fc1_weight_final = module.get_weight_workspace( + tensor=fc1_weight, + quantizer=fc1_weight_quantizer, + cache_name=(None if is_first_microbatch is None else "fc1_weight"), + update_workspace=update_workspace, + skip_update_flag=skip_fp8_weight_update, + fsdp_group=fsdp_group, + ) + fc2_weight_quantizer.set_usage(rowwise=True, columnwise=True) + fc2_weight_final = module.get_weight_workspace( + tensor=fc2_weight, + quantizer=fc2_weight_quantizer, + cache_name=(None if is_first_microbatch is None else "fc2_weight"), + update_workspace=update_workspace, + skip_update_flag=skip_fp8_weight_update, + fsdp_group=fsdp_group, + ) # Cast biases to expected dtype bias_dtype = activation_dtype @@ -430,7 +426,6 @@ class _LayerNormMLP(torch.autograd.Function): dim_size[0] = dim_size[0] // tp_world_size dim_size[1] = fc2_weight.size(0) rs_out = torch.empty(dim_size, dtype=activation_dtype, device=device) - fc2_out = ub_obj_fc2out.get_buffer(output_quantizer) else: dim_size = list(act_out.size()) dim_size[1] = fc2_weight.size(0) @@ -450,6 +445,14 @@ class _LayerNormMLP(torch.autograd.Function): ub_type=tex.CommOverlapType.RS if ub_overlap_rs else None, extra_output=rs_out, ) + + # Weight with column-wise usage is needed for dgrad GEMM. + if is_grad_enabled and inp.requires_grad: + if isinstance(fc1_weight_final, QuantizedTensor): + fc1_weight_final.update_usage(columnwise_usage=True) + if isinstance(fc2_weight_final, QuantizedTensor): + fc2_weight_final.update_usage(columnwise_usage=True) + if not is_grad_enabled: clear_tensor_data(act_out, fc1_out_without_bias, fc1_out) @@ -488,6 +491,8 @@ class _LayerNormMLP(torch.autograd.Function): fc2_weight_final if fp8 and not isinstance(fc2_weight, Float8Tensor) else None, ) + ctx.fc1_weight_quantizer = fc1_weight_quantizer + ctx.fc2_weight_quantizer = fc2_weight_quantizer if not fc1_weight.requires_grad: if not return_layernorm_output: clear_tensor_data(ln_out) @@ -500,11 +505,13 @@ class _LayerNormMLP(torch.autograd.Function): ln_weight, ln_out.clone() if ub_overlap_ag else ln_out, # avoid saving a UB buffer fc1_weight_final, + fc1_weight, fc1_bias, fc1_out, fc1_out_without_bias, act_out, fc2_weight_final, + fc2_weight, fc2_bias, mu, rsigma, @@ -619,11 +626,13 @@ class _LayerNormMLP(torch.autograd.Function): ln_weight, ln_out, fc1_weight, + origin_fc1_weight, fc1_bias, fc1_out, fc1_out_without_bias, act_out, fc2_weight, + origin_fc2_weight, fc2_bias, mu, rsigma, @@ -642,7 +651,7 @@ class _LayerNormMLP(torch.autograd.Function): ) fc2_weight_main_grad = ( ctx.fc2_main_grad - if fc2_weight is not None + if origin_fc2_weight is not None and ctx.fuse_wgrad_accumulation and ctx.fc2_weight_requires_grad else None @@ -651,8 +660,8 @@ class _LayerNormMLP(torch.autograd.Function): # For CPU offloading, we offloaded weight and weight.main_grad to different tensors, # we need to connect them into one. if ctx.fuse_wgrad_accumulation: - fc1_weight.main_grad = fc1_weight_main_grad - fc2_weight.main_grad = fc2_weight_main_grad + origin_fc1_weight.main_grad = fc1_weight_main_grad + origin_fc2_weight.main_grad = fc2_weight_main_grad # TODO: Fix this # pylint: disable=fixme # Gather saved autograd context tensors when running with FSDP @@ -735,6 +744,11 @@ class _LayerNormMLP(torch.autograd.Function): ) # FC2 DGRAD; Unconditional + if ctx.fc2_weight_quantizer is not None and isinstance(ctx.fc2_weight, QuantizedTensor): + ctx.fc2_weight.update_usage( + rowwise_usage=ctx.fc2_weight_quantizer.rowwise_usage, + columnwise_usage=ctx.fc2_weight_quantizer.columnwise_usage, + ) gemm_output, *_ = general_gemm( fc2_weight, grad_output, @@ -769,14 +783,18 @@ class _LayerNormMLP(torch.autograd.Function): act_out, grad_output, get_workspace(), - out_dtype=ctx.activation_dtype, + out_dtype=( + origin_fc2_weight.main_grad.dtype + if ctx.fuse_wgrad_accumulation + else ctx.activation_dtype + ), quantization_params=None, # wgrad in high precision layout="NT", grad=True, bias=fc2_bias if fc2_bias is not None and fc2_bias_grad is None else None, accumulate=accumulate_wgrad_into_param_main_grad, use_split_accumulator=_2X_ACC_WGRAD, - out=fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None, + out=origin_fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None, ) if fc2_bias_grad is None: fc2_bias_grad = fc2_bias_grad_ @@ -864,6 +882,13 @@ class _LayerNormMLP(torch.autograd.Function): fc1_dgrad_bulk = ub_obj_fc1_wgrad.get_buffer(None) # FC1 DGRAD: Unconditional + if ctx.fc1_weight_quantizer is not None and isinstance( + ctx.fc1_weight_quantizer, QuantizedTensor + ): + ctx.fc1_weight.update_usage( + rowwise_usage=ctx.fc1_weight_quantizer.rowwise_usage, + columnwise_usage=ctx.fc1_weight_quantizer.columnwise_usage, + ) fc1_dgrad, *_, fc1_dgrad_rs_out = general_gemm( fc1_weight, dact, @@ -930,12 +955,16 @@ class _LayerNormMLP(torch.autograd.Function): ln_out_total, dact, get_workspace(), - out_dtype=ctx.activation_dtype, + out_dtype=( + origin_fc1_weight.main_grad.dtype + if ctx.fuse_wgrad_accumulation + else ctx.activation_dtype + ), layout="NT", grad=fuse_gemm_and_bias_fc1_wgrad, bias=fc1_bias if fuse_gemm_and_bias_fc1_wgrad else None, accumulate=accumulate_wgrad_into_param_main_grad, - out=fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None, + out=origin_fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None, ub=ub_obj_fc1_wgrad, ub_type=tex.CommOverlapType.RS if ctx.ub_bulk_wgrad else None, extra_output=fc1_dgrad_rs_out, @@ -996,16 +1025,21 @@ class _LayerNormMLP(torch.autograd.Function): if ctx.fc1_weight_requires_grad: # Handle custom DDP from mcore. if ctx.fuse_wgrad_accumulation and hasattr(fc1_weight, "grad_added_to_main_grad"): - fc1_weight.grad_added_to_main_grad = True - if getattr(fc1_weight, "zero_out_wgrad", False): + origin_fc1_weight.grad_added_to_main_grad = True + if getattr(origin_fc1_weight, "zero_out_wgrad", False): fc1_wgrad = torch.zeros( - fc1_weight.main_grad.shape, - dtype=fc1_weight.dtype, + origin_fc1_weight.main_grad.shape, + dtype=origin_fc1_weight.dtype, device=torch.cuda.current_device(), requires_grad=False, ) else: - fc1_wgrad = None + fc1_wgrad = torch.empty( + origin_fc1_weight.main_grad.shape, + dtype=origin_fc1_weight.dtype, + device=torch.cuda.current_device(), + requires_grad=False, + ) elif ctx.fuse_wgrad_accumulation: fc1_wgrad = None else: @@ -1013,17 +1047,24 @@ class _LayerNormMLP(torch.autograd.Function): if ctx.fc2_weight_requires_grad: # Handle custom DDP from mcore. - if ctx.fuse_wgrad_accumulation and hasattr(fc2_weight, "grad_added_to_main_grad"): - fc2_weight.grad_added_to_main_grad = True - if getattr(fc2_weight, "zero_out_wgrad", False): + if ctx.fuse_wgrad_accumulation and hasattr( + origin_fc2_weight, "grad_added_to_main_grad" + ): + origin_fc2_weight.grad_added_to_main_grad = True + if getattr(origin_fc2_weight, "zero_out_wgrad", False): fc2_wgrad = torch.zeros( - fc2_weight.main_grad.shape, - dtype=fc2_weight.dtype, + origin_fc2_weight.main_grad.shape, + dtype=origin_fc2_weight.dtype, device=torch.cuda.current_device(), requires_grad=False, ) else: - fc2_wgrad = None + fc2_wgrad = torch.empty( + origin_fc2_weight.main_grad.shape, + dtype=origin_fc2_weight.dtype, + device=torch.cuda.current_device(), + requires_grad=False, + ) elif ctx.fuse_wgrad_accumulation: fc2_wgrad = None else: @@ -1429,6 +1470,11 @@ class LayerNormMLP(TransformerEngineBaseModule): if skip_fp8_weight_update is not None: is_first_microbatch = False + fp8_output = False + if self.ub_overlap_rs: + if get_ub("fc2_fprop").is_fp8_ubuf(): + fp8_output = True + with self.prepare_forward(inp, num_gemms=2) as inp: # Get quantizers ( @@ -1440,7 +1486,7 @@ class LayerNormMLP(TransformerEngineBaseModule): grad_fc1_output_quantizer, grad_fc2_output_quantizer, grad_input_quantizer, - ) = self._get_quantizers() + ) = self._get_quantizers(fp8_output) # Get weight tensors fc1_weight = self.fc1_weight @@ -1528,7 +1574,7 @@ class LayerNormMLP(TransformerEngineBaseModule): return out, ln_out return out - def _get_quantizers(self): + def _get_quantizers(self, fp8_output): ( fc1_input_quantizer, fc1_weight_quantizer, @@ -1550,6 +1596,8 @@ class LayerNormMLP(TransformerEngineBaseModule): ) fc2_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_WEIGHT] fc2_weight_quantizer.internal = True + if fp8_output: + output_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_OUTPUT] if torch.is_grad_enabled(): grad_fc2_output_quantizer = self.quantizers["scaling_bwd"][ tex.FP8BwdTensors.GRAD_OUTPUT1 diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index f96355a..91dfe92 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -176,31 +176,29 @@ class _Linear(torch.autograd.Function): nvtx_range_pop(f"{nvtx_label}.input_cast_comm") # Cast weight to expected dtype - weightmat = weight if not fp8: - weightmat = cast_if_needed(weightmat, activation_dtype) + weightmat = cast_if_needed(weight, activation_dtype) else: - if not isinstance(weight, QuantizedTensor): - # Configure quantizer - if weight_quantizer is not None: - columnwise_usage = is_grad_enabled and inp.requires_grad - if not columnwise_usage: - columnwise_usage = ( - is_fp8_activation_recompute_enabled() - and not in_fp8_activation_recompute_phase() - ) - weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) - - # FP8 cast to workspace buffer - update_workspace = is_first_microbatch is None or is_first_microbatch - weightmat = module.get_weight_workspace( - tensor=weight, - quantizer=weight_quantizer, - cache_name=(None if is_first_microbatch is None else "weight"), - update_workspace=update_workspace, - skip_update_flag=skip_fp8_weight_update, - fsdp_group=fsdp_group, - ) + # Configure quantizer + if weight_quantizer is not None: + columnwise_usage = is_grad_enabled and inp.requires_grad + if not columnwise_usage: + columnwise_usage = ( + is_fp8_activation_recompute_enabled() + and not in_fp8_activation_recompute_phase() + ) + weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) + + # FP8 cast to workspace buffer + update_workspace = is_first_microbatch is None or is_first_microbatch + weightmat = module.get_weight_workspace( + tensor=weight, + quantizer=weight_quantizer, + cache_name=(None if is_first_microbatch is None else "weight"), + update_workspace=update_workspace, + skip_update_flag=skip_fp8_weight_update, + fsdp_group=fsdp_group, + ) # Cast bias to expected dtype bias_dtype = activation_dtype @@ -259,6 +257,7 @@ class _Linear(torch.autograd.Function): nvtx_range_pop(f"{nvtx_label}.gemm") if is_grad_enabled: + ctx.weight_quantizer = weight_quantizer saved_inputmat = None ctx.backward_input_needs_gather = ( @@ -274,6 +273,11 @@ class _Linear(torch.autograd.Function): inputmat.update_usage(rowwise_usage=False) saved_inputmat = inputmat + # Weight with column-wise usage is needed for dgrad GEMM. + if inp.requires_grad: + if isinstance(weightmat, QuantizedTensor): + weightmat.update_usage(columnwise_usage=True) + if cpu_offloading: set_offloading_param(weight, "weight_offloading", True) set_offloading_param(weightmat, "weight_offloading", True) @@ -530,6 +534,12 @@ class _Linear(torch.autograd.Function): recipe.fp8_gemm_dgrad.use_split_accumulator ) + if ctx.weight_quantizer is not None and isinstance(weight_fp8, QuantizedTensor): + weight_fp8.update_usage( + rowwise_usage=ctx.weight_quantizer.rowwise_usage, + columnwise_usage=ctx.weight_quantizer.columnwise_usage, + ) + dgrad, *_, rs_out = general_gemm( weight_fp8, grad_output, @@ -1077,6 +1087,13 @@ class Linear(TransformerEngineBaseModule): if skip_fp8_weight_update is not None: is_first_microbatch = False + if self.ub_overlap_rs_fprop: + if get_ub(self.ub_name + "_fprop").is_fp8_ubuf(): + fp8_output = True + if self.ub_overlap_rs_dgrad: + if get_ub(self.ub_name + "_dgrad").is_fp8_ubuf(): + fp8_grad = True + with self.prepare_forward( inp, allow_non_contiguous=isinstance(inp, QuantizedTensor), diff --git a/transformer_engine/pytorch/tensor/__init__.py b/transformer_engine/pytorch/tensor/__init__.py index 610ec2a..22b86fb 100644 --- a/transformer_engine/pytorch/tensor/__init__.py +++ b/transformer_engine/pytorch/tensor/__init__.py @@ -7,6 +7,7 @@ import torch from .quantized_tensor import QuantizedTensor, Quantizer +from .utils import cast_master_weights_to_fp8, replace_raw_data __all__ = [ "QuantizedTensor", diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index e45010b..2fb1283 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -185,9 +185,9 @@ class Float8CurrentScalingQuantizer(Quantizer): """ - """Scaling factor to multiply when quantizing to FP8""" + """Workspace buffer for FP8 scaling factor""" scale: torch.Tensor - """Max-abs value from last FP8 cast""" + """Workspace buffer for max-abs value""" amax: torch.Tensor """FP8 datatype""" dtype: TE_DType diff --git a/transformer_engine/pytorch/tensor/utils.py b/transformer_engine/pytorch/tensor/utils.py new file mode 100644 index 0000000..8dd04b5 --- /dev/null +++ b/transformer_engine/pytorch/tensor/utils.py @@ -0,0 +1,315 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Helper functions for using fp8 tensors as weights""" + +import torch + +import transformer_engine_torch as tex +from transformer_engine_torch import multi_tensor_scale, multi_tensor_compute_scale_and_scale_inv + +from .quantized_tensor import QuantizedTensor +from .float8_tensor import Float8Tensor, Float8Quantizer, Float8CurrentScalingQuantizer +from .mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer +from ..optimizers.multi_tensor_apply import multi_tensor_applier + + +def replace_raw_data(tensor: QuantizedTensor, new_raw_data: torch.Tensor): + r"""Change a quantized tensor's data buffer while preserving values + + This function modifies only the address space of the underlying + raw data and does not alter any other tensor attributes or values. + + This may be used for custom buffer allocations, e.g. packing + multiple parameter tensors together into a single contiguous + buffer for ZeRO-2. + + """ + if isinstance(tensor, Float8Tensor): + old_raw_data = tensor._data + assert old_raw_data.dtype == new_raw_data.dtype, "The data types of raw data don't match" + new_raw_data.detach().copy_(old_raw_data) + tensor._data = new_raw_data + del old_raw_data + elif isinstance(tensor, MXFP8Tensor): + raise NotImplementedError("replace_raw_data for MXFP8Tensor is not supported yet") + else: + raise ValueError(f"replace_raw_data for {type(tensor)} is not supported yet") + + +def cast_master_weights_to_fp8( + model_weights, master_weights, start_offsets, group, fsdp_shard_model_weights=None +): + r"""Helper function to cast master weights to FP8 primary weights. + + This is intended for use with ZeRO/FSDP. Each rank has a shard of + the master weights (possibly empty) and a full copy of the model + weights. + + Parameters + ---------- + model_weights : list of FP8 weights. + master_weights : list of master weights. Typically they are FP32 weights. + start_offsets : list of integers, the starting index of the master weight in the model weight. + master_weight may be smaller than model_weight because it could be distributed + across multiple ranks. These offsets indicate which part of the model_weight + should be updated. + group : The distributed group to do amax reduction. Typically it's the data parallel + group. + fsdp_shard_model_weights : list of FSDP shard model weights. If None, it means that the model weights are + not sharded. Otherwise, it means that the model weights are sharded and we get + target model weights data storage using the FSDP shard model weights. + + """ + + delayed_scaling_params = [] + current_scaling_params = [] + + if fsdp_shard_model_weights is None: + use_fsdp_shard_model_weights = False + fsdp_shard_model_weights = [None] * len(model_weights) + else: + use_fsdp_shard_model_weights = True + + for model_weight, master_weight, start_offset, fsdp_shard_model_weight in zip( + model_weights, master_weights, start_offsets, fsdp_shard_model_weights + ): + # Clear `_high_precision_init_val` of model_weight automatically. + # - Master weights are initialized from model weights, if we use fp8 primary weights to + # initialize master weights, the numerical values of master weights are not consistent + # with the numerical values when we initialize them from bf16/fp16 weights. + # - So we add a `_high_precision_init_val` attribute to each model weight to store the + # original bf16/fp16 weight on cpu before casting it to fp8. And users can use + # `get_high_precision_init_val` to get this cpu tensor. + # - This cpu tensor is not needed once the master weight is initialized, so users should + # call `clear_high_precision_init_val` to remove it after master weight is initialized. + # - In case users don't call `clear_high_precision_init_val`, we will clear it automatically + # here. It's safe to clear the `_high_precision_init_val` at this time because this + # function is supposed to be called after the master weights are initialized and updated. + if hasattr(model_weight, "clear_high_precision_init_val"): + model_weight.clear_high_precision_init_val() + + if master_weight is not None: + # When not using fp8_primary_weights, the master_weight (fp32) is first cast to + # bf16/fp16, and then cast to fp8 during forward. Although it's not necessary when + # fp8_primary_weights is enabled, we still keep this logic to keep numerical + # consistency. So here we cast the master_weight to model_weight.dtype. + master_weight = master_weight.to(model_weight.dtype) + + quantizer = model_weight._get_quantizer() + if isinstance(quantizer, Float8Quantizer): + delayed_scaling_params.append( + (model_weight, master_weight, start_offset, fsdp_shard_model_weight) + ) + elif isinstance(quantizer, Float8CurrentScalingQuantizer): + current_scaling_params.append( + (model_weight, master_weight, start_offset, fsdp_shard_model_weight) + ) + elif isinstance(quantizer, MXFP8Quantizer): + raise NotImplementedError( + "cast_master_weights_to_fp8 for MXFP8BlockScaling is not supported yet" + ) + else: + raise ValueError( + f"cast_master_weights_to_fp8 for {type(quantizer)} is not supported yet" + ) + + if len(delayed_scaling_params) > 0: + _cast_master_weights_to_fp8_delayed_scaling( + delayed_scaling_params, group, use_fsdp_shard_model_weights + ) + if len(current_scaling_params) > 0: + _cast_master_weights_to_fp8_current_scaling( + current_scaling_params, group, use_fsdp_shard_model_weights + ) + + +def _cast_master_weights_to_fp8_delayed_scaling(params, group, use_fsdp_shard_model_weights=False): + r"""Helper function to cast master weights to FP8 primary weights for delayed scaling. + + Parameters + ---------- + params : List of tuple, each tuple contains a model weight, a master weight, and an offset + indicating the starting index of the master weight in the model weight. + group : The distributed group to do amax reduction. Typically it's the data parallel + group. + use_fsdp_shard_model_weights : bool, if True, it means that the model weights are sharded. + """ + + # Collect amaxes to do reduce-max among dp group. + # Collect scales and scale_invs to update scale_invs of the fp8 weights. + amaxes, scales, scale_invs = [], [], [] + + for model_weight, master_weight, start_offset, shard_model_weight_raw in params: + # Reset transpose cache for all model weights. + # We cannot create transpose cache here because users (like megatron) may want to overlap + # the all-gather of model weights and forward process, so the model weight is not updated + # currently. + model_weight._reset_caches() + + quantizer = model_weight._get_quantizer() + + amaxes.append(quantizer.amax.view(1)) + scales.append(quantizer.scale.view(1)) + scale_invs.append(model_weight._scale_inv.view(1)) + + # If master weight is None, it means that the master weight of the current model weight + # is in other DP ranks. + if master_weight is None: + continue + + # If master weight is not None, start_offset must be a valid value. + assert start_offset is not None + assert start_offset >= 0 + end_offset = start_offset + master_weight.numel() + assert end_offset <= model_weight.numel() + + # master_weight may be smaller than model_weight because it could be distributed across + # multiple ranks. So we need to create a dummy weight using the raw data from model_weight. + if not use_fsdp_shard_model_weights: + shard_model_weight_raw = model_weight._data.view(-1)[start_offset:end_offset] + shard_model_weight_fp8 = quantizer.create_tensor_from_data( + shard_model_weight_raw.view(1, -1), + model_weight.dtype, + ) + + # Cast master weight to fp8. + quantizer.update_quantized(master_weight.view(1, -1), shard_model_weight_fp8) + + if len(amaxes) > 0: + dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device=amaxes[0].device) + + # Reduce amaxes. + packed_amaxes = torch.empty(len(amaxes), dtype=torch.float32, device=amaxes[0].device) + packed_amax_views = [packed_amaxes[i].view(1) for i in range(len(amaxes))] + multi_tensor_applier( + multi_tensor_scale, dummy_overflow_buf, [amaxes, packed_amax_views], 1.0 + ) + torch.distributed.all_reduce( + packed_amaxes, + op=torch.distributed.ReduceOp.MAX, + group=group, + ) + multi_tensor_applier( + multi_tensor_scale, dummy_overflow_buf, [packed_amax_views, amaxes], 1.0 + ) + + # Update scale_invs. + packed_scales = torch.empty(len(scales), dtype=torch.float32, device=scales[0].device) + packed_scale_views = [packed_scales[i].view(1) for i in range(len(scales))] + multi_tensor_applier( + multi_tensor_scale, dummy_overflow_buf, [scales, packed_scale_views], 1.0 + ) + torch.reciprocal(packed_scales, out=packed_scales) + multi_tensor_applier( + multi_tensor_scale, dummy_overflow_buf, [packed_scale_views, scale_invs], 1.0 + ) + + +def _cast_master_weights_to_fp8_current_scaling(params, group, use_fsdp_shard_model_weights=False): + r"""Helper function to cast master weights to FP8 primary weights for current scaling. + + Parameters + ---------- + params : List of tuple, each tuple contains a model weight, a master weight, and an offset + indicating the starting index of the master weight in the model weight. + group : The distributed group to do amax reduction. Typically it's the data parallel + group. + use_fsdp_shard_model_weights : bool, if True, it means that the model weights are sharded. + """ + + # Parameter attributes + device = params[0][0].device + fp8_dtype = params[0][0]._get_quantizer().dtype + force_pow_2_scales = params[0][0]._get_quantizer().force_pow_2_scales + amax_epsilon = params[0][0]._get_quantizer().amax_epsilon + + # Create a dummy overflow buffer, it's needed by multi_tensor_applier. + dummy_overflow_buf = torch.zeros(1, dtype=torch.int, device=device) + + # Create a contiguous buffer to store amaxes temporarily, so we can perform all all-reduce + # NCCL kernels at once. + packed_amaxes = torch.zeros(len(params), dtype=torch.float32, device=device) + amaxes = [packed_amaxes[i : i + 1] for i in range(len(params))] + + # Collect scales and scale_invs to update them after amax reduction. + scales, scale_invs = [], [] + + # --------------------------------------------------------------------------------------------- + # Step 1: Iterate through all the none empty master weights and compute amax of them. Store the + # amaxes in a contiguous buffer. If the master weight is None, the corresponding amax + # will be set to 0. + # --------------------------------------------------------------------------------------------- + for (model_weight, master_weight, _, _), amax in zip(params, amaxes): + + # Make sure all the model weights have the same numerical options. + quantizer = model_weight._get_quantizer() + assert quantizer.dtype == fp8_dtype + assert quantizer.force_pow_2_scales == force_pow_2_scales + assert quantizer.amax_epsilon == amax_epsilon + + scales.append(quantizer.scale.view(1)) + scale_invs.append(model_weight._scale_inv.view(1)) + + # Compute amax of the master weight and store it in packed_amaxes. + if master_weight is not None: + tex.compute_amax(master_weight, amax) + + # --------------------------------------------------------------------------------------------- + # Step 2: Perform all-reduce on packed_amaxes to get the global amax. + # --------------------------------------------------------------------------------------------- + torch.distributed.all_reduce(packed_amaxes, op=torch.distributed.ReduceOp.MAX, group=group) + + # --------------------------------------------------------------------------------------------- + # Step 3: Update scales and scale_invs. + # --------------------------------------------------------------------------------------------- + if fp8_dtype == tex.DType.kFloat8E4M3: + max_fp8 = 448.0 + elif fp8_dtype == tex.DType.kFloat8E5M2: + max_fp8 = 57344.0 + else: + raise ValueError(f"Unsupported FP8 dtype: {fp8_dtype}") + multi_tensor_applier( + multi_tensor_compute_scale_and_scale_inv, + dummy_overflow_buf, + [amaxes, scales, scale_invs], + max_fp8, + force_pow_2_scales, + amax_epsilon, + ) + + # --------------------------------------------------------------------------------------------- + # Step 4: Cast master weights to FP8. + # --------------------------------------------------------------------------------------------- + for (model_weight, master_weight, start_offset, model_weight_fragment), scale in zip( + params, scales + ): + # Reset transpose cache for all model weights. + # We cannot create transpose cache here because users (like megatron) may want to overlap + # the all-gather of model weights and forward process, so the model weight is not updated + # currently. + model_weight._reset_caches() + + # If master weight is None, it means that the master weight of the current model weight + # is in other DP ranks. + if master_weight is None: + continue + + # Cast master weight to FP8 + end_offset = start_offset + master_weight.numel() + if not use_fsdp_shard_model_weights: + model_weight_fragment = model_weight.reshape(-1)[start_offset:end_offset] + quantizer = Float8Quantizer( + scale=scale, + amax=torch.Tensor(), + fp8_dtype=model_weight._fp8_dtype, + ) + if use_fsdp_shard_model_weights and not isinstance(model_weight_fragment, Float8Tensor): + # NOTE: The fsdp shard model weight may be a unit8 tensor instead of + # a float8 tensor. We should handle this situation properly. + model_weight_fragment = quantizer.create_tensor_from_data( + model_weight_fragment.view(-1), + model_weight.dtype, + ) + quantizer.update_quantized(master_weight, model_weight_fragment)