Unverified Commit 427c736d authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[PyTorch] Fixes and tests for FP8 + activation recompute (#487)



* initial test fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Drop eval for selective checkpointing tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Remove redundant recompute for FA
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* CI fix; Decouple fused attention and numerics tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent f5d720a0
......@@ -25,8 +25,6 @@ from transformer_engine.pytorch.cpp_extensions.fused_attn import (
QKVLayout,
fused_attn_bwd,
fused_attn_fwd,
fused_attn_bwd_qkvpacked,
fused_attn_fwd_qkvpacked,
)
import transformer_engine.pytorch.fp8 as fp8
from transformer_engine.pytorch.module.base import (
......@@ -38,13 +36,24 @@ from transformer_engine.pytorch.utils import (
init_method_normal,
scaled_init_method_normal,
)
from transformer_engine.pytorch.distributed import _set_cuda_rng_state, CudaRNGStatesTracker
import transformer_engine_extensions as tex
from test_numerics import get_dummy_cuda_rng_tracker, reset_rng_states
# Only run FP8 tests on H100.
fp8_available, reason_for_no_fp8 = fp8.FP8GlobalStateManager.is_fp8_available()
_flash_attn_version = packaging.version.Version(version("flash-attn"))
_flash_attn_2_available = _flash_attn_version >= packaging.version.Version("2")
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# Record initial RNG state from script run.
_cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state()
def _get_cudnn_version():
cudnn_version_encoded = ext.get_cudnn_version()
cudnn_major = cudnn_version_encoded // 1000
......@@ -52,6 +61,13 @@ def _get_cudnn_version():
cudnn_patch = cudnn_version_encoded - 1000 * cudnn_major - 100 * cudnn_minor
return [cudnn_major, cudnn_minor, cudnn_patch]
def reset_rng_states() -> None:
"""revert back to initial RNG state."""
torch.set_rng_state(_cpu_rng_state)
_set_cuda_rng_state(_cuda_rng_state)
_cudnn_version = _get_cudnn_version()
......@@ -210,6 +226,13 @@ def _run_dot_product_attention(dtype, bs, config, backend, ckpt_attn, bias_type)
else:
bias = None
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
def get_dummy_cuda_rng_tracker():
"""Get cuda rng tracker."""
return _DUMMY_CUDA_RNG_STATE_TRACKER
block = (
DotProductAttention(
config.num_attention_heads,
......@@ -733,6 +756,13 @@ def _run_dpa_fp8_ref(dtype, bs, config, backend):
cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
op_grad = torch.load('op_grad.pt').cuda().view(bs, config.seq_len, -1).transpose(0,1)
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
def get_dummy_cuda_rng_tracker():
"""Get cuda rng tracker."""
return _DUMMY_CUDA_RNG_STATE_TRACKER
block = (
DotProductAttention(
config.num_attention_heads,
......
......@@ -12,6 +12,7 @@ import torch
import torch.nn as nn
from torch.nn import Parameter
from transformer_engine.pytorch.fp8 import fp8_autocast, FP8GlobalStateManager
from transformer_engine.pytorch.utils import (
init_method_normal,
scaled_init_method_normal,
......@@ -25,6 +26,10 @@ from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint
from transformer_engine.pytorch.distributed import _set_cuda_rng_state, CudaRNGStatesTracker
# Only run FP8 tests on H100.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
......@@ -90,20 +95,11 @@ def assert_allclose(l1: List[torch.Tensor], l2: List[torch.Tensor], atol: float)
def reset_rng_states() -> None:
# revert back to initial RNG state.
"""revert back to initial RNG state."""
torch.set_rng_state(_cpu_rng_state)
_set_cuda_rng_state(_cuda_rng_state)
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
def get_dummy_cuda_rng_tracker():
"""Get cuda rng tracker."""
return _DUMMY_CUDA_RNG_STATE_TRACKER
class TorchScaledMaskedSoftmax(nn.Module):
def __init__(self) -> None:
super().__init__()
......@@ -343,41 +339,21 @@ class TorchGPT(nn.Module):
return x
def _test_e2e_selective_recompute(block, bs, dtype, config, recompute=False):
def _test_e2e_selective_recompute(bs, dtype, config, fp8, recompute=False):
reset_rng_states()
te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
te_inp_hidden_states.retain_grad()
te_inp_attn_mask = get_causal_attn_mask(config.seq_len)
te_out = block(
te_inp_hidden_states,
attention_mask=te_inp_attn_mask,
checkpoint_core_attention=recompute,
)
loss = te_out.sum()
loss.backward()
torch.cuda.synchronize()
outputs = [te_out, te_inp_hidden_states.grad]
for p in block.parameters():
if p.requires_grad:
outputs.append(p.grad)
return outputs
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_gpt_selective_activation_recompute(dtype, bs, model):
config = model_configs[model]
FP8GlobalStateManager.reset()
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
def get_dummy_cuda_rng_tracker():
"""Get cuda rng tracker."""
return _DUMMY_CUDA_RNG_STATE_TRACKER
block = (
TransformerLayer(
config.hidden_size,
......@@ -395,38 +371,19 @@ def test_gpt_selective_activation_recompute(dtype, bs, model):
params_dtype=dtype,
)
.cuda()
.eval()
)
outputs = _test_e2e_selective_recompute(block, bs, dtype, config, recompute=False)
outputs_recompute = _test_e2e_selective_recompute(block, bs, dtype, config, recompute=True)
assert_all_equal(outputs, outputs_recompute)
def _test_e2e_full_recompute(block, bs, dtype, config, recompute=False):
reset_rng_states()
te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
te_inp_hidden_states.retain_grad()
te_inp_attn_mask = get_causal_attn_mask(config.seq_len)
if recompute:
te_out = te_checkpoint(
block,
False, # distribute_saved_activations
get_dummy_cuda_rng_tracker,
None, # tp_group
te_inp_hidden_states,
attention_mask=te_inp_attn_mask,
checkpoint_core_attention=False,
)
else:
with fp8_autocast(enabled=fp8):
te_out = block(
te_inp_hidden_states,
attention_mask=te_inp_attn_mask,
checkpoint_core_attention=False,
checkpoint_core_attention=recompute,
)
loss = te_out.sum()
loss.backward()
......@@ -442,13 +399,33 @@ def _test_e2e_full_recompute(block, bs, dtype, config, recompute=False):
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_gpt_full_activation_recompute(dtype, bs, model):
@pytest.mark.parametrize("fp8", all_boolean)
def test_gpt_selective_activation_recompute(dtype, bs, model, fp8):
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
config = model_configs[model]
outputs = _test_e2e_selective_recompute(bs, dtype, config, fp8, recompute=False)
outputs_recompute = _test_e2e_selective_recompute(bs, dtype, config, fp8, recompute=True)
assert_all_equal(outputs, outputs_recompute)
def _test_e2e_full_recompute(bs, dtype, config, fp8, recompute=False):
reset_rng_states()
FP8GlobalStateManager.reset()
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
def get_dummy_cuda_rng_tracker():
"""Get cuda rng tracker."""
return _DUMMY_CUDA_RNG_STATE_TRACKER
block = (
TransformerLayer(
config.hidden_size,
......@@ -466,11 +443,54 @@ def test_gpt_full_activation_recompute(dtype, bs, model):
params_dtype=dtype,
)
.cuda()
.eval()
)
outputs = _test_e2e_full_recompute(block, bs, dtype, config, recompute=False)
outputs_recompute = _test_e2e_full_recompute(block, bs, dtype, config, recompute=True)
te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
te_inp_hidden_states.retain_grad()
te_inp_attn_mask = get_causal_attn_mask(config.seq_len)
with fp8_autocast(enabled=fp8):
if recompute:
te_out = te_checkpoint(
block,
False, # distribute_saved_activations
get_dummy_cuda_rng_tracker,
None, # tp_group
te_inp_hidden_states,
attention_mask=te_inp_attn_mask,
checkpoint_core_attention=False,
)
else:
te_out = block(
te_inp_hidden_states,
attention_mask=te_inp_attn_mask,
checkpoint_core_attention=False,
)
loss = te_out.sum()
loss.backward()
torch.cuda.synchronize()
outputs = [te_out, te_inp_hidden_states.grad]
for p in block.parameters():
if p.requires_grad:
outputs.append(p.grad)
return outputs
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("fp8", all_boolean)
def test_gpt_full_activation_recompute(dtype, bs, model, fp8):
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
config = model_configs[model]
outputs = _test_e2e_full_recompute(bs, dtype, config, fp8, recompute=False)
outputs_recompute = _test_e2e_full_recompute(bs, dtype, config, fp8, recompute=True)
assert_all_equal(outputs, outputs_recompute)
......@@ -565,8 +585,8 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path=
def test_gpt_checkpointing(dtype, bs, model):
config = model_configs[model]
outputs = _test_e2e_checkpointing(bs, dtype, config, checkpoint=False)
outputs_recompute = _test_e2e_checkpointing(bs, dtype, config, checkpoint=True)
assert_all_equal(outputs, outputs_recompute)
outputs_checkpoint = _test_e2e_checkpointing(bs, dtype, config, checkpoint=True)
assert_all_equal(outputs, outputs_checkpoint)
def _test_e2e_gpt_accuracy(block, bs, dtype, config):
......
......@@ -2164,19 +2164,6 @@ class DotProductAttention(torch.nn.Module):
)
if use_flash_attention:
if checkpoint_core_attention:
return self._checkpointed_attention_forward(self.flash_attention,
query_layer,
key_layer,
value_layer,
attention_mask=attention_mask,
qkv_layout=qkv_layout,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
attn_mask_type=attn_mask_type,
cp_group=self.cp_group,
cp_global_ranks=self.cp_global_ranks,
cp_stream=self.cp_stream)
return self.flash_attention(query_layer,
key_layer,
value_layer,
......
......@@ -75,6 +75,29 @@ class FP8GlobalStateManager:
dp_amax_reduce_forward_idx = 0
dp_amax_reduce_backward_idx = 0
@classmethod
def reset(cls) -> None:
"""Reset the global state"""
cls.FP8_ENABLED = False
cls.FP8_CALIBRATION = False
cls.FP8_RECIPE = None
cls.FP8_DISTRIBUTED_GROUP = None
cls.IS_FIRST_FP8_MODULE = False
cls.FP8_AUTOCAST_COUNTER = 0
cls.FP8_CURRENT_CONTEXT_ID = 0
cls.FP8_AUTOCAST_DEPTH = 0
cls.global_fp8_buffer = {}
cls.fp8_tensors_recompute_buffer = []
cls.amax_forward_global_reduce_func = None
cls.buffer_delete_key_fwd = None
cls.buffer_delete_key_bwd = None
cls.amax_reduce_handle_fwd = None
cls.fp8_available = None
cls.reason_for_no_fp8 = ""
cls.dp_amax_reduce_interval = None
cls.dp_amax_reduce_forward_idx = 0
cls.dp_amax_reduce_backward_idx = 0
@classmethod
def is_fp8_available(cls) -> Tuple[bool, str]:
"""Return if fp8 support is available"""
......
......@@ -28,6 +28,7 @@ from ..distributed import (
gather_along_first_dim,
is_fp8_activation_recompute_enabled,
in_fp8_activation_recompute_phase,
get_distributed_world_size,
)
from ..cpp_extensions import (
fp8_cast_transpose_fused,
......@@ -77,9 +78,7 @@ def _prepare_backward(
_amax_reduce_handle_bwd = None
# Update amax and scale; Skip all setup for global amax reduction
if not fp8_meta["recipe"].reduce_amax:
amax_and_scale_update(fp8_meta, False)
else:
if fp8_meta["recipe"].reduce_amax and get_distributed_world_size(fp8_meta["fp8_group"]) > 1:
# From previous iteration
FP8GlobalStateManager.copy_amax_from_global_buffer(fp8_meta, forward=False)
amax_and_scale_update(fp8_meta, False)
......@@ -89,11 +88,14 @@ def _prepare_backward(
fp8_meta["autocast_id_bwd"] = fp8_meta["autocast_id_fwd_stack"].pop(0)
FP8GlobalStateManager.add_amax_to_global_buffer(fp8_meta, forward=False)
else:
amax_and_scale_update(fp8_meta, False)
with torch.cuda.nvtx.range(name + " backward"):
yield
if fp8 and fp8_meta["recipe"].reduce_amax:
if (fp8 and fp8_meta["recipe"].reduce_amax
and get_distributed_world_size(fp8_meta["fp8_group"]) > 1):
if fp8_meta["first_module"]:
_amax_reduce_handle_bwd = FP8GlobalStateManager.global_amax_reduction(
fp8_meta,
......@@ -549,7 +551,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# Previous iteration was grad_enabled
if self.fp8_meta.get("update_amax_and_scale_fwd", False):
if self.fp8_meta["recipe"].reduce_amax:
if (self.fp8_meta["recipe"].reduce_amax
and get_distributed_world_size(self.fp8_meta["fp8_group"]) > 1):
FP8GlobalStateManager.copy_amax_from_global_buffer(self.fp8_meta, forward=True)
amax_and_scale_update(
self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv
......@@ -562,7 +565,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if self.fp8 and self.training:
# Setup for amax reduction
if self.fp8_meta["recipe"].reduce_amax:
if (self.fp8_meta["recipe"].reduce_amax
and get_distributed_world_size(self.fp8_meta["fp8_group"]) > 1):
self.fp8_meta["first_module"] = FP8GlobalStateManager.is_first_fp8_module()
if self.fp8_meta["first_module"]:
# Wait for the prior AMAX reduction to finish
......@@ -588,7 +592,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.fp8
and self.training
and is_fp8_activation_recompute_enabled()
and not in_fp8_activation_recompute_phase()
):
FP8GlobalStateManager.copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta)
......@@ -599,7 +602,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta)
return
if self.fp8 and self.training and self.fp8_meta["recipe"].reduce_amax:
if (self.fp8 and self.training and self.fp8_meta["recipe"].reduce_amax
and get_distributed_world_size(self.fp8_meta["fp8_group"]) > 1):
FP8GlobalStateManager.set_fp8_context_id(self.fp8_meta["autocast_id_fwd"])
reduce_func = partial(
FP8GlobalStateManager.global_amax_reduction,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment