"git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "953b6d6802fad9b5a9bc1bd73c0b436c07ada29c"
Unverified Commit 66d6afbf authored by Paweł Gadziński's avatar Paweł Gadziński Committed by GitHub
Browse files

[PyTorch] More precise test for the CPU offloading. (#1668)



* test change
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* test fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* small changes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* small changes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* test
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* clear
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* base
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

---------
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent aee78831
...@@ -39,7 +39,7 @@ python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail ...@@ -39,7 +39,7 @@ python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py" python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py" python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py" python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py" NVTE_FLASH_ATTN=0 python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py"
NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python3 -m pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || test_fail "test_fused_attn.py" NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python3 -m pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || test_fail "test_fused_attn.py"
NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python3 -m pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_kv_cache.py || test_fail "test_kv_cache.py" NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python3 -m pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_kv_cache.py || test_fail "test_kv_cache.py"
......
...@@ -2,41 +2,84 @@ ...@@ -2,41 +2,84 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
import os
from contextlib import nullcontext
import pytest import pytest
import torch import torch
from contextlib import nullcontext
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
from transformer_engine.common import recipe
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
# Check if FP8 supported # Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
fp8_recipes = [
None, # non-fp8
# recipe.MXFP8BlockScaling(), - scale inverse tensors offloading doest not work yet
recipe.Float8CurrentScaling(),
recipe.DelayedScaling(),
]
SIZE = 512 SIZE = 512
NUM_HEADS = 8
NUM_LAYERS = 5
EPSILON = 0.1
# Flash attention saves some internal tensor for the backward pass
# that cannot be offloaded to CPU.
assert os.getenv("NVTE_FLASH_ATTN") == "0"
models = { # Offloading is supported for attention only for fused and flash attention backends,
"linear": te.Linear, # so the use of bfloat16 is required.
"layernorm_mlp": te.LayerNormMLP, #
"layernorm_linear": te.LayerNormLinear, # For the TransformerLayer, activation offloading with dropout is not supported,
# so we set hidden_dropout to 0.0.
model_types = {
"linear": lambda: te.Linear(SIZE, SIZE, params_dtype=torch.bfloat16),
"layernorm_mlp": lambda: te.LayerNormMLP(SIZE, SIZE, params_dtype=torch.bfloat16),
"layernorm_linear": lambda: te.LayerNormLinear(SIZE, SIZE, params_dtype=torch.bfloat16),
"multihead_attention": lambda: te.MultiheadAttention(
SIZE, NUM_HEADS, params_dtype=torch.bfloat16
),
"transformer_layer": lambda: te.TransformerLayer(
SIZE, SIZE, NUM_HEADS, params_dtype=torch.bfloat16, hidden_dropout=0.0
),
} }
def _get_input(): def _get_input():
return torch.empty((128, SIZE, SIZE)).cuda() return torch.empty((128, SIZE, SIZE), dtype=torch.bfloat16).cuda()
def _get_fp8_weight_cache_size(models, fp8_recipe):
"""
Calculate the total FP8 weight cache size (in MB) for a list of models.
"""
if fp8_recipe is None:
return 0
def _measure_memory_between_forward_and_backward(model_cls, fp8, cpu_offload): params_bytes = 0
for model in models:
for name, param in model.named_parameters():
if "weight" in name:
params_bytes += param.numel()
input_layer = model_cls(SIZE, SIZE) # One byte for columnwise and one byte for rowwise,
hidden_layer = model_cls(SIZE, SIZE) # hence multiply by 2 and convert to MB
output_layer = model_cls(SIZE, SIZE) # there is 1 byte of scale per 32 elements in mxFP8
factor_for_scale_inv_tensor = (1 + 1 / 32) if fp8_recipe.mxfp8() else 1
return (2 * params_bytes * factor_for_scale_inv_tensor) / (1024**2)
input = _get_input()
def _measure_memory_between_forward_and_backward(models, fp8_recipe, cpu_offload):
tensor = _get_input()
if cpu_offload: if cpu_offload:
offload_context, sync_function = te.get_cpu_offload_context( offload_context, sync_function = te.get_cpu_offload_context(
enabled=True, enabled=True,
num_layers=2, num_layers=len(models) - 1,
model_layers=3, model_layers=len(models),
offload_activations=True, offload_activations=True,
offload_weights=False, offload_weights=False,
) )
...@@ -44,42 +87,58 @@ def _measure_memory_between_forward_and_backward(model_cls, fp8, cpu_offload): ...@@ -44,42 +87,58 @@ def _measure_memory_between_forward_and_backward(model_cls, fp8, cpu_offload):
offload_context = nullcontext() offload_context = nullcontext()
sync_function = lambda x: x sync_function = lambda x: x
with te.fp8_autocast(enabled=fp8), offload_context: for model in models:
out = input_layer(input) with te.fp8_autocast(
out = sync_function(out) enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe
with te.fp8_autocast(enabled=fp8), offload_context: ), offload_context:
out = hidden_layer(out) tensor = model(tensor)
out = sync_function(out) tensor = sync_function(tensor)
with te.fp8_autocast(enabled=fp8), offload_context:
out = output_layer(out)
out = sync_function(out)
max_mem_used = torch.cuda.memory_allocated() / 1024**2
out.sum().backward()
del input_layer
del hidden_layer
del output_layer
del input
del out
max_mem_used = torch.cuda.memory_allocated() / (1024**2)
torch.cuda.synchronize() torch.cuda.synchronize()
return max_mem_used return max_mem_used
@pytest.mark.parametrize("fp8", [True, False]) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model_key", models.keys()) @pytest.mark.parametrize("model_key", model_types.keys())
def test_cpu_offload(fp8, model_key) -> None: def test_cpu_offload(fp8_recipe, model_key) -> None:
"""
We run three configurations:
(1) No offloading: All activations remain on the GPU between forward and backward passes.
(2) No offloading (one layer): Only the first layer's activations remain on the GPU between
forward and backward passes.
(3) With offloading (all layers): Only the last layer's activations remain on the GPU
between forward and backward passes, while all other layers are offloaded to the CPU.
if fp8 and not fp8_available: We expect the memory consumption of configurations (2) and (3) to be similar, with
pytest.skip(reason_for_no_fp8) the difference being the size of the FP8 cache that is not offloaded to the CPU.
We also expect this memory consumption to be smaller than in scenario (1).
"""
model_cls = models[model_key] model_cls = model_types[model_key]
models_list = [model_cls() for _ in range(NUM_LAYERS)]
without_offloading = _measure_memory_between_forward_and_backward(model_cls, fp8, False) if fp8_recipe and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe is not None:
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
with_offloading = _measure_memory_between_forward_and_backward(model_cls, fp8, True) without_offloading = _measure_memory_between_forward_and_backward(
models_list, fp8_recipe, False
)
without_offloading_one_layer = _measure_memory_between_forward_and_backward(
models_list[:1], fp8_recipe, False
)
with_offloading = _measure_memory_between_forward_and_backward(models_list, fp8_recipe, True)
assert with_offloading < without_offloading assert with_offloading < without_offloading
# The only difference between the memory consumption of with_offloading
# and without_offloading_one_layer should be the size of the FP8 weights cache,
# which is not offloaded to the CPU.
memory_consumption_diff = abs(with_offloading - without_offloading_one_layer)
assert (
memory_consumption_diff < _get_fp8_weight_cache_size(models_list[1:], fp8_recipe) + EPSILON
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment