Unverified Commit 2f61c401 authored by Paweł Gadziński's avatar Paweł Gadziński Committed by GitHub
Browse files

[PyTorch] Bunch of memory management fixes (#1686)



* init
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* clear() fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* lint fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* revert adding cpu offload tests for mxfp8
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* forgot to uncomment assert
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

---------
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent c0df246a
...@@ -366,7 +366,9 @@ def _test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad): ...@@ -366,7 +366,9 @@ def _test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad):
torch.cuda.synchronize() torch.cuda.synchronize()
def _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad): def _test_sanity_common(
block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching=True
):
if skip_dgrad and skip_wgrad: if skip_dgrad and skip_wgrad:
pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.") pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.")
...@@ -382,7 +384,11 @@ def _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad ...@@ -382,7 +384,11 @@ def _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad
use_fp8 = fp8_recipe is not None use_fp8 = fp8_recipe is not None
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe): with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
te_out = block(te_inp) if not microbatching:
te_out = block(te_inp)
else:
_ = block(te_inp, is_first_microbatch=True)
te_out = block(te_inp, is_first_microbatch=False)
if isinstance(te_out, tuple): if isinstance(te_out, tuple):
te_out = te_out[0] te_out = te_out[0]
loss = te_out.sum() loss = te_out.sum()
...@@ -436,8 +442,16 @@ def test_sanity_normalization_amp(dtype, model, skip_wgrad, skip_dgrad, normaliz ...@@ -436,8 +442,16 @@ def test_sanity_normalization_amp(dtype, model, skip_wgrad, skip_dgrad, normaliz
@pytest.mark.parametrize("zero_centered_gamma", all_boolean) @pytest.mark.parametrize("zero_centered_gamma", all_boolean)
@pytest.mark.parametrize("skip_dgrad", all_boolean) @pytest.mark.parametrize("skip_dgrad", all_boolean)
@pytest.mark.parametrize("normalization", all_normalizations) @pytest.mark.parametrize("normalization", all_normalizations)
@pytest.mark.parametrize("microbatching", all_boolean)
def test_sanity_layernorm_linear( def test_sanity_layernorm_linear(
dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, skip_dgrad, normalization dtype,
fp8_recipe,
model,
skip_wgrad,
zero_centered_gamma,
skip_dgrad,
normalization,
microbatching,
): ):
config = model_configs[model] config = model_configs[model]
...@@ -463,7 +477,7 @@ def test_sanity_layernorm_linear( ...@@ -463,7 +477,7 @@ def test_sanity_layernorm_linear(
params_dtype=dtype, params_dtype=dtype,
device="cuda", device="cuda",
) )
_test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad) _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching)
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
...@@ -471,7 +485,8 @@ def test_sanity_layernorm_linear( ...@@ -471,7 +485,8 @@ def test_sanity_layernorm_linear(
@pytest.mark.parametrize("model", ["small", "weird"]) @pytest.mark.parametrize("model", ["small", "weird"])
@pytest.mark.parametrize("skip_wgrad", all_boolean) @pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("skip_dgrad", all_boolean) @pytest.mark.parametrize("skip_dgrad", all_boolean)
def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad): @pytest.mark.parametrize("microbatching", all_boolean)
def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad, microbatching):
config = model_configs[model] config = model_configs[model]
if fp8_recipe is not None: if fp8_recipe is not None:
...@@ -494,7 +509,7 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad): ...@@ -494,7 +509,7 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad):
params_dtype=dtype, params_dtype=dtype,
device="cuda", device="cuda",
) )
_test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad) _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching)
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
...@@ -593,8 +608,17 @@ def test_sanity_grouped_linear( ...@@ -593,8 +608,17 @@ def test_sanity_grouped_linear(
@pytest.mark.parametrize("skip_dgrad", all_boolean) @pytest.mark.parametrize("skip_dgrad", all_boolean)
@pytest.mark.parametrize("activation", all_activations) @pytest.mark.parametrize("activation", all_activations)
@pytest.mark.parametrize("normalization", all_normalizations) @pytest.mark.parametrize("normalization", all_normalizations)
@pytest.mark.parametrize("microbatching", all_boolean)
def test_sanity_layernorm_mlp( def test_sanity_layernorm_mlp(
dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, skip_dgrad, activation, normalization dtype,
fp8_recipe,
model,
skip_wgrad,
zero_centered_gamma,
skip_dgrad,
activation,
normalization,
microbatching,
): ):
config = model_configs[model] config = model_configs[model]
...@@ -623,7 +647,7 @@ def test_sanity_layernorm_mlp( ...@@ -623,7 +647,7 @@ def test_sanity_layernorm_mlp(
params_dtype=dtype, params_dtype=dtype,
device="cuda", device="cuda",
) )
_test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad) _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching)
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
......
...@@ -1106,7 +1106,16 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1106,7 +1106,16 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
raise ValueError( raise ValueError(
"tensor and quantizer kwargs must be provided to construct FP8 workspace" "tensor and quantizer kwargs must be provided to construct FP8 workspace"
) )
if cache_name is not None:
# Ensure the tensor in the cache is an instance of torch.Tensor,
# as it persists beyond a single forward pass.
# Setting internal=True would cause the data to be removed in prepare_for_saving(...).
quantizer_internal = quantizer.internal
quantizer.internal = False
out = quantizer.quantize(tensor, dtype=workspace_dtype) out = quantizer.quantize(tensor, dtype=workspace_dtype)
if cache_name is not None:
quantizer.internal = quantizer_internal
# Update cache # Update cache
if cache_name is not None: if cache_name is not None:
......
...@@ -137,8 +137,10 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -137,8 +137,10 @@ class _LayerNormLinear(torch.autograd.Function):
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
out_features, in_features = weight.shape out_features, in_features = weight.shape
inp_shape = inp.shape inp_shape = inp.shape
inp_requires_grad = inp.requires_grad
assert inp_shape[-1] == in_features, "GEMM not possible" assert inp_shape[-1] == in_features, "GEMM not possible"
inputmat = inp.view((-1, in_features)) inp = inp.view((-1, in_features))
inputmat = inp
if fp8: if fp8:
assert_dim_for_fp8_exec(inputmat, weight) assert_dim_for_fp8_exec(inputmat, weight)
...@@ -399,7 +401,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -399,7 +401,7 @@ class _LayerNormLinear(torch.autograd.Function):
) )
ctx.save_for_backward(*tensors_to_save) ctx.save_for_backward(*tensors_to_save)
ctx.tensor_objects = tensor_objects ctx.tensor_objects = tensor_objects
ctx.requires_dgrad = inp.requires_grad ctx.requires_dgrad = inp_requires_grad
ctx.requires_wgrad = weight.requires_grad ctx.requires_wgrad = weight.requires_grad
ctx.quantized_weight = quantized_weight ctx.quantized_weight = quantized_weight
if fuse_wgrad_accumulation and weight.requires_grad: if fuse_wgrad_accumulation and weight.requires_grad:
...@@ -432,7 +434,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -432,7 +434,7 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.ub_bulk_wgrad = ub_bulk_wgrad ctx.ub_bulk_wgrad = ub_bulk_wgrad
ctx.ub_bulk_dgrad = ub_bulk_dgrad ctx.ub_bulk_dgrad = ub_bulk_dgrad
ctx.ub_name = ub_name ctx.ub_name = ub_name
ctx.requires_dgrad = inp.requires_grad ctx.requires_dgrad = inp_requires_grad
ctx.normalization = normalization ctx.normalization = normalization
ctx.reduce_and_update_bwd_fp8_tensors = False ctx.reduce_and_update_bwd_fp8_tensors = False
if ctx.fp8 and requires_grad(inp, ln_weight, ln_bias, weight, bias): if ctx.fp8 and requires_grad(inp, ln_weight, ln_bias, weight, bias):
...@@ -462,7 +464,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -462,7 +464,7 @@ class _LayerNormLinear(torch.autograd.Function):
if return_layernorm_output: if return_layernorm_output:
if return_layernorm_output_gathered: if return_layernorm_output_gathered:
shape = list(inp.shape) shape = list(inp_shape)
shape[0] *= tp_size shape[0] *= tp_size
return out, ln_out_return.view(shape) return out, ln_out_return.view(shape)
return out, ln_out_return.view_as(inp) return out, ln_out_return.view_as(inp)
......
...@@ -126,8 +126,7 @@ class _Linear(torch.autograd.Function): ...@@ -126,8 +126,7 @@ class _Linear(torch.autograd.Function):
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
out_features, in_features = weight.shape out_features, in_features = weight.shape
inp_shape = inp.shape assert inp.shape[-1] == in_features, "GEMM not possible"
assert inp_shape[-1] == in_features, "GEMM not possible"
tp_world_size = get_distributed_world_size(tp_group) tp_world_size = get_distributed_world_size(tp_group)
backward_needs_input = is_grad_enabled and weight.requires_grad backward_needs_input = is_grad_enabled and weight.requires_grad
...@@ -135,7 +134,8 @@ class _Linear(torch.autograd.Function): ...@@ -135,7 +134,8 @@ class _Linear(torch.autograd.Function):
# Prepare input tensor # Prepare input tensor
# Note: Cast to expected dtype and perform tensor-parallel communication # Note: Cast to expected dtype and perform tensor-parallel communication
nvtx_range_push(f"{nvtx_label}.input_cast_comm") nvtx_range_push(f"{nvtx_label}.input_cast_comm")
inputmat = inp.view(-1, in_features)
inputmat = inp
inputmat_total = None inputmat_total = None
with_input_all_gather_nccl = ( with_input_all_gather_nccl = (
parallel_mode == "column" and sequence_parallel and not ub_overlap_ag_fprop parallel_mode == "column" and sequence_parallel and not ub_overlap_ag_fprop
...@@ -256,7 +256,7 @@ class _Linear(torch.autograd.Function): ...@@ -256,7 +256,7 @@ class _Linear(torch.autograd.Function):
if ub_overlap_rs_fprop: if ub_overlap_rs_fprop:
ub_obj = get_ub(ub_name + "_fprop") ub_obj = get_ub(ub_name + "_fprop")
ub_type = tex.CommOverlapType.RS ub_type = tex.CommOverlapType.RS
out_shape = [reduce(multiply_op, inp_shape[:-1]) // tp_world_size, out_features] out_shape = [reduce(multiply_op, inp.shape[:-1]) // tp_world_size, out_features]
rs_out = torch.empty(out_shape, dtype=activation_dtype, device=inputmat_total.device) rs_out = torch.empty(out_shape, dtype=activation_dtype, device=inputmat_total.device)
elif ub_overlap_ag_fprop: elif ub_overlap_ag_fprop:
...@@ -365,7 +365,7 @@ class _Linear(torch.autograd.Function): ...@@ -365,7 +365,7 @@ class _Linear(torch.autograd.Function):
ctx.use_bias = bias is not None ctx.use_bias = bias is not None
ctx.sequence_parallel = sequence_parallel ctx.sequence_parallel = sequence_parallel
ctx.tensor_parallel = tensor_parallel ctx.tensor_parallel = tensor_parallel
ctx.inp_shape = inp_shape ctx.inp_shape = inp.shape
ctx.parallel_mode = parallel_mode ctx.parallel_mode = parallel_mode
ctx.tp_group = tp_group ctx.tp_group = tp_group
ctx.ub_overlap_ag = ub_overlap_ag_dgrad ctx.ub_overlap_ag = ub_overlap_ag_dgrad
...@@ -377,6 +377,7 @@ class _Linear(torch.autograd.Function): ...@@ -377,6 +377,7 @@ class _Linear(torch.autograd.Function):
ctx.requires_dgrad = inp.requires_grad ctx.requires_dgrad = inp.requires_grad
ctx.requires_wgrad = weight.requires_grad ctx.requires_wgrad = weight.requires_grad
ctx.reduce_and_update_bwd_fp8_tensors = False ctx.reduce_and_update_bwd_fp8_tensors = False
ctx.owns_input = saved_inputmat is not inp ctx.owns_input = saved_inputmat is not inp
if ctx.fp8 and requires_grad(inp, weight, bias): if ctx.fp8 and requires_grad(inp, weight, bias):
_first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE _first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE
...@@ -399,7 +400,6 @@ class _Linear(torch.autograd.Function): ...@@ -399,7 +400,6 @@ class _Linear(torch.autograd.Function):
out, _ = allreduce(out, tp_group) out, _ = allreduce(out, tp_group)
nvtx_range_pop(f"{nvtx_label}.row_parallel_comm") nvtx_range_pop(f"{nvtx_label}.row_parallel_comm")
out = out.view(-1, *inp_shape[1:-1], out_features)
return out return out
@staticmethod @staticmethod
......
...@@ -57,6 +57,17 @@ class Float8BlockwiseQTensorBase: ...@@ -57,6 +57,17 @@ class Float8BlockwiseQTensorBase:
return instance return instance
def clear(self):
"""Deallocate this tensor's memory. Typically not needed and must be used carefully."""
for t in (
self._rowwise_data,
self._columnwise_data,
self._rowwise_scale_inv,
self._columnwise_scale_inv,
):
if t is not None:
t.data = torch.Tensor()
def get_metadata(self) -> Dict[str, Any]: def get_metadata(self) -> Dict[str, Any]:
"""Get this tensor's metadata.""" """Get this tensor's metadata."""
return { return {
...@@ -74,14 +85,17 @@ class Float8BlockwiseQTensorBase: ...@@ -74,14 +85,17 @@ class Float8BlockwiseQTensorBase:
) -> Tuple[list[Optional[torch.Tensor]], Float8BlockwiseQTensorBase]: ) -> Tuple[list[Optional[torch.Tensor]], Float8BlockwiseQTensorBase]:
""" """
Prepare the tensor base for saving for backward Prepare the tensor base for saving for backward
This does not clear the tensors currently, because with PP config
that clears the weight cache between micro-batches. If the rowwise
data is not required for backward, this is a possible memory
pessimization, but is consistent with the other quantized tensor
classes.
""" """
tensors = [self._rowwise_data, self._columnwise_data] tensors = [
self._rowwise_data,
self._columnwise_data,
self._rowwise_scale_inv,
self._columnwise_scale_inv,
]
self._rowwise_data = None
self._columnwise_data = None
self._rowwise_scale_inv = None
self._columnwise_scale_inv = None
return tensors, self return tensors, self
def restore_from_saved( def restore_from_saved(
...@@ -90,7 +104,9 @@ class Float8BlockwiseQTensorBase: ...@@ -90,7 +104,9 @@ class Float8BlockwiseQTensorBase:
"""Restore the tensor base data from the saved tensors list.""" """Restore the tensor base data from the saved tensors list."""
self._rowwise_data = tensors[0] self._rowwise_data = tensors[0]
self._columnwise_data = tensors[1] self._columnwise_data = tensors[1]
return tensors[2:] self._rowwise_scale_inv = tensors[2]
self._columnwise_scale_inv = tensors[3]
return tensors[4:]
def get_data_tensors(self): def get_data_tensors(self):
"""Get this Tensor's data.""" """Get this Tensor's data."""
......
...@@ -90,6 +90,13 @@ class Float8TensorBase: ...@@ -90,6 +90,13 @@ class Float8TensorBase:
return instance return instance
def clear(self):
"""Deallocate this tensor's memory. Typically not needed and must be used carefully."""
for t in (self._data, self._transpose, self._scale_inv):
if t is not None:
t.data = torch.Tensor()
self._transpose_invalid = True
def get_metadata(self) -> Dict[str, Any]: def get_metadata(self) -> Dict[str, Any]:
"""Get this tensor's metadata.""" """Get this tensor's metadata."""
return { return {
...@@ -102,7 +109,10 @@ class Float8TensorBase: ...@@ -102,7 +109,10 @@ class Float8TensorBase:
def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], Float8TensorBase]: def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], Float8TensorBase]:
"""Prepare the tensor base for saving for backward""" """Prepare the tensor base for saving for backward"""
tensors = [self._data, self._transpose] tensors = [self._data, self._transpose, self._scale_inv]
self._data = None
self._transpose = None
self._scale_inv = None
return tensors, self return tensors, self
def restore_from_saved( def restore_from_saved(
...@@ -111,7 +121,8 @@ class Float8TensorBase: ...@@ -111,7 +121,8 @@ class Float8TensorBase:
"""Restore the tensor base data from the saved tensors list""" """Restore the tensor base data from the saved tensors list"""
self._data = tensors[0] self._data = tensors[0]
self._transpose = tensors[1] self._transpose = tensors[1]
return tensors[2:] self._scale_inv = tensors[2]
return tensors[3:]
def get_data_tensors(self): def get_data_tensors(self):
"""Get this Tensor's data.""" """Get this Tensor's data."""
......
...@@ -81,6 +81,17 @@ class MXFP8TensorBase: ...@@ -81,6 +81,17 @@ class MXFP8TensorBase:
return instance return instance
def clear(self):
"""Deallocate this tensor's memory. Typically not needed and must be used carefully."""
for t in (
self._rowwise_data,
self._columnwise_data,
self._rowwise_scale_inv,
self._columnwise_scale_inv,
):
if t is not None:
t.data = torch.Tensor()
def get_metadata(self) -> Dict[str, Any]: def get_metadata(self) -> Dict[str, Any]:
"""Get this tensor's metadata.""" """Get this tensor's metadata."""
return { return {
...@@ -94,7 +105,16 @@ class MXFP8TensorBase: ...@@ -94,7 +105,16 @@ class MXFP8TensorBase:
def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], MXFP8TensorBase]: def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], MXFP8TensorBase]:
"""Prepare the tensor base for saving for backward""" """Prepare the tensor base for saving for backward"""
tensors = [self._rowwise_data, self._columnwise_data] tensors = [
self._rowwise_data,
self._columnwise_data,
self._rowwise_scale_inv,
self._columnwise_scale_inv,
]
self._rowwise_data = None
self._columnwise_data = None
self._rowwise_scale_inv = None
self._columnwise_scale_inv = None
return tensors, self return tensors, self
def restore_from_saved( def restore_from_saved(
...@@ -103,7 +123,9 @@ class MXFP8TensorBase: ...@@ -103,7 +123,9 @@ class MXFP8TensorBase:
"""Restore the tensor base data from the saved tensors list.""" """Restore the tensor base data from the saved tensors list."""
self._rowwise_data = tensors[0] self._rowwise_data = tensors[0]
self._columnwise_data = tensors[1] self._columnwise_data = tensors[1]
return tensors[2:] self._rowwise_scale_inv = tensors[2]
self._columnwise_scale_inv = tensors[3]
return tensors[4:]
def get_data_tensors(self): def get_data_tensors(self):
"""Get this Tensor's data.""" """Get this Tensor's data."""
......
...@@ -432,11 +432,6 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor): ...@@ -432,11 +432,6 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor):
return self return self
raise ValueError("Float8BlockwiseQTensor does not support different memory formats!") raise ValueError("Float8BlockwiseQTensor does not support different memory formats!")
def clear(self):
"""Deallocate this tensor's memory. Typically not needed and must be used carefully."""
self._rowwise_data = torch.Tensor() if self._rowwise_data is not None else None
self._columnwise_data = torch.Tensor() if self._columnwise_data is not None else None
@classmethod @classmethod
def _make_in_reduce_ex( def _make_in_reduce_ex(
cls, cls,
......
...@@ -516,12 +516,6 @@ class Float8Tensor(Float8TensorBase, QuantizedTensor): ...@@ -516,12 +516,6 @@ class Float8Tensor(Float8TensorBase, QuantizedTensor):
del self._transpose # explicitly deletes the data for safety del self._transpose # explicitly deletes the data for safety
self._transpose = None self._transpose = None
def clear(self):
"""Deallocate this tensor's memory. Typically not needed and must be used carefully."""
self._data = torch.Tensor() if self._data is not None else None
self._transpose = torch.Tensor() if self._transpose is not None else None
self._transpose_invalid = True
@classmethod @classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None): def __torch_dispatch__(cls, func, types, args, kwargs=None):
......
...@@ -304,11 +304,6 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor): ...@@ -304,11 +304,6 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor):
return self return self
raise ValueError("MXFP8Tensor does not support different memory formats!") raise ValueError("MXFP8Tensor does not support different memory formats!")
def clear(self):
"""Deallocate this tensor's memory. Typically not needed and must be used carefully."""
self._rowwise_data = torch.Tensor() if self._rowwise_data is not None else None
self._columnwise_data = torch.Tensor() if self._columnwise_data is not None else None
@classmethod @classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None): def __torch_dispatch__(cls, func, types, args, kwargs=None):
......
...@@ -14,8 +14,6 @@ import torch ...@@ -14,8 +14,6 @@ import torch
import transformer_engine.pytorch.cpp_extensions as ext import transformer_engine.pytorch.cpp_extensions as ext
from ..debug.pytorch.debug_quantization import DebugQuantizedTensor from ..debug.pytorch.debug_quantization import DebugQuantizedTensor
from .tensor.quantized_tensor import QuantizedTensor
def requires_grad(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None: def requires_grad(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None:
"""Check if any of the given tensors require gradient.""" """Check if any of the given tensors require gradient."""
...@@ -34,7 +32,7 @@ def clear_tensor_data(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None: ...@@ -34,7 +32,7 @@ def clear_tensor_data(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None:
""" """
for t in tensors: for t in tensors:
if t is not None: if t is not None:
if isinstance(t, QuantizedTensor): if hasattr(t, "clear"):
t.clear() t.clear()
else: else:
t.data = torch.Tensor() t.data = torch.Tensor()
...@@ -421,10 +419,10 @@ def assert_dim_for_fp8_exec(*tensors: List[torch.Tensor]) -> None: ...@@ -421,10 +419,10 @@ def assert_dim_for_fp8_exec(*tensors: List[torch.Tensor]) -> None:
"""Assert that tensor or tensors dimensions are supported for FP8 TN GEMM.""" """Assert that tensor or tensors dimensions are supported for FP8 TN GEMM."""
for tensor in tensors: for tensor in tensors:
assert tensor.dim() == 2 and tensor.size(0) % 8 == 0 and tensor.size(1) % 16 == 0, ( assert math.prod(tensor.shape[:-1]) % 8 == 0 and tensor.shape[-1] % 16 == 0, (
"FP8 execution requires 2D input matrices with " "FP8 execution requires the product of all dimensions except the last to be divisible"
"height divisible by 8 and width divisible by 16, " " by 8 and the last dimension to be divisible by 16, but got tensor with"
f"but got tensor with dims={list(tensor.size())}" f" dims={list(tensor.size())}"
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment