Unverified Commit 605786f4 authored by Przemyslaw Tredak's avatar Przemyslaw Tredak Committed by GitHub
Browse files

[pyTorch] CPU performance optimizations (#2439)



* PoC of the changes
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Early exit from the Free function for the empty tensor
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Use the proper function for nvtx range
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Only do mark_not_offload when the cpu_offloading is enabled
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* First pass on making the setattr issue not come back
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Actually add pytest.ini
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Changes to __init__
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* A different way
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* WAR the fact that it is not possible to set __setattr__ dynamically
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Simpler solution and fixes
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix for the inference mode DPA
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Start of debugging debug tools
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* More fixes in debug
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Speculative moving the validate_name to the constructor
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Making the debug tools names saner
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Change the setattr usage in the tensor parallel group setting
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Adding try/finally - it does not seem to impact the time in observable
way
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fixing lint issues and the thunder test
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix 1 of the debug tests
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Removed the warning and enforcement in the CI
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* try-finally in the context manager
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fixing the debug tests
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

---------
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 36f4e451
......@@ -2790,7 +2790,7 @@ class Custom_MHA_FP8(TransformerEngineBaseModule):
cu_seqlens,
max_s,
) -> torch.Tensor:
with self.prepare_forward(inp, num_gemms=3) as inp:
with self.prepare_forward_ctx(inp, num_gemms=3) as inp:
out = _custom_mha_fp8.apply(
inp,
self.qkv_weight,
......
......@@ -30,10 +30,17 @@ configs = {
stats: [min, max, mean, std, l1_norm, l2_norm, cur_amax, dynamic_range]
start_step : 0
end_step: 1
""",
"log_fp8": """log_fp8:
layers:
layer_types: [linear]
enabled:
True
transformer_engine:
LogFp8TensorStats:
enabled: True
tensors: [activation, gradient, weight]
stats: [underflows, overflows]
stats: [underflows%]
start_step : 0
end_step: 1
""",
......@@ -46,22 +53,26 @@ fake_quant_config:
FakeQuant:
enabled: True
gemms: [fprop, dgrad, wgrad]
tensors: [activation, weight, gradient]
quant_format: FP8E5M2
""",
}
# Configs that require FP8 to be enabled
fp8_required_configs = {"log_fp8"}
def _get_model(model_key):
if model_key == "linear":
return te.Linear(D, D)
return te.Linear(D, D, name="layer")
if model_key == "layernorm_linear":
return te.LayerNormLinear(D, D)
return te.LayerNormLinear(D, D, name="layer")
if model_key == "layernorm_mlp":
return te.LayerNormMLP(D, D, D)
return te.LayerNormMLP(D, D, D, name="layer")
if model_key == "mha_attention":
return te.MultiheadAttention(D, H)
return te.MultiheadAttention(D, H, name="layer")
if model_key == "transformer_layer":
return te.TransformerLayer(D, D, H)
return te.TransformerLayer(D, D, H, name="layer")
def _run_forward_backward(model, fp8):
......@@ -95,4 +106,6 @@ def _run_test(model_key, fp8, config, feature_dirs, config_file, log_dir):
def test_sanity_debug(model_key, fp8, config_key, feature_dirs):
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if not fp8 and config_key in fp8_required_configs:
pytest.skip(f"Config '{config_key}' requires FP8")
_run_test(model_key, fp8, configs[config_key], feature_dirs)
......@@ -454,9 +454,9 @@ class TensorAllocator {
}
void Free(NVTETensor t) {
std::lock_guard<std::mutex> lock(mutex);
uintptr_t index = reinterpret_cast<uintptr_t>(t);
if (index == 0) return;
std::lock_guard<std::mutex> lock(mutex);
NVTE_CHECK(index <= memory.size(), "Invalid tensor.");
free_list.push_back(index);
// Clean up
......@@ -564,9 +564,9 @@ class GroupedTensorAllocator {
}
void Free(NVTEGroupedTensor t) {
std::lock_guard<std::mutex> lock(mutex);
uintptr_t index = reinterpret_cast<uintptr_t>(t);
if (index == 0) return;
std::lock_guard<std::mutex> lock(mutex);
NVTE_CHECK(index <= memory.size(), "Invalid grouped tensor.");
free_list.push_back(index);
// Clean up
......
......@@ -676,9 +676,9 @@ class DotProductAttention(TransformerEngineBaseModule):
# assume attention uses the same fp8_group as GEMMs
fp8_group = FP8GlobalStateManager.get_fp8_group()
self.fp8_parameters = FP8GlobalStateManager.with_fp8_parameters()
self.fp8 = FP8GlobalStateManager.is_fp8_enabled()
self.fp8_calibration = FP8GlobalStateManager.is_fp8_calibration()
self.fast_setattr("fp8_parameters", FP8GlobalStateManager.with_fp8_parameters())
self.fast_setattr("fp8", FP8GlobalStateManager.is_fp8_enabled())
self.fast_setattr("fp8_calibration", FP8GlobalStateManager.is_fp8_calibration())
fp8_enabled = self.fp8 or self.fp8_calibration
self.fp8_meta["fp8_checkpoint"] = self.fp8 or self.fp8_calibration
if self.fp8_parameters or fp8_enabled:
......@@ -703,7 +703,7 @@ class DotProductAttention(TransformerEngineBaseModule):
)
else:
# If fp8 isn't enabled, turn off and return.
self.fp8_initialized = False
self.fast_setattr("fp8_initialized", False)
return
if self.fp8_parameters and not self.fp8_initialized:
......@@ -721,7 +721,7 @@ class DotProductAttention(TransformerEngineBaseModule):
# Allocate scales and amaxes
self.init_fp8_meta_tensors(fp8_recipes)
self.fp8_initialized = True
self.fast_setattr("fp8_initialized", True)
self.fp8_meta["recipe"] = fp8_recipe_dpa
if fp8_recipe != fp8_recipe_dpa:
......@@ -1000,7 +1000,7 @@ class DotProductAttention(TransformerEngineBaseModule):
cases. It is ignored for other backends and when context parallelism is enabled.
"""
with self.prepare_forward(
with self.prepare_forward_ctx(
query_layer,
num_gemms=3,
allow_non_contiguous=True,
......@@ -1145,7 +1145,8 @@ class DotProductAttention(TransformerEngineBaseModule):
if attn_mask_type == "padding_causal":
attn_mask_type = attn_mask_type + "_bottom_right"
self.attention_type = "cross"
if self.attention_type != "cross":
self.fast_setattr("attention_type", "cross")
self.flash_attention.attention_type = self.attention_type
self.fused_attention.attention_type = self.attention_type
self.unfused_attention.attention_type = self.attention_type
......
......@@ -8,7 +8,6 @@ import collections
from typing import Callable, List, Optional, Tuple, Union
import torch
from transformer_engine.debug.pytorch.debug_state import TEDebugState
from transformer_engine.pytorch.quantization import FP8GlobalStateManager
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
......@@ -335,6 +334,7 @@ class MultiheadAttention(torch.nn.Module):
self.hidden_size_kv = self.hidden_size_per_attention_head * self.num_gqa_groups
self.name = name
TransformerEngineBaseModule._validate_name(self)
common_gemm_kwargs = {
"fuse_wgrad_accumulation": fuse_wgrad_accumulation,
......@@ -739,9 +739,6 @@ class MultiheadAttention(torch.nn.Module):
core_attention_bias_type in AttnBiasTypes
), f"core_attention_bias_type {core_attention_bias_type} is not supported!"
if TEDebugState.debug_enabled:
TransformerEngineBaseModule._validate_name(self)
# =================================================
# Pre-allocate memory for key-value cache for inference
# =================================================
......
......@@ -729,8 +729,8 @@ def checkpoint(
if isinstance(function, TransformerEngineBaseModule):
# If this TE module is FSDP-wrapped, clear its FSDP group information because there's no need
# to scatter/gather activations that we will recompute anyway.
setattr(function, "fsdp_wrapped", False)
setattr(function, "fsdp_group", None)
function.fast_setattr("fsdp_wrapped", False)
function.fast_setattr("fsdp_group", None)
# Otherwise discard unused te.utils.checkpoint.checkpoint() arguments
# and execute TE's own checkpointing
......@@ -2022,7 +2022,7 @@ def prepare_te_modules_for_fsdp(fsdp_root: torch.nn.Module) -> None:
)
root_state = _get_module_fsdp_state(fsdp_root)
assert root_state is not None, "Root module does not have a valid _FSDPState."
setattr(fsdp_root.module, "fsdp_group", root_state.process_group)
fsdp_root.module.fast_setattr("fsdp_group", root_state.process_group)
# Iterate through all FSDP-wrapped submodules and inject FSDP information into TE modules
fsdp_states, fsdp_modules = _get_fsdp_states_with_modules(fsdp_root)
......@@ -2033,7 +2033,7 @@ def prepare_te_modules_for_fsdp(fsdp_root: torch.nn.Module) -> None:
"TE modules with primary weights in FP8 cannot be FSDP-wrapped. "
"Please initialize your model without the te.quantized_model_init(...) context."
)
setattr(fsdp_module.module, "fsdp_group", state.process_group)
fsdp_module.module.fast_setattr("fsdp_group", state.process_group)
class FullyShardedDataParallel(FSDP):
......
......@@ -10,9 +10,8 @@ import pickle
import warnings
from enum import Enum
from abc import ABC, abstractmethod
from typing import Any, Dict, Generator, List, Optional, Set, Tuple, Union
from typing import Any, Dict, Generator, List, Optional, Tuple, Union
from contextlib import contextmanager
import logging
from types import MethodType
import torch
......@@ -50,6 +49,8 @@ from ..utils import (
is_non_tn_fp8_gemm_supported,
torch_get_autocast_gpu_dtype,
get_nvtx_range_context,
nvtx_range_push,
nvtx_range_pop,
)
from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage
from ...common.recipe import DelayedScaling, Recipe
......@@ -605,10 +606,10 @@ def fill_userbuffers_buffer_for_all_gather(
class TransformerEngineBaseModule(torch.nn.Module, ABC):
"""Base TE module."""
def __init__(self) -> None:
def __init__(self, name: Optional[str] = None) -> None:
super().__init__()
assert torch.cuda.is_available(), "TransformerEngine needs CUDA."
self.name = None
self.name = name
self.next_iter_when_debug_should_be_run = 0
self.fp8_initialized = False
self.fp8 = False
......@@ -633,25 +634,21 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if not TEDebugState.debug_enabled:
TEDebugState.initialize()
self._validate_name()
# Names of attributes that can be set quickly (see __setattr__
# method)
_fast_setattr_names: Set[str] = {
"activation_dtype",
"fp8",
"fp8_initialized",
"fp8_calibration",
"fp8_parameters",
}
def __setattr__(self, name: str, value: Any) -> None:
if name in TransformerEngineBaseModule._fast_setattr_names:
# torch.nn.Module has a custom __setattr__ that handles
# modules, parameters, and buffers. This is unnecessary
# overhead when setting plain attrs.
def fast_setattr(self, name: str, value: Any) -> None:
"""
Fast version of the Module's set attribute function.
Should be used for regular attributes, but not properties nor parameters/buffers.
"""
self.__dict__[name] = value
else:
# Default case
def module_setattr(self, name: str, value: Any) -> None:
"""
Regular version of the Module's set attribute function.
Should be used only when the fast version cannot be used - for the properties,
parameters and buffers.
"""
super().__setattr__(name, value)
def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> None:
......@@ -773,7 +770,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.set_meta_tensor(True, recipe)
self.set_meta_tensor(False, recipe)
self.fp8_meta_tensors_initialized = True
self.fast_setattr("fp8_meta_tensors_initialized", True)
def get_fp8_meta_tensors(self) -> None:
"""Get scales and amaxes."""
......@@ -930,7 +927,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
"""Get activation data type for AMP."""
# Native AMP (`torch.autocast`) gets highest priority
if torch.is_autocast_enabled():
self.activation_dtype = torch_get_autocast_gpu_dtype()
self.fast_setattr("activation_dtype", torch_get_autocast_gpu_dtype())
return
# All checks after this have already been performed once, thus skip
......@@ -945,7 +942,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
"Data types for parameters must match when outside of autocasted region. "
f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}"
)
self.activation_dtype = dtype
self.fast_setattr("activation_dtype", dtype)
def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
"""
......@@ -957,8 +954,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
tp_group : ProcessGroup, default = None
tensor parallel process group.
"""
self.tp_group = tp_group
self.tp_group_initialized = True
self.fast_setattr("tp_group", tp_group)
self.fast_setattr("tp_group_initialized", True)
def _get_fp8_params(self) -> Union[List[torch.Tensor], None]:
"""returns the FP8 weights."""
......@@ -974,48 +971,51 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# assume FP8 execution.
def init_fp8_metadata(self, num_gemms: int = 1) -> None:
"""Initialize fp8 related metadata and tensors during fprop."""
_original_recipe = self.fp8_meta.get("recipe", None)
self.fp8_parameters = FP8GlobalStateManager.with_fp8_parameters()
self.fp8 = FP8GlobalStateManager.is_fp8_enabled()
self.fp8_calibration = FP8GlobalStateManager.is_fp8_calibration()
fp8_enabled = self.fp8 or self.fp8_calibration
self.fp8_meta["fp8_checkpoint"] = self.fp8 or self.fp8_calibration
if self.fp8_parameters or fp8_enabled:
if (
self.fp8_initialized
and FP8GlobalStateManager.get_fp8_recipe() == self.fp8_meta["recipe"]
):
meta = self.fp8_meta
fp8 = FP8GlobalStateManager.is_fp8_enabled()
fp8_parameters = FP8GlobalStateManager.with_fp8_parameters()
fp8_calibration = FP8GlobalStateManager.is_fp8_calibration()
self.fast_setattr("fp8_parameters", fp8_parameters)
self.fast_setattr("fp8", fp8)
self.fast_setattr("fp8_calibration", fp8_calibration)
fp8_enabled = fp8 or fp8_calibration
meta["fp8_checkpoint"] = fp8_enabled
_original_recipe = None
if fp8_parameters or fp8_enabled:
_original_recipe = meta.get("recipe", None)
if self.fp8_initialized and FP8GlobalStateManager.get_fp8_recipe() == _original_recipe:
# FP8 init has already been run and recipe is the same, don't do anything.
return
self.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe()
meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe()
else:
# If fp8 isn't enabled, turn off and return.
self.fp8_initialized = False
self.fast_setattr("fp8_initialized", False)
return
if self.fp8_parameters and not self.fp8_initialized:
self.fp8_meta["num_gemms"] = num_gemms
self.init_fp8_meta_tensors(self.fp8_meta["recipe"])
if fp8_parameters and not self.fp8_initialized:
meta["num_gemms"] = num_gemms
self.init_fp8_meta_tensors(meta["recipe"])
if fp8_enabled:
# Set FP8 and other FP8 metadata
self.fp8_meta["num_gemms"] = num_gemms
self.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group()
meta["num_gemms"] = num_gemms
meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group()
# Set FP8_MAX per tensor according to recipe
if hasattr(self.fp8_meta["recipe"], "fp8_format"):
self.fp8_meta["fp8_max_fwd"] = self.fp8_meta["recipe"].fp8_format.value.max_fwd
self.fp8_meta["fp8_max_bwd"] = self.fp8_meta["recipe"].fp8_format.value.max_bwd
if hasattr(meta["recipe"], "fp8_format"):
meta["fp8_max_fwd"] = meta["recipe"].fp8_format.value.max_fwd
meta["fp8_max_bwd"] = meta["recipe"].fp8_format.value.max_bwd
# Allocate scales and amaxes
self.init_fp8_meta_tensors(self.fp8_meta["recipe"])
self.fp8_initialized = True
self.init_fp8_meta_tensors(meta["recipe"])
self.fast_setattr("fp8_initialized", True)
self.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe()
meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe()
_current_recipe = self.fp8_meta["recipe"]
_current_recipe = meta["recipe"]
if _original_recipe is not None and not (
issubclass(_current_recipe.__class__, _original_recipe.__class__)
or issubclass(_original_recipe.__class__, _current_recipe.__class__)
......@@ -1028,22 +1028,18 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# Clear cached workspaces as they were created with the old recipe/quantizer type
self._fp8_workspaces.clear()
@contextmanager
def prepare_forward(
self,
inp: torch.Tensor,
num_gemms: int = 1,
allow_non_contiguous: bool = False,
allow_different_data_and_param_types: bool = False,
) -> Generator[torch.Tensor, None, None]:
"""Checks and prep for FWD.
The context manager is needed because there isn't a way for a module to know
if it's the last FP8 module in the forward autocast. It is useful
to setup the forward aggregated amax reduction for every module
just in case. The autocast exit will pick up the most recent one.
"""
self.allow_different_data_and_param_types = allow_different_data_and_param_types
self.forwarded_at_least_once = True
) -> torch.Tensor:
"""Checks and prepares for FWD execution."""
self.fast_setattr(
"allow_different_data_and_param_types", allow_different_data_and_param_types
)
self.fast_setattr("forwarded_at_least_once", True)
# Activation recomputation is used and this is the second forward phase.
if self.fp8 and in_fp8_activation_recompute_phase():
......@@ -1074,13 +1070,37 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if self.training and is_fp8_activation_recompute_enabled():
FP8GlobalStateManager.copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta)
with get_nvtx_range_context(self.__class__.__name__ + " forward"):
nvtx_range_push(self.__class__.__name__ + " forward")
if not allow_non_contiguous and not inp.is_contiguous():
inp = inp.contiguous()
yield inp
return inp
def end_forward(self):
"""
Required to be called at the end of the forward function to properly handle
DelayedScaling metadata handling and the NVTX ranges.
"""
delayed_scaling_recipe = self.fp8 and self.fp8_meta["recipe"].delayed()
if delayed_scaling_recipe and self.fp8 and in_fp8_activation_recompute_phase():
FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta)
nvtx_range_pop()
@contextmanager
def prepare_forward_ctx(
self,
inp: torch.Tensor,
num_gemms: int = 1,
allow_non_contiguous: bool = False,
allow_different_data_and_param_types: bool = False,
) -> Generator[torch.Tensor, None, None]:
"""Checks and prepares for FWD execution."""
inp = self.prepare_forward(
inp, num_gemms, allow_non_contiguous, allow_different_data_and_param_types
)
try:
yield inp
finally:
self.end_forward()
def set_nccl_overlap_warning_if_tp(self) -> None:
"""When using TP, the NCCL communication needs to be scheduled
......@@ -1315,9 +1335,9 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# Update the parameter based on its type
if not is_dtensor:
setattr(self, name, param)
self.module_setattr(name, param)
else:
setattr(self, name, dtensor_param)
self.module_setattr(name, dtensor_param)
@abstractmethod
def forward(self):
......@@ -1516,7 +1536,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
debug = TEDebugState.debug_enabled
if not debug:
return False
self._validate_name()
# If layer is run first time in new iteration,
# we need to check if the debug should be enabled for this layer -
......@@ -1530,14 +1549,14 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
debug = False
else:
debug = TEDebugState.get_iteration() >= self.next_iter_when_debug_should_be_run
self.debug_last_iteration = TEDebugState.get_iteration()
self.debug_enabled_in_this_iteration = debug
self.fast_setattr("debug_last_iteration", TEDebugState.get_iteration())
self.fast_setattr("debug_enabled_in_this_iteration", debug)
else:
# If this is the same iteration as previous invocation of the module,
# we use the debug value from the first invocation in the iteration.
debug = self.debug_enabled_in_this_iteration
self.debug_last_iteration = TEDebugState.get_iteration()
self.fast_setattr("debug_last_iteration", TEDebugState.get_iteration())
if self.wgrad_store is not None:
if debug and self.wgrad_store.delay_wgrad_compute():
......@@ -1553,7 +1572,9 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# Sometimes features inform that they will not be enabled for particular layer
# for multiple next iterations.
self.next_iter_when_debug_should_be_run = next_iter_when_debug_should_be_run(quantizers)
self.fast_setattr(
"next_iter_when_debug_should_be_run", next_iter_when_debug_should_be_run(quantizers)
)
if not run_current:
return True
......@@ -1565,21 +1586,12 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
def _validate_name(self):
"""
Validate name passed to the module.
This is invoked in the forward() method as module names are assigned after Model is initialized in Megatron-LM.
If no name is assigned, it creates a default name with layer count as the variable.
It creates a default name with layer count as the variable
which may be changed by the user of the module.
"""
if self.name is not None:
return
assert TEDebugState.debug_enabled
import nvdlfw_inspect.api as debug_api
if self.name is None:
debug_api.log_message(
"Names are not provided to debug modules. ",
"Creating and using generic names. Pass names to debug modules for better"
" insight. ",
level=logging.WARNING,
)
self.name = f"Layer_{TEDebugState.get_layer_count()}"
def _check_weight_tensor_recipe_correspondence(self) -> None:
......
......@@ -614,7 +614,7 @@ class GroupedLinear(TransformerEngineBaseModule):
save_original_input: bool = False,
name: Optional[str] = None,
) -> None:
super().__init__()
super().__init__(name)
params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
self.num_gemms = num_gemms
......@@ -633,7 +633,6 @@ class GroupedLinear(TransformerEngineBaseModule):
), "GroupedLinear doesn't support Userbuffer overlap."
self.get_rng_state_tracker = get_rng_state_tracker
self.rng_tracker_name = rng_tracker_name
self.name = name
self.wgrad_store = WeightGradStore(delay_wgrad_compute)
......@@ -789,7 +788,8 @@ class GroupedLinear(TransformerEngineBaseModule):
is_grad_enabled = torch.is_grad_enabled()
with self.prepare_forward(inp, num_gemms=self.num_gemms) as inp:
inp = self.prepare_forward(inp, num_gemms=self.num_gemms)
try:
weight_tensors = self._get_weight_tensors()
bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)]
......@@ -844,6 +844,9 @@ class GroupedLinear(TransformerEngineBaseModule):
)
out = linear_fn(*autograd_ctx, inp, non_tensor_args, *weight_tensors, *bias_tensors)
finally:
self.end_forward()
if self.return_bias:
return out, [cast_if_needed(b, self.activation_dtype) for b in bias_tensors]
return out
......
......@@ -1158,9 +1158,9 @@ class LayerNormLinear(TransformerEngineBaseModule):
ub_name: Optional[str] = None,
delay_wgrad_compute: bool = False,
symmetric_ar_type: Optional[str] = None,
name: str = None,
name: Optional[str] = None,
) -> None:
super().__init__()
super().__init__(name)
params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
self.in_features = in_features
......@@ -1179,7 +1179,6 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.symmetric_ar_type = symmetric_ar_type
self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad)
self.name = name
if tp_group is None:
self.tp_size = tp_size
......@@ -1508,10 +1507,11 @@ class LayerNormLinear(TransformerEngineBaseModule):
).is_fp8_ubuf():
fp8_grad = True
with self.prepare_forward(
inp = self.prepare_forward(
inp, allow_non_contiguous=False # removed .contiguous from inside the layer
) as inp:
)
try:
# Get concatenated weight and bias tensors
weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()
......@@ -1590,6 +1590,9 @@ class LayerNormLinear(TransformerEngineBaseModule):
non_tensor_args,
)
finally:
self.end_forward()
if self.return_layernorm_output:
out, ln_out = out
......
......@@ -1787,7 +1787,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
zero_centered_gamma: bool = False,
device: Union[torch.device, str] = "cuda",
ub_overlap_ag: bool = False,
name: str = None,
name: Optional[str] = None,
ub_overlap_rs: bool = False,
ub_overlap_rs_dgrad: bool = False,
ub_bulk_dgrad: bool = False,
......@@ -1796,7 +1796,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
symmetric_ar_type: Optional[str] = None,
checkpoint: bool = False,
) -> None:
super().__init__()
super().__init__(name)
params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
self.fuse_wgrad_accumulation = fuse_wgrad_accumulation
......@@ -1827,7 +1827,6 @@ class LayerNormMLP(TransformerEngineBaseModule):
for use_fp8 in [False, True]
)
)
self.name = name
self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad)
......@@ -2047,8 +2046,9 @@ class LayerNormMLP(TransformerEngineBaseModule):
if get_ub("fc2_fprop", FP8GlobalStateManager.is_fp8_enabled()).is_fp8_ubuf():
fp8_output = True
with self.prepare_forward(inp, num_gemms=2) as inp:
inp = self.prepare_forward(inp, num_gemms=2)
try:
quantizers = (
self._get_quantizers(fp8_output, is_grad_enabled)
if not debug
......@@ -2087,7 +2087,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
# Disable bias_gelu_nvfusion for determinism checkpointing in non-reentrant mode
if self.bias_gelu_nvfusion and not use_reentrant_activation_recompute():
self.bias_gelu_nvfusion = False
self.fast_setattr("bias_gelu_nvfusion", False)
if is_grad_enabled:
fwd_fn = _LayerNormMLP.apply
......@@ -2157,6 +2157,9 @@ class LayerNormMLP(TransformerEngineBaseModule):
non_tensor_args,
)
finally:
self.end_forward()
if self.return_layernorm_output:
out, ln_out = out
......
......@@ -428,8 +428,8 @@ class _Linear(torch.autograd.Function):
# weights if weights are externally touched outside this module
ctx.weight_object = weight
if cpu_offloading:
mark_not_offload(weight, weightmat, bias)
# TODO(ksivamani): Check memory usage
tensors_to_save, tensor_objects = prepare_for_saving(
saved_inputmat,
......@@ -1098,7 +1098,7 @@ class Linear(TransformerEngineBaseModule):
save_original_input: bool = False,
name: Optional[str] = None,
) -> None:
super().__init__()
super().__init__(name)
params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
self.in_features = in_features
......@@ -1111,7 +1111,6 @@ class Linear(TransformerEngineBaseModule):
self.rng_tracker_name = rng_tracker_name
self.symmetric_ar_type = symmetric_ar_type
self.save_original_input = save_original_input
self.name = name
self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad)
......@@ -1395,11 +1394,8 @@ class Linear(TransformerEngineBaseModule):
).is_fp8_ubuf():
fp8_grad = True
with self.prepare_forward(
inp,
allow_non_contiguous=isinstance(inp, QuantizedTensor),
) as inp:
inp = self.prepare_forward(inp, allow_non_contiguous=isinstance(inp, QuantizedTensor))
try:
weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()
quantizers = (
......@@ -1470,6 +1466,8 @@ class Linear(TransformerEngineBaseModule):
bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None,
non_tensor_args,
)
finally:
self.end_forward()
if self.gemm_bias_unfused_add:
out = out + cast_if_needed(bias_tensor, self.activation_dtype)
......
......@@ -12,7 +12,6 @@ import torch
from transformer_engine.pytorch.torch_version import torch_version
from transformer_engine.pytorch.module import LayerNormMLP, LayerNorm, RMSNorm
from transformer_engine.debug.pytorch.debug_state import TEDebugState
from transformer_engine.pytorch.attention.multi_head_attention import MultiheadAttention
from transformer_engine.pytorch.attention.inference import InferenceParams
from transformer_engine.pytorch.jit import (
......@@ -398,6 +397,7 @@ class TransformerLayer(torch.nn.Module):
self.softmax_type = softmax_type
self.name = name
TransformerEngineBaseModule._validate_name(self)
attention_args = (
hidden_size,
......@@ -446,7 +446,7 @@ class TransformerLayer(torch.nn.Module):
qk_norm_type=qk_norm_type,
qk_norm_eps=qk_norm_eps,
qk_norm_before_rope=qk_norm_before_rope,
name=name + ".self_attention" if name is not None else None,
name=self.name + ".self_attention" if self.name is not None else None,
)
if layer_type == "decoder":
......@@ -463,7 +463,7 @@ class TransformerLayer(torch.nn.Module):
qk_norm_type=qk_norm_type,
qk_norm_eps=qk_norm_eps,
qk_norm_before_rope=qk_norm_before_rope,
name=name + ".inter_attention" if name is not None else None,
name=self.name + ".inter_attention" if self.name is not None else None,
)
# LayerNorm -> activation(Linear + Bias) -> Linear
......@@ -499,7 +499,7 @@ class TransformerLayer(torch.nn.Module):
activation_params=activation_params,
normalization=normalization,
device=device,
name=name + ".layernorm_mlp" if name is not None else None,
name=self.name + ".layernorm_mlp" if self.name is not None else None,
)
self.hidden_dropout = hidden_dropout
......@@ -768,9 +768,6 @@ class TransformerLayer(torch.nn.Module):
enc_dec_attn_mask[i].dtype == torch.bool for i in range(len(enc_dec_attn_mask))
), "Encoder-decoder attention mask must be boolean tensor(s)"
if TEDebugState.debug_enabled:
TransformerEngineBaseModule._validate_name(self)
# For AMP
if torch.is_autocast_enabled():
hidden_states = cast_if_needed(hidden_states, torch_get_autocast_gpu_dtype())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment