Unverified Commit 121ff62a authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

[PyTorch] Improve logging/messaging in attention (#1074)



* fix logging in attention
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove logging in fwd/bwd methods due to CPU overhead
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: fix check_set_window_size messaging
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix typo
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix window_size messaging
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* remove redundant imports
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 5bb3a412
...@@ -98,12 +98,12 @@ META_DP = tex.FP8BwdTensors.GRAD_INPUT3 ...@@ -98,12 +98,12 @@ META_DP = tex.FP8BwdTensors.GRAD_INPUT3
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0")) _NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
# NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0 # NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0
_NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0")) _NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0"))
log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL _log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL
log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG} _log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG}
logging.basicConfig( _log_level = _log_levels[_log_level if _log_level in [0, 1, 2] else 2]
format="[%(levelname)-8s | %(name)-19s]: %(message)s", _formatter = logging.Formatter("[%(levelname)-8s | %(name)-19s]: %(message)s")
level=log_levels[log_level if log_level in [0, 1, 2] else 2], _stream_handler = logging.StreamHandler()
) _stream_handler.setFormatter(_formatter)
_NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1")) _NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1"))
_NVTE_FUSED_ATTN = int(os.getenv("NVTE_FUSED_ATTN", "1")) _NVTE_FUSED_ATTN = int(os.getenv("NVTE_FUSED_ATTN", "1"))
...@@ -266,6 +266,9 @@ def get_attention_backend( ...@@ -266,6 +266,9 @@ def get_attention_backend(
# Run config # Run config
logger = logging.getLogger("DotProductAttention") logger = logging.getLogger("DotProductAttention")
logger.setLevel(_log_level)
if not logger.hasHandlers():
logger.addHandler(_stream_handler)
device_compute_capability = get_device_compute_capability() device_compute_capability = get_device_compute_capability()
cudnn_version = get_cudnn_version() cudnn_version = get_cudnn_version()
run_config = { run_config = {
...@@ -3236,31 +3239,28 @@ def check_set_window_size( ...@@ -3236,31 +3239,28 @@ def check_set_window_size(
""" """
orig_window_size = window_size orig_window_size = window_size
if "causal" in attn_mask_type: if "causal" in attn_mask_type:
if orig_window_size is None or ( if orig_window_size is None:
orig_window_size[0] == -1 and orig_window_size[1] in [-1, 0]
):
window_size = (-1, 0) window_size = (-1, 0)
warnings.warn( elif orig_window_size == (-1, -1) or (
"window_size should be (-1, 0) or (>=0, 0) for attn_mask_type=" + attn_mask_type orig_window_size[0] >= 0 and orig_window_size[1] != 0
) ):
elif orig_window_size[0] >= 0:
window_size = (orig_window_size[0], 0) window_size = (orig_window_size[0], 0)
warnings.warn( warnings.warn(
"window_size should be (-1, 0) or (>=0, 0) for attn_mask_type=" + attn_mask_type "window_size should be (-1, 0) or (>=0, 0) for attn_mask_type=" + attn_mask_type
) )
else: elif orig_window_size != (-1, 0) and (orig_window_size[0] < 0 or orig_window_size[1] != 0):
assert False, ( assert False, (
"window_size should be (-1, 0) or (>=0, 0) for attn_mask_type=" + attn_mask_type "window_size should be (-1, 0) or (>=0, 0) for attn_mask_type=" + attn_mask_type
) )
elif attn_mask_type in ["no_mask", "padding", "arbitrary"]: elif attn_mask_type in ["no_mask", "padding", "arbitrary"]:
if orig_window_size is None or ( if orig_window_size is None:
orig_window_size[0] == -1 and orig_window_size[1] in [-1, 0] window_size = (-1, -1)
): elif orig_window_size == (-1, 0):
window_size = (-1, -1) window_size = (-1, -1)
warnings.warn( warnings.warn(
"window_size should be (-1, -1) or (>=0, >=0) for attn_mask_type=" + attn_mask_type "window_size should be (-1, -1) or (>=0, >=0) for attn_mask_type=" + attn_mask_type
) )
elif orig_window_size[0] < 0 or orig_window_size[1] < 0: elif orig_window_size != (-1, -1) and (orig_window_size[0] < 0 or orig_window_size[1] < 0):
assert False, ( assert False, (
"window_size should be (-1, -1) or (>=0, >=0) for attn_mask_type=" + attn_mask_type "window_size should be (-1, -1) or (>=0, >=0) for attn_mask_type=" + attn_mask_type
) )
...@@ -3560,9 +3560,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function): ...@@ -3560,9 +3560,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
fp8_meta, fp8_meta,
deterministic, deterministic,
): ):
logger = logging.getLogger("FusedAttnFunc_qkvpacked")
if fp8: if fp8:
logger.debug("Running forward in FP8")
if fp8_meta["recipe"].fp8_mha: if fp8_meta["recipe"].fp8_mha:
assert isinstance(qkv, Float8Tensor), "qkv must be Float8Tensors for FP8 MHA." assert isinstance(qkv, Float8Tensor), "qkv must be Float8Tensors for FP8 MHA."
fp8_meta["scaling_fwd"].scale_inv[META_QKV] = qkv._scale_inv fp8_meta["scaling_fwd"].scale_inv[META_QKV] = qkv._scale_inv
...@@ -3646,7 +3644,6 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function): ...@@ -3646,7 +3644,6 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
fp8_meta["scaling_fwd"].scale_inv.clone(), fp8_meta["scaling_fwd"].scale_inv.clone(),
) )
else: else:
logger.debug("Running forward in %s", qkv.dtype)
out_ret, aux_ctx_tensors = fused_attn_fwd_qkvpacked( out_ret, aux_ctx_tensors = fused_attn_fwd_qkvpacked(
is_training, is_training,
max_seqlen, max_seqlen,
...@@ -3699,7 +3696,6 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function): ...@@ -3699,7 +3696,6 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, d_out): def backward(ctx, d_out):
logger = logging.getLogger("FusedAttnFunc_qkvpacked")
if ctx.fp8_meta["recipe"].fp8_mha: if ctx.fp8_meta["recipe"].fp8_mha:
assert isinstance( assert isinstance(
d_out, Float8Tensor d_out, Float8Tensor
...@@ -3753,7 +3749,6 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function): ...@@ -3753,7 +3749,6 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
else: else:
with torch.cuda.nvtx.range("_FusedAttn_qkvpacked"): with torch.cuda.nvtx.range("_FusedAttn_qkvpacked"):
if ctx.fp8: if ctx.fp8:
logger.debug("Running backward in FP8")
fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
fp8_dtype_backward = get_fp8_te_dtype( fp8_dtype_backward = get_fp8_te_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=False ctx.fp8_meta["recipe"], fprop_tensor=False
...@@ -3819,7 +3814,6 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function): ...@@ -3819,7 +3814,6 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
ctx.qkv_dtype, ctx.qkv_dtype,
).view(dqkv_fp8.shape) ).view(dqkv_fp8.shape)
else: else:
logger.debug("Running backward in %s", qkv.dtype)
if d_out.dtype == torch.uint8: if d_out.dtype == torch.uint8:
d_out = d_out_f8tensor.from_float8(qkv.dtype) d_out = d_out_f8tensor.from_float8(qkv.dtype)
dqkv, *rest = fused_attn_bwd_qkvpacked( dqkv, *rest = fused_attn_bwd_qkvpacked(
...@@ -3937,9 +3931,7 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function): ...@@ -3937,9 +3931,7 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
fp8_meta, fp8_meta,
deterministic, deterministic,
): ):
logger = logging.getLogger("FusedAttnFunc_kvpacked")
if fp8: if fp8:
logger.debug("Running forward in FP8")
if fp8_meta["recipe"].fp8_mha: if fp8_meta["recipe"].fp8_mha:
assert isinstance(q, Float8Tensor) and isinstance( assert isinstance(q, Float8Tensor) and isinstance(
kv, Float8Tensor kv, Float8Tensor
...@@ -4036,7 +4028,6 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function): ...@@ -4036,7 +4028,6 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
fp8_meta["scaling_fwd"].scale_inv.clone(), fp8_meta["scaling_fwd"].scale_inv.clone(),
) )
else: else:
logger.debug("Running forward in %s", q.dtype)
out_ret, aux_ctx_tensors = fused_attn_fwd_kvpacked( out_ret, aux_ctx_tensors = fused_attn_fwd_kvpacked(
is_training, is_training,
max_seqlen_q, max_seqlen_q,
...@@ -4100,7 +4091,6 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function): ...@@ -4100,7 +4091,6 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, d_out): def backward(ctx, d_out):
logger = logging.getLogger("FusedAttnFunc_kvpacked")
if ctx.fp8_meta["recipe"].fp8_mha: if ctx.fp8_meta["recipe"].fp8_mha:
assert isinstance( assert isinstance(
d_out, Float8Tensor d_out, Float8Tensor
...@@ -4158,7 +4148,6 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function): ...@@ -4158,7 +4148,6 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
else: else:
with torch.cuda.nvtx.range("_FusedAttn_kvpacked"): with torch.cuda.nvtx.range("_FusedAttn_kvpacked"):
if ctx.fp8: if ctx.fp8:
logger.debug("Running backward in FP8")
fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
fp8_dtype_backward = get_fp8_te_dtype( fp8_dtype_backward = get_fp8_te_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=False ctx.fp8_meta["recipe"], fprop_tensor=False
...@@ -4243,7 +4232,6 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function): ...@@ -4243,7 +4232,6 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
ctx.qkv_dtype, ctx.qkv_dtype,
).view(dkv_fp8.shape) ).view(dkv_fp8.shape)
else: else:
logger.debug("Running backward in %s", q.dtype)
if d_out.dtype == torch.uint8: if d_out.dtype == torch.uint8:
d_out = d_out_f8tensor.from_float8(q.dtype) d_out = d_out_f8tensor.from_float8(q.dtype)
dq, dkv, *rest = fused_attn_bwd_kvpacked( dq, dkv, *rest = fused_attn_bwd_kvpacked(
...@@ -4374,9 +4362,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -4374,9 +4362,7 @@ class FusedAttnFunc(torch.autograd.Function):
fp8_meta, fp8_meta,
deterministic, deterministic,
): ):
logger = logging.getLogger("FusedAttnFunc")
if fp8: if fp8:
logger.debug("Running forward in FP8")
fused_attention_backend = FusedAttnBackend["FP8"] fused_attention_backend = FusedAttnBackend["FP8"]
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
if fp8_meta["recipe"].fp8_mha: if fp8_meta["recipe"].fp8_mha:
...@@ -4544,7 +4530,6 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -4544,7 +4530,6 @@ class FusedAttnFunc(torch.autograd.Function):
fp8_meta["scaling_fwd"].scale_inv.clone(), fp8_meta["scaling_fwd"].scale_inv.clone(),
) )
else: else:
logger.debug("Running forward in %s", q.dtype)
out_ret, aux_ctx_tensors = fused_attn_fwd( out_ret, aux_ctx_tensors = fused_attn_fwd(
is_training, is_training,
max_seqlen_q, max_seqlen_q,
...@@ -4618,7 +4603,6 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -4618,7 +4603,6 @@ class FusedAttnFunc(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, d_out): def backward(ctx, d_out):
logger = logging.getLogger("FusedAttnFunc")
if ctx.fp8_meta["recipe"].fp8_mha: if ctx.fp8_meta["recipe"].fp8_mha:
assert isinstance( assert isinstance(
d_out, Float8Tensor d_out, Float8Tensor
...@@ -4680,7 +4664,6 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -4680,7 +4664,6 @@ class FusedAttnFunc(torch.autograd.Function):
else: else:
with torch.cuda.nvtx.range("_FusedAttn"): with torch.cuda.nvtx.range("_FusedAttn"):
if ctx.fp8: if ctx.fp8:
logger.debug("Running backward in FP8")
fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
fp8_dtype_backward = get_fp8_te_dtype( fp8_dtype_backward = get_fp8_te_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=False ctx.fp8_meta["recipe"], fprop_tensor=False
...@@ -4818,7 +4801,6 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -4818,7 +4801,6 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.qkv_dtype, ctx.qkv_dtype,
).view(dv_fp8.shape) ).view(dv_fp8.shape)
else: else:
logger.debug("Running backward in %s", q.dtype)
if d_out.dtype == torch.uint8: if d_out.dtype == torch.uint8:
d_out = d_out_f8tensor.from_float8(q.dtype) d_out = d_out_f8tensor.from_float8(q.dtype)
dq, dk, dv, *rest = fused_attn_bwd( dq, dk, dv, *rest = fused_attn_bwd(
...@@ -4959,7 +4941,6 @@ class FusedAttention(torch.nn.Module): ...@@ -4959,7 +4941,6 @@ class FusedAttention(torch.nn.Module):
) -> None: ) -> None:
super().__init__() super().__init__()
self.logger = logging.getLogger("FusedAttention")
self.softmax_scale = softmax_scale self.softmax_scale = softmax_scale
self.attention_dropout = attention_dropout self.attention_dropout = attention_dropout
self.attention_dropout_ctx = attention_dropout_ctx self.attention_dropout_ctx = attention_dropout_ctx
...@@ -5306,6 +5287,9 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -5306,6 +5287,9 @@ class DotProductAttention(TransformerEngineBaseModule):
super().__init__() super().__init__()
self.logger = logging.getLogger("DotProductAttention") self.logger = logging.getLogger("DotProductAttention")
self.logger.setLevel(_log_level)
if not self.logger.hasHandlers():
self.logger.addHandler(_stream_handler)
self.qkv_format = qkv_format self.qkv_format = qkv_format
attn_mask_type = attn_mask_type.replace(",", "_") attn_mask_type = attn_mask_type.replace(",", "_")
if attn_mask_type == "causal_padding": if attn_mask_type == "causal_padding":
...@@ -5619,7 +5603,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -5619,7 +5603,7 @@ class DotProductAttention(TransformerEngineBaseModule):
if self.fp8_meta["recipe"].fp8_mha: if self.fp8_meta["recipe"].fp8_mha:
if not self.fp8_meta["recipe"].fp8_dpa: if not self.fp8_meta["recipe"].fp8_dpa:
self.fp8_meta["recipe"].fp8_dpa = True self.fp8_meta["recipe"].fp8_dpa = True
self.logger.WARNING( self.logger.warning(
"""Forcing fp8_meta["recipe"].fp8_dpa=True due to """ """Forcing fp8_meta["recipe"].fp8_dpa=True due to """
"""fp8_meta["recipe"].fp8_mha=True""" """fp8_meta["recipe"].fp8_mha=True"""
) )
......
...@@ -3,8 +3,6 @@ ...@@ -3,8 +3,6 @@
# See LICENSE for license information. # See LICENSE for license information.
"""GroupedLinear API""" """GroupedLinear API"""
import os
import logging
from typing import Union, Optional, Callable, Tuple, List, Dict, Any from typing import Union, Optional, Callable, Tuple, List, Dict, Any
import torch import torch
...@@ -45,17 +43,6 @@ from ..jit import no_torch_dynamo ...@@ -45,17 +43,6 @@ from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing from ..graph import is_graph_capturing
from ..float8_tensor import Float8Tensor from ..float8_tensor import Float8Tensor
# NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
# NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0
_NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0"))
log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL
log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG}
logging.basicConfig(
format="[%(levelname)-8s | %(name)-19s]: %(message)s",
level=log_levels[log_level if log_level in [0, 1, 2] else 2],
)
__all__ = ["GroupedLinear"] __all__ = ["GroupedLinear"]
""" """
...@@ -97,7 +84,6 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -97,7 +84,6 @@ class _GroupedLinear(torch.autograd.Function):
is_grad_enabled: bool, is_grad_enabled: bool,
*weights_and_biases: Union[Float8Tensor, torch.Tensor, None], *weights_and_biases: Union[Float8Tensor, torch.Tensor, None],
) -> torch.Tensor: ) -> torch.Tensor:
logger = logging.getLogger("GroupedLinear")
num_gemms = len(m_splits) num_gemms = len(m_splits)
weights = weights_and_biases[:num_gemms] weights = weights_and_biases[:num_gemms]
weights_fp8 = weights_and_biases[num_gemms : 2 * num_gemms] weights_fp8 = weights_and_biases[num_gemms : 2 * num_gemms]
...@@ -151,8 +137,6 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -151,8 +137,6 @@ class _GroupedLinear(torch.autograd.Function):
inputmats = inputmats_no_fp8 inputmats = inputmats_no_fp8
if fp8: if fp8:
logger.debug("Running forward in FP8")
bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype
biases = [cast_if_needed(bias, bias_dtype) for bias in biases] if use_bias else biases biases = [cast_if_needed(bias, bias_dtype) for bias in biases] if use_bias else biases
...@@ -184,8 +168,6 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -184,8 +168,6 @@ class _GroupedLinear(torch.autograd.Function):
use_split_accumulator=_2X_ACC_FPROP, use_split_accumulator=_2X_ACC_FPROP,
) )
else: else:
logger.debug("Running forward in %s", activation_dtype)
# Cast for native AMP # Cast for native AMP
weights = [cast_if_needed(w, activation_dtype) for w in weights] weights = [cast_if_needed(w, activation_dtype) for w in weights]
biases = ( biases = (
...@@ -286,8 +268,6 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -286,8 +268,6 @@ class _GroupedLinear(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
logger = logging.getLogger("GroupedLinear")
with torch.cuda.nvtx.range("_GroupedLinear_backward"): with torch.cuda.nvtx.range("_GroupedLinear_backward"):
( (
fwd_scale_inverses, fwd_scale_inverses,
...@@ -353,7 +333,6 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -353,7 +333,6 @@ class _GroupedLinear(torch.autograd.Function):
if ctx.requires_dgrad: if ctx.requires_dgrad:
if ctx.fp8: if ctx.fp8:
logger.debug("Running backward in FP8")
dgrad = torch.empty( dgrad = torch.empty(
(sum(ctx.m_splits), weights_fp8[i].size(1)), (sum(ctx.m_splits), weights_fp8[i].size(1)),
dtype=ctx.activation_dtype, dtype=ctx.activation_dtype,
...@@ -376,8 +355,6 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -376,8 +355,6 @@ class _GroupedLinear(torch.autograd.Function):
use_split_accumulator=_2X_ACC_DGRAD, use_split_accumulator=_2X_ACC_DGRAD,
) )
else: else:
logger.debug("Running backward in %s", ctx.activation_dtype)
dgrad = torch.empty( dgrad = torch.empty(
(sum(ctx.m_splits), weights[0].size(1)), (sum(ctx.m_splits), weights[0].size(1)),
dtype=ctx.activation_dtype, dtype=ctx.activation_dtype,
......
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
"""LayerNormLinear API""" """LayerNormLinear API"""
import os import os
import warnings import warnings
import logging
from typing import Any, Callable, Dict, Optional, Tuple, Union from typing import Any, Callable, Dict, Optional, Tuple, Union
import torch import torch
...@@ -48,17 +47,6 @@ from ..graph import is_graph_capturing ...@@ -48,17 +47,6 @@ from ..graph import is_graph_capturing
from ._common import _apply_normalization, _noop_cat from ._common import _apply_normalization, _noop_cat
from ..float8_tensor import Float8Tensor from ..float8_tensor import Float8Tensor
# NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
# NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0
_NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0"))
log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL
log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG}
logging.basicConfig(
format="[%(levelname)-8s | %(name)-19s]: %(message)s",
level=log_levels[log_level if log_level in [0, 1, 2] else 2],
)
__all__ = ["LayerNormLinear"] __all__ = ["LayerNormLinear"]
...@@ -104,7 +92,6 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -104,7 +92,6 @@ class _LayerNormLinear(torch.autograd.Function):
ub_name: str, ub_name: str,
fsdp_group: Union[dist_group_type, None], fsdp_group: Union[dist_group_type, None],
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
logger = logging.getLogger("LayerNormLinear")
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
in_features = ln_weight.numel() in_features = ln_weight.numel()
assert inp.shape[-1] == in_features, "GEMM not possible" assert inp.shape[-1] == in_features, "GEMM not possible"
...@@ -203,8 +190,6 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -203,8 +190,6 @@ class _LayerNormLinear(torch.autograd.Function):
ln_out = ln_out_total ln_out = ln_out_total
if fp8: if fp8:
logger.debug("Running forward in FP8")
bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype
bias = cast_if_needed(bias, bias_dtype) if use_bias else bias bias = cast_if_needed(bias, bias_dtype) if use_bias else bias
...@@ -259,8 +244,6 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -259,8 +244,6 @@ class _LayerNormLinear(torch.autograd.Function):
dtype=activation_dtype, dtype=activation_dtype,
) )
else: else:
logger.debug("Running forward in %s", activation_dtype)
# Cast for native AMP # Cast for native AMP
weight = cast_if_needed(weight, activation_dtype) weight = cast_if_needed(weight, activation_dtype)
bias = cast_if_needed(bias, activation_dtype) if use_bias else bias bias = cast_if_needed(bias, activation_dtype) if use_bias else bias
...@@ -379,7 +362,6 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -379,7 +362,6 @@ class _LayerNormLinear(torch.autograd.Function):
def backward( def backward(
ctx, *grad_outputs: Tuple[torch.Tensor, ...] ctx, *grad_outputs: Tuple[torch.Tensor, ...]
) -> Tuple[Union[torch.Tensor, None], ...]: ) -> Tuple[Union[torch.Tensor, None], ...]:
logger = logging.getLogger("LayerNormLinear")
if isinstance(grad_outputs[0], Float8Tensor): if isinstance(grad_outputs[0], Float8Tensor):
ctx.fp8_meta["scaling_bwd"].scale_inv[tex.FP8BwdTensors.GRAD_OUTPUT1] = grad_outputs[ ctx.fp8_meta["scaling_bwd"].scale_inv[tex.FP8BwdTensors.GRAD_OUTPUT1] = grad_outputs[
0 0
...@@ -500,8 +482,6 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -500,8 +482,6 @@ class _LayerNormLinear(torch.autograd.Function):
ub_obj = None ub_obj = None
if ctx.fp8: if ctx.fp8:
logger.debug("Running backward in FP8")
fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False) fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False)
out_index, meta_tensor, out_te_type, out_type = ( out_index, meta_tensor, out_te_type, out_type = (
...@@ -544,8 +524,6 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -544,8 +524,6 @@ class _LayerNormLinear(torch.autograd.Function):
) )
clear_tensor_data(grad_output_c) clear_tensor_data(grad_output_c)
else: else:
logger.debug("Running backward in %s", ctx.activation_dtype)
# DGRAD: Evaluated unconditionally to feed into Linear backward # DGRAD: Evaluated unconditionally to feed into Linear backward
_, _, _ = tex.gemm( _, _, _ = tex.gemm(
weight, weight,
......
...@@ -3,8 +3,6 @@ ...@@ -3,8 +3,6 @@
# See LICENSE for license information. # See LICENSE for license information.
"""Linear API""" """Linear API"""
import os
import logging
from typing import Any, Callable, Dict, Optional, Tuple, Union from typing import Any, Callable, Dict, Optional, Tuple, Union
import torch import torch
...@@ -51,17 +49,6 @@ from ..jit import no_torch_dynamo ...@@ -51,17 +49,6 @@ from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing from ..graph import is_graph_capturing
from ..float8_tensor import Float8Tensor from ..float8_tensor import Float8Tensor
# NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
# NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0
_NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0"))
log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL
log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG}
logging.basicConfig(
format="[%(levelname)-8s | %(name)-19s]: %(message)s",
level=log_levels[log_level if log_level in [0, 1, 2] else 2],
)
__all__ = ["Linear"] __all__ = ["Linear"]
...@@ -97,7 +84,6 @@ class _Linear(torch.autograd.Function): ...@@ -97,7 +84,6 @@ class _Linear(torch.autograd.Function):
is_first_module_in_mha: bool, is_first_module_in_mha: bool,
fsdp_group: Union[dist_group_type, None], fsdp_group: Union[dist_group_type, None],
) -> torch.Tensor: ) -> torch.Tensor:
logger = logging.getLogger("Linear")
is_input_fp8 = isinstance(inp, Float8Tensor) is_input_fp8 = isinstance(inp, Float8Tensor)
if is_input_fp8: if is_input_fp8:
fp8_meta["scaling_fwd"].scale_inv[tex.FP8FwdTensors.GEMM1_INPUT] = inp._scale_inv[0] fp8_meta["scaling_fwd"].scale_inv[tex.FP8FwdTensors.GEMM1_INPUT] = inp._scale_inv[0]
...@@ -158,8 +144,6 @@ class _Linear(torch.autograd.Function): ...@@ -158,8 +144,6 @@ class _Linear(torch.autograd.Function):
else: else:
inputmat_total = inputmat inputmat_total = inputmat
if fp8: if fp8:
logger.debug("Running forward in FP8")
bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype
bias = cast_if_needed(bias, bias_dtype) if use_bias else bias bias = cast_if_needed(bias, bias_dtype) if use_bias else bias
...@@ -248,8 +232,6 @@ class _Linear(torch.autograd.Function): ...@@ -248,8 +232,6 @@ class _Linear(torch.autograd.Function):
dtype=activation_dtype, dtype=activation_dtype,
) )
else: else:
logger.debug("Running forward in %s", activation_dtype)
# Cast for native AMP # Cast for native AMP
weight = cast_if_needed(weight, activation_dtype) weight = cast_if_needed(weight, activation_dtype)
bias = cast_if_needed(bias, activation_dtype) if use_bias else bias bias = cast_if_needed(bias, activation_dtype) if use_bias else bias
...@@ -373,7 +355,6 @@ class _Linear(torch.autograd.Function): ...@@ -373,7 +355,6 @@ class _Linear(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
logger = logging.getLogger("Linear")
if isinstance(grad_output, Float8Tensor): if isinstance(grad_output, Float8Tensor):
ctx.fp8_meta["scaling_bwd"].scale_inv[ ctx.fp8_meta["scaling_bwd"].scale_inv[
tex.FP8BwdTensors.GRAD_OUTPUT1 tex.FP8BwdTensors.GRAD_OUTPUT1
...@@ -450,8 +431,6 @@ class _Linear(torch.autograd.Function): ...@@ -450,8 +431,6 @@ class _Linear(torch.autograd.Function):
if ctx.requires_dgrad: if ctx.requires_dgrad:
if ctx.fp8: if ctx.fp8:
logger.debug("Running backward in FP8")
if ctx.is_input_fp8: if ctx.is_input_fp8:
out_index, meta_tensor, output_te_dtype, output_dtype = ( out_index, meta_tensor, output_te_dtype, output_dtype = (
tex.FP8BwdTensors.GRAD_INPUT1, tex.FP8BwdTensors.GRAD_INPUT1,
...@@ -494,8 +473,6 @@ class _Linear(torch.autograd.Function): ...@@ -494,8 +473,6 @@ class _Linear(torch.autograd.Function):
dtype=ctx.activation_dtype, dtype=ctx.activation_dtype,
) )
else: else:
logger.debug("Running backward in %s", ctx.activation_dtype)
dgrad, _, _ = gemm( dgrad, _, _ = gemm(
weight, weight,
grad_output, grad_output,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment