Unverified Commit 2f61c401 authored by Paweł Gadziński's avatar Paweł Gadziński Committed by GitHub
Browse files

[PyTorch] Bunch of memory management fixes (#1686)



* init
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* clear() fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* lint fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* revert adding cpu offload tests for mxfp8
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* forgot to uncomment assert
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

---------
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent c0df246a
......@@ -366,7 +366,9 @@ def _test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad):
torch.cuda.synchronize()
def _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad):
def _test_sanity_common(
block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching=True
):
if skip_dgrad and skip_wgrad:
pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.")
......@@ -382,7 +384,11 @@ def _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad
use_fp8 = fp8_recipe is not None
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
if not microbatching:
te_out = block(te_inp)
else:
_ = block(te_inp, is_first_microbatch=True)
te_out = block(te_inp, is_first_microbatch=False)
if isinstance(te_out, tuple):
te_out = te_out[0]
loss = te_out.sum()
......@@ -436,8 +442,16 @@ def test_sanity_normalization_amp(dtype, model, skip_wgrad, skip_dgrad, normaliz
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
@pytest.mark.parametrize("skip_dgrad", all_boolean)
@pytest.mark.parametrize("normalization", all_normalizations)
@pytest.mark.parametrize("microbatching", all_boolean)
def test_sanity_layernorm_linear(
dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, skip_dgrad, normalization
dtype,
fp8_recipe,
model,
skip_wgrad,
zero_centered_gamma,
skip_dgrad,
normalization,
microbatching,
):
config = model_configs[model]
......@@ -463,7 +477,7 @@ def test_sanity_layernorm_linear(
params_dtype=dtype,
device="cuda",
)
_test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad)
_test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching)
@pytest.mark.parametrize("dtype", param_types)
......@@ -471,7 +485,8 @@ def test_sanity_layernorm_linear(
@pytest.mark.parametrize("model", ["small", "weird"])
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("skip_dgrad", all_boolean)
def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad):
@pytest.mark.parametrize("microbatching", all_boolean)
def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad, microbatching):
config = model_configs[model]
if fp8_recipe is not None:
......@@ -494,7 +509,7 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad):
params_dtype=dtype,
device="cuda",
)
_test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad)
_test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching)
@pytest.mark.parametrize("dtype", param_types)
......@@ -593,8 +608,17 @@ def test_sanity_grouped_linear(
@pytest.mark.parametrize("skip_dgrad", all_boolean)
@pytest.mark.parametrize("activation", all_activations)
@pytest.mark.parametrize("normalization", all_normalizations)
@pytest.mark.parametrize("microbatching", all_boolean)
def test_sanity_layernorm_mlp(
dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, skip_dgrad, activation, normalization
dtype,
fp8_recipe,
model,
skip_wgrad,
zero_centered_gamma,
skip_dgrad,
activation,
normalization,
microbatching,
):
config = model_configs[model]
......@@ -623,7 +647,7 @@ def test_sanity_layernorm_mlp(
params_dtype=dtype,
device="cuda",
)
_test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad)
_test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching)
@pytest.mark.parametrize("dtype", param_types)
......
......@@ -1106,7 +1106,16 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
raise ValueError(
"tensor and quantizer kwargs must be provided to construct FP8 workspace"
)
if cache_name is not None:
# Ensure the tensor in the cache is an instance of torch.Tensor,
# as it persists beyond a single forward pass.
# Setting internal=True would cause the data to be removed in prepare_for_saving(...).
quantizer_internal = quantizer.internal
quantizer.internal = False
out = quantizer.quantize(tensor, dtype=workspace_dtype)
if cache_name is not None:
quantizer.internal = quantizer_internal
# Update cache
if cache_name is not None:
......
......@@ -137,8 +137,10 @@ class _LayerNormLinear(torch.autograd.Function):
# Make sure input dimensions are compatible
out_features, in_features = weight.shape
inp_shape = inp.shape
inp_requires_grad = inp.requires_grad
assert inp_shape[-1] == in_features, "GEMM not possible"
inputmat = inp.view((-1, in_features))
inp = inp.view((-1, in_features))
inputmat = inp
if fp8:
assert_dim_for_fp8_exec(inputmat, weight)
......@@ -399,7 +401,7 @@ class _LayerNormLinear(torch.autograd.Function):
)
ctx.save_for_backward(*tensors_to_save)
ctx.tensor_objects = tensor_objects
ctx.requires_dgrad = inp.requires_grad
ctx.requires_dgrad = inp_requires_grad
ctx.requires_wgrad = weight.requires_grad
ctx.quantized_weight = quantized_weight
if fuse_wgrad_accumulation and weight.requires_grad:
......@@ -432,7 +434,7 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.ub_bulk_wgrad = ub_bulk_wgrad
ctx.ub_bulk_dgrad = ub_bulk_dgrad
ctx.ub_name = ub_name
ctx.requires_dgrad = inp.requires_grad
ctx.requires_dgrad = inp_requires_grad
ctx.normalization = normalization
ctx.reduce_and_update_bwd_fp8_tensors = False
if ctx.fp8 and requires_grad(inp, ln_weight, ln_bias, weight, bias):
......@@ -462,7 +464,7 @@ class _LayerNormLinear(torch.autograd.Function):
if return_layernorm_output:
if return_layernorm_output_gathered:
shape = list(inp.shape)
shape = list(inp_shape)
shape[0] *= tp_size
return out, ln_out_return.view(shape)
return out, ln_out_return.view_as(inp)
......
......@@ -126,8 +126,7 @@ class _Linear(torch.autograd.Function):
# Make sure input dimensions are compatible
out_features, in_features = weight.shape
inp_shape = inp.shape
assert inp_shape[-1] == in_features, "GEMM not possible"
assert inp.shape[-1] == in_features, "GEMM not possible"
tp_world_size = get_distributed_world_size(tp_group)
backward_needs_input = is_grad_enabled and weight.requires_grad
......@@ -135,7 +134,8 @@ class _Linear(torch.autograd.Function):
# Prepare input tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
nvtx_range_push(f"{nvtx_label}.input_cast_comm")
inputmat = inp.view(-1, in_features)
inputmat = inp
inputmat_total = None
with_input_all_gather_nccl = (
parallel_mode == "column" and sequence_parallel and not ub_overlap_ag_fprop
......@@ -256,7 +256,7 @@ class _Linear(torch.autograd.Function):
if ub_overlap_rs_fprop:
ub_obj = get_ub(ub_name + "_fprop")
ub_type = tex.CommOverlapType.RS
out_shape = [reduce(multiply_op, inp_shape[:-1]) // tp_world_size, out_features]
out_shape = [reduce(multiply_op, inp.shape[:-1]) // tp_world_size, out_features]
rs_out = torch.empty(out_shape, dtype=activation_dtype, device=inputmat_total.device)
elif ub_overlap_ag_fprop:
......@@ -365,7 +365,7 @@ class _Linear(torch.autograd.Function):
ctx.use_bias = bias is not None
ctx.sequence_parallel = sequence_parallel
ctx.tensor_parallel = tensor_parallel
ctx.inp_shape = inp_shape
ctx.inp_shape = inp.shape
ctx.parallel_mode = parallel_mode
ctx.tp_group = tp_group
ctx.ub_overlap_ag = ub_overlap_ag_dgrad
......@@ -377,6 +377,7 @@ class _Linear(torch.autograd.Function):
ctx.requires_dgrad = inp.requires_grad
ctx.requires_wgrad = weight.requires_grad
ctx.reduce_and_update_bwd_fp8_tensors = False
ctx.owns_input = saved_inputmat is not inp
if ctx.fp8 and requires_grad(inp, weight, bias):
_first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE
......@@ -399,7 +400,6 @@ class _Linear(torch.autograd.Function):
out, _ = allreduce(out, tp_group)
nvtx_range_pop(f"{nvtx_label}.row_parallel_comm")
out = out.view(-1, *inp_shape[1:-1], out_features)
return out
@staticmethod
......
......@@ -57,6 +57,17 @@ class Float8BlockwiseQTensorBase:
return instance
def clear(self):
"""Deallocate this tensor's memory. Typically not needed and must be used carefully."""
for t in (
self._rowwise_data,
self._columnwise_data,
self._rowwise_scale_inv,
self._columnwise_scale_inv,
):
if t is not None:
t.data = torch.Tensor()
def get_metadata(self) -> Dict[str, Any]:
"""Get this tensor's metadata."""
return {
......@@ -74,14 +85,17 @@ class Float8BlockwiseQTensorBase:
) -> Tuple[list[Optional[torch.Tensor]], Float8BlockwiseQTensorBase]:
"""
Prepare the tensor base for saving for backward
This does not clear the tensors currently, because with PP config
that clears the weight cache between micro-batches. If the rowwise
data is not required for backward, this is a possible memory
pessimization, but is consistent with the other quantized tensor
classes.
"""
tensors = [self._rowwise_data, self._columnwise_data]
tensors = [
self._rowwise_data,
self._columnwise_data,
self._rowwise_scale_inv,
self._columnwise_scale_inv,
]
self._rowwise_data = None
self._columnwise_data = None
self._rowwise_scale_inv = None
self._columnwise_scale_inv = None
return tensors, self
def restore_from_saved(
......@@ -90,7 +104,9 @@ class Float8BlockwiseQTensorBase:
"""Restore the tensor base data from the saved tensors list."""
self._rowwise_data = tensors[0]
self._columnwise_data = tensors[1]
return tensors[2:]
self._rowwise_scale_inv = tensors[2]
self._columnwise_scale_inv = tensors[3]
return tensors[4:]
def get_data_tensors(self):
"""Get this Tensor's data."""
......
......@@ -90,6 +90,13 @@ class Float8TensorBase:
return instance
def clear(self):
"""Deallocate this tensor's memory. Typically not needed and must be used carefully."""
for t in (self._data, self._transpose, self._scale_inv):
if t is not None:
t.data = torch.Tensor()
self._transpose_invalid = True
def get_metadata(self) -> Dict[str, Any]:
"""Get this tensor's metadata."""
return {
......@@ -102,7 +109,10 @@ class Float8TensorBase:
def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], Float8TensorBase]:
"""Prepare the tensor base for saving for backward"""
tensors = [self._data, self._transpose]
tensors = [self._data, self._transpose, self._scale_inv]
self._data = None
self._transpose = None
self._scale_inv = None
return tensors, self
def restore_from_saved(
......@@ -111,7 +121,8 @@ class Float8TensorBase:
"""Restore the tensor base data from the saved tensors list"""
self._data = tensors[0]
self._transpose = tensors[1]
return tensors[2:]
self._scale_inv = tensors[2]
return tensors[3:]
def get_data_tensors(self):
"""Get this Tensor's data."""
......
......@@ -81,6 +81,17 @@ class MXFP8TensorBase:
return instance
def clear(self):
"""Deallocate this tensor's memory. Typically not needed and must be used carefully."""
for t in (
self._rowwise_data,
self._columnwise_data,
self._rowwise_scale_inv,
self._columnwise_scale_inv,
):
if t is not None:
t.data = torch.Tensor()
def get_metadata(self) -> Dict[str, Any]:
"""Get this tensor's metadata."""
return {
......@@ -94,7 +105,16 @@ class MXFP8TensorBase:
def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], MXFP8TensorBase]:
"""Prepare the tensor base for saving for backward"""
tensors = [self._rowwise_data, self._columnwise_data]
tensors = [
self._rowwise_data,
self._columnwise_data,
self._rowwise_scale_inv,
self._columnwise_scale_inv,
]
self._rowwise_data = None
self._columnwise_data = None
self._rowwise_scale_inv = None
self._columnwise_scale_inv = None
return tensors, self
def restore_from_saved(
......@@ -103,7 +123,9 @@ class MXFP8TensorBase:
"""Restore the tensor base data from the saved tensors list."""
self._rowwise_data = tensors[0]
self._columnwise_data = tensors[1]
return tensors[2:]
self._rowwise_scale_inv = tensors[2]
self._columnwise_scale_inv = tensors[3]
return tensors[4:]
def get_data_tensors(self):
"""Get this Tensor's data."""
......
......@@ -432,11 +432,6 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor):
return self
raise ValueError("Float8BlockwiseQTensor does not support different memory formats!")
def clear(self):
"""Deallocate this tensor's memory. Typically not needed and must be used carefully."""
self._rowwise_data = torch.Tensor() if self._rowwise_data is not None else None
self._columnwise_data = torch.Tensor() if self._columnwise_data is not None else None
@classmethod
def _make_in_reduce_ex(
cls,
......
......@@ -516,12 +516,6 @@ class Float8Tensor(Float8TensorBase, QuantizedTensor):
del self._transpose # explicitly deletes the data for safety
self._transpose = None
def clear(self):
"""Deallocate this tensor's memory. Typically not needed and must be used carefully."""
self._data = torch.Tensor() if self._data is not None else None
self._transpose = torch.Tensor() if self._transpose is not None else None
self._transpose_invalid = True
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None):
......
......@@ -304,11 +304,6 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor):
return self
raise ValueError("MXFP8Tensor does not support different memory formats!")
def clear(self):
"""Deallocate this tensor's memory. Typically not needed and must be used carefully."""
self._rowwise_data = torch.Tensor() if self._rowwise_data is not None else None
self._columnwise_data = torch.Tensor() if self._columnwise_data is not None else None
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None):
......
......@@ -14,8 +14,6 @@ import torch
import transformer_engine.pytorch.cpp_extensions as ext
from ..debug.pytorch.debug_quantization import DebugQuantizedTensor
from .tensor.quantized_tensor import QuantizedTensor
def requires_grad(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None:
"""Check if any of the given tensors require gradient."""
......@@ -34,7 +32,7 @@ def clear_tensor_data(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None:
"""
for t in tensors:
if t is not None:
if isinstance(t, QuantizedTensor):
if hasattr(t, "clear"):
t.clear()
else:
t.data = torch.Tensor()
......@@ -421,10 +419,10 @@ def assert_dim_for_fp8_exec(*tensors: List[torch.Tensor]) -> None:
"""Assert that tensor or tensors dimensions are supported for FP8 TN GEMM."""
for tensor in tensors:
assert tensor.dim() == 2 and tensor.size(0) % 8 == 0 and tensor.size(1) % 16 == 0, (
"FP8 execution requires 2D input matrices with "
"height divisible by 8 and width divisible by 16, "
f"but got tensor with dims={list(tensor.size())}"
assert math.prod(tensor.shape[:-1]) % 8 == 0 and tensor.shape[-1] % 16 == 0, (
"FP8 execution requires the product of all dimensions except the last to be divisible"
" by 8 and the last dimension to be divisible by 16, but got tensor with"
f" dims={list(tensor.size())}"
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment