Unverified Commit 07bf4acf authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Remove unnecessary Pylint overrides (#794)



* Remove unnecessary Pylint overrides
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fixes to lint
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent aaf93548
......@@ -3,6 +3,8 @@
# See LICENSE for license information.
"""Transformer Engine bindings for pyTorch"""
import torch
from .module import LayerNormLinear
from .module import Linear
from .module import LayerNormMLP
......@@ -32,8 +34,8 @@ from .te_onnx_extensions import (
onnx_rmsnorm_fwd,
onnx_rmsnorm_fwd_fp8
)
try:
import torch
torch._dynamo.config.error_on_nested_jit_trace = False
except: # pylint: disable=bare-except
pass
except AttributeError:
pass # error_on_nested_jit_trace was added in PyTorch 2.2.0
......@@ -3,8 +3,10 @@
# See LICENSE for license information.
"""Functionality for CPU offloading of tensors saved for backward pass."""
from typing import Any
from __future__ import annotations
from contextlib import nullcontext
from typing import Any, Dict, Optional
import torch
from .float8_tensor import Float8Tensor
......@@ -99,10 +101,17 @@ class CpuOffloadHookWithOffloadHandler(CpuOffloadSavedTensorHook):
and `tensor_pop` interface. How the offload-handler manages the offloading, recovering
or prefetching timing is transparent to this hook.
"""
def __init__(self, offload_handler, handler_extra_kwargs={}, debug=False) -> None: # pylint: disable=dangerous-default-value
self.debug = debug
self.offload_handler = offload_handler
self.handler_extra_kwargs = handler_extra_kwargs
def __init__(
self,
offload_handler: OffloadHandler,
handler_extra_kwargs: Optional[Dict[str,Any]] = None,
debug: bool = False,
) -> None:
if handler_extra_kwargs is None:
handler_extra_kwargs = {}
self.debug: bool = debug
self.offload_handler: OffloadHandler = offload_handler
self.handler_extra_kwargs: Dict[str,Any] = handler_extra_kwargs
super().__init__()
def on_save_for_backward(self, tensor: torch.Tensor) -> Any:
......@@ -290,10 +299,10 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
allocate_new_buf = True
else:
tensor_buf = id_buf_map[tensor_id]
if not (tensor_buf.size() == tensor.size() and tensor_buf.dtype == tensor.dtype): # pylint: disable=simplifiable-if-statement
allocate_new_buf = True
else:
allocate_new_buf = False # in this case, reuse the old buffer
allocate_new_buf = (
tensor_buf.size() != tensor.size()
or tensor_buf.dtype != tensor.dtype
)
if allocate_new_buf:
# supposed to only execute once
......@@ -491,7 +500,7 @@ def get_cpu_offload_context(
def tensor_need_offloading_checker_weights(tensor):
return hasattr(tensor, "weight_offloading")
def tensor_need_offloading_checker_all(tensor): # pylint: disable=unused-argument
def tensor_need_offloading_checker_all(tensor):
return (hasattr(tensor,"activation_offloading") or hasattr(tensor, "weight_offloading"))
if offload_activations and offload_weights:
......
......@@ -730,8 +730,6 @@ class Float8Tensor(torch.Tensor):
return None
# Slice op
# TODO Consider additional bookkeeping so we invalidate caches # pylint: disable=fixme
# if these slices are modified in-place
if func == aten.slice.Tensor:
tensor = args[0]
data = tensor._data
......
......@@ -502,12 +502,12 @@ def fp8_model_init(enabled: bool = True) -> None:
This functionality is *EXPERIMENTAL*.
"""
_fp8_parameters = FP8GlobalStateManager.FP8_PARAMETERS
FP8GlobalStateManager.FP8_PARAMETERS = enabled
try:
_fp8_parameters = FP8GlobalStateManager.FP8_PARAMETERS
FP8GlobalStateManager.FP8_PARAMETERS = enabled
yield
finally:
FP8GlobalStateManager.FP8_PARAMETERS = _fp8_parameters # pylint: disable=used-before-assignment
FP8GlobalStateManager.FP8_PARAMETERS = _fp8_parameters
@contextmanager
......@@ -555,16 +555,16 @@ def fp8_autocast(
distributed group over which amaxes for the fp8 tensors
are reduced at the end of each training step.
"""
fp8_state = FP8GlobalStateManager.get_fp8_autocast_state()
FP8GlobalStateManager.fp8_autocast_enter(enabled=enabled,
calibrating=calibrating,
fp8_recipe=fp8_recipe,
fp8_group=fp8_group,
_graph=_graph)
try:
fp8_state = FP8GlobalStateManager.get_fp8_autocast_state()
FP8GlobalStateManager.fp8_autocast_enter(enabled=enabled,
calibrating=calibrating,
fp8_recipe=fp8_recipe,
fp8_group=fp8_group,
_graph=_graph)
yield
finally:
FP8GlobalStateManager.set_fp8_autocast_state(fp8_state) # pylint: disable=used-before-assignment
FP8GlobalStateManager.set_fp8_autocast_state(fp8_state)
FP8GlobalStateManager.fp8_autocast_exit(enabled, _graph=_graph)
......
......@@ -703,7 +703,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
out=grad_output_c,
)
else:
grad_output_c = grad_ouput_mat # pylint: disable=undefined-variable
grad_output_c = grad_output_mat
if not ctx.ub_overlap_ag:
grad_output_c, _ = gather_along_first_dim(grad_output_c, ctx.tp_group)
if not isinstance(grad_output_c, Float8Tensor):
......
......@@ -336,19 +336,20 @@ class FusedScaleMaskSoftmax(nn.Module):
return self.forward_fused_softmax(inp, mask, scale)
return self.forward_torch_softmax(inp, mask, scale)
def is_kernel_available(self, mask: torch.Tensor, b: int, np: int, sq: int, sk: int) -> bool:
def is_kernel_available(self, mask: torch.Tensor, b: int, np: int, sq: int, sk: int) -> bool: # pylint: disable=too-many-return-statements
"""Check FusedScaleMaskSoftmax kernel availability based on size"""
attn_batches = b * np
if ( # pylint: disable=too-many-boolean-expressions
not self.scaled_masked_softmax_fusion # user doesn't want to fuse
or not self.input_in_float16 # input must be fp16
or sk < 16
or sk > 16384 # sk must be 16 ~ 16384
or sk % 8 != 0 # sk must be divisor of 8
or self.attn_mask_type == "arbitrary" # Custom masks not supported
):
return False
if not self.scaled_masked_softmax_fusion:
return False # user doesn't want to fuse
if not self.input_in_float16:
return False # input must be fp16
if not 16 < sk < 16384:
return False # sk must be 16 ~ 16384
if sk % 8 != 0:
return False # sk must be divisor of 8
if self.attn_mask_type == "arbitrary":
return False # Custom masks not supported
if self.attn_mask_type == "causal": # unfused causal softmax kernel
return True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment