Unverified Commit 66055973 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Raise autocast usage error (#93)



* catch incorrect usage of fp8_autocast
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* catch error on first time double execution
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent c9245c02
...@@ -95,6 +95,12 @@ def add_amax_to_global_buffer(fp8_meta: Dict[str, Any], forward: bool = True) -> ...@@ -95,6 +95,12 @@ def add_amax_to_global_buffer(fp8_meta: Dict[str, Any], forward: bool = True) ->
if buffer_position_key not in fp8_meta: if buffer_position_key not in fp8_meta:
fp8_meta[buffer_position_key] = len(_global_fp8_buffer[buffer_key]) - 1 fp8_meta[buffer_position_key] = len(_global_fp8_buffer[buffer_key]) - 1
# Catch incorrect fp8_autocast usage.
assert fp8_meta[buffer_position_key] == len(_global_fp8_buffer[buffer_key]) - 1, \
"Same module is being invoked more than once inside an `fp8_autocast` region when using " \
"FP8 with amax reduction. This behavior is currently unsupported. For more details and " \
"correct usage, please see https://github.com/NVIDIA/TransformerEngine/pull/93."
def copy_forward_fp8_meta_tensors_for_recompute(fp8_meta: Dict[str, Any]) -> None: def copy_forward_fp8_meta_tensors_for_recompute(fp8_meta: Dict[str, Any]) -> None:
"""Copy the scaling factors and amaxes for recompute forward phase """Copy the scaling factors and amaxes for recompute forward phase
...@@ -157,7 +163,10 @@ def copy_amax_from_global_buffer( ...@@ -157,7 +163,10 @@ def copy_amax_from_global_buffer(
buffer_position_key = get_buffer_position_key(forward=forward) buffer_position_key = get_buffer_position_key(forward=forward)
if buffer_position_key not in fp8_meta: if buffer_position_key not in fp8_meta:
return return
amax_buffer_key = get_amax_buffer_key(fp8_meta, forward=forward) amax_buffer_key = get_amax_buffer_key(fp8_meta, forward=forward)
assert amax_buffer_key in _global_fp8_buffer, "TE internal error."
fp8_meta[fp8_meta_tensor_key].amax_history[0] = _global_fp8_buffer[amax_buffer_key][ fp8_meta[fp8_meta_tensor_key].amax_history[0] = _global_fp8_buffer[amax_buffer_key][
fp8_meta[buffer_position_key] fp8_meta[buffer_position_key]
] ]
...@@ -204,6 +213,14 @@ def fp8_autocast( ...@@ -204,6 +213,14 @@ def fp8_autocast(
with shapes where both dimensions are divisible by 16. In terms of the input to the full with shapes where both dimensions are divisible by 16. In terms of the input to the full
Transformer network, this typically requires padding sequence length to be multiple of 16. Transformer network, this typically requires padding sequence length to be multiple of 16.
.. note::
When :attr:`fp8_recipe.reduce_amax==True`, any module must not be invoked more than once
inside a single `fp8_autocast` region. This is unsupported behavior because the amax
reduction is handled during the exit of the `fp8_autocast` context. Calling the same
module more than once inside an `fp8_autocast` region overrides the amax tensors
before reduction can occur.
Parameters Parameters
---------- ----------
enabled: bool, default = `False` enabled: bool, default = `False`
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment