Unverified Commit 83a4c219 authored by cyanguwa's avatar cyanguwa Committed by GitHub
Browse files

[C/PyTorch] Add FP8 DPA and MHA (#768)



* WIP: fp8 v1 fprop integration
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: minor fixes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add debug info
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add more debug info
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fprop working for h1; w/ debug info
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: add bprop
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* cleanup; bprop running but has mismatches
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add gitlab frontend as submodule
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clean up and add back v0.9.2 FE support; fprop/bprop passing with 5e-2 tols
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix after merge; add bias_b/h to caching descriptor
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* distinguish fwd/bwd tensor types for bprop
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fix for F16 cases; include added dqkv_type and d_scale_dp
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* adjust out shape for bwd in test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add casting from/to FP8 to DPA module
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: bshd_bshd_bshd layout
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: support all sbhd/bshd layouts
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clean up
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add qkvpacked and kvpacked support in both FusedAttnFunc and C levels
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove qkvpacked/kvpacked calls in DPA module (used for testing)
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove tp setup; add allow_non_contiguous; update FE; revert to sbh3d in tests; clean up
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add NVTE_FP8_DPA_BWD to control whether to use FP8 bwd or F16 bwd
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix MQA
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix MQA/GQA in FP8 v1 API
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update FE to 705d8e3, with API change
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* test causal mask
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* restrict mha_fill for THD format
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix fused attn with CP and comment out is_alibi code
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clean up FE0.9 vs FE1.0 FP8 implementations, and related unit tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* change NVTE_FP8_DPA_BWD default to 1, and fix its use in qkvpacked/kvpacked APIs
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix lint and self.tp_size/group in FusedAttention()
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update FE to 6902c94
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add FP8 MHA support
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update to FE v1.3.0
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fixes for FP8 MHA with different configs
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* emit stats regardless of is_training
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix linear when input is not Float8Tensor
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix d_out type when f16 bprop
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix user buffer for layernorm_linear/linear and revert two FP8 casts in MHA
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add docstring for fp8_dpa/mha in recipe
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix backend selection to avoid FA
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* replace transpose with transpose_2d
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* use RMSE for FP8 unit tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* replace two more transpose with transpose_2d
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add FP8 initialization to FusedAttention
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* rm docs
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Revert "add FP8 initialization to FusedAttention"

This reverts commit 15fffd825d6f23f31ea709b16ba01dfd61efabf8.
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Change order of ctxs
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* minor fixes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add back docs and mark as beta
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fixes for tests and docs
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent f69e45be
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
# See LICENSE for license information. # See LICENSE for license information.
"""Linear API""" """Linear API"""
import os
from typing import Union, Optional, Callable, Tuple, List, Dict, Any from typing import Union, Optional, Callable, Tuple, List, Dict, Any
import torch import torch
...@@ -46,6 +47,8 @@ from ..jit import no_torch_dynamo ...@@ -46,6 +47,8 @@ from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing from ..graph import is_graph_capturing
from ..float8_tensor import Float8Tensor from ..float8_tensor import Float8Tensor
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
__all__ = ["Linear"] __all__ = ["Linear"]
...@@ -81,11 +84,16 @@ class _Linear(torch.autograd.Function): ...@@ -81,11 +84,16 @@ class _Linear(torch.autograd.Function):
ub_overlap_rs: bool, ub_overlap_rs: bool,
ub_overlap_ag: bool, ub_overlap_ag: bool,
ub_name: str, ub_name: str,
is_first_module_in_mha: bool,
) -> torch.Tensor: ) -> torch.Tensor:
is_input_fp8 = isinstance(inp, Float8Tensor)
if is_input_fp8:
fp8_meta["scaling_fwd"].scale_inv[tex.FP8FwdTensors.GEMM1_INPUT] = inp._scale_inv[0]
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
in_features = weight.shape[-1] in_features = weight.shape[-1]
assert inp.shape[-1] == in_features, "GEMM not possible" assert inp.shape[-1] == in_features, "GEMM not possible"
inputmat = inp.view((-1, in_features)) inputmat = inp.view(-1, in_features)
if fp8: if fp8:
assert_dim_for_fp8_exec(inputmat) assert_dim_for_fp8_exec(inputmat)
assert_dim_for_fp8_exec(weight) assert_dim_for_fp8_exec(weight)
...@@ -103,8 +111,19 @@ class _Linear(torch.autograd.Function): ...@@ -103,8 +111,19 @@ class _Linear(torch.autograd.Function):
inputmat = cast_if_needed(inputmat, activation_dtype) inputmat = cast_if_needed(inputmat, activation_dtype)
inputmat_t = None inputmat_t = None
inputmat_no_fp8 = inputmat inputmat_no_fp8 = inputmat
if fp8: if fp8:
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
if isinstance(inputmat, Float8Tensor):
if (
not fp8_meta["recipe"].override_linear_precision.wgrad
and is_grad_enabled
and weight.requires_grad
and not sequence_parallel
):
# FP8 input for forward, FP8 input transpose for backward wgrad
inputmat_t = inputmat.transpose_2d()
else:
if ( if (
not fp8_meta["recipe"].override_linear_precision.wgrad not fp8_meta["recipe"].override_linear_precision.wgrad
and is_grad_enabled and is_grad_enabled
...@@ -134,6 +153,9 @@ class _Linear(torch.autograd.Function): ...@@ -134,6 +153,9 @@ class _Linear(torch.autograd.Function):
inputmat_total = inputmat inputmat_total = inputmat
if fp8: if fp8:
if _NVTE_DEBUG:
print('[Linear]: using FP8 forward')
bias_dtype = ( bias_dtype = (
torch.bfloat16 torch.bfloat16
if activation_dtype == torch.float32 if activation_dtype == torch.float32
...@@ -174,8 +196,16 @@ class _Linear(torch.autograd.Function): ...@@ -174,8 +196,16 @@ class _Linear(torch.autograd.Function):
) )
weight_t_fp8 = None weight_t_fp8 = None
if is_first_module_in_mha:
proj_out_index, meta_tensor, proj_out_tetype, proj_out_pttype = (
tex.FP8FwdTensors.GEMM1_OUTPUT,
fp8_meta["scaling_fwd"],
fp8_dtype_forward,
torch.uint8)
else:
proj_out_index, meta_tensor, proj_out_tetype, proj_out_pttype = ( proj_out_index, meta_tensor, proj_out_tetype, proj_out_pttype = (
None, None, None, activation_dtype) None, None, None, activation_dtype)
if ub_overlap_rs: if ub_overlap_rs:
ub_obj_projout = get_ub(ub_name+"_fprop") ub_obj_projout = get_ub(ub_name+"_fprop")
out = ub_obj_projout.get_ubuf_output(1) out = ub_obj_projout.get_ubuf_output(1)
...@@ -202,14 +232,15 @@ class _Linear(torch.autograd.Function): ...@@ -202,14 +232,15 @@ class _Linear(torch.autograd.Function):
else: else:
dim_size = list(inputmat_total.size()) dim_size = list(inputmat_total.size())
dim_size[1] = weight.size(0) dim_size[1] = weight.size(0)
out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device) out = torch.empty(dim_size, dtype=proj_out_pttype, device=inputmat_total.device)
_ = fp8_gemm( _ = fp8_gemm(
weight_fp8._data, weight_fp8._data,
fp8_meta["scaling_fwd"].scale_inv, fp8_meta["scaling_fwd"].scale_inv,
tex.FP8FwdTensors.GEMM1_WEIGHT, tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward, fp8_dtype_forward,
inputmat_total, inputmat_total._data
if isinstance(inputmat_total, Float8Tensor) else inputmat_total,
fp8_meta["scaling_fwd"].scale_inv, fp8_meta["scaling_fwd"].scale_inv,
tex.FP8FwdTensors.GEMM1_INPUT, tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward, fp8_dtype_forward,
...@@ -226,7 +257,18 @@ class _Linear(torch.autograd.Function): ...@@ -226,7 +257,18 @@ class _Linear(torch.autograd.Function):
fp8_meta_tensor = meta_tensor, fp8_meta_tensor = meta_tensor,
D_dtype = proj_out_tetype, D_dtype = proj_out_tetype,
) )
if is_first_module_in_mha:
out = Float8Tensor(data=out,
fp8_meta=fp8_meta,
fp8_meta_forward=True,
fp8_meta_index=tex.FP8FwdTensors.GEMM1_OUTPUT,
fp8_dtype=fp8_dtype_forward,
dtype=activation_dtype,
)
else: else:
if _NVTE_DEBUG:
print('[Linear]: using non-FP8 forward')
# Cast for native AMP # Cast for native AMP
weight = cast_if_needed(weight, activation_dtype) weight = cast_if_needed(weight, activation_dtype)
bias = cast_if_needed(bias, activation_dtype) if use_bias else bias bias = cast_if_needed(bias, activation_dtype) if use_bias else bias
...@@ -319,6 +361,7 @@ class _Linear(torch.autograd.Function): ...@@ -319,6 +361,7 @@ class _Linear(torch.autograd.Function):
ctx.ub_name = ub_name ctx.ub_name = ub_name
ctx.tp_size = tp_size ctx.tp_size = tp_size
ctx.requires_dgrad = inp.requires_grad ctx.requires_dgrad = inp.requires_grad
ctx.is_input_fp8 = is_input_fp8
ctx.primary_weights_in_fp8 = primary_weights_in_fp8 ctx.primary_weights_in_fp8 = primary_weights_in_fp8
ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module() ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module()
...@@ -338,6 +381,10 @@ class _Linear(torch.autograd.Function): ...@@ -338,6 +381,10 @@ class _Linear(torch.autograd.Function):
def backward( def backward(
ctx, grad_output: torch.Tensor ctx, grad_output: torch.Tensor
) -> Tuple[Union[torch.Tensor, None], ...]: ) -> Tuple[Union[torch.Tensor, None], ...]:
if isinstance(grad_output[0], Float8Tensor):
ctx.fp8_meta["scaling_bwd"].scale_inv[
tex.FP8BwdTensors.GRAD_OUTPUT1] = grad_output._scale_inv
with torch.cuda.nvtx.range("_Linear_backward"): with torch.cuda.nvtx.range("_Linear_backward"):
( (
inputmat, inputmat,
...@@ -412,6 +459,18 @@ class _Linear(torch.autograd.Function): ...@@ -412,6 +459,18 @@ class _Linear(torch.autograd.Function):
if ctx.requires_dgrad: if ctx.requires_dgrad:
if ctx.fp8: if ctx.fp8:
if _NVTE_DEBUG:
print('[Linear]: using FP8 backward')
if ctx.is_input_fp8:
out_index, meta_tensor, output_te_dtype, output_dtype = (
tex.FP8BwdTensors.GRAD_INPUT1,
ctx.fp8_meta["scaling_bwd"],
fp8_dtype_backward,
torch.uint8)
else:
out_index, meta_tensor, output_te_dtype, output_dtype = (
None, None, None, ctx.activation_dtype)
dgrad, _ = fp8_gemm( dgrad, _ = fp8_gemm(
weight_t_fp8, weight_t_fp8,
fwd_scale_inverses, fwd_scale_inverses,
...@@ -421,13 +480,27 @@ class _Linear(torch.autograd.Function): ...@@ -421,13 +480,27 @@ class _Linear(torch.autograd.Function):
ctx.fp8_meta["scaling_bwd"].scale_inv, ctx.fp8_meta["scaling_bwd"].scale_inv,
tex.FP8BwdTensors.GRAD_OUTPUT1, tex.FP8BwdTensors.GRAD_OUTPUT1,
fp8_dtype_backward, fp8_dtype_backward,
ctx.activation_dtype, output_dtype,
get_workspace(), get_workspace(),
use_split_accumulator=_2X_ACC_DGRAD, use_split_accumulator=_2X_ACC_DGRAD,
ub_algo=ub_algo if ctx.ub_overlap_ag else None, ub_algo=ub_algo if ctx.ub_overlap_ag else None,
ub=ctx.ub_obj_gradout if ctx.ub_overlap_ag else None, ub=ctx.ub_obj_gradout if ctx.ub_overlap_ag else None,
out_index=out_index,
fp8_meta_tensor=meta_tensor,
D_dtype=output_te_dtype,
)
if output_dtype == torch.uint8:
dgrad = Float8Tensor(data=dgrad,
fp8_meta=ctx.fp8_meta,
fp8_meta_forward=False,
fp8_meta_index=tex.FP8BwdTensors.GRAD_INPUT1,
fp8_dtype=fp8_dtype_backward,
dtype=ctx.activation_dtype,
) )
else: else:
if _NVTE_DEBUG:
print('[Linear]: using non-FP8 backward')
dgrad, _, _ = gemm( dgrad, _, _ = gemm(
weight, weight,
grad_output, grad_output,
...@@ -455,11 +528,19 @@ class _Linear(torch.autograd.Function): ...@@ -455,11 +528,19 @@ class _Linear(torch.autograd.Function):
# WGRAD # WGRAD
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
if ctx.ub_overlap_ag: if ctx.ub_overlap_ag:
if isinstance(grad_output_c, Float8Tensor):
grad_output_t = grad_output_c.transpose_2d()
else:
grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward) grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward)
if inputmat_t_total is None: if inputmat_t_total is None:
inputmat_t_total = tex.fp8_transpose(inputmat_total, fp8_dtype_backward) if isinstance(inputmat_total, Float8Tensor):
inputmat_t_total = inputmat_total.transpose_2d()
else:
inputmat_t_total = tex.fp8_transpose(
inputmat_total, fp8_dtype_backward)
wgrad, _ = fp8_gemm( wgrad, _ = fp8_gemm(
inputmat_t_total, inputmat_t_total._data
if isinstance(inputmat_t_total, Float8Tensor) else inputmat_t_total,
fwd_scale_inverses, fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_INPUT, tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward, fp8_dtype_forward,
...@@ -558,6 +639,7 @@ class _Linear(torch.autograd.Function): ...@@ -558,6 +639,7 @@ class _Linear(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
) )
...@@ -850,6 +932,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -850,6 +932,7 @@ class Linear(TransformerEngineBaseModule):
self, self,
inp: torch.Tensor, inp: torch.Tensor,
is_first_microbatch: Optional[bool] = None, is_first_microbatch: Optional[bool] = None,
is_first_module_in_mha: Optional[bool] = False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
""" """
Apply the linear transformation to the input. Apply the linear transformation to the input.
...@@ -871,16 +954,22 @@ class Linear(TransformerEngineBaseModule): ...@@ -871,16 +954,22 @@ class Linear(TransformerEngineBaseModule):
* it also allows skipping gradient accumulation during the * it also allows skipping gradient accumulation during the
first microbatch (since it is the first gradient being first microbatch (since it is the first gradient being
produced) produced)
is_first_module_in_mha: Optional[bool], default = False
Whether to output in FP8. By default, Linear outputs in inp.dtype.
""" """
skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
if skip_fp8_weight_update is not None: if skip_fp8_weight_update is not None:
is_first_microbatch = False is_first_microbatch = False
with self.prepare_forward(inp, is_first_microbatch) as inp: with self.prepare_forward(inp,
is_first_microbatch,
allow_non_contiguous=isinstance(inp,Float8Tensor)) as inp:
assert self.fp8 or not self.primary_weights_in_fp8, \ assert self.fp8 or not self.primary_weights_in_fp8, \
"Need to run inside fp8_autocast region when weights are stored in FP8." "Need to run inside fp8_autocast region when weights are stored in FP8."
is_first_module_in_mha = is_first_module_in_mha and self.fp8_meta["recipe"].fp8_mha
# Get concatenated weight and bias tensors # Get concatenated weight and bias tensors
if len(self.parameter_split_sizes) == 1: if len(self.parameter_split_sizes) == 1:
weight_tensor = getattr(self, self.weight_names[0]) weight_tensor = getattr(self, self.weight_names[0])
...@@ -939,6 +1028,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -939,6 +1028,7 @@ class Linear(TransformerEngineBaseModule):
self.ub_overlap_rs, self.ub_overlap_rs,
self.ub_overlap_ag, self.ub_overlap_ag,
self.ub_name, self.ub_name,
is_first_module_in_mha,
) )
out = linear_fn(*args) out = linear_fn(*args)
......
...@@ -15,8 +15,13 @@ def clear_tensor_data(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None: ...@@ -15,8 +15,13 @@ def clear_tensor_data(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None:
Must be used carefully. Must be used carefully.
""" """
from .float8_tensor import Float8Tensor
for t in tensors: for t in tensors:
if t is not None: if t is not None:
if isinstance(t, Float8Tensor):
t._data.data = torch.Tensor()
del t
else:
t.data = torch.Tensor() t.data = torch.Tensor()
del t del t
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment