# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """TE FP8 extensions and GEMMs""" from typing import Optional, Tuple, Union import torch import transformer_engine_extensions as tex from .constants import TE_DType def fp8_gemm( A: torch.Tensor, A_scale_inv: torch.Tensor, A_fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], A_dtype: tex.DType, B: torch.Tensor, B_scale_inv: torch.Tensor, B_fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], B_dtype: tex.DType, out_dtype: torch.dtype, workspace: torch.Tensor, accumulate: bool = False, out: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, use_bias: bool = False, fp32_output: bool = False, use_split_accumulator: bool = False, ) -> torch.Tensor: """TN layout GEMM with fp8 inputs.""" empty_tensor = torch.Tensor() return_output = False if out is None: out = torch.empty( B.shape[0], A.shape[0], dtype=torch.float32 if fp32_output else out_dtype, device="cuda", ) return_output = True out_dtype = tex.DType.kFloat32 if fp32_output else TE_DType[out_dtype] _ = torch.ops.tex_ts.te_gemm_ts( A, A_scale_inv, A_fp8_tensor, A_dtype, True, # transa B, B_scale_inv, B_fp8_tensor, B_dtype, False, # transb out, out_dtype, bias if use_bias else empty_tensor, empty_tensor, # this is pre_gelu_out False, # grad workspace, workspace.shape[0], accumulate, use_split_accumulator, ) if return_output: return out return None def gemm( A: torch.Tensor, B: torch.Tensor, dtype: torch.dtype, workspace: torch.Tensor, gelu: bool = False, gelu_input: Optional[torch.Tensor] = None, grad: bool = False, accumulate: bool = False, layout: str = "TN", out: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, use_bias: bool = False, fp32_output: bool = False, ) -> Tuple[Union[torch.Tensor, None], ...]: """Non FP8 GEMM.""" assert layout in ("TN", "NN", "NT"), f"GEMM layout {layout} not supported." transa = layout[0] == "T" transb = layout[1] == "T" empty_tensor = torch.Tensor() fp8_index = -1 # dummy index input_dtype = TE_DType[dtype] output_dtype = tex.DType.kFloat32 if fp32_output else input_dtype return_output = False if out is None: out = torch.empty( B.shape[1] if transb else B.shape[0], A.shape[0] if transa else A.shape[1], dtype=torch.float32 if fp32_output else dtype, device="cuda", ) return_output = True if gelu and not grad: gelu_input = torch.empty_like(out, dtype=dtype) elif not gelu: gelu_input = empty_tensor if grad and use_bias: grad_bias = torch.empty( B.shape[1], dtype=torch.float32 if fp32_output else dtype, device="cuda" ) else: grad_bias = empty_tensor bias = bias if use_bias else empty_tensor _ = torch.ops.tex_ts.te_gemm_ts( A, empty_tensor, fp8_index, input_dtype, transa, B, empty_tensor, fp8_index, input_dtype, transb, out, output_dtype, grad_bias if grad else bias, gelu_input, grad, workspace, workspace.shape[0], accumulate, False, # use_split_accumulator ) if return_output: return out, grad_bias, gelu_input return None, grad_bias, gelu_input def fp8_cast_transpose_fused( inp: torch.Tensor, fp8_meta_tensor: tex.FP8TensorMeta, fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], otype: tex.DType, cast_out: Optional[torch.Tensor] = None, transpose_out: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.Tensor, torch.Tensor], None]: """Cast + Transpose with FP8 output""" return_outputs = False if cast_out is None or transpose_out is None: cast_out = torch.empty_like(inp, dtype=torch.int8) transpose_out = torch.empty( inp.shape[1], inp.shape[0], device="cuda", dtype=torch.int8 ) return_outputs = True tex.fused_cast_transpose( inp, fp8_meta_tensor.scale[fp8_tensor], fp8_meta_tensor.amax_history[0][fp8_tensor], fp8_meta_tensor.scale_inv[fp8_tensor], cast_out, transpose_out, otype, ) if return_outputs: return cast_out, transpose_out return None def fp8_cast_transpose_bgrad_fused( inp: torch.Tensor, fp8_meta_tensor: tex.FP8TensorMeta, fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], otype: tex.DType, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Cast + Transpose + BGRAD with FP8 output""" return tex.fused_cast_transpose_bgrad( inp, fp8_meta_tensor.scale[fp8_tensor], fp8_meta_tensor.amax_history[0][fp8_tensor], fp8_meta_tensor.scale_inv[fp8_tensor], otype, ) def fp8_cast_transpose_bgrad_dgelu_fused( grad_output: torch.Tensor, gelu_input: torch.Tensor, fp8_meta_tensor: tex.FP8TensorMeta, fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], otype: tex.DType, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Cast + Transpose + BGRAD + DGELU with FP8 output""" return tex.fused_cast_transpose_bgrad_dgelu( grad_output, gelu_input, fp8_meta_tensor.scale[fp8_tensor], fp8_meta_tensor.amax_history[0][fp8_tensor], fp8_meta_tensor.scale_inv[fp8_tensor], otype, ) def fp8_gelu( inp: torch.Tensor, fp8_meta_tensor: tex.FP8TensorMeta, fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], otype: tex.DType, ) -> torch.Tensor: """GeLU with FP8 output""" return torch.ops.tex_ts.fp8_gelu_ts( inp, fp8_meta_tensor.scale, fp8_meta_tensor.amax_history, fp8_meta_tensor.scale_inv, fp8_tensor, otype, ) def layernorm_fwd_fp8( inp: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float, fp8_meta_tensor: tex.FP8TensorMeta, fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], otype: tex.DType, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """LayerNorm with FP8 output""" return tex.layernorm_fwd_fp8( inp, weight, bias, eps, fp8_meta_tensor.scale[fp8_tensor], fp8_meta_tensor.amax_history[0][fp8_tensor], fp8_meta_tensor.scale_inv[fp8_tensor], otype, ) def layernorm_fwd_fp8_inf( inp: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float, fp8_meta_tensor: tex.FP8TensorMeta, fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], otype: tex.DType, ) -> torch.Tensor: """LayerNorm with FP8 output. This version of layernorm_fwd_fp8 is specialized for inference, and returns only the normalized output. """ ret = torch.ops.tex_ts.layernorm_fwd_fp8_inf_ts( inp, weight, bias, eps, fp8_meta_tensor.scale, fp8_meta_tensor.amax_history, fp8_meta_tensor.scale_inv, fp8_tensor, otype) return ret def layernorm_fwd_inf( inp: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float, ) -> torch.Tensor: """LayerNorm with FP8 output""" return torch.ops.tex_ts.layernorm_fwd_inf_ts( inp, weight, bias, eps, ) def cast_to_fp8( inp: torch.Tensor, fp8_meta_tensor: tex.FP8TensorMeta, fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], otype: tex.DType, ) -> torch.Tensor: """Cast input to FP8""" return torch.ops.tex_ts.cast_to_fp8_ts( inp, fp8_meta_tensor.scale, fp8_meta_tensor.amax_history, fp8_meta_tensor.scale_inv, fp8_tensor, otype, ) def cast_from_fp8( inp: torch.Tensor, fp8_meta_tensor: tex.FP8TensorMeta, fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], itype: tex.DType, otype: tex.DType, ) -> torch.Tensor: """Cast input from FP8""" return torch.ops.tex_ts.cast_from_fp8_ts( inp, fp8_meta_tensor.scale_inv, fp8_tensor, itype, otype, )