Unverified Commit 427c736d authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[PyTorch] Fixes and tests for FP8 + activation recompute (#487)



* initial test fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Drop eval for selective checkpointing tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Remove redundant recompute for FA
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* CI fix; Decouple fused attention and numerics tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent f5d720a0
...@@ -25,8 +25,6 @@ from transformer_engine.pytorch.cpp_extensions.fused_attn import ( ...@@ -25,8 +25,6 @@ from transformer_engine.pytorch.cpp_extensions.fused_attn import (
QKVLayout, QKVLayout,
fused_attn_bwd, fused_attn_bwd,
fused_attn_fwd, fused_attn_fwd,
fused_attn_bwd_qkvpacked,
fused_attn_fwd_qkvpacked,
) )
import transformer_engine.pytorch.fp8 as fp8 import transformer_engine.pytorch.fp8 as fp8
from transformer_engine.pytorch.module.base import ( from transformer_engine.pytorch.module.base import (
...@@ -38,13 +36,24 @@ from transformer_engine.pytorch.utils import ( ...@@ -38,13 +36,24 @@ from transformer_engine.pytorch.utils import (
init_method_normal, init_method_normal,
scaled_init_method_normal, scaled_init_method_normal,
) )
from transformer_engine.pytorch.distributed import _set_cuda_rng_state, CudaRNGStatesTracker
import transformer_engine_extensions as tex import transformer_engine_extensions as tex
from test_numerics import get_dummy_cuda_rng_tracker, reset_rng_states
# Only run FP8 tests on H100.
fp8_available, reason_for_no_fp8 = fp8.FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = fp8.FP8GlobalStateManager.is_fp8_available()
_flash_attn_version = packaging.version.Version(version("flash-attn")) _flash_attn_version = packaging.version.Version(version("flash-attn"))
_flash_attn_2_available = _flash_attn_version >= packaging.version.Version("2") _flash_attn_2_available = _flash_attn_version >= packaging.version.Version("2")
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# Record initial RNG state from script run.
_cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state()
def _get_cudnn_version(): def _get_cudnn_version():
cudnn_version_encoded = ext.get_cudnn_version() cudnn_version_encoded = ext.get_cudnn_version()
cudnn_major = cudnn_version_encoded // 1000 cudnn_major = cudnn_version_encoded // 1000
...@@ -52,6 +61,13 @@ def _get_cudnn_version(): ...@@ -52,6 +61,13 @@ def _get_cudnn_version():
cudnn_patch = cudnn_version_encoded - 1000 * cudnn_major - 100 * cudnn_minor cudnn_patch = cudnn_version_encoded - 1000 * cudnn_major - 100 * cudnn_minor
return [cudnn_major, cudnn_minor, cudnn_patch] return [cudnn_major, cudnn_minor, cudnn_patch]
def reset_rng_states() -> None:
"""revert back to initial RNG state."""
torch.set_rng_state(_cpu_rng_state)
_set_cuda_rng_state(_cuda_rng_state)
_cudnn_version = _get_cudnn_version() _cudnn_version = _get_cudnn_version()
...@@ -210,6 +226,13 @@ def _run_dot_product_attention(dtype, bs, config, backend, ckpt_attn, bias_type) ...@@ -210,6 +226,13 @@ def _run_dot_product_attention(dtype, bs, config, backend, ckpt_attn, bias_type)
else: else:
bias = None bias = None
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
def get_dummy_cuda_rng_tracker():
"""Get cuda rng tracker."""
return _DUMMY_CUDA_RNG_STATE_TRACKER
block = ( block = (
DotProductAttention( DotProductAttention(
config.num_attention_heads, config.num_attention_heads,
...@@ -733,6 +756,13 @@ def _run_dpa_fp8_ref(dtype, bs, config, backend): ...@@ -733,6 +756,13 @@ def _run_dpa_fp8_ref(dtype, bs, config, backend):
cu_seqlens[1:] = torch.cumsum(seqlens, dim=0) cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
op_grad = torch.load('op_grad.pt').cuda().view(bs, config.seq_len, -1).transpose(0,1) op_grad = torch.load('op_grad.pt').cuda().view(bs, config.seq_len, -1).transpose(0,1)
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
def get_dummy_cuda_rng_tracker():
"""Get cuda rng tracker."""
return _DUMMY_CUDA_RNG_STATE_TRACKER
block = ( block = (
DotProductAttention( DotProductAttention(
config.num_attention_heads, config.num_attention_heads,
......
...@@ -12,6 +12,7 @@ import torch ...@@ -12,6 +12,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch.nn import Parameter from torch.nn import Parameter
from transformer_engine.pytorch.fp8 import fp8_autocast, FP8GlobalStateManager
from transformer_engine.pytorch.utils import ( from transformer_engine.pytorch.utils import (
init_method_normal, init_method_normal,
scaled_init_method_normal, scaled_init_method_normal,
...@@ -25,6 +26,10 @@ from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint ...@@ -25,6 +26,10 @@ from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint
from transformer_engine.pytorch.distributed import _set_cuda_rng_state, CudaRNGStatesTracker from transformer_engine.pytorch.distributed import _set_cuda_rng_state, CudaRNGStatesTracker
# Only run FP8 tests on H100.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
seed = 1234 seed = 1234
torch.manual_seed(seed) torch.manual_seed(seed)
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
...@@ -90,20 +95,11 @@ def assert_allclose(l1: List[torch.Tensor], l2: List[torch.Tensor], atol: float) ...@@ -90,20 +95,11 @@ def assert_allclose(l1: List[torch.Tensor], l2: List[torch.Tensor], atol: float)
def reset_rng_states() -> None: def reset_rng_states() -> None:
# revert back to initial RNG state. """revert back to initial RNG state."""
torch.set_rng_state(_cpu_rng_state) torch.set_rng_state(_cpu_rng_state)
_set_cuda_rng_state(_cuda_rng_state) _set_cuda_rng_state(_cuda_rng_state)
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
def get_dummy_cuda_rng_tracker():
"""Get cuda rng tracker."""
return _DUMMY_CUDA_RNG_STATE_TRACKER
class TorchScaledMaskedSoftmax(nn.Module): class TorchScaledMaskedSoftmax(nn.Module):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
...@@ -343,8 +339,39 @@ class TorchGPT(nn.Module): ...@@ -343,8 +339,39 @@ class TorchGPT(nn.Module):
return x return x
def _test_e2e_selective_recompute(block, bs, dtype, config, recompute=False): def _test_e2e_selective_recompute(bs, dtype, config, fp8, recompute=False):
reset_rng_states() reset_rng_states()
FP8GlobalStateManager.reset()
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
def get_dummy_cuda_rng_tracker():
"""Get cuda rng tracker."""
return _DUMMY_CUDA_RNG_STATE_TRACKER
block = (
TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.embed,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
get_rng_state_tracker=get_dummy_cuda_rng_tracker,
params_dtype=dtype,
)
.cuda()
)
te_inp_hidden_states = torch.randn( te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
...@@ -352,6 +379,7 @@ def _test_e2e_selective_recompute(block, bs, dtype, config, recompute=False): ...@@ -352,6 +379,7 @@ def _test_e2e_selective_recompute(block, bs, dtype, config, recompute=False):
te_inp_hidden_states.retain_grad() te_inp_hidden_states.retain_grad()
te_inp_attn_mask = get_causal_attn_mask(config.seq_len) te_inp_attn_mask = get_causal_attn_mask(config.seq_len)
with fp8_autocast(enabled=fp8):
te_out = block( te_out = block(
te_inp_hidden_states, te_inp_hidden_states,
attention_mask=te_inp_attn_mask, attention_mask=te_inp_attn_mask,
...@@ -371,13 +399,33 @@ def _test_e2e_selective_recompute(block, bs, dtype, config, recompute=False): ...@@ -371,13 +399,33 @@ def _test_e2e_selective_recompute(block, bs, dtype, config, recompute=False):
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
def test_gpt_selective_activation_recompute(dtype, bs, model): @pytest.mark.parametrize("fp8", all_boolean)
def test_gpt_selective_activation_recompute(dtype, bs, model, fp8):
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
config = model_configs[model] config = model_configs[model]
outputs = _test_e2e_selective_recompute(bs, dtype, config, fp8, recompute=False)
outputs_recompute = _test_e2e_selective_recompute(bs, dtype, config, fp8, recompute=True)
assert_all_equal(outputs, outputs_recompute)
def _test_e2e_full_recompute(bs, dtype, config, fp8, recompute=False):
reset_rng_states()
FP8GlobalStateManager.reset()
sigma = 0.023 sigma = 0.023
init_method = init_method_normal(sigma) init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
def get_dummy_cuda_rng_tracker():
"""Get cuda rng tracker."""
return _DUMMY_CUDA_RNG_STATE_TRACKER
block = ( block = (
TransformerLayer( TransformerLayer(
config.hidden_size, config.hidden_size,
...@@ -395,23 +443,15 @@ def test_gpt_selective_activation_recompute(dtype, bs, model): ...@@ -395,23 +443,15 @@ def test_gpt_selective_activation_recompute(dtype, bs, model):
params_dtype=dtype, params_dtype=dtype,
) )
.cuda() .cuda()
.eval()
) )
outputs = _test_e2e_selective_recompute(block, bs, dtype, config, recompute=False)
outputs_recompute = _test_e2e_selective_recompute(block, bs, dtype, config, recompute=True)
assert_all_equal(outputs, outputs_recompute)
def _test_e2e_full_recompute(block, bs, dtype, config, recompute=False):
reset_rng_states()
te_inp_hidden_states = torch.randn( te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda() ).cuda()
te_inp_hidden_states.retain_grad() te_inp_hidden_states.retain_grad()
te_inp_attn_mask = get_causal_attn_mask(config.seq_len) te_inp_attn_mask = get_causal_attn_mask(config.seq_len)
with fp8_autocast(enabled=fp8):
if recompute: if recompute:
te_out = te_checkpoint( te_out = te_checkpoint(
block, block,
...@@ -442,35 +482,15 @@ def _test_e2e_full_recompute(block, bs, dtype, config, recompute=False): ...@@ -442,35 +482,15 @@ def _test_e2e_full_recompute(block, bs, dtype, config, recompute=False):
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
def test_gpt_full_activation_recompute(dtype, bs, model): @pytest.mark.parametrize("fp8", all_boolean)
config = model_configs[model] def test_gpt_full_activation_recompute(dtype, bs, model, fp8):
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
sigma = 0.023 config = model_configs[model]
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block = (
TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.embed,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
get_rng_state_tracker=get_dummy_cuda_rng_tracker,
params_dtype=dtype,
)
.cuda()
.eval()
)
outputs = _test_e2e_full_recompute(block, bs, dtype, config, recompute=False) outputs = _test_e2e_full_recompute(bs, dtype, config, fp8, recompute=False)
outputs_recompute = _test_e2e_full_recompute(block, bs, dtype, config, recompute=True) outputs_recompute = _test_e2e_full_recompute(bs, dtype, config, fp8, recompute=True)
assert_all_equal(outputs, outputs_recompute) assert_all_equal(outputs, outputs_recompute)
...@@ -565,8 +585,8 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path= ...@@ -565,8 +585,8 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path=
def test_gpt_checkpointing(dtype, bs, model): def test_gpt_checkpointing(dtype, bs, model):
config = model_configs[model] config = model_configs[model]
outputs = _test_e2e_checkpointing(bs, dtype, config, checkpoint=False) outputs = _test_e2e_checkpointing(bs, dtype, config, checkpoint=False)
outputs_recompute = _test_e2e_checkpointing(bs, dtype, config, checkpoint=True) outputs_checkpoint = _test_e2e_checkpointing(bs, dtype, config, checkpoint=True)
assert_all_equal(outputs, outputs_recompute) assert_all_equal(outputs, outputs_checkpoint)
def _test_e2e_gpt_accuracy(block, bs, dtype, config): def _test_e2e_gpt_accuracy(block, bs, dtype, config):
......
...@@ -2164,19 +2164,6 @@ class DotProductAttention(torch.nn.Module): ...@@ -2164,19 +2164,6 @@ class DotProductAttention(torch.nn.Module):
) )
if use_flash_attention: if use_flash_attention:
if checkpoint_core_attention:
return self._checkpointed_attention_forward(self.flash_attention,
query_layer,
key_layer,
value_layer,
attention_mask=attention_mask,
qkv_layout=qkv_layout,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
attn_mask_type=attn_mask_type,
cp_group=self.cp_group,
cp_global_ranks=self.cp_global_ranks,
cp_stream=self.cp_stream)
return self.flash_attention(query_layer, return self.flash_attention(query_layer,
key_layer, key_layer,
value_layer, value_layer,
......
...@@ -75,6 +75,29 @@ class FP8GlobalStateManager: ...@@ -75,6 +75,29 @@ class FP8GlobalStateManager:
dp_amax_reduce_forward_idx = 0 dp_amax_reduce_forward_idx = 0
dp_amax_reduce_backward_idx = 0 dp_amax_reduce_backward_idx = 0
@classmethod
def reset(cls) -> None:
"""Reset the global state"""
cls.FP8_ENABLED = False
cls.FP8_CALIBRATION = False
cls.FP8_RECIPE = None
cls.FP8_DISTRIBUTED_GROUP = None
cls.IS_FIRST_FP8_MODULE = False
cls.FP8_AUTOCAST_COUNTER = 0
cls.FP8_CURRENT_CONTEXT_ID = 0
cls.FP8_AUTOCAST_DEPTH = 0
cls.global_fp8_buffer = {}
cls.fp8_tensors_recompute_buffer = []
cls.amax_forward_global_reduce_func = None
cls.buffer_delete_key_fwd = None
cls.buffer_delete_key_bwd = None
cls.amax_reduce_handle_fwd = None
cls.fp8_available = None
cls.reason_for_no_fp8 = ""
cls.dp_amax_reduce_interval = None
cls.dp_amax_reduce_forward_idx = 0
cls.dp_amax_reduce_backward_idx = 0
@classmethod @classmethod
def is_fp8_available(cls) -> Tuple[bool, str]: def is_fp8_available(cls) -> Tuple[bool, str]:
"""Return if fp8 support is available""" """Return if fp8 support is available"""
......
...@@ -28,6 +28,7 @@ from ..distributed import ( ...@@ -28,6 +28,7 @@ from ..distributed import (
gather_along_first_dim, gather_along_first_dim,
is_fp8_activation_recompute_enabled, is_fp8_activation_recompute_enabled,
in_fp8_activation_recompute_phase, in_fp8_activation_recompute_phase,
get_distributed_world_size,
) )
from ..cpp_extensions import ( from ..cpp_extensions import (
fp8_cast_transpose_fused, fp8_cast_transpose_fused,
...@@ -77,9 +78,7 @@ def _prepare_backward( ...@@ -77,9 +78,7 @@ def _prepare_backward(
_amax_reduce_handle_bwd = None _amax_reduce_handle_bwd = None
# Update amax and scale; Skip all setup for global amax reduction # Update amax and scale; Skip all setup for global amax reduction
if not fp8_meta["recipe"].reduce_amax: if fp8_meta["recipe"].reduce_amax and get_distributed_world_size(fp8_meta["fp8_group"]) > 1:
amax_and_scale_update(fp8_meta, False)
else:
# From previous iteration # From previous iteration
FP8GlobalStateManager.copy_amax_from_global_buffer(fp8_meta, forward=False) FP8GlobalStateManager.copy_amax_from_global_buffer(fp8_meta, forward=False)
amax_and_scale_update(fp8_meta, False) amax_and_scale_update(fp8_meta, False)
...@@ -89,11 +88,14 @@ def _prepare_backward( ...@@ -89,11 +88,14 @@ def _prepare_backward(
fp8_meta["autocast_id_bwd"] = fp8_meta["autocast_id_fwd_stack"].pop(0) fp8_meta["autocast_id_bwd"] = fp8_meta["autocast_id_fwd_stack"].pop(0)
FP8GlobalStateManager.add_amax_to_global_buffer(fp8_meta, forward=False) FP8GlobalStateManager.add_amax_to_global_buffer(fp8_meta, forward=False)
else:
amax_and_scale_update(fp8_meta, False)
with torch.cuda.nvtx.range(name + " backward"): with torch.cuda.nvtx.range(name + " backward"):
yield yield
if fp8 and fp8_meta["recipe"].reduce_amax: if (fp8 and fp8_meta["recipe"].reduce_amax
and get_distributed_world_size(fp8_meta["fp8_group"]) > 1):
if fp8_meta["first_module"]: if fp8_meta["first_module"]:
_amax_reduce_handle_bwd = FP8GlobalStateManager.global_amax_reduction( _amax_reduce_handle_bwd = FP8GlobalStateManager.global_amax_reduction(
fp8_meta, fp8_meta,
...@@ -549,7 +551,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -549,7 +551,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# Previous iteration was grad_enabled # Previous iteration was grad_enabled
if self.fp8_meta.get("update_amax_and_scale_fwd", False): if self.fp8_meta.get("update_amax_and_scale_fwd", False):
if self.fp8_meta["recipe"].reduce_amax: if (self.fp8_meta["recipe"].reduce_amax
and get_distributed_world_size(self.fp8_meta["fp8_group"]) > 1):
FP8GlobalStateManager.copy_amax_from_global_buffer(self.fp8_meta, forward=True) FP8GlobalStateManager.copy_amax_from_global_buffer(self.fp8_meta, forward=True)
amax_and_scale_update( amax_and_scale_update(
self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv
...@@ -562,7 +565,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -562,7 +565,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if self.fp8 and self.training: if self.fp8 and self.training:
# Setup for amax reduction # Setup for amax reduction
if self.fp8_meta["recipe"].reduce_amax: if (self.fp8_meta["recipe"].reduce_amax
and get_distributed_world_size(self.fp8_meta["fp8_group"]) > 1):
self.fp8_meta["first_module"] = FP8GlobalStateManager.is_first_fp8_module() self.fp8_meta["first_module"] = FP8GlobalStateManager.is_first_fp8_module()
if self.fp8_meta["first_module"]: if self.fp8_meta["first_module"]:
# Wait for the prior AMAX reduction to finish # Wait for the prior AMAX reduction to finish
...@@ -588,7 +592,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -588,7 +592,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.fp8 self.fp8
and self.training and self.training
and is_fp8_activation_recompute_enabled() and is_fp8_activation_recompute_enabled()
and not in_fp8_activation_recompute_phase()
): ):
FP8GlobalStateManager.copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta) FP8GlobalStateManager.copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta)
...@@ -599,7 +602,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -599,7 +602,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta) FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta)
return return
if self.fp8 and self.training and self.fp8_meta["recipe"].reduce_amax: if (self.fp8 and self.training and self.fp8_meta["recipe"].reduce_amax
and get_distributed_world_size(self.fp8_meta["fp8_group"]) > 1):
FP8GlobalStateManager.set_fp8_context_id(self.fp8_meta["autocast_id_fwd"]) FP8GlobalStateManager.set_fp8_context_id(self.fp8_meta["autocast_id_fwd"])
reduce_func = partial( reduce_func = partial(
FP8GlobalStateManager.global_amax_reduction, FP8GlobalStateManager.global_amax_reduction,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment