Unverified Commit 6e90fcb7 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Upgrade pylint to 3.3.1 (#1257)



* Upgrade pylint and first round formatting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* round 2
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* round 3
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Format and fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Paddle lint
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Reviews
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* FIxes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* More linting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Run formatter
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Paddle lint
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 161b1d98
......@@ -8,7 +8,9 @@ extension-pkg-whitelist=flash_attn_2_cuda,
extension-pkg-allow-list=transformer_engine.transformer_engine_jax
disable=too-many-locals,
too-few-public-methods,
too-many-public-methods,
too-many-positional-arguments,
invalid-name,
too-many-arguments,
abstract-method,
......
......@@ -6,7 +6,7 @@ set -e
: "${TE_PATH:=/opt/transformerengine}"
pip install cpplint==1.6.0 pylint==2.13.5
pip install cpplint==1.6.0 pylint==3.3.1
if [ -z "${PYTHON_ONLY}" ]
then
cd $TE_PATH
......
......@@ -6,7 +6,7 @@ set -e
: "${TE_PATH:=/opt/transformerengine}"
pip install cpplint==1.6.0 pylint==2.13.5
pip install cpplint==1.6.0 pylint==3.3.1
if [ -z "${PYTHON_ONLY}" ]
then
cd $TE_PATH
......
......@@ -6,7 +6,7 @@ set -e
: "${TE_PATH:=/opt/transformerengine}"
pip install cpplint==1.6.0 pylint==2.13.5
pip install cpplint==1.6.0 pylint==3.3.1
if [ -z "${PYTHON_ONLY}" ]
then
cd $TE_PATH
......
......@@ -583,6 +583,7 @@ def fused_attn_fwd_qkvpacked(
fused_attention_backend != FusedAttnBackend["No_Backend"]
), "Fused attention does not support this input combination."
rng_elts_per_thread = None
# BF16/FP16 fused attention API from fmha_v1 apex
if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]:
rng_elts_per_thread = (
......@@ -773,6 +774,7 @@ def fused_attn_fwd_kvpacked(
fused_attention_backend != FusedAttnBackend["No_Backend"]
), "Fused attention does not support this input combination."
rng_elts_per_thread = None
# BF16/FP16 fused attention API from fmha_v1 apex
if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]:
rng_elts_per_thread = (
......@@ -982,6 +984,7 @@ def fused_attn_fwd(
fused_attention_backend != FusedAttnBackend["No_Backend"]
), "Fused attention does not support this input combination."
rng_elts_per_thread = None
# BF16/FP16 fused attention API from fmha_v1 apex
if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]:
rng_elts_per_thread = (
......
......@@ -100,6 +100,7 @@ class FP8MetaBufferBase(ABC):
self._dp_amax_reduce_interval = int(os.getenv("NVTE_DP_AMAX_REDUCE_INTERVAL", "1"))
tp_amax_reduce = False
reduce_group = -1 # Set value that will raise error if not set. `None` is a valid group.
if self._dp_amax_reduce_idx == 0:
reduce_group = fp8_meta["fp8_group"]
else:
......
......@@ -1008,6 +1008,7 @@ class MultiHeadAttention(paddle.nn.Layer):
else:
raise ValueError(f"hidden_states should have 2 or 3 dimensions, got {input_dim}.")
layernorm_output = None
if self.attention_type == "self":
if self.input_layernorm:
layernorm_qkv_outputs = self.layernorm_qkv(
......
......@@ -266,6 +266,8 @@ def _mlp_backward(
accumulate_wgrad_into_param_main_grad,
)
dgelu_t = None
fc1_bgrad_ = None
if activation == "gelu":
# GELU Bwd
dgelu, dgelu_t, fc1_bgrad_ = dgelu_cast_transpose_bgrad_fp8(
......
This diff is collapsed.
......@@ -78,10 +78,10 @@ def canonicalize_fp8_scales(
scale_inv_offset = 0
# Pack tensors and offsets into dicts
tensors = dict(scale=scale, amax=amax, scale_inv=scale_inv)
offsets = dict(
scale_offset=scale_offset,
amax_offset=amax_offset,
scale_inv_offset=scale_inv_offset,
)
tensors = {"scale": scale, "amax": amax, "scale_inv": scale_inv}
offsets = {
"scale_offset": scale_offset,
"amax_offset": amax_offset,
"scale_inv_offset": scale_inv_offset,
}
return tensors, offsets
......@@ -240,13 +240,11 @@ def fused_attn_fwd_qkvpacked(
rng_elts_per_thread = (
max_seqlen * max_seqlen + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1
) // BACKEND_F16m512_FP8_THREADS_PER_CTA
# BF16/FP16 fused attention API from fmha_v2
if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
elif fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS
# FP8 fused attention API from fmha_v2
if fused_attention_backend == FusedAttnBackend["FP8"]:
elif fused_attention_backend == FusedAttnBackend["FP8"]:
rng_elts_per_thread = (
max_seqlen * max_seqlen + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1
) // BACKEND_F16m512_FP8_THREADS_PER_CTA
......@@ -259,6 +257,8 @@ def fused_attn_fwd_qkvpacked(
assert q_scale_o is not None, "q_scale_o is required as an input for FP8 fused attention."
assert amax_s is not None, "amax_s is required as an input for FP8 fused attention."
assert amax_o is not None, "amax_o is required as an input for FP8 fused attention."
else:
raise ValueError(f"Unsupported backend {fused_attention_backend}")
# execute kernel
output_tensors = tex.fused_attn_fwd_qkvpacked(
......@@ -633,13 +633,11 @@ def fused_attn_fwd_kvpacked(
rng_elts_per_thread = (
max_seqlen_q * max_seqlen_kv + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1
) // BACKEND_F16m512_FP8_THREADS_PER_CTA
# BF16/FP16 fused attention API from fmha_v2
if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
elif fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS
# FP8 fused attention API from fmha_v2
if fused_attention_backend == FusedAttnBackend["FP8"]:
elif fused_attention_backend == FusedAttnBackend["FP8"]:
rng_elts_per_thread = (
max_seqlen_q * max_seqlen_q + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1
) // BACKEND_F16m512_FP8_THREADS_PER_CTA
......@@ -652,6 +650,8 @@ def fused_attn_fwd_kvpacked(
assert q_scale_o is not None, "q_scale_o is required as an input for FP8 fused attention."
assert amax_s is not None, "amax_s is required as an input for FP8 fused attention."
assert amax_o is not None, "amax_o is required as an input for FP8 fused attention."
else:
raise ValueError(f"Unsupported backend {fused_attention_backend}")
# execute kernel
output_tensors = tex.fused_attn_fwd_kvpacked(
......@@ -1058,13 +1058,11 @@ def fused_attn_fwd(
rng_elts_per_thread = (
max_seqlen_q * max_seqlen_kv + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1
) // BACKEND_F16m512_FP8_THREADS_PER_CTA
# BF16/FP16 fused attention API from fmha_v2
if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
elif fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS
# FP8 fused attention API from fmha_v2
if fused_attention_backend == FusedAttnBackend["FP8"]:
elif fused_attention_backend == FusedAttnBackend["FP8"]:
rng_elts_per_thread = (
max_seqlen_q * max_seqlen_q + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1
) // BACKEND_F16m512_FP8_THREADS_PER_CTA
......@@ -1077,6 +1075,8 @@ def fused_attn_fwd(
assert q_scale_o is not None, "q_scale_o is required as an input for FP8 fused attention."
assert amax_s is not None, "amax_s is required as an input for FP8 fused attention."
assert amax_o is not None, "amax_o is required as an input for FP8 fused attention."
else:
raise ValueError(f"Unsupported backend {fused_attention_backend}")
# execute kernel
output_tensors = tex.fused_attn_fwd(
......
......@@ -161,6 +161,7 @@ class GroupCommitFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, tensor, cpu_offload_handler):
# pylint: disable=missing-function-docstring
cpu_offload_handler.on_group_commit_forward()
ctx.cpu_offload_handler = cpu_offload_handler
# return the identical tensor
......@@ -168,6 +169,7 @@ class GroupCommitFunction(torch.autograd.Function):
@staticmethod
def backward(ctx, grad_output):
# pylint: disable=missing-function-docstring
cpu_offload_handler = ctx.cpu_offload_handler
cpu_offload_handler.on_group_commit_backward()
return grad_output, None
......
......@@ -752,11 +752,11 @@ class CudaRNGStatesTracker:
"""
# Check seed is not already used.
if seed in self.seeds_:
raise Exception(f"seed {seed} already exists")
raise RuntimeError(f"seed {seed} already exists")
self.seeds_.add(seed)
# Check that state is not already defined.
if name in self.states_:
raise Exception(f"cuda rng state {name} already exists")
raise RuntimeError(f"cuda rng state {name} already exists")
if graph_safe_rng_available():
new_state = _get_cuda_rng_state(clone=True)
......@@ -786,7 +786,7 @@ class CudaRNGStatesTracker:
"""
# Check if we have added the state
if name not in self.states_:
raise Exception(f"cuda rng state {name} is not added")
raise KeyError(f"cuda rng state {name} is not added")
# Get the reference to current rng state.
orig_cuda_rng_state = _get_cuda_rng_state()
# Set rng state to the desired one
......
......@@ -358,6 +358,7 @@ def _make_graphed_callables(
@staticmethod
def forward(ctx, skip_fp8_weight_update, *inputs):
# pylint: disable=missing-function-docstring
# Set flag for whether to update FP8 weight updates
ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module()
......@@ -377,6 +378,7 @@ def _make_graphed_callables(
@staticmethod
@torch.autograd.function.once_differentiable
def backward(ctx, *grads):
# pylint: disable=missing-function-docstring
# Replay backward graph
assert len(grads) == len(static_grad_outputs)
......
......@@ -8,6 +8,8 @@ from typing import Callable, Optional, Tuple
import torch
# pylint: disable=unnecessary-lambda-assignment
jit_fuser = torch.jit.script
if torch.__version__ >= "2" and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))):
jit_fuser = torch.compile
......
......@@ -124,6 +124,7 @@ class _NoopCatFunc(torch.autograd.Function):
dim: int,
*tensors: Tuple[torch.Tensor, ...],
) -> torch.Tensor:
# pylint: disable=missing-function-docstring
# Check first tensor
if not tensors:
......@@ -192,6 +193,7 @@ class _NoopCatFunc(torch.autograd.Function):
ctx,
grad_output: torch.Tensor,
) -> Tuple[Optional[torch.Tensor], ...]:
# pylint: disable=missing-function-docstring
grad_inputs = []
for split_start, split_end in ctx.split_ranges:
slices = [slice(None)] * grad_output.dim()
......
......@@ -694,7 +694,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
else:
# If fp8 isn't enabled, turn off and return.
self.fp8_initialized = False
return
@contextmanager
def prepare_forward(
......@@ -744,7 +743,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if self.fp8 and in_fp8_activation_recompute_phase():
FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta)
return
def set_nccl_overlap_warning_if_tp(self) -> None:
"""When using TP, the NCCL communication needs to be scheduled
......
......@@ -28,6 +28,7 @@ class _Fp8Padding(torch.autograd.Function):
padded_m_splits: List[int],
is_grad_enabled: bool,
) -> torch.Tensor:
# pylint: disable=missing-function-docstring
# Make sure input dimensions are compatible
in_features = inp.shape[-1]
......@@ -46,6 +47,7 @@ class _Fp8Padding(torch.autograd.Function):
@staticmethod
def backward(ctx, grad_output: torch.Tensor):
# pylint: disable=missing-function-docstring
grad_input = None
if ctx.requires_dgrad:
......
......@@ -28,6 +28,7 @@ class _Fp8Unpadding(torch.autograd.Function):
padded_m_splits: List[int],
is_grad_enabled: bool,
) -> torch.Tensor:
# pylint: disable=missing-function-docstring
inputmats = torch.split(inp.view(-1, inp.shape[-1]), padded_m_splits)
out_ret = torch.cat(
[grad_output_mat[: m_splits[i]] for i, grad_output_mat in enumerate(inputmats)], dim=0
......@@ -42,6 +43,7 @@ class _Fp8Unpadding(torch.autograd.Function):
@staticmethod
def backward(ctx, grad_output: torch.Tensor):
# pylint: disable=missing-function-docstring
grad_input = None
if ctx.requires_dgrad:
grad_output = grad_output.contiguous()
......
......@@ -70,6 +70,7 @@ class _GroupedLinear(torch.autograd.Function):
weights_fp8: List[Union[Float8Tensor, None]],
*weights_and_biases: Union[Float8Tensor, torch.Tensor, None],
) -> torch.Tensor:
# pylint: disable=missing-function-docstring
num_gemms = len(m_splits)
weights = weights_and_biases[:num_gemms]
biases = weights_and_biases[num_gemms:]
......@@ -268,6 +269,7 @@ class _GroupedLinear(torch.autograd.Function):
@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
# pylint: disable=missing-function-docstring
with torch.cuda.nvtx.range("_GroupedLinear_backward"):
(
inputmat_scale_inv,
......@@ -641,7 +643,7 @@ class GroupedLinear(TransformerEngineBaseModule):
if self.primary_weights_in_fp8:
self.init_fp8_metadata(num_gemms=self.num_gemms)
self.reset_parameters(defer_init=(device == "meta"))
self.reset_parameters(defer_init=device == "meta")
# For RPL, bias has to be added after TP collectives
# So it cannot be fused with the GEMM
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment