diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 732f0a1..29f40bb 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -39,6 +39,7 @@ python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py" NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python3 -m pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || test_fail "test_fused_attn.py" NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python3 -m pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_paged_attn.py || test_fail "test_paged_attn.py" + if [ "$RET" -ne 0 ]; then echo "Error in the following test cases:$FAILED_CASES" exit 1 diff --git a/qa/L1_pytorch_distributed_unittest/test.sh b/qa/L1_pytorch_distributed_unittest/test.sh index 5776734..36d491e 100644 --- a/qa/L1_pytorch_distributed_unittest/test.sh +++ b/qa/L1_pytorch_distributed_unittest/test.sh @@ -26,6 +26,7 @@ python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py | python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py" # python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || test_fail "test_fusible_ops_with_userbuffers.py" ### TODO Debug UB support with te.Sequential python3 -m pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py || test_fail "test_fused_attn_with_cp.py" +python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py" if [ "$RET" -ne 0 ]; then echo "Error in the following test cases:$FAILED_CASES" diff --git a/tests/pytorch/distributed/run_cast_master_weights_to_fp8.py b/tests/pytorch/distributed/run_cast_master_weights_to_fp8.py new file mode 100644 index 0000000..1b38f72 --- /dev/null +++ b/tests/pytorch/distributed/run_cast_master_weights_to_fp8.py @@ -0,0 +1,671 @@ +#!/usr/bin/python3 + +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import argparse +import datetime +import os +import sys + +import torch +from torch import nn +import torch.distributed as dist + +from transformer_engine.common.recipe import ( + DelayedScaling, + Float8CurrentScaling, + Format, + Recipe, +) +import transformer_engine.pytorch as te +from transformer_engine.pytorch.tensor import QuantizedTensor, cast_master_weights_to_fp8 +from transformer_engine.pytorch.tensor.float8_tensor import ( + Float8Tensor, + Float8CurrentScalingQuantizer, +) +from transformer_engine.pytorch.tensor.utils import replace_raw_data + + +def _get_raw_data(quantized_tensor): + """Get the underlying data of a quantized tensor, used in zero-1 optimizer""" + if isinstance(quantized_tensor, Float8Tensor): + assert hasattr(quantized_tensor, "_data"), "Float8Tensor does not have _data attribute" + assert quantized_tensor._data.dtype == torch.uint8, "Float8Tensor _data must be uint8" + return quantized_tensor._data + else: + raise ValueError(f"Unsupported quantized tensor type: {type(quantized_tensor)}") + + +class MiniZero_1: + """A mini zero-1 optimizer implementation, just used for this test""" + + def __init__(self, weights, lr, dp_group): + self.rank = dist.get_rank(dp_group) + self.world_size = dist.get_world_size(dp_group) + + self.weights = weights + self.lr = lr + self.dp_group = dp_group + + # [self.offsets[i], self.offsets[i+1]) is the range of weights[i] in the global buffer + self.offsets = [0] + for weight in self.weights: + self.offsets.append(self.offsets[-1] + weight.numel()) + + # Padding to avoid global buffer cannot be divided by world size, so the offsets[-1] may + # not be the end range of the last weight. + if self.offsets[-1] % self.world_size != 0: + self.offsets[-1] += self.world_size - self.offsets[-1] % self.world_size + + self.master_weights = [] + # The start offset of the master weight in the weight + self.start_offsets = [] + # The overlapping area of the weight and this rank's local buffer + self.overlapping_areas = [] + + # The start and end of this rank's local buffer in the global buffer + rank_start = self.offsets[-1] // self.world_size * self.rank + rank_end = rank_start + self.offsets[-1] // self.world_size + + for weight, offset in zip(self.weights, self.offsets[:-1]): + if offset >= rank_end or (offset + weight.numel()) <= rank_start: + # This weight is not in this rank's local buffer + master_weight = None + start_offset = None + overlapping_area = None + else: + overlapping_start = max(rank_start, offset) + overlapping_end = min(rank_end, offset + weight.numel()) + length = overlapping_end - overlapping_start + start_offset = overlapping_start - offset + if isinstance(weight, QuantizedTensor): + # If weight is a FP8 tensor, we need to use the original high precision version + # to initialize the master weight. + high_precision_init_val = weight.get_high_precision_init_val().view(-1) + master_weight = high_precision_init_val.to(weight.device).float()[ + start_offset : start_offset + length + ] + else: + master_weight = ( + weight.detach().view(-1).float()[start_offset : start_offset + length] + ) + overlapping_area = (overlapping_start, overlapping_end) + self.master_weights.append(master_weight) + self.start_offsets.append(start_offset) + self.overlapping_areas.append(overlapping_area) + + # Create global buffer for grads reduce-scatter + self.grad_buffer = torch.empty( + [self.offsets[-1]], dtype=torch.float32, device=weights[0].device + ) + self.grad_buffer_slice = self.grad_buffer[rank_start:rank_end] + + # Create global buffer for weights all-gather + if isinstance(self.weights[0], QuantizedTensor): + weight_buffer_dtype = torch.uint8 + else: + weight_buffer_dtype = weights[0].dtype + self.weight_buffer = torch.empty( + [self.offsets[-1]], dtype=weight_buffer_dtype, device=weights[0].device + ) + self.weight_buffer_slice = self.weight_buffer[rank_start:rank_end] + + def step(self): + # ----------------------------------------------------------------------------------------- + # Step 1: Copy grads to the grad buffer + # ----------------------------------------------------------------------------------------- + for weight, offset in zip(self.weights, self.offsets[:-1]): + start = offset + end = offset + weight.numel() + self.grad_buffer[start:end].copy_(weight.main_grad.view(-1)) + + # ----------------------------------------------------------------------------------------- + # Step 2: Grads reduce-scatter + # ----------------------------------------------------------------------------------------- + # Don't use reduce_scatter directly to explicitly control the reduce order. + # dist.reduce_scatter_tensor(self.grad_buffer_slice, self.grad_buffer, op=dist.ReduceOp.AVG, + # group=self.dp_group) + buffers = [torch.empty_like(self.grad_buffer) for _ in range(self.world_size)] + dist.all_gather(buffers, self.grad_buffer, group=self.dp_group) + for i in range(1, self.world_size): + buffers[0] += buffers[i] + rank_start = self.offsets[-1] // self.world_size * self.rank + rank_end = rank_start + self.offsets[-1] // self.world_size + self.grad_buffer_slice.copy_(buffers[0][rank_start:rank_end]) + self.grad_buffer_slice /= self.world_size + + # ----------------------------------------------------------------------------------------- + # Step 3: Update master weights + # ----------------------------------------------------------------------------------------- + for master_weight, overlapping_area in zip(self.master_weights, self.overlapping_areas): + if master_weight is None: + # This weight's master weight is in other rank. + continue + grad = self.grad_buffer[overlapping_area[0] : overlapping_area[1]] + master_weight -= grad * self.lr + + # ----------------------------------------------------------------------------------------- + # Step 4: Cast master weights to BF16 or FP8, depending on the type of the weight + # ----------------------------------------------------------------------------------------- + if isinstance(self.weights[0], QuantizedTensor): + # FP8 weights case + for i in range(1, len(self.weights)): + assert isinstance(self.weights[i], QuantizedTensor) + cast_master_weights_to_fp8( + self.weights, self.master_weights, self.start_offsets, self.dp_group + ) + else: + # BF16 weights case + for weight, master_weight, start_offset in zip( + self.weights, self.master_weights, self.start_offsets + ): + if master_weight is None: + continue + start = start_offset + end = start_offset + master_weight.numel() + weight.data.view(-1)[start:end].copy_(master_weight) + + # ----------------------------------------------------------------------------------------- + # Step 5: Copy the updated weights (not all weights) to the weight buffer + # ----------------------------------------------------------------------------------------- + for i in range(len(self.weights)): + master_weight = self.master_weights[i] + if master_weight is None: + continue + start_offset = self.start_offsets[i] + if isinstance(self.weights[i], QuantizedTensor): + weight = _get_raw_data(self.weights[i]) + else: + weight = self.weights[i] + weight_slice = weight.view(-1)[start_offset : start_offset + master_weight.numel()] + overlapping_start, overlapping_end = self.overlapping_areas[i] + self.weight_buffer[overlapping_start:overlapping_end].copy_(weight_slice) + + # ----------------------------------------------------------------------------------------- + # Step 6: Weight all-gather (FP8 or BF16) + # ----------------------------------------------------------------------------------------- + dist.all_gather_into_tensor( + self.weight_buffer, self.weight_buffer_slice, group=self.dp_group + ) + + # ----------------------------------------------------------------------------------------- + # Step 7: Copy the gathered weights from weight buffer to the actual weights + # ----------------------------------------------------------------------------------------- + for weight, offset in zip(self.weights, self.offsets[:-1]): + start = offset + end = offset + weight.numel() + if isinstance(weight, QuantizedTensor): + weight = _get_raw_data(weight) + weight.view(-1).data.copy_(self.weight_buffer[start:end]) + + +class MiniOptimizer: + + def __init__(self, weights, lr, dp_group): + self.world_size = dist.get_world_size(dp_group) + + self.weights = weights + self.lr = lr + self.dp_group = dp_group + + master_weights = [] + for weight in self.weights: + master_weights.append(weight.detach().float()) + self.master_weights = master_weights + + def step(self): + for weight, master_weight in zip(self.weights, self.master_weights): + main_grad = weight.main_grad + + # Don't use all-reduce directly to explicitly control the reduce order. + # dist.all_reduce(main_grad, op=dist.ReduceOp.AVG, group=self.dp_group) + buffers = [torch.empty_like(main_grad) for _ in range(self.world_size)] + dist.all_gather(buffers, main_grad, group=self.dp_group) + for i in range(1, self.world_size): + buffers[0] += buffers[i] + main_grad.copy_(buffers[0]) + main_grad /= self.world_size + + master_weight -= main_grad * self.lr + weight.data.copy_(master_weight) + + +class MiniFSDP: + def __init__(self, weights, lr, dp_group): + rank = dist.get_rank(dp_group) + world_size = dist.get_world_size(dp_group) + + self.weights = weights + self.lr = lr + self.dp_group = dp_group + + # Flatten the weights and pad to align with world size + raw_data_list = [ + _get_raw_data(w).view(-1) if isinstance(w, QuantizedTensor) else w.view(-1) + for w in weights + ] + if isinstance(weights[0], QuantizedTensor): + raw_data_list = [_get_raw_data(w).view(-1) for w in weights] + else: + raw_data_list = [w.view(-1) for w in weights] + self.flatten_weight, original_length = self._flatten_tensors_with_pad(raw_data_list) + + # Split flattened weights into shards + self.local_weight_shard = torch.chunk(self.flatten_weight, world_size)[rank] + self.local_main_grad_shard = torch.zeros_like(self.local_weight_shard) + shard_size = self.flatten_weight.size(0) // world_size + + # Map original tensors to flattened indices + tensor_indices = [] + cumulative_length = 0 + for tensor in raw_data_list: + length = tensor.size(0) + tensor_indices.append((cumulative_length, cumulative_length + length)) + cumulative_length += length + + # Build shard index mappings + self.weight_indices = [] + self.shard_indices = [] + for idx, (start, end) in enumerate(tensor_indices): + shard_start = rank * shard_size + shard_end = shard_start + shard_size + adjusted_end = min(shard_end, original_length) + + if start <= adjusted_end and end >= shard_start: + start_idx = max(start, shard_start) + end_idx = min(end, adjusted_end) + self.weight_indices.append((start_idx - start, end_idx - start)) + self.shard_indices.append((start_idx - shard_start, end_idx - shard_start)) + else: + self.weight_indices.append((None, None)) + self.shard_indices.append((None, None)) + + if isinstance(weights[idx], QuantizedTensor): + replace_raw_data( + weights[idx], self.flatten_weight[start:end].view(weights[idx].shape) + ) + else: + weights[idx].data = self.flatten_weight[start:end].view(weights[idx].shape) + + # Initialize local model weights and high-precision master weights + self.local_weights = [] + self.master_weights = [] + for i, weight in enumerate(self.weights): + weight_start, weight_end = self.weight_indices[i] + shard_start, shard_end = self.shard_indices[i] + if shard_start is not None and shard_end is not None: + local_weight_shard = self.local_weight_shard[shard_start:shard_end] + self.local_weights.append(local_weight_shard) + + if isinstance(weight, QuantizedTensor): + high_precision_init_val = weight.get_high_precision_init_val().view(-1) + master_weight_shard = high_precision_init_val.to(weight.device).float()[ + weight_start:weight_end + ] + else: + master_weight_shard = weight.detach().view(-1).float()[weight_start:weight_end] + self.master_weights.append(master_weight_shard) + else: + self.local_weights.append(None) + self.master_weights.append(None) + setattr( + weight, "main_grad", torch.zeros_like(weight, dtype=torch.float32, device="cuda") + ) + + def _flatten_tensors_with_pad(self, tensors): + """ + Flatten the list of tensors and pad them to align with the world size. + + Args: + tensors (list): List of tensors to flatten. + + Returns: + tuple: Flattened tensor and its original length before padding. + """ + world_size = dist.get_world_size(self.dp_group) + + flatten_tensor = torch.cat(tensors) + original_length = flatten_tensor.size(0) + + padding_needed = (world_size - original_length % world_size) % world_size + if padding_needed > 0: + flatten_tensor = torch.cat( + [flatten_tensor, torch.zeros(padding_needed, dtype=flatten_tensor.dtype)] + ) + + return flatten_tensor, original_length + + def zero_grad(self): + for weight in self.weights: + weight.grad = None + weight.main_grad.zero_() + + def step(self): + """ + Perform an optimization step for the distributed sharded model. + + This method includes: + 1. Gradient reduce-scatter: Synchronize gradients across all processes. + 2. Master weight update: Update high-precision master weights using local gradients. + 3. Precision casting: Cast updated master weights to FP8 or BF16 precision. + 4. Weight synchronization: All-gather updated weights across all processes. + + Returns: + None + """ + # Step 1: Reduce-scatter the gradients + main_grad_buffer, _ = self._flatten_tensors_with_pad( + [weight.main_grad.view(-1) for weight in self.weights] + ) + main_grad_buffer = main_grad_buffer.to(self.local_main_grad_shard.dtype) + dist.reduce_scatter_tensor( + self.local_main_grad_shard, main_grad_buffer, group=self.dp_group + ) + + # Step 2: Update the master weights + for weight, master_weight, (shard_start, shard_end) in zip( + self.weights, self.master_weights, self.shard_indices + ): + if master_weight is None: + continue + + # Extract the local gradient shard for this weight + grad = self.local_main_grad_shard[shard_start:shard_end] + + # Update the master weight using gradient descent + master_weight -= grad * self.lr + + # Step 3: Cast master weights to FP8 or BF16 precision + if isinstance(self.weights[0], QuantizedTensor): + local_weights = [] + for local_weight in self.local_weights: + if local_weight is None: + local_weights.append(None) + continue + + local_weights.append(local_weight) + + cast_master_weights_to_fp8( + self.weights, + self.master_weights, + [idx[0] for idx in self.weight_indices], + self.dp_group, + local_weights, + ) + else: + for weight, master_weight in zip(self.local_weights, self.master_weights): + if master_weight is None: + continue + + # Copy updated master weights to local weights + weight.data.copy_(master_weight) + + # Step 4: All-gather updated weights across processes + dist.all_gather_into_tensor( + self.flatten_weight, self.local_weight_shard, group=self.dp_group + ) + + +def _test_fsdp_cast_master_weights_to_fp8(quantization, dp_group): + rank = dist.get_rank(dp_group) + world_size = dist.get_world_size(dp_group) + + # Configuration constants + NUM_STEPS = 100 + SEED = 12345 + + torch.manual_seed(SEED) + torch.cuda.manual_seed(SEED) + + mock_groups = [dist.new_group(ranks=[i]) for i in range(world_size)] + mock_group = mock_groups[rank] + + linear_kwargs = { + "params_dtype": torch.bfloat16, + "bias": False, + "fuse_wgrad_accumulation": False, + } + + # Create model with FP8 weights + with te.fp8.fp8_model_init( + enabled=quantization is not None, + recipe=quantization_recipe(quantization), + preserve_high_precision_init_val=True, + ): + model_fp8 = nn.Sequential( + te.Linear(128, 256, **linear_kwargs), + te.Linear(256, 256 * 3, **linear_kwargs), + te.Linear(256 * 3, 128, **linear_kwargs), + ) + + # Create model with BF16 weights + model = nn.Sequential( + te.Linear(128, 256, **linear_kwargs), + te.Linear(256, 256 * 3, **linear_kwargs), + te.Linear(256 * 3, 128, **linear_kwargs), + ) + + # Make sure the BF16 model and FP8 model have the same initial weights + for w_fp8, w in zip(model_fp8.parameters(), model.parameters()): + high_precision_init_val = w_fp8.get_high_precision_init_val() + w.data.copy_(high_precision_init_val) + + optimizer_fp8 = MiniFSDP([w for w in model_fp8.parameters()], 10.0, dp_group) + optimizer = MiniFSDP([w for w in model.parameters()], 10.0, dp_group) + + for _ in range(100): + optimizer_fp8.zero_grad() + optimizer.zero_grad() + + inputs = [ + torch.randn(16, 128, dtype=torch.bfloat16, device="cuda") for _ in range(world_size) + ] + # Choose based on rank to make sure the inputs of different ranks are different. + x = inputs[rank] + + with te.fp8.fp8_autocast( + enabled=quantization is not None, + fp8_recipe=quantization_recipe(quantization), + fp8_group=mock_group, + ): + y_fp8 = model_fp8(x) + + with te.fp8_autocast( + enabled=quantization is not None, + fp8_recipe=quantization_recipe(quantization), + fp8_group=mock_group, + ): + y = model(x) + + targets = [torch.randn_like(y) for _ in range(world_size)] + # Choose based on rank to make sure the targets of different ranks are different. + target = targets[rank] + loss_fp8 = nn.MSELoss()(y_fp8, target) + loss = nn.MSELoss()(y, target) + + loss_fp8.backward() + loss.backward() + + optimizer_fp8.step() + optimizer.step() + + torch.testing.assert_close(loss_fp8, loss, atol=0, rtol=0) + + print( + f"✅ Successfully validated FSDP {NUM_STEPS} training steps with" + f" {quantization} quantization" + ) + + +def _test_zero_1(dp_group): + """Make sure the implementation of zero-1 optimizer is correct""" + rank = dist.get_rank(dp_group) + world_size = dist.get_world_size(dp_group) + + torch.manual_seed(12345) + torch.cuda.manual_seed(12345) + + weights = [ + torch.randn(256 * 256, dtype=torch.bfloat16, device="cuda"), + torch.randn(256 * 256 * 3, dtype=torch.bfloat16, device="cuda"), + torch.randn(256 * 256 * 2 - 1, dtype=torch.bfloat16, device="cuda"), + ] + + weights_1 = weights + weights_2 = [weight.clone() for weight in weights] + + lr = 1.0 + optimizer_1 = MiniZero_1(weights_1, lr, dp_group) + optimizer_2 = MiniOptimizer(weights_2, lr, dp_group) + + for _ in range(100): + for w1, w2 in zip(weights_1, weights_2): + main_grads = [ + torch.randn_like(w1, dtype=torch.float32, device="cuda") for _ in range(world_size) + ] + # Choose based on rank to make sure the grads of different ranks are different. + main_grad = main_grads[rank] + w1.main_grad = main_grad + w2.main_grad = main_grad + + optimizer_1.step() + optimizer_2.step() + + for w1, w2 in zip(weights_1, weights_2): + torch.testing.assert_close(w1, w2, atol=0, rtol=0) + + +def quantization_recipe(quantization) -> Recipe: + """Quantization recipe setup""" + if quantization == "fp8": + return DelayedScaling( + fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo="max" + ) + elif quantization == "fp8_cs": + return Float8CurrentScaling() + else: + raise ValueError(f"Unsupported quantization: {quantization}") + + +def _test_cast_master_weights_to_fp8(quantization, dp_group): + rank = dist.get_rank(dp_group) + world_size = dist.get_world_size(dp_group) + + torch.manual_seed(12345) + torch.cuda.manual_seed(12345) + + mock_groups = [dist.new_group(ranks=[i]) for i in range(world_size)] + mock_group = mock_groups[rank] + + linear_kwargs = {"params_dtype": torch.bfloat16, "bias": False, "fuse_wgrad_accumulation": True} + + # Create model with FP8 weights + with te.fp8.fp8_model_init( + enabled=quantization is not None, + recipe=quantization_recipe(quantization), + preserve_high_precision_init_val=True, + ): + model_fp8 = nn.Sequential( + te.Linear(128, 256, **linear_kwargs), + te.Linear(256, 256 * 3, **linear_kwargs), + te.Linear(256 * 3, 128, **linear_kwargs), + ) + + # Create model with BF16 weights + model = nn.Sequential( + te.Linear(128, 256, **linear_kwargs), + te.Linear(256, 256 * 3, **linear_kwargs), + te.Linear(256 * 3, 128, **linear_kwargs), + ) + + # Make sure the BF16 model and FP8 model have the same initial weights + for w_fp8, w in zip(model_fp8.parameters(), model.parameters()): + high_precision_init_val = w_fp8.get_high_precision_init_val() + w.data.copy_(high_precision_init_val) + + # Allocate main_grads for each weight + for w_fp8, w in zip(model_fp8.parameters(), model.parameters()): + w_fp8.main_grad = torch.zeros_like(w_fp8, dtype=torch.float32, device="cuda") + w.main_grad = torch.zeros_like(w, dtype=torch.float32, device="cuda") + + optimizer_fp8 = MiniZero_1([w for w in model_fp8.parameters()], 10.0, dp_group) + optimizer = MiniZero_1([w for w in model.parameters()], 10.0, dp_group) + + for _ in range(100): + for w_fp8, w in zip(model_fp8.parameters(), model.parameters()): + w_fp8.main_grad.zero_() + w.main_grad.zero_() + + inputs = [ + torch.randn(16, 128, dtype=torch.bfloat16, device="cuda") for _ in range(world_size) + ] + # Choose based on rank to make sure the inputs of different ranks are different. + x = inputs[rank] + + with te.fp8.fp8_autocast( + enabled=quantization is not None, + fp8_recipe=quantization_recipe(quantization), + fp8_group=mock_group, + ): + y_fp8 = model_fp8(x) + + with te.fp8_autocast( + enabled=quantization is not None, + fp8_recipe=quantization_recipe(quantization), + fp8_group=mock_group, + ): + y = model(x) + + targets = [torch.randn_like(y) for _ in range(world_size)] + # Choose based on rank to make sure the targets of different ranks are different. + target = targets[rank] + loss_fp8 = nn.MSELoss()(y_fp8, target) + loss = nn.MSELoss()(y, target) + + loss_fp8.backward() + loss.backward() + + optimizer_fp8.step() + optimizer.step() + + torch.testing.assert_close(loss_fp8, loss, atol=0, rtol=0) + + +def main(argv=None, namespace=None): + WORLD_RANK = int(os.getenv("RANK", "0")) + WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) + LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0")) + LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1")) + + assert WORLD_SIZE == LOCAL_SIZE # this test supports only 1 node + assert LOCAL_SIZE <= torch.cuda.device_count() + dist_init_kwargs = { + "backend": "nccl", + "rank": WORLD_RANK, + "world_size": WORLD_SIZE, + "timeout": datetime.timedelta(seconds=30), + } + dist_init_kwargs["init_method"] = "env://" + dist_init_kwargs["device_id"] = torch.device(f"cuda:{LOCAL_RANK}") + assert dist.is_nccl_available() + torch.cuda.set_device(LOCAL_RANK) + dist.init_process_group(**dist_init_kwargs) + + parser = argparse.ArgumentParser() + parser.add_argument("--quantization", type=str, default=None, choices=["fp8", "fp8_cs"]) + args = parser.parse_args(argv, namespace) + + dp_group = dist.new_group(backend="nccl") + _test_zero_1(dp_group) + _test_cast_master_weights_to_fp8(args.quantization, dp_group) + _test_fsdp_cast_master_weights_to_fp8(args.quantization, dp_group) + + dist.destroy_process_group() + return 0 + + +if __name__ == "__main__": + + sys.exit(main()) diff --git a/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py b/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py new file mode 100644 index 0000000..8ebe86b --- /dev/null +++ b/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py @@ -0,0 +1,35 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import os +import subprocess +from pathlib import Path + +import pytest +import torch +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager + + +if torch.cuda.device_count() < 2: + pytest.skip("cast_master_weights_to_fp8 test needs at least 2 GPUs.") + +fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() + +TEST_ROOT = Path(__file__).parent.resolve() +NUM_PROCS: int = min(2, torch.cuda.device_count()) +LAUNCH_CMD = ["torchrun", f"--nproc_per_node={NUM_PROCS}"] + + +def _run_test(quantization): + test_path = TEST_ROOT / "run_cast_master_weights_to_fp8.py" + test_cmd = LAUNCH_CMD + [str(test_path)] + ["--quantization", quantization] + result = subprocess.run(test_cmd, env=os.environ, check=False) + assert result.returncode == 0 + + +@pytest.mark.parametrize("quantization", ["fp8", "fp8_cs"]) +def test_cast_master_weights_to_fp8(quantization): + if not fp8_available: + pytest.skip(reason_for_no_fp8) + _run_test(quantization) diff --git a/tests/pytorch/references/ref_per_tensor_cs.py b/tests/pytorch/references/ref_per_tensor_cs.py index 1895b31..dad0c42 100644 --- a/tests/pytorch/references/ref_per_tensor_cs.py +++ b/tests/pytorch/references/ref_per_tensor_cs.py @@ -8,12 +8,8 @@ import transformer_engine_torch as tex from transformer_engine.pytorch.constants import TE_DType_To_Torch -# compute amax and scale -def _ref_compute_amax_scale(x, quant_dtype, eps, pow_2_scales): - x_fp32 = x.to(torch.float32) - amax = torch.amax(torch.abs(x_fp32)).view(1) - assert amax.dtype == torch.float, "amax must be a float tensor." - fp8_max = torch.finfo(quant_dtype).max +# Compute scale and scale_inv from amax +def _ref_compute_scale_and_scale_inv_from_amax(amax, fp8_max, eps, pow_2_scales): # Clamping amax to avoid division by small numbers amax = torch.max(amax, torch.tensor(eps)) @@ -52,6 +48,20 @@ def _ref_compute_amax_scale(x, quant_dtype, eps, pow_2_scales): # Compute scale_inv scale_inv = torch.reciprocal(scale) + return scale, scale_inv + + +# compute amax and scale +def _ref_compute_amax_scale(x, quant_dtype, eps, pow_2_scales): + x_fp32 = x.to(torch.float32) + amax = torch.amax(torch.abs(x_fp32)).view(1) + assert amax.dtype == torch.float, "amax must be a float tensor." + fp8_max = torch.finfo(quant_dtype).max + + scale, scale_inv = _ref_compute_scale_and_scale_inv_from_amax(amax, fp8_max, eps, pow_2_scales) + # Clamping amax to avoid division by small numbers + amax = torch.max(amax, torch.tensor(eps)) + return scale, scale_inv, amax @@ -103,3 +113,7 @@ def ref_per_tensor_cs_cast( qx_t = _multi_dim_transpose(qx) sx_t = sx return qx, sx, qx_t, sx_t + + +def ref_compute_scale_and_scale_inv_from_amax(amax, fp8_max, eps, pow_2_scales): + return _ref_compute_scale_and_scale_inv_from_amax(amax, fp8_max, eps, pow_2_scales) diff --git a/tests/pytorch/test_fused_rope.py b/tests/pytorch/test_fused_rope.py index e236a29..5d1adf4 100644 --- a/tests/pytorch/test_fused_rope.py +++ b/tests/pytorch/test_fused_rope.py @@ -11,52 +11,6 @@ from transformer_engine.pytorch.dot_product_attention.rope import ( ) -def _get_thd_freqs_on_this_cp_rank( - cp_rank: int, cp_size: int, x: torch.Tensor, freqs: torch.Tensor -) -> torch.Tensor: - if cp_size > 1: - cp_seg = x.size(0) // 2 - full_seqlen = cp_size * x.size(0) - return torch.cat( - [ - freqs[cp_rank * cp_seg : (cp_rank + 1) * cp_seg], - freqs[full_seqlen - (cp_rank + 1) * cp_seg : full_seqlen - cp_rank * cp_seg], - ] - ) - else: - return freqs[: x.size(0)] - - -def apply_rotary_pos_emb_thd( - t: torch.Tensor, - cu_seqlens: torch.Tensor, - freqs: torch.Tensor, - cp_size: int = 1, - cp_rank: int = 0, -) -> torch.Tensor: - """A baseline implementation of applying RoPE for `thd` format. - - Args: - t (Tensor): Input tensor T is of shape [t, h, d] - cu_seqlens(Tensor): Cumulative sum of sequence lengths in a batch for `t`, - with shape [b + 1] and dtype torch.int32. - freqs (Tensor): Rotary Positional embedding tensor freq is of shape [max_s, 1, 1, d] - - Returns: - Tensor: Shape [t, h, d]. The input tensor after applying RoPE. - """ - cu_seqlens = cu_seqlens // cp_size - seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() - return torch.cat( - [ - apply_rotary_pos_emb( - x.unsqueeze(1), _get_thd_freqs_on_this_cp_rank(cp_rank, cp_size, x, freqs) - ) - for x in torch.split(t, seqlens) - ] - ).squeeze(1) - - # Gradient is a broadcasted scalar def _overlapping_grad(output: torch.Tensor) -> torch.Tensor: return output.sum() * 2 @@ -76,6 +30,8 @@ def _non_overlapping_grad(output: torch.Tensor) -> torch.Tensor: @pytest.mark.parametrize("transpose", [None, (0, 1), (2, 3)]) @pytest.mark.parametrize("tensor_format", ["sbhd", "bshd"]) @pytest.mark.parametrize("loss_func", [_overlapping_grad, _non_overlapping_grad]) +@pytest.mark.parametrize("cp_size", [1, 2]) +@pytest.mark.parametrize("interleaved", [True, False]) def test_fused_rope( dtype: torch.dtype, seq_length: int, @@ -85,6 +41,8 @@ def test_fused_rope( transpose: Union[Tuple, None], tensor_format: str, loss_func: Callable, + cp_size: int, + interleaved: bool, ) -> None: device = torch.device("cuda:0") batch_size, head_num = 2, 64 @@ -99,35 +57,46 @@ def test_fused_rope( t = t.transpose(*transpose).contiguous().transpose(*transpose) t.requires_grad = True - rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent) - emb = rotary_pos_emb(seq_length) - - # unfused - # The fused kernel computes in float32 internally, so we force the unfused func to use float32 - # for more accurate comparison - output_unfused = apply_rotary_pos_emb( - t.float(), emb, tensor_format=tensor_format, fused=False - ).to(dtype) - loss_unfused = loss_func(output_unfused) - loss_unfused.backward() - grad_unfused = t.grad.detach().clone() - t.grad = None - - # fused - output_fused = apply_rotary_pos_emb( - t, - emb, - tensor_format=tensor_format, - fused=True, - ) - loss_fused = loss_func(output_fused) - loss_fused.backward() - grad_fused = t.grad.detach().clone() - t.grad = None + rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent, interleaved=interleaved) + emb = rotary_pos_emb(seq_length * cp_size) + assert emb.is_contiguous() - torch.testing.assert_close(output_fused, output_unfused) - torch.testing.assert_close(grad_fused, grad_unfused) - assert output_fused.is_contiguous() + for cp_rank in range(cp_size): + # unfused + # The fused kernel computes in float32 internally, so we force the unfused func to use float32 + # for more accurate comparison + output_unfused = apply_rotary_pos_emb( + t.float(), + emb, + tensor_format=tensor_format, + interleaved=interleaved, + fused=False, + cp_size=cp_size, + cp_rank=cp_rank, + ).to(dtype) + loss_unfused = loss_func(output_unfused) + loss_unfused.backward() + grad_unfused = t.grad.detach().clone() + t.grad = None + + # fused + output_fused = apply_rotary_pos_emb( + t, + emb, + tensor_format=tensor_format, + interleaved=interleaved, + fused=True, + cp_size=cp_size, + cp_rank=cp_rank, + ) + loss_fused = loss_func(output_fused) + loss_fused.backward() + grad_fused = t.grad.detach().clone() + t.grad = None + + torch.testing.assert_close(output_fused, output_unfused) + torch.testing.assert_close(grad_fused, grad_unfused) + assert output_fused.is_contiguous() @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16]) @@ -135,7 +104,8 @@ def test_fused_rope( @pytest.mark.parametrize("rotary_percent", [0.5, 1.0]) @pytest.mark.parametrize("transpose", [None, (1, 2)]) @pytest.mark.parametrize("loss_func", [_overlapping_grad, _non_overlapping_grad]) -@pytest.mark.parametrize("cp_size", [1, 2, 3]) +@pytest.mark.parametrize("cp_size", [1, 2]) +@pytest.mark.parametrize("interleaved", [True, False]) def test_fused_rope_thd( dtype: torch.dtype, hidden_size: int, @@ -143,6 +113,7 @@ def test_fused_rope_thd( transpose: Union[Tuple, None], loss_func: Callable, cp_size: int, + interleaved: bool, ) -> None: device = torch.device("cuda:0") batch_size, head_num = 2, 64 @@ -170,15 +141,23 @@ def test_fused_rope_thd( t = t.transpose(*transpose).contiguous().transpose(*transpose) t.requires_grad = True - rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent) + rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent, interleaved=interleaved) emb = rotary_pos_emb(cu_seqlens_padded[-1]) + assert emb.is_contiguous() for cp_rank in range(cp_size): # unfused # The fused kernel computes in float32 internally, so we force the unfused func to use float32 # for more accurate comparison - output_unfused = apply_rotary_pos_emb_thd( - t.float(), cu_seqlens_padded, emb, cp_size, cp_rank + output_unfused = apply_rotary_pos_emb( + t.float(), + emb, + tensor_format="thd", + interleaved=interleaved, + fused=False, + cu_seqlens=cu_seqlens_padded, + cp_size=cp_size, + cp_rank=cp_rank, ).to(dtype) loss_unfused = loss_func(output_unfused) loss_unfused.backward() @@ -189,6 +168,7 @@ def test_fused_rope_thd( output_fused = apply_rotary_pos_emb( t, emb, + interleaved=interleaved, fused=True, tensor_format="thd", cu_seqlens=cu_seqlens_padded, diff --git a/tests/pytorch/test_multi_tensor.py b/tests/pytorch/test_multi_tensor.py index ecc06c3..4dc1ec0 100644 --- a/tests/pytorch/test_multi_tensor.py +++ b/tests/pytorch/test_multi_tensor.py @@ -9,6 +9,9 @@ import transformer_engine.pytorch as te import transformer_engine_torch as tex from transformer_engine.pytorch.optimizers import MultiTensorApply +from references.ref_per_tensor_cs import ref_compute_scale_and_scale_inv_from_amax + + input_size_pairs = [ (7777 * 77, 555 * 555), (777, 555), @@ -216,3 +219,42 @@ def test_multi_tensor_unscale_l2norm(input_size_pair, applier, repeat, in_type, if per_tensor: torch.testing.assert_close(norm_per_tensor, normab.broadcast_to(norm_per_tensor.shape)) assert overflow_buf.item() == 0 + + +@pytest.mark.parametrize("input_size_pair", input_size_pairs + [(1, 1)]) +@pytest.mark.parametrize("applier", appliers) +@pytest.mark.parametrize("repeat", [1, 55]) +@pytest.mark.parametrize("max_fp8", [448.0, 57344.0]) +@pytest.mark.parametrize("pow_2_scales", [False, True]) +@pytest.mark.parametrize("epsilon", [0.0, 100.0]) +def test_multi_tensor_compute_scale_and_scale_inv( + input_size_pair, applier, repeat, max_fp8, pow_2_scales, epsilon +): + sizea, sizeb = input_size_pair + device = torch.device("cuda") + overflow_buf = torch.zeros(1, dtype=torch.int32, device=device) + a = torch.randn([sizea], dtype=torch.float32, device=device).abs() + b = torch.randn([sizeb], dtype=torch.float32, device=device).abs() + + amax_list = [] + for i in range(repeat): + amax_list += [a.clone(), b.clone()] + + scale_list = [torch.empty_like(x) for x in amax_list] + scale_inv_list = [torch.empty_like(x) for x in amax_list] + + applier( + tex.multi_tensor_compute_scale_and_scale_inv, + overflow_buf, + [amax_list, scale_list, scale_inv_list], + max_fp8, + pow_2_scales, + epsilon, + ) + + for amax, scale, scale_inv in zip(amax_list, scale_list, scale_inv_list): + scale_ref, scale_inv_ref = ref_compute_scale_and_scale_inv_from_amax( + amax, max_fp8, epsilon, pow_2_scales + ) + torch.testing.assert_close(scale, scale_ref, rtol=0, atol=0) + torch.testing.assert_close(scale_inv, scale_inv_ref, rtol=0, atol=0) diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 1e6250f..980eeef 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -36,7 +36,12 @@ from transformer_engine.common import recipe import transformer_engine_torch as tex from transformer_engine.pytorch.cpp_extensions import general_gemm from transformer_engine.pytorch.module.base import get_workspace -from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer +from transformer_engine.pytorch.tensor import QuantizedTensor +from transformer_engine.pytorch.tensor.float8_tensor import ( + Float8Quantizer, + Float8CurrentScalingQuantizer, +) +from transformer_engine.pytorch.tensor.utils import replace_raw_data from test_numerics import reset_rng_states, dtype_tols # Only run FP8 tests on supported devices. @@ -1196,3 +1201,70 @@ def _run_attention_extra_state(dtype, config, checkpoint=False, mimic_v1_6=False outputs.append(p.grad) return outputs + + +@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) +def test_replace_raw_data_for_float8tensor(): + """Test the functionality of replace_raw_data""" + torch.manual_seed(12345) + torch.cuda.manual_seed(12345) + + fp8_quantizer = Float8CurrentScalingQuantizer(fp8_dtype=tex.DType.kFloat8E4M3, device="cuda") + fp8_tensor = fp8_quantizer.make_empty([128, 128], dtype=torch.bfloat16, device="cuda") + random_bf16_data = torch.randn(fp8_tensor.shape, dtype=torch.bfloat16, device="cuda") + fp8_quantizer.update_quantized(random_bf16_data, fp8_tensor) + + attrs_to_check = ["_quantizer", "_fp8_dtype", "_scale_inv", "_transpose", "_transpose_invalid"] + attrs = {} + for attr in attrs_to_check: + attrs[attr] = getattr(fp8_tensor, attr) + + old_data = fp8_tensor._data + new_data = torch.empty_like(old_data) + replace_raw_data(fp8_tensor, new_data) + + # Make sure the new_data is properly assigned. + assert fp8_tensor._data.data_ptr() != old_data.data_ptr() + assert fp8_tensor._data.data_ptr() == new_data.data_ptr() + # Make sure the values are not changed. + torch.testing.assert_close(old_data, fp8_tensor._data, atol=0, rtol=0) + # Make sure other attributes are not changed (totally identical) + for attr in attrs_to_check: + assert id(getattr(fp8_tensor, attr)) == id(attrs[attr]) + + +@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) +def test_fp8_model_init_high_precision_init_val(): + """Test fp8_model_init with preserve_high_precision_init_val=True""" + with fp8_model_init(preserve_high_precision_init_val=True): + model = Linear(768, 768) + + weight = model.weight + + assert isinstance(weight, QuantizedTensor), "Weight should be QuantizedTensor" + assert hasattr(weight, "_high_precision_init_val"), "_high_precision_init_val not found" + assert hasattr(weight, "get_high_precision_init_val"), "get_high_precision_init_val() not found" + assert hasattr( + weight, "clear_high_precision_init_val" + ), "clear_high_precision_init_val() not found" + + high_precision = weight.get_high_precision_init_val() + assert high_precision.device.type == "cpu", "high_precision_init_val is not on the CPU" + + new_weight = weight._get_quantizer().make_empty( + shape=weight.shape, dtype=weight.dtype, device=weight.device + ) + weight._get_quantizer().update_quantized(high_precision.to(weight.device), new_weight) + + torch.testing.assert_close( + new_weight.dequantize(dtype=weight.dtype), + weight.dequantize(dtype=weight.dtype), + rtol=0, + atol=0, + ) + + weight.clear_high_precision_init_val() + assert weight.get_high_precision_init_val() is None, "clear_high_precision_init_val() not work" + assert not hasattr( + weight, "._high_precision_init_val" + ), "clear_high_precision_init_val() not work" diff --git a/transformer_engine/common/fused_rope/fused_rope.cu b/transformer_engine/common/fused_rope/fused_rope.cu index 7f35ddd..1ab6d4e 100644 --- a/transformer_engine/common/fused_rope/fused_rope.cu +++ b/transformer_engine/common/fused_rope/fused_rope.cu @@ -16,10 +16,11 @@ namespace transformer_engine { template __device__ void fused_rope_block_forward(const scalar_t *src, const float *freqs, scalar_t *dst, - const int s_id, const int offset_block, - const int offset_block_dst, const int h, const int d, - const int d2, const int stride_h, const int stride_d, - const int o_stride_h, const int o_stride_d) { + const bool interleaved, const int s_id, + const int offset_block, const int offset_block_dst, + const int h, const int d, const int d2, const int stride_h, + const int stride_d, const int o_stride_h, + const int o_stride_d) { #pragma unroll for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) { float v_cos, v_sin; @@ -29,9 +30,18 @@ __device__ void fused_rope_block_forward(const scalar_t *src, const float *freqs int offset_src = offset_block + h_id * stride_h + d_id * stride_d; int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d; float v_src = src[offset_src]; - float v_src_rotate = (d_id + d2 / 2 < d2) - ? -static_cast(src[offset_src + (d2 / 2) * stride_d]) - : static_cast(src[offset_src + (d2 / 2 - d2) * stride_d]); + float v_src_rotate; + if (!interleaved) { + v_src_rotate = (d_id + d2 / 2 < d2) + ? -static_cast(src[offset_src + (d2 / 2) * stride_d]) + : static_cast(src[offset_src + (d2 / 2 - d2) * stride_d]); + } else { + v_src_rotate = (d_id % 2 == 0) + // d_id + 1 + ? -static_cast(src[offset_src + stride_d]) + // d_id - 1 + : static_cast(src[offset_src - stride_d]); + } dst[offset_dst] = v_src * v_cos + v_src_rotate * v_sin; } } @@ -52,22 +62,39 @@ __device__ void fused_rope_block_forward(const scalar_t *src, const float *freqs template __device__ void fused_rope_block_backward(const scalar_t *src, const float *freqs, scalar_t *dst, - const int s_id, const int offset_block, - const int offset_block_dst, const int h, const int d, - const int d2, const int stride_h, const int stride_d, + const bool interleaved, const int s_id, + const int offset_block, const int offset_block_dst, + const int h, const int d, const int d2, + const int stride_h, const int stride_d, const int o_stride_h, const int o_stride_d) { #pragma unroll for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) { float v_cos = cosf(freqs[s_id * d2 + d_id]); - float v_sin = (d_id + d2 / 2 < d2) ? sinf(freqs[s_id * d2 + d_id + d2 / 2]) - : -sinf(freqs[s_id * d2 + d_id + d2 / 2 - d2]); + float v_sin; + if (!interleaved) { + v_sin = (d_id + d2 / 2 < d2) ? sinf(freqs[s_id * d2 + d_id + d2 / 2]) + : -sinf(freqs[s_id * d2 + d_id + d2 / 2 - d2]); + } else { + v_sin = + (d_id % 2 == 0) ? sinf(freqs[s_id * d2 + d_id + 1]) : -sinf(freqs[s_id * d2 + d_id - 1]); + } #pragma unroll for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { int offset_src = offset_block + h_id * stride_h + d_id * stride_d; int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d; float v_src = src[offset_src]; - float v_src_rotate = (d_id + d2 / 2 < d2) ? src[offset_src + (d2 / 2) * stride_d] - : src[offset_src + (d2 / 2 - d2) * stride_d]; + float v_src_rotate; + if (!interleaved) { + v_src_rotate = (d_id + d2 / 2 < d2) + ? static_cast(src[offset_src + (d2 / 2) * stride_d]) + : static_cast(src[offset_src + (d2 / 2 - d2) * stride_d]); + } else { + v_src_rotate = (d_id % 2 == 0) + // d_id + 1 + ? static_cast(src[offset_src + stride_d]) + // d_id - 1 + : static_cast(src[offset_src - stride_d]); + } dst[offset_dst] = v_src * v_cos + v_src_rotate * v_sin; } } @@ -87,51 +114,33 @@ __device__ void fused_rope_block_backward(const scalar_t *src, const float *freq } template -__global__ void fused_rope_forward_kernel(const scalar_t *src, const float *freqs, scalar_t *dst, +__global__ void fused_rope_forward_kernel(const scalar_t *src, const int *cu_seqlens, + const float *freqs, scalar_t *dst, const bool interleaved, + const int cp_size, const int cp_rank, const int s, const int h, const int d, const int d2, - const int stride_s, const int stride_b, + const int stride_s_or_t, const int stride_b, const int stride_h, const int stride_d, - const int o_stride_s, const int o_stride_b, + const int o_stride_s_or_t, const int o_stride_b, const int o_stride_h, const int o_stride_d) { int s_id = blockIdx.x, b_id = blockIdx.y; - int offset_block = s_id * stride_s + b_id * stride_b; - int offset_block_dst = s_id * o_stride_s + b_id * o_stride_b; - fused_rope_block_forward(src, freqs, dst, s_id, offset_block, offset_block_dst, h, d, d2, - stride_h, stride_d, o_stride_h, o_stride_d); -} - -template -__global__ void fused_rope_backward_kernel(const scalar_t *src, const float *freqs, scalar_t *dst, - const int h, const int d, const int d2, - const int stride_s, const int stride_b, - const int stride_h, const int stride_d, - const int o_stride_s, const int o_stride_b, - const int o_stride_h, const int o_stride_d) { - int s_id = blockIdx.x, b_id = blockIdx.y; - int offset_block = s_id * stride_s + b_id * stride_b; - int offset_block_dst = s_id * o_stride_s + b_id * o_stride_b; - fused_rope_block_backward(src, freqs, dst, s_id, offset_block, offset_block_dst, h, d, d2, - stride_h, stride_d, o_stride_h, o_stride_d); -} - -template -__global__ void fused_rope_thd_forward_kernel(const scalar_t *src, const int *cu_seqlens, - const float *freqs, scalar_t *dst, const int cp_size, - const int cp_rank, const int h, const int d, - const int d2, const int stride_t, const int stride_h, - const int stride_d, const int o_stride_t, - const int o_stride_h, const int o_stride_d) { - int s_id = blockIdx.x, b_id = blockIdx.y; - int start = cu_seqlens[b_id] / cp_size; - int end = cu_seqlens[b_id + 1] / cp_size; - int t_id = s_id + start; - if (t_id >= end) return; - int offset_block = t_id * stride_t; - int offset_block_dst = t_id * o_stride_t; + int offset_block, offset_block_dst; + int cur_seqlens; + if (cu_seqlens != nullptr) { // THD + int start = cu_seqlens[b_id] / cp_size; + int end = cu_seqlens[b_id + 1] / cp_size; + int t_id = s_id + start; + if (t_id >= end) return; + offset_block = t_id * stride_s_or_t; + offset_block_dst = t_id * o_stride_s_or_t; + cur_seqlens = end - start; + } else { // SBHD/BSHD + offset_block = s_id * stride_s_or_t + b_id * stride_b; + offset_block_dst = s_id * o_stride_s_or_t + b_id * o_stride_b; + cur_seqlens = s; + } int s_id_for_freqs; if (cp_size > 1) { - int cur_seqlens = end - start; assert(cur_seqlens % 2 == 0); if (s_id < cur_seqlens / 2) { s_id_for_freqs = s_id + cp_rank * cur_seqlens / 2; @@ -142,28 +151,37 @@ __global__ void fused_rope_thd_forward_kernel(const scalar_t *src, const int *cu } else { s_id_for_freqs = s_id; } - fused_rope_block_forward(src, freqs, dst, s_id_for_freqs, offset_block, offset_block_dst, h, d, - d2, stride_h, stride_d, o_stride_h, o_stride_d); + + fused_rope_block_forward(src, freqs, dst, interleaved, s_id_for_freqs, offset_block, + offset_block_dst, h, d, d2, stride_h, stride_d, o_stride_h, o_stride_d); } template -__global__ void fused_rope_thd_backward_kernel(const scalar_t *src, const int *cu_seqlens, - const float *freqs, scalar_t *dst, const int cp_size, - const int cp_rank, const int h, const int d, - const int d2, const int stride_t, const int stride_h, - const int stride_d, const int o_stride_t, - const int o_stride_h, const int o_stride_d) { +__global__ void fused_rope_backward_kernel( + const scalar_t *src, const int *cu_seqlens, const float *freqs, scalar_t *dst, + const bool interleaved, const int cp_size, const int cp_rank, const int s, const int h, + const int d, const int d2, const int stride_s_or_t, const int stride_b, const int stride_h, + const int stride_d, const int o_stride_s_or_t, const int o_stride_b, const int o_stride_h, + const int o_stride_d) { int s_id = blockIdx.x, b_id = blockIdx.y; - int start = cu_seqlens[b_id] / cp_size; - int end = cu_seqlens[b_id + 1] / cp_size; - int t_id = s_id + start; - if (t_id >= end) return; - int offset_block = t_id * stride_t; - int offset_block_dst = t_id * o_stride_t; + int offset_block, offset_block_dst; + int cur_seqlens; + if (cu_seqlens != nullptr) { // THD + int start = cu_seqlens[b_id] / cp_size; + int end = cu_seqlens[b_id + 1] / cp_size; + int t_id = s_id + start; + if (t_id >= end) return; + offset_block = t_id * stride_s_or_t; + offset_block_dst = t_id * o_stride_s_or_t; + cur_seqlens = end - start; + } else { // SBHD/BSHD + offset_block = s_id * stride_s_or_t + b_id * stride_b; + offset_block_dst = s_id * o_stride_s_or_t + b_id * o_stride_b; + cur_seqlens = s; + } int s_id_for_freqs; if (cp_size > 1) { - int cur_seqlens = end - start; assert(cur_seqlens % 2 == 0); if (s_id < cur_seqlens / 2) { s_id_for_freqs = s_id + cp_rank * cur_seqlens / 2; @@ -174,193 +192,136 @@ __global__ void fused_rope_thd_backward_kernel(const scalar_t *src, const int *c } else { s_id_for_freqs = s_id; } - fused_rope_block_backward(src, freqs, dst, s_id_for_freqs, offset_block, offset_block_dst, h, d, - d2, stride_h, stride_d, o_stride_h, o_stride_d); + + fused_rope_block_backward(src, freqs, dst, interleaved, s_id_for_freqs, offset_block, + offset_block_dst, h, d, d2, stride_h, stride_d, o_stride_h, o_stride_d); } template -void fused_rope_forward_launcher(const scalar_t *input, const float *freqs, scalar_t *output, +void fused_rope_forward_launcher(const scalar_t *input, const int *cu_seqlens, const float *freqs, + scalar_t *output, const NVTE_QKV_Format qkv_format, + const bool interleaved, const int cp_size, const int cp_rank, const int s, const int b, const int h, const int d, const int d2, - const int stride_s, const int stride_b, const int stride_h, - const int stride_d, const int o_stride_s, const int o_stride_b, - const int o_stride_h, const int o_stride_d, cudaStream_t stream) { + const int stride_s_or_t, const int stride_b, const int stride_h, + const int stride_d, cudaStream_t stream) { int warps_per_block = h < 16 ? 4 : 8; dim3 blocks(s, b); dim3 threads(THREADS_PER_WARP, warps_per_block); + int o_stride_s_or_t, o_stride_b; + if (qkv_format == NVTE_QKV_Format::NVTE_THD) { + NVTE_CHECK(cu_seqlens != nullptr, "cu_seqlens is required for THD format"); + o_stride_s_or_t = h * d; + o_stride_b = 0; + } else if (qkv_format == NVTE_QKV_Format::NVTE_SBHD) { + o_stride_s_or_t = b * h * d; + o_stride_b = h * d; + } else { + o_stride_s_or_t = h * d; + o_stride_b = s * h * d; + } + const int o_stride_h = d; + const int o_stride_d = 1; fused_rope_forward_kernel<<>>( - input, freqs, output, h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, - o_stride_b, o_stride_h, o_stride_d); + input, cu_seqlens, freqs, output, interleaved, cp_size, cp_rank, s, h, d, d2, stride_s_or_t, + stride_b, stride_h, stride_d, o_stride_s_or_t, o_stride_b, o_stride_h, o_stride_d); NVTE_CHECK_CUDA(cudaGetLastError()); } template -void fused_rope_backward_launcher(const scalar_t *output_grads, const float *freqs, - scalar_t *input_grads, const int s, const int b, const int h, - const int d, const int d2, const int stride_s, const int stride_b, - const int stride_h, const int stride_d, const int o_stride_s, - const int o_stride_b, const int o_stride_h, const int o_stride_d, +void fused_rope_backward_launcher(const scalar_t *output_grads, const int *cu_seqlens, + const float *freqs, scalar_t *input_grads, + const NVTE_QKV_Format qkv_format, const bool interleaved, + const int cp_size, const int cp_rank, const int s, const int b, + const int h, const int d, const int d2, const int stride_s_or_t, + const int stride_b, const int stride_h, const int stride_d, cudaStream_t stream) { int warps_per_block = h < 16 ? 4 : 8; dim3 blocks(s, b); dim3 threads(THREADS_PER_WARP, warps_per_block); + int o_stride_s_or_t, o_stride_b; + if (qkv_format == NVTE_QKV_Format::NVTE_THD) { + NVTE_CHECK(cu_seqlens != nullptr, "cu_seqlens is required for THD format"); + o_stride_s_or_t = h * d; + o_stride_b = 0; + } else if (qkv_format == NVTE_QKV_Format::NVTE_SBHD) { + o_stride_s_or_t = b * h * d; + o_stride_b = h * d; + } else { + o_stride_s_or_t = h * d; + o_stride_b = s * h * d; + } + const int o_stride_h = d; + const int o_stride_d = 1; fused_rope_backward_kernel<<>>( - output_grads, freqs, input_grads, h, d, d2, stride_s, stride_b, stride_h, stride_d, - o_stride_s, o_stride_b, o_stride_h, o_stride_d); - NVTE_CHECK_CUDA(cudaGetLastError()); -} - -template -void fused_rope_thd_forward_launcher(const scalar_t *input, const int *cu_seqlens, - const float *freqs, scalar_t *output, const int cp_size, - const int cp_rank, const int max_s, const int b, const int h, - const int d, const int d2, const int stride_t, - const int stride_h, const int stride_d, const int o_stride_t, - const int o_stride_h, const int o_stride_d, - cudaStream_t stream) { - int warps_per_block = h < 16 ? 4 : 8; - dim3 blocks(max_s, b); - dim3 threads(THREADS_PER_WARP, warps_per_block); - - fused_rope_thd_forward_kernel<<>>( - input, cu_seqlens, freqs, output, cp_size, cp_rank, h, d, d2, stride_t, stride_h, stride_d, - o_stride_t, o_stride_h, o_stride_d); + output_grads, cu_seqlens, freqs, input_grads, interleaved, cp_size, cp_rank, s, h, d, d2, + stride_s_or_t, stride_b, stride_h, stride_d, o_stride_s_or_t, o_stride_b, o_stride_h, + o_stride_d); NVTE_CHECK_CUDA(cudaGetLastError()); } -template -void fused_rope_thd_backward_launcher(const scalar_t *output_grads, const int *cu_seqlens, - const float *freqs, scalar_t *input_grads, const int cp_size, - const int cp_rank, const int max_s, const int b, const int h, - const int d, const int d2, const int stride_t, - const int stride_h, const int stride_d, const int o_stride_t, - const int o_stride_h, const int o_stride_d, - cudaStream_t stream) { - int warps_per_block = h < 16 ? 4 : 8; - dim3 blocks(max_s, b); - dim3 threads(THREADS_PER_WARP, warps_per_block); - - fused_rope_thd_backward_kernel<<>>( - output_grads, cu_seqlens, freqs, input_grads, cp_size, cp_rank, h, d, d2, stride_t, stride_h, - stride_d, o_stride_t, o_stride_h, o_stride_d); - NVTE_CHECK_CUDA(cudaGetLastError()); -} - -void fused_rope_forward(const Tensor &input, const Tensor &freqs, Tensor *output, const int s, - const int b, const int h, const int d, const int d2, const int stride_s, - const int stride_b, const int stride_h, const int stride_d, - const int o_stride_s, const int o_stride_b, const int o_stride_h, - const int o_stride_d, cudaStream_t stream) { +void fused_rope_forward(const Tensor &input, const Tensor &cu_seqlens, const Tensor &freqs, + Tensor *output, const NVTE_QKV_Format qkv_format, const bool interleaved, + const int cp_size, const int cp_rank, const int s, const int b, const int h, + const int d, const int d2, const int stride_s_or_t, const int stride_b, + const int stride_h, const int stride_d, cudaStream_t stream) { TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( input.data.dtype, scalar_t, fused_rope_forward_launcher(reinterpret_cast(input.data.dptr), + reinterpret_cast(cu_seqlens.data.dptr), reinterpret_cast(freqs.data.dptr), - reinterpret_cast(output->data.dptr), s, b, h, d, d2, - stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b, - o_stride_h, o_stride_d, stream);); + reinterpret_cast(output->data.dptr), qkv_format, + interleaved, cp_size, cp_rank, s, b, h, d, d2, stride_s_or_t, + stride_b, stride_h, stride_d, stream);); } -void fused_rope_backward(const Tensor &output_grads, const Tensor &freqs, Tensor *input_grads, - const int s, const int b, const int h, const int d, const int d2, - const int stride_s, const int stride_b, const int stride_h, - const int stride_d, const int o_stride_s, const int o_stride_b, - const int o_stride_h, const int o_stride_d, cudaStream_t stream) { +void fused_rope_backward(const Tensor &output_grads, const Tensor &cu_seqlens, const Tensor &freqs, + Tensor *input_grads, const NVTE_QKV_Format qkv_format, + const bool interleaved, const int cp_size, const int cp_rank, const int s, + const int b, const int h, const int d, const int d2, + const int stride_s_or_t, const int stride_b, const int stride_h, + const int stride_d, cudaStream_t stream) { TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( output_grads.data.dtype, scalar_t, fused_rope_backward_launcher(reinterpret_cast(output_grads.data.dptr), + reinterpret_cast(cu_seqlens.data.dptr), reinterpret_cast(freqs.data.dptr), - reinterpret_cast(input_grads->data.dptr), s, b, h, d, - d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, - o_stride_b, o_stride_h, o_stride_d, stream);); -} - -void fused_rope_thd_forward(const Tensor &input, const Tensor &cu_seqlens, const Tensor &freqs, - Tensor *output, const int cp_size, const int cp_rank, const int max_s, - const int b, const int h, const int d, const int d2, const int stride_t, - const int stride_h, const int stride_d, const int o_stride_t, - const int o_stride_h, const int o_stride_d, cudaStream_t stream) { - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - input.data.dtype, scalar_t, - fused_rope_thd_forward_launcher(reinterpret_cast(input.data.dptr), - reinterpret_cast(cu_seqlens.data.dptr), - reinterpret_cast(freqs.data.dptr), - reinterpret_cast(output->data.dptr), cp_size, - cp_rank, max_s, b, h, d, d2, stride_t, stride_h, stride_d, - o_stride_t, o_stride_h, o_stride_d, stream);); -} - -void fused_rope_thd_backward(const Tensor &output_grads, const Tensor &cu_seqlens, - const Tensor &freqs, Tensor *input_grads, const int cp_size, - const int cp_rank, const int max_s, const int b, const int h, - const int d, const int d2, const int stride_t, const int stride_h, - const int stride_d, const int o_stride_t, const int o_stride_h, - const int o_stride_d, cudaStream_t stream) { - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - output_grads.data.dtype, scalar_t, - fused_rope_thd_backward_launcher(reinterpret_cast(output_grads.data.dptr), - reinterpret_cast(cu_seqlens.data.dptr), - reinterpret_cast(freqs.data.dptr), - reinterpret_cast(input_grads->data.dptr), - cp_size, cp_rank, max_s, b, h, d, d2, stride_t, stride_h, - stride_d, o_stride_t, o_stride_h, o_stride_d, stream);); + reinterpret_cast(input_grads->data.dptr), qkv_format, + interleaved, cp_size, cp_rank, s, b, h, d, d2, stride_s_or_t, + stride_b, stride_h, stride_d, stream);); } } // end namespace transformer_engine -void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor freqs, NVTETensor output, - const int s, const int b, const int h, const int d, const int d2, - const int stride_s, const int stride_b, const int stride_h, - const int stride_d, const int o_stride_s, const int o_stride_b, - const int o_stride_h, const int o_stride_d, cudaStream_t stream) { +void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor cu_seqlens, + const NVTETensor freqs, NVTETensor output, + const NVTE_QKV_Format qkv_format, const bool interleaved, + const int cp_size, const int cp_rank, const int s, const int b, + const int h, const int d, const int d2, const int stride_s_or_t, + const int stride_b, const int stride_h, const int stride_d, + cudaStream_t stream) { NVTE_API_CALL(nvte_fused_rope_forward); using namespace transformer_engine; fused_rope_forward(*reinterpret_cast(input), + *reinterpret_cast(cu_seqlens), *reinterpret_cast(freqs), reinterpret_cast(output), - s, b, h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b, - o_stride_h, o_stride_d, stream); + qkv_format, interleaved, cp_size, cp_rank, s, b, h, d, d2, stride_s_or_t, + stride_b, stride_h, stride_d, stream); } -void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor freqs, - NVTETensor input_grads, const int s, const int b, const int h, - const int d, const int d2, const int stride_s, const int stride_b, - const int stride_h, const int stride_d, const int o_stride_s, - const int o_stride_b, const int o_stride_h, const int o_stride_d, +void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor cu_seqlens, + const NVTETensor freqs, NVTETensor input_grads, + const NVTE_QKV_Format qkv_format, const bool interleaved, + const int cp_size, const int cp_rank, const int s, const int b, + const int h, const int d, const int d2, const int stride_s_or_t, + const int stride_b, const int stride_h, const int stride_d, cudaStream_t stream) { NVTE_API_CALL(nvte_fused_rope_backward); using namespace transformer_engine; fused_rope_backward(*reinterpret_cast(output_grads), + *reinterpret_cast(cu_seqlens), *reinterpret_cast(freqs), - reinterpret_cast(input_grads), s, b, h, d, d2, stride_s, stride_b, - stride_h, stride_d, o_stride_s, o_stride_b, o_stride_h, o_stride_d, stream); -} - -void nvte_fused_rope_thd_forward(const NVTETensor input, const NVTETensor cu_seqlens, - const NVTETensor freqs, NVTETensor output, const int cp_size, - const int cp_rank, const int max_s, const int b, const int h, - const int d, const int d2, const int stride_t, const int stride_h, - const int stride_d, const int o_stride_t, const int o_stride_h, - const int o_stride_d, cudaStream_t stream) { - NVTE_API_CALL(nvte_fused_rope_thd_forward); - using namespace transformer_engine; - fused_rope_thd_forward(*reinterpret_cast(input), - *reinterpret_cast(cu_seqlens), - *reinterpret_cast(freqs), - reinterpret_cast(output), cp_size, cp_rank, max_s, b, h, d, d2, - stride_t, stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d, stream); -} - -void nvte_fused_rope_thd_backward(const NVTETensor output_grads, const NVTETensor cu_seqlens, - const NVTETensor freqs, NVTETensor input_grads, const int cp_size, - const int cp_rank, const int max_s, const int b, const int h, - const int d, const int d2, const int stride_t, const int stride_h, - const int stride_d, const int o_stride_t, const int o_stride_h, - const int o_stride_d, cudaStream_t stream) { - NVTE_API_CALL(nvte_fused_rope_thd_backward); - using namespace transformer_engine; - fused_rope_thd_backward( - *reinterpret_cast(output_grads), - *reinterpret_cast(cu_seqlens), *reinterpret_cast(freqs), - reinterpret_cast(input_grads), cp_size, cp_rank, max_s, b, h, d, d2, stride_t, - stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d, stream); + reinterpret_cast(input_grads), qkv_format, interleaved, cp_size, + cp_rank, s, b, h, d, d2, stride_s_or_t, stride_b, stride_h, stride_d, stream); } diff --git a/transformer_engine/common/include/transformer_engine/fused_rope.h b/transformer_engine/common/include/transformer_engine/fused_rope.h index 41a0e3b..5a5bcc7 100644 --- a/transformer_engine/common/include/transformer_engine/fused_rope.h +++ b/transformer_engine/common/include/transformer_engine/fused_rope.h @@ -7,6 +7,7 @@ #ifndef TRANSFORMER_ENGINE_FUSED_ROPE_H_ #define TRANSFORMER_ENGINE_FUSED_ROPE_H_ +#include "fused_attn.h" #include "transformer_engine.h" #ifdef __cplusplus @@ -16,112 +17,63 @@ extern "C" { /*! \brief Apply rotary positional embedding to the input tensor. * * \param[in] input Input tensor for fused rope. + * \param[in] cu_seqlens The cumulative sum of sequence lengths tensor. + * (Required for the thd format, empty tensor for other formats) * \param[in] freqs The freqs tensor. * \param[out] output Output tensor. + * \param[in] qkv_format QKV format. + * \param[in] interleaved Whether to use interleaved rotary position embedding. + * \param[in] cp_size Context parallel world size. + * \param[in] cp_rank Context parallel rank. * \param[in] s Length of the s dimension of input. * \param[in] b Length of the b dimension of input. * \param[in] h Length of the h dimension of input. * \param[in] d Length of the d dimension of input. * \param[in] d2 Length of the d dimension of freqs. - * \param[in] stride_s Stride of the s dimension of input. - * \param[in] stride_b Stride of the b dimension of input. + * \param[in] stride_s_or_t Stride of the s (sbhd/bshd)/t (thd) dimension of input. + * \param[in] stride_b Stride of the b dimension of input. (0 for thd). * \param[in] stride_h Stride of the h dimension of input. * \param[in] stride_d Stride of the d dimension of input. - * \param[in] o_stride_s Stride of the s dimension of output. - * \param[in] o_stride_b Stride of the b dimension of output. - * \param[in] o_stride_h Stride of the h dimension of output. - * \param[in] o_stride_d Stride of the d dimension of output. * \param[in] stream CUDA stream used for the operation. */ -void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor freqs, NVTETensor output, - const int s, const int b, const int h, const int d, const int d2, - const int stride_s, const int stride_b, const int stride_h, - const int stride_d, const int o_stride_s, const int o_stride_b, - const int o_stride_h, const int o_stride_d, cudaStream_t stream); +void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor cu_seqlens, + const NVTETensor freqs, NVTETensor output, + const NVTE_QKV_Format qkv_format, const bool interleaved, + const int cp_size, const int cp_rank, const int s, const int b, + const int h, const int d, const int d2, const int stride_s_or_t, + const int stride_b, const int stride_h, const int stride_d, + cudaStream_t stream); /*! \brief Compute the backward of the fused rope. * * \param[in] output_grads Incoming gradient tensor for backward. + * \param[in] cu_seqlens The cumulative sum of sequence lengths tensor. + * (Required for the thd format, empty tensor for other formats) * \param[in] freqs The freqs tensor. * \param[out] input_grads Input gradient tensor to calculate. + * \param[in] qkv_format QKV format. + * \param[in] interleaved Whether to use interleaved rotary position embedding. + * \param[in] cp_size Context parallel world size. + * \param[in] cp_rank Context parallel rank. * \param[in] s Length of the s dimension of output_grads. * \param[in] b Length of the b dimension of output_grads. * \param[in] h Length of the h dimension of output_grads. * \param[in] d Length of the d dimension of output_grads. * \param[in] d2 Length of the d dimension of freqs. - * \param[in] stride_s Stride of the s dimension of output_grads. - * \param[in] stride_b Stride of the b dimension of output_grads. + * \param[in] stride_s_or_t Stride of the s (sbhd/bshd)/t (thd) dimension of output_grads. + * \param[in] stride_b Stride of the b dimension of output_grads. (0 for thd). * \param[in] stride_h Stride of the h dimension of output_grads. * \param[in] stride_d Stride of the d dimension of output_grads. - * \param[in] o_stride_s Stride of the s dimension of input_grads. - * \param[in] o_stride_b Stride of the b dimension of input_grads. - * \param[in] o_stride_h Stride of the h dimension of input_grads. - * \param[in] o_stride_d Stride of the d dimension of input_grads. * \param[in] stream CUDA stream used for the operation. */ -void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor freqs, - NVTETensor input_grads, const int s, const int b, const int h, - const int d, const int d2, const int stride_s, const int stride_b, - const int stride_h, const int stride_d, const int o_stride_s, - const int o_stride_b, const int o_stride_h, const int o_stride_d, +void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor cu_seqlens, + const NVTETensor freqs, NVTETensor input_grads, + const NVTE_QKV_Format qkv_format, const bool interleaved, + const int cp_size, const int cp_rank, const int s, const int b, + const int h, const int d, const int d2, const int stride_s_or_t, + const int stride_b, const int stride_h, const int stride_d, cudaStream_t stream); -/*! \brief Apply rotary positional embedding to the input tensor in thd format. - * - * \param[in] input Input tensor for fused rope. - * \param[in] cu_seqlens The cumulative sum of sequence lengths tensor. - * \param[in] freqs The freqs tensor. - * \param[out] output Output tensor. - * \param[in] cp_size Context parallel world size. - * \param[in] cp_rank Context parallel rank. - * \param[in] max_s Max sequence length. - * \param[in] b Batch size. - * \param[in] h Length of the h dimension of input. - * \param[in] d Length of the d dimension of input. - * \param[in] d2 Length of the d dimension of freqs. - * \param[in] stride_t Stride of the t dimension of input. - * \param[in] stride_h Stride of the h dimension of input. - * \param[in] stride_d Stride of the d dimension of input. - * \param[in] o_stride_t Stride of the t dimension of output. - * \param[in] o_stride_h Stride of the h dimension of output. - * \param[in] o_stride_d Stride of the d dimension of output. - * \param[in] stream CUDA stream used for the operation. - */ -void nvte_fused_rope_thd_forward(const NVTETensor input, const NVTETensor cu_seqlens, - const NVTETensor freqs, NVTETensor output, const int cp_size, - const int cp_rank, const int max_s, const int b, const int h, - const int d, const int d2, const int stride_t, const int stride_h, - const int stride_d, const int o_stride_t, const int o_stride_h, - const int o_stride_d, cudaStream_t stream); - -/*! \brief Compute the backward of the fused rope in thd format. - * - * \param[in] output_grads Incoming gradient tensor for backward. - * \param[in] cu_seqlens The cumulative sum of sequence lengths tensor. - * \param[in] freqs The freqs tensor. - * \param[out] input_grads Input gradient to calculate. - * \param[in] cp_size Context parallel world size. - * \param[in] cp_rank Context parallel rank. - * \param[in] max_s Max sequence length. - * \param[in] b Batch size. - * \param[in] h Length of the h dimension of output_grads. - * \param[in] d Length of the d dimension of output_grads. - * \param[in] d2 Length of the d dimension of freqs. - * \param[in] stride_t Stride of the t dimension of output_grads. - * \param[in] stride_h Stride of the h dimension of output_grads. - * \param[in] stride_d Stride of the d dimension of output_grads. - * \param[in] o_stride_t Stride of the t dimension of input_grads. - * \param[in] o_stride_h Stride of the h dimension of input_grads. - * \param[in] o_stride_d Stride of the d dimension of input_grads. - * \param[in] stream CUDA stream used for the operation. - */ -void nvte_fused_rope_thd_backward(const NVTETensor output_grads, const NVTETensor cu_seqlens, - const NVTETensor freqs, NVTETensor input_grads, const int cp_size, - const int cp_rank, const int max_s, const int b, const int h, - const int d, const int d2, const int stride_t, const int stride_h, - const int stride_d, const int o_stride_t, const int o_stride_h, - const int o_stride_d, cudaStream_t stream); - #ifdef __cplusplus } // extern "C" #endif diff --git a/transformer_engine/common/recipe/current_scaling.cu b/transformer_engine/common/recipe/current_scaling.cu index 3a25d71..cf07d12 100644 --- a/transformer_engine/common/recipe/current_scaling.cu +++ b/transformer_engine/common/recipe/current_scaling.cu @@ -13,6 +13,7 @@ #include "../common.h" #include "../util/logging.h" #include "../util/vectorized_pointwise.h" +#include "recipe_common.cuh" namespace transformer_engine { namespace { @@ -135,7 +136,7 @@ void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaSt "Output tensor for amax computation has invalid amax tensor " "(expected FP32, got dtype=", to_string(output.amax.dtype), ")"); - CheckOutputTensor(output, "output_compute_amax"); + CheckOutputTensor(output, "output_compute_amax", true); // Compute amax TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( @@ -151,41 +152,7 @@ namespace { __global__ void compute_scale_from_amax_kernel(const float *amax_ptr, float *scale_ptr, const float max_fp8, const bool force_pow_2_scales, const float epsilon) { - float amax = *amax_ptr; - if (amax < epsilon) { - amax = epsilon; - } - - float scale = 1.f; - - if (isinf(amax) || amax == 0.f) { - *scale_ptr = scale; - return; - } - - scale = max_fp8 / amax; - - // The amax is too small that the scale becoming infinite in FP32. In other word, - // the scale is not representable in FP32. - if (isinf(scale)) { - // use fp32 max to represent the scale - scale = std::numeric_limits::max(); - } - - if (isnan(scale)) { - scale = 1.f; - } - - if (force_pow_2_scales) { - uint32_t scale_bits = *reinterpret_cast(&scale); - scale_bits &= 0xFF800000; - // If the exponent was zero, we have a logic error. - __builtin_assume(scale_bits != 0); - __builtin_assume(scale_bits != 0x80000000); - scale = *reinterpret_cast(&scale_bits); - } - - *scale_ptr = scale; + *scale_ptr = compute_scale_from_amax(*amax_ptr, max_fp8, force_pow_2_scales, epsilon); } } // namespace diff --git a/transformer_engine/common/recipe/recipe_common.cuh b/transformer_engine/common/recipe/recipe_common.cuh new file mode 100644 index 0000000..c789a9b --- /dev/null +++ b/transformer_engine/common/recipe/recipe_common.cuh @@ -0,0 +1,56 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_RECIPE_RECIPE_COMMON_CUH_ +#define TRANSFORMER_ENGINE_RECIPE_RECIPE_COMMON_CUH_ + +#include + +namespace transformer_engine { + +__device__ __forceinline__ float compute_scale_from_amax(float amax, float max_fp8, + bool force_pow_2_scales, float epsilon) { + if (amax < epsilon) { + amax = epsilon; + } + + float scale = 1.f; + + if (isinf(amax) || amax == 0.f) { + return scale; + } + + // Here we don't use "scale = max_fp8 / amax" because it has different results with/without + // "--use_fast_math". + // "__fdiv_rn" has the same behavior with "max_fp8 / amax" when not using fast math. + scale = __fdiv_rn(max_fp8, amax); + + // The amax is too small that the scale becoming infinite in FP32. In other word, + // the scale is not representable in FP32. + if (isinf(scale)) { + // use fp32 max to represent the scale + scale = std::numeric_limits::max(); + } + + if (isnan(scale)) { + scale = 1.f; + } + + if (force_pow_2_scales) { + uint32_t scale_bits = *reinterpret_cast(&scale); + scale_bits &= 0xFF800000; + // If the exponent was zero, we have a logic error. + __builtin_assume(scale_bits != 0); + __builtin_assume(scale_bits != 0x80000000); + scale = *reinterpret_cast(&scale_bits); + } + + return scale; +} + +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_RECIPE_RECIPE_COMMON_CUH_ diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index e430be0..5e022c7 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -252,6 +252,8 @@ at::Tensor scaled_aligned_causal_masked_softmax_backward(at::Tensor output_grads * FP8 recipe **************************************************************************************************/ +void compute_amax(const at::Tensor &tensor, at::Tensor &amax); + void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reduction_buffer, std::vector amax_histories, std::vector scales, @@ -263,16 +265,14 @@ void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reductio **************************************************************************************************/ at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs, - const bool transpose_output_memory); + const NVTE_QKV_Format qkv_format, const bool interleaved, + const c10::optional cu_seqlens, const int cp_size, + const int cp_rank); at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor &freqs, - const bool transpose_output_memory); - -at::Tensor fused_rope_thd_forward(const at::Tensor &input, const at::Tensor &cu_seqlens, - const at::Tensor &freqs, const int cp_size, const int cp_rank); - -at::Tensor fused_rope_thd_backward(const at::Tensor &output_grads, const at::Tensor &cu_seqlens, - const at::Tensor &freqs, const int cp_size, const int cp_rank); + const NVTE_QKV_Format qkv_format, const bool interleaved, + const c10::optional cu_seqlens, const int cp_size, + const int cp_rank); /*************************************************************************************************** * Miscellaneous @@ -359,6 +359,10 @@ void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag, float momentum, float dampening, float lr, bool nesterov, bool first_run, bool wd_after_momentum, float scale); +void multi_tensor_compute_scale_and_scale_inv_cuda( + int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, + float max_fp8, bool force_pow_2_scales, float epsilon); + /*************************************************************************************************** * padding **************************************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp b/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp index c323e7b..424a988 100644 --- a/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp +++ b/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp @@ -7,217 +7,181 @@ #include "extensions.h" at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs, - const bool transpose_output_memory) { + const NVTE_QKV_Format qkv_format, const bool interleaved, + const c10::optional cu_seqlens, const int cp_size, + const int cp_rank) { using namespace transformer_engine::pytorch; - TORCH_CHECK(input.dim() == 4, "expected 4D tensor"); + TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor"); - TORCH_CHECK(input.size(0) <= freqs.size(0), - "expected freqs tensor has a longer sequence length than input"); TORCH_CHECK(freqs.size(1) == 1 && freqs.size(2) == 1, "expected the second and third dims of the freqs tensor equal 1"); - TORCH_CHECK(input.size(3) >= freqs.size(3), - "expected the last dim of the input tensor equals or is " - "greater than the freqs tensor"); TORCH_CHECK(freqs.scalar_type() == at::ScalarType::Float, "Dtype of the freqs tensor must be float"); - // input sizes: (s, b, h, d) + // output + auto act_options = at::TensorOptions().dtype(input.scalar_type()).device(input.device()); + auto output = at::empty(input.sizes(), act_options); + + auto input_cu = makeTransformerEngineTensor(input); + auto freqs_cu = makeTransformerEngineTensor(freqs); + auto output_cu = makeTransformerEngineTensor(output); + + if (qkv_format == NVTE_QKV_Format::NVTE_THD) { + TORCH_CHECK(input.dim() == 3, "expected 3D tensor"); + TORCH_CHECK(cu_seqlens.has_value(), "expected cu_seqlens tensor"); + TORCH_CHECK(cu_seqlens.value().dim() == 1, "expected 1D tensor"); + TORCH_CHECK(input.size(2) >= freqs.size(3), + "expected the last dim of the input tensor equals or is " + "greater than the freqs tensor"); + + // input sizes: (t, h, d) + // t: cumulative sum of sequence lengths + // h: head num + // d: dim of each head + // const int t = input.size(0); + const int h = input.size(1); + const int d = input.size(2); + // input strides + const int stride_t = input.stride(0); + const int stride_h = input.stride(1); + const int stride_d = input.stride(2); + // batch size + const int b = cu_seqlens.value().size(0) - 1; + // freqs' shape is (max_s, 1, 1, d2) + const int max_s = freqs.size(0); + const int d2 = freqs.size(3); + + auto cu_seqlens_cu = makeTransformerEngineTensor(cu_seqlens.value()); + + nvte_fused_rope_forward(input_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(), + output_cu.data(), qkv_format, interleaved, cp_size, cp_rank, max_s, b, + h, d, d2, stride_t, /*stride_b=*/0, stride_h, stride_d, + at::cuda::getCurrentCUDAStream()); + + return output; + } + + TORCH_CHECK(input.dim() == 4, "expected 4D tensor"); + // input sizes: (s, b, h, d) or (b, s, h, d) // s: sequence length // b: batch size // h: head num // d: dim of each head - const int s = input.size(0); - const int b = input.size(1); + const int s = qkv_format == NVTE_QKV_Format::NVTE_SBHD ? input.size(0) : input.size(1); + const int b = qkv_format == NVTE_QKV_Format::NVTE_SBHD ? input.size(1) : input.size(0); const int h = input.size(2); const int d = input.size(3); // input strides - const int stride_s = input.stride(0); - const int stride_b = input.stride(1); + const int stride_s = qkv_format == NVTE_QKV_Format::NVTE_SBHD ? input.stride(0) : input.stride(1); + const int stride_b = qkv_format == NVTE_QKV_Format::NVTE_SBHD ? input.stride(1) : input.stride(0); const int stride_h = input.stride(2); const int stride_d = input.stride(3); // freqs' shape is always (s, 1, 1, d2), so the strides are same under // different memory formats const int d2 = freqs.size(3); - // output - auto act_options = input.options().requires_grad(false); - at::Tensor output; - if (transpose_output_memory) { - output = torch::empty({b, s, h, d}, act_options).transpose(0, 1); - } else { - output = torch::empty({s, b, h, d}, act_options); - } - // output strides - const int o_stride_s = output.stride(0); - const int o_stride_b = output.stride(1); - const int o_stride_h = output.stride(2); - const int o_stride_d = output.stride(3); - - auto input_cu = makeTransformerEngineTensor(input); - auto freqs_cu = makeTransformerEngineTensor(freqs); - auto output_cu = makeTransformerEngineTensor(output); + TORCH_CHECK(s * cp_size <= freqs.size(0), + "expected freqs tensor has a longer sequence length than input"); + TORCH_CHECK(d >= d2, + "expected the last dim of the input tensor equals or is " + "greater than the freqs tensor"); - nvte_fused_rope_forward(input_cu.data(), freqs_cu.data(), output_cu.data(), s, b, h, d, d2, - stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b, - o_stride_h, o_stride_d, at::cuda::getCurrentCUDAStream()); + auto cu_seqlens_cu = transformer_engine::TensorWrapper(); // empty cu_seqlens tensor + nvte_fused_rope_forward(input_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(), output_cu.data(), + qkv_format, interleaved, cp_size, cp_rank, s, b, h, d, d2, stride_s, + stride_b, stride_h, stride_d, at::cuda::getCurrentCUDAStream()); return output; } at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor &freqs, - const bool transpose_output_memory) { + const NVTE_QKV_Format qkv_format, const bool interleaved, + const c10::optional cu_seqlens, const int cp_size, + const int cp_rank) { using namespace transformer_engine::pytorch; - TORCH_CHECK(output_grads.dim() == 4, "expected 4D tensor"); TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor"); - TORCH_CHECK(output_grads.size(0) <= freqs.size(0), - "expected freqs tensor has a longer sequence length than output_grads"); TORCH_CHECK(freqs.size(1) == 1 && freqs.size(2) == 1, "expected the second and third dims of the freqs tensor equal 1"); - TORCH_CHECK(output_grads.size(3) >= freqs.size(3), - "expected the last dim of the output_grads tensor equals or is " - "greater than the freqs tensor"); TORCH_CHECK(freqs.scalar_type() == at::ScalarType::Float, "Dtype of the freqs tensor must be float"); + auto act_options = + at::TensorOptions().dtype(output_grads.scalar_type()).device(output_grads.device()); + auto input_grads = at::empty(output_grads.sizes(), act_options); + + auto output_grads_cu = makeTransformerEngineTensor(output_grads); + auto freqs_cu = makeTransformerEngineTensor(freqs); + auto input_grads_cu = makeTransformerEngineTensor(input_grads); + + if (qkv_format == NVTE_QKV_Format::NVTE_THD) { + TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor"); + TORCH_CHECK(cu_seqlens.has_value(), "expected cu_seqlens tensor"); + TORCH_CHECK(cu_seqlens.value().dim() == 1, "expected 1D tensor"); + TORCH_CHECK(output_grads.size(2) >= freqs.size(3), + "expected the last dim of the output_grads tensor equals or is " + "greater than the freqs tensor"); + + // output_grads sizes: (t, h, d) + // t: cumulative sum of sequence lengths + // h: head num + // d: dim of each head + // const int t = output_grads.size(0); + const int h = output_grads.size(1); + const int d = output_grads.size(2); + // output_grads strides + const int stride_t = output_grads.stride(0); + const int stride_h = output_grads.stride(1); + const int stride_d = output_grads.stride(2); + // batch size + const int b = cu_seqlens.value().size(0) - 1; + // freqs' shape is (max_s, 1, 1, d2) + const int max_s = freqs.size(0); + const int d2 = freqs.size(3); + + auto cu_seqlens_cu = makeTransformerEngineTensor(cu_seqlens.value()); + + nvte_fused_rope_backward(output_grads_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(), + input_grads_cu.data(), qkv_format, interleaved, cp_size, cp_rank, + max_s, b, h, d, d2, stride_t, /*stride_b=*/0, stride_h, stride_d, + at::cuda::getCurrentCUDAStream()); + + return input_grads; + } + + TORCH_CHECK(output_grads.dim() == 4, "expected 4D tensor"); // output_grads sizes: (s, b, h, d) // s: sequence length // b: batch size // h: head num // d: dim of each head - const int s = output_grads.size(0); - const int b = output_grads.size(1); + const int s = + qkv_format == NVTE_QKV_Format::NVTE_SBHD ? output_grads.size(0) : output_grads.size(1); + const int b = + qkv_format == NVTE_QKV_Format::NVTE_SBHD ? output_grads.size(1) : output_grads.size(0); const int h = output_grads.size(2); const int d = output_grads.size(3); // output_grads strides - const int stride_s = output_grads.stride(0); - const int stride_b = output_grads.stride(1); + const int stride_s = + qkv_format == NVTE_QKV_Format::NVTE_SBHD ? output_grads.stride(0) : output_grads.stride(1); + const int stride_b = + qkv_format == NVTE_QKV_Format::NVTE_SBHD ? output_grads.stride(1) : output_grads.stride(0); const int stride_h = output_grads.stride(2); const int stride_d = output_grads.stride(3); // freqs' shape is always (s, 1, 1, d2), so the strides are same under // different memory formats const int d2 = freqs.size(3); - auto act_options = output_grads.options().requires_grad(false); - at::Tensor input_grads; - if (transpose_output_memory) { - input_grads = torch::empty({b, s, h, d}, act_options).transpose(0, 1); - } else { - input_grads = torch::empty({s, b, h, d}, act_options); - } - const int o_stride_s = input_grads.stride(0); - const int o_stride_b = input_grads.stride(1); - const int o_stride_h = input_grads.stride(2); - const int o_stride_d = input_grads.stride(3); - - auto output_grads_cu = makeTransformerEngineTensor(output_grads); - auto freqs_cu = makeTransformerEngineTensor(freqs); - auto input_grads_cu = makeTransformerEngineTensor(input_grads); - - nvte_fused_rope_backward(output_grads_cu.data(), freqs_cu.data(), input_grads_cu.data(), s, b, h, - d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b, - o_stride_h, o_stride_d, at::cuda::getCurrentCUDAStream()); - - return input_grads; -} - -at::Tensor fused_rope_thd_forward(const at::Tensor &input, const at::Tensor &cu_seqlens, - const at::Tensor &freqs, const int cp_size, const int cp_rank) { - using namespace transformer_engine::pytorch; - TORCH_CHECK(input.dim() == 3, "expected 3D tensor"); - TORCH_CHECK(cu_seqlens.dim() == 1, "expected 1D tensor"); - TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor"); - TORCH_CHECK(freqs.size(1) == 1 && freqs.size(2) == 1, - "expected the second and third dims of the freqs tensor equal 1"); - TORCH_CHECK(input.size(2) >= freqs.size(3), - "expected the last dim of the input tensor equals or is " - "greater than the freqs tensor"); - TORCH_CHECK(freqs.scalar_type() == at::ScalarType::Float, - "Dtype of the freqs tensor must be float"); - - // input sizes: (t, h, d) - // t: cumulative sum of sequence lengths - // h: head num - // d: dim of each head - const int t = input.size(0); - const int h = input.size(1); - const int d = input.size(2); - // input strides - const int stride_t = input.stride(0); - const int stride_h = input.stride(1); - const int stride_d = input.stride(2); - // batch size - const int b = cu_seqlens.size(0) - 1; - // freqs' shape is (max_s, 1, 1, d2) - const int max_s = freqs.size(0); - const int d2 = freqs.size(3); - - // output - auto act_options = input.options().requires_grad(false); - auto output = torch::empty({t, h, d}, act_options); - // output strides - const int o_stride_t = output.stride(0); - const int o_stride_h = output.stride(1); - const int o_stride_d = output.stride(2); - - auto input_cu = makeTransformerEngineTensor(input); - auto cu_seqlens_cu = makeTransformerEngineTensor(cu_seqlens); - auto freqs_cu = makeTransformerEngineTensor(freqs); - auto output_cu = makeTransformerEngineTensor(output); - - nvte_fused_rope_thd_forward(input_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(), - output_cu.data(), cp_size, cp_rank, max_s, b, h, d, d2, stride_t, - stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d, - at::cuda::getCurrentCUDAStream()); - - return output; -} - -at::Tensor fused_rope_thd_backward(const at::Tensor &output_grads, const at::Tensor &cu_seqlens, - const at::Tensor &freqs, const int cp_size, const int cp_rank) { - using namespace transformer_engine::pytorch; - TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor"); - TORCH_CHECK(cu_seqlens.dim() == 1, "expected 1D tensor"); - TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor"); - TORCH_CHECK(freqs.size(1) == 1 && freqs.size(2) == 1, - "expected the second and third dims of the freqs tensor equal 1"); - TORCH_CHECK(output_grads.size(2) >= freqs.size(3), + TORCH_CHECK(s * cp_size <= freqs.size(0), + "expected freqs tensor has a longer sequence length than output_grads"); + TORCH_CHECK(d >= d2, "expected the last dim of the output_grads tensor equals or is " "greater than the freqs tensor"); - TORCH_CHECK(freqs.scalar_type() == at::ScalarType::Float, - "Dtype of the freqs tensor must be float"); - - // output_grads sizes: (t, h, d) - // t: cumulative sum of sequence lengths - // h: head num - // d: dim of each head - const int t = output_grads.size(0); - const int h = output_grads.size(1); - const int d = output_grads.size(2); - // output_grads strides - const int stride_t = output_grads.stride(0); - const int stride_h = output_grads.stride(1); - const int stride_d = output_grads.stride(2); - // batch size - const int b = cu_seqlens.size(0) - 1; - // freqs' shape is (max_s, 1, 1, d2) - const int max_s = freqs.size(0); - const int d2 = freqs.size(3); - - auto act_options = output_grads.options().requires_grad(false); - auto input_grads = torch::empty({t, h, d}, act_options); - const int o_stride_t = input_grads.stride(0); - const int o_stride_h = input_grads.stride(1); - const int o_stride_d = input_grads.stride(2); - - auto output_grads_cu = makeTransformerEngineTensor(output_grads); - auto cu_seqlens_cu = makeTransformerEngineTensor(cu_seqlens); - auto freqs_cu = makeTransformerEngineTensor(freqs); - auto input_grads_cu = makeTransformerEngineTensor(input_grads); - nvte_fused_rope_thd_backward(output_grads_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(), - input_grads_cu.data(), cp_size, cp_rank, max_s, b, h, d, d2, - stride_t, stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d, - at::cuda::getCurrentCUDAStream()); + auto cu_seqlens_cu = transformer_engine::TensorWrapper(); // empty cu_seqlens tensor + nvte_fused_rope_backward(output_grads_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(), + input_grads_cu.data(), qkv_format, interleaved, cp_size, cp_rank, s, b, + h, d, d2, stride_s, stride_b, stride_h, stride_d, + at::cuda::getCurrentCUDAStream()); return input_grads; } diff --git a/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_compute_scale.cu b/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_compute_scale.cu new file mode 100644 index 0000000..d262767 --- /dev/null +++ b/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_compute_scale.cu @@ -0,0 +1,66 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +// Another possibility: +// #include + +#include +// Stringstream is a big hammer, but I want to rely on operator<< for dtype. +#include + +#include "common/recipe/recipe_common.cuh" +#include "common/utils.cuh" +#include "multi_tensor_apply.cuh" +#include "type_shim.h" + +#define BLOCK_SIZE 256 + +struct ComputeScaleAndScaleInvFunctor { + __device__ __forceinline__ void operator()(int chunk_size, volatile int *noop_gmem, + TensorListMetadata<3> &tl, // NOLINT(*) + float max_fp8, bool force_pow_2_scales, + float epsilon) { + // I'd like this kernel to propagate infs/nans. + // if(*noop_gmem == 1) + // return; + + int tensor_loc = tl.block_to_tensor[blockIdx.x]; + int chunk_idx = tl.block_to_chunk[blockIdx.x]; + int n = tl.sizes[tensor_loc]; + + float *amax = reinterpret_cast(tl.addresses[0][tensor_loc]); + amax += chunk_idx * chunk_size; + + float *scale = reinterpret_cast(tl.addresses[1][tensor_loc]); + scale += chunk_idx * chunk_size; + + float *scale_inv = reinterpret_cast(tl.addresses[2][tensor_loc]); + scale_inv += chunk_idx * chunk_size; + + n -= chunk_idx * chunk_size; + + for (int i_start = threadIdx.x; i_start < n && i_start < chunk_size; i_start += blockDim.x) { + float scale_val = transformer_engine::compute_scale_from_amax(amax[i_start], max_fp8, + force_pow_2_scales, epsilon); + scale[i_start] = scale_val; + transformer_engine::reciprocal(scale_inv + i_start, scale_val); + } + } +}; + +void multi_tensor_compute_scale_and_scale_inv_cuda( + int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, + float max_fp8, bool force_pow_2_scales, float epsilon) { + using namespace at; + + multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + ComputeScaleAndScaleInvFunctor(), max_fp8, force_pow_2_scales, epsilon); + AT_CUDA_CHECK(cudaGetLastError()); +} diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index a58fd3a..ffd524c 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -178,6 +178,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("dtype"), py::kw_only(), py::arg("out"), py::call_guard()); m.def("get_fused_attn_backend", &get_fused_attn_backend, "Get Fused Attention backend", py::call_guard()); + m.def("compute_amax", &compute_amax, "Compute amax", py::arg("input"), py::arg("amax")); m.def("fused_amax_and_scale_update_after_reduction", &fused_amax_and_scale_update_after_reduction, "Update amax history and FP8 scale/scale_inv after reduction", py::call_guard()); @@ -202,10 +203,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::call_guard()); m.def("fused_rope_backward", &fused_rope_backward, "Fused Apply RoPE BWD", py::call_guard()); - m.def("fused_rope_thd_forward", &fused_rope_thd_forward, "Fused Apply RoPE FWD for thd format", - py::call_guard()); - m.def("fused_rope_thd_backward", &fused_rope_thd_backward, "Fused Apply RoPE BWD for thd format", - py::call_guard()); // Misc m.def("get_cublasLt_version", &get_cublasLt_version, "Get cublasLt version", @@ -265,6 +262,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("multi_tensor_sgd", &multi_tensor_sgd_cuda, "Fused SGD optimizer for list of contiguous tensors", py::call_guard()); + m.def("multi_tensor_compute_scale_and_scale_inv", &multi_tensor_compute_scale_and_scale_inv_cuda, + "Fused compute scale and scale_inv from amax", py::call_guard()); // Data structures py::class_(m, "FP8TensorMeta") diff --git a/transformer_engine/pytorch/csrc/extensions/recipe.cpp b/transformer_engine/pytorch/csrc/extensions/recipe.cpp index e8a31da..2dc3b69 100644 --- a/transformer_engine/pytorch/csrc/extensions/recipe.cpp +++ b/transformer_engine/pytorch/csrc/extensions/recipe.cpp @@ -12,10 +12,27 @@ #include "common/common.h" #include "extensions.h" -void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reduction_buffer, +void compute_amax(const at::Tensor& tensor, at::Tensor& amax) { + using namespace transformer_engine; + using namespace transformer_engine::pytorch; + + auto input_tensor = tensor.contiguous(); + const TensorWrapper& te_input = makeTransformerEngineTensor(input_tensor); + + TORCH_CHECK(amax.scalar_type() == at::kFloat, "amax must be a float tensor"); + TORCH_CHECK(amax.numel() == 1, "amax must have exactly one element"); + TensorWrapper fake_te_output( + nullptr, te_input.shape(), + transformer_engine::DType::kFloat8E4M3, // It doesn't matter because we only compute amax. + amax.data_ptr()); + + nvte_compute_amax(te_input.data(), fake_te_output.data(), at::cuda::getCurrentCUDAStream()); +} + +void fused_amax_and_scale_update_after_reduction(const at::Tensor& amax_reduction_buffer, std::vector amax_histories, std::vector scales, - const std::string &amax_compute_algo, + const std::string& amax_compute_algo, transformer_engine::DType fp8_dtype, float margin) { using namespace transformer_engine; diff --git a/transformer_engine/pytorch/dot_product_attention/rope.py b/transformer_engine/pytorch/dot_product_attention/rope.py index 83698c7..6793f1b 100644 --- a/transformer_engine/pytorch/dot_product_attention/rope.py +++ b/transformer_engine/pytorch/dot_product_attention/rope.py @@ -7,7 +7,12 @@ Rotary Position Embedding implementation of different types along with helper fu """ from typing import Optional, Tuple, Union import torch + import transformer_engine_torch as tex +from transformer_engine.pytorch.cpp_extensions.fused_attn import QKVFormat + + +__all__ = ["RotaryPositionEmbedding", "apply_rotary_pos_emb"] class RotaryPositionEmbedding(torch.nn.Module): @@ -22,19 +27,24 @@ class RotaryPositionEmbedding(torch.nn.Module): seq_len_interpolation_factor: Optional[int] = None, pretrained_max_position_embeddings: Optional[int] = None, rotary_base: float = 10000.0, + interleaved: bool = False, ): """ Parameters ---------- dim: int - rotary embedding dimension - rotary_percent: float + Rotary embedding dimension. + rotary_percent: float, default = 1.0 Percent of rotary dimension to use for rotary position embeddings. - seq_len_interpolation_factor: int - if not None, discrete positions will be interpolated by this factor via the trick in + seq_len_interpolation_factor: int, default = None + If not None, discrete positions will be interpolated by this factor via the trick in https://arxiv.org/abs/2306.15595 - pretrained_max_position_embeddings: int - pre-trained max_position_embeddings before position interpolation + pretrained_max_position_embeddings: int, default = None + Pre-trained max_position_embeddings before position interpolation. + rotary_base: float, default = 10000.0 + Base of the rotary position embedding. + interleaved: bool, default = False + Whether to use interleaved rotary position embedding. """ super().__init__() if rotary_percent < 1.0: @@ -50,17 +60,18 @@ class RotaryPositionEmbedding(torch.nn.Module): ) self.register_buffer("inv_freq", inv_freq) self.pretrained_max_position_embeddings = pretrained_max_position_embeddings + self.interleaved = interleaved def forward(self, max_seq_len: int, offset: int = 0): """ - Create rotary position embedding frequencies + Create rotary position embedding frequencies. Parameters ---------- max_seq_len: int - sequence length of a sample + Sequence length of a sample. offset: int, default = 0 - fixed offset for freqencies + Fixed offset for frequencies. """ seq = ( torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype) @@ -84,7 +95,12 @@ class RotaryPositionEmbedding(torch.nn.Module): freqs = torch.einsum("i , j -> i j", seq, self.inv_freq) # first part even vector components, second part odd vector components, # 2 * dim in dimension size - emb = torch.cat((freqs, freqs), dim=-1) + if not self.interleaved: + emb = torch.cat((freqs, freqs), dim=-1) + else: + emb = torch.stack((freqs.view(-1, 1), freqs.view(-1, 1)), dim=-1).view( + freqs.shape[0], -1 + ) # emb [seq_length, .., dim] return emb.reshape(emb.size(0), 1, 1, emb.size(1)) @@ -104,61 +120,146 @@ class FusedRoPEFunc(torch.autograd.Function): t: torch.Tensor, freqs: torch.Tensor, tensor_format: str = "sbhd", + interleaved: bool = False, cu_seqlens: Union[torch.Tensor, None] = None, cp_size: int = 1, cp_rank: int = 0, ) -> torch.Tensor: - # pylint: disable=missing-function-docstring + """Fused RoPE forward.""" if freqs.dtype != torch.float32: freqs = freqs.float() - if tensor_format == "sbhd": - output = tex.fused_rope_forward(t, freqs, False) - elif tensor_format == "bshd": - output = tex.fused_rope_forward(t.transpose(0, 1), freqs, True).transpose(0, 1) - elif tensor_format == "thd": - output = tex.fused_rope_thd_forward(t, cu_seqlens, freqs, cp_size, cp_rank) - else: - raise ValueError(f"Unsupported tensor_format: {tensor_format}.") + assert tensor_format in ( + "sbhd", + "bshd", + "thd", + ), f"Unsupported tensor_format: {tensor_format}." + output = tex.fused_rope_forward( + t, freqs, QKVFormat[tensor_format], interleaved, cu_seqlens, cp_size, cp_rank + ) ctx.save_for_backward(freqs, cu_seqlens) ctx.tensor_format = tensor_format ctx.cp_size = cp_size ctx.cp_rank = cp_rank + ctx.interleaved = interleaved return output @staticmethod def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: - # pylint: disable=missing-function-docstring + """Fused RoPE backward.""" freqs, cu_seqlens = ctx.saved_tensors - if ctx.tensor_format == "sbhd": - grad_input = tex.fused_rope_backward(grad_output, freqs, False) - elif ctx.tensor_format == "bshd": - grad_input = tex.fused_rope_backward( - grad_output.transpose(0, 1), freqs, True - ).transpose(0, 1) - elif ctx.tensor_format == "thd": - grad_input = tex.fused_rope_thd_backward( - grad_output, cu_seqlens, freqs, ctx.cp_size, ctx.cp_rank - ) - else: - raise ValueError(f"Unsupported tensor_format: {ctx.tensor_format}.") + grad_input = tex.fused_rope_backward( + grad_output, + freqs, + QKVFormat[ctx.tensor_format], + ctx.interleaved, + cu_seqlens, + ctx.cp_size, + ctx.cp_rank, + ) + + return grad_input, None, None, None, None, None, None - return grad_input, None, None, None, None, None +def _rotate_half(x: torch.Tensor, interleaved: bool) -> torch.Tensor: + """Change sign so the last dimension becomes [-odd, +even] -def _rotate_half(x: torch.Tensor) -> torch.Tensor: + Args: + x: torch.Tensor. Input tensor. + interleaved: bool. Whether to use interleaved rotary position embedding. + + Returns: + Tensor: Tensor rotated half. """ - change sign so the last dimension becomes [-odd, +even] + if not interleaved: + x1, x2 = torch.chunk(x, 2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + # interleaved + x1 = x[:, :, :, ::2] + x2 = x[:, :, :, 1::2] + x_new = torch.stack((-x2, x1), dim=-1) + return x_new.view(x_new.shape[0], x_new.shape[1], x_new.shape[2], -1) + + +def _apply_rotary_pos_emb_base( + t: torch.Tensor, + freqs: torch.Tensor, + tensor_format: str = "sbhd", + interleaved: bool = False, +) -> torch.Tensor: + """ + Base implementation of applying rotary positional embedding tensor to the input tensor. + + Parameters + ---------- + t: torch.Tensor + Input tensor of shape `[s, b, h, d]` or `[b, s, h, d]`, on which rotary positional + embedding will be applied. + freqs: torch.Tensor + Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` and dtype 'float', + with `s2 >= s` and `d2 <= d`. + tensor_format: {'sbhd', 'bshd'}, default = 'sbhd' + Should be `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is of shape + `[seq, bs, ...]`. + interleaved: bool, default = False + Whether to use interleaved rotary position embedding. + """ + max_seq_len = freqs.shape[0] + cur_seq_len = t.shape[1] if tensor_format == "bshd" else t.shape[0] + + # Only apply the rotary embeddings up to the sequence length of the running + # input. + assert ( + cur_seq_len <= max_seq_len + ), f"Rotary Embeddings only supported up to {max_seq_len} sequence length!" + freqs = freqs[:cur_seq_len] + if tensor_format == "bshd": + freqs = freqs.transpose(0, 1) # [seq, 1, 1, dim] -> [1, seq, 1, dim] + # cos/sin first then dtype conversion for better precision + cos_ = torch.cos(freqs).to(t.dtype) + sin_ = torch.sin(freqs).to(t.dtype) + + rot_dim = freqs.shape[-1] + # ideally t_pass is empty so rotary pos embedding is applied to all tensor t + t, t_pass = t[..., :rot_dim], t[..., rot_dim:] + + # first part is cosine component + # second part is sine component, need to change signs with _rotate_half method + t = (t * cos_) + (_rotate_half(t, interleaved) * sin_) + return torch.cat((t, t_pass), dim=-1) + + +def _get_freqs_on_this_cp_rank( + freqs: torch.Tensor, seqlen: int, cp_size: int, cp_rank: int +) -> torch.Tensor: + """Get the position embedding on the current context parallel rank. + + Args: + freqs: torch.Tensor. Positional embedding tensor in shape `[s2, 1, 1, d2]`. + seqlen: int. Length of the current sequence. + cp_size: int. Context parallel world size. + cp_rank: int. Context parallel rank. """ - x = x.view(x.shape[:-1] + torch.Size((2, x.shape[-1] // 2))) - x1, x2 = x.unbind(dim=-2) - return torch.cat((-x2, x1), dim=-1) + if cp_size > 1: + cp_seg = seqlen // 2 + full_seqlen = cp_size * seqlen + return torch.cat( + [ + freqs[cp_rank * cp_seg : (cp_rank + 1) * cp_seg], + freqs[full_seqlen - (cp_rank + 1) * cp_seg : full_seqlen - cp_rank * cp_seg], + ] + ) + + # cp_size == 1 + return freqs[:seqlen] def apply_rotary_pos_emb( t: torch.Tensor, freqs: torch.Tensor, tensor_format: str = "sbhd", + interleaved: bool = False, fused: bool = False, cu_seqlens: Union[torch.Tensor, None] = None, cp_size: int = 1, @@ -175,11 +276,13 @@ def apply_rotary_pos_emb( freqs: torch.Tensor Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` and dtype 'float', with `s2 >= s` and `d2 <= d`. - fused: bool, default = False - Whether to use a fused applying RoPE implementation. tensor_format: {'sbhd', 'bshd', 'thd'}, default = 'sbhd' is `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is of shape `[seq, bs, ...]`. 'thd' is only supported when `fused` is True. + interleaved: bool, default = False + Whether to use interleaved rotary position embedding. + fused: bool, default = False + Whether to use a fused applying RoPE implementation. cu_seqlens: torch.Tensor, default = None. Cumulative sum of sequence lengths in a batch for `t`, with shape [b + 1] and dtype torch.int32. Only valid when `tensor_format` is 'thd'. @@ -189,37 +292,40 @@ def apply_rotary_pos_emb( cp_rank: int, default = 0. Context parallel rank. Only valid when `tensor_format` is 'thd' and `fused` is True. """ - if fused: - assert ( - tensor_format != "thd" or cu_seqlens is not None - ), "cu_seqlens must not be None when tensor_format is 'thd'." - return FusedRoPEFunc.apply(t, freqs, tensor_format, cu_seqlens, cp_size, cp_rank) - - assert tensor_format in ("sbhd", "bshd"), ( - "Only formats `sbhd` or `bshd` are supported for input tensor `t` " - f"when fused is False, got {tensor_format}." - ) - - max_seq_len = freqs.shape[0] - cur_seq_len = t.shape[1] if tensor_format == "bshd" else t.shape[0] - - # Only apply the rotary embeddings up to the sequence length of the running - # input. assert ( - cur_seq_len <= max_seq_len - ), f"Rotary Embeddings only supported up to {max_seq_len} sequence length!" - freqs = freqs[:cur_seq_len] - if tensor_format == "bshd": - freqs = freqs.transpose(0, 1) # [seq, 1, 1, dim] -> [1, seq, 1, dim] - # cos/sin first then dtype conversion for better precision - cos_ = torch.cos(freqs).to(t.dtype) - sin_ = torch.sin(freqs).to(t.dtype) + tensor_format != "thd" or cu_seqlens is not None + ), "cu_seqlens must not be None when tensor_format is 'thd'." - rot_dim = freqs.shape[-1] - # ideally t_pass is empty so rotary pos embedding is applied to all tensor t - t, t_pass = t[..., :rot_dim], t[..., rot_dim:] + if fused: + return FusedRoPEFunc.apply( + t, freqs, tensor_format, interleaved, cu_seqlens, cp_size, cp_rank + ) - # first part is cosine component - # second part is sine component, need to change signs with _rotate_half method - t = (t * cos_) + (_rotate_half(t) * sin_) - return torch.cat((t, t_pass), dim=-1) + # Unfused THD format + if tensor_format == "thd": + cu_seqlens = cu_seqlens // cp_size + seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + return torch.cat( + [ + _apply_rotary_pos_emb_base( + x.unsqueeze(1), + _get_freqs_on_this_cp_rank(freqs, x.size(0), cp_size, cp_rank), + interleaved=interleaved, + ) + for x in torch.split(t, seqlens) + ] + ).squeeze(1) + + # Unfused SBHD/BSHD format + if tensor_format == "sbhd": + seqlen = t.size(0) + elif tensor_format == "bshd": + seqlen = t.size(1) + else: + raise ValueError(f"Unsupported tensor_format: {tensor_format}.") + return _apply_rotary_pos_emb_base( + t, + _get_freqs_on_this_cp_rank(freqs, seqlen, cp_size, cp_rank), + tensor_format, + interleaved=interleaved, + ) diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 87298c2..38f829c 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -93,6 +93,7 @@ class FP8GlobalStateManager: FP8_RECIPE = None FP8_DISTRIBUTED_GROUP = None FP8_PARAMETERS = False + HIGH_PRECISION_INIT_VAL = False IS_FIRST_FP8_MODULE = False FP8_GRAPH_CAPTURING = False FP8_AUTOCAST_DEPTH = 0 @@ -117,6 +118,7 @@ class FP8GlobalStateManager: cls.FP8_RECIPE = None cls.FP8_DISTRIBUTED_GROUP = None cls.FP8_PARAMETERS = False + cls.HIGH_PRECISION_INIT_VAL = False cls.IS_FIRST_FP8_MODULE = False cls.FP8_GRAPH_CAPTURING = False cls.FP8_AUTOCAST_DEPTH = 0 @@ -267,6 +269,11 @@ class FP8GlobalStateManager: """Should the parameters be stored as FP8""" return cls.FP8_PARAMETERS + @classmethod + def with_high_precision_init_val(cls) -> bool: + """Should the high precision initial values be stored with FP8 parameters""" + return cls.HIGH_PRECISION_INIT_VAL + @classmethod def fp8_graph_capturing(cls) -> bool: """Is CUDA graph capture under way?""" @@ -500,7 +507,11 @@ class FP8GlobalStateManager: @contextmanager -def fp8_model_init(enabled: bool = True, recipe: Optional[Recipe] = None) -> None: +def fp8_model_init( + enabled: bool = True, + recipe: Optional[Recipe] = None, + preserve_high_precision_init_val: bool = False, +) -> None: """ Context manager for FP8 initialization of parameters. @@ -511,6 +522,12 @@ def fp8_model_init(enabled: bool = True, recipe: Optional[Recipe] = None) -> Non with fp8_model_init(enabled=True): model = transformer_engine.pytorch.Linear(768, 768) + # Preserving high precision initial value to initialize master weight + with fp8_model_init(enabled=True, preserve_high_precision_init_val=True): + model = transformer_engine.pytorch.Linear(768, 768) + master_weight = model.weight.get_high_precision_init_val() + model.weight.clear_high_precision_init_val() + Parameters ---------- enabled: bool, default = `True` @@ -526,18 +543,29 @@ def fp8_model_init(enabled: bool = True, recipe: Optional[Recipe] = None) -> Non * LoRA-like fine-tuning, where the main parameters of the model do not change. recipe: transformer_engine.common.recipe.Recipe, default = `None` Recipe used to create the parameters. If left to None, it uses the default FP8 recipe. + preserve_high_precision_init_val: bool, default = `False` + when enabled, store the high precision tensor used to initialize FP8 parameters + in CPU memory, and add two function attributes named `get_high_precision_init_val()` + and `clear_high_precision_init_val()` to FP8 parameters to get/clear this high + precision tensor. The purpose is that users can use this high-precision copy + to initialize master weights, avoiding the loss of precision that can occur when + using FP8 parameters directly. Note that after the master weights are initialized, + users should call `clear_high_precision_init_val()` to release this CPU memory. This functionality is *EXPERIMENTAL*. """ _fp8_parameters = FP8GlobalStateManager.FP8_PARAMETERS _fp8_recipe = FP8GlobalStateManager.FP8_RECIPE + _high_precision_init_val = FP8GlobalStateManager.HIGH_PRECISION_INIT_VAL FP8GlobalStateManager.FP8_PARAMETERS = enabled FP8GlobalStateManager.FP8_RECIPE = get_default_fp8_recipe() if recipe is None else recipe + FP8GlobalStateManager.HIGH_PRECISION_INIT_VAL = preserve_high_precision_init_val try: yield finally: FP8GlobalStateManager.FP8_PARAMETERS = _fp8_parameters FP8GlobalStateManager.FP8_RECIPE = _fp8_recipe + FP8GlobalStateManager.HIGH_PRECISION_INIT_VAL = _high_precision_init_val @contextmanager diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 4b82054..cdb75aa 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -10,6 +10,7 @@ import warnings from abc import ABC, abstractmethod from typing import Any, Dict, Generator, List, Optional, Set, Tuple, Union from contextlib import contextmanager +from types import MethodType import torch import torch.nn.functional as F @@ -405,6 +406,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): self.sequence_parallel = False self.param_init_meta = {} self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters() + self.preserve_high_precision_init_val = FP8GlobalStateManager.with_high_precision_init_val() self.fsdp_wrapped = False self.fsdp_group = None self._fp8_workspaces: Dict[str, QuantizedTensor] = {} @@ -902,7 +904,11 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): # If primary weights are in fp8, wrap the parameter as FP8Tensor fp8_meta_index = self.param_init_meta[name].fp8_meta_index + high_precision_init_val = None if self.primary_weights_in_fp8 and fp8_meta_index is not None: + if self.preserve_high_precision_init_val: + high_precision_init_val = param.detach().cpu() + quantizer = self.quantizers["scaling_fwd"][fp8_meta_index] assert ( quantizer is not None @@ -914,7 +920,34 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): # NOTE: Currently this can only be broken when primary weights are in Fp8 but # re-applying the nn.Parameter() wrap is a no-op when the input is already # a parameter so we always re-apply it just for extra safety. - setattr(self, name, torch.nn.Parameter(param)) + param = torch.nn.Parameter(param) + if high_precision_init_val is not None: + + # - Master weights are initialized from model weights, if we use fp8 primary + # weights to initialize master weights, the numerical values of master weights + # are not consistent with the numerical values when we initialize them from + # bf16/fp16 weights. + # - So we add a `_high_precision_init_val` attribute to each model weight to store + # the original bf16/fp16 weight on cpu before casting it to fp8. And users can + # use `get_high_precision_init_val` to get this cpu tensor. + # - This cpu tensor is not needed once the master weight is initialized, so users + # should call `clear_high_precision_init_val` to remove it after master weight + # is initialized. + + def get(self): + if hasattr(self, "_high_precision_init_val"): + return self._high_precision_init_val + return None + + def clear(self): + if hasattr(self, "_high_precision_init_val"): + del self._high_precision_init_val + + param._high_precision_init_val = high_precision_init_val + param.get_high_precision_init_val = MethodType(get, param) + param.clear_high_precision_init_val = MethodType(clear, param) + + setattr(self, name, param) @abstractmethod def forward(self): @@ -953,10 +986,26 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): FSDP process group that the weights are distributed over. """ + # FP8 primary weights + if isinstance(tensor, QuantizedTensor): + if update_workspace and quantizer is not None: + tensor.update_usage( + rowwise_usage=quantizer.rowwise_usage, + columnwise_usage=quantizer.columnwise_usage, + ) + return tensor + # Try getting workspace from cache out = None if cache_name is not None: out = self._fp8_workspaces.get(cache_name, None) + if quantizer is not None and isinstance(out, MXFP8TensorBase): + if quantizer.rowwise_usage and out._rowwise_data is None: + out = None + del self._fp8_workspaces[cache_name] + elif quantizer.columnwise_usage and out._columnwise_data is None: + out = None + del self._fp8_workspaces[cache_name] # Gather cached Fp8 workspace if it's distributed # NOTE: FSDP sharding is supported only for Fp8 buffers and will not work diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 8bf420a..8963a61 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -130,20 +130,17 @@ class _GroupedLinear(torch.autograd.Function): ) weights_fp8 = [] bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype - if not isinstance(weights[0], QuantizedTensor): - # FP8 cast to workspace buffer - update_workspace = is_first_microbatch is None or is_first_microbatch - for i in range(num_gemms): - weight_fp8 = module.get_weight_workspace( - tensor=weights[i], - quantizer=weight_quantizers[i], - cache_name=(None if is_first_microbatch is None else f"weight{i}"), - update_workspace=update_workspace, - skip_update_flag=skip_fp8_weight_update, - ) - weights_fp8.append(weight_fp8) - else: - weights_fp8 = weights + # FP8 cast to workspace buffer + update_workspace = is_first_microbatch is None or is_first_microbatch + for i in range(num_gemms): + weight_fp8 = module.get_weight_workspace( + tensor=weights[i], + quantizer=weight_quantizers[i], + cache_name=(None if is_first_microbatch is None else f"weight{i}"), + update_workspace=update_workspace, + skip_update_flag=skip_fp8_weight_update, + ) + weights_fp8.append(weight_fp8) else: inputmats = inputmats_no_fp8 @@ -180,7 +177,7 @@ class _GroupedLinear(torch.autograd.Function): weight_quantizers[i].calibrate(weights[i]) if is_grad_enabled: - + ctx.weight_quantizers = weight_quantizers ctx.weights_shape_1 = weights[0].shape[1] tensors_to_save, tensor_objects = prepare_for_saving(*inputmats, *weights_fp8, *biases) @@ -270,6 +267,12 @@ class _GroupedLinear(torch.autograd.Function): device=ctx.device, ) + for weight, quantizer in zip(weights, ctx.weight_quantizers): + if quantizer is not None and isinstance(weight, QuantizedTensor): + weight.update_usage( + rowwise_usage=quantizer.rowwise_usage, + columnwise_usage=quantizer.columnwise_usage, + ) general_grouped_gemm( weights, grad_output, diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 4022924..fc316e3 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -262,28 +262,26 @@ class _LayerNormLinear(torch.autograd.Function): nvtx_range_pop(f"{nvtx_label}.gemm_input_cast_comm") # Cast weight to expected dtype - weightmat = weight - quantized_weight = False if not fp8: - weightmat = cast_if_needed(weightmat, activation_dtype) + quantized_weight = False + weightmat = cast_if_needed(weight, activation_dtype) else: - if not isinstance(weight, QuantizedTensor): - quantized_weight = True - - # Configure quantizer - if weight_quantizer is not None: - weight_quantizer.set_usage(rowwise=True, columnwise=True) - - # FP8 cast to workspace buffer - update_workspace = is_first_microbatch is None or is_first_microbatch - weightmat = module.get_weight_workspace( - tensor=weight, - quantizer=weight_quantizer, - cache_name=(None if is_first_microbatch is None else "weight"), - update_workspace=update_workspace, - skip_update_flag=skip_fp8_weight_update, - fsdp_group=fsdp_group, - ) + quantized_weight = not isinstance(weight, QuantizedTensor) + + # Configure quantizer + if weight_quantizer is not None: + weight_quantizer.set_usage(rowwise=True, columnwise=True) + + # FP8 cast to workspace buffer + update_workspace = is_first_microbatch is None or is_first_microbatch + weightmat = module.get_weight_workspace( + tensor=weight, + quantizer=weight_quantizer, + cache_name=(None if is_first_microbatch is None else "weight"), + update_workspace=update_workspace, + skip_update_flag=skip_fp8_weight_update, + fsdp_group=fsdp_group, + ) # Cast bias to expected dtype bias_dtype = activation_dtype @@ -345,11 +343,12 @@ class _LayerNormLinear(torch.autograd.Function): clear_tensor_data(ln_out, ln_out_total) if is_grad_enabled: + ctx.weight_quantizer = weight_quantizer ctx.ln_out_needs_gather = ( weight.requires_grad and parallel_mode == "column" and sequence_parallel ) - # Input with column-wise usage is needed for dgrad GEMM. + # Input with column-wise usage is needed for wgrad GEMM. if backward_needs_input: if isinstance(ln_out, QuantizedTensor): # For sequence parallel in vanilla FP8, rowwise data is @@ -358,6 +357,11 @@ class _LayerNormLinear(torch.autograd.Function): if isinstance(ln_out, MXFP8TensorBase) or not ctx.ln_out_needs_gather: ln_out.update_usage(rowwise_usage=False) + # Weight with column-wise usage is needed for dgrad GEMM. + if inp.requires_grad: + if isinstance(weightmat, QuantizedTensor): + weightmat.update_usage(columnwise_usage=True) + if cpu_offloading: if fp8 and weightmat is not None: set_offloading_param(weightmat, "weight_offloading", True) @@ -642,6 +646,11 @@ class _LayerNormLinear(torch.autograd.Function): if hasattr(recipe, "fp8_gemm_dgrad"): dgrad_gemm_use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator + if ctx.weight_quantizer is not None and isinstance(weight, QuantizedTensor): + weight.update_usage( + rowwise_usage=ctx.weight_quantizer.rowwise_usage, + columnwise_usage=ctx.weight_quantizer.columnwise_usage, + ) dgrad, *_ = general_gemm( weight, grad_output, @@ -1274,6 +1283,7 @@ class LayerNormLinear(TransformerEngineBaseModule): inp: torch.Tensor, is_first_microbatch: Optional[bool] = None, fp8_output: Optional[bool] = False, + fp8_grad: Optional[bool] = False, ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: """ Apply layer normalization to the input followed by a linear transformation. @@ -1304,6 +1314,13 @@ class LayerNormLinear(TransformerEngineBaseModule): if skip_fp8_weight_update is not None: is_first_microbatch = False + if self.ub_overlap_rs_fprop: + if get_ub(self.ub_name + "_fprop").is_fp8_ubuf(): + fp8_output = True + if self.ub_overlap_rs_dgrad: + if get_ub(self.ub_name + "_dgrad").is_fp8_ubuf(): + fp8_grad = True + with self.prepare_forward( inp, allow_non_contiguous=False # removed .contiguous from inside the layer ) as inp: @@ -1331,7 +1348,7 @@ class LayerNormLinear(TransformerEngineBaseModule): output_quantizer, grad_output_quantizer, grad_input_quantizer, - ) = self._get_quantizers(fp8_output) + ) = self._get_quantizers(fp8_output, fp8_grad) if torch.is_grad_enabled(): fwd_fn = _LayerNormLinear.apply @@ -1397,7 +1414,7 @@ class LayerNormLinear(TransformerEngineBaseModule): return out, ln_out return out - def _get_quantizers(self, fp8_output): + def _get_quantizers(self, fp8_output, fp8_grad): if not self.fp8: return [None] * 5 grad_input_quantizer = None @@ -1412,6 +1429,8 @@ class LayerNormLinear(TransformerEngineBaseModule): if torch.is_grad_enabled(): grad_output_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1] grad_output_quantizer.internal = True + if fp8_grad: + grad_input_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT1] return ( input_quantizer, diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 633690b..9cffc47 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -319,35 +319,31 @@ class _LayerNormMLP(torch.autograd.Function): ln_out_total = ln_out # Cast weights to expected dtype - fc1_weight_final = fc1_weight - fc2_weight_final = fc2_weight if not fp8: - fc1_weight_final = cast_if_needed(fc1_weight_final, activation_dtype) - fc2_weight_final = cast_if_needed(fc2_weight_final, activation_dtype) + fc1_weight_final = cast_if_needed(fc1_weight, activation_dtype) + fc2_weight_final = cast_if_needed(fc2_weight, activation_dtype) else: # If weights are not quantized, we call get_weight_workspace, # which handles weight caching etc. - if not isinstance(fc1_weight, QuantizedTensor): - # FP8 cast to workspace buffer - update_workspace = is_first_microbatch is None or is_first_microbatch - fc1_weight_final = module.get_weight_workspace( - tensor=fc1_weight, - quantizer=fc1_weight_quantizer, - cache_name=(None if is_first_microbatch is None else "fc1_weight"), - update_workspace=update_workspace, - skip_update_flag=skip_fp8_weight_update, - fsdp_group=fsdp_group, - ) - if not isinstance(fc2_weight, QuantizedTensor): - fc2_weight_quantizer.set_usage(rowwise=True, columnwise=True) - fc2_weight_final = module.get_weight_workspace( - tensor=fc2_weight, - quantizer=fc2_weight_quantizer, - cache_name=(None if is_first_microbatch is None else "fc2_weight"), - update_workspace=update_workspace, - skip_update_flag=skip_fp8_weight_update, - fsdp_group=fsdp_group, - ) + # FP8 cast to workspace buffer + update_workspace = is_first_microbatch is None or is_first_microbatch + fc1_weight_final = module.get_weight_workspace( + tensor=fc1_weight, + quantizer=fc1_weight_quantizer, + cache_name=(None if is_first_microbatch is None else "fc1_weight"), + update_workspace=update_workspace, + skip_update_flag=skip_fp8_weight_update, + fsdp_group=fsdp_group, + ) + fc2_weight_quantizer.set_usage(rowwise=True, columnwise=True) + fc2_weight_final = module.get_weight_workspace( + tensor=fc2_weight, + quantizer=fc2_weight_quantizer, + cache_name=(None if is_first_microbatch is None else "fc2_weight"), + update_workspace=update_workspace, + skip_update_flag=skip_fp8_weight_update, + fsdp_group=fsdp_group, + ) # Cast biases to expected dtype bias_dtype = activation_dtype @@ -430,7 +426,6 @@ class _LayerNormMLP(torch.autograd.Function): dim_size[0] = dim_size[0] // tp_world_size dim_size[1] = fc2_weight.size(0) rs_out = torch.empty(dim_size, dtype=activation_dtype, device=device) - fc2_out = ub_obj_fc2out.get_buffer(output_quantizer) else: dim_size = list(act_out.size()) dim_size[1] = fc2_weight.size(0) @@ -450,6 +445,14 @@ class _LayerNormMLP(torch.autograd.Function): ub_type=tex.CommOverlapType.RS if ub_overlap_rs else None, extra_output=rs_out, ) + + # Weight with column-wise usage is needed for dgrad GEMM. + if is_grad_enabled and inp.requires_grad: + if isinstance(fc1_weight_final, QuantizedTensor): + fc1_weight_final.update_usage(columnwise_usage=True) + if isinstance(fc2_weight_final, QuantizedTensor): + fc2_weight_final.update_usage(columnwise_usage=True) + if not is_grad_enabled: clear_tensor_data(act_out, fc1_out_without_bias, fc1_out) @@ -488,6 +491,8 @@ class _LayerNormMLP(torch.autograd.Function): fc2_weight_final if fp8 and not isinstance(fc2_weight, Float8Tensor) else None, ) + ctx.fc1_weight_quantizer = fc1_weight_quantizer + ctx.fc2_weight_quantizer = fc2_weight_quantizer if not fc1_weight.requires_grad: if not return_layernorm_output: clear_tensor_data(ln_out) @@ -500,11 +505,13 @@ class _LayerNormMLP(torch.autograd.Function): ln_weight, ln_out.clone() if ub_overlap_ag else ln_out, # avoid saving a UB buffer fc1_weight_final, + fc1_weight, fc1_bias, fc1_out, fc1_out_without_bias, act_out, fc2_weight_final, + fc2_weight, fc2_bias, mu, rsigma, @@ -619,11 +626,13 @@ class _LayerNormMLP(torch.autograd.Function): ln_weight, ln_out, fc1_weight, + origin_fc1_weight, fc1_bias, fc1_out, fc1_out_without_bias, act_out, fc2_weight, + origin_fc2_weight, fc2_bias, mu, rsigma, @@ -642,7 +651,7 @@ class _LayerNormMLP(torch.autograd.Function): ) fc2_weight_main_grad = ( ctx.fc2_main_grad - if fc2_weight is not None + if origin_fc2_weight is not None and ctx.fuse_wgrad_accumulation and ctx.fc2_weight_requires_grad else None @@ -651,8 +660,8 @@ class _LayerNormMLP(torch.autograd.Function): # For CPU offloading, we offloaded weight and weight.main_grad to different tensors, # we need to connect them into one. if ctx.fuse_wgrad_accumulation: - fc1_weight.main_grad = fc1_weight_main_grad - fc2_weight.main_grad = fc2_weight_main_grad + origin_fc1_weight.main_grad = fc1_weight_main_grad + origin_fc2_weight.main_grad = fc2_weight_main_grad # TODO: Fix this # pylint: disable=fixme # Gather saved autograd context tensors when running with FSDP @@ -735,6 +744,11 @@ class _LayerNormMLP(torch.autograd.Function): ) # FC2 DGRAD; Unconditional + if ctx.fc2_weight_quantizer is not None and isinstance(ctx.fc2_weight, QuantizedTensor): + ctx.fc2_weight.update_usage( + rowwise_usage=ctx.fc2_weight_quantizer.rowwise_usage, + columnwise_usage=ctx.fc2_weight_quantizer.columnwise_usage, + ) gemm_output, *_ = general_gemm( fc2_weight, grad_output, @@ -769,14 +783,18 @@ class _LayerNormMLP(torch.autograd.Function): act_out, grad_output, get_workspace(), - out_dtype=ctx.activation_dtype, + out_dtype=( + origin_fc2_weight.main_grad.dtype + if ctx.fuse_wgrad_accumulation + else ctx.activation_dtype + ), quantization_params=None, # wgrad in high precision layout="NT", grad=True, bias=fc2_bias if fc2_bias is not None and fc2_bias_grad is None else None, accumulate=accumulate_wgrad_into_param_main_grad, use_split_accumulator=_2X_ACC_WGRAD, - out=fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None, + out=origin_fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None, ) if fc2_bias_grad is None: fc2_bias_grad = fc2_bias_grad_ @@ -864,6 +882,13 @@ class _LayerNormMLP(torch.autograd.Function): fc1_dgrad_bulk = ub_obj_fc1_wgrad.get_buffer(None) # FC1 DGRAD: Unconditional + if ctx.fc1_weight_quantizer is not None and isinstance( + ctx.fc1_weight_quantizer, QuantizedTensor + ): + ctx.fc1_weight.update_usage( + rowwise_usage=ctx.fc1_weight_quantizer.rowwise_usage, + columnwise_usage=ctx.fc1_weight_quantizer.columnwise_usage, + ) fc1_dgrad, *_, fc1_dgrad_rs_out = general_gemm( fc1_weight, dact, @@ -930,12 +955,16 @@ class _LayerNormMLP(torch.autograd.Function): ln_out_total, dact, get_workspace(), - out_dtype=ctx.activation_dtype, + out_dtype=( + origin_fc1_weight.main_grad.dtype + if ctx.fuse_wgrad_accumulation + else ctx.activation_dtype + ), layout="NT", grad=fuse_gemm_and_bias_fc1_wgrad, bias=fc1_bias if fuse_gemm_and_bias_fc1_wgrad else None, accumulate=accumulate_wgrad_into_param_main_grad, - out=fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None, + out=origin_fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None, ub=ub_obj_fc1_wgrad, ub_type=tex.CommOverlapType.RS if ctx.ub_bulk_wgrad else None, extra_output=fc1_dgrad_rs_out, @@ -996,16 +1025,21 @@ class _LayerNormMLP(torch.autograd.Function): if ctx.fc1_weight_requires_grad: # Handle custom DDP from mcore. if ctx.fuse_wgrad_accumulation and hasattr(fc1_weight, "grad_added_to_main_grad"): - fc1_weight.grad_added_to_main_grad = True - if getattr(fc1_weight, "zero_out_wgrad", False): + origin_fc1_weight.grad_added_to_main_grad = True + if getattr(origin_fc1_weight, "zero_out_wgrad", False): fc1_wgrad = torch.zeros( - fc1_weight.main_grad.shape, - dtype=fc1_weight.dtype, + origin_fc1_weight.main_grad.shape, + dtype=origin_fc1_weight.dtype, device=torch.cuda.current_device(), requires_grad=False, ) else: - fc1_wgrad = None + fc1_wgrad = torch.empty( + origin_fc1_weight.main_grad.shape, + dtype=origin_fc1_weight.dtype, + device=torch.cuda.current_device(), + requires_grad=False, + ) elif ctx.fuse_wgrad_accumulation: fc1_wgrad = None else: @@ -1013,17 +1047,24 @@ class _LayerNormMLP(torch.autograd.Function): if ctx.fc2_weight_requires_grad: # Handle custom DDP from mcore. - if ctx.fuse_wgrad_accumulation and hasattr(fc2_weight, "grad_added_to_main_grad"): - fc2_weight.grad_added_to_main_grad = True - if getattr(fc2_weight, "zero_out_wgrad", False): + if ctx.fuse_wgrad_accumulation and hasattr( + origin_fc2_weight, "grad_added_to_main_grad" + ): + origin_fc2_weight.grad_added_to_main_grad = True + if getattr(origin_fc2_weight, "zero_out_wgrad", False): fc2_wgrad = torch.zeros( - fc2_weight.main_grad.shape, - dtype=fc2_weight.dtype, + origin_fc2_weight.main_grad.shape, + dtype=origin_fc2_weight.dtype, device=torch.cuda.current_device(), requires_grad=False, ) else: - fc2_wgrad = None + fc2_wgrad = torch.empty( + origin_fc2_weight.main_grad.shape, + dtype=origin_fc2_weight.dtype, + device=torch.cuda.current_device(), + requires_grad=False, + ) elif ctx.fuse_wgrad_accumulation: fc2_wgrad = None else: @@ -1429,6 +1470,11 @@ class LayerNormMLP(TransformerEngineBaseModule): if skip_fp8_weight_update is not None: is_first_microbatch = False + fp8_output = False + if self.ub_overlap_rs: + if get_ub("fc2_fprop").is_fp8_ubuf(): + fp8_output = True + with self.prepare_forward(inp, num_gemms=2) as inp: # Get quantizers ( @@ -1440,7 +1486,7 @@ class LayerNormMLP(TransformerEngineBaseModule): grad_fc1_output_quantizer, grad_fc2_output_quantizer, grad_input_quantizer, - ) = self._get_quantizers() + ) = self._get_quantizers(fp8_output) # Get weight tensors fc1_weight = self.fc1_weight @@ -1528,7 +1574,7 @@ class LayerNormMLP(TransformerEngineBaseModule): return out, ln_out return out - def _get_quantizers(self): + def _get_quantizers(self, fp8_output): ( fc1_input_quantizer, fc1_weight_quantizer, @@ -1550,6 +1596,8 @@ class LayerNormMLP(TransformerEngineBaseModule): ) fc2_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_WEIGHT] fc2_weight_quantizer.internal = True + if fp8_output: + output_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_OUTPUT] if torch.is_grad_enabled(): grad_fc2_output_quantizer = self.quantizers["scaling_bwd"][ tex.FP8BwdTensors.GRAD_OUTPUT1 diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index f96355a..91dfe92 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -176,31 +176,29 @@ class _Linear(torch.autograd.Function): nvtx_range_pop(f"{nvtx_label}.input_cast_comm") # Cast weight to expected dtype - weightmat = weight if not fp8: - weightmat = cast_if_needed(weightmat, activation_dtype) + weightmat = cast_if_needed(weight, activation_dtype) else: - if not isinstance(weight, QuantizedTensor): - # Configure quantizer - if weight_quantizer is not None: - columnwise_usage = is_grad_enabled and inp.requires_grad - if not columnwise_usage: - columnwise_usage = ( - is_fp8_activation_recompute_enabled() - and not in_fp8_activation_recompute_phase() - ) - weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) - - # FP8 cast to workspace buffer - update_workspace = is_first_microbatch is None or is_first_microbatch - weightmat = module.get_weight_workspace( - tensor=weight, - quantizer=weight_quantizer, - cache_name=(None if is_first_microbatch is None else "weight"), - update_workspace=update_workspace, - skip_update_flag=skip_fp8_weight_update, - fsdp_group=fsdp_group, - ) + # Configure quantizer + if weight_quantizer is not None: + columnwise_usage = is_grad_enabled and inp.requires_grad + if not columnwise_usage: + columnwise_usage = ( + is_fp8_activation_recompute_enabled() + and not in_fp8_activation_recompute_phase() + ) + weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) + + # FP8 cast to workspace buffer + update_workspace = is_first_microbatch is None or is_first_microbatch + weightmat = module.get_weight_workspace( + tensor=weight, + quantizer=weight_quantizer, + cache_name=(None if is_first_microbatch is None else "weight"), + update_workspace=update_workspace, + skip_update_flag=skip_fp8_weight_update, + fsdp_group=fsdp_group, + ) # Cast bias to expected dtype bias_dtype = activation_dtype @@ -259,6 +257,7 @@ class _Linear(torch.autograd.Function): nvtx_range_pop(f"{nvtx_label}.gemm") if is_grad_enabled: + ctx.weight_quantizer = weight_quantizer saved_inputmat = None ctx.backward_input_needs_gather = ( @@ -274,6 +273,11 @@ class _Linear(torch.autograd.Function): inputmat.update_usage(rowwise_usage=False) saved_inputmat = inputmat + # Weight with column-wise usage is needed for dgrad GEMM. + if inp.requires_grad: + if isinstance(weightmat, QuantizedTensor): + weightmat.update_usage(columnwise_usage=True) + if cpu_offloading: set_offloading_param(weight, "weight_offloading", True) set_offloading_param(weightmat, "weight_offloading", True) @@ -530,6 +534,12 @@ class _Linear(torch.autograd.Function): recipe.fp8_gemm_dgrad.use_split_accumulator ) + if ctx.weight_quantizer is not None and isinstance(weight_fp8, QuantizedTensor): + weight_fp8.update_usage( + rowwise_usage=ctx.weight_quantizer.rowwise_usage, + columnwise_usage=ctx.weight_quantizer.columnwise_usage, + ) + dgrad, *_, rs_out = general_gemm( weight_fp8, grad_output, @@ -1077,6 +1087,13 @@ class Linear(TransformerEngineBaseModule): if skip_fp8_weight_update is not None: is_first_microbatch = False + if self.ub_overlap_rs_fprop: + if get_ub(self.ub_name + "_fprop").is_fp8_ubuf(): + fp8_output = True + if self.ub_overlap_rs_dgrad: + if get_ub(self.ub_name + "_dgrad").is_fp8_ubuf(): + fp8_grad = True + with self.prepare_forward( inp, allow_non_contiguous=isinstance(inp, QuantizedTensor), diff --git a/transformer_engine/pytorch/tensor/__init__.py b/transformer_engine/pytorch/tensor/__init__.py index 610ec2a..22b86fb 100644 --- a/transformer_engine/pytorch/tensor/__init__.py +++ b/transformer_engine/pytorch/tensor/__init__.py @@ -7,6 +7,7 @@ import torch from .quantized_tensor import QuantizedTensor, Quantizer +from .utils import cast_master_weights_to_fp8, replace_raw_data __all__ = [ "QuantizedTensor", diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index e45010b..2fb1283 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -185,9 +185,9 @@ class Float8CurrentScalingQuantizer(Quantizer): """ - """Scaling factor to multiply when quantizing to FP8""" + """Workspace buffer for FP8 scaling factor""" scale: torch.Tensor - """Max-abs value from last FP8 cast""" + """Workspace buffer for max-abs value""" amax: torch.Tensor """FP8 datatype""" dtype: TE_DType diff --git a/transformer_engine/pytorch/tensor/utils.py b/transformer_engine/pytorch/tensor/utils.py new file mode 100644 index 0000000..8dd04b5 --- /dev/null +++ b/transformer_engine/pytorch/tensor/utils.py @@ -0,0 +1,315 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Helper functions for using fp8 tensors as weights""" + +import torch + +import transformer_engine_torch as tex +from transformer_engine_torch import multi_tensor_scale, multi_tensor_compute_scale_and_scale_inv + +from .quantized_tensor import QuantizedTensor +from .float8_tensor import Float8Tensor, Float8Quantizer, Float8CurrentScalingQuantizer +from .mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer +from ..optimizers.multi_tensor_apply import multi_tensor_applier + + +def replace_raw_data(tensor: QuantizedTensor, new_raw_data: torch.Tensor): + r"""Change a quantized tensor's data buffer while preserving values + + This function modifies only the address space of the underlying + raw data and does not alter any other tensor attributes or values. + + This may be used for custom buffer allocations, e.g. packing + multiple parameter tensors together into a single contiguous + buffer for ZeRO-2. + + """ + if isinstance(tensor, Float8Tensor): + old_raw_data = tensor._data + assert old_raw_data.dtype == new_raw_data.dtype, "The data types of raw data don't match" + new_raw_data.detach().copy_(old_raw_data) + tensor._data = new_raw_data + del old_raw_data + elif isinstance(tensor, MXFP8Tensor): + raise NotImplementedError("replace_raw_data for MXFP8Tensor is not supported yet") + else: + raise ValueError(f"replace_raw_data for {type(tensor)} is not supported yet") + + +def cast_master_weights_to_fp8( + model_weights, master_weights, start_offsets, group, fsdp_shard_model_weights=None +): + r"""Helper function to cast master weights to FP8 primary weights. + + This is intended for use with ZeRO/FSDP. Each rank has a shard of + the master weights (possibly empty) and a full copy of the model + weights. + + Parameters + ---------- + model_weights : list of FP8 weights. + master_weights : list of master weights. Typically they are FP32 weights. + start_offsets : list of integers, the starting index of the master weight in the model weight. + master_weight may be smaller than model_weight because it could be distributed + across multiple ranks. These offsets indicate which part of the model_weight + should be updated. + group : The distributed group to do amax reduction. Typically it's the data parallel + group. + fsdp_shard_model_weights : list of FSDP shard model weights. If None, it means that the model weights are + not sharded. Otherwise, it means that the model weights are sharded and we get + target model weights data storage using the FSDP shard model weights. + + """ + + delayed_scaling_params = [] + current_scaling_params = [] + + if fsdp_shard_model_weights is None: + use_fsdp_shard_model_weights = False + fsdp_shard_model_weights = [None] * len(model_weights) + else: + use_fsdp_shard_model_weights = True + + for model_weight, master_weight, start_offset, fsdp_shard_model_weight in zip( + model_weights, master_weights, start_offsets, fsdp_shard_model_weights + ): + # Clear `_high_precision_init_val` of model_weight automatically. + # - Master weights are initialized from model weights, if we use fp8 primary weights to + # initialize master weights, the numerical values of master weights are not consistent + # with the numerical values when we initialize them from bf16/fp16 weights. + # - So we add a `_high_precision_init_val` attribute to each model weight to store the + # original bf16/fp16 weight on cpu before casting it to fp8. And users can use + # `get_high_precision_init_val` to get this cpu tensor. + # - This cpu tensor is not needed once the master weight is initialized, so users should + # call `clear_high_precision_init_val` to remove it after master weight is initialized. + # - In case users don't call `clear_high_precision_init_val`, we will clear it automatically + # here. It's safe to clear the `_high_precision_init_val` at this time because this + # function is supposed to be called after the master weights are initialized and updated. + if hasattr(model_weight, "clear_high_precision_init_val"): + model_weight.clear_high_precision_init_val() + + if master_weight is not None: + # When not using fp8_primary_weights, the master_weight (fp32) is first cast to + # bf16/fp16, and then cast to fp8 during forward. Although it's not necessary when + # fp8_primary_weights is enabled, we still keep this logic to keep numerical + # consistency. So here we cast the master_weight to model_weight.dtype. + master_weight = master_weight.to(model_weight.dtype) + + quantizer = model_weight._get_quantizer() + if isinstance(quantizer, Float8Quantizer): + delayed_scaling_params.append( + (model_weight, master_weight, start_offset, fsdp_shard_model_weight) + ) + elif isinstance(quantizer, Float8CurrentScalingQuantizer): + current_scaling_params.append( + (model_weight, master_weight, start_offset, fsdp_shard_model_weight) + ) + elif isinstance(quantizer, MXFP8Quantizer): + raise NotImplementedError( + "cast_master_weights_to_fp8 for MXFP8BlockScaling is not supported yet" + ) + else: + raise ValueError( + f"cast_master_weights_to_fp8 for {type(quantizer)} is not supported yet" + ) + + if len(delayed_scaling_params) > 0: + _cast_master_weights_to_fp8_delayed_scaling( + delayed_scaling_params, group, use_fsdp_shard_model_weights + ) + if len(current_scaling_params) > 0: + _cast_master_weights_to_fp8_current_scaling( + current_scaling_params, group, use_fsdp_shard_model_weights + ) + + +def _cast_master_weights_to_fp8_delayed_scaling(params, group, use_fsdp_shard_model_weights=False): + r"""Helper function to cast master weights to FP8 primary weights for delayed scaling. + + Parameters + ---------- + params : List of tuple, each tuple contains a model weight, a master weight, and an offset + indicating the starting index of the master weight in the model weight. + group : The distributed group to do amax reduction. Typically it's the data parallel + group. + use_fsdp_shard_model_weights : bool, if True, it means that the model weights are sharded. + """ + + # Collect amaxes to do reduce-max among dp group. + # Collect scales and scale_invs to update scale_invs of the fp8 weights. + amaxes, scales, scale_invs = [], [], [] + + for model_weight, master_weight, start_offset, shard_model_weight_raw in params: + # Reset transpose cache for all model weights. + # We cannot create transpose cache here because users (like megatron) may want to overlap + # the all-gather of model weights and forward process, so the model weight is not updated + # currently. + model_weight._reset_caches() + + quantizer = model_weight._get_quantizer() + + amaxes.append(quantizer.amax.view(1)) + scales.append(quantizer.scale.view(1)) + scale_invs.append(model_weight._scale_inv.view(1)) + + # If master weight is None, it means that the master weight of the current model weight + # is in other DP ranks. + if master_weight is None: + continue + + # If master weight is not None, start_offset must be a valid value. + assert start_offset is not None + assert start_offset >= 0 + end_offset = start_offset + master_weight.numel() + assert end_offset <= model_weight.numel() + + # master_weight may be smaller than model_weight because it could be distributed across + # multiple ranks. So we need to create a dummy weight using the raw data from model_weight. + if not use_fsdp_shard_model_weights: + shard_model_weight_raw = model_weight._data.view(-1)[start_offset:end_offset] + shard_model_weight_fp8 = quantizer.create_tensor_from_data( + shard_model_weight_raw.view(1, -1), + model_weight.dtype, + ) + + # Cast master weight to fp8. + quantizer.update_quantized(master_weight.view(1, -1), shard_model_weight_fp8) + + if len(amaxes) > 0: + dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device=amaxes[0].device) + + # Reduce amaxes. + packed_amaxes = torch.empty(len(amaxes), dtype=torch.float32, device=amaxes[0].device) + packed_amax_views = [packed_amaxes[i].view(1) for i in range(len(amaxes))] + multi_tensor_applier( + multi_tensor_scale, dummy_overflow_buf, [amaxes, packed_amax_views], 1.0 + ) + torch.distributed.all_reduce( + packed_amaxes, + op=torch.distributed.ReduceOp.MAX, + group=group, + ) + multi_tensor_applier( + multi_tensor_scale, dummy_overflow_buf, [packed_amax_views, amaxes], 1.0 + ) + + # Update scale_invs. + packed_scales = torch.empty(len(scales), dtype=torch.float32, device=scales[0].device) + packed_scale_views = [packed_scales[i].view(1) for i in range(len(scales))] + multi_tensor_applier( + multi_tensor_scale, dummy_overflow_buf, [scales, packed_scale_views], 1.0 + ) + torch.reciprocal(packed_scales, out=packed_scales) + multi_tensor_applier( + multi_tensor_scale, dummy_overflow_buf, [packed_scale_views, scale_invs], 1.0 + ) + + +def _cast_master_weights_to_fp8_current_scaling(params, group, use_fsdp_shard_model_weights=False): + r"""Helper function to cast master weights to FP8 primary weights for current scaling. + + Parameters + ---------- + params : List of tuple, each tuple contains a model weight, a master weight, and an offset + indicating the starting index of the master weight in the model weight. + group : The distributed group to do amax reduction. Typically it's the data parallel + group. + use_fsdp_shard_model_weights : bool, if True, it means that the model weights are sharded. + """ + + # Parameter attributes + device = params[0][0].device + fp8_dtype = params[0][0]._get_quantizer().dtype + force_pow_2_scales = params[0][0]._get_quantizer().force_pow_2_scales + amax_epsilon = params[0][0]._get_quantizer().amax_epsilon + + # Create a dummy overflow buffer, it's needed by multi_tensor_applier. + dummy_overflow_buf = torch.zeros(1, dtype=torch.int, device=device) + + # Create a contiguous buffer to store amaxes temporarily, so we can perform all all-reduce + # NCCL kernels at once. + packed_amaxes = torch.zeros(len(params), dtype=torch.float32, device=device) + amaxes = [packed_amaxes[i : i + 1] for i in range(len(params))] + + # Collect scales and scale_invs to update them after amax reduction. + scales, scale_invs = [], [] + + # --------------------------------------------------------------------------------------------- + # Step 1: Iterate through all the none empty master weights and compute amax of them. Store the + # amaxes in a contiguous buffer. If the master weight is None, the corresponding amax + # will be set to 0. + # --------------------------------------------------------------------------------------------- + for (model_weight, master_weight, _, _), amax in zip(params, amaxes): + + # Make sure all the model weights have the same numerical options. + quantizer = model_weight._get_quantizer() + assert quantizer.dtype == fp8_dtype + assert quantizer.force_pow_2_scales == force_pow_2_scales + assert quantizer.amax_epsilon == amax_epsilon + + scales.append(quantizer.scale.view(1)) + scale_invs.append(model_weight._scale_inv.view(1)) + + # Compute amax of the master weight and store it in packed_amaxes. + if master_weight is not None: + tex.compute_amax(master_weight, amax) + + # --------------------------------------------------------------------------------------------- + # Step 2: Perform all-reduce on packed_amaxes to get the global amax. + # --------------------------------------------------------------------------------------------- + torch.distributed.all_reduce(packed_amaxes, op=torch.distributed.ReduceOp.MAX, group=group) + + # --------------------------------------------------------------------------------------------- + # Step 3: Update scales and scale_invs. + # --------------------------------------------------------------------------------------------- + if fp8_dtype == tex.DType.kFloat8E4M3: + max_fp8 = 448.0 + elif fp8_dtype == tex.DType.kFloat8E5M2: + max_fp8 = 57344.0 + else: + raise ValueError(f"Unsupported FP8 dtype: {fp8_dtype}") + multi_tensor_applier( + multi_tensor_compute_scale_and_scale_inv, + dummy_overflow_buf, + [amaxes, scales, scale_invs], + max_fp8, + force_pow_2_scales, + amax_epsilon, + ) + + # --------------------------------------------------------------------------------------------- + # Step 4: Cast master weights to FP8. + # --------------------------------------------------------------------------------------------- + for (model_weight, master_weight, start_offset, model_weight_fragment), scale in zip( + params, scales + ): + # Reset transpose cache for all model weights. + # We cannot create transpose cache here because users (like megatron) may want to overlap + # the all-gather of model weights and forward process, so the model weight is not updated + # currently. + model_weight._reset_caches() + + # If master weight is None, it means that the master weight of the current model weight + # is in other DP ranks. + if master_weight is None: + continue + + # Cast master weight to FP8 + end_offset = start_offset + master_weight.numel() + if not use_fsdp_shard_model_weights: + model_weight_fragment = model_weight.reshape(-1)[start_offset:end_offset] + quantizer = Float8Quantizer( + scale=scale, + amax=torch.Tensor(), + fp8_dtype=model_weight._fp8_dtype, + ) + if use_fsdp_shard_model_weights and not isinstance(model_weight_fragment, Float8Tensor): + # NOTE: The fsdp shard model weight may be a unit8 tensor instead of + # a float8 tensor. We should handle this situation properly. + model_weight_fragment = quantizer.create_tensor_from_data( + model_weight_fragment.view(-1), + model_weight.dtype, + ) + quantizer.update_quantized(master_weight, model_weight_fragment)