Unverified Commit 6e90fcb7 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Upgrade pylint to 3.3.1 (#1257)



* Upgrade pylint and first round formatting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* round 2
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* round 3
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Format and fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Paddle lint
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Reviews
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* FIxes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* More linting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Run formatter
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Paddle lint
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 161b1d98
...@@ -8,7 +8,9 @@ extension-pkg-whitelist=flash_attn_2_cuda, ...@@ -8,7 +8,9 @@ extension-pkg-whitelist=flash_attn_2_cuda,
extension-pkg-allow-list=transformer_engine.transformer_engine_jax extension-pkg-allow-list=transformer_engine.transformer_engine_jax
disable=too-many-locals, disable=too-many-locals,
too-few-public-methods,
too-many-public-methods, too-many-public-methods,
too-many-positional-arguments,
invalid-name, invalid-name,
too-many-arguments, too-many-arguments,
abstract-method, abstract-method,
......
...@@ -6,7 +6,7 @@ set -e ...@@ -6,7 +6,7 @@ set -e
: "${TE_PATH:=/opt/transformerengine}" : "${TE_PATH:=/opt/transformerengine}"
pip install cpplint==1.6.0 pylint==2.13.5 pip install cpplint==1.6.0 pylint==3.3.1
if [ -z "${PYTHON_ONLY}" ] if [ -z "${PYTHON_ONLY}" ]
then then
cd $TE_PATH cd $TE_PATH
......
...@@ -6,7 +6,7 @@ set -e ...@@ -6,7 +6,7 @@ set -e
: "${TE_PATH:=/opt/transformerengine}" : "${TE_PATH:=/opt/transformerengine}"
pip install cpplint==1.6.0 pylint==2.13.5 pip install cpplint==1.6.0 pylint==3.3.1
if [ -z "${PYTHON_ONLY}" ] if [ -z "${PYTHON_ONLY}" ]
then then
cd $TE_PATH cd $TE_PATH
......
...@@ -6,7 +6,7 @@ set -e ...@@ -6,7 +6,7 @@ set -e
: "${TE_PATH:=/opt/transformerengine}" : "${TE_PATH:=/opt/transformerengine}"
pip install cpplint==1.6.0 pylint==2.13.5 pip install cpplint==1.6.0 pylint==3.3.1
if [ -z "${PYTHON_ONLY}" ] if [ -z "${PYTHON_ONLY}" ]
then then
cd $TE_PATH cd $TE_PATH
......
...@@ -583,6 +583,7 @@ def fused_attn_fwd_qkvpacked( ...@@ -583,6 +583,7 @@ def fused_attn_fwd_qkvpacked(
fused_attention_backend != FusedAttnBackend["No_Backend"] fused_attention_backend != FusedAttnBackend["No_Backend"]
), "Fused attention does not support this input combination." ), "Fused attention does not support this input combination."
rng_elts_per_thread = None
# BF16/FP16 fused attention API from fmha_v1 apex # BF16/FP16 fused attention API from fmha_v1 apex
if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]: if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]:
rng_elts_per_thread = ( rng_elts_per_thread = (
...@@ -773,6 +774,7 @@ def fused_attn_fwd_kvpacked( ...@@ -773,6 +774,7 @@ def fused_attn_fwd_kvpacked(
fused_attention_backend != FusedAttnBackend["No_Backend"] fused_attention_backend != FusedAttnBackend["No_Backend"]
), "Fused attention does not support this input combination." ), "Fused attention does not support this input combination."
rng_elts_per_thread = None
# BF16/FP16 fused attention API from fmha_v1 apex # BF16/FP16 fused attention API from fmha_v1 apex
if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]: if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]:
rng_elts_per_thread = ( rng_elts_per_thread = (
...@@ -982,6 +984,7 @@ def fused_attn_fwd( ...@@ -982,6 +984,7 @@ def fused_attn_fwd(
fused_attention_backend != FusedAttnBackend["No_Backend"] fused_attention_backend != FusedAttnBackend["No_Backend"]
), "Fused attention does not support this input combination." ), "Fused attention does not support this input combination."
rng_elts_per_thread = None
# BF16/FP16 fused attention API from fmha_v1 apex # BF16/FP16 fused attention API from fmha_v1 apex
if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]: if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]:
rng_elts_per_thread = ( rng_elts_per_thread = (
......
...@@ -100,6 +100,7 @@ class FP8MetaBufferBase(ABC): ...@@ -100,6 +100,7 @@ class FP8MetaBufferBase(ABC):
self._dp_amax_reduce_interval = int(os.getenv("NVTE_DP_AMAX_REDUCE_INTERVAL", "1")) self._dp_amax_reduce_interval = int(os.getenv("NVTE_DP_AMAX_REDUCE_INTERVAL", "1"))
tp_amax_reduce = False tp_amax_reduce = False
reduce_group = -1 # Set value that will raise error if not set. `None` is a valid group.
if self._dp_amax_reduce_idx == 0: if self._dp_amax_reduce_idx == 0:
reduce_group = fp8_meta["fp8_group"] reduce_group = fp8_meta["fp8_group"]
else: else:
......
...@@ -1008,6 +1008,7 @@ class MultiHeadAttention(paddle.nn.Layer): ...@@ -1008,6 +1008,7 @@ class MultiHeadAttention(paddle.nn.Layer):
else: else:
raise ValueError(f"hidden_states should have 2 or 3 dimensions, got {input_dim}.") raise ValueError(f"hidden_states should have 2 or 3 dimensions, got {input_dim}.")
layernorm_output = None
if self.attention_type == "self": if self.attention_type == "self":
if self.input_layernorm: if self.input_layernorm:
layernorm_qkv_outputs = self.layernorm_qkv( layernorm_qkv_outputs = self.layernorm_qkv(
......
...@@ -266,6 +266,8 @@ def _mlp_backward( ...@@ -266,6 +266,8 @@ def _mlp_backward(
accumulate_wgrad_into_param_main_grad, accumulate_wgrad_into_param_main_grad,
) )
dgelu_t = None
fc1_bgrad_ = None
if activation == "gelu": if activation == "gelu":
# GELU Bwd # GELU Bwd
dgelu, dgelu_t, fc1_bgrad_ = dgelu_cast_transpose_bgrad_fp8( dgelu, dgelu_t, fc1_bgrad_ = dgelu_cast_transpose_bgrad_fp8(
......
...@@ -86,6 +86,7 @@ from transformer_engine.pytorch.export import is_in_onnx_export_mode ...@@ -86,6 +86,7 @@ from transformer_engine.pytorch.export import is_in_onnx_export_mode
from transformer_engine.pytorch.jit import jit_fuser, no_torch_dynamo from transformer_engine.pytorch.jit import jit_fuser, no_torch_dynamo
from transformer_engine.pytorch.graph import is_graph_capturing from transformer_engine.pytorch.graph import is_graph_capturing
# NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0 # NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0")) _NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
# NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0 # NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0
...@@ -123,6 +124,13 @@ _flash_attn_2_4_plus = False ...@@ -123,6 +124,13 @@ _flash_attn_2_4_plus = False
_flash_attn_2_4_1_plus = False _flash_attn_2_4_1_plus = False
_flash_attn_2_5_7_plus = False _flash_attn_2_5_7_plus = False
_flash_attn_2_6_0_plus = False _flash_attn_2_6_0_plus = False
flash_attn_func = None
flash_attn_varlen_func = None
flash_attn_varlen_fwd = None
flash_attn_varlen_bwd = None
flash_attn_cuda_bwd = None
try: try:
_flash_attn_version = PkgVersion(get_pkg_version("flash-attn")) _flash_attn_version = PkgVersion(get_pkg_version("flash-attn"))
except PackageNotFoundError: except PackageNotFoundError:
...@@ -196,7 +204,6 @@ else: ...@@ -196,7 +204,6 @@ else:
_flash_attn_3_0_0_beta = PkgVersion("3.0.0b") < _flash_attn_3_version < PkgVersion("3.0.0") _flash_attn_3_0_0_beta = PkgVersion("3.0.0b") < _flash_attn_3_version < PkgVersion("3.0.0")
_use_flash_attn_3 = True _use_flash_attn_3 = True
_attention_backends = { _attention_backends = {
"attention_params": None, "attention_params": None,
"use_flash_attention": None, "use_flash_attention": None,
...@@ -304,6 +311,11 @@ _alibi_cache = { ...@@ -304,6 +311,11 @@ _alibi_cache = {
__all__ = ["DotProductAttention", "InferenceParams", "MultiheadAttention"] __all__ = ["DotProductAttention", "InferenceParams", "MultiheadAttention"]
def maybe_contiguous(tensor: torch.Tensor) -> torch.Tensor:
"""Make tensor contiguous if final stride is not 1."""
return tensor.contiguous() if tensor.stride(-1) != 1 else tensor
def get_attention_backend( def get_attention_backend(
attention_params: AttentionParams = None, attention_params: AttentionParams = None,
): ):
...@@ -362,9 +374,7 @@ def get_attention_backend( ...@@ -362,9 +374,7 @@ def get_attention_backend(
run_config = { run_config = {
"transformer_engine_version": te.__version__, "transformer_engine_version": te.__version__,
"compute_capability": "sm" "compute_capability": "sm"
+ str( + str(10 * device_compute_capability[0] + device_compute_capability[1]),
(lambda x, y: x * 10 + y)(device_compute_capability[0], device_compute_capability[1])
),
"flash_attn_version": ( "flash_attn_version": (
str(_flash_attn_version) if _flash_attn_is_installed else "not installed" str(_flash_attn_version) if _flash_attn_is_installed else "not installed"
), ),
...@@ -1135,8 +1145,11 @@ def get_alibi( ...@@ -1135,8 +1145,11 @@ def get_alibi(
assert _alibi_cache["_alibi_slopes"] is not None, "ALiBi slopes can not be None!" assert _alibi_cache["_alibi_slopes"] is not None, "ALiBi slopes can not be None!"
if _alibi_cache["_alibi_slopes"].dim() == 1: if _alibi_cache["_alibi_slopes"].dim() == 1:
slopes_shape = torch.Size([1, _alibi_cache["_alibi_slopes"].shape[0], 1, 1]) slopes_shape = torch.Size([1, _alibi_cache["_alibi_slopes"].shape[0], 1, 1])
if _alibi_cache["_alibi_slopes"].dim() == 2: elif _alibi_cache["_alibi_slopes"].dim() == 2:
slopes_shape = torch.Size([*_alibi_cache["_alibi_slopes"].shape[:], 1, 1]) slopes_shape = torch.Size([*_alibi_cache["_alibi_slopes"].shape[:], 1, 1])
else:
raise ValueError("ALiBi slopes cannot exceed 2 dimensions.")
bias = torch.arange(max_seqlen_q, dtype=torch.int32, device="cuda").view( bias = torch.arange(max_seqlen_q, dtype=torch.int32, device="cuda").view(
1, 1, max_seqlen_q, 1 1, 1, max_seqlen_q, 1
) - torch.arange(max_seqlen_kv, dtype=torch.int32, device="cuda").view( ) - torch.arange(max_seqlen_kv, dtype=torch.int32, device="cuda").view(
...@@ -1372,6 +1385,7 @@ class PackTensors(torch.autograd.Function): ...@@ -1372,6 +1385,7 @@ class PackTensors(torch.autograd.Function):
def forward( def forward(
ctx, indices: torch.Tensor, *tensors: Tuple[torch.Tensor, ...] ctx, indices: torch.Tensor, *tensors: Tuple[torch.Tensor, ...]
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# pylint: disable=missing-function-docstring
assert 1 <= len(tensors) <= 3, f"Packing {len(tensors)} tensors not supported." assert 1 <= len(tensors) <= 3, f"Packing {len(tensors)} tensors not supported."
ctx.save_for_backward(indices) ctx.save_for_backward(indices)
ctx.dim0 = tensors[0].shape[0] ctx.dim0 = tensors[0].shape[0]
...@@ -1383,6 +1397,7 @@ class PackTensors(torch.autograd.Function): ...@@ -1383,6 +1397,7 @@ class PackTensors(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, *grad_outputs: Tuple[torch.Tensor, ...]): def backward(ctx, *grad_outputs: Tuple[torch.Tensor, ...]):
# pylint: disable=missing-function-docstring
(indices,) = ctx.saved_tensors (indices,) = ctx.saved_tensors
if len(grad_outputs) == 1: if len(grad_outputs) == 1:
return None, unpack_tensor(indices, ctx.dim0, *grad_outputs) return None, unpack_tensor(indices, ctx.dim0, *grad_outputs)
...@@ -1403,11 +1418,13 @@ class UnpackTensor(torch.autograd.Function): ...@@ -1403,11 +1418,13 @@ class UnpackTensor(torch.autograd.Function):
dim0: int, dim0: int,
tensor: torch.Tensor, tensor: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
# pylint: disable=missing-function-docstring
ctx.save_for_backward(indices) ctx.save_for_backward(indices)
return unpack_tensor(indices, dim0, tensor) return unpack_tensor(indices, dim0, tensor)
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
# pylint: disable=missing-function-docstring
(indices,) = ctx.saved_tensors (indices,) = ctx.saved_tensors
return None, None, pack_tensor(indices, grad_output) return None, None, pack_tensor(indices, grad_output)
...@@ -1661,6 +1678,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1661,6 +1678,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
cp_global_ranks, cp_global_ranks,
cp_stream, cp_stream,
): ):
# pylint: disable=missing-function-docstring
if softmax_scale is None: if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5) softmax_scale = q.shape[-1] ** (-0.5)
...@@ -1690,6 +1708,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1690,6 +1708,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
causal = "causal" in attn_mask_type causal = "causal" in attn_mask_type
padding = "padding" in attn_mask_type padding = "padding" in attn_mask_type
seq_dim = None
if qkv_format in ["bshd", "sbhd"]: if qkv_format in ["bshd", "sbhd"]:
seq_dim = qkv_format.index("s") seq_dim = qkv_format.index("s")
qkv_layout = qkv_format + "_" + qkv_format[:-2] + "2" + qkv_format[-2:] qkv_layout = qkv_format + "_" + qkv_format[:-2] + "2" + qkv_format[-2:]
...@@ -1705,6 +1724,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1705,6 +1724,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
cu_seqlens_q_per_step = [None for _ in range(cp_size)] cu_seqlens_q_per_step = [None for _ in range(cp_size)]
cu_seqlens_kv_per_step = [None for _ in range(cp_size)] cu_seqlens_kv_per_step = [None for _ in range(cp_size)]
fused_attn_qkv_dtype = None
fused_attn_backend = None
amax_per_step = None
if fp8: if fp8:
if use_fused_attention: if use_fused_attention:
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
...@@ -1796,6 +1818,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1796,6 +1818,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
softmax_lse_in_packed_format = not use_fused_attention and ( softmax_lse_in_packed_format = not use_fused_attention and (
_flash_attn_2_6_0_plus or _use_flash_attn_3 _flash_attn_2_6_0_plus or _use_flash_attn_3
) )
flash_attn_fwd = None
if not use_fused_attention: if not use_fused_attention:
fa_forward_kwargs = {"softmax_scale": softmax_scale} fa_forward_kwargs = {"softmax_scale": softmax_scale}
if _use_flash_attn_3: if _use_flash_attn_3:
...@@ -1834,6 +1857,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1834,6 +1857,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
p2p_comm_buffers[0] = torch.cat((k.unsqueeze(0), v.unsqueeze(0)), dim=0) p2p_comm_buffers[0] = torch.cat((k.unsqueeze(0), v.unsqueeze(0)), dim=0)
send_recv_reqs = [[], []] send_recv_reqs = [[], []]
softmax_lse_ = None
out = None
for i in range(cp_size + 1): for i in range(cp_size + 1):
if i < cp_size: if i < cp_size:
with torch.cuda.stream(flash_attn_streams[i % 2]): with torch.cuda.stream(flash_attn_streams[i % 2]):
...@@ -2326,8 +2351,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2326,8 +2351,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
softmax_lse = softmax_lse.to(torch.float) softmax_lse = softmax_lse.to(torch.float)
for i in range(cp_size): for i in range(cp_size):
out_ = None
if qkv_format == "bshd": if qkv_format == "bshd":
out_per_step[i] = out_per_step[i].view(out.shape[0], -1, *out.shape[-2:]) out_per_step[i] = out_per_step[i].view(
out.shape[0], -1, *out.shape[-2:]
) # pylint: disable=used-before-assignment
out_ = out[:, 1, ...] out_ = out[:, 1, ...]
elif qkv_format == "sbhd": elif qkv_format == "sbhd":
out_per_step[i] = out_per_step[i].view(-1, *out.shape[-3:]) out_per_step[i] = out_per_step[i].view(-1, *out.shape[-3:])
...@@ -2405,6 +2433,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2405,6 +2433,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
fp8_meta["scaling_fwd"].amax_history[0][META_S] = amax_cp_fwd[0] fp8_meta["scaling_fwd"].amax_history[0][META_S] = amax_cp_fwd[0]
fp8_meta["scaling_fwd"].amax_history[0][META_O_CP] = amax_cp_fwd[1] fp8_meta["scaling_fwd"].amax_history[0][META_O_CP] = amax_cp_fwd[1]
out_fp8 = None
out_f16 = out.to(q_fp8.dtype if fp8 and fp8_meta["recipe"].fp8_mha else q_f16.dtype) out_f16 = out.to(q_fp8.dtype if fp8 and fp8_meta["recipe"].fp8_mha else q_f16.dtype)
if fp8 and (fp8_meta["recipe"].fp8_mha or int(os.getenv("NVTE_FP8_DPA_BWD", "1"))): if fp8 and (fp8_meta["recipe"].fp8_mha or int(os.getenv("NVTE_FP8_DPA_BWD", "1"))):
out_fp8 = cast_to_fp8(out_f16, fp8_meta["scaling_fwd"], META_O, fp8_dtype_forward) out_fp8 = cast_to_fp8(out_f16, fp8_meta["scaling_fwd"], META_O, fp8_dtype_forward)
...@@ -2486,6 +2515,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2486,6 +2515,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, dout): def backward(ctx, dout):
# pylint: disable=missing-function-docstring
cp_size_a2a = ctx.cp_size_a2a cp_size_a2a = ctx.cp_size_a2a
rank_a2a = ctx.rank_a2a rank_a2a = ctx.rank_a2a
...@@ -2521,6 +2551,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2521,6 +2551,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
) )
else: else:
attn_dbias = None attn_dbias = None
attn_dbias_ = None
softmax_lse_in_packed_format = not ctx.use_fused_attention and ( softmax_lse_in_packed_format = not ctx.use_fused_attention and (
_flash_attn_2_6_0_plus or _use_flash_attn_3 _flash_attn_2_6_0_plus or _use_flash_attn_3
...@@ -2545,6 +2576,12 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2545,6 +2576,12 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
softmax_lse.unsqueeze_(-1) softmax_lse.unsqueeze_(-1)
dout_dtype = dout.dtype dout_dtype = dout.dtype
fused_attn_backend = None
fused_attn_qkv_dtype = None
fused_attn_dqkv_dtype = None
amax_per_step = None
seq_dim = None
dout_fp8_dtype = None
if ctx.fp8: if ctx.fp8:
if ctx.use_fused_attention: if ctx.use_fused_attention:
fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
...@@ -2615,13 +2652,19 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2615,13 +2652,19 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
) )
if not ctx.fp8 and ctx.fp8_meta is not None and ctx.fp8_meta["recipe"].fp8_mha: if not ctx.fp8 and ctx.fp8_meta is not None and ctx.fp8_meta["recipe"].fp8_mha:
dout = cast_from_fp8( dout = cast_from_fp8(
dout, None, None, dout_fp8_dtype, TE_DType[dout_dtype], scale_inv=dout_scale_inv dout,
None,
None,
dout_fp8_dtype,
TE_DType[dout_dtype],
scale_inv=dout_scale_inv, # pylint: disable=used-before-assignment
) )
out = out.view(*q.shape) out = out.view(*q.shape)
dout = dout.view(*q.shape) dout = dout.view(*q.shape)
send_recv_reqs = [] send_recv_reqs = []
flash_attn_bwd = None
if not ctx.use_fused_attention: if not ctx.use_fused_attention:
fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale} fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale}
if _use_flash_attn_3: if _use_flash_attn_3:
...@@ -2673,6 +2716,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2673,6 +2716,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
) )
kv = p2p_comm_buffers[i % 2][0] kv = p2p_comm_buffers[i % 2][0]
dk_, dv_ = None, None
if ctx.fp8 and ctx.use_fused_attention: if ctx.fp8 and ctx.use_fused_attention:
fp8_meta_kwargs["amax_dp"] = amax_per_step[0][i] fp8_meta_kwargs["amax_dp"] = amax_per_step[0][i]
fp8_meta_kwargs["amax_dqkv"] = amax_per_step[0][i] fp8_meta_kwargs["amax_dqkv"] = amax_per_step[0][i]
...@@ -3106,7 +3150,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -3106,7 +3150,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
else: else:
dkv = p2p_comm_buffers[(i + 1) % 2][1] dkv = p2p_comm_buffers[(i + 1) % 2][1]
if ctx.use_fused_attention: if ctx.use_fused_attention:
dkv_ = torch.cat((dk_.unsqueeze(0), dv_.unsqueeze(0)), dim=0) dkv_ = torch.cat(
(dk_.unsqueeze(0), dv_.unsqueeze(0)), dim=0
) # pylint: disable=used-before-assignment
if ctx.qkv_format in ["bshd", "sbhd"]: if ctx.qkv_format in ["bshd", "sbhd"]:
# [b, 2, sk//2, 2, np, hn] -> [2, b, 2, sk//2, np, hn] or # [b, 2, sk//2, 2, np, hn] -> [2, b, 2, sk//2, np, hn] or
# [2, sk//2, b, 2, np, hn] -> [2, 2, sk//2, b, np, hn] # [2, sk//2, b, 2, np, hn] -> [2, 2, sk//2, b, np, hn]
...@@ -3334,6 +3380,7 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function): ...@@ -3334,6 +3380,7 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
cp_group, cp_group,
cp_stream, cp_stream,
): ):
# pylint: disable=missing-function-docstring
if softmax_scale is None: if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5) softmax_scale = q.shape[-1] ** (-0.5)
...@@ -3351,6 +3398,7 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function): ...@@ -3351,6 +3398,7 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
use_fused_attention or _flash_attn_2_3_plus use_fused_attention or _flash_attn_2_3_plus
), "Sliding window attention only can work with FusedAttention or FlashAttention >= 2.3!" ), "Sliding window attention only can work with FusedAttention or FlashAttention >= 2.3!"
flash_attn_fwd = None
if not use_fused_attention: if not use_fused_attention:
fa_forward_kwargs = {"softmax_scale": softmax_scale} fa_forward_kwargs = {"softmax_scale": softmax_scale}
if _use_flash_attn_3: if _use_flash_attn_3:
...@@ -3521,6 +3569,7 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function): ...@@ -3521,6 +3569,7 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, dout): def backward(ctx, dout):
# pylint: disable=missing-function-docstring
cp_size = get_distributed_world_size(ctx.cp_group) cp_size = get_distributed_world_size(ctx.cp_group)
rank = get_distributed_rank(ctx.cp_group) rank = get_distributed_rank(ctx.cp_group)
...@@ -3565,6 +3614,7 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function): ...@@ -3565,6 +3614,7 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
local_seq_chunk_ids = [rank, 2 * cp_size - rank - 1] local_seq_chunk_ids = [rank, 2 * cp_size - rank - 1]
flash_attn_bwd = None
if not ctx.use_fused_attention: if not ctx.use_fused_attention:
fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale} fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale}
if _use_flash_attn_3: if _use_flash_attn_3:
...@@ -3751,6 +3801,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -3751,6 +3801,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
cp_group, cp_group,
cp_stream, cp_stream,
): ):
# pylint: disable=missing-function-docstring
if softmax_scale is None: if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5) softmax_scale = q.shape[-1] ** (-0.5)
...@@ -3768,6 +3819,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -3768,6 +3819,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
or _flash_attn_2_3_plus or _flash_attn_2_3_plus
), "Sliding window attention only can work with FusedAttention or FlashAttention >= 2.3!" ), "Sliding window attention only can work with FusedAttention or FlashAttention >= 2.3!"
flash_attn_fwd = None
if not use_fused_attention: if not use_fused_attention:
fa_forward_kwargs = {"softmax_scale": softmax_scale} fa_forward_kwargs = {"softmax_scale": softmax_scale}
if _use_flash_attn_3: if _use_flash_attn_3:
...@@ -3797,6 +3849,8 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -3797,6 +3849,8 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0 q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0
), "Sequence length per GPU needs to be divisible by 2!" ), "Sequence length per GPU needs to be divisible by 2!"
fused_attn_backend = None
fused_attn_qkv_dtype = None
if fp8: if fp8:
if use_fused_attention: if use_fused_attention:
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
...@@ -3990,6 +4044,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -3990,6 +4044,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, dout): def backward(ctx, dout):
# pylint: disable=missing-function-docstring
cp_size = get_distributed_world_size(ctx.cp_group) cp_size = get_distributed_world_size(ctx.cp_group)
q, k, v, out = ctx.saved_tensors[:4] q, k, v, out = ctx.saved_tensors[:4]
...@@ -4003,6 +4058,9 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -4003,6 +4058,9 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
causal = "causal" in ctx.attn_mask_type causal = "causal" in ctx.attn_mask_type
seq_dim = ctx.qkv_format.index("s") seq_dim = ctx.qkv_format.index("s")
fused_attn_backend = None
fused_attn_dqkv_dtype = None
fused_attn_qkv_dtype = None
if ctx.fp8: if ctx.fp8:
if ctx.use_fused_attention: if ctx.use_fused_attention:
fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
...@@ -4054,6 +4112,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -4054,6 +4112,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
[out, dout], chunk_ids_for_a2a, seq_dim, cp_size, ctx.cp_group, ctx.cp_stream, True [out, dout], chunk_ids_for_a2a, seq_dim, cp_size, ctx.cp_group, ctx.cp_stream, True
) )
flash_attn_bwd = None
if not ctx.use_fused_attention: if not ctx.use_fused_attention:
fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale} fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale}
if _use_flash_attn_3: if _use_flash_attn_3:
...@@ -4400,6 +4459,7 @@ class FusedRoPEFunc(torch.autograd.Function): ...@@ -4400,6 +4459,7 @@ class FusedRoPEFunc(torch.autograd.Function):
cp_size: int = 1, cp_size: int = 1,
cp_rank: int = 0, cp_rank: int = 0,
) -> torch.Tensor: ) -> torch.Tensor:
# pylint: disable=missing-function-docstring
if freqs.dtype != torch.float32: if freqs.dtype != torch.float32:
freqs = freqs.float() freqs = freqs.float()
if tensor_format == "sbhd": if tensor_format == "sbhd":
...@@ -4419,6 +4479,7 @@ class FusedRoPEFunc(torch.autograd.Function): ...@@ -4419,6 +4479,7 @@ class FusedRoPEFunc(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
# pylint: disable=missing-function-docstring
freqs, cu_seqlens = ctx.saved_tensors freqs, cu_seqlens = ctx.saved_tensors
if ctx.tensor_format == "sbhd": if ctx.tensor_format == "sbhd":
grad_input = tex.fused_rope_backward(grad_output, freqs, False) grad_input = tex.fused_rope_backward(grad_output, freqs, False)
...@@ -4525,6 +4586,7 @@ class _SplitAlongDim(torch.autograd.Function): ...@@ -4525,6 +4586,7 @@ class _SplitAlongDim(torch.autograd.Function):
split_dim: int, split_dim: int,
split_size_or_sections: Union[int, List[int], Tuple[int]], split_size_or_sections: Union[int, List[int], Tuple[int]],
) -> Tuple[torch.Tensor, ...]: ) -> Tuple[torch.Tensor, ...]:
# pylint: disable=missing-function-docstring
ctx.split_dim = split_dim ctx.split_dim = split_dim
ctx.split_size_or_sections = split_size_or_sections ctx.split_size_or_sections = split_size_or_sections
if isinstance(mixed_x_layer, Float8Tensor): if isinstance(mixed_x_layer, Float8Tensor):
...@@ -4543,6 +4605,7 @@ class _SplitAlongDim(torch.autograd.Function): ...@@ -4543,6 +4605,7 @@ class _SplitAlongDim(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, *grad_outputs): def backward(ctx, *grad_outputs):
# pylint: disable=missing-function-docstring
assert len(grad_outputs) > 0, "No gradients received for backprop!" assert len(grad_outputs) > 0, "No gradients received for backprop!"
if isinstance(ctx.split_size_or_sections, (list, tuple)): if isinstance(ctx.split_size_or_sections, (list, tuple)):
...@@ -4887,6 +4950,7 @@ class _PrepareQKVForFA(torch.autograd.Function): ...@@ -4887,6 +4950,7 @@ class _PrepareQKVForFA(torch.autograd.Function):
key_layer: torch.Tensor, key_layer: torch.Tensor,
value_layer: torch.Tensor, value_layer: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# pylint: disable=missing-function-docstring
# All inputs received are non-contiguous tensors. # All inputs received are non-contiguous tensors.
# The `query_layer` tensor is used to access the # The `query_layer` tensor is used to access the
# full memory region of the QKV tensor. # full memory region of the QKV tensor.
...@@ -4904,6 +4968,7 @@ class _PrepareQKVForFA(torch.autograd.Function): ...@@ -4904,6 +4968,7 @@ class _PrepareQKVForFA(torch.autograd.Function):
dk: torch.Tensor, dk: torch.Tensor,
dv: torch.Tensor, dv: torch.Tensor,
) -> Tuple[Union[torch.Tensor, None], ...]: ) -> Tuple[Union[torch.Tensor, None], ...]:
# pylint: disable=missing-function-docstring
dqkv = tex.fa_prepare_bwd(dq, dk, dv) dqkv = tex.fa_prepare_bwd(dq, dk, dv)
dq, dk, dv = split_tensor_along_dim(dqkv, -1, 3) dq, dk, dv = split_tensor_along_dim(dqkv, -1, 3)
return dq, dk, dv return dq, dk, dv
...@@ -5055,7 +5120,7 @@ def get_qkv_layout( ...@@ -5055,7 +5120,7 @@ def get_qkv_layout(
q, k, v = [x.contiguous() for x in [q, k, v]] q, k, v = [x.contiguous() for x in [q, k, v]]
qkv_layout = run_iteratively(q, k, v) qkv_layout = run_iteratively(q, k, v)
if qkv_layout == "not_supported": if qkv_layout == "not_supported":
raise Exception("The provided qkv memory layout is not supported!") raise RuntimeError("The provided qkv memory layout is not supported!")
return qkv_layout, q, k, v return qkv_layout, q, k, v
...@@ -5351,9 +5416,9 @@ class FlashAttention(torch.nn.Module): ...@@ -5351,9 +5416,9 @@ class FlashAttention(torch.nn.Module):
fa_3_optional_forward_kwargs = {} fa_3_optional_forward_kwargs = {}
fa_3_optional_forward_kwargs["window_size"] = window_size fa_3_optional_forward_kwargs["window_size"] = window_size
fa_3_optional_forward_kwargs["deterministic"] = self.deterministic fa_3_optional_forward_kwargs["deterministic"] = self.deterministic
activation_dtype = query_layer.dtype
if fp8: if fp8:
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
activation_dtype = query_layer.dtype
torch_dtype = get_fp8_torch_dtype(fp8_meta["recipe"], fprop_tensor=True) torch_dtype = get_fp8_torch_dtype(fp8_meta["recipe"], fprop_tensor=True)
def convert_to_torch_float8(tensor, dtype): def convert_to_torch_float8(tensor, dtype):
...@@ -5511,6 +5576,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function): ...@@ -5511,6 +5576,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
fp8_meta, fp8_meta,
deterministic, deterministic,
): ):
# pylint: disable=missing-function-docstring
is_input_fp8 = False is_input_fp8 = False
is_output_fp8 = fp8_meta["recipe"].fp8_mha is_output_fp8 = fp8_meta["recipe"].fp8_mha
if fp8: if fp8:
...@@ -5664,6 +5730,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function): ...@@ -5664,6 +5730,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, d_out): def backward(ctx, d_out):
# pylint: disable=missing-function-docstring
if ctx.is_output_fp8: if ctx.is_output_fp8:
assert isinstance( assert isinstance(
d_out, Float8Tensor d_out, Float8Tensor
...@@ -5683,12 +5750,12 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function): ...@@ -5683,12 +5750,12 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
fwd_scale_invs, fwd_scale_invs,
*aux_ctx_tensors, *aux_ctx_tensors,
) = ctx.saved_tensors ) = ctx.saved_tensors
rest = [None]
if not aux_ctx_tensors[0].is_contiguous(): if not aux_ctx_tensors[0].is_contiguous():
aux_ctx_tensors[0] = aux_ctx_tensors[0].contiguous() aux_ctx_tensors[0] = aux_ctx_tensors[0].contiguous()
if ctx.use_FAv2_bwd: if ctx.use_FAv2_bwd:
softmax_lse, rng_state = aux_ctx_tensors softmax_lse, rng_state = aux_ctx_tensors
dqkv = torch.empty_like(qkv) dqkv = torch.empty_like(qkv)
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
d_out, q, k, v, out = [ d_out, q, k, v, out = [
maybe_contiguous(x) for x in (d_out, qkv[:, 0], qkv[:, 1], qkv[:, 2], out) maybe_contiguous(x) for x in (d_out, qkv[:, 0], qkv[:, 1], qkv[:, 2], out)
] ]
...@@ -5899,6 +5966,7 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function): ...@@ -5899,6 +5966,7 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
fp8_meta, fp8_meta,
deterministic, deterministic,
): ):
# pylint: disable=missing-function-docstring
is_input_fp8 = False is_input_fp8 = False
is_output_fp8 = fp8_meta["recipe"].fp8_mha is_output_fp8 = fp8_meta["recipe"].fp8_mha
if fp8: if fp8:
...@@ -6080,6 +6148,7 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function): ...@@ -6080,6 +6148,7 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, d_out): def backward(ctx, d_out):
# pylint: disable=missing-function-docstring
if ctx.is_output_fp8: if ctx.is_output_fp8:
assert isinstance( assert isinstance(
d_out, Float8Tensor d_out, Float8Tensor
...@@ -6103,13 +6172,13 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function): ...@@ -6103,13 +6172,13 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
fwd_scale_invs, fwd_scale_invs,
*aux_ctx_tensors, *aux_ctx_tensors,
) = ctx.saved_tensors ) = ctx.saved_tensors
rest = [None]
if not aux_ctx_tensors[0].is_contiguous(): if not aux_ctx_tensors[0].is_contiguous():
aux_ctx_tensors[0] = aux_ctx_tensors[0].contiguous() aux_ctx_tensors[0] = aux_ctx_tensors[0].contiguous()
if ctx.use_FAv2_bwd: if ctx.use_FAv2_bwd:
softmax_lse, rng_state = aux_ctx_tensors softmax_lse, rng_state = aux_ctx_tensors
dq = torch.empty_like(q) dq = torch.empty_like(q)
dkv = torch.empty_like(kv) dkv = torch.empty_like(kv)
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
d_out, q, k, v, out = [maybe_contiguous(x) for x in (d_out, q, kv[:, 0], kv[:, 1], out)] d_out, q, k, v, out = [maybe_contiguous(x) for x in (d_out, q, kv[:, 0], kv[:, 1], out)]
flash_attn_cuda_bwd( flash_attn_cuda_bwd(
d_out, d_out,
...@@ -6351,6 +6420,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -6351,6 +6420,7 @@ class FusedAttnFunc(torch.autograd.Function):
fp8_meta, fp8_meta,
deterministic, deterministic,
): ):
# pylint: disable=missing-function-docstring
is_input_fp8 = False is_input_fp8 = False
is_output_fp8 = fp8_meta["recipe"].fp8_mha is_output_fp8 = fp8_meta["recipe"].fp8_mha
if fp8: if fp8:
...@@ -6616,6 +6686,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -6616,6 +6686,7 @@ class FusedAttnFunc(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, d_out): def backward(ctx, d_out):
# pylint: disable=missing-function-docstring
if ctx.is_output_fp8: if ctx.is_output_fp8:
assert isinstance( assert isinstance(
d_out, Float8Tensor d_out, Float8Tensor
...@@ -6643,12 +6714,12 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -6643,12 +6714,12 @@ class FusedAttnFunc(torch.autograd.Function):
) = ctx.saved_tensors ) = ctx.saved_tensors
if not aux_ctx_tensors[0].is_contiguous(): if not aux_ctx_tensors[0].is_contiguous():
aux_ctx_tensors[0] = aux_ctx_tensors[0].contiguous() aux_ctx_tensors[0] = aux_ctx_tensors[0].contiguous()
rest = [None]
if ctx.use_FAv2_bwd: if ctx.use_FAv2_bwd:
softmax_lse, rng_state = aux_ctx_tensors softmax_lse, rng_state = aux_ctx_tensors
dq = torch.empty_like(q) dq = torch.empty_like(q)
dk = torch.empty_like(k) dk = torch.empty_like(k)
dv = torch.empty_like(v) dv = torch.empty_like(v)
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
d_out, q, k, v, out = [maybe_contiguous(x) for x in (d_out, q, k, v, out)] d_out, q, k, v, out = [maybe_contiguous(x) for x in (d_out, q, k, v, out)]
flash_attn_cuda_bwd( flash_attn_cuda_bwd(
d_out, d_out,
...@@ -7859,7 +7930,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -7859,7 +7930,7 @@ class DotProductAttention(TransformerEngineBaseModule):
if qkv_format == "sbhd": if qkv_format == "sbhd":
max_seqlen_q, max_seqlen_kv = (query_layer.shape[0], key_layer.shape[0]) max_seqlen_q, max_seqlen_kv = (query_layer.shape[0], key_layer.shape[0])
batch_size = query_layer.shape[1] batch_size = query_layer.shape[1]
if qkv_format == "bshd": else:
max_seqlen_q, max_seqlen_kv = (query_layer.shape[1], key_layer.shape[1]) max_seqlen_q, max_seqlen_kv = (query_layer.shape[1], key_layer.shape[1])
batch_size = query_layer.shape[0] batch_size = query_layer.shape[0]
max_seqlen_q *= cp_size max_seqlen_q *= cp_size
...@@ -8168,7 +8239,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -8168,7 +8239,7 @@ class DotProductAttention(TransformerEngineBaseModule):
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
) )
raise Exception("No dot product attention support for the provided inputs!") raise ValueError("No dot product attention support for the provided inputs!")
class MultiheadAttention(torch.nn.Module): class MultiheadAttention(torch.nn.Module):
...@@ -8522,6 +8593,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -8522,6 +8593,7 @@ class MultiheadAttention(torch.nn.Module):
def _allocate_memory( def _allocate_memory(
self, inference_max_sequence_len: int, batch_size: int, dtype: torch.dtype self, inference_max_sequence_len: int, batch_size: int, dtype: torch.dtype
) -> torch.Tensor: ) -> torch.Tensor:
"""Allocates memory for KV cache."""
return torch.empty( return torch.empty(
inference_max_sequence_len, inference_max_sequence_len,
batch_size, batch_size,
...@@ -8692,10 +8764,8 @@ class MultiheadAttention(torch.nn.Module): ...@@ -8692,10 +8764,8 @@ class MultiheadAttention(torch.nn.Module):
window_size = check_set_window_size(attn_mask_type, window_size) window_size = check_set_window_size(attn_mask_type, window_size)
if "padding" in attn_mask_type and attention_mask is not None: if "padding" in attn_mask_type and attention_mask is not None:
for i, _ in enumerate(attention_mask): for mask in attention_mask:
assert ( assert mask.dtype == torch.bool, "Attention mask must be in boolean type!"
attention_mask[i].dtype == torch.bool
), "Attention mask must be in boolean type!"
assert ( assert (
core_attention_bias_type in AttnBiasTypes core_attention_bias_type in AttnBiasTypes
...@@ -8737,6 +8807,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -8737,6 +8807,7 @@ class MultiheadAttention(torch.nn.Module):
and FP8GlobalStateManager.get_fp8_recipe().fp8_mha and FP8GlobalStateManager.get_fp8_recipe().fp8_mha
) )
layernorm_output = None
if self.attention_type == "self": if self.attention_type == "self":
# Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn] # Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn]
if self.input_layernorm: if self.input_layernorm:
...@@ -8904,6 +8975,8 @@ class MultiheadAttention(torch.nn.Module): ...@@ -8904,6 +8975,8 @@ class MultiheadAttention(torch.nn.Module):
sequence_length = key_layer.size(0) sequence_length = key_layer.size(0)
elif self.qkv_format == "bshd": elif self.qkv_format == "bshd":
sequence_length = key_layer.size(1) sequence_length = key_layer.size(1)
else:
raise ValueError(f"QKV format {self.qkv_format} not supported for KV caching.")
sequence_start = inference_params.sequence_len_offset sequence_start = inference_params.sequence_len_offset
sequence_end = sequence_start + sequence_length sequence_end = sequence_start + sequence_length
......
...@@ -78,10 +78,10 @@ def canonicalize_fp8_scales( ...@@ -78,10 +78,10 @@ def canonicalize_fp8_scales(
scale_inv_offset = 0 scale_inv_offset = 0
# Pack tensors and offsets into dicts # Pack tensors and offsets into dicts
tensors = dict(scale=scale, amax=amax, scale_inv=scale_inv) tensors = {"scale": scale, "amax": amax, "scale_inv": scale_inv}
offsets = dict( offsets = {
scale_offset=scale_offset, "scale_offset": scale_offset,
amax_offset=amax_offset, "amax_offset": amax_offset,
scale_inv_offset=scale_inv_offset, "scale_inv_offset": scale_inv_offset,
) }
return tensors, offsets return tensors, offsets
...@@ -240,13 +240,11 @@ def fused_attn_fwd_qkvpacked( ...@@ -240,13 +240,11 @@ def fused_attn_fwd_qkvpacked(
rng_elts_per_thread = ( rng_elts_per_thread = (
max_seqlen * max_seqlen + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1 max_seqlen * max_seqlen + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1
) // BACKEND_F16m512_FP8_THREADS_PER_CTA ) // BACKEND_F16m512_FP8_THREADS_PER_CTA
# BF16/FP16 fused attention API from fmha_v2 # BF16/FP16 fused attention API from fmha_v2
if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]: elif fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS
# FP8 fused attention API from fmha_v2 # FP8 fused attention API from fmha_v2
if fused_attention_backend == FusedAttnBackend["FP8"]: elif fused_attention_backend == FusedAttnBackend["FP8"]:
rng_elts_per_thread = ( rng_elts_per_thread = (
max_seqlen * max_seqlen + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1 max_seqlen * max_seqlen + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1
) // BACKEND_F16m512_FP8_THREADS_PER_CTA ) // BACKEND_F16m512_FP8_THREADS_PER_CTA
...@@ -259,6 +257,8 @@ def fused_attn_fwd_qkvpacked( ...@@ -259,6 +257,8 @@ def fused_attn_fwd_qkvpacked(
assert q_scale_o is not None, "q_scale_o is required as an input for FP8 fused attention." assert q_scale_o is not None, "q_scale_o is required as an input for FP8 fused attention."
assert amax_s is not None, "amax_s is required as an input for FP8 fused attention." assert amax_s is not None, "amax_s is required as an input for FP8 fused attention."
assert amax_o is not None, "amax_o is required as an input for FP8 fused attention." assert amax_o is not None, "amax_o is required as an input for FP8 fused attention."
else:
raise ValueError(f"Unsupported backend {fused_attention_backend}")
# execute kernel # execute kernel
output_tensors = tex.fused_attn_fwd_qkvpacked( output_tensors = tex.fused_attn_fwd_qkvpacked(
...@@ -633,13 +633,11 @@ def fused_attn_fwd_kvpacked( ...@@ -633,13 +633,11 @@ def fused_attn_fwd_kvpacked(
rng_elts_per_thread = ( rng_elts_per_thread = (
max_seqlen_q * max_seqlen_kv + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1 max_seqlen_q * max_seqlen_kv + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1
) // BACKEND_F16m512_FP8_THREADS_PER_CTA ) // BACKEND_F16m512_FP8_THREADS_PER_CTA
# BF16/FP16 fused attention API from fmha_v2 # BF16/FP16 fused attention API from fmha_v2
if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]: elif fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS
# FP8 fused attention API from fmha_v2 # FP8 fused attention API from fmha_v2
if fused_attention_backend == FusedAttnBackend["FP8"]: elif fused_attention_backend == FusedAttnBackend["FP8"]:
rng_elts_per_thread = ( rng_elts_per_thread = (
max_seqlen_q * max_seqlen_q + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1 max_seqlen_q * max_seqlen_q + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1
) // BACKEND_F16m512_FP8_THREADS_PER_CTA ) // BACKEND_F16m512_FP8_THREADS_PER_CTA
...@@ -652,6 +650,8 @@ def fused_attn_fwd_kvpacked( ...@@ -652,6 +650,8 @@ def fused_attn_fwd_kvpacked(
assert q_scale_o is not None, "q_scale_o is required as an input for FP8 fused attention." assert q_scale_o is not None, "q_scale_o is required as an input for FP8 fused attention."
assert amax_s is not None, "amax_s is required as an input for FP8 fused attention." assert amax_s is not None, "amax_s is required as an input for FP8 fused attention."
assert amax_o is not None, "amax_o is required as an input for FP8 fused attention." assert amax_o is not None, "amax_o is required as an input for FP8 fused attention."
else:
raise ValueError(f"Unsupported backend {fused_attention_backend}")
# execute kernel # execute kernel
output_tensors = tex.fused_attn_fwd_kvpacked( output_tensors = tex.fused_attn_fwd_kvpacked(
...@@ -1058,13 +1058,11 @@ def fused_attn_fwd( ...@@ -1058,13 +1058,11 @@ def fused_attn_fwd(
rng_elts_per_thread = ( rng_elts_per_thread = (
max_seqlen_q * max_seqlen_kv + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1 max_seqlen_q * max_seqlen_kv + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1
) // BACKEND_F16m512_FP8_THREADS_PER_CTA ) // BACKEND_F16m512_FP8_THREADS_PER_CTA
# BF16/FP16 fused attention API from fmha_v2 # BF16/FP16 fused attention API from fmha_v2
if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]: elif fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS
# FP8 fused attention API from fmha_v2 # FP8 fused attention API from fmha_v2
if fused_attention_backend == FusedAttnBackend["FP8"]: elif fused_attention_backend == FusedAttnBackend["FP8"]:
rng_elts_per_thread = ( rng_elts_per_thread = (
max_seqlen_q * max_seqlen_q + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1 max_seqlen_q * max_seqlen_q + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1
) // BACKEND_F16m512_FP8_THREADS_PER_CTA ) // BACKEND_F16m512_FP8_THREADS_PER_CTA
...@@ -1077,6 +1075,8 @@ def fused_attn_fwd( ...@@ -1077,6 +1075,8 @@ def fused_attn_fwd(
assert q_scale_o is not None, "q_scale_o is required as an input for FP8 fused attention." assert q_scale_o is not None, "q_scale_o is required as an input for FP8 fused attention."
assert amax_s is not None, "amax_s is required as an input for FP8 fused attention." assert amax_s is not None, "amax_s is required as an input for FP8 fused attention."
assert amax_o is not None, "amax_o is required as an input for FP8 fused attention." assert amax_o is not None, "amax_o is required as an input for FP8 fused attention."
else:
raise ValueError(f"Unsupported backend {fused_attention_backend}")
# execute kernel # execute kernel
output_tensors = tex.fused_attn_fwd( output_tensors = tex.fused_attn_fwd(
......
...@@ -161,6 +161,7 @@ class GroupCommitFunction(torch.autograd.Function): ...@@ -161,6 +161,7 @@ class GroupCommitFunction(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, tensor, cpu_offload_handler): def forward(ctx, tensor, cpu_offload_handler):
# pylint: disable=missing-function-docstring
cpu_offload_handler.on_group_commit_forward() cpu_offload_handler.on_group_commit_forward()
ctx.cpu_offload_handler = cpu_offload_handler ctx.cpu_offload_handler = cpu_offload_handler
# return the identical tensor # return the identical tensor
...@@ -168,6 +169,7 @@ class GroupCommitFunction(torch.autograd.Function): ...@@ -168,6 +169,7 @@ class GroupCommitFunction(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
# pylint: disable=missing-function-docstring
cpu_offload_handler = ctx.cpu_offload_handler cpu_offload_handler = ctx.cpu_offload_handler
cpu_offload_handler.on_group_commit_backward() cpu_offload_handler.on_group_commit_backward()
return grad_output, None return grad_output, None
......
...@@ -752,11 +752,11 @@ class CudaRNGStatesTracker: ...@@ -752,11 +752,11 @@ class CudaRNGStatesTracker:
""" """
# Check seed is not already used. # Check seed is not already used.
if seed in self.seeds_: if seed in self.seeds_:
raise Exception(f"seed {seed} already exists") raise RuntimeError(f"seed {seed} already exists")
self.seeds_.add(seed) self.seeds_.add(seed)
# Check that state is not already defined. # Check that state is not already defined.
if name in self.states_: if name in self.states_:
raise Exception(f"cuda rng state {name} already exists") raise RuntimeError(f"cuda rng state {name} already exists")
if graph_safe_rng_available(): if graph_safe_rng_available():
new_state = _get_cuda_rng_state(clone=True) new_state = _get_cuda_rng_state(clone=True)
...@@ -786,7 +786,7 @@ class CudaRNGStatesTracker: ...@@ -786,7 +786,7 @@ class CudaRNGStatesTracker:
""" """
# Check if we have added the state # Check if we have added the state
if name not in self.states_: if name not in self.states_:
raise Exception(f"cuda rng state {name} is not added") raise KeyError(f"cuda rng state {name} is not added")
# Get the reference to current rng state. # Get the reference to current rng state.
orig_cuda_rng_state = _get_cuda_rng_state() orig_cuda_rng_state = _get_cuda_rng_state()
# Set rng state to the desired one # Set rng state to the desired one
......
...@@ -358,6 +358,7 @@ def _make_graphed_callables( ...@@ -358,6 +358,7 @@ def _make_graphed_callables(
@staticmethod @staticmethod
def forward(ctx, skip_fp8_weight_update, *inputs): def forward(ctx, skip_fp8_weight_update, *inputs):
# pylint: disable=missing-function-docstring
# Set flag for whether to update FP8 weight updates # Set flag for whether to update FP8 weight updates
ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module() ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module()
...@@ -377,6 +378,7 @@ def _make_graphed_callables( ...@@ -377,6 +378,7 @@ def _make_graphed_callables(
@staticmethod @staticmethod
@torch.autograd.function.once_differentiable @torch.autograd.function.once_differentiable
def backward(ctx, *grads): def backward(ctx, *grads):
# pylint: disable=missing-function-docstring
# Replay backward graph # Replay backward graph
assert len(grads) == len(static_grad_outputs) assert len(grads) == len(static_grad_outputs)
......
...@@ -8,6 +8,8 @@ from typing import Callable, Optional, Tuple ...@@ -8,6 +8,8 @@ from typing import Callable, Optional, Tuple
import torch import torch
# pylint: disable=unnecessary-lambda-assignment
jit_fuser = torch.jit.script jit_fuser = torch.jit.script
if torch.__version__ >= "2" and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))): if torch.__version__ >= "2" and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))):
jit_fuser = torch.compile jit_fuser = torch.compile
......
...@@ -124,6 +124,7 @@ class _NoopCatFunc(torch.autograd.Function): ...@@ -124,6 +124,7 @@ class _NoopCatFunc(torch.autograd.Function):
dim: int, dim: int,
*tensors: Tuple[torch.Tensor, ...], *tensors: Tuple[torch.Tensor, ...],
) -> torch.Tensor: ) -> torch.Tensor:
# pylint: disable=missing-function-docstring
# Check first tensor # Check first tensor
if not tensors: if not tensors:
...@@ -192,6 +193,7 @@ class _NoopCatFunc(torch.autograd.Function): ...@@ -192,6 +193,7 @@ class _NoopCatFunc(torch.autograd.Function):
ctx, ctx,
grad_output: torch.Tensor, grad_output: torch.Tensor,
) -> Tuple[Optional[torch.Tensor], ...]: ) -> Tuple[Optional[torch.Tensor], ...]:
# pylint: disable=missing-function-docstring
grad_inputs = [] grad_inputs = []
for split_start, split_end in ctx.split_ranges: for split_start, split_end in ctx.split_ranges:
slices = [slice(None)] * grad_output.dim() slices = [slice(None)] * grad_output.dim()
......
...@@ -694,7 +694,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -694,7 +694,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
else: else:
# If fp8 isn't enabled, turn off and return. # If fp8 isn't enabled, turn off and return.
self.fp8_initialized = False self.fp8_initialized = False
return
@contextmanager @contextmanager
def prepare_forward( def prepare_forward(
...@@ -744,7 +743,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -744,7 +743,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if self.fp8 and in_fp8_activation_recompute_phase(): if self.fp8 and in_fp8_activation_recompute_phase():
FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta) FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta)
return
def set_nccl_overlap_warning_if_tp(self) -> None: def set_nccl_overlap_warning_if_tp(self) -> None:
"""When using TP, the NCCL communication needs to be scheduled """When using TP, the NCCL communication needs to be scheduled
......
...@@ -28,6 +28,7 @@ class _Fp8Padding(torch.autograd.Function): ...@@ -28,6 +28,7 @@ class _Fp8Padding(torch.autograd.Function):
padded_m_splits: List[int], padded_m_splits: List[int],
is_grad_enabled: bool, is_grad_enabled: bool,
) -> torch.Tensor: ) -> torch.Tensor:
# pylint: disable=missing-function-docstring
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
in_features = inp.shape[-1] in_features = inp.shape[-1]
...@@ -46,6 +47,7 @@ class _Fp8Padding(torch.autograd.Function): ...@@ -46,6 +47,7 @@ class _Fp8Padding(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, grad_output: torch.Tensor): def backward(ctx, grad_output: torch.Tensor):
# pylint: disable=missing-function-docstring
grad_input = None grad_input = None
if ctx.requires_dgrad: if ctx.requires_dgrad:
......
...@@ -28,6 +28,7 @@ class _Fp8Unpadding(torch.autograd.Function): ...@@ -28,6 +28,7 @@ class _Fp8Unpadding(torch.autograd.Function):
padded_m_splits: List[int], padded_m_splits: List[int],
is_grad_enabled: bool, is_grad_enabled: bool,
) -> torch.Tensor: ) -> torch.Tensor:
# pylint: disable=missing-function-docstring
inputmats = torch.split(inp.view(-1, inp.shape[-1]), padded_m_splits) inputmats = torch.split(inp.view(-1, inp.shape[-1]), padded_m_splits)
out_ret = torch.cat( out_ret = torch.cat(
[grad_output_mat[: m_splits[i]] for i, grad_output_mat in enumerate(inputmats)], dim=0 [grad_output_mat[: m_splits[i]] for i, grad_output_mat in enumerate(inputmats)], dim=0
...@@ -42,6 +43,7 @@ class _Fp8Unpadding(torch.autograd.Function): ...@@ -42,6 +43,7 @@ class _Fp8Unpadding(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, grad_output: torch.Tensor): def backward(ctx, grad_output: torch.Tensor):
# pylint: disable=missing-function-docstring
grad_input = None grad_input = None
if ctx.requires_dgrad: if ctx.requires_dgrad:
grad_output = grad_output.contiguous() grad_output = grad_output.contiguous()
......
...@@ -70,6 +70,7 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -70,6 +70,7 @@ class _GroupedLinear(torch.autograd.Function):
weights_fp8: List[Union[Float8Tensor, None]], weights_fp8: List[Union[Float8Tensor, None]],
*weights_and_biases: Union[Float8Tensor, torch.Tensor, None], *weights_and_biases: Union[Float8Tensor, torch.Tensor, None],
) -> torch.Tensor: ) -> torch.Tensor:
# pylint: disable=missing-function-docstring
num_gemms = len(m_splits) num_gemms = len(m_splits)
weights = weights_and_biases[:num_gemms] weights = weights_and_biases[:num_gemms]
biases = weights_and_biases[num_gemms:] biases = weights_and_biases[num_gemms:]
...@@ -268,6 +269,7 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -268,6 +269,7 @@ class _GroupedLinear(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
# pylint: disable=missing-function-docstring
with torch.cuda.nvtx.range("_GroupedLinear_backward"): with torch.cuda.nvtx.range("_GroupedLinear_backward"):
( (
inputmat_scale_inv, inputmat_scale_inv,
...@@ -641,7 +643,7 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -641,7 +643,7 @@ class GroupedLinear(TransformerEngineBaseModule):
if self.primary_weights_in_fp8: if self.primary_weights_in_fp8:
self.init_fp8_metadata(num_gemms=self.num_gemms) self.init_fp8_metadata(num_gemms=self.num_gemms)
self.reset_parameters(defer_init=(device == "meta")) self.reset_parameters(defer_init=device == "meta")
# For RPL, bias has to be added after TP collectives # For RPL, bias has to be added after TP collectives
# So it cannot be fused with the GEMM # So it cannot be fused with the GEMM
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment