"...models/git@developer.sourcefind.cn:jerrrrry/infinilm.git" did not exist on "fc97bbd8ef5fd9ecc4de04186939c89c2a957efa"
Unverified Commit 07bf4acf authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Remove unnecessary Pylint overrides (#794)



* Remove unnecessary Pylint overrides
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fixes to lint
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent aaf93548
...@@ -3,6 +3,8 @@ ...@@ -3,6 +3,8 @@
# See LICENSE for license information. # See LICENSE for license information.
"""Transformer Engine bindings for pyTorch""" """Transformer Engine bindings for pyTorch"""
import torch
from .module import LayerNormLinear from .module import LayerNormLinear
from .module import Linear from .module import Linear
from .module import LayerNormMLP from .module import LayerNormMLP
...@@ -32,8 +34,8 @@ from .te_onnx_extensions import ( ...@@ -32,8 +34,8 @@ from .te_onnx_extensions import (
onnx_rmsnorm_fwd, onnx_rmsnorm_fwd,
onnx_rmsnorm_fwd_fp8 onnx_rmsnorm_fwd_fp8
) )
try: try:
import torch
torch._dynamo.config.error_on_nested_jit_trace = False torch._dynamo.config.error_on_nested_jit_trace = False
except: # pylint: disable=bare-except except AttributeError:
pass pass # error_on_nested_jit_trace was added in PyTorch 2.2.0
...@@ -3,8 +3,10 @@ ...@@ -3,8 +3,10 @@
# See LICENSE for license information. # See LICENSE for license information.
"""Functionality for CPU offloading of tensors saved for backward pass.""" """Functionality for CPU offloading of tensors saved for backward pass."""
from typing import Any from __future__ import annotations
from contextlib import nullcontext from contextlib import nullcontext
from typing import Any, Dict, Optional
import torch import torch
from .float8_tensor import Float8Tensor from .float8_tensor import Float8Tensor
...@@ -99,10 +101,17 @@ class CpuOffloadHookWithOffloadHandler(CpuOffloadSavedTensorHook): ...@@ -99,10 +101,17 @@ class CpuOffloadHookWithOffloadHandler(CpuOffloadSavedTensorHook):
and `tensor_pop` interface. How the offload-handler manages the offloading, recovering and `tensor_pop` interface. How the offload-handler manages the offloading, recovering
or prefetching timing is transparent to this hook. or prefetching timing is transparent to this hook.
""" """
def __init__(self, offload_handler, handler_extra_kwargs={}, debug=False) -> None: # pylint: disable=dangerous-default-value def __init__(
self.debug = debug self,
self.offload_handler = offload_handler offload_handler: OffloadHandler,
self.handler_extra_kwargs = handler_extra_kwargs handler_extra_kwargs: Optional[Dict[str,Any]] = None,
debug: bool = False,
) -> None:
if handler_extra_kwargs is None:
handler_extra_kwargs = {}
self.debug: bool = debug
self.offload_handler: OffloadHandler = offload_handler
self.handler_extra_kwargs: Dict[str,Any] = handler_extra_kwargs
super().__init__() super().__init__()
def on_save_for_backward(self, tensor: torch.Tensor) -> Any: def on_save_for_backward(self, tensor: torch.Tensor) -> Any:
...@@ -290,10 +299,10 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler): ...@@ -290,10 +299,10 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
allocate_new_buf = True allocate_new_buf = True
else: else:
tensor_buf = id_buf_map[tensor_id] tensor_buf = id_buf_map[tensor_id]
if not (tensor_buf.size() == tensor.size() and tensor_buf.dtype == tensor.dtype): # pylint: disable=simplifiable-if-statement allocate_new_buf = (
allocate_new_buf = True tensor_buf.size() != tensor.size()
else: or tensor_buf.dtype != tensor.dtype
allocate_new_buf = False # in this case, reuse the old buffer )
if allocate_new_buf: if allocate_new_buf:
# supposed to only execute once # supposed to only execute once
...@@ -491,7 +500,7 @@ def get_cpu_offload_context( ...@@ -491,7 +500,7 @@ def get_cpu_offload_context(
def tensor_need_offloading_checker_weights(tensor): def tensor_need_offloading_checker_weights(tensor):
return hasattr(tensor, "weight_offloading") return hasattr(tensor, "weight_offloading")
def tensor_need_offloading_checker_all(tensor): # pylint: disable=unused-argument def tensor_need_offloading_checker_all(tensor):
return (hasattr(tensor,"activation_offloading") or hasattr(tensor, "weight_offloading")) return (hasattr(tensor,"activation_offloading") or hasattr(tensor, "weight_offloading"))
if offload_activations and offload_weights: if offload_activations and offload_weights:
......
...@@ -730,8 +730,6 @@ class Float8Tensor(torch.Tensor): ...@@ -730,8 +730,6 @@ class Float8Tensor(torch.Tensor):
return None return None
# Slice op # Slice op
# TODO Consider additional bookkeeping so we invalidate caches # pylint: disable=fixme
# if these slices are modified in-place
if func == aten.slice.Tensor: if func == aten.slice.Tensor:
tensor = args[0] tensor = args[0]
data = tensor._data data = tensor._data
......
...@@ -502,12 +502,12 @@ def fp8_model_init(enabled: bool = True) -> None: ...@@ -502,12 +502,12 @@ def fp8_model_init(enabled: bool = True) -> None:
This functionality is *EXPERIMENTAL*. This functionality is *EXPERIMENTAL*.
""" """
_fp8_parameters = FP8GlobalStateManager.FP8_PARAMETERS
FP8GlobalStateManager.FP8_PARAMETERS = enabled
try: try:
_fp8_parameters = FP8GlobalStateManager.FP8_PARAMETERS
FP8GlobalStateManager.FP8_PARAMETERS = enabled
yield yield
finally: finally:
FP8GlobalStateManager.FP8_PARAMETERS = _fp8_parameters # pylint: disable=used-before-assignment FP8GlobalStateManager.FP8_PARAMETERS = _fp8_parameters
@contextmanager @contextmanager
...@@ -555,16 +555,16 @@ def fp8_autocast( ...@@ -555,16 +555,16 @@ def fp8_autocast(
distributed group over which amaxes for the fp8 tensors distributed group over which amaxes for the fp8 tensors
are reduced at the end of each training step. are reduced at the end of each training step.
""" """
fp8_state = FP8GlobalStateManager.get_fp8_autocast_state()
FP8GlobalStateManager.fp8_autocast_enter(enabled=enabled,
calibrating=calibrating,
fp8_recipe=fp8_recipe,
fp8_group=fp8_group,
_graph=_graph)
try: try:
fp8_state = FP8GlobalStateManager.get_fp8_autocast_state()
FP8GlobalStateManager.fp8_autocast_enter(enabled=enabled,
calibrating=calibrating,
fp8_recipe=fp8_recipe,
fp8_group=fp8_group,
_graph=_graph)
yield yield
finally: finally:
FP8GlobalStateManager.set_fp8_autocast_state(fp8_state) # pylint: disable=used-before-assignment FP8GlobalStateManager.set_fp8_autocast_state(fp8_state)
FP8GlobalStateManager.fp8_autocast_exit(enabled, _graph=_graph) FP8GlobalStateManager.fp8_autocast_exit(enabled, _graph=_graph)
......
...@@ -703,7 +703,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -703,7 +703,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
out=grad_output_c, out=grad_output_c,
) )
else: else:
grad_output_c = grad_ouput_mat # pylint: disable=undefined-variable grad_output_c = grad_output_mat
if not ctx.ub_overlap_ag: if not ctx.ub_overlap_ag:
grad_output_c, _ = gather_along_first_dim(grad_output_c, ctx.tp_group) grad_output_c, _ = gather_along_first_dim(grad_output_c, ctx.tp_group)
if not isinstance(grad_output_c, Float8Tensor): if not isinstance(grad_output_c, Float8Tensor):
......
...@@ -336,19 +336,20 @@ class FusedScaleMaskSoftmax(nn.Module): ...@@ -336,19 +336,20 @@ class FusedScaleMaskSoftmax(nn.Module):
return self.forward_fused_softmax(inp, mask, scale) return self.forward_fused_softmax(inp, mask, scale)
return self.forward_torch_softmax(inp, mask, scale) return self.forward_torch_softmax(inp, mask, scale)
def is_kernel_available(self, mask: torch.Tensor, b: int, np: int, sq: int, sk: int) -> bool: def is_kernel_available(self, mask: torch.Tensor, b: int, np: int, sq: int, sk: int) -> bool: # pylint: disable=too-many-return-statements
"""Check FusedScaleMaskSoftmax kernel availability based on size""" """Check FusedScaleMaskSoftmax kernel availability based on size"""
attn_batches = b * np attn_batches = b * np
if ( # pylint: disable=too-many-boolean-expressions if not self.scaled_masked_softmax_fusion:
not self.scaled_masked_softmax_fusion # user doesn't want to fuse return False # user doesn't want to fuse
or not self.input_in_float16 # input must be fp16 if not self.input_in_float16:
or sk < 16 return False # input must be fp16
or sk > 16384 # sk must be 16 ~ 16384 if not 16 < sk < 16384:
or sk % 8 != 0 # sk must be divisor of 8 return False # sk must be 16 ~ 16384
or self.attn_mask_type == "arbitrary" # Custom masks not supported if sk % 8 != 0:
): return False # sk must be divisor of 8
return False if self.attn_mask_type == "arbitrary":
return False # Custom masks not supported
if self.attn_mask_type == "causal": # unfused causal softmax kernel if self.attn_mask_type == "causal": # unfused causal softmax kernel
return True return True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment