Unverified Commit 07afda98 authored by hx's avatar hx Committed by GitHub
Browse files

[PyTorch] Add save_original_input in Linear/GroupedLinear to save memory (#1865)



* save original input
Signed-off-by: default avatarHongxiao Bai <hongxiaob@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix input_quantizer usage in Linear bwd
Signed-off-by: default avatarHongxiao Bai <hongxiaob@nvidia.com>

* minor fix
Signed-off-by: default avatarHongxiao Bai <hongxiaob@nvidia.com>

* refine the docstring
Signed-off-by: default avatarHongxiao Bai <hongxiaob@nvidia.com>

* Merge remote-tracking branch 'origin/main' into save_bf16_in_fp8_gemm
Signed-off-by: default avatarHongxiao Bai <hongxiaob@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* decouple linear bwd with save_original_input; clean up UTs
Signed-off-by: default avatarHongxiao Bai <hongxiaob@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarHongxiao Bai <hongxiaob@nvidia.com>
Signed-off-by: default avatarhx <hongxiaob@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarXin Yao <xiny@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent fa91ed72
...@@ -516,8 +516,11 @@ def test_linear(): ...@@ -516,8 +516,11 @@ def test_linear():
{"return_bias": True}, {"return_bias": True},
{"params_dtype": torch.float16}, {"params_dtype": torch.float16},
{"delay_wgrad_compute": True}, {"delay_wgrad_compute": True},
{"save_original_input": True},
] ]
for kwargs in kwargs_list: for kwargs in kwargs_list:
if kwargs.get("save_original_input", False) and QUANTIZATION == "fp8":
continue
for parallel_mode in ["column", "row"]: for parallel_mode in ["column", "row"]:
for sequence_parallel in [False, True]: for sequence_parallel in [False, True]:
_test_linear(parallel_mode, sequence_parallel, **kwargs) _test_linear(parallel_mode, sequence_parallel, **kwargs)
......
...@@ -1055,8 +1055,11 @@ def test_mha_accuracy(dtype, bs, model, mask_type): ...@@ -1055,8 +1055,11 @@ def test_mha_accuracy(dtype, bs, model, mask_type):
assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype]) assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype])
def _test_granular_accuracy(block, bs, dtype, config, delay_wgrad_compute=False): def _test_granular_accuracy(block, bs, dtype, config, delay_wgrad_compute=False, recipe=None):
reset_rng_states() reset_rng_states()
fp8 = recipe is not None
if fp8:
FP8GlobalStateManager.reset()
inp_hidden_states = torch.randn( inp_hidden_states = torch.randn(
(config.seq_len, bs, config.hidden_size), (config.seq_len, bs, config.hidden_size),
...@@ -1066,6 +1069,7 @@ def _test_granular_accuracy(block, bs, dtype, config, delay_wgrad_compute=False) ...@@ -1066,6 +1069,7 @@ def _test_granular_accuracy(block, bs, dtype, config, delay_wgrad_compute=False)
) )
inp_hidden_states.retain_grad() inp_hidden_states.retain_grad()
with fp8_autocast(enabled=fp8, fp8_recipe=recipe):
out = block(inp_hidden_states) out = block(inp_hidden_states)
if isinstance(out, (List, Tuple)): if isinstance(out, (List, Tuple)):
out = out[0] out = out[0]
...@@ -1264,6 +1268,64 @@ def test_linear_accuracy_delay_wgrad_compute(dtype, bs, model, bias, fuse_wgrad_ ...@@ -1264,6 +1268,64 @@ def test_linear_accuracy_delay_wgrad_compute(dtype, bs, model, bias, fuse_wgrad_
torch.testing.assert_close(o, o_ref, rtol=0, atol=0) torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("recipe", fp8_recipes + [None])
def test_linear_accuracy_save_original_input(dtype, model, recipe):
bs = 1
fuse_wgrad_accumulation = True
fp8_model_params = False
fp8 = recipe is not None
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8 and recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8 and recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8 and recipe.delayed():
pytest.skip("DelayedScaling recipe is not supported with save_original_input")
config = model_configs[model]
if config.seq_len % 16 != 0 and fp8:
pytest.skip("FP8 requires sequence length to be divisible by 16.")
with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
te_linear_ref = Linear(
config.hidden_size,
4 * config.hidden_size,
bias=False,
params_dtype=dtype,
device="cuda",
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
save_original_input=False,
).eval()
te_linear = Linear(
config.hidden_size,
4 * config.hidden_size,
bias=False,
params_dtype=dtype,
device="cuda",
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
save_original_input=True,
).eval()
# Share params
with torch.no_grad():
te_linear_ref.weight = Parameter(te_linear.weight.clone())
if fuse_wgrad_accumulation:
weight = getattr(te_linear, f"weight")
weight.main_grad = torch.rand_like(weight, dtype=torch.float32)
te_linear_ref.weight.main_grad = weight.main_grad.clone()
te_outputs = _test_granular_accuracy(te_linear, bs, dtype, config, recipe=recipe)
te_outputs_ref = _test_granular_accuracy(te_linear_ref, bs, dtype, config, recipe=recipe)
# Shoule be bit-wise match
for i, (o, o_ref) in enumerate(zip(te_outputs, te_outputs_ref)):
torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", ["126m"]) @pytest.mark.parametrize("model", ["126m"])
...@@ -1764,6 +1826,111 @@ def test_grouped_linear_accuracy( ...@@ -1764,6 +1826,111 @@ def test_grouped_linear_accuracy(
device="cuda", device="cuda",
fuse_wgrad_accumulation=fuse_wgrad_accumulation, fuse_wgrad_accumulation=fuse_wgrad_accumulation,
delay_wgrad_compute=delay_wgrad_compute, delay_wgrad_compute=delay_wgrad_compute,
save_original_input=False,
).eval()
sequential_linear = torch.nn.ModuleList(
[
Linear(
config.hidden_size,
4 * config.hidden_size,
bias=bias,
params_dtype=dtype,
parallel_mode=parallel_mode,
device="cuda",
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
).eval()
for _ in range(num_gemms)
]
)
# Share params
with torch.no_grad():
for i in range(num_gemms):
sequential_linear[i].weight = Parameter(getattr(grouped_linear, f"weight{i}").clone())
if bias:
sequential_linear[i].bias = Parameter(getattr(grouped_linear, f"bias{i}").clone())
if fuse_wgrad_accumulation:
weight_i = getattr(grouped_linear, f"weight{i}")
weight_i.main_grad = torch.rand_like(weight_i, dtype=torch.float32)
sequential_linear[i].weight.main_grad = weight_i.main_grad.clone()
outputs_ref = _test_grouped_linear_accuracy(
sequential_linear,
num_gemms,
bs,
dtype,
config,
recipe,
fp8,
fuse_wgrad_accumulation,
delay_wgrad_compute,
)
outputs = _test_grouped_linear_accuracy(
grouped_linear,
num_gemms,
bs,
dtype,
config,
recipe,
fp8,
fuse_wgrad_accumulation,
delay_wgrad_compute,
)
# Shoule be bit-wise match
for i, (o, o_ref) in enumerate(zip(outputs, outputs_ref)):
torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
@pytest.mark.parametrize("dtype", param_types, ids=str)
@pytest.mark.parametrize("num_gemms", [3])
@pytest.mark.parametrize("bs", [1])
@pytest.mark.parametrize("model", ["126m"])
@pytest.mark.parametrize("recipe", fp8_recipes + [None])
@pytest.mark.parametrize("fp8_model_params", [False])
@pytest.mark.parametrize("fuse_wgrad_accumulation", [True])
@pytest.mark.parametrize("bias", [False])
@pytest.mark.parametrize("delay_wgrad_compute", [True])
def test_grouped_linear_accuracy_save_original_input(
dtype,
num_gemms,
bs,
model,
recipe,
fp8_model_params,
fuse_wgrad_accumulation,
bias,
delay_wgrad_compute,
parallel_mode=None,
):
fp8 = recipe is not None
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8 and recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.")
if fp8 and recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8 and recipe.delayed():
pytest.skip("DelayedScaling recipe is not supported with save_original_input")
config = model_configs[model]
if config.seq_len % 16 != 0 and fp8:
pytest.skip("FP8 requires sequence length to be divisible by 16.")
with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
grouped_linear = GroupedLinear(
num_gemms,
config.hidden_size,
4 * config.hidden_size,
bias=bias,
params_dtype=dtype,
parallel_mode=parallel_mode,
device="cuda",
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
delay_wgrad_compute=delay_wgrad_compute,
save_original_input=True,
).eval() ).eval()
sequential_linear = torch.nn.ModuleList( sequential_linear = torch.nn.ModuleList(
[ [
...@@ -1938,7 +2105,89 @@ def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, r ...@@ -1938,7 +2105,89 @@ def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, r
@pytest.mark.parametrize("recipe", fp8_recipes) @pytest.mark.parametrize("recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("fp8_model_params", all_boolean)
def test_padding_grouped_linear_accuracy( def test_padding_grouped_linear_accuracy(
dtype, num_gemms, bs, model, fp8, recipe, fp8_model_params, parallel_mode=None dtype,
num_gemms,
bs,
model,
fp8,
recipe,
fp8_model_params,
parallel_mode=None,
):
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.")
if recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
config = model_configs[model]
if config.seq_len % 16 != 0 and fp8:
pytest.skip("FP8 requires sequence length to be divisible by 16.")
with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
grouped_linear = TorchGroupedLinearWithPadding(
num_gemms,
config.hidden_size,
4 * config.hidden_size,
bias=False,
params_dtype=dtype,
parallel_mode=parallel_mode,
fp8=fp8,
).eval()
with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
ref_grouped_linear = GroupedLinear(
num_gemms,
config.hidden_size,
4 * config.hidden_size,
bias=False,
params_dtype=dtype,
parallel_mode=parallel_mode,
device="cuda",
save_original_input=False,
).eval()
# Share params
with torch.no_grad():
inner_grouped_linear = grouped_linear.linear_fn
for i in range(num_gemms):
setattr(
ref_grouped_linear,
f"weight{i}",
Parameter(getattr(inner_grouped_linear, f"weight{i}").clone()),
)
outputs = _test_padding_grouped_linear_accuracy(
grouped_linear, num_gemms, bs, dtype, config, recipe, fp8
)
outputs_ref = _test_padding_grouped_linear_accuracy(
ref_grouped_linear, num_gemms, bs, dtype, config, recipe, fp8
)
# Shoule be bit-wise match
for i, (o, o_ref) in enumerate(zip(outputs, outputs_ref)):
torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("num_gemms", [3])
@pytest.mark.parametrize("bs", [1])
@pytest.mark.parametrize("model", ["126m"])
@pytest.mark.parametrize("fp8", [True])
@pytest.mark.parametrize("recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_model_params", [False])
def test_padding_grouped_linear_accuracy_save_original_input(
dtype,
num_gemms,
bs,
model,
fp8,
recipe,
fp8_model_params,
parallel_mode=None,
): ):
if fp8 and not fp8_available: if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
...@@ -1948,6 +2197,8 @@ def test_padding_grouped_linear_accuracy( ...@@ -1948,6 +2197,8 @@ def test_padding_grouped_linear_accuracy(
pytest.skip("FP8 parameters are not supported in debug mode.") pytest.skip("FP8 parameters are not supported in debug mode.")
if recipe.float8_block_scaling() and not fp8_block_scaling_available: if recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling) pytest.skip(reason_for_no_fp8_block_scaling)
if fp8 and recipe.delayed():
pytest.skip("DelayedScaling recipe is not supported with save_original_input")
config = model_configs[model] config = model_configs[model]
if config.seq_len % 16 != 0 and fp8: if config.seq_len % 16 != 0 and fp8:
...@@ -1973,6 +2224,7 @@ def test_padding_grouped_linear_accuracy( ...@@ -1973,6 +2224,7 @@ def test_padding_grouped_linear_accuracy(
params_dtype=dtype, params_dtype=dtype,
parallel_mode=parallel_mode, parallel_mode=parallel_mode,
device="cuda", device="cuda",
save_original_input=True,
).eval() ).eval()
# Share params # Share params
......
...@@ -160,7 +160,8 @@ void multi_tensor_quantize_impl(const std::vector<TensorWrapper> &input_list, ...@@ -160,7 +160,8 @@ void multi_tensor_quantize_impl(const std::vector<TensorWrapper> &input_list,
with_fused_kernel = false; with_fused_kernel = false;
break; break;
} }
if (nvte_tensor_columnwise_data(output_list[i].data()) == nullptr) { if (nvte_tensor_data(output_list[i].data()) == nullptr ||
nvte_tensor_columnwise_data(output_list[i].data()) == nullptr) {
with_fused_kernel = false; with_fused_kernel = false;
break; break;
} }
......
...@@ -42,6 +42,7 @@ from ..jit import no_torch_dynamo ...@@ -42,6 +42,7 @@ from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing from ..graph import is_graph_capturing
from ..cpu_offload import is_cpu_offload_enabled from ..cpu_offload import is_cpu_offload_enabled
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer
from ..tensor.quantized_tensor import ( from ..tensor.quantized_tensor import (
QuantizedTensorBase, QuantizedTensorBase,
Quantizer, Quantizer,
...@@ -78,6 +79,7 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -78,6 +79,7 @@ class _GroupedLinear(torch.autograd.Function):
is_grad_enabled: bool, is_grad_enabled: bool,
module, module,
skip_fp8_weight_update, skip_fp8_weight_update,
save_original_input,
*weights_and_biases, *weights_and_biases,
) -> torch.Tensor: ) -> torch.Tensor:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
...@@ -89,11 +91,15 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -89,11 +91,15 @@ class _GroupedLinear(torch.autograd.Function):
weight_requires_grad = weights[0].requires_grad weight_requires_grad = weights[0].requires_grad
# Configure quantizers # Configure quantizers
if save_original_input and isinstance(input_quantizers[0], Float8Quantizer):
raise ValueError("DelayedScaling recipe is not supported with save_original_input")
if input_quantizers[0] is not None: if input_quantizers[0] is not None:
for input_quantizer in input_quantizers: for input_quantizer in input_quantizers:
input_quantizer.set_usage( input_quantizer.set_usage(
rowwise=True, rowwise=True,
columnwise=(is_grad_enabled and weight_requires_grad), columnwise=(
is_grad_enabled and weight_requires_grad and not save_original_input
),
) )
columnwise_usage = is_grad_enabled and inp.requires_grad columnwise_usage = is_grad_enabled and inp.requires_grad
if not columnwise_usage: if not columnwise_usage:
...@@ -189,9 +195,15 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -189,9 +195,15 @@ class _GroupedLinear(torch.autograd.Function):
# TODO: update after #1638 is merged. # pylint: disable=fixme # TODO: update after #1638 is merged. # pylint: disable=fixme
if weight_requires_grad: if weight_requires_grad:
if save_original_input:
inputmats = [None] * num_gemms
inputmats[0] = inp
else:
for inputmat in inputmats: for inputmat in inputmats:
if isinstance(inputmat, QuantizedTensorBase): if isinstance(inputmat, QuantizedTensorBase):
inputmat.update_usage(rowwise_usage=False, columnwise_usage=True) inputmat.update_usage(rowwise_usage=False, columnwise_usage=True)
else:
inputmats = [None] * num_gemms
if inp.requires_grad: if inp.requires_grad:
for weight in weights_fp8: for weight in weights_fp8:
if isinstance(weight, QuantizedTensorBase): if isinstance(weight, QuantizedTensorBase):
...@@ -241,6 +253,8 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -241,6 +253,8 @@ class _GroupedLinear(torch.autograd.Function):
or FP8GlobalStateManager.is_first_fp8_module() or FP8GlobalStateManager.is_first_fp8_module()
) )
ctx.wgrad_store = wgrad_store ctx.wgrad_store = wgrad_store
ctx.save_original_input = save_original_input
ctx.input_quantizers = input_quantizers
# [*, in_features] -> [*, out_features] except first dimension changes for SP # [*, in_features] -> [*, out_features] except first dimension changes for SP
return out.view(-1, *inp.shape[1:-1], out.shape[-1]) return out.view(-1, *inp.shape[1:-1], out.shape[-1])
...@@ -357,6 +371,27 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -357,6 +371,27 @@ class _GroupedLinear(torch.autograd.Function):
torch.empty(w.size(), dtype=ctx.activation_dtype, device=ctx.device) torch.empty(w.size(), dtype=ctx.activation_dtype, device=ctx.device)
for w in weights for w in weights
] ]
if ctx.save_original_input:
inp = inputmats[0]
in_features = inp.shape[-1]
inp_view = inp.reshape(-1, in_features)
if ctx.input_quantizers[0] is not None:
for input_quantizer in ctx.input_quantizers:
if isinstance(
input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
):
input_quantizer.set_usage(rowwise=True, columnwise=True)
else:
input_quantizer.set_usage(rowwise=False, columnwise=True)
inputmats: list
if ctx.fp8:
inputmats = tex.split_quantize(inp_view, ctx.m_splits, ctx.input_quantizers)
else:
inputmats = torch.split(
cast_if_needed(inp_view, ctx.activation_dtype), ctx.m_splits
)
grouped_gemm_wgrad = functools.partial( grouped_gemm_wgrad = functools.partial(
general_grouped_gemm, general_grouped_gemm,
out_dtype=ctx.activation_dtype, out_dtype=ctx.activation_dtype,
...@@ -448,6 +483,7 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -448,6 +483,7 @@ class _GroupedLinear(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
*wgrad_list, *wgrad_list,
*grad_biases, *grad_biases,
) )
...@@ -498,6 +534,11 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -498,6 +534,11 @@ class GroupedLinear(TransformerEngineBaseModule):
would not fit in GPU memory. would not fit in GPU memory.
delay_wgrad_compute : bool, default = `False` delay_wgrad_compute : bool, default = `False`
Whether to delay weight gradient computation Whether to delay weight gradient computation
save_original_input : bool, default = `False`
If set to `True`, always saves the original input tensor rather than the
cast tensor. In some scenarios, the input tensor is used by multiple modules,
and saving the original input tensor may reduce the memory usage.
Cannot work with FP8 DelayedScaling recipe.
Note: GroupedLinear doesn't really handle the TP communications inside. The `tp_size` and Note: GroupedLinear doesn't really handle the TP communications inside. The `tp_size` and
`parallel_mode` are used to determine the shapes of weights and biases. `parallel_mode` are used to determine the shapes of weights and biases.
...@@ -525,6 +566,7 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -525,6 +566,7 @@ class GroupedLinear(TransformerEngineBaseModule):
ub_overlap_ag: bool = False, ub_overlap_ag: bool = False,
ub_name: Optional[str] = None, ub_name: Optional[str] = None,
delay_wgrad_compute: bool = False, delay_wgrad_compute: bool = False,
save_original_input: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -539,6 +581,7 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -539,6 +581,7 @@ class GroupedLinear(TransformerEngineBaseModule):
self.ub_overlap_rs = ub_overlap_rs self.ub_overlap_rs = ub_overlap_rs
self.ub_overlap_ag = ub_overlap_ag self.ub_overlap_ag = ub_overlap_ag
self.ub_name = ub_name self.ub_name = ub_name
self.save_original_input = save_original_input
assert ( assert (
not ub_overlap_rs and not ub_overlap_ag not ub_overlap_rs and not ub_overlap_ag
), "GroupedLinear doesn't support Userbuffer overlap." ), "GroupedLinear doesn't support Userbuffer overlap."
...@@ -754,6 +797,7 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -754,6 +797,7 @@ class GroupedLinear(TransformerEngineBaseModule):
torch.is_grad_enabled(), torch.is_grad_enabled(),
self, self,
skip_fp8_weight_update, skip_fp8_weight_update,
self.save_original_input,
*weight_tensors, *weight_tensors,
*bias_tensors, *bias_tensors,
) )
......
...@@ -117,6 +117,7 @@ class _Linear(torch.autograd.Function): ...@@ -117,6 +117,7 @@ class _Linear(torch.autograd.Function):
module: torch.nn.Module, module: torch.nn.Module,
skip_fp8_weight_update: bool, skip_fp8_weight_update: bool,
symmetric_ar_type: str, symmetric_ar_type: str,
save_original_input: bool = False,
debug: Optional[bool] = False, debug: Optional[bool] = False,
) -> torch.Tensor: ) -> torch.Tensor:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
...@@ -157,6 +158,11 @@ class _Linear(torch.autograd.Function): ...@@ -157,6 +158,11 @@ class _Linear(torch.autograd.Function):
own_quantized_input = False own_quantized_input = False
if fp8: if fp8:
assert_dim_for_fp8_exec(inputmat, weight) assert_dim_for_fp8_exec(inputmat, weight)
if save_original_input:
assert not isinstance(
input_quantizer, Float8Quantizer
), "DelayedScaling recipe is not supported with save_original_input"
if with_input_all_gather_nccl or ub_overlap_ag_fprop: # All-gather input tensor if with_input_all_gather_nccl or ub_overlap_ag_fprop: # All-gather input tensor
# Cast local input tensor if needed # Cast local input tensor if needed
...@@ -164,7 +170,9 @@ class _Linear(torch.autograd.Function): ...@@ -164,7 +170,9 @@ class _Linear(torch.autograd.Function):
if input_quantizer is None: if input_quantizer is None:
raise ValueError("Missing quantizer for input tensor") raise ValueError("Missing quantizer for input tensor")
if not isinstance(inputmat, QuantizedTensorBase): if not isinstance(inputmat, QuantizedTensorBase):
input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input) input_quantizer.set_usage(
rowwise=True, columnwise=backward_needs_input and not save_original_input
)
if isinstance( if isinstance(
input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer) input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
): ):
...@@ -201,7 +209,9 @@ class _Linear(torch.autograd.Function): ...@@ -201,7 +209,9 @@ class _Linear(torch.autograd.Function):
else: else:
if input_quantizer is None: if input_quantizer is None:
raise ValueError("Missing quantizer for input tensor") raise ValueError("Missing quantizer for input tensor")
input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input) input_quantizer.set_usage(
rowwise=True, columnwise=backward_needs_input and not save_original_input
)
inputmat = input_quantizer(inputmat) inputmat = input_quantizer(inputmat)
own_quantized_input = True own_quantized_input = True
else: else:
...@@ -330,6 +340,9 @@ class _Linear(torch.autograd.Function): ...@@ -330,6 +340,9 @@ class _Linear(torch.autograd.Function):
# ------------------------------------------------------ # ------------------------------------------------------
if is_grad_enabled: if is_grad_enabled:
if save_original_input:
inputmat = inp
ctx.weight_quantizer = weight_quantizer ctx.weight_quantizer = weight_quantizer
saved_inputmat = None saved_inputmat = None
...@@ -338,6 +351,7 @@ class _Linear(torch.autograd.Function): ...@@ -338,6 +351,7 @@ class _Linear(torch.autograd.Function):
) )
if backward_needs_input: if backward_needs_input:
if not save_original_input:
if own_quantized_input and isinstance(inputmat, QuantizedTensorBase): if own_quantized_input and isinstance(inputmat, QuantizedTensorBase):
# For sequence parallel in vanilla FP8, rowwise data is # For sequence parallel in vanilla FP8, rowwise data is
# to gather the input. For MXFP8, columnwise only data # to gather the input. For MXFP8, columnwise only data
...@@ -557,6 +571,24 @@ class _Linear(torch.autograd.Function): ...@@ -557,6 +571,24 @@ class _Linear(torch.autograd.Function):
# -------------------------------------------------- # --------------------------------------------------
inputmat_total = None inputmat_total = None
inputmat_total_work = None inputmat_total_work = None
if ctx.requires_wgrad:
input_is_quantized = isinstance(inputmat, QuantizedTensorBase)
if ctx.fp8 or ctx.debug:
if not input_is_quantized:
quantizer = ctx.input_quantizer
if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)):
quantizer.set_usage(
rowwise=True,
columnwise=not ctx.backward_input_needs_gather,
)
else:
quantizer.set_usage(rowwise=False, columnwise=True)
inputmat = quantizer(inputmat)
else:
if input_is_quantized:
inputmat = inputmat.dequantize(dtype=ctx.activation_dtype)
else:
inputmat = cast_if_needed(inputmat, ctx.activation_dtype)
if ctx.backward_input_needs_gather: if ctx.backward_input_needs_gather:
quantizer = None quantizer = None
if ctx.fp8 or ctx.debug: if ctx.fp8 or ctx.debug:
...@@ -898,6 +930,7 @@ class _Linear(torch.autograd.Function): ...@@ -898,6 +930,7 @@ class _Linear(torch.autograd.Function):
None, # module None, # module
None, # skip_fp8_weight_update None, # skip_fp8_weight_update
None, # symmetric_ar_type None, # symmetric_ar_type
None, # save_original_input
None, # debug None, # debug
) )
...@@ -980,6 +1013,11 @@ class Linear(TransformerEngineBaseModule): ...@@ -980,6 +1013,11 @@ class Linear(TransformerEngineBaseModule):
This can help in latency bound communication situations. This can help in latency bound communication situations.
Requires PyTorch version 2.7.0 or higher. When set to None, standard all-reduce Requires PyTorch version 2.7.0 or higher. When set to None, standard all-reduce
is used. is used.
save_original_input : bool, default = `False`
If set to `True`, always saves the original input tensor rather than the
cast tensor. In some scenarios, the input tensor is used by multiple modules,
and saving the original input tensor may reduce the memory usage.
Cannot work with FP8 DelayedScaling recipe.
""" """
def __init__( def __init__(
...@@ -1007,6 +1045,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -1007,6 +1045,7 @@ class Linear(TransformerEngineBaseModule):
ub_name: Optional[str] = None, ub_name: Optional[str] = None,
delay_wgrad_compute: bool = False, delay_wgrad_compute: bool = False,
symmetric_ar_type: Optional[str] = None, symmetric_ar_type: Optional[str] = None,
save_original_input: bool = False,
name: Optional[str] = None, name: Optional[str] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -1021,6 +1060,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -1021,6 +1060,7 @@ class Linear(TransformerEngineBaseModule):
self.get_rng_state_tracker = get_rng_state_tracker self.get_rng_state_tracker = get_rng_state_tracker
self.rng_tracker_name = rng_tracker_name self.rng_tracker_name = rng_tracker_name
self.symmetric_ar_type = symmetric_ar_type self.symmetric_ar_type = symmetric_ar_type
self.save_original_input = save_original_input
self.name = name self.name = name
if TEDebugState.debug_enabled: if TEDebugState.debug_enabled:
...@@ -1371,6 +1411,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -1371,6 +1411,7 @@ class Linear(TransformerEngineBaseModule):
self, self,
skip_fp8_weight_update, skip_fp8_weight_update,
self.symmetric_ar_type, self.symmetric_ar_type,
self.save_original_input,
debug, debug,
) )
out = linear_fn(*args) out = linear_fn(*args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment