Unverified Commit 41909dc8 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Linear op avoids saving input tensor if weight grad is not needed (#1817)



* Linear op avoids saving input tensor if weight grad is not needed
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Linear op forward avoids producing quantized tensors with unnecessary usages
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix linter warnings
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Avoid unnecessary usages in fused linear ops
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
parent c9e8e305
......@@ -913,6 +913,8 @@ class TestBasicOps:
@pytest.mark.parametrize("bias", (False, True))
@pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8"))
@pytest.mark.parametrize("quantized_weight", (False, True))
@pytest.mark.parametrize("input_requires_grad", (False, True))
@pytest.mark.parametrize("weight_requires_grad", (False, True))
def test_linear(
self,
*,
......@@ -923,6 +925,8 @@ class TestBasicOps:
device: torch.device = "cuda",
quantization: Optional[str],
quantized_weight: bool,
input_requires_grad: bool,
weight_requires_grad: bool,
) -> None:
"""GEMM + bias"""
......@@ -943,9 +947,10 @@ class TestBasicOps:
test_device=device,
test_is_fp8=quantized_compute,
)
if isinstance(x_test, QuantizedTensor):
with torch.no_grad():
x_test = x_test.dequantize().requires_grad_()
if isinstance(x_test, QuantizedTensor):
x_test = x_test.dequantize()
x_test.requires_grad_(requires_grad=input_requires_grad)
w_ref, w_test = make_reference_and_test_tensors(
(out_features, in_features),
test_dtype=dtype,
......@@ -986,8 +991,11 @@ class TestBasicOps:
op.bias.copy_(b_test)
del w_test
del b_test
for param in op.parameters():
param.requires_grad_(requires_grad=weight_requires_grad)
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
y_test = op(x_test)
if input_requires_grad or weight_requires_grad:
y_test.backward(dy_test)
# Expected numerical error
......@@ -999,10 +1007,12 @@ class TestBasicOps:
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
dw_test = op.weight.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
if input_requires_grad:
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
if weight_requires_grad:
dw_test = op.weight.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(dw_test, w_ref.grad, **tols)
if bias:
db_test = op.bias.grad.to(dtype=torch.float64, device="cpu")
......
......@@ -349,7 +349,9 @@ class BasicLinear(BasicOperation):
input_quantizer: Optional[Quantizer] = None,
weight_quantizer: Optional[Quantizer] = None,
output_quantizer: Optional[Quantizer] = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
input_requires_grad: bool = True,
weight_requires_grad: bool = True,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
"""Functional API for forward pass
Parameters
......@@ -385,17 +387,25 @@ class BasicLinear(BasicOperation):
Builder class for quantized weight tensor.
output_quantizer: Quantizer, optional
Builder class for quantized output tensor.
input_requires_grad: bool, default = `True`
Whether the loss gradient w.r.t. the input tensor is
required in the backward pass.
weight_requires_grad: bool, default = `True`
Whether the loss gradient w.r.t. the weight tensor is
required in the backward pass.
Returns
-------
torch.Tensor
Output tensor
torch.Tensor
Input tensor used in GEMM, possibly cast and reshaped from
provided input tensor
torch.Tensor
Weight tensor used in GEMM, possibly cast and reshaped from
provided weight tensor
torch.Tensor, optional
Input tensor, ready for use in backward pass. `None` is
returned if loss gradient w.r.t. the weight tensor is not
required.
torch.Tensor, optional
Weight tensor, ready for use in backward pass. `None` is
returned if loss gradient w.r.t. the input tensor is not
required.
"""
......@@ -416,7 +426,7 @@ class BasicLinear(BasicOperation):
if with_quantized_compute:
if input_quantizer is None:
raise ValueError("Missing quantizer for input tensor")
input_quantizer.set_usage(rowwise=True)
input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad)
if with_x_all_gather:
input_quantizer.set_usage(columnwise=False)
x, x_async = gather_along_first_dim(
......@@ -449,7 +459,7 @@ class BasicLinear(BasicOperation):
if with_quantized_compute and not w_is_quantized:
if weight_quantizer is None:
raise ValueError("Missing quantizer for weight tensor")
weight_quantizer.set_usage(rowwise=True)
weight_quantizer.set_usage(rowwise=True, columnwise=input_requires_grad)
w = weight_quantizer(w)
elif not with_quantized_compute and w_is_quantized:
w = w.dequantize()
......@@ -526,17 +536,25 @@ class BasicLinear(BasicOperation):
else:
torch.distributed.all_reduce(y, group=tensor_parallel_group)
# Detach input tensor if needed
# Note: PyTorch autograd produces esoteric errors if we save
# input tensor as context for backward pass.
# Prepare weight tensor for backward pass
if input_requires_grad:
if w is not weight and with_quantized_compute and isinstance(w, QuantizedTensor):
w.update_usage(rowwise_usage=False, columnwise_usage=True)
else:
w = None
# Prepare input tensor for backward pass
if weight_requires_grad:
if x_local is input:
# PyTorch autograd produces esoteric errors if we
# cache input tensor directly.
x_local = x_local.detach()
# Configure input tensor for backward pass
if with_quantized_compute and isinstance(x_local, QuantizedTensor):
if not (isinstance(x_local, Float8TensorBase) and with_x_all_gather):
# FP8 does not support all-gather of transpose data
x_local.update_usage(rowwise_usage=False, columnwise_usage=True)
else:
x_local = None
return y, x_local, w
......@@ -892,7 +910,7 @@ class BasicLinear(BasicOperation):
dtype = torch.get_autocast_dtype("cuda")
# Linear forward
output, x_local, _ = BasicLinear._functional_forward(
output, x_local, w = BasicLinear._functional_forward(
input=input_,
weight=self.weight,
dtype=dtype,
......@@ -903,10 +921,12 @@ class BasicLinear(BasicOperation):
input_quantizer=input_quantizer,
weight_quantizer=weight_quantizer,
output_quantizer=output_quantizer,
input_requires_grad=input_requires_grad,
weight_requires_grad=weight_requires_grad,
)
# Save state for backward pass
ctx.save_for_backward(x_local)
ctx.save_for_backward(x_local, w)
ctx.with_quantized_compute = with_quantized_compute
ctx.input_quantizer = input_quantizer
ctx.weight_quantizer = weight_quantizer
......@@ -926,7 +946,7 @@ class BasicLinear(BasicOperation):
) -> tuple[torch.Tensor, Iterable[Optional[torch.Tensor]]]:
# Saved tensors from forward pass
(x_local,) = ctx.saved_tensors
(x_local, w) = ctx.saved_tensors
# wgrad fusion
accumulate_into_main_grad = self._accumulate_into_main_grad
......@@ -946,7 +966,7 @@ class BasicLinear(BasicOperation):
grad_input, grad_weight = BasicLinear._functional_backward(
grad_output=grad_output,
input=x_local,
weight=self.weight,
weight=w,
input_requires_grad=ctx.input_requires_grad,
weight_requires_grad=ctx.weight_requires_grad,
dtype=ctx.dtype,
......
......@@ -51,7 +51,7 @@ class BackwardLinearAdd(FusedOperation):
linear_op_ctx = basic_op_ctxs[0]
# Saved tensors from forward pass
(x_local,) = linear_op_ctx.saved_tensors
(x_local, w) = linear_op_ctx.saved_tensors
# wgrad fusion
accumulate_into_main_grad = linear_op._accumulate_into_main_grad
......@@ -72,7 +72,7 @@ class BackwardLinearAdd(FusedOperation):
grad_input, grad_weight = BasicLinear._functional_backward(
grad_output=grad_output,
input=x_local,
weight=linear_op.weight,
weight=w,
input_requires_grad=linear_op_ctx.input_requires_grad,
weight_requires_grad=linear_op_ctx.weight_requires_grad,
dtype=grad_input.dtype,
......
......@@ -82,6 +82,10 @@ class ForwardLinearBiasActivation(FusedOperation):
else:
raise NotImplementedError("Activations are not yet supported")
# Check which grads are required
input_requires_grad = linear_op_ctx.requires_grad and input_.requires_grad
weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad
# FP8 metadata
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
input_quantizer = None
......@@ -106,7 +110,7 @@ class ForwardLinearBiasActivation(FusedOperation):
dtype = torch.get_autocast_dtype("cuda")
# Linear forward
output, x_local, _ = BasicLinear._functional_forward(
output, x_local, w = BasicLinear._functional_forward(
input=input_,
weight=linear_op.weight,
bias=bias,
......@@ -118,18 +122,20 @@ class ForwardLinearBiasActivation(FusedOperation):
input_quantizer=input_quantizer,
weight_quantizer=weight_quantizer,
output_quantizer=output_quantizer,
input_requires_grad=input_requires_grad,
weight_requires_grad=weight_requires_grad,
)
# Save state for backward pass
linear_op_ctx.save_for_backward(x_local)
linear_op_ctx.save_for_backward(x_local, w)
linear_op_ctx.with_quantized_compute = with_quantized_compute
linear_op_ctx.input_quantizer = input_quantizer
linear_op_ctx.weight_quantizer = weight_quantizer
linear_op_ctx.grad_output_quantizer = grad_output_quantizer
linear_op_ctx.grad_input_quantizer = grad_input_quantizer
linear_op_ctx.dtype = dtype
linear_op_ctx.input_requires_grad = input_.requires_grad
linear_op_ctx.weight_requires_grad = linear_op.weight.requires_grad
linear_op_ctx.input_requires_grad = input_requires_grad
linear_op_ctx.weight_requires_grad = weight_requires_grad
linear_op_ctx.has_prev_op = basic_op_prev_ops[0] is not None
return output, [() for _ in range(len(self.basic_ops))]
......
......@@ -76,6 +76,10 @@ class ForwardLinearBiasAdd(FusedOperation):
if basic_op_kwargs[idx]:
raise ValueError("Bias operation forward does not expect keyword arguments")
# Check which grads are required
input_requires_grad = linear_op_ctx.requires_grad and input_.requires_grad
weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad
# FP8 metadata
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
input_quantizer = None
......@@ -98,7 +102,7 @@ class ForwardLinearBiasAdd(FusedOperation):
# Linear forward
output = basic_op_extra_inputs[self._op_idxs["add"]][0]
output, x_local, _ = BasicLinear._functional_forward(
output, x_local, w = BasicLinear._functional_forward(
input=input_,
weight=linear_op.weight,
bias=bias,
......@@ -111,18 +115,20 @@ class ForwardLinearBiasAdd(FusedOperation):
input_quantizer=input_quantizer,
weight_quantizer=weight_quantizer,
output_quantizer=output_quantizer,
input_requires_grad=input_requires_grad,
weight_requires_grad=weight_requires_grad,
)
# Save state for backward pass
linear_op_ctx.save_for_backward(x_local)
linear_op_ctx.save_for_backward(x_local, w)
linear_op_ctx.with_quantized_compute = with_quantized_compute
linear_op_ctx.input_quantizer = input_quantizer
linear_op_ctx.weight_quantizer = weight_quantizer
linear_op_ctx.grad_output_quantizer = grad_output_quantizer
linear_op_ctx.grad_input_quantizer = grad_input_quantizer
linear_op_ctx.dtype = dtype
linear_op_ctx.input_requires_grad = input_.requires_grad
linear_op_ctx.weight_requires_grad = linear_op.weight.requires_grad
linear_op_ctx.input_requires_grad = input_requires_grad
linear_op_ctx.weight_requires_grad = weight_requires_grad
linear_op_ctx.has_prev_op = basic_op_prev_ops[0] is not None
return output, [() for _ in range(len(self.basic_ops))]
......
......@@ -500,7 +500,7 @@ class UserbuffersBackwardLinear(FusedOperation):
bias_op = self.basic_ops[idx]
# Saved tensors from forward pass
(x_local,) = linear_op_ctx.saved_tensors
(x_local, w) = linear_op_ctx.saved_tensors
# wgrad fusion
accumulate_into_main_grad = linear_op._accumulate_into_main_grad
......@@ -520,7 +520,7 @@ class UserbuffersBackwardLinear(FusedOperation):
retval = UserbuffersBackwardLinear._functional_backward(
grad_output=grad_output,
input=x_local,
weight=linear_op.weight,
weight=w,
weight_requires_grad=linear_op_ctx.weight_requires_grad,
bias_requires_grad=(bias_op is not None),
dtype=linear_op_ctx.dtype,
......
......@@ -98,6 +98,8 @@ class UserbuffersForwardLinear(FusedOperation):
input_quantizer: Optional[Quantizer] = None,
weight_quantizer: Optional[Quantizer] = None,
output_quantizer: Optional[Quantizer] = None,
input_requires_grad: bool = True,
weight_requires_grad: bool = True,
ub_comm_name: str,
) -> tuple[torch.Tensor, dict]:
"""Functional API for forward pass
......@@ -131,6 +133,12 @@ class UserbuffersForwardLinear(FusedOperation):
Builder class for quantized weight tensor.
output_quantizer: Quantizer, optional
Builder class for quantized output tensor.
input_requires_grad: bool, default = `True`
Whether the loss gradient w.r.t. the input tensor is
required in the backward pass.
weight_requires_grad: bool, default = `True`
Whether the loss gradient w.r.t. the weight tensor is
required in the backward pass.
ub_comm_name: str
Layer type (e.g. "qkv", "proj", "fc1", "fc2"). This is
used to access the corresponding Userbuffers communicators
......@@ -141,8 +149,9 @@ class UserbuffersForwardLinear(FusedOperation):
torch.Tensor
Output tensor
dict
Extra output tensors. "input" is the input tensor,
possibly cast and reshaped from the provided input tensor.
Extra output tensors. "input" is the input tensor and
"weight" is the weight tensor, both ready for use in the
backward pass.
"""
......@@ -198,7 +207,7 @@ class UserbuffersForwardLinear(FusedOperation):
if with_ub_all_gather:
if input_quantizer is not None:
if not isinstance(x_local, QuantizedTensorBase):
input_quantizer.set_usage(rowwise=True, columnwise=True)
input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad)
if isinstance(input_quantizer, Float8Quantizer):
input_quantizer.set_usage(columnwise=False)
x_local = input_quantizer(x_local)
......@@ -212,7 +221,7 @@ class UserbuffersForwardLinear(FusedOperation):
else:
if with_quantized_compute:
if not isinstance(x_local, QuantizedTensorBase):
input_quantizer.set_usage(rowwise=True, columnwise=True)
input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad)
x_local = input_quantizer(x_local)
else:
if isinstance(x_local, QuantizedTensorBase):
......@@ -225,7 +234,7 @@ class UserbuffersForwardLinear(FusedOperation):
w = weight
w_is_quantized = isinstance(w, QuantizedTensorBase)
if with_quantized_compute and not w_is_quantized:
weight_quantizer.set_usage(rowwise=True)
weight_quantizer.set_usage(rowwise=True, columnwise=input_requires_grad)
w = weight_quantizer(w)
elif not with_quantized_compute and w_is_quantized:
w = w.dequantize()
......@@ -258,17 +267,25 @@ class UserbuffersForwardLinear(FusedOperation):
else:
y_local = gemm_output
# Detach input tensor if needed
# Note: PyTorch autograd produces esoteric errors if we save
# input tensor as context for backward pass.
# Prepare weight tensor for backward pass
if input_requires_grad:
if w is not weight and with_quantized_compute and isinstance(w, QuantizedTensorBase):
w.update_usage(rowwise_usage=False, columnwise_usage=True)
else:
w = None
# Prepare input tensor for backward pass
if weight_requires_grad:
if x_local is input:
# PyTorch autograd produces esoteric errors if we
# cache input tensor directly.
x_local = x_local.detach()
# Configure input tensor for backward pass
if with_quantized_compute and isinstance(x_local, QuantizedTensorBase):
if not (isinstance(x_local, Float8TensorBase) and with_ub_all_gather):
# FP8 does not support all-gather of transpose data
x_local.update_usage(rowwise_usage=False, columnwise_usage=True)
else:
x_local = None
# Return cast tensors
extra_outputs = {"input": x_local, "weight": w}
......@@ -298,6 +315,10 @@ class UserbuffersForwardLinear(FusedOperation):
if basic_op_kwargs[idx]:
raise ValueError("Bias operation forward does not expect keyword arguments")
# Check which grads are required
input_requires_grad = linear_op_ctx.requires_grad and input_.requires_grad
weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad
# Quantization metadata
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
input_quantizer = None
......@@ -338,12 +359,15 @@ class UserbuffersForwardLinear(FusedOperation):
input_quantizer=input_quantizer,
weight_quantizer=weight_quantizer,
output_quantizer=None, # Not supported
input_requires_grad=input_requires_grad,
weight_requires_grad=weight_requires_grad,
ub_comm_name=linear_op._userbuffers_options["comm_name"],
)
x_local = extra_outputs["input"]
w = extra_outputs["weight"]
# Save state for backward pass
linear_op_ctx.save_for_backward(x_local)
linear_op_ctx.save_for_backward(x_local, w)
linear_op_ctx.with_quantized_compute = with_quantized_compute
linear_op_ctx.input_quantizer = input_quantizer
linear_op_ctx.weight_quantizer = weight_quantizer
......@@ -351,8 +375,8 @@ class UserbuffersForwardLinear(FusedOperation):
linear_op_ctx.grad_input_quantizer = grad_input_quantizer
linear_op_ctx.dtype = dtype
linear_op_ctx.input_dims = input_.size()
linear_op_ctx.input_requires_grad = input_.requires_grad
linear_op_ctx.weight_requires_grad = linear_op.weight.requires_grad
linear_op_ctx.input_requires_grad = input_requires_grad
linear_op_ctx.weight_requires_grad = weight_requires_grad
linear_op_ctx.has_prev_op = basic_op_prev_ops[0] is not None
return output, [() for _ in range(len(self.basic_ops))]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment