Unverified Commit 9416519d authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Apply formatting (#929)



* Apply formatting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Apply formatting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent d99142a0
...@@ -12,7 +12,7 @@ from .constants import RecomputeFunctionNames ...@@ -12,7 +12,7 @@ from .constants import RecomputeFunctionNames
from .fp8 import get_global_fp8_state from .fp8 import get_global_fp8_state
__all__ = ['recompute'] __all__ = ["recompute"]
_DISABLE_RECOMPUTE = int(os.getenv("NVTE_DISABLE_RECOMPUTE", "0")) _DISABLE_RECOMPUTE = int(os.getenv("NVTE_DISABLE_RECOMPUTE", "0"))
...@@ -48,8 +48,9 @@ def recompute(function, *args, **kwargs): ...@@ -48,8 +48,9 @@ def recompute(function, *args, **kwargs):
kwargs : dict kwargs : dict
dictionary of string keys for keyword arguments to :attr:`function`. dictionary of string keys for keyword arguments to :attr:`function`.
""" """
assert not _DISABLE_RECOMPUTE, "Recompute is disabled. " \ assert (
f"Got NVTE_DISABLE_RECOMPUTE={_DISABLE_RECOMPUTE}." not _DISABLE_RECOMPUTE
), f"Recompute is disabled. Got NVTE_DISABLE_RECOMPUTE={_DISABLE_RECOMPUTE}."
global_fp8_state = get_global_fp8_state() global_fp8_state = get_global_fp8_state()
......
...@@ -27,7 +27,10 @@ if bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))) or os.path.isdir(build_tools_ ...@@ -27,7 +27,10 @@ if bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))) or os.path.isdir(build_tools_
from build_tools.build_ext import get_build_ext # pylint: disable=wrong-import-position from build_tools.build_ext import get_build_ext # pylint: disable=wrong-import-position
from build_tools.utils import package_files, copy_common_headers # pylint: disable=wrong-import-position from build_tools.utils import (
package_files,
copy_common_headers,
) # pylint: disable=wrong-import-position
from build_tools.te_version import te_version # pylint: disable=wrong-import-position from build_tools.te_version import te_version # pylint: disable=wrong-import-position
from build_tools.paddle import setup_paddle_extension # pylint: disable=wrong-import-position from build_tools.paddle import setup_paddle_extension # pylint: disable=wrong-import-position
...@@ -38,12 +41,12 @@ CMakeBuildExtension = get_build_ext(BuildExtension) ...@@ -38,12 +41,12 @@ CMakeBuildExtension = get_build_ext(BuildExtension)
if __name__ == "__main__": if __name__ == "__main__":
# Extensions # Extensions
common_headers_dir = "common_headers" common_headers_dir = "common_headers"
copy_common_headers( copy_common_headers(current_file_path.parent, str(current_file_path / common_headers_dir))
current_file_path.parent,
str(current_file_path / common_headers_dir))
ext_modules = [ ext_modules = [
setup_paddle_extension( setup_paddle_extension(
"csrc", current_file_path / "csrc", current_file_path / common_headers_dir)] "csrc", current_file_path / "csrc", current_file_path / common_headers_dir
)
]
# Configure package # Configure package
setuptools.setup( setuptools.setup(
...@@ -56,9 +59,11 @@ if __name__ == "__main__": ...@@ -56,9 +59,11 @@ if __name__ == "__main__":
install_requires=["paddlepaddle-gpu"], install_requires=["paddlepaddle-gpu"],
tests_require=["numpy"], tests_require=["numpy"],
include_package_data=True, include_package_data=True,
package_data={"csrc": package_files("csrc"), package_data={
"csrc": package_files("csrc"),
common_headers_dir: package_files(common_headers_dir), common_headers_dir: package_files(common_headers_dir),
"build_tools": package_files("build_tools")}, "build_tools": package_files("build_tools"),
},
) )
if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")): if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")):
shutil.rmtree(common_headers_dir) shutil.rmtree(common_headers_dir)
...@@ -10,14 +10,16 @@ import paddle.nn.functional as F ...@@ -10,14 +10,16 @@ import paddle.nn.functional as F
from .cpp_extensions import swiglu_pd from .cpp_extensions import swiglu_pd
def cast_if_needed(tensor: Union[paddle.Tensor, None], def cast_if_needed(
dtype: paddle.dtype) -> Union[paddle.Tensor, None]: tensor: Union[paddle.Tensor, None], dtype: paddle.dtype
) -> Union[paddle.Tensor, None]:
"""Cast tensor to dtype""" """Cast tensor to dtype"""
return tensor if tensor is None or tensor.dtype == dtype else paddle.cast(tensor, dtype) return tensor if tensor is None or tensor.dtype == dtype else paddle.cast(tensor, dtype)
def cast_if_needed_inplace(tensor: Union[paddle.Tensor, None], def cast_if_needed_inplace(
dtype: paddle.dtype) -> Union[paddle.Tensor, None]: tensor: Union[paddle.Tensor, None], dtype: paddle.dtype
) -> Union[paddle.Tensor, None]:
"""Cast tensor to dtype (inplace), not to be used on layer inputs""" """Cast tensor to dtype (inplace), not to be used on layer inputs"""
return tensor if tensor is None or tensor.dtype == dtype else tensor._to(dtype=dtype) return tensor if tensor is None or tensor.dtype == dtype else tensor._to(dtype=dtype)
...@@ -36,7 +38,8 @@ def assert_dim_for_fp8_forward_exec(tensor: paddle.Tensor) -> None: ...@@ -36,7 +38,8 @@ def assert_dim_for_fp8_forward_exec(tensor: paddle.Tensor) -> None:
# single tensor check so it's clear which tensor is triggering the assertion # single tensor check so it's clear which tensor is triggering the assertion
assert check_dim_for_fp8_forward_exec(tensor), ( assert check_dim_for_fp8_forward_exec(tensor), (
"Tensor dimensions are not compatible for FP8 execution: " "Tensor dimensions are not compatible for FP8 execution: "
f"({tensor.shape[0]} % 8 != 0, {tensor.shape[1]} % 16 != 0)") f"({tensor.shape[0]} % 8 != 0, {tensor.shape[1]} % 16 != 0)"
)
def get_bias_dtype(activation_dtype: paddle.dtype): def get_bias_dtype(activation_dtype: paddle.dtype):
...@@ -47,18 +50,19 @@ def get_bias_dtype(activation_dtype: paddle.dtype): ...@@ -47,18 +50,19 @@ def get_bias_dtype(activation_dtype: paddle.dtype):
def get_paddle_act_func(activation): def get_paddle_act_func(activation):
"""Get paddle activation function""" """Get paddle activation function"""
funcs = { funcs = {
'gelu': F.gelu, "gelu": F.gelu,
'relu': F.relu, "relu": F.relu,
'silu': F.silu, "silu": F.silu,
'swiglu': swiglu_pd, "swiglu": swiglu_pd,
} }
if activation not in funcs: if activation not in funcs:
raise "Activation type " + activation + " is not supported." raise "Activation type " + activation + " is not supported."
return funcs[activation] return funcs[activation]
def attention_mask_func(attention_scores: paddle.Tensor, def attention_mask_func(
attention_mask: paddle.Tensor) -> paddle.Tensor: attention_scores: paddle.Tensor, attention_mask: paddle.Tensor
) -> paddle.Tensor:
"""Get attention mask""" """Get attention mask"""
def _masked_fill(x, mask, value): def _masked_fill(x, mask, value):
...@@ -71,14 +75,14 @@ def attention_mask_func(attention_scores: paddle.Tensor, ...@@ -71,14 +75,14 @@ def attention_mask_func(attention_scores: paddle.Tensor,
def mask_to_cu_seqlens(mask: paddle.Tensor, need_kv: bool = False) -> paddle.Tensor: def mask_to_cu_seqlens(mask: paddle.Tensor, need_kv: bool = False) -> paddle.Tensor:
"""Convert mask to cu_seqlens""" """Convert mask to cu_seqlens"""
assert 'bool' in str(mask.dtype), "mask must be bool dtype" assert "bool" in str(mask.dtype), "mask must be bool dtype"
assert len(mask.shape) == 4 and mask.shape[1] == 1, "mask must be [b, 1, s_q, s_kv]" assert len(mask.shape) == 4 and mask.shape[1] == 1, "mask must be [b, 1, s_q, s_kv]"
q_actual_seqlens = paddle.sum(mask[:, :, :, 0].logical_not(), axis=(-1, -2), dtype='int32') q_actual_seqlens = paddle.sum(mask[:, :, :, 0].logical_not(), axis=(-1, -2), dtype="int32")
q_cu_seqlens = paddle.cumsum(q_actual_seqlens) q_cu_seqlens = paddle.cumsum(q_actual_seqlens)
q_cu_seqlens = paddle.concat([paddle.zeros([1], dtype=paddle.int32), q_cu_seqlens], axis=0) q_cu_seqlens = paddle.concat([paddle.zeros([1], dtype=paddle.int32), q_cu_seqlens], axis=0)
if not need_kv: if not need_kv:
return q_cu_seqlens, None return q_cu_seqlens, None
kv_actual_seqlens = paddle.sum(mask[:, :, 0, :].logical_not(), axis=(-1, -2), dtype='int32') kv_actual_seqlens = paddle.sum(mask[:, :, 0, :].logical_not(), axis=(-1, -2), dtype="int32")
kv_cu_seqlens = paddle.cumsum(kv_actual_seqlens) kv_cu_seqlens = paddle.cumsum(kv_actual_seqlens)
kv_cu_seqlens = paddle.concat([paddle.zeros([1], dtype=paddle.int32), kv_cu_seqlens], axis=0) kv_cu_seqlens = paddle.concat([paddle.zeros([1], dtype=paddle.int32), kv_cu_seqlens], axis=0)
return q_cu_seqlens, kv_cu_seqlens return q_cu_seqlens, kv_cu_seqlens
...@@ -87,7 +91,7 @@ def mask_to_cu_seqlens(mask: paddle.Tensor, need_kv: bool = False) -> paddle.Ten ...@@ -87,7 +91,7 @@ def mask_to_cu_seqlens(mask: paddle.Tensor, need_kv: bool = False) -> paddle.Ten
def divide(numerator: int, denominator: int) -> int: def divide(numerator: int, denominator: int) -> int:
"""Ensure that numerator is divisible by the denominator and return """Ensure that numerator is divisible by the denominator and return
the division value.""" the division value."""
assert (numerator % denominator == 0), f"{numerator} is not divisible by {denominator}" assert numerator % denominator == 0, f"{numerator} is not divisible by {denominator}"
return numerator // denominator return numerator // denominator
...@@ -110,8 +114,9 @@ def save_for_backward_allow_none(ctx, *args) -> None: ...@@ -110,8 +114,9 @@ def save_for_backward_allow_none(ctx, *args) -> None:
def saved_tensor_allow_none(ctx) -> Tuple[Optional[paddle.Tensor]]: def saved_tensor_allow_none(ctx) -> Tuple[Optional[paddle.Tensor]]:
"""Used with `save_for_backward_allow_none` in pair. Get saved tensors from ctx.""" """Used with `save_for_backward_allow_none` in pair. Get saved tensors from ctx."""
assert hasattr(ctx, '_indices_mapping'), "`saved_tensor_allow_none` must be used " \ assert hasattr(
"with `save_for_backward_allow_none` in pair." ctx, "_indices_mapping"
), "`saved_tensor_allow_none` must be used with `save_for_backward_allow_none` in pair."
indices_mapping = ctx._indices_mapping indices_mapping = ctx._indices_mapping
outputs = [] outputs = []
...@@ -132,8 +137,12 @@ def clear_tensor_data(*tensors: Tuple[Optional[paddle.Tensor], ...]) -> None: ...@@ -132,8 +137,12 @@ def clear_tensor_data(*tensors: Tuple[Optional[paddle.Tensor], ...]) -> None:
""" """
def can_free(t): def can_free(t):
return (t is not None and isinstance(t, paddle.Tensor) and t._is_initialized() return (
and t.inplace_version == 0) t is not None
and isinstance(t, paddle.Tensor)
and t._is_initialized()
and t.inplace_version == 0
)
for t in tensors: for t in tensors:
if can_free(t): if can_free(t):
......
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -29,9 +29,22 @@ AttnTypes = ("self", "cross") ...@@ -29,9 +29,22 @@ AttnTypes = ("self", "cross")
AttnBiasTypes = ("pre_scale_bias", "post_scale_bias", "no_bias", "alibi") AttnBiasTypes = ("pre_scale_bias", "post_scale_bias", "no_bias", "alibi")
QKVLayouts = ( QKVLayouts = (
"sb3hd", "sbh3d", "sbhd_sb2hd", "sbhd_sbh2d", "sbhd_sbhd_sbhd", "sb3hd",
"bs3hd", "bsh3d", "bshd_bs2hd", "bshd_bsh2d", "bshd_bshd_bshd", "sbh3d",
"t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd") "sbhd_sb2hd",
"sbhd_sbh2d",
"sbhd_sbhd_sbhd",
"bs3hd",
"bsh3d",
"bshd_bs2hd",
"bshd_bsh2d",
"bshd_bshd_bshd",
"t3hd",
"th3d",
"thd_t2hd",
"thd_th2d",
"thd_thd_thd",
)
LayerTypes = ("encoder", "decoder") LayerTypes = ("encoder", "decoder")
......
...@@ -8,7 +8,7 @@ import torch ...@@ -8,7 +8,7 @@ import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
__all__ = ['gelu', 'relu', 'reglu', 'geglu', 'swiglu', 'qgelu', 'srelu'] __all__ = ["gelu", "relu", "reglu", "geglu", "swiglu", "qgelu", "srelu"]
def gelu( def gelu(
...@@ -167,6 +167,7 @@ def qgelu( ...@@ -167,6 +167,7 @@ def qgelu(
otype, otype,
) )
def srelu( def srelu(
inp: torch.Tensor, inp: torch.Tensor,
fp8_meta_tensor: tex.FP8TensorMeta, fp8_meta_tensor: tex.FP8TensorMeta,
......
...@@ -8,8 +8,7 @@ import torch ...@@ -8,8 +8,7 @@ import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
__all__ = ['cast_to_fp8', __all__ = ["cast_to_fp8", "cast_from_fp8"]
'cast_from_fp8']
def cast_to_fp8( def cast_to_fp8(
...@@ -30,7 +29,7 @@ def cast_to_fp8( ...@@ -30,7 +29,7 @@ def cast_to_fp8(
fp8_meta_tensor.amax_history, fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv, fp8_meta_tensor.scale_inv,
fp8_tensor, fp8_tensor,
otype otype,
) )
return None return None
...@@ -43,6 +42,7 @@ def cast_to_fp8( ...@@ -43,6 +42,7 @@ def cast_to_fp8(
otype, otype,
) )
def cast_from_fp8( def cast_from_fp8(
inp: torch.Tensor, inp: torch.Tensor,
fp8_meta_tensor: tex.FP8TensorMeta, fp8_meta_tensor: tex.FP8TensorMeta,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment