Unverified Commit 83a4c219 authored by cyanguwa's avatar cyanguwa Committed by GitHub
Browse files

[C/PyTorch] Add FP8 DPA and MHA (#768)



* WIP: fp8 v1 fprop integration
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: minor fixes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add debug info
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add more debug info
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fprop working for h1; w/ debug info
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: add bprop
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* cleanup; bprop running but has mismatches
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add gitlab frontend as submodule
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clean up and add back v0.9.2 FE support; fprop/bprop passing with 5e-2 tols
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix after merge; add bias_b/h to caching descriptor
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* distinguish fwd/bwd tensor types for bprop
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fix for F16 cases; include added dqkv_type and d_scale_dp
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* adjust out shape for bwd in test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add casting from/to FP8 to DPA module
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: bshd_bshd_bshd layout
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: support all sbhd/bshd layouts
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clean up
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add qkvpacked and kvpacked support in both FusedAttnFunc and C levels
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove qkvpacked/kvpacked calls in DPA module (used for testing)
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove tp setup; add allow_non_contiguous; update FE; revert to sbh3d in tests; clean up
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add NVTE_FP8_DPA_BWD to control whether to use FP8 bwd or F16 bwd
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix MQA
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix MQA/GQA in FP8 v1 API
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update FE to 705d8e3, with API change
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* test causal mask
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* restrict mha_fill for THD format
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix fused attn with CP and comment out is_alibi code
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clean up FE0.9 vs FE1.0 FP8 implementations, and related unit tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* change NVTE_FP8_DPA_BWD default to 1, and fix its use in qkvpacked/kvpacked APIs
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix lint and self.tp_size/group in FusedAttention()
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update FE to 6902c94
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add FP8 MHA support
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update to FE v1.3.0
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fixes for FP8 MHA with different configs
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* emit stats regardless of is_training
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix linear when input is not Float8Tensor
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix d_out type when f16 bprop
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix user buffer for layernorm_linear/linear and revert two FP8 casts in MHA
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add docstring for fp8_dpa/mha in recipe
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix backend selection to avoid FA
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* replace transpose with transpose_2d
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* use RMSE for FP8 unit tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* replace two more transpose with transpose_2d
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add FP8 initialization to FusedAttention
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* rm docs
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Revert "add FP8 initialization to FusedAttention"

This reverts commit 15fffd825d6f23f31ea709b16ba01dfd61efabf8.
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Change order of ctxs
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* minor fixes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add back docs and mark as beta
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fixes for tests and docs
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent f69e45be
......@@ -3,6 +3,7 @@
# See LICENSE for license information.
"""Linear API"""
import os
from typing import Union, Optional, Callable, Tuple, List, Dict, Any
import torch
......@@ -46,6 +47,8 @@ from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing
from ..float8_tensor import Float8Tensor
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
__all__ = ["Linear"]
......@@ -81,11 +84,16 @@ class _Linear(torch.autograd.Function):
ub_overlap_rs: bool,
ub_overlap_ag: bool,
ub_name: str,
is_first_module_in_mha: bool,
) -> torch.Tensor:
is_input_fp8 = isinstance(inp, Float8Tensor)
if is_input_fp8:
fp8_meta["scaling_fwd"].scale_inv[tex.FP8FwdTensors.GEMM1_INPUT] = inp._scale_inv[0]
# Make sure input dimensions are compatible
in_features = weight.shape[-1]
assert inp.shape[-1] == in_features, "GEMM not possible"
inputmat = inp.view((-1, in_features))
inputmat = inp.view(-1, in_features)
if fp8:
assert_dim_for_fp8_exec(inputmat)
assert_dim_for_fp8_exec(weight)
......@@ -103,8 +111,19 @@ class _Linear(torch.autograd.Function):
inputmat = cast_if_needed(inputmat, activation_dtype)
inputmat_t = None
inputmat_no_fp8 = inputmat
if fp8:
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
if isinstance(inputmat, Float8Tensor):
if (
not fp8_meta["recipe"].override_linear_precision.wgrad
and is_grad_enabled
and weight.requires_grad
and not sequence_parallel
):
# FP8 input for forward, FP8 input transpose for backward wgrad
inputmat_t = inputmat.transpose_2d()
else:
if (
not fp8_meta["recipe"].override_linear_precision.wgrad
and is_grad_enabled
......@@ -134,6 +153,9 @@ class _Linear(torch.autograd.Function):
inputmat_total = inputmat
if fp8:
if _NVTE_DEBUG:
print('[Linear]: using FP8 forward')
bias_dtype = (
torch.bfloat16
if activation_dtype == torch.float32
......@@ -174,8 +196,16 @@ class _Linear(torch.autograd.Function):
)
weight_t_fp8 = None
if is_first_module_in_mha:
proj_out_index, meta_tensor, proj_out_tetype, proj_out_pttype = (
tex.FP8FwdTensors.GEMM1_OUTPUT,
fp8_meta["scaling_fwd"],
fp8_dtype_forward,
torch.uint8)
else:
proj_out_index, meta_tensor, proj_out_tetype, proj_out_pttype = (
None, None, None, activation_dtype)
if ub_overlap_rs:
ub_obj_projout = get_ub(ub_name+"_fprop")
out = ub_obj_projout.get_ubuf_output(1)
......@@ -202,14 +232,15 @@ class _Linear(torch.autograd.Function):
else:
dim_size = list(inputmat_total.size())
dim_size[1] = weight.size(0)
out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device)
out = torch.empty(dim_size, dtype=proj_out_pttype, device=inputmat_total.device)
_ = fp8_gemm(
weight_fp8._data,
fp8_meta["scaling_fwd"].scale_inv,
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
inputmat_total,
inputmat_total._data
if isinstance(inputmat_total, Float8Tensor) else inputmat_total,
fp8_meta["scaling_fwd"].scale_inv,
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
......@@ -226,7 +257,18 @@ class _Linear(torch.autograd.Function):
fp8_meta_tensor = meta_tensor,
D_dtype = proj_out_tetype,
)
if is_first_module_in_mha:
out = Float8Tensor(data=out,
fp8_meta=fp8_meta,
fp8_meta_forward=True,
fp8_meta_index=tex.FP8FwdTensors.GEMM1_OUTPUT,
fp8_dtype=fp8_dtype_forward,
dtype=activation_dtype,
)
else:
if _NVTE_DEBUG:
print('[Linear]: using non-FP8 forward')
# Cast for native AMP
weight = cast_if_needed(weight, activation_dtype)
bias = cast_if_needed(bias, activation_dtype) if use_bias else bias
......@@ -319,6 +361,7 @@ class _Linear(torch.autograd.Function):
ctx.ub_name = ub_name
ctx.tp_size = tp_size
ctx.requires_dgrad = inp.requires_grad
ctx.is_input_fp8 = is_input_fp8
ctx.primary_weights_in_fp8 = primary_weights_in_fp8
ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module()
......@@ -338,6 +381,10 @@ class _Linear(torch.autograd.Function):
def backward(
ctx, grad_output: torch.Tensor
) -> Tuple[Union[torch.Tensor, None], ...]:
if isinstance(grad_output[0], Float8Tensor):
ctx.fp8_meta["scaling_bwd"].scale_inv[
tex.FP8BwdTensors.GRAD_OUTPUT1] = grad_output._scale_inv
with torch.cuda.nvtx.range("_Linear_backward"):
(
inputmat,
......@@ -412,6 +459,18 @@ class _Linear(torch.autograd.Function):
if ctx.requires_dgrad:
if ctx.fp8:
if _NVTE_DEBUG:
print('[Linear]: using FP8 backward')
if ctx.is_input_fp8:
out_index, meta_tensor, output_te_dtype, output_dtype = (
tex.FP8BwdTensors.GRAD_INPUT1,
ctx.fp8_meta["scaling_bwd"],
fp8_dtype_backward,
torch.uint8)
else:
out_index, meta_tensor, output_te_dtype, output_dtype = (
None, None, None, ctx.activation_dtype)
dgrad, _ = fp8_gemm(
weight_t_fp8,
fwd_scale_inverses,
......@@ -421,13 +480,27 @@ class _Linear(torch.autograd.Function):
ctx.fp8_meta["scaling_bwd"].scale_inv,
tex.FP8BwdTensors.GRAD_OUTPUT1,
fp8_dtype_backward,
ctx.activation_dtype,
output_dtype,
get_workspace(),
use_split_accumulator=_2X_ACC_DGRAD,
ub_algo=ub_algo if ctx.ub_overlap_ag else None,
ub=ctx.ub_obj_gradout if ctx.ub_overlap_ag else None,
out_index=out_index,
fp8_meta_tensor=meta_tensor,
D_dtype=output_te_dtype,
)
if output_dtype == torch.uint8:
dgrad = Float8Tensor(data=dgrad,
fp8_meta=ctx.fp8_meta,
fp8_meta_forward=False,
fp8_meta_index=tex.FP8BwdTensors.GRAD_INPUT1,
fp8_dtype=fp8_dtype_backward,
dtype=ctx.activation_dtype,
)
else:
if _NVTE_DEBUG:
print('[Linear]: using non-FP8 backward')
dgrad, _, _ = gemm(
weight,
grad_output,
......@@ -455,11 +528,19 @@ class _Linear(torch.autograd.Function):
# WGRAD
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
if ctx.ub_overlap_ag:
if isinstance(grad_output_c, Float8Tensor):
grad_output_t = grad_output_c.transpose_2d()
else:
grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward)
if inputmat_t_total is None:
inputmat_t_total = tex.fp8_transpose(inputmat_total, fp8_dtype_backward)
if isinstance(inputmat_total, Float8Tensor):
inputmat_t_total = inputmat_total.transpose_2d()
else:
inputmat_t_total = tex.fp8_transpose(
inputmat_total, fp8_dtype_backward)
wgrad, _ = fp8_gemm(
inputmat_t_total,
inputmat_t_total._data
if isinstance(inputmat_t_total, Float8Tensor) else inputmat_t_total,
fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
......@@ -558,6 +639,7 @@ class _Linear(torch.autograd.Function):
None,
None,
None,
None,
)
......@@ -850,6 +932,7 @@ class Linear(TransformerEngineBaseModule):
self,
inp: torch.Tensor,
is_first_microbatch: Optional[bool] = None,
is_first_module_in_mha: Optional[bool] = False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
"""
Apply the linear transformation to the input.
......@@ -871,16 +954,22 @@ class Linear(TransformerEngineBaseModule):
* it also allows skipping gradient accumulation during the
first microbatch (since it is the first gradient being
produced)
is_first_module_in_mha: Optional[bool], default = False
Whether to output in FP8. By default, Linear outputs in inp.dtype.
"""
skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
if skip_fp8_weight_update is not None:
is_first_microbatch = False
with self.prepare_forward(inp, is_first_microbatch) as inp:
with self.prepare_forward(inp,
is_first_microbatch,
allow_non_contiguous=isinstance(inp,Float8Tensor)) as inp:
assert self.fp8 or not self.primary_weights_in_fp8, \
"Need to run inside fp8_autocast region when weights are stored in FP8."
is_first_module_in_mha = is_first_module_in_mha and self.fp8_meta["recipe"].fp8_mha
# Get concatenated weight and bias tensors
if len(self.parameter_split_sizes) == 1:
weight_tensor = getattr(self, self.weight_names[0])
......@@ -939,6 +1028,7 @@ class Linear(TransformerEngineBaseModule):
self.ub_overlap_rs,
self.ub_overlap_ag,
self.ub_name,
is_first_module_in_mha,
)
out = linear_fn(*args)
......
......@@ -15,8 +15,13 @@ def clear_tensor_data(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None:
Must be used carefully.
"""
from .float8_tensor import Float8Tensor
for t in tensors:
if t is not None:
if isinstance(t, Float8Tensor):
t._data.data = torch.Tensor()
del t
else:
t.data = torch.Tensor()
del t
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment