Unverified Commit 605786f4 authored by Przemyslaw Tredak's avatar Przemyslaw Tredak Committed by GitHub
Browse files

[pyTorch] CPU performance optimizations (#2439)



* PoC of the changes
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Early exit from the Free function for the empty tensor
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Use the proper function for nvtx range
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Only do mark_not_offload when the cpu_offloading is enabled
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* First pass on making the setattr issue not come back
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Actually add pytest.ini
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Changes to __init__
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* A different way
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* WAR the fact that it is not possible to set __setattr__ dynamically
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Simpler solution and fixes
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix for the inference mode DPA
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Start of debugging debug tools
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* More fixes in debug
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Speculative moving the validate_name to the constructor
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Making the debug tools names saner
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Change the setattr usage in the tensor parallel group setting
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Adding try/finally - it does not seem to impact the time in observable
way
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fixing lint issues and the thunder test
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix 1 of the debug tests
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Removed the warning and enforcement in the CI
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* try-finally in the context manager
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fixing the debug tests
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

---------
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 36f4e451
...@@ -2790,7 +2790,7 @@ class Custom_MHA_FP8(TransformerEngineBaseModule): ...@@ -2790,7 +2790,7 @@ class Custom_MHA_FP8(TransformerEngineBaseModule):
cu_seqlens, cu_seqlens,
max_s, max_s,
) -> torch.Tensor: ) -> torch.Tensor:
with self.prepare_forward(inp, num_gemms=3) as inp: with self.prepare_forward_ctx(inp, num_gemms=3) as inp:
out = _custom_mha_fp8.apply( out = _custom_mha_fp8.apply(
inp, inp,
self.qkv_weight, self.qkv_weight,
......
...@@ -30,10 +30,17 @@ configs = { ...@@ -30,10 +30,17 @@ configs = {
stats: [min, max, mean, std, l1_norm, l2_norm, cur_amax, dynamic_range] stats: [min, max, mean, std, l1_norm, l2_norm, cur_amax, dynamic_range]
start_step : 0 start_step : 0
end_step: 1 end_step: 1
""",
"log_fp8": """log_fp8:
layers:
layer_types: [linear]
enabled:
True
transformer_engine:
LogFp8TensorStats: LogFp8TensorStats:
enabled: True enabled: True
tensors: [activation, gradient, weight] tensors: [activation, gradient, weight]
stats: [underflows, overflows] stats: [underflows%]
start_step : 0 start_step : 0
end_step: 1 end_step: 1
""", """,
...@@ -46,22 +53,26 @@ fake_quant_config: ...@@ -46,22 +53,26 @@ fake_quant_config:
FakeQuant: FakeQuant:
enabled: True enabled: True
gemms: [fprop, dgrad, wgrad] gemms: [fprop, dgrad, wgrad]
tensors: [activation, weight, gradient]
quant_format: FP8E5M2 quant_format: FP8E5M2
""", """,
} }
# Configs that require FP8 to be enabled
fp8_required_configs = {"log_fp8"}
def _get_model(model_key): def _get_model(model_key):
if model_key == "linear": if model_key == "linear":
return te.Linear(D, D) return te.Linear(D, D, name="layer")
if model_key == "layernorm_linear": if model_key == "layernorm_linear":
return te.LayerNormLinear(D, D) return te.LayerNormLinear(D, D, name="layer")
if model_key == "layernorm_mlp": if model_key == "layernorm_mlp":
return te.LayerNormMLP(D, D, D) return te.LayerNormMLP(D, D, D, name="layer")
if model_key == "mha_attention": if model_key == "mha_attention":
return te.MultiheadAttention(D, H) return te.MultiheadAttention(D, H, name="layer")
if model_key == "transformer_layer": if model_key == "transformer_layer":
return te.TransformerLayer(D, D, H) return te.TransformerLayer(D, D, H, name="layer")
def _run_forward_backward(model, fp8): def _run_forward_backward(model, fp8):
...@@ -95,4 +106,6 @@ def _run_test(model_key, fp8, config, feature_dirs, config_file, log_dir): ...@@ -95,4 +106,6 @@ def _run_test(model_key, fp8, config, feature_dirs, config_file, log_dir):
def test_sanity_debug(model_key, fp8, config_key, feature_dirs): def test_sanity_debug(model_key, fp8, config_key, feature_dirs):
if fp8 and not fp8_available: if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
if not fp8 and config_key in fp8_required_configs:
pytest.skip(f"Config '{config_key}' requires FP8")
_run_test(model_key, fp8, configs[config_key], feature_dirs) _run_test(model_key, fp8, configs[config_key], feature_dirs)
...@@ -454,9 +454,9 @@ class TensorAllocator { ...@@ -454,9 +454,9 @@ class TensorAllocator {
} }
void Free(NVTETensor t) { void Free(NVTETensor t) {
std::lock_guard<std::mutex> lock(mutex);
uintptr_t index = reinterpret_cast<uintptr_t>(t); uintptr_t index = reinterpret_cast<uintptr_t>(t);
if (index == 0) return; if (index == 0) return;
std::lock_guard<std::mutex> lock(mutex);
NVTE_CHECK(index <= memory.size(), "Invalid tensor."); NVTE_CHECK(index <= memory.size(), "Invalid tensor.");
free_list.push_back(index); free_list.push_back(index);
// Clean up // Clean up
...@@ -564,9 +564,9 @@ class GroupedTensorAllocator { ...@@ -564,9 +564,9 @@ class GroupedTensorAllocator {
} }
void Free(NVTEGroupedTensor t) { void Free(NVTEGroupedTensor t) {
std::lock_guard<std::mutex> lock(mutex);
uintptr_t index = reinterpret_cast<uintptr_t>(t); uintptr_t index = reinterpret_cast<uintptr_t>(t);
if (index == 0) return; if (index == 0) return;
std::lock_guard<std::mutex> lock(mutex);
NVTE_CHECK(index <= memory.size(), "Invalid grouped tensor."); NVTE_CHECK(index <= memory.size(), "Invalid grouped tensor.");
free_list.push_back(index); free_list.push_back(index);
// Clean up // Clean up
......
...@@ -676,9 +676,9 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -676,9 +676,9 @@ class DotProductAttention(TransformerEngineBaseModule):
# assume attention uses the same fp8_group as GEMMs # assume attention uses the same fp8_group as GEMMs
fp8_group = FP8GlobalStateManager.get_fp8_group() fp8_group = FP8GlobalStateManager.get_fp8_group()
self.fp8_parameters = FP8GlobalStateManager.with_fp8_parameters() self.fast_setattr("fp8_parameters", FP8GlobalStateManager.with_fp8_parameters())
self.fp8 = FP8GlobalStateManager.is_fp8_enabled() self.fast_setattr("fp8", FP8GlobalStateManager.is_fp8_enabled())
self.fp8_calibration = FP8GlobalStateManager.is_fp8_calibration() self.fast_setattr("fp8_calibration", FP8GlobalStateManager.is_fp8_calibration())
fp8_enabled = self.fp8 or self.fp8_calibration fp8_enabled = self.fp8 or self.fp8_calibration
self.fp8_meta["fp8_checkpoint"] = self.fp8 or self.fp8_calibration self.fp8_meta["fp8_checkpoint"] = self.fp8 or self.fp8_calibration
if self.fp8_parameters or fp8_enabled: if self.fp8_parameters or fp8_enabled:
...@@ -703,7 +703,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -703,7 +703,7 @@ class DotProductAttention(TransformerEngineBaseModule):
) )
else: else:
# If fp8 isn't enabled, turn off and return. # If fp8 isn't enabled, turn off and return.
self.fp8_initialized = False self.fast_setattr("fp8_initialized", False)
return return
if self.fp8_parameters and not self.fp8_initialized: if self.fp8_parameters and not self.fp8_initialized:
...@@ -721,7 +721,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -721,7 +721,7 @@ class DotProductAttention(TransformerEngineBaseModule):
# Allocate scales and amaxes # Allocate scales and amaxes
self.init_fp8_meta_tensors(fp8_recipes) self.init_fp8_meta_tensors(fp8_recipes)
self.fp8_initialized = True self.fast_setattr("fp8_initialized", True)
self.fp8_meta["recipe"] = fp8_recipe_dpa self.fp8_meta["recipe"] = fp8_recipe_dpa
if fp8_recipe != fp8_recipe_dpa: if fp8_recipe != fp8_recipe_dpa:
...@@ -1000,7 +1000,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -1000,7 +1000,7 @@ class DotProductAttention(TransformerEngineBaseModule):
cases. It is ignored for other backends and when context parallelism is enabled. cases. It is ignored for other backends and when context parallelism is enabled.
""" """
with self.prepare_forward( with self.prepare_forward_ctx(
query_layer, query_layer,
num_gemms=3, num_gemms=3,
allow_non_contiguous=True, allow_non_contiguous=True,
...@@ -1145,10 +1145,11 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -1145,10 +1145,11 @@ class DotProductAttention(TransformerEngineBaseModule):
if attn_mask_type == "padding_causal": if attn_mask_type == "padding_causal":
attn_mask_type = attn_mask_type + "_bottom_right" attn_mask_type = attn_mask_type + "_bottom_right"
self.attention_type = "cross" if self.attention_type != "cross":
self.flash_attention.attention_type = self.attention_type self.fast_setattr("attention_type", "cross")
self.fused_attention.attention_type = self.attention_type self.flash_attention.attention_type = self.attention_type
self.unfused_attention.attention_type = self.attention_type self.fused_attention.attention_type = self.attention_type
self.unfused_attention.attention_type = self.attention_type
query_layer, key_layer, value_layer = [ query_layer, key_layer, value_layer = [
x.contiguous() if not x.is_contiguous() else x x.contiguous() if not x.is_contiguous() else x
......
...@@ -8,7 +8,6 @@ import collections ...@@ -8,7 +8,6 @@ import collections
from typing import Callable, List, Optional, Tuple, Union from typing import Callable, List, Optional, Tuple, Union
import torch import torch
from transformer_engine.debug.pytorch.debug_state import TEDebugState
from transformer_engine.pytorch.quantization import FP8GlobalStateManager from transformer_engine.pytorch.quantization import FP8GlobalStateManager
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
...@@ -335,6 +334,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -335,6 +334,7 @@ class MultiheadAttention(torch.nn.Module):
self.hidden_size_kv = self.hidden_size_per_attention_head * self.num_gqa_groups self.hidden_size_kv = self.hidden_size_per_attention_head * self.num_gqa_groups
self.name = name self.name = name
TransformerEngineBaseModule._validate_name(self)
common_gemm_kwargs = { common_gemm_kwargs = {
"fuse_wgrad_accumulation": fuse_wgrad_accumulation, "fuse_wgrad_accumulation": fuse_wgrad_accumulation,
...@@ -739,9 +739,6 @@ class MultiheadAttention(torch.nn.Module): ...@@ -739,9 +739,6 @@ class MultiheadAttention(torch.nn.Module):
core_attention_bias_type in AttnBiasTypes core_attention_bias_type in AttnBiasTypes
), f"core_attention_bias_type {core_attention_bias_type} is not supported!" ), f"core_attention_bias_type {core_attention_bias_type} is not supported!"
if TEDebugState.debug_enabled:
TransformerEngineBaseModule._validate_name(self)
# ================================================= # =================================================
# Pre-allocate memory for key-value cache for inference # Pre-allocate memory for key-value cache for inference
# ================================================= # =================================================
......
...@@ -729,8 +729,8 @@ def checkpoint( ...@@ -729,8 +729,8 @@ def checkpoint(
if isinstance(function, TransformerEngineBaseModule): if isinstance(function, TransformerEngineBaseModule):
# If this TE module is FSDP-wrapped, clear its FSDP group information because there's no need # If this TE module is FSDP-wrapped, clear its FSDP group information because there's no need
# to scatter/gather activations that we will recompute anyway. # to scatter/gather activations that we will recompute anyway.
setattr(function, "fsdp_wrapped", False) function.fast_setattr("fsdp_wrapped", False)
setattr(function, "fsdp_group", None) function.fast_setattr("fsdp_group", None)
# Otherwise discard unused te.utils.checkpoint.checkpoint() arguments # Otherwise discard unused te.utils.checkpoint.checkpoint() arguments
# and execute TE's own checkpointing # and execute TE's own checkpointing
...@@ -2022,7 +2022,7 @@ def prepare_te_modules_for_fsdp(fsdp_root: torch.nn.Module) -> None: ...@@ -2022,7 +2022,7 @@ def prepare_te_modules_for_fsdp(fsdp_root: torch.nn.Module) -> None:
) )
root_state = _get_module_fsdp_state(fsdp_root) root_state = _get_module_fsdp_state(fsdp_root)
assert root_state is not None, "Root module does not have a valid _FSDPState." assert root_state is not None, "Root module does not have a valid _FSDPState."
setattr(fsdp_root.module, "fsdp_group", root_state.process_group) fsdp_root.module.fast_setattr("fsdp_group", root_state.process_group)
# Iterate through all FSDP-wrapped submodules and inject FSDP information into TE modules # Iterate through all FSDP-wrapped submodules and inject FSDP information into TE modules
fsdp_states, fsdp_modules = _get_fsdp_states_with_modules(fsdp_root) fsdp_states, fsdp_modules = _get_fsdp_states_with_modules(fsdp_root)
...@@ -2033,7 +2033,7 @@ def prepare_te_modules_for_fsdp(fsdp_root: torch.nn.Module) -> None: ...@@ -2033,7 +2033,7 @@ def prepare_te_modules_for_fsdp(fsdp_root: torch.nn.Module) -> None:
"TE modules with primary weights in FP8 cannot be FSDP-wrapped. " "TE modules with primary weights in FP8 cannot be FSDP-wrapped. "
"Please initialize your model without the te.quantized_model_init(...) context." "Please initialize your model without the te.quantized_model_init(...) context."
) )
setattr(fsdp_module.module, "fsdp_group", state.process_group) fsdp_module.module.fast_setattr("fsdp_group", state.process_group)
class FullyShardedDataParallel(FSDP): class FullyShardedDataParallel(FSDP):
......
...@@ -10,9 +10,8 @@ import pickle ...@@ -10,9 +10,8 @@ import pickle
import warnings import warnings
from enum import Enum from enum import Enum
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, Generator, List, Optional, Set, Tuple, Union from typing import Any, Dict, Generator, List, Optional, Tuple, Union
from contextlib import contextmanager from contextlib import contextmanager
import logging
from types import MethodType from types import MethodType
import torch import torch
...@@ -50,6 +49,8 @@ from ..utils import ( ...@@ -50,6 +49,8 @@ from ..utils import (
is_non_tn_fp8_gemm_supported, is_non_tn_fp8_gemm_supported,
torch_get_autocast_gpu_dtype, torch_get_autocast_gpu_dtype,
get_nvtx_range_context, get_nvtx_range_context,
nvtx_range_push,
nvtx_range_pop,
) )
from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage
from ...common.recipe import DelayedScaling, Recipe from ...common.recipe import DelayedScaling, Recipe
...@@ -605,10 +606,10 @@ def fill_userbuffers_buffer_for_all_gather( ...@@ -605,10 +606,10 @@ def fill_userbuffers_buffer_for_all_gather(
class TransformerEngineBaseModule(torch.nn.Module, ABC): class TransformerEngineBaseModule(torch.nn.Module, ABC):
"""Base TE module.""" """Base TE module."""
def __init__(self) -> None: def __init__(self, name: Optional[str] = None) -> None:
super().__init__() super().__init__()
assert torch.cuda.is_available(), "TransformerEngine needs CUDA." assert torch.cuda.is_available(), "TransformerEngine needs CUDA."
self.name = None self.name = name
self.next_iter_when_debug_should_be_run = 0 self.next_iter_when_debug_should_be_run = 0
self.fp8_initialized = False self.fp8_initialized = False
self.fp8 = False self.fp8 = False
...@@ -633,26 +634,22 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -633,26 +634,22 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if not TEDebugState.debug_enabled: if not TEDebugState.debug_enabled:
TEDebugState.initialize() TEDebugState.initialize()
self._validate_name()
# Names of attributes that can be set quickly (see __setattr__ def fast_setattr(self, name: str, value: Any) -> None:
# method) """
_fast_setattr_names: Set[str] = { Fast version of the Module's set attribute function.
"activation_dtype", Should be used for regular attributes, but not properties nor parameters/buffers.
"fp8", """
"fp8_initialized", self.__dict__[name] = value
"fp8_calibration",
"fp8_parameters",
}
def __setattr__(self, name: str, value: Any) -> None: def module_setattr(self, name: str, value: Any) -> None:
if name in TransformerEngineBaseModule._fast_setattr_names: """
# torch.nn.Module has a custom __setattr__ that handles Regular version of the Module's set attribute function.
# modules, parameters, and buffers. This is unnecessary Should be used only when the fast version cannot be used - for the properties,
# overhead when setting plain attrs. parameters and buffers.
self.__dict__[name] = value """
else: super().__setattr__(name, value)
# Default case
super().__setattr__(name, value)
def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> None: def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> None:
""" """
...@@ -773,7 +770,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -773,7 +770,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.set_meta_tensor(True, recipe) self.set_meta_tensor(True, recipe)
self.set_meta_tensor(False, recipe) self.set_meta_tensor(False, recipe)
self.fp8_meta_tensors_initialized = True self.fast_setattr("fp8_meta_tensors_initialized", True)
def get_fp8_meta_tensors(self) -> None: def get_fp8_meta_tensors(self) -> None:
"""Get scales and amaxes.""" """Get scales and amaxes."""
...@@ -930,7 +927,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -930,7 +927,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
"""Get activation data type for AMP.""" """Get activation data type for AMP."""
# Native AMP (`torch.autocast`) gets highest priority # Native AMP (`torch.autocast`) gets highest priority
if torch.is_autocast_enabled(): if torch.is_autocast_enabled():
self.activation_dtype = torch_get_autocast_gpu_dtype() self.fast_setattr("activation_dtype", torch_get_autocast_gpu_dtype())
return return
# All checks after this have already been performed once, thus skip # All checks after this have already been performed once, thus skip
...@@ -945,7 +942,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -945,7 +942,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
"Data types for parameters must match when outside of autocasted region. " "Data types for parameters must match when outside of autocasted region. "
f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}" f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}"
) )
self.activation_dtype = dtype self.fast_setattr("activation_dtype", dtype)
def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None: def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
""" """
...@@ -957,8 +954,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -957,8 +954,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
tp_group : ProcessGroup, default = None tp_group : ProcessGroup, default = None
tensor parallel process group. tensor parallel process group.
""" """
self.tp_group = tp_group self.fast_setattr("tp_group", tp_group)
self.tp_group_initialized = True self.fast_setattr("tp_group_initialized", True)
def _get_fp8_params(self) -> Union[List[torch.Tensor], None]: def _get_fp8_params(self) -> Union[List[torch.Tensor], None]:
"""returns the FP8 weights.""" """returns the FP8 weights."""
...@@ -974,48 +971,51 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -974,48 +971,51 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# assume FP8 execution. # assume FP8 execution.
def init_fp8_metadata(self, num_gemms: int = 1) -> None: def init_fp8_metadata(self, num_gemms: int = 1) -> None:
"""Initialize fp8 related metadata and tensors during fprop.""" """Initialize fp8 related metadata and tensors during fprop."""
_original_recipe = self.fp8_meta.get("recipe", None) meta = self.fp8_meta
self.fp8_parameters = FP8GlobalStateManager.with_fp8_parameters() fp8 = FP8GlobalStateManager.is_fp8_enabled()
self.fp8 = FP8GlobalStateManager.is_fp8_enabled() fp8_parameters = FP8GlobalStateManager.with_fp8_parameters()
self.fp8_calibration = FP8GlobalStateManager.is_fp8_calibration() fp8_calibration = FP8GlobalStateManager.is_fp8_calibration()
fp8_enabled = self.fp8 or self.fp8_calibration self.fast_setattr("fp8_parameters", fp8_parameters)
self.fp8_meta["fp8_checkpoint"] = self.fp8 or self.fp8_calibration self.fast_setattr("fp8", fp8)
self.fast_setattr("fp8_calibration", fp8_calibration)
if self.fp8_parameters or fp8_enabled: fp8_enabled = fp8 or fp8_calibration
if ( meta["fp8_checkpoint"] = fp8_enabled
self.fp8_initialized
and FP8GlobalStateManager.get_fp8_recipe() == self.fp8_meta["recipe"] _original_recipe = None
):
if fp8_parameters or fp8_enabled:
_original_recipe = meta.get("recipe", None)
if self.fp8_initialized and FP8GlobalStateManager.get_fp8_recipe() == _original_recipe:
# FP8 init has already been run and recipe is the same, don't do anything. # FP8 init has already been run and recipe is the same, don't do anything.
return return
self.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe() meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe()
else: else:
# If fp8 isn't enabled, turn off and return. # If fp8 isn't enabled, turn off and return.
self.fp8_initialized = False self.fast_setattr("fp8_initialized", False)
return return
if self.fp8_parameters and not self.fp8_initialized: if fp8_parameters and not self.fp8_initialized:
self.fp8_meta["num_gemms"] = num_gemms meta["num_gemms"] = num_gemms
self.init_fp8_meta_tensors(self.fp8_meta["recipe"]) self.init_fp8_meta_tensors(meta["recipe"])
if fp8_enabled: if fp8_enabled:
# Set FP8 and other FP8 metadata # Set FP8 and other FP8 metadata
self.fp8_meta["num_gemms"] = num_gemms meta["num_gemms"] = num_gemms
self.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group() meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group()
# Set FP8_MAX per tensor according to recipe # Set FP8_MAX per tensor according to recipe
if hasattr(self.fp8_meta["recipe"], "fp8_format"): if hasattr(meta["recipe"], "fp8_format"):
self.fp8_meta["fp8_max_fwd"] = self.fp8_meta["recipe"].fp8_format.value.max_fwd meta["fp8_max_fwd"] = meta["recipe"].fp8_format.value.max_fwd
self.fp8_meta["fp8_max_bwd"] = self.fp8_meta["recipe"].fp8_format.value.max_bwd meta["fp8_max_bwd"] = meta["recipe"].fp8_format.value.max_bwd
# Allocate scales and amaxes # Allocate scales and amaxes
self.init_fp8_meta_tensors(self.fp8_meta["recipe"]) self.init_fp8_meta_tensors(meta["recipe"])
self.fp8_initialized = True self.fast_setattr("fp8_initialized", True)
self.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe() meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe()
_current_recipe = self.fp8_meta["recipe"] _current_recipe = meta["recipe"]
if _original_recipe is not None and not ( if _original_recipe is not None and not (
issubclass(_current_recipe.__class__, _original_recipe.__class__) issubclass(_current_recipe.__class__, _original_recipe.__class__)
or issubclass(_original_recipe.__class__, _current_recipe.__class__) or issubclass(_original_recipe.__class__, _current_recipe.__class__)
...@@ -1028,22 +1028,18 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1028,22 +1028,18 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# Clear cached workspaces as they were created with the old recipe/quantizer type # Clear cached workspaces as they were created with the old recipe/quantizer type
self._fp8_workspaces.clear() self._fp8_workspaces.clear()
@contextmanager
def prepare_forward( def prepare_forward(
self, self,
inp: torch.Tensor, inp: torch.Tensor,
num_gemms: int = 1, num_gemms: int = 1,
allow_non_contiguous: bool = False, allow_non_contiguous: bool = False,
allow_different_data_and_param_types: bool = False, allow_different_data_and_param_types: bool = False,
) -> Generator[torch.Tensor, None, None]: ) -> torch.Tensor:
"""Checks and prep for FWD. """Checks and prepares for FWD execution."""
The context manager is needed because there isn't a way for a module to know self.fast_setattr(
if it's the last FP8 module in the forward autocast. It is useful "allow_different_data_and_param_types", allow_different_data_and_param_types
to setup the forward aggregated amax reduction for every module )
just in case. The autocast exit will pick up the most recent one. self.fast_setattr("forwarded_at_least_once", True)
"""
self.allow_different_data_and_param_types = allow_different_data_and_param_types
self.forwarded_at_least_once = True
# Activation recomputation is used and this is the second forward phase. # Activation recomputation is used and this is the second forward phase.
if self.fp8 and in_fp8_activation_recompute_phase(): if self.fp8 and in_fp8_activation_recompute_phase():
...@@ -1074,13 +1070,37 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1074,13 +1070,37 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if self.training and is_fp8_activation_recompute_enabled(): if self.training and is_fp8_activation_recompute_enabled():
FP8GlobalStateManager.copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta) FP8GlobalStateManager.copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta)
with get_nvtx_range_context(self.__class__.__name__ + " forward"): nvtx_range_push(self.__class__.__name__ + " forward")
if not allow_non_contiguous and not inp.is_contiguous(): if not allow_non_contiguous and not inp.is_contiguous():
inp = inp.contiguous() inp = inp.contiguous()
yield inp return inp
def end_forward(self):
"""
Required to be called at the end of the forward function to properly handle
DelayedScaling metadata handling and the NVTX ranges.
"""
delayed_scaling_recipe = self.fp8 and self.fp8_meta["recipe"].delayed()
if delayed_scaling_recipe and self.fp8 and in_fp8_activation_recompute_phase(): if delayed_scaling_recipe and self.fp8 and in_fp8_activation_recompute_phase():
FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta) FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta)
nvtx_range_pop()
@contextmanager
def prepare_forward_ctx(
self,
inp: torch.Tensor,
num_gemms: int = 1,
allow_non_contiguous: bool = False,
allow_different_data_and_param_types: bool = False,
) -> Generator[torch.Tensor, None, None]:
"""Checks and prepares for FWD execution."""
inp = self.prepare_forward(
inp, num_gemms, allow_non_contiguous, allow_different_data_and_param_types
)
try:
yield inp
finally:
self.end_forward()
def set_nccl_overlap_warning_if_tp(self) -> None: def set_nccl_overlap_warning_if_tp(self) -> None:
"""When using TP, the NCCL communication needs to be scheduled """When using TP, the NCCL communication needs to be scheduled
...@@ -1315,9 +1335,9 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1315,9 +1335,9 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# Update the parameter based on its type # Update the parameter based on its type
if not is_dtensor: if not is_dtensor:
setattr(self, name, param) self.module_setattr(name, param)
else: else:
setattr(self, name, dtensor_param) self.module_setattr(name, dtensor_param)
@abstractmethod @abstractmethod
def forward(self): def forward(self):
...@@ -1516,7 +1536,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1516,7 +1536,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
debug = TEDebugState.debug_enabled debug = TEDebugState.debug_enabled
if not debug: if not debug:
return False return False
self._validate_name()
# If layer is run first time in new iteration, # If layer is run first time in new iteration,
# we need to check if the debug should be enabled for this layer - # we need to check if the debug should be enabled for this layer -
...@@ -1530,14 +1549,14 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1530,14 +1549,14 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
debug = False debug = False
else: else:
debug = TEDebugState.get_iteration() >= self.next_iter_when_debug_should_be_run debug = TEDebugState.get_iteration() >= self.next_iter_when_debug_should_be_run
self.debug_last_iteration = TEDebugState.get_iteration() self.fast_setattr("debug_last_iteration", TEDebugState.get_iteration())
self.debug_enabled_in_this_iteration = debug self.fast_setattr("debug_enabled_in_this_iteration", debug)
else: else:
# If this is the same iteration as previous invocation of the module, # If this is the same iteration as previous invocation of the module,
# we use the debug value from the first invocation in the iteration. # we use the debug value from the first invocation in the iteration.
debug = self.debug_enabled_in_this_iteration debug = self.debug_enabled_in_this_iteration
self.debug_last_iteration = TEDebugState.get_iteration() self.fast_setattr("debug_last_iteration", TEDebugState.get_iteration())
if self.wgrad_store is not None: if self.wgrad_store is not None:
if debug and self.wgrad_store.delay_wgrad_compute(): if debug and self.wgrad_store.delay_wgrad_compute():
...@@ -1553,7 +1572,9 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1553,7 +1572,9 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# Sometimes features inform that they will not be enabled for particular layer # Sometimes features inform that they will not be enabled for particular layer
# for multiple next iterations. # for multiple next iterations.
self.next_iter_when_debug_should_be_run = next_iter_when_debug_should_be_run(quantizers) self.fast_setattr(
"next_iter_when_debug_should_be_run", next_iter_when_debug_should_be_run(quantizers)
)
if not run_current: if not run_current:
return True return True
...@@ -1565,22 +1586,13 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1565,22 +1586,13 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
def _validate_name(self): def _validate_name(self):
""" """
Validate name passed to the module. Validate name passed to the module.
This is invoked in the forward() method as module names are assigned after Model is initialized in Megatron-LM. It creates a default name with layer count as the variable
If no name is assigned, it creates a default name with layer count as the variable. which may be changed by the user of the module.
""" """
if self.name is not None: if self.name is not None:
return return
assert TEDebugState.debug_enabled
import nvdlfw_inspect.api as debug_api self.name = f"Layer_{TEDebugState.get_layer_count()}"
if self.name is None:
debug_api.log_message(
"Names are not provided to debug modules. ",
"Creating and using generic names. Pass names to debug modules for better"
" insight. ",
level=logging.WARNING,
)
self.name = f"Layer_{TEDebugState.get_layer_count()}"
def _check_weight_tensor_recipe_correspondence(self) -> None: def _check_weight_tensor_recipe_correspondence(self) -> None:
""" """
......
...@@ -614,7 +614,7 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -614,7 +614,7 @@ class GroupedLinear(TransformerEngineBaseModule):
save_original_input: bool = False, save_original_input: bool = False,
name: Optional[str] = None, name: Optional[str] = None,
) -> None: ) -> None:
super().__init__() super().__init__(name)
params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
self.num_gemms = num_gemms self.num_gemms = num_gemms
...@@ -633,7 +633,6 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -633,7 +633,6 @@ class GroupedLinear(TransformerEngineBaseModule):
), "GroupedLinear doesn't support Userbuffer overlap." ), "GroupedLinear doesn't support Userbuffer overlap."
self.get_rng_state_tracker = get_rng_state_tracker self.get_rng_state_tracker = get_rng_state_tracker
self.rng_tracker_name = rng_tracker_name self.rng_tracker_name = rng_tracker_name
self.name = name
self.wgrad_store = WeightGradStore(delay_wgrad_compute) self.wgrad_store = WeightGradStore(delay_wgrad_compute)
...@@ -789,7 +788,8 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -789,7 +788,8 @@ class GroupedLinear(TransformerEngineBaseModule):
is_grad_enabled = torch.is_grad_enabled() is_grad_enabled = torch.is_grad_enabled()
with self.prepare_forward(inp, num_gemms=self.num_gemms) as inp: inp = self.prepare_forward(inp, num_gemms=self.num_gemms)
try:
weight_tensors = self._get_weight_tensors() weight_tensors = self._get_weight_tensors()
bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)] bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)]
...@@ -844,6 +844,9 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -844,6 +844,9 @@ class GroupedLinear(TransformerEngineBaseModule):
) )
out = linear_fn(*autograd_ctx, inp, non_tensor_args, *weight_tensors, *bias_tensors) out = linear_fn(*autograd_ctx, inp, non_tensor_args, *weight_tensors, *bias_tensors)
finally:
self.end_forward()
if self.return_bias: if self.return_bias:
return out, [cast_if_needed(b, self.activation_dtype) for b in bias_tensors] return out, [cast_if_needed(b, self.activation_dtype) for b in bias_tensors]
return out return out
......
...@@ -1158,9 +1158,9 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1158,9 +1158,9 @@ class LayerNormLinear(TransformerEngineBaseModule):
ub_name: Optional[str] = None, ub_name: Optional[str] = None,
delay_wgrad_compute: bool = False, delay_wgrad_compute: bool = False,
symmetric_ar_type: Optional[str] = None, symmetric_ar_type: Optional[str] = None,
name: str = None, name: Optional[str] = None,
) -> None: ) -> None:
super().__init__() super().__init__(name)
params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
self.in_features = in_features self.in_features = in_features
...@@ -1179,7 +1179,6 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1179,7 +1179,6 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.symmetric_ar_type = symmetric_ar_type self.symmetric_ar_type = symmetric_ar_type
self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad) self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad)
self.name = name
if tp_group is None: if tp_group is None:
self.tp_size = tp_size self.tp_size = tp_size
...@@ -1508,10 +1507,11 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1508,10 +1507,11 @@ class LayerNormLinear(TransformerEngineBaseModule):
).is_fp8_ubuf(): ).is_fp8_ubuf():
fp8_grad = True fp8_grad = True
with self.prepare_forward( inp = self.prepare_forward(
inp, allow_non_contiguous=False # removed .contiguous from inside the layer inp, allow_non_contiguous=False # removed .contiguous from inside the layer
) as inp: )
try:
# Get concatenated weight and bias tensors # Get concatenated weight and bias tensors
weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()
...@@ -1590,6 +1590,9 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1590,6 +1590,9 @@ class LayerNormLinear(TransformerEngineBaseModule):
non_tensor_args, non_tensor_args,
) )
finally:
self.end_forward()
if self.return_layernorm_output: if self.return_layernorm_output:
out, ln_out = out out, ln_out = out
......
...@@ -1787,7 +1787,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1787,7 +1787,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
zero_centered_gamma: bool = False, zero_centered_gamma: bool = False,
device: Union[torch.device, str] = "cuda", device: Union[torch.device, str] = "cuda",
ub_overlap_ag: bool = False, ub_overlap_ag: bool = False,
name: str = None, name: Optional[str] = None,
ub_overlap_rs: bool = False, ub_overlap_rs: bool = False,
ub_overlap_rs_dgrad: bool = False, ub_overlap_rs_dgrad: bool = False,
ub_bulk_dgrad: bool = False, ub_bulk_dgrad: bool = False,
...@@ -1796,7 +1796,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1796,7 +1796,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
symmetric_ar_type: Optional[str] = None, symmetric_ar_type: Optional[str] = None,
checkpoint: bool = False, checkpoint: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__(name)
params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
self.fuse_wgrad_accumulation = fuse_wgrad_accumulation self.fuse_wgrad_accumulation = fuse_wgrad_accumulation
...@@ -1827,7 +1827,6 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1827,7 +1827,6 @@ class LayerNormMLP(TransformerEngineBaseModule):
for use_fp8 in [False, True] for use_fp8 in [False, True]
) )
) )
self.name = name
self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad) self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad)
...@@ -2047,8 +2046,9 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -2047,8 +2046,9 @@ class LayerNormMLP(TransformerEngineBaseModule):
if get_ub("fc2_fprop", FP8GlobalStateManager.is_fp8_enabled()).is_fp8_ubuf(): if get_ub("fc2_fprop", FP8GlobalStateManager.is_fp8_enabled()).is_fp8_ubuf():
fp8_output = True fp8_output = True
with self.prepare_forward(inp, num_gemms=2) as inp: inp = self.prepare_forward(inp, num_gemms=2)
try:
quantizers = ( quantizers = (
self._get_quantizers(fp8_output, is_grad_enabled) self._get_quantizers(fp8_output, is_grad_enabled)
if not debug if not debug
...@@ -2087,7 +2087,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -2087,7 +2087,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
# Disable bias_gelu_nvfusion for determinism checkpointing in non-reentrant mode # Disable bias_gelu_nvfusion for determinism checkpointing in non-reentrant mode
if self.bias_gelu_nvfusion and not use_reentrant_activation_recompute(): if self.bias_gelu_nvfusion and not use_reentrant_activation_recompute():
self.bias_gelu_nvfusion = False self.fast_setattr("bias_gelu_nvfusion", False)
if is_grad_enabled: if is_grad_enabled:
fwd_fn = _LayerNormMLP.apply fwd_fn = _LayerNormMLP.apply
...@@ -2157,6 +2157,9 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -2157,6 +2157,9 @@ class LayerNormMLP(TransformerEngineBaseModule):
non_tensor_args, non_tensor_args,
) )
finally:
self.end_forward()
if self.return_layernorm_output: if self.return_layernorm_output:
out, ln_out = out out, ln_out = out
......
...@@ -428,8 +428,8 @@ class _Linear(torch.autograd.Function): ...@@ -428,8 +428,8 @@ class _Linear(torch.autograd.Function):
# weights if weights are externally touched outside this module # weights if weights are externally touched outside this module
ctx.weight_object = weight ctx.weight_object = weight
if cpu_offloading:
mark_not_offload(weight, weightmat, bias) mark_not_offload(weight, weightmat, bias)
# TODO(ksivamani): Check memory usage # TODO(ksivamani): Check memory usage
tensors_to_save, tensor_objects = prepare_for_saving( tensors_to_save, tensor_objects = prepare_for_saving(
saved_inputmat, saved_inputmat,
...@@ -1098,7 +1098,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -1098,7 +1098,7 @@ class Linear(TransformerEngineBaseModule):
save_original_input: bool = False, save_original_input: bool = False,
name: Optional[str] = None, name: Optional[str] = None,
) -> None: ) -> None:
super().__init__() super().__init__(name)
params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
self.in_features = in_features self.in_features = in_features
...@@ -1111,7 +1111,6 @@ class Linear(TransformerEngineBaseModule): ...@@ -1111,7 +1111,6 @@ class Linear(TransformerEngineBaseModule):
self.rng_tracker_name = rng_tracker_name self.rng_tracker_name = rng_tracker_name
self.symmetric_ar_type = symmetric_ar_type self.symmetric_ar_type = symmetric_ar_type
self.save_original_input = save_original_input self.save_original_input = save_original_input
self.name = name
self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad) self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad)
...@@ -1395,11 +1394,8 @@ class Linear(TransformerEngineBaseModule): ...@@ -1395,11 +1394,8 @@ class Linear(TransformerEngineBaseModule):
).is_fp8_ubuf(): ).is_fp8_ubuf():
fp8_grad = True fp8_grad = True
with self.prepare_forward( inp = self.prepare_forward(inp, allow_non_contiguous=isinstance(inp, QuantizedTensor))
inp, try:
allow_non_contiguous=isinstance(inp, QuantizedTensor),
) as inp:
weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()
quantizers = ( quantizers = (
...@@ -1470,6 +1466,8 @@ class Linear(TransformerEngineBaseModule): ...@@ -1470,6 +1466,8 @@ class Linear(TransformerEngineBaseModule):
bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None, bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None,
non_tensor_args, non_tensor_args,
) )
finally:
self.end_forward()
if self.gemm_bias_unfused_add: if self.gemm_bias_unfused_add:
out = out + cast_if_needed(bias_tensor, self.activation_dtype) out = out + cast_if_needed(bias_tensor, self.activation_dtype)
......
...@@ -12,7 +12,6 @@ import torch ...@@ -12,7 +12,6 @@ import torch
from transformer_engine.pytorch.torch_version import torch_version from transformer_engine.pytorch.torch_version import torch_version
from transformer_engine.pytorch.module import LayerNormMLP, LayerNorm, RMSNorm from transformer_engine.pytorch.module import LayerNormMLP, LayerNorm, RMSNorm
from transformer_engine.debug.pytorch.debug_state import TEDebugState
from transformer_engine.pytorch.attention.multi_head_attention import MultiheadAttention from transformer_engine.pytorch.attention.multi_head_attention import MultiheadAttention
from transformer_engine.pytorch.attention.inference import InferenceParams from transformer_engine.pytorch.attention.inference import InferenceParams
from transformer_engine.pytorch.jit import ( from transformer_engine.pytorch.jit import (
...@@ -398,6 +397,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -398,6 +397,7 @@ class TransformerLayer(torch.nn.Module):
self.softmax_type = softmax_type self.softmax_type = softmax_type
self.name = name self.name = name
TransformerEngineBaseModule._validate_name(self)
attention_args = ( attention_args = (
hidden_size, hidden_size,
...@@ -446,7 +446,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -446,7 +446,7 @@ class TransformerLayer(torch.nn.Module):
qk_norm_type=qk_norm_type, qk_norm_type=qk_norm_type,
qk_norm_eps=qk_norm_eps, qk_norm_eps=qk_norm_eps,
qk_norm_before_rope=qk_norm_before_rope, qk_norm_before_rope=qk_norm_before_rope,
name=name + ".self_attention" if name is not None else None, name=self.name + ".self_attention" if self.name is not None else None,
) )
if layer_type == "decoder": if layer_type == "decoder":
...@@ -463,7 +463,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -463,7 +463,7 @@ class TransformerLayer(torch.nn.Module):
qk_norm_type=qk_norm_type, qk_norm_type=qk_norm_type,
qk_norm_eps=qk_norm_eps, qk_norm_eps=qk_norm_eps,
qk_norm_before_rope=qk_norm_before_rope, qk_norm_before_rope=qk_norm_before_rope,
name=name + ".inter_attention" if name is not None else None, name=self.name + ".inter_attention" if self.name is not None else None,
) )
# LayerNorm -> activation(Linear + Bias) -> Linear # LayerNorm -> activation(Linear + Bias) -> Linear
...@@ -499,7 +499,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -499,7 +499,7 @@ class TransformerLayer(torch.nn.Module):
activation_params=activation_params, activation_params=activation_params,
normalization=normalization, normalization=normalization,
device=device, device=device,
name=name + ".layernorm_mlp" if name is not None else None, name=self.name + ".layernorm_mlp" if self.name is not None else None,
) )
self.hidden_dropout = hidden_dropout self.hidden_dropout = hidden_dropout
...@@ -768,9 +768,6 @@ class TransformerLayer(torch.nn.Module): ...@@ -768,9 +768,6 @@ class TransformerLayer(torch.nn.Module):
enc_dec_attn_mask[i].dtype == torch.bool for i in range(len(enc_dec_attn_mask)) enc_dec_attn_mask[i].dtype == torch.bool for i in range(len(enc_dec_attn_mask))
), "Encoder-decoder attention mask must be boolean tensor(s)" ), "Encoder-decoder attention mask must be boolean tensor(s)"
if TEDebugState.debug_enabled:
TransformerEngineBaseModule._validate_name(self)
# For AMP # For AMP
if torch.is_autocast_enabled(): if torch.is_autocast_enabled():
hidden_states = cast_if_needed(hidden_states, torch_get_autocast_gpu_dtype()) hidden_states = cast_if_needed(hidden_states, torch_get_autocast_gpu_dtype())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment