Unverified Commit 2e0bfbd9 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[PyTorch] FP8 fixes (#380)



* Initial refactor
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Reorder methods by purpose
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Save full global state
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* More fixes to test
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent cbfb8c6b
......@@ -3,6 +3,7 @@ extension-pkg-whitelist=torch,
transformer_engine_extensions
disable=too-many-locals,
too-many-public-methods,
invalid-name,
too-many-arguments,
abstract-method,
......
......@@ -10,7 +10,7 @@ from transformer_engine.pytorch.utils import (
scaled_init_method_normal,
get_device_compute_capability,
)
from transformer_engine.pytorch.fp8 import is_fp8_available
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch import TransformerLayer
from transformer_engine.pytorch.attention import DotProductAttention
import os
......@@ -18,7 +18,7 @@ import os
from pkg_resources import packaging
from importlib.metadata import version
from test_numerics import get_dummy_cuda_rng_tracker, reset_rng_states
fp8_available, reason_for_no_fp8 = is_fp8_available()
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
_flash_attn_version = packaging.version.Version(version("flash-attn"))
_flash_attn_2_available = _flash_attn_version >= packaging.version.Version("2")
......
......@@ -38,7 +38,7 @@ import transformer_engine.pytorch.cpp_extensions as texcpp
import transformer_engine.pytorch.softmax as softmax_defs
from transformer_engine.pytorch.utils import get_default_init_method
from transformer_engine.pytorch.export import is_in_onnx_export_mode
from transformer_engine.pytorch.fp8 import is_fp8_available
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
# Global test configuration knobs.
......@@ -66,7 +66,7 @@ assert OPSET >= TRILU_OPSET
# Shared library implementing custom FP8 Q/DQ operators for ONNX Runtime (ORT).
ORT_CUSTOM_OPS_LIB = os.path.join(TESTS_DIR, "./libcustom_ort_fp8_qdq_ops.so")
fp8_available, reason_for_no_fp8 = is_fp8_available()
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
skip_FP8 = pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
supported_activations = ["gelu", "relu", "reglu", "geglu", "swiglu"]
......
......@@ -5,7 +5,7 @@
import torch
import pytest
from transformer_engine.pytorch.fp8 import fp8_autocast, is_fp8_available
from transformer_engine.pytorch.fp8 import fp8_autocast, FP8GlobalStateManager
from transformer_engine.pytorch.utils import (
init_method_normal,
scaled_init_method_normal,
......@@ -21,7 +21,7 @@ from transformer_engine.pytorch import (
from transformer_engine.common import recipe
# Only run FP8 tests on H100.
fp8_available, reason_for_no_fp8 = is_fp8_available()
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
def custom_amax_to_scale(
......
......@@ -12,7 +12,7 @@ from torch.utils.checkpoint import detach_variable
from .utils import safely_set_viewless_tensor_data
from .constants import dist_group_type
from .fp8 import is_fp8_enabled
from .fp8 import FP8GlobalStateManager
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {
"tensor_model_parallel": False,
......@@ -145,7 +145,8 @@ def activation_recompute_forward(
"""
global _FP8_ACTIVATION_RECOMPUTE_ENABLED, _FP8_ACTIVATION_RECOMPUTE_PHASE
try:
_FP8_ACTIVATION_RECOMPUTE_ENABLED = activation_recompute and is_fp8_enabled()
_FP8_ACTIVATION_RECOMPUTE_ENABLED = (
activation_recompute and FP8GlobalStateManager.is_fp8_enabled())
_FP8_ACTIVATION_RECOMPUTE_PHASE = recompute_phase
yield
finally:
......
......@@ -16,28 +16,11 @@ from .constants import dist_group_type
from .utils import get_device_compute_capability
from .jit import jit_fuser
_FP8_ENABLED = False
_FP8_CALIBRATION = False
_FP8_RECIPE = None
_FP8_DISTRIBUTED_GROUP = None
_IS_FIRST_FP8_MODULE = False
_FP8_AUTOCAST_COUNTER = 0
_FP8_CURRENT_CONTEXT_ID = 0
_FP8_AUTOCAST_DEPTH = 0
_global_fp8_buffer = {}
_fp8_tensors_recompute_buffer = []
_amax_forward_global_reduce_func = None
_buffer_delete_key_fwd = None
_buffer_delete_key_bwd = None
_amax_reduce_handle_fwd = None
_is_fp8_available = None
_reason_for_no_fp8 = ""
_dp_amax_reduce_interval = None
_dp_amax_reduce_forward_idx = 0
_dp_amax_reduce_backward_idx = 0
def _check_fp8_support() -> Tuple[bool, str]:
__all__ = ["fp8_autocast"]
def check_fp8_support() -> Tuple[bool, str]:
"""Return if fp8 support is available"""
if get_device_compute_capability() >= 9.0: # hopper and above
return True, ""
......@@ -50,182 +33,420 @@ def _check_fp8_support() -> Tuple[bool, str]:
return True, ""
def is_fp8_available() -> Tuple[bool, str]:
"""Return if fp8 support is available"""
global _is_fp8_available, _reason_for_no_fp8
if _is_fp8_available is None:
_is_fp8_available, _reason_for_no_fp8 = _check_fp8_support()
return _is_fp8_available, _reason_for_no_fp8
def get_meta_tensor_key(forward: bool = True) -> str:
"""Returns scaling key in `fp8_meta`."""
if forward:
return "scaling_fwd"
return "scaling_bwd"
def get_buffer_position_key(forward: bool = True) -> str:
"""Returns module position key in `fp8_meta`."""
if forward:
return "global_fp8_buffer_pos_fwd"
return "global_fp8_buffer_pos_bwd"
def get_autocast_key(forward: bool = True) -> str:
"""Returns module position key in `fp8_meta`."""
if forward:
return "autocast_id_fwd"
return "autocast_id_bwd"
def get_amax_reduce_handle_fwd() -> Union[bool, None]:
"""Return AMAX reduction wait handle of forward prop."""
global _amax_reduce_handle_fwd
return _amax_reduce_handle_fwd
def get_global_fp8_buffer() -> Dict[str, List[torch.Tensor]]:
"""Returns global fp8 buffer."""
return _global_fp8_buffer
def set_global_fp8_buffer(buffer: Dict[str, List[torch.Tensor]]) -> None:
"""Sets global fp8 buffer."""
global _global_fp8_buffer
# Map all tensors back to GPU.
for k, v in buffer.items():
buffer[k] = [tensor.cuda() for tensor in v]
_global_fp8_buffer = buffer
def setup_amax_forward_global_reduce_func(f: Callable) -> None:
"""Sets up the function to call during autocast exit."""
global _amax_forward_global_reduce_func
_amax_forward_global_reduce_func = f
def get_amax_buffer_key(fp8_meta: Dict[str, Any], forward: bool = True) -> str:
"""Return a key in `_global_fp8_buffer` for the AMAX storage."""
if forward:
return f"FWD_AMAX_{fp8_meta['autocast_id_fwd']}"
return f"BWD_AMAX_{fp8_meta['autocast_id_bwd']}"
def add_amax_to_global_buffer(fp8_meta: Dict[str, Any], forward: bool = True) -> None:
"""Append 1D tensor `amax` to global buffer."""
global _global_fp8_buffer
buffer_key = get_amax_buffer_key(fp8_meta, forward=forward)
fp8_meta_tensor_key = get_meta_tensor_key(forward=forward)
buffer_position_key = get_buffer_position_key(forward=forward)
if buffer_key not in _global_fp8_buffer:
_global_fp8_buffer[buffer_key] = [fp8_meta[fp8_meta_tensor_key].amax_history[0]]
else:
_global_fp8_buffer[buffer_key].append(
fp8_meta[fp8_meta_tensor_key].amax_history[0]
)
if buffer_position_key not in fp8_meta:
fp8_meta[buffer_position_key] = len(_global_fp8_buffer[buffer_key]) - 1
# Catch incorrect fp8_autocast usage.
assert fp8_meta[buffer_position_key] == len(_global_fp8_buffer[buffer_key]) - 1, \
"Same module is being invoked more than once inside an `fp8_autocast` region when using " \
"FP8 with amax reduction. This behavior is currently unsupported. For more details and " \
"correct usage, please see https://github.com/NVIDIA/TransformerEngine/pull/93."
def copy_forward_fp8_meta_tensors_for_recompute(fp8_meta: Dict[str, Any]) -> None:
"""Copy the scaling factors and amaxes for recompute forward phase
to ensure both forward steps are numerically same.
def get_default_fp8_recipe() -> DelayedScaling:
"""FP8 recipe if not provided by user
Margin = 0, interval = 1, E4M3
"""
global _fp8_tensors_recompute_buffer
buffer_position_key = "global_fp8_buffer_pos_fwd_recompute"
return DelayedScaling()
to_copy = [
fp8_meta["scaling_fwd"].amax_history.clone(),
fp8_meta["scaling_fwd"].scale.clone(),
fp8_meta["scaling_fwd"].scale_inv.clone(),
]
if buffer_position_key in fp8_meta:
_fp8_tensors_recompute_buffer[fp8_meta[buffer_position_key]].append(to_copy)
else:
if len(_fp8_tensors_recompute_buffer) == 0:
_fp8_tensors_recompute_buffer = [deque()]
else:
_fp8_tensors_recompute_buffer.append(deque())
_fp8_tensors_recompute_buffer[-1].append(to_copy)
fp8_meta[buffer_position_key] = len(_fp8_tensors_recompute_buffer) - 1
def get_fp8_te_dtype(
fp8_recipe: DelayedScaling, fprop_tensor: bool = True
) -> tex.DType:
"""Get fp8 data type according to recipe and tensor"""
if fp8_recipe.fp8_format == Format.E4M3 or (
fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor
):
return tex.DType.kFloat8E4M3
return tex.DType.kFloat8E5M2
def get_old_fp8_meta_tensors_for_recompute(fp8_meta: Dict[str, Any]) -> None:
"""Switch to the copied scaling factors and amaxes from phase
1 forward for indentical numerical outputs.
class FP8GlobalStateManager:
"""Class to keep track of and manipulate the global
FP8 state at different stages of execution.
"""
FP8_ENABLED = False
FP8_CALIBRATION = False
FP8_RECIPE = None
FP8_DISTRIBUTED_GROUP = None
IS_FIRST_FP8_MODULE = False
FP8_AUTOCAST_COUNTER = 0
FP8_CURRENT_CONTEXT_ID = 0
FP8_AUTOCAST_DEPTH = 0
global_fp8_buffer = {}
fp8_tensors_recompute_buffer = []
amax_forward_global_reduce_func = None
buffer_delete_key_fwd = None
buffer_delete_key_bwd = None
amax_reduce_handle_fwd = None
fp8_available = None
reason_for_no_fp8 = ""
dp_amax_reduce_interval = None
dp_amax_reduce_forward_idx = 0
dp_amax_reduce_backward_idx = 0
@classmethod
def is_fp8_available(cls) -> Tuple[bool, str]:
"""Return if fp8 support is available"""
if cls.fp8_available is None:
cls.fp8_available, cls.reason_for_no_fp8 = check_fp8_support()
return cls.fp8_available, cls.reason_for_no_fp8
@classmethod
def get_global_fp8_state_checkpoint(cls) -> Dict[str, Union[int, str]]:
"""Returns global fp8 state variables."""
# Convert attributes to dictionary to make future proof against
# changes in global state variables in order to make setting the
# checkpoint backwards compatible.
global_fp8_state = {}
global_fp8_state["FP8_AUTOCAST_COUNTER"] = cls.FP8_AUTOCAST_COUNTER
global_fp8_state["FP8_CURRENT_CONTEXT_ID"] = cls.FP8_CURRENT_CONTEXT_ID
global_fp8_state["FP8_AUTOCAST_DEPTH"] = cls.FP8_AUTOCAST_DEPTH
global_fp8_state["buffer_delete_key_fwd"] = cls.buffer_delete_key_fwd
global_fp8_state["buffer_delete_key_bwd"] = cls.buffer_delete_key_bwd
global_fp8_state["dp_amax_reduce_interval"] = cls.dp_amax_reduce_interval
global_fp8_state["dp_amax_reduce_forward_idx"] = cls.dp_amax_reduce_forward_idx
global_fp8_state["dp_amax_reduce_backward_idx"] = cls.dp_amax_reduce_backward_idx
return global_fp8_state
@classmethod
def set_global_fp8_state_checkpoint(cls, state: Dict[str, Union[int, str]]) -> None:
"""Sets global fp8 state variables."""
for k, v in state.items():
if hasattr(cls, k):
setattr(cls, k, v)
@classmethod
def get_global_fp8_buffer_checkpoint(cls) -> Dict[str, List[torch.Tensor]]:
"""Returns global fp8 amax buffer."""
return cls.global_fp8_buffer
@classmethod
def set_global_fp8_buffer_checkpoint(cls, buffer: Dict[str, List[torch.Tensor]]) -> None:
"""Sets global fp8 amax buffer."""
# Map all tensors back to GPU.
for k, v in buffer.items():
buffer[k] = [tensor.cuda() for tensor in v]
cls.global_fp8_buffer = buffer
@staticmethod
def get_meta_tensor_key(forward: bool = True) -> str:
"""Returns scaling key in `fp8_meta`."""
if forward:
return "scaling_fwd"
return "scaling_bwd"
@staticmethod
def get_buffer_position_key(forward: bool = True) -> str:
"""Returns module position key in `fp8_meta`."""
if forward:
return "global_fp8_buffer_pos_fwd"
return "global_fp8_buffer_pos_bwd"
@staticmethod
def get_autocast_key(forward: bool = True) -> str:
"""Returns module position key in `fp8_meta`."""
if forward:
return "autocast_id_fwd"
return "autocast_id_bwd"
@staticmethod
def get_amax_buffer_key(fp8_meta: Dict[str, Any], forward: bool = True) -> str:
"""Return a key in `_global_fp8_buffer` for the AMAX storage."""
if forward:
return f"FWD_AMAX_{fp8_meta['autocast_id_fwd']}"
return f"BWD_AMAX_{fp8_meta['autocast_id_bwd']}"
@classmethod
def get_amax_reduce_handle_fwd(cls) -> Union[bool, None]:
"""Return AMAX reduction wait handle of forward prop."""
return cls.amax_reduce_handle_fwd
@classmethod
def setup_amax_forward_global_reduce_func(cls, f: Callable) -> None:
"""Sets up the function to call during autocast exit."""
cls.amax_forward_global_reduce_func = f
@classmethod
def add_amax_to_global_buffer(cls, fp8_meta: Dict[str, Any], forward: bool = True) -> None:
"""Append 1D tensor `amax` to global buffer."""
buffer_key = cls.get_amax_buffer_key(fp8_meta, forward=forward)
fp8_meta_tensor_key = cls.get_meta_tensor_key(forward=forward)
buffer_position_key = cls.get_buffer_position_key(forward=forward)
if buffer_key not in cls.global_fp8_buffer:
cls.global_fp8_buffer[buffer_key] = [fp8_meta[fp8_meta_tensor_key].amax_history[0]]
else:
cls.global_fp8_buffer[buffer_key].append(
fp8_meta[fp8_meta_tensor_key].amax_history[0]
)
if buffer_position_key not in fp8_meta:
fp8_meta[buffer_position_key] = len(cls.global_fp8_buffer[buffer_key]) - 1
# Catch incorrect fp8_autocast usage.
assert fp8_meta[buffer_position_key] == len(cls.global_fp8_buffer[buffer_key]) - 1, \
"Same module is being invoked more than once inside an `fp8_autocast` " \
"region when using FP8 with amax reduction. This behavior is currently" \
" unsupported. For more details and correct usage, please see " \
"https://github.com/NVIDIA/TransformerEngine/pull/93."
@classmethod
def copy_amax_from_global_buffer(
cls, fp8_meta: Dict[str, Any], forward: bool = True
) -> None:
"""Populate current amax with the correct location from buffer."""
fp8_meta_tensor_key = cls.get_meta_tensor_key(forward=forward)
buffer_position_key = cls.get_buffer_position_key(forward=forward)
if buffer_position_key not in fp8_meta:
return
amax_buffer_key = cls.get_amax_buffer_key(fp8_meta, forward=forward)
assert amax_buffer_key in cls.global_fp8_buffer, "TE internal error."
fp8_meta[fp8_meta_tensor_key].amax_history[0] = cls.global_fp8_buffer[amax_buffer_key][
fp8_meta[buffer_position_key]
]
@classmethod
def set_amax_buffer_key_deletion(
cls, fp8_meta: Dict[str, Any], forward: bool = True
) -> None:
"""Delete this amax key from global buffer during autocast end."""
if cls.get_autocast_key(forward=forward) not in fp8_meta:
return
if forward:
cls.buffer_delete_key_fwd = cls.get_amax_buffer_key(fp8_meta, forward=forward)
else:
cls.buffer_delete_key_bwd = cls.get_amax_buffer_key(fp8_meta, forward=forward)
@classmethod
def delete_key_from_amax_buffer(cls, forward: bool = True) -> None:
"""Delete the key from global amax buffer."""
if forward:
if (
cls.buffer_delete_key_fwd is not None
and cls.buffer_delete_key_fwd in cls.global_fp8_buffer
):
del cls.global_fp8_buffer[cls.buffer_delete_key_fwd]
else:
if (
cls.buffer_delete_key_bwd is not None
and cls.buffer_delete_key_bwd in cls.global_fp8_buffer
):
del cls.global_fp8_buffer[cls.buffer_delete_key_bwd]
@classmethod
def get_fp8_context_id(cls) -> int:
"""Returns an ID for the current FP8 context."""
return cls.FP8_CURRENT_CONTEXT_ID
@classmethod
def set_fp8_context_id(cls, ctx_id: int) -> None:
"""Sets the current FP8 context."""
cls.FP8_CURRENT_CONTEXT_ID = ctx_id
@classmethod
def new_fp8_context_id(cls) -> int:
"""Returns global autocast counter as a proxy to be used
as the autocast ID for FP8 modules.
"""
return cls.FP8_AUTOCAST_COUNTER
@classmethod
def is_fp8_enabled(cls) -> bool:
"""Is FP8 enabled"""
return cls.FP8_ENABLED
@classmethod
def is_fp8_calibration(cls) -> bool:
"""Is FP8 calibration"""
return cls.FP8_CALIBRATION
@classmethod
def is_first_fp8_module(cls):
"""Returns `True` only the first time when called multiple
times from within the same `fp8_autocast` context.
"""
tmp = cls.IS_FIRST_FP8_MODULE
cls.IS_FIRST_FP8_MODULE = False
return tmp
@classmethod
def get_fp8_recipe(cls) -> DelayedScaling:
"""Return the fp8 recipe"""
return cls.FP8_RECIPE
@classmethod
def get_fp8_group(cls) -> Union[dist_group_type, None]:
"""Return the fp8 group for scale/amax comm"""
return cls.FP8_DISTRIBUTED_GROUP
@classmethod
def get_fp8_autocast_state(cls) -> Tuple[bool, bool, DelayedScaling, dist_group_type, bool]:
"""FP8 autocast state getter"""
return (
cls.FP8_ENABLED,
cls.FP8_CALIBRATION,
cls.FP8_RECIPE,
cls.FP8_DISTRIBUTED_GROUP,
cls.IS_FIRST_FP8_MODULE)
@classmethod
def set_fp8_autocast_state(
cls,
fp8_state: Tuple[bool, bool, DelayedScaling, dist_group_type, bool]
) -> None:
"""FP8 autocast state setter"""
(cls.FP8_ENABLED,
cls.FP8_CALIBRATION,
cls.FP8_RECIPE,
cls.FP8_DISTRIBUTED_GROUP,
cls.IS_FIRST_FP8_MODULE) = fp8_state
@staticmethod
def reduce_tensor_across_group_op_max(
tensor: torch.Tensor, group: dist_group_type, async_op: bool
) -> None:
"""Reduce tensor across given group."""
if torch.distributed.is_initialized():
wait_handle = torch.distributed.all_reduce(
tensor,
op=torch.distributed.ReduceOp.MAX,
group=group,
async_op=async_op,
)
return wait_handle
return None
# Store updated amaxes and scales from phase 1 post forward.
fp8_meta["updated_amax_history_fwd"] = fp8_meta["scaling_fwd"].amax_history
fp8_meta["updated_scale_fwd"] = fp8_meta["scaling_fwd"].scale
fp8_meta["updated_scale_inv_fwd"] = fp8_meta["scaling_fwd"].scale_inv
# Retrieve stashed amaxes and scales from phase 1 pre forward.
buffer_position_key = "global_fp8_buffer_pos_fwd_recompute"
stashed_fp8_meta = _fp8_tensors_recompute_buffer[
fp8_meta[buffer_position_key]
].popleft()
# Replace amaxes and scales with stashed values for phase 2 forward
fp8_meta["scaling_fwd"].amax_history = stashed_fp8_meta[0]
fp8_meta["scaling_fwd"].scale = stashed_fp8_meta[1]
fp8_meta["scaling_fwd"].scale_inv = stashed_fp8_meta[2]
def restore_fp8_meta_tensors(fp8_meta: Dict[str, Any]) -> None:
"""Restore latest scaling factors and amaxes after recompute forward run."""
fp8_meta["scaling_fwd"].amax_history = fp8_meta["updated_amax_history_fwd"]
fp8_meta["scaling_fwd"].scale = fp8_meta["updated_scale_fwd"]
fp8_meta["scaling_fwd"].scale_inv = fp8_meta["updated_scale_inv_fwd"]
def copy_amax_from_global_buffer(
fp8_meta: Dict[str, Any], forward: bool = True
) -> None:
"""Populate current amax with the correct location from buffer."""
fp8_meta_tensor_key = get_meta_tensor_key(forward=forward)
buffer_position_key = get_buffer_position_key(forward=forward)
if buffer_position_key not in fp8_meta:
return
amax_buffer_key = get_amax_buffer_key(fp8_meta, forward=forward)
assert amax_buffer_key in _global_fp8_buffer, "TE internal error."
@classmethod
def global_amax_reduction(
cls,
fp8_meta: Dict[str, Any],
tp_group: dist_group_type,
tp_size: int,
forward: bool = True,
) -> None:
"""Concatenate, reduce, and split amaxes in the global buffer."""
amax_buffer_key = cls.get_amax_buffer_key(fp8_meta, forward=forward)
# Key already deleted.
if amax_buffer_key not in cls.global_fp8_buffer:
return None
fp8_meta[fp8_meta_tensor_key].amax_history[0] = _global_fp8_buffer[amax_buffer_key][
fp8_meta[buffer_position_key]
]
# Reduce AMAX in DP-domain at an interval.
if cls.dp_amax_reduce_interval is None:
cls.dp_amax_reduce_interval = int(os.getenv("NVTE_DP_AMAX_REDUCE_INTERVAL", "1"))
tp_amax_reduce = False
if forward:
if cls.dp_amax_reduce_forward_idx == 0:
reduce_group = fp8_meta["fp8_group"]
else:
tp_amax_reduce = True
cls.dp_amax_reduce_forward_idx = (
(cls.dp_amax_reduce_forward_idx + 1) % cls.dp_amax_reduce_interval)
else:
if cls.dp_amax_reduce_backward_idx == 0:
reduce_group = fp8_meta["fp8_group"]
else:
tp_amax_reduce = True
cls.dp_amax_reduce_backward_idx = (
(cls.dp_amax_reduce_backward_idx + 1) % cls.dp_amax_reduce_interval)
if tp_amax_reduce:
if tp_size > 1:
reduce_group = tp_group
else:
return None
chunk_sizes = [x.numel() for x in cls.global_fp8_buffer[amax_buffer_key]]
contiguous_amax = torch.cat(cls.global_fp8_buffer[amax_buffer_key])
wait_handle = cls.reduce_tensor_across_group_op_max(
contiguous_amax,
reduce_group,
fp8_meta["async_amax_reduction"],
)
cls.global_fp8_buffer[amax_buffer_key] = list(contiguous_amax.split(chunk_sizes))
return wait_handle
def set_amax_buffer_key_deletion(
fp8_meta: Dict[str, Any], forward: bool = True
) -> None:
"""Delete this amax key from global buffer during autocast end."""
if get_autocast_key(forward=forward) not in fp8_meta:
return
global _buffer_delete_key_fwd, _buffer_delete_key_bwd
if forward:
_buffer_delete_key_fwd = get_amax_buffer_key(fp8_meta, forward=forward)
else:
_buffer_delete_key_bwd = get_amax_buffer_key(fp8_meta, forward=forward)
@classmethod
def fp8_autocast_enter(
cls,
enabled: bool = False,
calibrating: bool = False,
fp8_recipe: Optional[DelayedScaling] = None,
fp8_group: Optional[dist_group_type] = None,
) -> None:
"""Set state and tracking variables for entry into FP8 region."""
cls.FP8_ENABLED = enabled
cls.FP8_CALIBRATION = calibrating
cls.FP8_RECIPE = get_default_fp8_recipe() if fp8_recipe is None else fp8_recipe
cls.FP8_DISTRIBUTED_GROUP = fp8_group
if cls.FP8_AUTOCAST_DEPTH == 0:
cls.IS_FIRST_FP8_MODULE = True
cls.FP8_AUTOCAST_COUNTER += 1
cls.FP8_AUTOCAST_DEPTH += 1
if enabled:
fp8_available, reason_for_no_fp8 = cls.is_fp8_available()
assert fp8_available, reason_for_no_fp8
def get_default_fp8_recipe() -> DelayedScaling:
"""FP8 recipe if not provided by user
Margin = 0, interval = 1, E4M3
"""
return DelayedScaling()
@classmethod
def fp8_autocast_exit(cls):
"""Set state and tracking variables for exit from FP8 region."""
cls.FP8_AUTOCAST_DEPTH -= 1
if cls.FP8_AUTOCAST_DEPTH == 0:
if callable(cls.amax_forward_global_reduce_func):
cls.amax_reduce_handle_fwd = cls.amax_forward_global_reduce_func() # pylint: disable=not-callable
cls.delete_key_from_amax_buffer(forward=True)
@classmethod
def copy_forward_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> None:
"""Copy the scaling factors and amaxes for recompute forward phase
to ensure both forward steps are numerically same.
"""
buffer_position_key = "global_fp8_buffer_pos_fwd_recompute"
to_copy = [
fp8_meta["scaling_fwd"].amax_history.clone(),
fp8_meta["scaling_fwd"].scale.clone(),
fp8_meta["scaling_fwd"].scale_inv.clone(),
]
if buffer_position_key in fp8_meta:
cls.fp8_tensors_recompute_buffer[fp8_meta[buffer_position_key]].append(to_copy)
else:
if len(cls.fp8_tensors_recompute_buffer) == 0:
cls.fp8_tensors_recompute_buffer = [deque()]
else:
cls.fp8_tensors_recompute_buffer.append(deque())
cls.fp8_tensors_recompute_buffer[-1].append(to_copy)
fp8_meta[buffer_position_key] = len(cls.fp8_tensors_recompute_buffer) - 1
@classmethod
def get_old_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> None:
"""Switch to the copied scaling factors and amaxes from phase
1 forward for indentical numerical outputs.
"""
# Store updated amaxes and scales from phase 1 post forward.
fp8_meta["updated_amax_history_fwd"] = fp8_meta["scaling_fwd"].amax_history
fp8_meta["updated_scale_fwd"] = fp8_meta["scaling_fwd"].scale
fp8_meta["updated_scale_inv_fwd"] = fp8_meta["scaling_fwd"].scale_inv
# Retrieve stashed amaxes and scales from phase 1 pre forward.
buffer_position_key = "global_fp8_buffer_pos_fwd_recompute"
stashed_fp8_meta = cls.fp8_tensors_recompute_buffer[
fp8_meta[buffer_position_key]
].popleft()
# Replace amaxes and scales with stashed values for phase 2 forward
fp8_meta["scaling_fwd"].amax_history = stashed_fp8_meta[0]
fp8_meta["scaling_fwd"].scale = stashed_fp8_meta[1]
fp8_meta["scaling_fwd"].scale_inv = stashed_fp8_meta[2]
@staticmethod
def restore_fp8_meta_tensors(fp8_meta: Dict[str, Any]) -> None:
"""Restore latest scaling factors and amaxes after recompute forward run."""
fp8_meta["scaling_fwd"].amax_history = fp8_meta["updated_amax_history_fwd"]
fp8_meta["scaling_fwd"].scale = fp8_meta["updated_scale_fwd"]
fp8_meta["scaling_fwd"].scale_inv = fp8_meta["updated_scale_inv_fwd"]
@contextmanager
......@@ -272,96 +493,16 @@ def fp8_autocast(
distributed group over which amaxes for the fp8 tensors
are reduced at the end of each training step.
"""
global _FP8_ENABLED, _FP8_CALIBRATION, _FP8_RECIPE, _FP8_DISTRIBUTED_GROUP, _FP8_AUTOCAST_DEPTH
global _IS_FIRST_FP8_MODULE, _FP8_AUTOCAST_COUNTER
global _global_fp8_buffer, _buffer_delete_key_fwd
global _amax_reduce_handle_fwd
fp8_state = (
_FP8_ENABLED,
_FP8_CALIBRATION,
_FP8_RECIPE,
_FP8_DISTRIBUTED_GROUP,
_IS_FIRST_FP8_MODULE)
try:
_FP8_ENABLED = enabled
_FP8_CALIBRATION = calibrating
_FP8_RECIPE = get_default_fp8_recipe() if fp8_recipe is None else fp8_recipe
_FP8_DISTRIBUTED_GROUP = fp8_group
if _FP8_AUTOCAST_DEPTH == 0:
_IS_FIRST_FP8_MODULE = True
_FP8_AUTOCAST_COUNTER += 1
_FP8_AUTOCAST_DEPTH += 1
if enabled:
fp8_available, reason_for_no_fp8 = is_fp8_available()
assert fp8_available, reason_for_no_fp8
fp8_state = FP8GlobalStateManager.get_fp8_autocast_state()
FP8GlobalStateManager.fp8_autocast_enter(enabled, calibrating, fp8_recipe, fp8_group)
yield
finally:
(_FP8_ENABLED,
_FP8_CALIBRATION,
_FP8_RECIPE,
_FP8_DISTRIBUTED_GROUP,
_IS_FIRST_FP8_MODULE) = fp8_state
_FP8_AUTOCAST_DEPTH -= 1
if _FP8_AUTOCAST_DEPTH == 0:
if callable(_amax_forward_global_reduce_func):
_amax_reduce_handle_fwd = _amax_forward_global_reduce_func()
delete_key_from_amax_buffer(forward=True)
def get_fp8_context_id() -> int:
"""Returns an ID for the current FP8 context."""
return _FP8_CURRENT_CONTEXT_ID
def set_fp8_context_id(ctx_id: int) -> None:
"""Sets the current FP8 context."""
global _FP8_CURRENT_CONTEXT_ID
_FP8_CURRENT_CONTEXT_ID = ctx_id
def new_fp8_context_id() -> int:
"""Returns global autocast counter as a proxy to be used
as the autocast ID for FP8 modules.
"""
return _FP8_AUTOCAST_COUNTER
def is_fp8_enabled() -> bool:
"""Is FP8 enabled"""
return _FP8_ENABLED
def is_fp8_calibration() -> bool:
"""Is FP8 calibration"""
return _FP8_CALIBRATION
def is_first_fp8_module():
"""Returns `True` only the first time when called multiple
times from within the same `fp8_autocast` context.
"""
global _IS_FIRST_FP8_MODULE
tmp = _IS_FIRST_FP8_MODULE
_IS_FIRST_FP8_MODULE = False
return tmp
FP8GlobalStateManager.set_fp8_autocast_state(fp8_state) # pylint: disable=used-before-assignment
FP8GlobalStateManager.fp8_autocast_exit()
def get_fp8_recipe() -> DelayedScaling:
"""Return the fp8 recipe"""
return _FP8_RECIPE
def get_fp8_group() -> Union[dist_group_type, None]:
"""Return the fp8 group for scale/amax comm"""
return _FP8_DISTRIBUTED_GROUP
def update_amax_history(amax_history: torch.Tensor) -> torch.Tensor:
def _update_amax_history(amax_history: torch.Tensor) -> torch.Tensor:
"""Update amax history and set next amax to zero."""
if amax_history.shape[0] > 1:
amax_history = torch.roll(amax_history, -1, 0)
......@@ -380,7 +521,7 @@ def _default_get_amax(
else: # amax_compute_algo == "most_recent"
amax = amax_history[0].clone()
amax_history = update_amax_history(amax_history)
amax_history = _update_amax_history(amax_history)
return amax_history, amax
......@@ -415,7 +556,7 @@ def _compute_scaling_factor_inverse(
@jit_fuser
def fused_amax_and_scale_update(
def _fused_amax_and_scale_update(
amax_history: torch.Tensor,
scale: torch.Tensor,
scale_inv: torch.Tensor,
......@@ -460,7 +601,7 @@ def _compute_amax(
if callable(recipe.amax_compute_algo):
amax = recipe.amax_compute_algo(amax_history)
amax_history = update_amax_history(amax_history)
amax_history = _update_amax_history(amax_history)
return amax_history, amax
return _default_get_amax(
amax_history,
......@@ -502,7 +643,7 @@ def amax_and_scale_update(
fp8_meta[fp8_meta_tensor_key].amax_history,
fp8_meta[fp8_meta_tensor_key].scale,
fp8_meta[fp8_meta_tensor_key].scale_inv,
) = fused_amax_and_scale_update(
) = _fused_amax_and_scale_update(
fp8_meta[fp8_meta_tensor_key].amax_history,
fp8_meta[fp8_meta_tensor_key].scale,
fp8_meta[fp8_meta_tensor_key].scale_inv,
......@@ -529,99 +670,3 @@ def amax_and_scale_update(
fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"],
update_weight_scale_inv,
)
def get_fp8_te_dtype(
fp8_recipe: DelayedScaling, fprop_tensor: bool = True
) -> tex.DType:
"""Get fp8 data type according to recipe and tensor"""
if fp8_recipe.fp8_format == Format.E4M3 or (
fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor
):
return tex.DType.kFloat8E4M3
return tex.DType.kFloat8E5M2
def reduce_tensor_across_group_op_max(
tensor: torch.Tensor, group: dist_group_type, async_op: bool
) -> None:
"""Reduce tensor across given group."""
if torch.distributed.is_initialized():
wait_handle = torch.distributed.all_reduce(
tensor,
op=torch.distributed.ReduceOp.MAX,
group=group,
async_op=async_op,
)
return wait_handle
return None
def global_amax_reduction(
fp8_meta: Dict[str, Any],
tp_group: dist_group_type,
tp_size: int,
forward: bool = True,
) -> None:
"""Concatenate, reduce, and split amaxes in the global buffer."""
global _global_fp8_buffer
amax_buffer_key = get_amax_buffer_key(fp8_meta, forward=forward)
# Key already deleted.
if amax_buffer_key not in _global_fp8_buffer:
return None
# Reduce AMAX in DP-domain at an interval.
global _dp_amax_reduce_interval, _dp_amax_reduce_forward_idx, _dp_amax_reduce_backward_idx
if _dp_amax_reduce_interval is None:
_dp_amax_reduce_interval = int(os.getenv("NVTE_DP_AMAX_REDUCE_INTERVAL", "1"))
tp_amax_reduce = False
if forward:
if _dp_amax_reduce_forward_idx == 0:
reduce_group = fp8_meta["fp8_group"]
else:
tp_amax_reduce = True
_dp_amax_reduce_forward_idx = (_dp_amax_reduce_forward_idx + 1) % _dp_amax_reduce_interval
else:
if _dp_amax_reduce_backward_idx == 0:
reduce_group = fp8_meta["fp8_group"]
else:
tp_amax_reduce = True
_dp_amax_reduce_backward_idx = (_dp_amax_reduce_backward_idx + 1) % _dp_amax_reduce_interval
if tp_amax_reduce:
if tp_size > 1:
reduce_group = tp_group
else:
return None
chunk_sizes = [x.numel() for x in _global_fp8_buffer[amax_buffer_key]]
contiguous_amax = torch.cat(_global_fp8_buffer[amax_buffer_key])
wait_handle = reduce_tensor_across_group_op_max(
contiguous_amax,
reduce_group,
fp8_meta["async_amax_reduction"],
)
_global_fp8_buffer[amax_buffer_key] = list(contiguous_amax.split(chunk_sizes))
return wait_handle
def delete_key_from_amax_buffer(forward: bool = True) -> None:
"""Delete the key from global amax buffer."""
global _global_fp8_buffer, _buffer_delete_key_fwd, _buffer_delete_key_bwd
if forward:
if (
_buffer_delete_key_fwd is not None
and _buffer_delete_key_fwd in _global_fp8_buffer
):
del _global_fp8_buffer[_buffer_delete_key_fwd]
else:
if (
_buffer_delete_key_bwd is not None
and _buffer_delete_key_bwd in _global_fp8_buffer
):
del _global_fp8_buffer[_buffer_delete_key_bwd]
......@@ -19,29 +19,10 @@ from torch.nn.parameter import Parameter
import transformer_engine_extensions as tex
from ..export import is_in_onnx_export_mode
from ..fp8 import (
is_fp8_enabled,
is_fp8_calibration,
get_fp8_recipe,
get_fp8_group,
get_default_fp8_recipe,
get_fp8_te_dtype,
is_first_fp8_module,
new_fp8_context_id,
get_fp8_context_id,
set_fp8_context_id,
add_amax_to_global_buffer,
copy_amax_from_global_buffer,
global_amax_reduction,
setup_amax_forward_global_reduce_func,
FP8GlobalStateManager,
amax_and_scale_update,
get_global_fp8_buffer,
set_global_fp8_buffer,
set_amax_buffer_key_deletion,
delete_key_from_amax_buffer,
copy_forward_fp8_meta_tensors_for_recompute,
get_old_fp8_meta_tensors_for_recompute,
restore_fp8_meta_tensors,
get_amax_reduce_handle_fwd,
)
from ..distributed import (
gather_along_first_dim,
......@@ -97,30 +78,30 @@ def _prepare_backward(
# Update amax and scale; Skip all setup for global amax reduction
if not fp8_meta["recipe"].reduce_amax:
amax_and_scale_update(fp8_meta, False)
FP8GlobalStateManager.amax_and_scale_update(fp8_meta, False)
else:
# From previous iteration
copy_amax_from_global_buffer(fp8_meta, forward=False)
FP8GlobalStateManager.copy_amax_from_global_buffer(fp8_meta, forward=False)
amax_and_scale_update(fp8_meta, False)
set_amax_buffer_key_deletion(fp8_meta, forward=False)
FP8GlobalStateManager.set_amax_buffer_key_deletion(fp8_meta, forward=False)
# Get new backward key.
fp8_meta["autocast_id_bwd"] = fp8_meta["autocast_id_fwd_stack"].pop(0)
add_amax_to_global_buffer(fp8_meta, forward=False)
FP8GlobalStateManager.add_amax_to_global_buffer(fp8_meta, forward=False)
with torch.cuda.nvtx.range(name + " backward"):
yield
if fp8 and fp8_meta["recipe"].reduce_amax:
if fp8_meta["first_module"]:
_amax_reduce_handle_bwd = global_amax_reduction(
_amax_reduce_handle_bwd = FP8GlobalStateManager.global_amax_reduction(
fp8_meta,
tp_group,
tp_size,
forward=False
)
delete_key_from_amax_buffer(forward=False)
FP8GlobalStateManager.delete_key_from_amax_buffer(forward=False)
def initialize_ub(
......@@ -356,12 +337,13 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
state["scale_bwd"] = self.fp8_meta["scaling_bwd"].scale
state["scale_inv_bwd"] = self.fp8_meta["scaling_bwd"].scale_inv
state["amax_history_bwd"] = self.fp8_meta["scaling_bwd"].amax_history
state["global_fp8_buffer"] = get_global_fp8_buffer()
state["global_fp8_buffer"] = FP8GlobalStateManager.get_global_fp8_buffer_checkpoint()
state["global_fp8_state"] = FP8GlobalStateManager.get_global_fp8_state_checkpoint()
# Store other pickelable values.
extra = {}
for k, v in self.fp8_meta.items():
if isinstance(v, (bool, int, float, str)):
if isinstance(v, (bool, int, float, str, list)):
extra[k] = v
state["extra_fp8_variables"] = extra
......@@ -403,7 +385,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.fp8_meta["scaling_bwd"].amax_history.copy_(amax_history_bwd)
# Restore global FP8 buffer state.
set_global_fp8_buffer(state[4])
FP8GlobalStateManager.set_global_fp8_buffer_checkpoint(state[4])
self.fp8_meta["update_amax_and_scale_fwd"] = state[5]
self.fp8_meta["global_fp8_buffer_pos_fwd"] = state[6]
self.fp8_meta["global_fp8_buffer_pos_bwd"] = state[7]
......@@ -420,8 +402,16 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if state is None:
return
# Restore global FP8 buffer states.
set_global_fp8_buffer(state["global_fp8_buffer"])
# Restore global FP8 amax buffer.
FP8GlobalStateManager.set_global_fp8_buffer_checkpoint(state["global_fp8_buffer"])
# Restore global FP8 state.
if "global_fp8_state" in state:
FP8GlobalStateManager.set_global_fp8_state_checkpoint(state["global_fp8_state"])
else:
warnings.warn(
"This checkpoint format is deprecated and will be"
"removed in a future release of Transformer Engine"
)
# Load extra items.
self.fp8_meta.update(state["extra_fp8_variables"])
self.fp8_meta["recipe"].amax_history_len = state["amax_history_fwd"].shape[0]
......@@ -525,19 +515,20 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# assume FP8 execution.
def fp8_init(self, num_gemms: int = 1) -> None:
"""Initialize fp8 related metadata and tensors during fprop."""
self.fp8 = is_fp8_enabled()
self.fp8_calibration = is_fp8_calibration()
self.fp8 = FP8GlobalStateManager.is_fp8_enabled()
self.fp8_calibration = FP8GlobalStateManager.is_fp8_calibration()
self.fp8_meta["fp8_checkpoint"] = self.fp8 or self.fp8_calibration
if self.fp8 or self.fp8_calibration:
# FP8 init has already been run and recipe is the same, don't do anything.
if self.fp8_initialized and get_fp8_recipe() == self.fp8_meta["recipe"]:
if (self.fp8_initialized
and FP8GlobalStateManager.get_fp8_recipe() == self.fp8_meta["recipe"]):
return
# Set FP8, recipe, and other FP8 metadata
self.fp8_meta["recipe"] = get_fp8_recipe()
self.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe()
self.fp8_meta["num_gemms"] = num_gemms
self.fp8_meta["fp8_group"] = get_fp8_group()
self.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group()
# Set FP8_MAX per tensor according to recipe
self.fp8_meta["fp8_max_fwd"] = self.fp8_meta["recipe"].fp8_format.value.max_fwd
......@@ -567,7 +558,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# Activation recomputation is used and this is the second forward phase.
if self.fp8 and in_fp8_activation_recompute_phase():
get_old_fp8_meta_tensors_for_recompute(self.fp8_meta)
FP8GlobalStateManager.get_old_fp8_meta_tensors_for_recompute(self.fp8_meta)
else:
assert inp.is_cuda, "TransformerEngine needs CUDA."
......@@ -591,11 +582,11 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# Previous iteration was grad_enabled
if self.fp8_meta.get("update_amax_and_scale_fwd", False):
if self.fp8_meta["recipe"].reduce_amax:
copy_amax_from_global_buffer(self.fp8_meta, forward=True)
FP8GlobalStateManager.copy_amax_from_global_buffer(self.fp8_meta, forward=True)
amax_and_scale_update(
self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv
)
set_amax_buffer_key_deletion(self.fp8_meta, forward=True)
FP8GlobalStateManager.set_amax_buffer_key_deletion(self.fp8_meta, forward=True)
else:
amax_and_scale_update(
self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv
......@@ -604,20 +595,22 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if self.fp8 and self.training:
# Setup for amax reduction
if self.fp8_meta["recipe"].reduce_amax:
self.fp8_meta["first_module"] = is_first_fp8_module()
self.fp8_meta["first_module"] = FP8GlobalStateManager.is_first_fp8_module()
if self.fp8_meta["first_module"]:
# Wait for the prior AMAX reduction to finish
amax_reduce_handle_fwd = get_amax_reduce_handle_fwd()
amax_reduce_handle_fwd = FP8GlobalStateManager.get_amax_reduce_handle_fwd()
if amax_reduce_handle_fwd is not None:
amax_reduce_handle_fwd.wait()
self.fp8_meta["autocast_id_fwd"] = new_fp8_context_id()
set_fp8_context_id(self.fp8_meta["autocast_id_fwd"])
self.fp8_meta["autocast_id_fwd"] = (
FP8GlobalStateManager.new_fp8_context_id())
FP8GlobalStateManager.set_fp8_context_id(self.fp8_meta["autocast_id_fwd"])
else:
self.fp8_meta["autocast_id_fwd"] = get_fp8_context_id()
self.fp8_meta["autocast_id_fwd"] = (
FP8GlobalStateManager.get_fp8_context_id())
self.fp8_meta["autocast_id_fwd_stack"].append(
self.fp8_meta["autocast_id_fwd"]
)
add_amax_to_global_buffer(self.fp8_meta, forward=True)
FP8GlobalStateManager.add_amax_to_global_buffer(self.fp8_meta, forward=True)
self.fp8_meta["update_amax_and_scale_fwd"] = True
else:
self.fp8_meta["update_amax_and_scale_fwd"] = False
......@@ -629,25 +622,25 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
and is_fp8_activation_recompute_enabled()
and not in_fp8_activation_recompute_phase()
):
copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta)
FP8GlobalStateManager.copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta)
with torch.cuda.nvtx.range(self.__class__.__name__ + " forward"):
yield inp.contiguous()
if self.fp8 and in_fp8_activation_recompute_phase():
restore_fp8_meta_tensors(self.fp8_meta)
FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta)
return
if self.fp8 and self.training and self.fp8_meta["recipe"].reduce_amax:
set_fp8_context_id(self.fp8_meta["autocast_id_fwd"])
FP8GlobalStateManager.set_fp8_context_id(self.fp8_meta["autocast_id_fwd"])
reduce_func = partial(
global_amax_reduction,
FP8GlobalStateManager.global_amax_reduction,
self.fp8_meta,
self.tp_group,
self.tp_size,
forward=True
)
setup_amax_forward_global_reduce_func(reduce_func)
FP8GlobalStateManager.setup_amax_forward_global_reduce_func(reduce_func)
def set_nccl_overlap_warning_if_tp(self) -> None:
"""When using TP, the NCCL communication needs to be scheduled
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment