Unverified Commit 121ff62a authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

[PyTorch] Improve logging/messaging in attention (#1074)



* fix logging in attention
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove logging in fwd/bwd methods due to CPU overhead
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: fix check_set_window_size messaging
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix typo
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix window_size messaging
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* remove redundant imports
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 5bb3a412
......@@ -98,12 +98,12 @@ META_DP = tex.FP8BwdTensors.GRAD_INPUT3
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
# NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0
_NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0"))
log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL
log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG}
logging.basicConfig(
format="[%(levelname)-8s | %(name)-19s]: %(message)s",
level=log_levels[log_level if log_level in [0, 1, 2] else 2],
)
_log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL
_log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG}
_log_level = _log_levels[_log_level if _log_level in [0, 1, 2] else 2]
_formatter = logging.Formatter("[%(levelname)-8s | %(name)-19s]: %(message)s")
_stream_handler = logging.StreamHandler()
_stream_handler.setFormatter(_formatter)
_NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1"))
_NVTE_FUSED_ATTN = int(os.getenv("NVTE_FUSED_ATTN", "1"))
......@@ -266,6 +266,9 @@ def get_attention_backend(
# Run config
logger = logging.getLogger("DotProductAttention")
logger.setLevel(_log_level)
if not logger.hasHandlers():
logger.addHandler(_stream_handler)
device_compute_capability = get_device_compute_capability()
cudnn_version = get_cudnn_version()
run_config = {
......@@ -3236,31 +3239,28 @@ def check_set_window_size(
"""
orig_window_size = window_size
if "causal" in attn_mask_type:
if orig_window_size is None or (
orig_window_size[0] == -1 and orig_window_size[1] in [-1, 0]
):
if orig_window_size is None:
window_size = (-1, 0)
warnings.warn(
"window_size should be (-1, 0) or (>=0, 0) for attn_mask_type=" + attn_mask_type
)
elif orig_window_size[0] >= 0:
elif orig_window_size == (-1, -1) or (
orig_window_size[0] >= 0 and orig_window_size[1] != 0
):
window_size = (orig_window_size[0], 0)
warnings.warn(
"window_size should be (-1, 0) or (>=0, 0) for attn_mask_type=" + attn_mask_type
)
else:
elif orig_window_size != (-1, 0) and (orig_window_size[0] < 0 or orig_window_size[1] != 0):
assert False, (
"window_size should be (-1, 0) or (>=0, 0) for attn_mask_type=" + attn_mask_type
)
elif attn_mask_type in ["no_mask", "padding", "arbitrary"]:
if orig_window_size is None or (
orig_window_size[0] == -1 and orig_window_size[1] in [-1, 0]
):
if orig_window_size is None:
window_size = (-1, -1)
elif orig_window_size == (-1, 0):
window_size = (-1, -1)
warnings.warn(
"window_size should be (-1, -1) or (>=0, >=0) for attn_mask_type=" + attn_mask_type
)
elif orig_window_size[0] < 0 or orig_window_size[1] < 0:
elif orig_window_size != (-1, -1) and (orig_window_size[0] < 0 or orig_window_size[1] < 0):
assert False, (
"window_size should be (-1, -1) or (>=0, >=0) for attn_mask_type=" + attn_mask_type
)
......@@ -3560,9 +3560,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
fp8_meta,
deterministic,
):
logger = logging.getLogger("FusedAttnFunc_qkvpacked")
if fp8:
logger.debug("Running forward in FP8")
if fp8_meta["recipe"].fp8_mha:
assert isinstance(qkv, Float8Tensor), "qkv must be Float8Tensors for FP8 MHA."
fp8_meta["scaling_fwd"].scale_inv[META_QKV] = qkv._scale_inv
......@@ -3646,7 +3644,6 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
fp8_meta["scaling_fwd"].scale_inv.clone(),
)
else:
logger.debug("Running forward in %s", qkv.dtype)
out_ret, aux_ctx_tensors = fused_attn_fwd_qkvpacked(
is_training,
max_seqlen,
......@@ -3699,7 +3696,6 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
@staticmethod
def backward(ctx, d_out):
logger = logging.getLogger("FusedAttnFunc_qkvpacked")
if ctx.fp8_meta["recipe"].fp8_mha:
assert isinstance(
d_out, Float8Tensor
......@@ -3753,7 +3749,6 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
else:
with torch.cuda.nvtx.range("_FusedAttn_qkvpacked"):
if ctx.fp8:
logger.debug("Running backward in FP8")
fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
fp8_dtype_backward = get_fp8_te_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=False
......@@ -3819,7 +3814,6 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
ctx.qkv_dtype,
).view(dqkv_fp8.shape)
else:
logger.debug("Running backward in %s", qkv.dtype)
if d_out.dtype == torch.uint8:
d_out = d_out_f8tensor.from_float8(qkv.dtype)
dqkv, *rest = fused_attn_bwd_qkvpacked(
......@@ -3937,9 +3931,7 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
fp8_meta,
deterministic,
):
logger = logging.getLogger("FusedAttnFunc_kvpacked")
if fp8:
logger.debug("Running forward in FP8")
if fp8_meta["recipe"].fp8_mha:
assert isinstance(q, Float8Tensor) and isinstance(
kv, Float8Tensor
......@@ -4036,7 +4028,6 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
fp8_meta["scaling_fwd"].scale_inv.clone(),
)
else:
logger.debug("Running forward in %s", q.dtype)
out_ret, aux_ctx_tensors = fused_attn_fwd_kvpacked(
is_training,
max_seqlen_q,
......@@ -4100,7 +4091,6 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
@staticmethod
def backward(ctx, d_out):
logger = logging.getLogger("FusedAttnFunc_kvpacked")
if ctx.fp8_meta["recipe"].fp8_mha:
assert isinstance(
d_out, Float8Tensor
......@@ -4158,7 +4148,6 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
else:
with torch.cuda.nvtx.range("_FusedAttn_kvpacked"):
if ctx.fp8:
logger.debug("Running backward in FP8")
fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
fp8_dtype_backward = get_fp8_te_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=False
......@@ -4243,7 +4232,6 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
ctx.qkv_dtype,
).view(dkv_fp8.shape)
else:
logger.debug("Running backward in %s", q.dtype)
if d_out.dtype == torch.uint8:
d_out = d_out_f8tensor.from_float8(q.dtype)
dq, dkv, *rest = fused_attn_bwd_kvpacked(
......@@ -4374,9 +4362,7 @@ class FusedAttnFunc(torch.autograd.Function):
fp8_meta,
deterministic,
):
logger = logging.getLogger("FusedAttnFunc")
if fp8:
logger.debug("Running forward in FP8")
fused_attention_backend = FusedAttnBackend["FP8"]
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
if fp8_meta["recipe"].fp8_mha:
......@@ -4544,7 +4530,6 @@ class FusedAttnFunc(torch.autograd.Function):
fp8_meta["scaling_fwd"].scale_inv.clone(),
)
else:
logger.debug("Running forward in %s", q.dtype)
out_ret, aux_ctx_tensors = fused_attn_fwd(
is_training,
max_seqlen_q,
......@@ -4618,7 +4603,6 @@ class FusedAttnFunc(torch.autograd.Function):
@staticmethod
def backward(ctx, d_out):
logger = logging.getLogger("FusedAttnFunc")
if ctx.fp8_meta["recipe"].fp8_mha:
assert isinstance(
d_out, Float8Tensor
......@@ -4680,7 +4664,6 @@ class FusedAttnFunc(torch.autograd.Function):
else:
with torch.cuda.nvtx.range("_FusedAttn"):
if ctx.fp8:
logger.debug("Running backward in FP8")
fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
fp8_dtype_backward = get_fp8_te_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=False
......@@ -4818,7 +4801,6 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.qkv_dtype,
).view(dv_fp8.shape)
else:
logger.debug("Running backward in %s", q.dtype)
if d_out.dtype == torch.uint8:
d_out = d_out_f8tensor.from_float8(q.dtype)
dq, dk, dv, *rest = fused_attn_bwd(
......@@ -4959,7 +4941,6 @@ class FusedAttention(torch.nn.Module):
) -> None:
super().__init__()
self.logger = logging.getLogger("FusedAttention")
self.softmax_scale = softmax_scale
self.attention_dropout = attention_dropout
self.attention_dropout_ctx = attention_dropout_ctx
......@@ -5306,6 +5287,9 @@ class DotProductAttention(TransformerEngineBaseModule):
super().__init__()
self.logger = logging.getLogger("DotProductAttention")
self.logger.setLevel(_log_level)
if not self.logger.hasHandlers():
self.logger.addHandler(_stream_handler)
self.qkv_format = qkv_format
attn_mask_type = attn_mask_type.replace(",", "_")
if attn_mask_type == "causal_padding":
......@@ -5619,7 +5603,7 @@ class DotProductAttention(TransformerEngineBaseModule):
if self.fp8_meta["recipe"].fp8_mha:
if not self.fp8_meta["recipe"].fp8_dpa:
self.fp8_meta["recipe"].fp8_dpa = True
self.logger.WARNING(
self.logger.warning(
"""Forcing fp8_meta["recipe"].fp8_dpa=True due to """
"""fp8_meta["recipe"].fp8_mha=True"""
)
......
......@@ -3,8 +3,6 @@
# See LICENSE for license information.
"""GroupedLinear API"""
import os
import logging
from typing import Union, Optional, Callable, Tuple, List, Dict, Any
import torch
......@@ -45,17 +43,6 @@ from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing
from ..float8_tensor import Float8Tensor
# NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
# NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0
_NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0"))
log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL
log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG}
logging.basicConfig(
format="[%(levelname)-8s | %(name)-19s]: %(message)s",
level=log_levels[log_level if log_level in [0, 1, 2] else 2],
)
__all__ = ["GroupedLinear"]
"""
......@@ -97,7 +84,6 @@ class _GroupedLinear(torch.autograd.Function):
is_grad_enabled: bool,
*weights_and_biases: Union[Float8Tensor, torch.Tensor, None],
) -> torch.Tensor:
logger = logging.getLogger("GroupedLinear")
num_gemms = len(m_splits)
weights = weights_and_biases[:num_gemms]
weights_fp8 = weights_and_biases[num_gemms : 2 * num_gemms]
......@@ -151,8 +137,6 @@ class _GroupedLinear(torch.autograd.Function):
inputmats = inputmats_no_fp8
if fp8:
logger.debug("Running forward in FP8")
bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype
biases = [cast_if_needed(bias, bias_dtype) for bias in biases] if use_bias else biases
......@@ -184,8 +168,6 @@ class _GroupedLinear(torch.autograd.Function):
use_split_accumulator=_2X_ACC_FPROP,
)
else:
logger.debug("Running forward in %s", activation_dtype)
# Cast for native AMP
weights = [cast_if_needed(w, activation_dtype) for w in weights]
biases = (
......@@ -286,8 +268,6 @@ class _GroupedLinear(torch.autograd.Function):
@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
logger = logging.getLogger("GroupedLinear")
with torch.cuda.nvtx.range("_GroupedLinear_backward"):
(
fwd_scale_inverses,
......@@ -353,7 +333,6 @@ class _GroupedLinear(torch.autograd.Function):
if ctx.requires_dgrad:
if ctx.fp8:
logger.debug("Running backward in FP8")
dgrad = torch.empty(
(sum(ctx.m_splits), weights_fp8[i].size(1)),
dtype=ctx.activation_dtype,
......@@ -376,8 +355,6 @@ class _GroupedLinear(torch.autograd.Function):
use_split_accumulator=_2X_ACC_DGRAD,
)
else:
logger.debug("Running backward in %s", ctx.activation_dtype)
dgrad = torch.empty(
(sum(ctx.m_splits), weights[0].size(1)),
dtype=ctx.activation_dtype,
......
......@@ -5,7 +5,6 @@
"""LayerNormLinear API"""
import os
import warnings
import logging
from typing import Any, Callable, Dict, Optional, Tuple, Union
import torch
......@@ -48,17 +47,6 @@ from ..graph import is_graph_capturing
from ._common import _apply_normalization, _noop_cat
from ..float8_tensor import Float8Tensor
# NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
# NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0
_NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0"))
log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL
log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG}
logging.basicConfig(
format="[%(levelname)-8s | %(name)-19s]: %(message)s",
level=log_levels[log_level if log_level in [0, 1, 2] else 2],
)
__all__ = ["LayerNormLinear"]
......@@ -104,7 +92,6 @@ class _LayerNormLinear(torch.autograd.Function):
ub_name: str,
fsdp_group: Union[dist_group_type, None],
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
logger = logging.getLogger("LayerNormLinear")
# Make sure input dimensions are compatible
in_features = ln_weight.numel()
assert inp.shape[-1] == in_features, "GEMM not possible"
......@@ -203,8 +190,6 @@ class _LayerNormLinear(torch.autograd.Function):
ln_out = ln_out_total
if fp8:
logger.debug("Running forward in FP8")
bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype
bias = cast_if_needed(bias, bias_dtype) if use_bias else bias
......@@ -259,8 +244,6 @@ class _LayerNormLinear(torch.autograd.Function):
dtype=activation_dtype,
)
else:
logger.debug("Running forward in %s", activation_dtype)
# Cast for native AMP
weight = cast_if_needed(weight, activation_dtype)
bias = cast_if_needed(bias, activation_dtype) if use_bias else bias
......@@ -379,7 +362,6 @@ class _LayerNormLinear(torch.autograd.Function):
def backward(
ctx, *grad_outputs: Tuple[torch.Tensor, ...]
) -> Tuple[Union[torch.Tensor, None], ...]:
logger = logging.getLogger("LayerNormLinear")
if isinstance(grad_outputs[0], Float8Tensor):
ctx.fp8_meta["scaling_bwd"].scale_inv[tex.FP8BwdTensors.GRAD_OUTPUT1] = grad_outputs[
0
......@@ -500,8 +482,6 @@ class _LayerNormLinear(torch.autograd.Function):
ub_obj = None
if ctx.fp8:
logger.debug("Running backward in FP8")
fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False)
out_index, meta_tensor, out_te_type, out_type = (
......@@ -544,8 +524,6 @@ class _LayerNormLinear(torch.autograd.Function):
)
clear_tensor_data(grad_output_c)
else:
logger.debug("Running backward in %s", ctx.activation_dtype)
# DGRAD: Evaluated unconditionally to feed into Linear backward
_, _, _ = tex.gemm(
weight,
......
......@@ -3,8 +3,6 @@
# See LICENSE for license information.
"""Linear API"""
import os
import logging
from typing import Any, Callable, Dict, Optional, Tuple, Union
import torch
......@@ -51,17 +49,6 @@ from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing
from ..float8_tensor import Float8Tensor
# NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
# NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0
_NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0"))
log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL
log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG}
logging.basicConfig(
format="[%(levelname)-8s | %(name)-19s]: %(message)s",
level=log_levels[log_level if log_level in [0, 1, 2] else 2],
)
__all__ = ["Linear"]
......@@ -97,7 +84,6 @@ class _Linear(torch.autograd.Function):
is_first_module_in_mha: bool,
fsdp_group: Union[dist_group_type, None],
) -> torch.Tensor:
logger = logging.getLogger("Linear")
is_input_fp8 = isinstance(inp, Float8Tensor)
if is_input_fp8:
fp8_meta["scaling_fwd"].scale_inv[tex.FP8FwdTensors.GEMM1_INPUT] = inp._scale_inv[0]
......@@ -158,8 +144,6 @@ class _Linear(torch.autograd.Function):
else:
inputmat_total = inputmat
if fp8:
logger.debug("Running forward in FP8")
bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype
bias = cast_if_needed(bias, bias_dtype) if use_bias else bias
......@@ -248,8 +232,6 @@ class _Linear(torch.autograd.Function):
dtype=activation_dtype,
)
else:
logger.debug("Running forward in %s", activation_dtype)
# Cast for native AMP
weight = cast_if_needed(weight, activation_dtype)
bias = cast_if_needed(bias, activation_dtype) if use_bias else bias
......@@ -373,7 +355,6 @@ class _Linear(torch.autograd.Function):
@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
logger = logging.getLogger("Linear")
if isinstance(grad_output, Float8Tensor):
ctx.fp8_meta["scaling_bwd"].scale_inv[
tex.FP8BwdTensors.GRAD_OUTPUT1
......@@ -450,8 +431,6 @@ class _Linear(torch.autograd.Function):
if ctx.requires_dgrad:
if ctx.fp8:
logger.debug("Running backward in FP8")
if ctx.is_input_fp8:
out_index, meta_tensor, output_te_dtype, output_dtype = (
tex.FP8BwdTensors.GRAD_INPUT1,
......@@ -494,8 +473,6 @@ class _Linear(torch.autograd.Function):
dtype=ctx.activation_dtype,
)
else:
logger.debug("Running backward in %s", ctx.activation_dtype)
dgrad, _, _ = gemm(
weight,
grad_output,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment