Unverified Commit 44fbe9e6 authored by Kshiteej K's avatar Kshiteej K Committed by GitHub
Browse files

fix: update grad_output quant to avoid redundant work (#1736)



* fix: update grad_output quant to avoid redundant work
Signed-off-by: default avatarkshitij12345 <kshitijkalambarkar@gmail.com>

* add test
Signed-off-by: default avatarkshitij12345 <kshitijkalambarkar@gmail.com>

* don't keep only columnwise quant if requires_dgrad=False
Signed-off-by: default avatarkshitij12345 <kshitijkalambarkar@gmail.com>

* fix stray merge
Signed-off-by: default avatarkshitij12345 <kshitijkalambarkar@gmail.com>

* fix for ctx.use_bias is True case
Signed-off-by: default avatarkshitij12345 <kshitijkalambarkar@gmail.com>

* Skip if FP8 not available
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarkshitij12345 <kshitijkalambarkar@gmail.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent bbddcb92
...@@ -995,6 +995,32 @@ def test_sanity_checkpointing_on_callables(): ...@@ -995,6 +995,32 @@ def test_sanity_checkpointing_on_callables():
torch.testing.assert_close(grad_checkpoint, grad_standard) torch.testing.assert_close(grad_checkpoint, grad_standard)
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_linear_frozen_weights_memory_default_recipe():
"""Test that memory usage is optimized when weights are frozen for MXFP8."""
dim = 1024
linear = Linear(dim, dim, bias=False)
x = torch.randn(dim, dim, requires_grad=True, device="cuda")
# Freeze weights
linear.weight.requires_grad = False
# Forward and backward pass with FP8
with fp8_autocast():
o = linear(x)
g_o = torch.randn_like(o)
max_memory_before_backward = torch.cuda.max_memory_allocated()
o.backward(g_o)
max_memory_after_backward = torch.cuda.max_memory_allocated()
memory_diff = (max_memory_after_backward - max_memory_before_backward) / 1e6
assert memory_diff < 5.5, (
f"Memory usage with frozen weights ({memory_diff}MB) should be less than 5.5MB as the"
" grad_output should be quantized only columnwise."
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"module_name", "module_name",
("Linear", "LayerNormLinear", "LayerNormMLP", "GroupedLinear", "ops.Linear"), ("Linear", "LayerNormLinear", "LayerNormMLP", "GroupedLinear", "ops.Linear"),
......
...@@ -1320,7 +1320,7 @@ void fp8_quantize_arch_l_100(const Tensor &input, const Tensor *act_input, const ...@@ -1320,7 +1320,7 @@ void fp8_quantize_arch_l_100(const Tensor &input, const Tensor *act_input, const
if (!is_tensor_scaling(output->scaling_mode) || IS_DBIAS) { if (!is_tensor_scaling(output->scaling_mode) || IS_DBIAS) {
// zhongboz: should we just ignore IS_ACT here? // zhongboz: should we just ignore IS_ACT here?
NVTE_ERROR("Not implemented scaling mode or fusion: " + to_string(output->scaling_mode) + NVTE_ERROR("Not implemented scaling mode or fusion: " + to_string(output->scaling_mode) +
" on GPU with compute capability < 10.0."); " or IS_DBIAS=true" + " on GPU with compute capability < 10.0.");
} }
switch (output->scaling_mode) { switch (output->scaling_mode) {
case NVTE_DELAYED_TENSOR_SCALING: { case NVTE_DELAYED_TENSOR_SCALING: {
......
...@@ -559,6 +559,19 @@ class _Linear(torch.autograd.Function): ...@@ -559,6 +559,19 @@ class _Linear(torch.autograd.Function):
# usage for only dgrad GEMM. # usage for only dgrad GEMM.
quantizer.set_usage(columnwise=False) quantizer.set_usage(columnwise=False)
# Adjust the quantization direction approach depending
# on whether wgrad calculations will be performed.
# NOTE: If requires_dgrad is False, disabling `rowwise` quantization and keeping `columnwise` quantization
# results in `Assertion failed: output_tensor->has_data(). Quantizing in only the columnwise direction not supported yet!`
# NOTE: For `ctx.bias is True`, selected quantize kernel errors with
# `cast_kernels.cuh:1322 in function fp8_quantize_arch_l_100: Not implemented scaling mode or fusion: NVTE_DELAYED_TENSOR_SCALING or IS_DBIAS=true on GPU with compute capability < 10.0.`
if (
not ctx.use_bias
and not ctx.requires_wgrad
and ctx.grad_output_quantizer is not None
):
ctx.grad_output_quantizer.set_usage(columnwise=False)
# Prepare grad output tensor # Prepare grad output tensor
nvtx_range_push(f"{nvtx_label}.grad_output_preprocess") nvtx_range_push(f"{nvtx_label}.grad_output_preprocess")
( (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment