Unverified Commit 6e90fcb7 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Upgrade pylint to 3.3.1 (#1257)



* Upgrade pylint and first round formatting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* round 2
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* round 3
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Format and fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Paddle lint
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Reviews
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* FIxes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* More linting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Run formatter
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Paddle lint
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 161b1d98
...@@ -8,7 +8,9 @@ extension-pkg-whitelist=flash_attn_2_cuda, ...@@ -8,7 +8,9 @@ extension-pkg-whitelist=flash_attn_2_cuda,
extension-pkg-allow-list=transformer_engine.transformer_engine_jax extension-pkg-allow-list=transformer_engine.transformer_engine_jax
disable=too-many-locals, disable=too-many-locals,
too-few-public-methods,
too-many-public-methods, too-many-public-methods,
too-many-positional-arguments,
invalid-name, invalid-name,
too-many-arguments, too-many-arguments,
abstract-method, abstract-method,
......
...@@ -6,7 +6,7 @@ set -e ...@@ -6,7 +6,7 @@ set -e
: "${TE_PATH:=/opt/transformerengine}" : "${TE_PATH:=/opt/transformerengine}"
pip install cpplint==1.6.0 pylint==2.13.5 pip install cpplint==1.6.0 pylint==3.3.1
if [ -z "${PYTHON_ONLY}" ] if [ -z "${PYTHON_ONLY}" ]
then then
cd $TE_PATH cd $TE_PATH
......
...@@ -6,7 +6,7 @@ set -e ...@@ -6,7 +6,7 @@ set -e
: "${TE_PATH:=/opt/transformerengine}" : "${TE_PATH:=/opt/transformerengine}"
pip install cpplint==1.6.0 pylint==2.13.5 pip install cpplint==1.6.0 pylint==3.3.1
if [ -z "${PYTHON_ONLY}" ] if [ -z "${PYTHON_ONLY}" ]
then then
cd $TE_PATH cd $TE_PATH
......
...@@ -6,7 +6,7 @@ set -e ...@@ -6,7 +6,7 @@ set -e
: "${TE_PATH:=/opt/transformerengine}" : "${TE_PATH:=/opt/transformerengine}"
pip install cpplint==1.6.0 pylint==2.13.5 pip install cpplint==1.6.0 pylint==3.3.1
if [ -z "${PYTHON_ONLY}" ] if [ -z "${PYTHON_ONLY}" ]
then then
cd $TE_PATH cd $TE_PATH
......
...@@ -583,6 +583,7 @@ def fused_attn_fwd_qkvpacked( ...@@ -583,6 +583,7 @@ def fused_attn_fwd_qkvpacked(
fused_attention_backend != FusedAttnBackend["No_Backend"] fused_attention_backend != FusedAttnBackend["No_Backend"]
), "Fused attention does not support this input combination." ), "Fused attention does not support this input combination."
rng_elts_per_thread = None
# BF16/FP16 fused attention API from fmha_v1 apex # BF16/FP16 fused attention API from fmha_v1 apex
if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]: if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]:
rng_elts_per_thread = ( rng_elts_per_thread = (
...@@ -773,6 +774,7 @@ def fused_attn_fwd_kvpacked( ...@@ -773,6 +774,7 @@ def fused_attn_fwd_kvpacked(
fused_attention_backend != FusedAttnBackend["No_Backend"] fused_attention_backend != FusedAttnBackend["No_Backend"]
), "Fused attention does not support this input combination." ), "Fused attention does not support this input combination."
rng_elts_per_thread = None
# BF16/FP16 fused attention API from fmha_v1 apex # BF16/FP16 fused attention API from fmha_v1 apex
if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]: if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]:
rng_elts_per_thread = ( rng_elts_per_thread = (
...@@ -982,6 +984,7 @@ def fused_attn_fwd( ...@@ -982,6 +984,7 @@ def fused_attn_fwd(
fused_attention_backend != FusedAttnBackend["No_Backend"] fused_attention_backend != FusedAttnBackend["No_Backend"]
), "Fused attention does not support this input combination." ), "Fused attention does not support this input combination."
rng_elts_per_thread = None
# BF16/FP16 fused attention API from fmha_v1 apex # BF16/FP16 fused attention API from fmha_v1 apex
if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]: if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]:
rng_elts_per_thread = ( rng_elts_per_thread = (
......
...@@ -100,6 +100,7 @@ class FP8MetaBufferBase(ABC): ...@@ -100,6 +100,7 @@ class FP8MetaBufferBase(ABC):
self._dp_amax_reduce_interval = int(os.getenv("NVTE_DP_AMAX_REDUCE_INTERVAL", "1")) self._dp_amax_reduce_interval = int(os.getenv("NVTE_DP_AMAX_REDUCE_INTERVAL", "1"))
tp_amax_reduce = False tp_amax_reduce = False
reduce_group = -1 # Set value that will raise error if not set. `None` is a valid group.
if self._dp_amax_reduce_idx == 0: if self._dp_amax_reduce_idx == 0:
reduce_group = fp8_meta["fp8_group"] reduce_group = fp8_meta["fp8_group"]
else: else:
......
...@@ -1008,6 +1008,7 @@ class MultiHeadAttention(paddle.nn.Layer): ...@@ -1008,6 +1008,7 @@ class MultiHeadAttention(paddle.nn.Layer):
else: else:
raise ValueError(f"hidden_states should have 2 or 3 dimensions, got {input_dim}.") raise ValueError(f"hidden_states should have 2 or 3 dimensions, got {input_dim}.")
layernorm_output = None
if self.attention_type == "self": if self.attention_type == "self":
if self.input_layernorm: if self.input_layernorm:
layernorm_qkv_outputs = self.layernorm_qkv( layernorm_qkv_outputs = self.layernorm_qkv(
......
...@@ -266,6 +266,8 @@ def _mlp_backward( ...@@ -266,6 +266,8 @@ def _mlp_backward(
accumulate_wgrad_into_param_main_grad, accumulate_wgrad_into_param_main_grad,
) )
dgelu_t = None
fc1_bgrad_ = None
if activation == "gelu": if activation == "gelu":
# GELU Bwd # GELU Bwd
dgelu, dgelu_t, fc1_bgrad_ = dgelu_cast_transpose_bgrad_fp8( dgelu, dgelu_t, fc1_bgrad_ = dgelu_cast_transpose_bgrad_fp8(
......
This diff is collapsed.
...@@ -78,10 +78,10 @@ def canonicalize_fp8_scales( ...@@ -78,10 +78,10 @@ def canonicalize_fp8_scales(
scale_inv_offset = 0 scale_inv_offset = 0
# Pack tensors and offsets into dicts # Pack tensors and offsets into dicts
tensors = dict(scale=scale, amax=amax, scale_inv=scale_inv) tensors = {"scale": scale, "amax": amax, "scale_inv": scale_inv}
offsets = dict( offsets = {
scale_offset=scale_offset, "scale_offset": scale_offset,
amax_offset=amax_offset, "amax_offset": amax_offset,
scale_inv_offset=scale_inv_offset, "scale_inv_offset": scale_inv_offset,
) }
return tensors, offsets return tensors, offsets
...@@ -240,13 +240,11 @@ def fused_attn_fwd_qkvpacked( ...@@ -240,13 +240,11 @@ def fused_attn_fwd_qkvpacked(
rng_elts_per_thread = ( rng_elts_per_thread = (
max_seqlen * max_seqlen + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1 max_seqlen * max_seqlen + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1
) // BACKEND_F16m512_FP8_THREADS_PER_CTA ) // BACKEND_F16m512_FP8_THREADS_PER_CTA
# BF16/FP16 fused attention API from fmha_v2 # BF16/FP16 fused attention API from fmha_v2
if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]: elif fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS
# FP8 fused attention API from fmha_v2 # FP8 fused attention API from fmha_v2
if fused_attention_backend == FusedAttnBackend["FP8"]: elif fused_attention_backend == FusedAttnBackend["FP8"]:
rng_elts_per_thread = ( rng_elts_per_thread = (
max_seqlen * max_seqlen + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1 max_seqlen * max_seqlen + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1
) // BACKEND_F16m512_FP8_THREADS_PER_CTA ) // BACKEND_F16m512_FP8_THREADS_PER_CTA
...@@ -259,6 +257,8 @@ def fused_attn_fwd_qkvpacked( ...@@ -259,6 +257,8 @@ def fused_attn_fwd_qkvpacked(
assert q_scale_o is not None, "q_scale_o is required as an input for FP8 fused attention." assert q_scale_o is not None, "q_scale_o is required as an input for FP8 fused attention."
assert amax_s is not None, "amax_s is required as an input for FP8 fused attention." assert amax_s is not None, "amax_s is required as an input for FP8 fused attention."
assert amax_o is not None, "amax_o is required as an input for FP8 fused attention." assert amax_o is not None, "amax_o is required as an input for FP8 fused attention."
else:
raise ValueError(f"Unsupported backend {fused_attention_backend}")
# execute kernel # execute kernel
output_tensors = tex.fused_attn_fwd_qkvpacked( output_tensors = tex.fused_attn_fwd_qkvpacked(
...@@ -633,13 +633,11 @@ def fused_attn_fwd_kvpacked( ...@@ -633,13 +633,11 @@ def fused_attn_fwd_kvpacked(
rng_elts_per_thread = ( rng_elts_per_thread = (
max_seqlen_q * max_seqlen_kv + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1 max_seqlen_q * max_seqlen_kv + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1
) // BACKEND_F16m512_FP8_THREADS_PER_CTA ) // BACKEND_F16m512_FP8_THREADS_PER_CTA
# BF16/FP16 fused attention API from fmha_v2 # BF16/FP16 fused attention API from fmha_v2
if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]: elif fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS
# FP8 fused attention API from fmha_v2 # FP8 fused attention API from fmha_v2
if fused_attention_backend == FusedAttnBackend["FP8"]: elif fused_attention_backend == FusedAttnBackend["FP8"]:
rng_elts_per_thread = ( rng_elts_per_thread = (
max_seqlen_q * max_seqlen_q + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1 max_seqlen_q * max_seqlen_q + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1
) // BACKEND_F16m512_FP8_THREADS_PER_CTA ) // BACKEND_F16m512_FP8_THREADS_PER_CTA
...@@ -652,6 +650,8 @@ def fused_attn_fwd_kvpacked( ...@@ -652,6 +650,8 @@ def fused_attn_fwd_kvpacked(
assert q_scale_o is not None, "q_scale_o is required as an input for FP8 fused attention." assert q_scale_o is not None, "q_scale_o is required as an input for FP8 fused attention."
assert amax_s is not None, "amax_s is required as an input for FP8 fused attention." assert amax_s is not None, "amax_s is required as an input for FP8 fused attention."
assert amax_o is not None, "amax_o is required as an input for FP8 fused attention." assert amax_o is not None, "amax_o is required as an input for FP8 fused attention."
else:
raise ValueError(f"Unsupported backend {fused_attention_backend}")
# execute kernel # execute kernel
output_tensors = tex.fused_attn_fwd_kvpacked( output_tensors = tex.fused_attn_fwd_kvpacked(
...@@ -1058,13 +1058,11 @@ def fused_attn_fwd( ...@@ -1058,13 +1058,11 @@ def fused_attn_fwd(
rng_elts_per_thread = ( rng_elts_per_thread = (
max_seqlen_q * max_seqlen_kv + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1 max_seqlen_q * max_seqlen_kv + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1
) // BACKEND_F16m512_FP8_THREADS_PER_CTA ) // BACKEND_F16m512_FP8_THREADS_PER_CTA
# BF16/FP16 fused attention API from fmha_v2 # BF16/FP16 fused attention API from fmha_v2
if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]: elif fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS
# FP8 fused attention API from fmha_v2 # FP8 fused attention API from fmha_v2
if fused_attention_backend == FusedAttnBackend["FP8"]: elif fused_attention_backend == FusedAttnBackend["FP8"]:
rng_elts_per_thread = ( rng_elts_per_thread = (
max_seqlen_q * max_seqlen_q + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1 max_seqlen_q * max_seqlen_q + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1
) // BACKEND_F16m512_FP8_THREADS_PER_CTA ) // BACKEND_F16m512_FP8_THREADS_PER_CTA
...@@ -1077,6 +1075,8 @@ def fused_attn_fwd( ...@@ -1077,6 +1075,8 @@ def fused_attn_fwd(
assert q_scale_o is not None, "q_scale_o is required as an input for FP8 fused attention." assert q_scale_o is not None, "q_scale_o is required as an input for FP8 fused attention."
assert amax_s is not None, "amax_s is required as an input for FP8 fused attention." assert amax_s is not None, "amax_s is required as an input for FP8 fused attention."
assert amax_o is not None, "amax_o is required as an input for FP8 fused attention." assert amax_o is not None, "amax_o is required as an input for FP8 fused attention."
else:
raise ValueError(f"Unsupported backend {fused_attention_backend}")
# execute kernel # execute kernel
output_tensors = tex.fused_attn_fwd( output_tensors = tex.fused_attn_fwd(
......
...@@ -161,6 +161,7 @@ class GroupCommitFunction(torch.autograd.Function): ...@@ -161,6 +161,7 @@ class GroupCommitFunction(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, tensor, cpu_offload_handler): def forward(ctx, tensor, cpu_offload_handler):
# pylint: disable=missing-function-docstring
cpu_offload_handler.on_group_commit_forward() cpu_offload_handler.on_group_commit_forward()
ctx.cpu_offload_handler = cpu_offload_handler ctx.cpu_offload_handler = cpu_offload_handler
# return the identical tensor # return the identical tensor
...@@ -168,6 +169,7 @@ class GroupCommitFunction(torch.autograd.Function): ...@@ -168,6 +169,7 @@ class GroupCommitFunction(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
# pylint: disable=missing-function-docstring
cpu_offload_handler = ctx.cpu_offload_handler cpu_offload_handler = ctx.cpu_offload_handler
cpu_offload_handler.on_group_commit_backward() cpu_offload_handler.on_group_commit_backward()
return grad_output, None return grad_output, None
......
...@@ -752,11 +752,11 @@ class CudaRNGStatesTracker: ...@@ -752,11 +752,11 @@ class CudaRNGStatesTracker:
""" """
# Check seed is not already used. # Check seed is not already used.
if seed in self.seeds_: if seed in self.seeds_:
raise Exception(f"seed {seed} already exists") raise RuntimeError(f"seed {seed} already exists")
self.seeds_.add(seed) self.seeds_.add(seed)
# Check that state is not already defined. # Check that state is not already defined.
if name in self.states_: if name in self.states_:
raise Exception(f"cuda rng state {name} already exists") raise RuntimeError(f"cuda rng state {name} already exists")
if graph_safe_rng_available(): if graph_safe_rng_available():
new_state = _get_cuda_rng_state(clone=True) new_state = _get_cuda_rng_state(clone=True)
...@@ -786,7 +786,7 @@ class CudaRNGStatesTracker: ...@@ -786,7 +786,7 @@ class CudaRNGStatesTracker:
""" """
# Check if we have added the state # Check if we have added the state
if name not in self.states_: if name not in self.states_:
raise Exception(f"cuda rng state {name} is not added") raise KeyError(f"cuda rng state {name} is not added")
# Get the reference to current rng state. # Get the reference to current rng state.
orig_cuda_rng_state = _get_cuda_rng_state() orig_cuda_rng_state = _get_cuda_rng_state()
# Set rng state to the desired one # Set rng state to the desired one
......
...@@ -358,6 +358,7 @@ def _make_graphed_callables( ...@@ -358,6 +358,7 @@ def _make_graphed_callables(
@staticmethod @staticmethod
def forward(ctx, skip_fp8_weight_update, *inputs): def forward(ctx, skip_fp8_weight_update, *inputs):
# pylint: disable=missing-function-docstring
# Set flag for whether to update FP8 weight updates # Set flag for whether to update FP8 weight updates
ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module() ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module()
...@@ -377,6 +378,7 @@ def _make_graphed_callables( ...@@ -377,6 +378,7 @@ def _make_graphed_callables(
@staticmethod @staticmethod
@torch.autograd.function.once_differentiable @torch.autograd.function.once_differentiable
def backward(ctx, *grads): def backward(ctx, *grads):
# pylint: disable=missing-function-docstring
# Replay backward graph # Replay backward graph
assert len(grads) == len(static_grad_outputs) assert len(grads) == len(static_grad_outputs)
......
...@@ -8,6 +8,8 @@ from typing import Callable, Optional, Tuple ...@@ -8,6 +8,8 @@ from typing import Callable, Optional, Tuple
import torch import torch
# pylint: disable=unnecessary-lambda-assignment
jit_fuser = torch.jit.script jit_fuser = torch.jit.script
if torch.__version__ >= "2" and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))): if torch.__version__ >= "2" and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))):
jit_fuser = torch.compile jit_fuser = torch.compile
......
...@@ -124,6 +124,7 @@ class _NoopCatFunc(torch.autograd.Function): ...@@ -124,6 +124,7 @@ class _NoopCatFunc(torch.autograd.Function):
dim: int, dim: int,
*tensors: Tuple[torch.Tensor, ...], *tensors: Tuple[torch.Tensor, ...],
) -> torch.Tensor: ) -> torch.Tensor:
# pylint: disable=missing-function-docstring
# Check first tensor # Check first tensor
if not tensors: if not tensors:
...@@ -192,6 +193,7 @@ class _NoopCatFunc(torch.autograd.Function): ...@@ -192,6 +193,7 @@ class _NoopCatFunc(torch.autograd.Function):
ctx, ctx,
grad_output: torch.Tensor, grad_output: torch.Tensor,
) -> Tuple[Optional[torch.Tensor], ...]: ) -> Tuple[Optional[torch.Tensor], ...]:
# pylint: disable=missing-function-docstring
grad_inputs = [] grad_inputs = []
for split_start, split_end in ctx.split_ranges: for split_start, split_end in ctx.split_ranges:
slices = [slice(None)] * grad_output.dim() slices = [slice(None)] * grad_output.dim()
......
...@@ -694,7 +694,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -694,7 +694,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
else: else:
# If fp8 isn't enabled, turn off and return. # If fp8 isn't enabled, turn off and return.
self.fp8_initialized = False self.fp8_initialized = False
return
@contextmanager @contextmanager
def prepare_forward( def prepare_forward(
...@@ -744,7 +743,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -744,7 +743,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if self.fp8 and in_fp8_activation_recompute_phase(): if self.fp8 and in_fp8_activation_recompute_phase():
FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta) FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta)
return
def set_nccl_overlap_warning_if_tp(self) -> None: def set_nccl_overlap_warning_if_tp(self) -> None:
"""When using TP, the NCCL communication needs to be scheduled """When using TP, the NCCL communication needs to be scheduled
......
...@@ -28,6 +28,7 @@ class _Fp8Padding(torch.autograd.Function): ...@@ -28,6 +28,7 @@ class _Fp8Padding(torch.autograd.Function):
padded_m_splits: List[int], padded_m_splits: List[int],
is_grad_enabled: bool, is_grad_enabled: bool,
) -> torch.Tensor: ) -> torch.Tensor:
# pylint: disable=missing-function-docstring
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
in_features = inp.shape[-1] in_features = inp.shape[-1]
...@@ -46,6 +47,7 @@ class _Fp8Padding(torch.autograd.Function): ...@@ -46,6 +47,7 @@ class _Fp8Padding(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, grad_output: torch.Tensor): def backward(ctx, grad_output: torch.Tensor):
# pylint: disable=missing-function-docstring
grad_input = None grad_input = None
if ctx.requires_dgrad: if ctx.requires_dgrad:
......
...@@ -28,6 +28,7 @@ class _Fp8Unpadding(torch.autograd.Function): ...@@ -28,6 +28,7 @@ class _Fp8Unpadding(torch.autograd.Function):
padded_m_splits: List[int], padded_m_splits: List[int],
is_grad_enabled: bool, is_grad_enabled: bool,
) -> torch.Tensor: ) -> torch.Tensor:
# pylint: disable=missing-function-docstring
inputmats = torch.split(inp.view(-1, inp.shape[-1]), padded_m_splits) inputmats = torch.split(inp.view(-1, inp.shape[-1]), padded_m_splits)
out_ret = torch.cat( out_ret = torch.cat(
[grad_output_mat[: m_splits[i]] for i, grad_output_mat in enumerate(inputmats)], dim=0 [grad_output_mat[: m_splits[i]] for i, grad_output_mat in enumerate(inputmats)], dim=0
...@@ -42,6 +43,7 @@ class _Fp8Unpadding(torch.autograd.Function): ...@@ -42,6 +43,7 @@ class _Fp8Unpadding(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, grad_output: torch.Tensor): def backward(ctx, grad_output: torch.Tensor):
# pylint: disable=missing-function-docstring
grad_input = None grad_input = None
if ctx.requires_dgrad: if ctx.requires_dgrad:
grad_output = grad_output.contiguous() grad_output = grad_output.contiguous()
......
...@@ -70,6 +70,7 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -70,6 +70,7 @@ class _GroupedLinear(torch.autograd.Function):
weights_fp8: List[Union[Float8Tensor, None]], weights_fp8: List[Union[Float8Tensor, None]],
*weights_and_biases: Union[Float8Tensor, torch.Tensor, None], *weights_and_biases: Union[Float8Tensor, torch.Tensor, None],
) -> torch.Tensor: ) -> torch.Tensor:
# pylint: disable=missing-function-docstring
num_gemms = len(m_splits) num_gemms = len(m_splits)
weights = weights_and_biases[:num_gemms] weights = weights_and_biases[:num_gemms]
biases = weights_and_biases[num_gemms:] biases = weights_and_biases[num_gemms:]
...@@ -268,6 +269,7 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -268,6 +269,7 @@ class _GroupedLinear(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
# pylint: disable=missing-function-docstring
with torch.cuda.nvtx.range("_GroupedLinear_backward"): with torch.cuda.nvtx.range("_GroupedLinear_backward"):
( (
inputmat_scale_inv, inputmat_scale_inv,
...@@ -641,7 +643,7 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -641,7 +643,7 @@ class GroupedLinear(TransformerEngineBaseModule):
if self.primary_weights_in_fp8: if self.primary_weights_in_fp8:
self.init_fp8_metadata(num_gemms=self.num_gemms) self.init_fp8_metadata(num_gemms=self.num_gemms)
self.reset_parameters(defer_init=(device == "meta")) self.reset_parameters(defer_init=device == "meta")
# For RPL, bias has to be added after TP collectives # For RPL, bias has to be added after TP collectives
# So it cannot be fused with the GEMM # So it cannot be fused with the GEMM
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment