Unverified Commit 144e4888 authored by Jan Bielak's avatar Jan Bielak Committed by GitHub
Browse files

Fix some Pylance errors (#259)



* Ignore IDE files
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Fix typing errors
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Ignore devcontainer files
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Avoid import from private module
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Apply @timmoon10 's suggestions
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

---------
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>
parent 80825fde
...@@ -16,6 +16,11 @@ build/ ...@@ -16,6 +16,11 @@ build/
__pycache__ __pycache__
.ycm_extra_conf.py .ycm_extra_conf.py
.vimrc .vimrc
.vs
.vscode
.cache
.hypothesis
.devcontainer.json
tests/cpp/build/ tests/cpp/build/
docs/_build docs/_build
.ipynb_checkpoints .ipynb_checkpoints
......
...@@ -1501,7 +1501,7 @@ class SoftmaxPrimitive(BasePrimitive): ...@@ -1501,7 +1501,7 @@ class SoftmaxPrimitive(BasePrimitive):
pow2 = 1 << (k_seqlen - 1).bit_length() pow2 = 1 << (k_seqlen - 1).bit_length()
warp_size = pow2 if pow2 < threads_per_warp else threads_per_warp warp_size = pow2 if pow2 < threads_per_warp else threads_per_warp
batches_per_warp = 2 if pow2 <= 128 else 1 batches_per_warp = 2 if pow2 <= 128 else 1
warps_per_block = threads_per_block / warp_size warps_per_block = threads_per_block // warp_size
batches_per_block = warps_per_block * batches_per_warp batches_per_block = warps_per_block * batches_per_warp
return batches_per_block return batches_per_block
......
...@@ -281,7 +281,7 @@ class FP8Helper: ...@@ -281,7 +281,7 @@ class FP8Helper:
return jnp.vstack([fp8_max_per_gemm] * num_of_gemm) return jnp.vstack([fp8_max_per_gemm] * num_of_gemm)
@staticmethod @staticmethod
def get_fp8_meta_indices(gemm_idx: int) -> Tuple[int]: def get_fp8_meta_indices(gemm_idx: int) -> Tuple[int, int, int]:
""" """
Obtain the index about FP8 metas by the given GEMM index. Obtain the index about FP8 metas by the given GEMM index.
""" """
...@@ -453,7 +453,7 @@ def get_delayed_scaling(): ...@@ -453,7 +453,7 @@ def get_delayed_scaling():
""" """
amax_compute_algo = "max" if FP8Helper.AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX \ amax_compute_algo = "max" if FP8Helper.AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX \
else "most_recent" else "most_recent"
return DelayedScaling(margin=FP8Helper.MARGIN, return DelayedScaling(margin=int(FP8Helper.MARGIN),
interval=FP8Helper.UPDATE_FP8META_INTERVAL, interval=FP8Helper.UPDATE_FP8META_INTERVAL,
fp8_format=FP8Helper.FP8_FORMAT, fp8_format=FP8Helper.FP8_FORMAT,
amax_history_len=FP8Helper.AMAX_HISTORY_LEN, amax_history_len=FP8Helper.AMAX_HISTORY_LEN,
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
"""Enums for e2e transformer""" """Enums for e2e transformer"""
import torch import torch
import torch.distributed
import transformer_engine_extensions as tex import transformer_engine_extensions as tex
...@@ -29,4 +30,4 @@ LayerTypes = ("encoder", "decoder") ...@@ -29,4 +30,4 @@ LayerTypes = ("encoder", "decoder")
GemmParallelModes = ("row", "column", None) GemmParallelModes = ("row", "column", None)
dist_group_type = torch._C._distributed_c10d.ProcessGroup dist_group_type = torch.distributed.ProcessGroup
...@@ -7,7 +7,7 @@ import os ...@@ -7,7 +7,7 @@ import os
import pickle import pickle
import warnings import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Union, Optional, Tuple, Dict, Any, List from typing import Generator, Union, Optional, Tuple, Dict, Any, List
from functools import partial from functools import partial
from contextlib import contextmanager from contextlib import contextmanager
...@@ -86,7 +86,7 @@ def _prepare_backward( ...@@ -86,7 +86,7 @@ def _prepare_backward(
tp_group: dist_group_type, tp_group: dist_group_type,
tp_size: int, tp_size: int,
name: str = "" name: str = ""
) -> None: ) -> Generator[None, None, None]:
"""Checks and prep for BWD.""" """Checks and prep for BWD."""
if fp8: if fp8:
global _amax_reduce_handle_bwd global _amax_reduce_handle_bwd
...@@ -542,7 +542,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -542,7 +542,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
inp: torch.Tensor, inp: torch.Tensor,
is_first_microbatch: Union[bool, None], is_first_microbatch: Union[bool, None],
num_gemms: int = 1, num_gemms: int = 1,
) -> None: ) -> Generator[torch.Tensor, None, None]:
"""Checks and prep for FWD. """Checks and prep for FWD.
The context manager is needed because there isn't a way for a module to know The context manager is needed because there isn't a way for a module to know
if it's the last FP8 module in the forward autocast. It is useful if it's the last FP8 module in the forward autocast. It is useful
......
...@@ -342,6 +342,6 @@ class FusedScaleMaskSoftmax(nn.Module): ...@@ -342,6 +342,6 @@ class FusedScaleMaskSoftmax(nn.Module):
pow2 = 1 << (key_seq_len - 1).bit_length() pow2 = 1 << (key_seq_len - 1).bit_length()
warp_size = pow2 if pow2 < THREADS_PER_WARP else THREADS_PER_WARP warp_size = pow2 if pow2 < THREADS_PER_WARP else THREADS_PER_WARP
batches_per_warp = 2 if pow2 <= 128 else 1 batches_per_warp = 2 if pow2 <= 128 else 1
warps_per_block = THREADS_PER_BLOCK / warp_size warps_per_block = THREADS_PER_BLOCK // warp_size
batches_per_block = warps_per_block * batches_per_warp batches_per_block = warps_per_block * batches_per_warp
return batches_per_block return batches_per_block
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
"""FP8 utilies for TransformerEngine""" """FP8 utilies for TransformerEngine"""
from contextlib import contextmanager from contextlib import contextmanager
from typing import Optional, Dict, Any from typing import Generator, Optional, Dict, Any
import tensorflow as tf import tensorflow as tf
import transformer_engine_tensorflow as tex import transformer_engine_tensorflow as tex
...@@ -69,7 +69,7 @@ def get_default_fp8_recipe(): ...@@ -69,7 +69,7 @@ def get_default_fp8_recipe():
def fp8_autocast( def fp8_autocast(
enabled: bool = False, enabled: bool = False,
fp8_recipe: Optional[DelayedScaling] = None, fp8_recipe: Optional[DelayedScaling] = None,
) -> None: ) -> Generator[None, None, None]:
""" """
Context manager for FP8 usage. Context manager for FP8 usage.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment