"cmd/git@developer.sourcefind.cn:liming6/sshd-tool.git" did not exist on "abfc4893f9d2abe9908fb0e407e5f67de1d0fce6"
Unverified Commit 67b67432 authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

Update FE to 1.5.2 and miscellaneous fixes (#975)



* update FE to 1.5.2
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* enable unfused attn for cross attn
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* unify logging info
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* omit cudnn 9.1.1 and 9.2.1 due to bugs
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* set cu_seqlens_padded to cu_seqlens by default
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* replace variable name with ctx.variable
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Revert "enable unfused attn for cross attn"

This reverts commit bc49f14fca904217a711b4a86c45a4a739a17a14.
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* restrict cudnn version for fp8 tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove mha_fill for FP8
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Revert "remove mha_fill for FP8"

This reverts commit 83ffc44114dc6eb3d426d742b6c5a4d34805ec04.
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* lower cudnn version to >=9.2.1
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 7326af9d
Subproject commit b740542818f36857acf7f9853f749bbad4118c65 Subproject commit 98ca4e1941fe3263f128f74f10063a3ea35c7019
...@@ -1270,7 +1270,7 @@ def _rmse(a, b): ...@@ -1270,7 +1270,7 @@ def _rmse(a, b):
return math.sqrt((torch.pow((a - b), 2) / a.numel()).sum()) return math.sqrt((torch.pow((a - b), 2) / a.numel()).sum())
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 3), reason="cuDNN 8.9.3+ is required.") @pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.") @pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.")
@pytest.mark.parametrize("dtype", param_types_fp8_vs_f16) @pytest.mark.parametrize("dtype", param_types_fp8_vs_f16)
...@@ -1445,7 +1445,7 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm): ...@@ -1445,7 +1445,7 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm):
return out, param_names, tuple(x.grad for x in params) return out, param_names, tuple(x.grad for x in params)
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 3), reason="cuDNN 8.9.3+ is required.") @pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.") @pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.")
@pytest.mark.parametrize("dtype", param_types_fp8_vs_f16) @pytest.mark.parametrize("dtype", param_types_fp8_vs_f16)
...@@ -1654,7 +1654,14 @@ models_v0 = ["fp8_1", "fp8_2", "fp8_5", "fp8_6"] ...@@ -1654,7 +1654,14 @@ models_v0 = ["fp8_1", "fp8_2", "fp8_5", "fp8_6"]
models_v1 = ["fp8_3", "fp8_4", "fp8_7", "fp8_8"] models_v1 = ["fp8_3", "fp8_4", "fp8_7", "fp8_8"]
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 3), reason="cuDNN 8.9.3+ is required.") @pytest.mark.skipif(
(
get_cudnn_version() < (8, 9, 3)
if cudnn_frontend_version == 0
else get_cudnn_version() < (9, 2, 1)
),
reason=f"""cuDNN {"8.9.3" if cudnn_frontend_version == 0 else "9.2.1"}+ is required.""",
)
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.") @pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.")
@pytest.mark.parametrize("dtype", param_types_fp8) @pytest.mark.parametrize("dtype", param_types_fp8)
......
...@@ -20,6 +20,7 @@ from transformer_engine.pytorch.utils import ( ...@@ -20,6 +20,7 @@ from transformer_engine.pytorch.utils import (
init_method_normal, init_method_normal,
scaled_init_method_normal, scaled_init_method_normal,
is_bf16_compatible, is_bf16_compatible,
get_cudnn_version,
) )
from transformer_engine.pytorch import ( from transformer_engine.pytorch import (
LayerNormLinear, LayerNormLinear,
...@@ -1004,6 +1005,7 @@ def test_sanity_fp8_gemm_with_unalignment(N, datatype): ...@@ -1004,6 +1005,7 @@ def test_sanity_fp8_gemm_with_unalignment(N, datatype):
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability() != (9, 0), reason="FP8 tests require Hopper.") @pytest.mark.skipif(get_device_compute_capability() != (9, 0), reason="FP8 tests require Hopper.")
@pytest.mark.skipif(get_cudnn_version() < (9, 3, 0), reason="cuDNN 9.3.0+ is required.")
@pytest.mark.parametrize("model", ["large"]) @pytest.mark.parametrize("model", ["large"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_sanity_attention_extra_state(model, dtype): def test_sanity_attention_extra_state(model, dtype):
......
...@@ -85,7 +85,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -85,7 +85,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
(((cudnn_runtime_version >= 8900) && (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) && (((cudnn_runtime_version >= 8900) && (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) &&
(max_seqlen_q == max_seqlen_kv) && (max_seqlen_q <= 512) && (head_dim == 64) && (max_seqlen_q == max_seqlen_kv) && (max_seqlen_q <= 512) && (head_dim == 64) &&
(attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK)) || (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK)) ||
((cudnn_runtime_version >= 90100) && (max_seqlen_q % 128 == 0) && ((cudnn_runtime_version >= 90201) && (max_seqlen_q % 128 == 0) &&
(max_seqlen_kv % 128 == 0) && (head_dim == 128) && (max_seqlen_kv % 128 == 0) && (head_dim == 128) &&
((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) ||
(qkv_format == NVTE_QKV_Format::NVTE_SBHD)) && (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) &&
......
...@@ -4179,6 +4179,7 @@ class FusedAttention(torch.nn.Module): ...@@ -4179,6 +4179,7 @@ class FusedAttention(torch.nn.Module):
and cu_seqlens_q is not None and cu_seqlens_q is not None
and cu_seqlens_kv is not None and cu_seqlens_kv is not None
), "max_seqlen_q/kv and cu_seqlens_q/kv can not be None when qkv_format is thd!" ), "max_seqlen_q/kv and cu_seqlens_q/kv can not be None when qkv_format is thd!"
if cu_seqlens_q_padded is None or cu_seqlens_kv_padded is None: if cu_seqlens_q_padded is None or cu_seqlens_kv_padded is None:
cu_seqlens_q_padded = cu_seqlens_q cu_seqlens_q_padded = cu_seqlens_q
cu_seqlens_kv_padded = cu_seqlens_kv cu_seqlens_kv_padded = cu_seqlens_kv
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
"""GroupedLinear API""" """GroupedLinear API"""
import os import os
import logging
from typing import Union, Optional, Callable, Tuple, List, Dict, Any from typing import Union, Optional, Callable, Tuple, List, Dict, Any
import torch import torch
...@@ -44,7 +45,16 @@ from ..jit import no_torch_dynamo ...@@ -44,7 +45,16 @@ from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing from ..graph import is_graph_capturing
from ..float8_tensor import Float8Tensor from ..float8_tensor import Float8Tensor
# NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0")) _NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
# NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0
_NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0"))
log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL
log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG}
logging.basicConfig(
format="[%(levelname)-8s | %(name)-19s]: %(message)s",
level=log_levels[log_level if log_level in [0, 1, 2] else 2],
)
__all__ = ["GroupedLinear"] __all__ = ["GroupedLinear"]
...@@ -95,6 +105,7 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -95,6 +105,7 @@ class _GroupedLinear(torch.autograd.Function):
is_grad_enabled: bool, is_grad_enabled: bool,
*weights_and_biases: Union[Float8Tensor, torch.Tensor, None], *weights_and_biases: Union[Float8Tensor, torch.Tensor, None],
) -> torch.Tensor: ) -> torch.Tensor:
logger = logging.getLogger("GroupedLinear")
num_gemms = len(m_splits) num_gemms = len(m_splits)
weights = weights_and_biases[:num_gemms] weights = weights_and_biases[:num_gemms]
weights_fp8 = weights_and_biases[num_gemms : 2 * num_gemms] weights_fp8 = weights_and_biases[num_gemms : 2 * num_gemms]
...@@ -149,8 +160,7 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -149,8 +160,7 @@ class _GroupedLinear(torch.autograd.Function):
inputmats = inputmats_no_fp8 inputmats = inputmats_no_fp8
if fp8: if fp8:
if _NVTE_DEBUG: logger.debug("Running forward in FP8")
print("[GroupedLinear]: using FP8 forward")
bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype
biases = [cast_if_needed(bias, bias_dtype) for bias in biases] if use_bias else biases biases = [cast_if_needed(bias, bias_dtype) for bias in biases] if use_bias else biases
...@@ -188,8 +198,7 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -188,8 +198,7 @@ class _GroupedLinear(torch.autograd.Function):
# unpad the output # unpad the output
out = torch.cat([o[: m_splits[i]] for i, o in enumerate(out_list)], dim=0) out = torch.cat([o[: m_splits[i]] for i, o in enumerate(out_list)], dim=0)
else: else:
if _NVTE_DEBUG: logger.debug("Running forward in %s", activation_dtype)
print("[GroupedLinear]: using non-FP8 forward")
# Cast for native AMP # Cast for native AMP
weights = [cast_if_needed(w, activation_dtype) for w in weights] weights = [cast_if_needed(w, activation_dtype) for w in weights]
...@@ -294,6 +303,7 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -294,6 +303,7 @@ class _GroupedLinear(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
logger = logging.getLogger("GroupedLinear")
with torch.cuda.nvtx.range("_GroupedLinear_backward"): with torch.cuda.nvtx.range("_GroupedLinear_backward"):
( (
...@@ -361,8 +371,7 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -361,8 +371,7 @@ class _GroupedLinear(torch.autograd.Function):
if ctx.requires_dgrad: if ctx.requires_dgrad:
if ctx.fp8: if ctx.fp8:
if _NVTE_DEBUG: logger.debug("Running backward in FP8")
print("[GroupedLinear]: using FP8 backward")
dgrad_list = [ dgrad_list = [
torch.empty( torch.empty(
(grad_output_c[i].size(0), weights_fp8[i].size(1)), (grad_output_c[i].size(0), weights_fp8[i].size(1)),
...@@ -392,8 +401,7 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -392,8 +401,7 @@ class _GroupedLinear(torch.autograd.Function):
[d[: ctx.m_splits[i]] for i, d in enumerate(dgrad_list)], dim=0 [d[: ctx.m_splits[i]] for i, d in enumerate(dgrad_list)], dim=0
) )
else: else:
if _NVTE_DEBUG: logger.debug("Running backward in %s", ctx.activation_dtype)
print("[GroupedLinear]: using non-FP8 backward")
dgrad = torch.empty( dgrad = torch.empty(
(sum(ctx.m_splits), weights[0].size(1)), (sum(ctx.m_splits), weights[0].size(1)),
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
"""LayerNormLinear API""" """LayerNormLinear API"""
import os import os
import warnings import warnings
import logging
from typing import Any, Callable, Dict, Optional, Tuple, Union from typing import Any, Callable, Dict, Optional, Tuple, Union
import torch import torch
...@@ -47,7 +48,16 @@ from ..graph import is_graph_capturing ...@@ -47,7 +48,16 @@ from ..graph import is_graph_capturing
from ._common import _apply_normalization, _noop_cat from ._common import _apply_normalization, _noop_cat
from ..float8_tensor import Float8Tensor from ..float8_tensor import Float8Tensor
# NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0")) _NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
# NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0
_NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0"))
log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL
log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG}
logging.basicConfig(
format="[%(levelname)-8s | %(name)-19s]: %(message)s",
level=log_levels[log_level if log_level in [0, 1, 2] else 2],
)
__all__ = ["LayerNormLinear"] __all__ = ["LayerNormLinear"]
...@@ -94,6 +104,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -94,6 +104,7 @@ class _LayerNormLinear(torch.autograd.Function):
ub_name: str, ub_name: str,
fsdp_group: Union[dist_group_type, None], fsdp_group: Union[dist_group_type, None],
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
logger = logging.getLogger("LayerNormLinear")
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
in_features = ln_weight.numel() in_features = ln_weight.numel()
assert inp.shape[-1] == in_features, "GEMM not possible" assert inp.shape[-1] == in_features, "GEMM not possible"
...@@ -190,8 +201,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -190,8 +201,7 @@ class _LayerNormLinear(torch.autograd.Function):
ln_out = ln_out_total ln_out = ln_out_total
if fp8: if fp8:
if _NVTE_DEBUG: logger.debug("Running forward in FP8")
print("[LayerNormLinear]: using FP8 forward")
bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype
bias = cast_if_needed(bias, bias_dtype) if use_bias else bias bias = cast_if_needed(bias, bias_dtype) if use_bias else bias
...@@ -247,8 +257,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -247,8 +257,7 @@ class _LayerNormLinear(torch.autograd.Function):
dtype=activation_dtype, dtype=activation_dtype,
) )
else: else:
if _NVTE_DEBUG: logger.debug("Running forward in %s", activation_dtype)
print("[LayerNormLinear]: using non-FP8 forward")
# Cast for native AMP # Cast for native AMP
weight = cast_if_needed(weight, activation_dtype) weight = cast_if_needed(weight, activation_dtype)
...@@ -370,6 +379,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -370,6 +379,7 @@ class _LayerNormLinear(torch.autograd.Function):
def backward( def backward(
ctx, *grad_outputs: Tuple[torch.Tensor, ...] ctx, *grad_outputs: Tuple[torch.Tensor, ...]
) -> Tuple[Union[torch.Tensor, None], ...]: ) -> Tuple[Union[torch.Tensor, None], ...]:
logger = logging.getLogger("LayerNormLinear")
if isinstance(grad_outputs[0], Float8Tensor): if isinstance(grad_outputs[0], Float8Tensor):
ctx.fp8_meta["scaling_bwd"].scale_inv[tex.FP8BwdTensors.GRAD_OUTPUT1] = grad_outputs[ ctx.fp8_meta["scaling_bwd"].scale_inv[tex.FP8BwdTensors.GRAD_OUTPUT1] = grad_outputs[
0 0
...@@ -490,8 +500,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -490,8 +500,7 @@ class _LayerNormLinear(torch.autograd.Function):
ub_obj = None ub_obj = None
if ctx.fp8: if ctx.fp8:
if _NVTE_DEBUG: logger.debug("Running backward in FP8")
print("[LayerNormLinear]: using FP8 backward")
fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False) fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False)
...@@ -535,8 +544,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -535,8 +544,7 @@ class _LayerNormLinear(torch.autograd.Function):
) )
clear_tensor_data(grad_output_c) clear_tensor_data(grad_output_c)
else: else:
if _NVTE_DEBUG: logger.debug("Running backward in %s", ctx.activation_dtype)
print("[LayerNormLinear]: using non-FP8 backward")
# DGRAD: Evaluated unconditionally to feed into Linear backward # DGRAD: Evaluated unconditionally to feed into Linear backward
_, _, _ = tex.gemm( _, _, _ = tex.gemm(
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
"""Linear API""" """Linear API"""
import os import os
import logging
from typing import Any, Callable, Dict, Optional, Tuple, Union from typing import Any, Callable, Dict, Optional, Tuple, Union
import torch import torch
...@@ -50,7 +51,16 @@ from ..jit import no_torch_dynamo ...@@ -50,7 +51,16 @@ from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing from ..graph import is_graph_capturing
from ..float8_tensor import Float8Tensor from ..float8_tensor import Float8Tensor
# NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0")) _NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
# NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0
_NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0"))
log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL
log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG}
logging.basicConfig(
format="[%(levelname)-8s | %(name)-19s]: %(message)s",
level=log_levels[log_level if log_level in [0, 1, 2] else 2],
)
__all__ = ["Linear"] __all__ = ["Linear"]
...@@ -87,6 +97,7 @@ class _Linear(torch.autograd.Function): ...@@ -87,6 +97,7 @@ class _Linear(torch.autograd.Function):
is_first_module_in_mha: bool, is_first_module_in_mha: bool,
fsdp_group: Union[dist_group_type, None], fsdp_group: Union[dist_group_type, None],
) -> torch.Tensor: ) -> torch.Tensor:
logger = logging.getLogger("Linear")
is_input_fp8 = isinstance(inp, Float8Tensor) is_input_fp8 = isinstance(inp, Float8Tensor)
if is_input_fp8: if is_input_fp8:
fp8_meta["scaling_fwd"].scale_inv[tex.FP8FwdTensors.GEMM1_INPUT] = inp._scale_inv[0] fp8_meta["scaling_fwd"].scale_inv[tex.FP8FwdTensors.GEMM1_INPUT] = inp._scale_inv[0]
...@@ -147,8 +158,7 @@ class _Linear(torch.autograd.Function): ...@@ -147,8 +158,7 @@ class _Linear(torch.autograd.Function):
else: else:
inputmat_total = inputmat inputmat_total = inputmat
if fp8: if fp8:
if _NVTE_DEBUG: logger.debug("Running forward in FP8")
print("[Linear]: using FP8 forward")
bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype
bias = cast_if_needed(bias, bias_dtype) if use_bias else bias bias = cast_if_needed(bias, bias_dtype) if use_bias else bias
...@@ -238,8 +248,7 @@ class _Linear(torch.autograd.Function): ...@@ -238,8 +248,7 @@ class _Linear(torch.autograd.Function):
dtype=activation_dtype, dtype=activation_dtype,
) )
else: else:
if _NVTE_DEBUG: logger.debug("Running forward in %s", activation_dtype)
print("[Linear]: using non-FP8 forward")
# Cast for native AMP # Cast for native AMP
weight = cast_if_needed(weight, activation_dtype) weight = cast_if_needed(weight, activation_dtype)
...@@ -366,6 +375,7 @@ class _Linear(torch.autograd.Function): ...@@ -366,6 +375,7 @@ class _Linear(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
logger = logging.getLogger("Linear")
if isinstance(grad_output, Float8Tensor): if isinstance(grad_output, Float8Tensor):
ctx.fp8_meta["scaling_bwd"].scale_inv[ ctx.fp8_meta["scaling_bwd"].scale_inv[
tex.FP8BwdTensors.GRAD_OUTPUT1 tex.FP8BwdTensors.GRAD_OUTPUT1
...@@ -442,8 +452,7 @@ class _Linear(torch.autograd.Function): ...@@ -442,8 +452,7 @@ class _Linear(torch.autograd.Function):
if ctx.requires_dgrad: if ctx.requires_dgrad:
if ctx.fp8: if ctx.fp8:
if _NVTE_DEBUG: logger.debug("Running backward in FP8")
print("[Linear]: using FP8 backward")
if ctx.is_input_fp8: if ctx.is_input_fp8:
out_index, meta_tensor, output_te_dtype, output_dtype = ( out_index, meta_tensor, output_te_dtype, output_dtype = (
...@@ -487,8 +496,7 @@ class _Linear(torch.autograd.Function): ...@@ -487,8 +496,7 @@ class _Linear(torch.autograd.Function):
dtype=ctx.activation_dtype, dtype=ctx.activation_dtype,
) )
else: else:
if _NVTE_DEBUG: logger.debug("Running backward in %s", ctx.activation_dtype)
print("[Linear]: using non-FP8 backward")
dgrad, _, _ = gemm( dgrad, _, _ = gemm(
weight, weight,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment