Unverified Commit 6e90fcb7 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Upgrade pylint to 3.3.1 (#1257)



* Upgrade pylint and first round formatting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* round 2
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* round 3
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Format and fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Paddle lint
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Reviews
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* FIxes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* More linting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Run formatter
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Paddle lint
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 161b1d98
......@@ -8,7 +8,9 @@ extension-pkg-whitelist=flash_attn_2_cuda,
extension-pkg-allow-list=transformer_engine.transformer_engine_jax
disable=too-many-locals,
too-few-public-methods,
too-many-public-methods,
too-many-positional-arguments,
invalid-name,
too-many-arguments,
abstract-method,
......
......@@ -6,7 +6,7 @@ set -e
: "${TE_PATH:=/opt/transformerengine}"
pip install cpplint==1.6.0 pylint==2.13.5
pip install cpplint==1.6.0 pylint==3.3.1
if [ -z "${PYTHON_ONLY}" ]
then
cd $TE_PATH
......
......@@ -6,7 +6,7 @@ set -e
: "${TE_PATH:=/opt/transformerengine}"
pip install cpplint==1.6.0 pylint==2.13.5
pip install cpplint==1.6.0 pylint==3.3.1
if [ -z "${PYTHON_ONLY}" ]
then
cd $TE_PATH
......
......@@ -6,7 +6,7 @@ set -e
: "${TE_PATH:=/opt/transformerengine}"
pip install cpplint==1.6.0 pylint==2.13.5
pip install cpplint==1.6.0 pylint==3.3.1
if [ -z "${PYTHON_ONLY}" ]
then
cd $TE_PATH
......
......@@ -583,6 +583,7 @@ def fused_attn_fwd_qkvpacked(
fused_attention_backend != FusedAttnBackend["No_Backend"]
), "Fused attention does not support this input combination."
rng_elts_per_thread = None
# BF16/FP16 fused attention API from fmha_v1 apex
if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]:
rng_elts_per_thread = (
......@@ -773,6 +774,7 @@ def fused_attn_fwd_kvpacked(
fused_attention_backend != FusedAttnBackend["No_Backend"]
), "Fused attention does not support this input combination."
rng_elts_per_thread = None
# BF16/FP16 fused attention API from fmha_v1 apex
if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]:
rng_elts_per_thread = (
......@@ -982,6 +984,7 @@ def fused_attn_fwd(
fused_attention_backend != FusedAttnBackend["No_Backend"]
), "Fused attention does not support this input combination."
rng_elts_per_thread = None
# BF16/FP16 fused attention API from fmha_v1 apex
if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]:
rng_elts_per_thread = (
......
......@@ -100,6 +100,7 @@ class FP8MetaBufferBase(ABC):
self._dp_amax_reduce_interval = int(os.getenv("NVTE_DP_AMAX_REDUCE_INTERVAL", "1"))
tp_amax_reduce = False
reduce_group = -1 # Set value that will raise error if not set. `None` is a valid group.
if self._dp_amax_reduce_idx == 0:
reduce_group = fp8_meta["fp8_group"]
else:
......
......@@ -1008,6 +1008,7 @@ class MultiHeadAttention(paddle.nn.Layer):
else:
raise ValueError(f"hidden_states should have 2 or 3 dimensions, got {input_dim}.")
layernorm_output = None
if self.attention_type == "self":
if self.input_layernorm:
layernorm_qkv_outputs = self.layernorm_qkv(
......
......@@ -266,6 +266,8 @@ def _mlp_backward(
accumulate_wgrad_into_param_main_grad,
)
dgelu_t = None
fc1_bgrad_ = None
if activation == "gelu":
# GELU Bwd
dgelu, dgelu_t, fc1_bgrad_ = dgelu_cast_transpose_bgrad_fp8(
......
......@@ -86,6 +86,7 @@ from transformer_engine.pytorch.export import is_in_onnx_export_mode
from transformer_engine.pytorch.jit import jit_fuser, no_torch_dynamo
from transformer_engine.pytorch.graph import is_graph_capturing
# NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
# NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0
......@@ -123,6 +124,13 @@ _flash_attn_2_4_plus = False
_flash_attn_2_4_1_plus = False
_flash_attn_2_5_7_plus = False
_flash_attn_2_6_0_plus = False
flash_attn_func = None
flash_attn_varlen_func = None
flash_attn_varlen_fwd = None
flash_attn_varlen_bwd = None
flash_attn_cuda_bwd = None
try:
_flash_attn_version = PkgVersion(get_pkg_version("flash-attn"))
except PackageNotFoundError:
......@@ -196,7 +204,6 @@ else:
_flash_attn_3_0_0_beta = PkgVersion("3.0.0b") < _flash_attn_3_version < PkgVersion("3.0.0")
_use_flash_attn_3 = True
_attention_backends = {
"attention_params": None,
"use_flash_attention": None,
......@@ -304,6 +311,11 @@ _alibi_cache = {
__all__ = ["DotProductAttention", "InferenceParams", "MultiheadAttention"]
def maybe_contiguous(tensor: torch.Tensor) -> torch.Tensor:
"""Make tensor contiguous if final stride is not 1."""
return tensor.contiguous() if tensor.stride(-1) != 1 else tensor
def get_attention_backend(
attention_params: AttentionParams = None,
):
......@@ -362,9 +374,7 @@ def get_attention_backend(
run_config = {
"transformer_engine_version": te.__version__,
"compute_capability": "sm"
+ str(
(lambda x, y: x * 10 + y)(device_compute_capability[0], device_compute_capability[1])
),
+ str(10 * device_compute_capability[0] + device_compute_capability[1]),
"flash_attn_version": (
str(_flash_attn_version) if _flash_attn_is_installed else "not installed"
),
......@@ -1135,8 +1145,11 @@ def get_alibi(
assert _alibi_cache["_alibi_slopes"] is not None, "ALiBi slopes can not be None!"
if _alibi_cache["_alibi_slopes"].dim() == 1:
slopes_shape = torch.Size([1, _alibi_cache["_alibi_slopes"].shape[0], 1, 1])
if _alibi_cache["_alibi_slopes"].dim() == 2:
elif _alibi_cache["_alibi_slopes"].dim() == 2:
slopes_shape = torch.Size([*_alibi_cache["_alibi_slopes"].shape[:], 1, 1])
else:
raise ValueError("ALiBi slopes cannot exceed 2 dimensions.")
bias = torch.arange(max_seqlen_q, dtype=torch.int32, device="cuda").view(
1, 1, max_seqlen_q, 1
) - torch.arange(max_seqlen_kv, dtype=torch.int32, device="cuda").view(
......@@ -1372,6 +1385,7 @@ class PackTensors(torch.autograd.Function):
def forward(
ctx, indices: torch.Tensor, *tensors: Tuple[torch.Tensor, ...]
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# pylint: disable=missing-function-docstring
assert 1 <= len(tensors) <= 3, f"Packing {len(tensors)} tensors not supported."
ctx.save_for_backward(indices)
ctx.dim0 = tensors[0].shape[0]
......@@ -1383,6 +1397,7 @@ class PackTensors(torch.autograd.Function):
@staticmethod
def backward(ctx, *grad_outputs: Tuple[torch.Tensor, ...]):
# pylint: disable=missing-function-docstring
(indices,) = ctx.saved_tensors
if len(grad_outputs) == 1:
return None, unpack_tensor(indices, ctx.dim0, *grad_outputs)
......@@ -1403,11 +1418,13 @@ class UnpackTensor(torch.autograd.Function):
dim0: int,
tensor: torch.Tensor,
) -> torch.Tensor:
# pylint: disable=missing-function-docstring
ctx.save_for_backward(indices)
return unpack_tensor(indices, dim0, tensor)
@staticmethod
def backward(ctx, grad_output):
# pylint: disable=missing-function-docstring
(indices,) = ctx.saved_tensors
return None, None, pack_tensor(indices, grad_output)
......@@ -1661,6 +1678,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
cp_global_ranks,
cp_stream,
):
# pylint: disable=missing-function-docstring
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
......@@ -1690,6 +1708,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
causal = "causal" in attn_mask_type
padding = "padding" in attn_mask_type
seq_dim = None
if qkv_format in ["bshd", "sbhd"]:
seq_dim = qkv_format.index("s")
qkv_layout = qkv_format + "_" + qkv_format[:-2] + "2" + qkv_format[-2:]
......@@ -1705,6 +1724,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
cu_seqlens_q_per_step = [None for _ in range(cp_size)]
cu_seqlens_kv_per_step = [None for _ in range(cp_size)]
fused_attn_qkv_dtype = None
fused_attn_backend = None
amax_per_step = None
if fp8:
if use_fused_attention:
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
......@@ -1796,6 +1818,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
softmax_lse_in_packed_format = not use_fused_attention and (
_flash_attn_2_6_0_plus or _use_flash_attn_3
)
flash_attn_fwd = None
if not use_fused_attention:
fa_forward_kwargs = {"softmax_scale": softmax_scale}
if _use_flash_attn_3:
......@@ -1834,6 +1857,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
p2p_comm_buffers[0] = torch.cat((k.unsqueeze(0), v.unsqueeze(0)), dim=0)
send_recv_reqs = [[], []]
softmax_lse_ = None
out = None
for i in range(cp_size + 1):
if i < cp_size:
with torch.cuda.stream(flash_attn_streams[i % 2]):
......@@ -2326,8 +2351,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
softmax_lse = softmax_lse.to(torch.float)
for i in range(cp_size):
out_ = None
if qkv_format == "bshd":
out_per_step[i] = out_per_step[i].view(out.shape[0], -1, *out.shape[-2:])
out_per_step[i] = out_per_step[i].view(
out.shape[0], -1, *out.shape[-2:]
) # pylint: disable=used-before-assignment
out_ = out[:, 1, ...]
elif qkv_format == "sbhd":
out_per_step[i] = out_per_step[i].view(-1, *out.shape[-3:])
......@@ -2405,6 +2433,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
fp8_meta["scaling_fwd"].amax_history[0][META_S] = amax_cp_fwd[0]
fp8_meta["scaling_fwd"].amax_history[0][META_O_CP] = amax_cp_fwd[1]
out_fp8 = None
out_f16 = out.to(q_fp8.dtype if fp8 and fp8_meta["recipe"].fp8_mha else q_f16.dtype)
if fp8 and (fp8_meta["recipe"].fp8_mha or int(os.getenv("NVTE_FP8_DPA_BWD", "1"))):
out_fp8 = cast_to_fp8(out_f16, fp8_meta["scaling_fwd"], META_O, fp8_dtype_forward)
......@@ -2486,6 +2515,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
@staticmethod
def backward(ctx, dout):
# pylint: disable=missing-function-docstring
cp_size_a2a = ctx.cp_size_a2a
rank_a2a = ctx.rank_a2a
......@@ -2521,6 +2551,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
)
else:
attn_dbias = None
attn_dbias_ = None
softmax_lse_in_packed_format = not ctx.use_fused_attention and (
_flash_attn_2_6_0_plus or _use_flash_attn_3
......@@ -2545,6 +2576,12 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
softmax_lse.unsqueeze_(-1)
dout_dtype = dout.dtype
fused_attn_backend = None
fused_attn_qkv_dtype = None
fused_attn_dqkv_dtype = None
amax_per_step = None
seq_dim = None
dout_fp8_dtype = None
if ctx.fp8:
if ctx.use_fused_attention:
fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
......@@ -2615,13 +2652,19 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
)
if not ctx.fp8 and ctx.fp8_meta is not None and ctx.fp8_meta["recipe"].fp8_mha:
dout = cast_from_fp8(
dout, None, None, dout_fp8_dtype, TE_DType[dout_dtype], scale_inv=dout_scale_inv
dout,
None,
None,
dout_fp8_dtype,
TE_DType[dout_dtype],
scale_inv=dout_scale_inv, # pylint: disable=used-before-assignment
)
out = out.view(*q.shape)
dout = dout.view(*q.shape)
send_recv_reqs = []
flash_attn_bwd = None
if not ctx.use_fused_attention:
fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale}
if _use_flash_attn_3:
......@@ -2673,6 +2716,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
)
kv = p2p_comm_buffers[i % 2][0]
dk_, dv_ = None, None
if ctx.fp8 and ctx.use_fused_attention:
fp8_meta_kwargs["amax_dp"] = amax_per_step[0][i]
fp8_meta_kwargs["amax_dqkv"] = amax_per_step[0][i]
......@@ -3106,7 +3150,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
else:
dkv = p2p_comm_buffers[(i + 1) % 2][1]
if ctx.use_fused_attention:
dkv_ = torch.cat((dk_.unsqueeze(0), dv_.unsqueeze(0)), dim=0)
dkv_ = torch.cat(
(dk_.unsqueeze(0), dv_.unsqueeze(0)), dim=0
) # pylint: disable=used-before-assignment
if ctx.qkv_format in ["bshd", "sbhd"]:
# [b, 2, sk//2, 2, np, hn] -> [2, b, 2, sk//2, np, hn] or
# [2, sk//2, b, 2, np, hn] -> [2, 2, sk//2, b, np, hn]
......@@ -3334,6 +3380,7 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
cp_group,
cp_stream,
):
# pylint: disable=missing-function-docstring
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
......@@ -3351,6 +3398,7 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
use_fused_attention or _flash_attn_2_3_plus
), "Sliding window attention only can work with FusedAttention or FlashAttention >= 2.3!"
flash_attn_fwd = None
if not use_fused_attention:
fa_forward_kwargs = {"softmax_scale": softmax_scale}
if _use_flash_attn_3:
......@@ -3521,6 +3569,7 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
@staticmethod
def backward(ctx, dout):
# pylint: disable=missing-function-docstring
cp_size = get_distributed_world_size(ctx.cp_group)
rank = get_distributed_rank(ctx.cp_group)
......@@ -3565,6 +3614,7 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
local_seq_chunk_ids = [rank, 2 * cp_size - rank - 1]
flash_attn_bwd = None
if not ctx.use_fused_attention:
fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale}
if _use_flash_attn_3:
......@@ -3751,6 +3801,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
cp_group,
cp_stream,
):
# pylint: disable=missing-function-docstring
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
......@@ -3768,6 +3819,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
or _flash_attn_2_3_plus
), "Sliding window attention only can work with FusedAttention or FlashAttention >= 2.3!"
flash_attn_fwd = None
if not use_fused_attention:
fa_forward_kwargs = {"softmax_scale": softmax_scale}
if _use_flash_attn_3:
......@@ -3797,6 +3849,8 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0
), "Sequence length per GPU needs to be divisible by 2!"
fused_attn_backend = None
fused_attn_qkv_dtype = None
if fp8:
if use_fused_attention:
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
......@@ -3990,6 +4044,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
@staticmethod
def backward(ctx, dout):
# pylint: disable=missing-function-docstring
cp_size = get_distributed_world_size(ctx.cp_group)
q, k, v, out = ctx.saved_tensors[:4]
......@@ -4003,6 +4058,9 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
causal = "causal" in ctx.attn_mask_type
seq_dim = ctx.qkv_format.index("s")
fused_attn_backend = None
fused_attn_dqkv_dtype = None
fused_attn_qkv_dtype = None
if ctx.fp8:
if ctx.use_fused_attention:
fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
......@@ -4054,6 +4112,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
[out, dout], chunk_ids_for_a2a, seq_dim, cp_size, ctx.cp_group, ctx.cp_stream, True
)
flash_attn_bwd = None
if not ctx.use_fused_attention:
fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale}
if _use_flash_attn_3:
......@@ -4400,6 +4459,7 @@ class FusedRoPEFunc(torch.autograd.Function):
cp_size: int = 1,
cp_rank: int = 0,
) -> torch.Tensor:
# pylint: disable=missing-function-docstring
if freqs.dtype != torch.float32:
freqs = freqs.float()
if tensor_format == "sbhd":
......@@ -4419,6 +4479,7 @@ class FusedRoPEFunc(torch.autograd.Function):
@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
# pylint: disable=missing-function-docstring
freqs, cu_seqlens = ctx.saved_tensors
if ctx.tensor_format == "sbhd":
grad_input = tex.fused_rope_backward(grad_output, freqs, False)
......@@ -4525,6 +4586,7 @@ class _SplitAlongDim(torch.autograd.Function):
split_dim: int,
split_size_or_sections: Union[int, List[int], Tuple[int]],
) -> Tuple[torch.Tensor, ...]:
# pylint: disable=missing-function-docstring
ctx.split_dim = split_dim
ctx.split_size_or_sections = split_size_or_sections
if isinstance(mixed_x_layer, Float8Tensor):
......@@ -4543,6 +4605,7 @@ class _SplitAlongDim(torch.autograd.Function):
@staticmethod
def backward(ctx, *grad_outputs):
# pylint: disable=missing-function-docstring
assert len(grad_outputs) > 0, "No gradients received for backprop!"
if isinstance(ctx.split_size_or_sections, (list, tuple)):
......@@ -4887,6 +4950,7 @@ class _PrepareQKVForFA(torch.autograd.Function):
key_layer: torch.Tensor,
value_layer: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# pylint: disable=missing-function-docstring
# All inputs received are non-contiguous tensors.
# The `query_layer` tensor is used to access the
# full memory region of the QKV tensor.
......@@ -4904,6 +4968,7 @@ class _PrepareQKVForFA(torch.autograd.Function):
dk: torch.Tensor,
dv: torch.Tensor,
) -> Tuple[Union[torch.Tensor, None], ...]:
# pylint: disable=missing-function-docstring
dqkv = tex.fa_prepare_bwd(dq, dk, dv)
dq, dk, dv = split_tensor_along_dim(dqkv, -1, 3)
return dq, dk, dv
......@@ -5055,7 +5120,7 @@ def get_qkv_layout(
q, k, v = [x.contiguous() for x in [q, k, v]]
qkv_layout = run_iteratively(q, k, v)
if qkv_layout == "not_supported":
raise Exception("The provided qkv memory layout is not supported!")
raise RuntimeError("The provided qkv memory layout is not supported!")
return qkv_layout, q, k, v
......@@ -5351,9 +5416,9 @@ class FlashAttention(torch.nn.Module):
fa_3_optional_forward_kwargs = {}
fa_3_optional_forward_kwargs["window_size"] = window_size
fa_3_optional_forward_kwargs["deterministic"] = self.deterministic
activation_dtype = query_layer.dtype
if fp8:
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
activation_dtype = query_layer.dtype
torch_dtype = get_fp8_torch_dtype(fp8_meta["recipe"], fprop_tensor=True)
def convert_to_torch_float8(tensor, dtype):
......@@ -5511,6 +5576,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
fp8_meta,
deterministic,
):
# pylint: disable=missing-function-docstring
is_input_fp8 = False
is_output_fp8 = fp8_meta["recipe"].fp8_mha
if fp8:
......@@ -5664,6 +5730,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
@staticmethod
def backward(ctx, d_out):
# pylint: disable=missing-function-docstring
if ctx.is_output_fp8:
assert isinstance(
d_out, Float8Tensor
......@@ -5683,12 +5750,12 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
fwd_scale_invs,
*aux_ctx_tensors,
) = ctx.saved_tensors
rest = [None]
if not aux_ctx_tensors[0].is_contiguous():
aux_ctx_tensors[0] = aux_ctx_tensors[0].contiguous()
if ctx.use_FAv2_bwd:
softmax_lse, rng_state = aux_ctx_tensors
dqkv = torch.empty_like(qkv)
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
d_out, q, k, v, out = [
maybe_contiguous(x) for x in (d_out, qkv[:, 0], qkv[:, 1], qkv[:, 2], out)
]
......@@ -5899,6 +5966,7 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
fp8_meta,
deterministic,
):
# pylint: disable=missing-function-docstring
is_input_fp8 = False
is_output_fp8 = fp8_meta["recipe"].fp8_mha
if fp8:
......@@ -6080,6 +6148,7 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
@staticmethod
def backward(ctx, d_out):
# pylint: disable=missing-function-docstring
if ctx.is_output_fp8:
assert isinstance(
d_out, Float8Tensor
......@@ -6103,13 +6172,13 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
fwd_scale_invs,
*aux_ctx_tensors,
) = ctx.saved_tensors
rest = [None]
if not aux_ctx_tensors[0].is_contiguous():
aux_ctx_tensors[0] = aux_ctx_tensors[0].contiguous()
if ctx.use_FAv2_bwd:
softmax_lse, rng_state = aux_ctx_tensors
dq = torch.empty_like(q)
dkv = torch.empty_like(kv)
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
d_out, q, k, v, out = [maybe_contiguous(x) for x in (d_out, q, kv[:, 0], kv[:, 1], out)]
flash_attn_cuda_bwd(
d_out,
......@@ -6351,6 +6420,7 @@ class FusedAttnFunc(torch.autograd.Function):
fp8_meta,
deterministic,
):
# pylint: disable=missing-function-docstring
is_input_fp8 = False
is_output_fp8 = fp8_meta["recipe"].fp8_mha
if fp8:
......@@ -6616,6 +6686,7 @@ class FusedAttnFunc(torch.autograd.Function):
@staticmethod
def backward(ctx, d_out):
# pylint: disable=missing-function-docstring
if ctx.is_output_fp8:
assert isinstance(
d_out, Float8Tensor
......@@ -6643,12 +6714,12 @@ class FusedAttnFunc(torch.autograd.Function):
) = ctx.saved_tensors
if not aux_ctx_tensors[0].is_contiguous():
aux_ctx_tensors[0] = aux_ctx_tensors[0].contiguous()
rest = [None]
if ctx.use_FAv2_bwd:
softmax_lse, rng_state = aux_ctx_tensors
dq = torch.empty_like(q)
dk = torch.empty_like(k)
dv = torch.empty_like(v)
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
d_out, q, k, v, out = [maybe_contiguous(x) for x in (d_out, q, k, v, out)]
flash_attn_cuda_bwd(
d_out,
......@@ -7859,7 +7930,7 @@ class DotProductAttention(TransformerEngineBaseModule):
if qkv_format == "sbhd":
max_seqlen_q, max_seqlen_kv = (query_layer.shape[0], key_layer.shape[0])
batch_size = query_layer.shape[1]
if qkv_format == "bshd":
else:
max_seqlen_q, max_seqlen_kv = (query_layer.shape[1], key_layer.shape[1])
batch_size = query_layer.shape[0]
max_seqlen_q *= cp_size
......@@ -8168,7 +8239,7 @@ class DotProductAttention(TransformerEngineBaseModule):
alibi_slopes=alibi_slopes,
)
raise Exception("No dot product attention support for the provided inputs!")
raise ValueError("No dot product attention support for the provided inputs!")
class MultiheadAttention(torch.nn.Module):
......@@ -8522,6 +8593,7 @@ class MultiheadAttention(torch.nn.Module):
def _allocate_memory(
self, inference_max_sequence_len: int, batch_size: int, dtype: torch.dtype
) -> torch.Tensor:
"""Allocates memory for KV cache."""
return torch.empty(
inference_max_sequence_len,
batch_size,
......@@ -8692,10 +8764,8 @@ class MultiheadAttention(torch.nn.Module):
window_size = check_set_window_size(attn_mask_type, window_size)
if "padding" in attn_mask_type and attention_mask is not None:
for i, _ in enumerate(attention_mask):
assert (
attention_mask[i].dtype == torch.bool
), "Attention mask must be in boolean type!"
for mask in attention_mask:
assert mask.dtype == torch.bool, "Attention mask must be in boolean type!"
assert (
core_attention_bias_type in AttnBiasTypes
......@@ -8737,6 +8807,7 @@ class MultiheadAttention(torch.nn.Module):
and FP8GlobalStateManager.get_fp8_recipe().fp8_mha
)
layernorm_output = None
if self.attention_type == "self":
# Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn]
if self.input_layernorm:
......@@ -8904,6 +8975,8 @@ class MultiheadAttention(torch.nn.Module):
sequence_length = key_layer.size(0)
elif self.qkv_format == "bshd":
sequence_length = key_layer.size(1)
else:
raise ValueError(f"QKV format {self.qkv_format} not supported for KV caching.")
sequence_start = inference_params.sequence_len_offset
sequence_end = sequence_start + sequence_length
......
......@@ -78,10 +78,10 @@ def canonicalize_fp8_scales(
scale_inv_offset = 0
# Pack tensors and offsets into dicts
tensors = dict(scale=scale, amax=amax, scale_inv=scale_inv)
offsets = dict(
scale_offset=scale_offset,
amax_offset=amax_offset,
scale_inv_offset=scale_inv_offset,
)
tensors = {"scale": scale, "amax": amax, "scale_inv": scale_inv}
offsets = {
"scale_offset": scale_offset,
"amax_offset": amax_offset,
"scale_inv_offset": scale_inv_offset,
}
return tensors, offsets
......@@ -240,13 +240,11 @@ def fused_attn_fwd_qkvpacked(
rng_elts_per_thread = (
max_seqlen * max_seqlen + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1
) // BACKEND_F16m512_FP8_THREADS_PER_CTA
# BF16/FP16 fused attention API from fmha_v2
if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
elif fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS
# FP8 fused attention API from fmha_v2
if fused_attention_backend == FusedAttnBackend["FP8"]:
elif fused_attention_backend == FusedAttnBackend["FP8"]:
rng_elts_per_thread = (
max_seqlen * max_seqlen + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1
) // BACKEND_F16m512_FP8_THREADS_PER_CTA
......@@ -259,6 +257,8 @@ def fused_attn_fwd_qkvpacked(
assert q_scale_o is not None, "q_scale_o is required as an input for FP8 fused attention."
assert amax_s is not None, "amax_s is required as an input for FP8 fused attention."
assert amax_o is not None, "amax_o is required as an input for FP8 fused attention."
else:
raise ValueError(f"Unsupported backend {fused_attention_backend}")
# execute kernel
output_tensors = tex.fused_attn_fwd_qkvpacked(
......@@ -633,13 +633,11 @@ def fused_attn_fwd_kvpacked(
rng_elts_per_thread = (
max_seqlen_q * max_seqlen_kv + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1
) // BACKEND_F16m512_FP8_THREADS_PER_CTA
# BF16/FP16 fused attention API from fmha_v2
if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
elif fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS
# FP8 fused attention API from fmha_v2
if fused_attention_backend == FusedAttnBackend["FP8"]:
elif fused_attention_backend == FusedAttnBackend["FP8"]:
rng_elts_per_thread = (
max_seqlen_q * max_seqlen_q + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1
) // BACKEND_F16m512_FP8_THREADS_PER_CTA
......@@ -652,6 +650,8 @@ def fused_attn_fwd_kvpacked(
assert q_scale_o is not None, "q_scale_o is required as an input for FP8 fused attention."
assert amax_s is not None, "amax_s is required as an input for FP8 fused attention."
assert amax_o is not None, "amax_o is required as an input for FP8 fused attention."
else:
raise ValueError(f"Unsupported backend {fused_attention_backend}")
# execute kernel
output_tensors = tex.fused_attn_fwd_kvpacked(
......@@ -1058,13 +1058,11 @@ def fused_attn_fwd(
rng_elts_per_thread = (
max_seqlen_q * max_seqlen_kv + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1
) // BACKEND_F16m512_FP8_THREADS_PER_CTA
# BF16/FP16 fused attention API from fmha_v2
if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
elif fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS
# FP8 fused attention API from fmha_v2
if fused_attention_backend == FusedAttnBackend["FP8"]:
elif fused_attention_backend == FusedAttnBackend["FP8"]:
rng_elts_per_thread = (
max_seqlen_q * max_seqlen_q + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1
) // BACKEND_F16m512_FP8_THREADS_PER_CTA
......@@ -1077,6 +1075,8 @@ def fused_attn_fwd(
assert q_scale_o is not None, "q_scale_o is required as an input for FP8 fused attention."
assert amax_s is not None, "amax_s is required as an input for FP8 fused attention."
assert amax_o is not None, "amax_o is required as an input for FP8 fused attention."
else:
raise ValueError(f"Unsupported backend {fused_attention_backend}")
# execute kernel
output_tensors = tex.fused_attn_fwd(
......
......@@ -161,6 +161,7 @@ class GroupCommitFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, tensor, cpu_offload_handler):
# pylint: disable=missing-function-docstring
cpu_offload_handler.on_group_commit_forward()
ctx.cpu_offload_handler = cpu_offload_handler
# return the identical tensor
......@@ -168,6 +169,7 @@ class GroupCommitFunction(torch.autograd.Function):
@staticmethod
def backward(ctx, grad_output):
# pylint: disable=missing-function-docstring
cpu_offload_handler = ctx.cpu_offload_handler
cpu_offload_handler.on_group_commit_backward()
return grad_output, None
......
......@@ -752,11 +752,11 @@ class CudaRNGStatesTracker:
"""
# Check seed is not already used.
if seed in self.seeds_:
raise Exception(f"seed {seed} already exists")
raise RuntimeError(f"seed {seed} already exists")
self.seeds_.add(seed)
# Check that state is not already defined.
if name in self.states_:
raise Exception(f"cuda rng state {name} already exists")
raise RuntimeError(f"cuda rng state {name} already exists")
if graph_safe_rng_available():
new_state = _get_cuda_rng_state(clone=True)
......@@ -786,7 +786,7 @@ class CudaRNGStatesTracker:
"""
# Check if we have added the state
if name not in self.states_:
raise Exception(f"cuda rng state {name} is not added")
raise KeyError(f"cuda rng state {name} is not added")
# Get the reference to current rng state.
orig_cuda_rng_state = _get_cuda_rng_state()
# Set rng state to the desired one
......
......@@ -358,6 +358,7 @@ def _make_graphed_callables(
@staticmethod
def forward(ctx, skip_fp8_weight_update, *inputs):
# pylint: disable=missing-function-docstring
# Set flag for whether to update FP8 weight updates
ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module()
......@@ -377,6 +378,7 @@ def _make_graphed_callables(
@staticmethod
@torch.autograd.function.once_differentiable
def backward(ctx, *grads):
# pylint: disable=missing-function-docstring
# Replay backward graph
assert len(grads) == len(static_grad_outputs)
......
......@@ -8,6 +8,8 @@ from typing import Callable, Optional, Tuple
import torch
# pylint: disable=unnecessary-lambda-assignment
jit_fuser = torch.jit.script
if torch.__version__ >= "2" and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))):
jit_fuser = torch.compile
......
......@@ -124,6 +124,7 @@ class _NoopCatFunc(torch.autograd.Function):
dim: int,
*tensors: Tuple[torch.Tensor, ...],
) -> torch.Tensor:
# pylint: disable=missing-function-docstring
# Check first tensor
if not tensors:
......@@ -192,6 +193,7 @@ class _NoopCatFunc(torch.autograd.Function):
ctx,
grad_output: torch.Tensor,
) -> Tuple[Optional[torch.Tensor], ...]:
# pylint: disable=missing-function-docstring
grad_inputs = []
for split_start, split_end in ctx.split_ranges:
slices = [slice(None)] * grad_output.dim()
......
......@@ -694,7 +694,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
else:
# If fp8 isn't enabled, turn off and return.
self.fp8_initialized = False
return
@contextmanager
def prepare_forward(
......@@ -744,7 +743,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if self.fp8 and in_fp8_activation_recompute_phase():
FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta)
return
def set_nccl_overlap_warning_if_tp(self) -> None:
"""When using TP, the NCCL communication needs to be scheduled
......
......@@ -28,6 +28,7 @@ class _Fp8Padding(torch.autograd.Function):
padded_m_splits: List[int],
is_grad_enabled: bool,
) -> torch.Tensor:
# pylint: disable=missing-function-docstring
# Make sure input dimensions are compatible
in_features = inp.shape[-1]
......@@ -46,6 +47,7 @@ class _Fp8Padding(torch.autograd.Function):
@staticmethod
def backward(ctx, grad_output: torch.Tensor):
# pylint: disable=missing-function-docstring
grad_input = None
if ctx.requires_dgrad:
......
......@@ -28,6 +28,7 @@ class _Fp8Unpadding(torch.autograd.Function):
padded_m_splits: List[int],
is_grad_enabled: bool,
) -> torch.Tensor:
# pylint: disable=missing-function-docstring
inputmats = torch.split(inp.view(-1, inp.shape[-1]), padded_m_splits)
out_ret = torch.cat(
[grad_output_mat[: m_splits[i]] for i, grad_output_mat in enumerate(inputmats)], dim=0
......@@ -42,6 +43,7 @@ class _Fp8Unpadding(torch.autograd.Function):
@staticmethod
def backward(ctx, grad_output: torch.Tensor):
# pylint: disable=missing-function-docstring
grad_input = None
if ctx.requires_dgrad:
grad_output = grad_output.contiguous()
......
......@@ -70,6 +70,7 @@ class _GroupedLinear(torch.autograd.Function):
weights_fp8: List[Union[Float8Tensor, None]],
*weights_and_biases: Union[Float8Tensor, torch.Tensor, None],
) -> torch.Tensor:
# pylint: disable=missing-function-docstring
num_gemms = len(m_splits)
weights = weights_and_biases[:num_gemms]
biases = weights_and_biases[num_gemms:]
......@@ -268,6 +269,7 @@ class _GroupedLinear(torch.autograd.Function):
@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
# pylint: disable=missing-function-docstring
with torch.cuda.nvtx.range("_GroupedLinear_backward"):
(
inputmat_scale_inv,
......@@ -641,7 +643,7 @@ class GroupedLinear(TransformerEngineBaseModule):
if self.primary_weights_in_fp8:
self.init_fp8_metadata(num_gemms=self.num_gemms)
self.reset_parameters(defer_init=(device == "meta"))
self.reset_parameters(defer_init=device == "meta")
# For RPL, bias has to be added after TP collectives
# So it cannot be fused with the GEMM
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment