Unverified Commit 41909dc8 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Linear op avoids saving input tensor if weight grad is not needed (#1817)



* Linear op avoids saving input tensor if weight grad is not needed
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Linear op forward avoids producing quantized tensors with unnecessary usages
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix linter warnings
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Avoid unnecessary usages in fused linear ops
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
parent c9e8e305
...@@ -913,6 +913,8 @@ class TestBasicOps: ...@@ -913,6 +913,8 @@ class TestBasicOps:
@pytest.mark.parametrize("bias", (False, True)) @pytest.mark.parametrize("bias", (False, True))
@pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8"))
@pytest.mark.parametrize("quantized_weight", (False, True)) @pytest.mark.parametrize("quantized_weight", (False, True))
@pytest.mark.parametrize("input_requires_grad", (False, True))
@pytest.mark.parametrize("weight_requires_grad", (False, True))
def test_linear( def test_linear(
self, self,
*, *,
...@@ -923,6 +925,8 @@ class TestBasicOps: ...@@ -923,6 +925,8 @@ class TestBasicOps:
device: torch.device = "cuda", device: torch.device = "cuda",
quantization: Optional[str], quantization: Optional[str],
quantized_weight: bool, quantized_weight: bool,
input_requires_grad: bool,
weight_requires_grad: bool,
) -> None: ) -> None:
"""GEMM + bias""" """GEMM + bias"""
...@@ -943,9 +947,10 @@ class TestBasicOps: ...@@ -943,9 +947,10 @@ class TestBasicOps:
test_device=device, test_device=device,
test_is_fp8=quantized_compute, test_is_fp8=quantized_compute,
) )
if isinstance(x_test, QuantizedTensor):
with torch.no_grad(): with torch.no_grad():
x_test = x_test.dequantize().requires_grad_() if isinstance(x_test, QuantizedTensor):
x_test = x_test.dequantize()
x_test.requires_grad_(requires_grad=input_requires_grad)
w_ref, w_test = make_reference_and_test_tensors( w_ref, w_test = make_reference_and_test_tensors(
(out_features, in_features), (out_features, in_features),
test_dtype=dtype, test_dtype=dtype,
...@@ -986,8 +991,11 @@ class TestBasicOps: ...@@ -986,8 +991,11 @@ class TestBasicOps:
op.bias.copy_(b_test) op.bias.copy_(b_test)
del w_test del w_test
del b_test del b_test
for param in op.parameters():
param.requires_grad_(requires_grad=weight_requires_grad)
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
y_test = op(x_test) y_test = op(x_test)
if input_requires_grad or weight_requires_grad:
y_test.backward(dy_test) y_test.backward(dy_test)
# Expected numerical error # Expected numerical error
...@@ -999,10 +1007,12 @@ class TestBasicOps: ...@@ -999,10 +1007,12 @@ class TestBasicOps:
# Check results # Check results
y_test = y_test.to(dtype=torch.float64, device="cpu") y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
dw_test = op.weight.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols) torch.testing.assert_close(y_test, y_ref, **tols)
if input_requires_grad:
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(dx_test, x_ref.grad, **tols) torch.testing.assert_close(dx_test, x_ref.grad, **tols)
if weight_requires_grad:
dw_test = op.weight.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(dw_test, w_ref.grad, **tols) torch.testing.assert_close(dw_test, w_ref.grad, **tols)
if bias: if bias:
db_test = op.bias.grad.to(dtype=torch.float64, device="cpu") db_test = op.bias.grad.to(dtype=torch.float64, device="cpu")
......
...@@ -349,7 +349,9 @@ class BasicLinear(BasicOperation): ...@@ -349,7 +349,9 @@ class BasicLinear(BasicOperation):
input_quantizer: Optional[Quantizer] = None, input_quantizer: Optional[Quantizer] = None,
weight_quantizer: Optional[Quantizer] = None, weight_quantizer: Optional[Quantizer] = None,
output_quantizer: Optional[Quantizer] = None, output_quantizer: Optional[Quantizer] = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: input_requires_grad: bool = True,
weight_requires_grad: bool = True,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
"""Functional API for forward pass """Functional API for forward pass
Parameters Parameters
...@@ -385,17 +387,25 @@ class BasicLinear(BasicOperation): ...@@ -385,17 +387,25 @@ class BasicLinear(BasicOperation):
Builder class for quantized weight tensor. Builder class for quantized weight tensor.
output_quantizer: Quantizer, optional output_quantizer: Quantizer, optional
Builder class for quantized output tensor. Builder class for quantized output tensor.
input_requires_grad: bool, default = `True`
Whether the loss gradient w.r.t. the input tensor is
required in the backward pass.
weight_requires_grad: bool, default = `True`
Whether the loss gradient w.r.t. the weight tensor is
required in the backward pass.
Returns Returns
------- -------
torch.Tensor torch.Tensor
Output tensor Output tensor
torch.Tensor torch.Tensor, optional
Input tensor used in GEMM, possibly cast and reshaped from Input tensor, ready for use in backward pass. `None` is
provided input tensor returned if loss gradient w.r.t. the weight tensor is not
torch.Tensor required.
Weight tensor used in GEMM, possibly cast and reshaped from torch.Tensor, optional
provided weight tensor Weight tensor, ready for use in backward pass. `None` is
returned if loss gradient w.r.t. the input tensor is not
required.
""" """
...@@ -416,7 +426,7 @@ class BasicLinear(BasicOperation): ...@@ -416,7 +426,7 @@ class BasicLinear(BasicOperation):
if with_quantized_compute: if with_quantized_compute:
if input_quantizer is None: if input_quantizer is None:
raise ValueError("Missing quantizer for input tensor") raise ValueError("Missing quantizer for input tensor")
input_quantizer.set_usage(rowwise=True) input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad)
if with_x_all_gather: if with_x_all_gather:
input_quantizer.set_usage(columnwise=False) input_quantizer.set_usage(columnwise=False)
x, x_async = gather_along_first_dim( x, x_async = gather_along_first_dim(
...@@ -449,7 +459,7 @@ class BasicLinear(BasicOperation): ...@@ -449,7 +459,7 @@ class BasicLinear(BasicOperation):
if with_quantized_compute and not w_is_quantized: if with_quantized_compute and not w_is_quantized:
if weight_quantizer is None: if weight_quantizer is None:
raise ValueError("Missing quantizer for weight tensor") raise ValueError("Missing quantizer for weight tensor")
weight_quantizer.set_usage(rowwise=True) weight_quantizer.set_usage(rowwise=True, columnwise=input_requires_grad)
w = weight_quantizer(w) w = weight_quantizer(w)
elif not with_quantized_compute and w_is_quantized: elif not with_quantized_compute and w_is_quantized:
w = w.dequantize() w = w.dequantize()
...@@ -526,17 +536,25 @@ class BasicLinear(BasicOperation): ...@@ -526,17 +536,25 @@ class BasicLinear(BasicOperation):
else: else:
torch.distributed.all_reduce(y, group=tensor_parallel_group) torch.distributed.all_reduce(y, group=tensor_parallel_group)
# Detach input tensor if needed # Prepare weight tensor for backward pass
# Note: PyTorch autograd produces esoteric errors if we save if input_requires_grad:
# input tensor as context for backward pass. if w is not weight and with_quantized_compute and isinstance(w, QuantizedTensor):
w.update_usage(rowwise_usage=False, columnwise_usage=True)
else:
w = None
# Prepare input tensor for backward pass
if weight_requires_grad:
if x_local is input: if x_local is input:
# PyTorch autograd produces esoteric errors if we
# cache input tensor directly.
x_local = x_local.detach() x_local = x_local.detach()
# Configure input tensor for backward pass
if with_quantized_compute and isinstance(x_local, QuantizedTensor): if with_quantized_compute and isinstance(x_local, QuantizedTensor):
if not (isinstance(x_local, Float8TensorBase) and with_x_all_gather): if not (isinstance(x_local, Float8TensorBase) and with_x_all_gather):
# FP8 does not support all-gather of transpose data # FP8 does not support all-gather of transpose data
x_local.update_usage(rowwise_usage=False, columnwise_usage=True) x_local.update_usage(rowwise_usage=False, columnwise_usage=True)
else:
x_local = None
return y, x_local, w return y, x_local, w
...@@ -892,7 +910,7 @@ class BasicLinear(BasicOperation): ...@@ -892,7 +910,7 @@ class BasicLinear(BasicOperation):
dtype = torch.get_autocast_dtype("cuda") dtype = torch.get_autocast_dtype("cuda")
# Linear forward # Linear forward
output, x_local, _ = BasicLinear._functional_forward( output, x_local, w = BasicLinear._functional_forward(
input=input_, input=input_,
weight=self.weight, weight=self.weight,
dtype=dtype, dtype=dtype,
...@@ -903,10 +921,12 @@ class BasicLinear(BasicOperation): ...@@ -903,10 +921,12 @@ class BasicLinear(BasicOperation):
input_quantizer=input_quantizer, input_quantizer=input_quantizer,
weight_quantizer=weight_quantizer, weight_quantizer=weight_quantizer,
output_quantizer=output_quantizer, output_quantizer=output_quantizer,
input_requires_grad=input_requires_grad,
weight_requires_grad=weight_requires_grad,
) )
# Save state for backward pass # Save state for backward pass
ctx.save_for_backward(x_local) ctx.save_for_backward(x_local, w)
ctx.with_quantized_compute = with_quantized_compute ctx.with_quantized_compute = with_quantized_compute
ctx.input_quantizer = input_quantizer ctx.input_quantizer = input_quantizer
ctx.weight_quantizer = weight_quantizer ctx.weight_quantizer = weight_quantizer
...@@ -926,7 +946,7 @@ class BasicLinear(BasicOperation): ...@@ -926,7 +946,7 @@ class BasicLinear(BasicOperation):
) -> tuple[torch.Tensor, Iterable[Optional[torch.Tensor]]]: ) -> tuple[torch.Tensor, Iterable[Optional[torch.Tensor]]]:
# Saved tensors from forward pass # Saved tensors from forward pass
(x_local,) = ctx.saved_tensors (x_local, w) = ctx.saved_tensors
# wgrad fusion # wgrad fusion
accumulate_into_main_grad = self._accumulate_into_main_grad accumulate_into_main_grad = self._accumulate_into_main_grad
...@@ -946,7 +966,7 @@ class BasicLinear(BasicOperation): ...@@ -946,7 +966,7 @@ class BasicLinear(BasicOperation):
grad_input, grad_weight = BasicLinear._functional_backward( grad_input, grad_weight = BasicLinear._functional_backward(
grad_output=grad_output, grad_output=grad_output,
input=x_local, input=x_local,
weight=self.weight, weight=w,
input_requires_grad=ctx.input_requires_grad, input_requires_grad=ctx.input_requires_grad,
weight_requires_grad=ctx.weight_requires_grad, weight_requires_grad=ctx.weight_requires_grad,
dtype=ctx.dtype, dtype=ctx.dtype,
......
...@@ -51,7 +51,7 @@ class BackwardLinearAdd(FusedOperation): ...@@ -51,7 +51,7 @@ class BackwardLinearAdd(FusedOperation):
linear_op_ctx = basic_op_ctxs[0] linear_op_ctx = basic_op_ctxs[0]
# Saved tensors from forward pass # Saved tensors from forward pass
(x_local,) = linear_op_ctx.saved_tensors (x_local, w) = linear_op_ctx.saved_tensors
# wgrad fusion # wgrad fusion
accumulate_into_main_grad = linear_op._accumulate_into_main_grad accumulate_into_main_grad = linear_op._accumulate_into_main_grad
...@@ -72,7 +72,7 @@ class BackwardLinearAdd(FusedOperation): ...@@ -72,7 +72,7 @@ class BackwardLinearAdd(FusedOperation):
grad_input, grad_weight = BasicLinear._functional_backward( grad_input, grad_weight = BasicLinear._functional_backward(
grad_output=grad_output, grad_output=grad_output,
input=x_local, input=x_local,
weight=linear_op.weight, weight=w,
input_requires_grad=linear_op_ctx.input_requires_grad, input_requires_grad=linear_op_ctx.input_requires_grad,
weight_requires_grad=linear_op_ctx.weight_requires_grad, weight_requires_grad=linear_op_ctx.weight_requires_grad,
dtype=grad_input.dtype, dtype=grad_input.dtype,
......
...@@ -82,6 +82,10 @@ class ForwardLinearBiasActivation(FusedOperation): ...@@ -82,6 +82,10 @@ class ForwardLinearBiasActivation(FusedOperation):
else: else:
raise NotImplementedError("Activations are not yet supported") raise NotImplementedError("Activations are not yet supported")
# Check which grads are required
input_requires_grad = linear_op_ctx.requires_grad and input_.requires_grad
weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad
# FP8 metadata # FP8 metadata
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
input_quantizer = None input_quantizer = None
...@@ -106,7 +110,7 @@ class ForwardLinearBiasActivation(FusedOperation): ...@@ -106,7 +110,7 @@ class ForwardLinearBiasActivation(FusedOperation):
dtype = torch.get_autocast_dtype("cuda") dtype = torch.get_autocast_dtype("cuda")
# Linear forward # Linear forward
output, x_local, _ = BasicLinear._functional_forward( output, x_local, w = BasicLinear._functional_forward(
input=input_, input=input_,
weight=linear_op.weight, weight=linear_op.weight,
bias=bias, bias=bias,
...@@ -118,18 +122,20 @@ class ForwardLinearBiasActivation(FusedOperation): ...@@ -118,18 +122,20 @@ class ForwardLinearBiasActivation(FusedOperation):
input_quantizer=input_quantizer, input_quantizer=input_quantizer,
weight_quantizer=weight_quantizer, weight_quantizer=weight_quantizer,
output_quantizer=output_quantizer, output_quantizer=output_quantizer,
input_requires_grad=input_requires_grad,
weight_requires_grad=weight_requires_grad,
) )
# Save state for backward pass # Save state for backward pass
linear_op_ctx.save_for_backward(x_local) linear_op_ctx.save_for_backward(x_local, w)
linear_op_ctx.with_quantized_compute = with_quantized_compute linear_op_ctx.with_quantized_compute = with_quantized_compute
linear_op_ctx.input_quantizer = input_quantizer linear_op_ctx.input_quantizer = input_quantizer
linear_op_ctx.weight_quantizer = weight_quantizer linear_op_ctx.weight_quantizer = weight_quantizer
linear_op_ctx.grad_output_quantizer = grad_output_quantizer linear_op_ctx.grad_output_quantizer = grad_output_quantizer
linear_op_ctx.grad_input_quantizer = grad_input_quantizer linear_op_ctx.grad_input_quantizer = grad_input_quantizer
linear_op_ctx.dtype = dtype linear_op_ctx.dtype = dtype
linear_op_ctx.input_requires_grad = input_.requires_grad linear_op_ctx.input_requires_grad = input_requires_grad
linear_op_ctx.weight_requires_grad = linear_op.weight.requires_grad linear_op_ctx.weight_requires_grad = weight_requires_grad
linear_op_ctx.has_prev_op = basic_op_prev_ops[0] is not None linear_op_ctx.has_prev_op = basic_op_prev_ops[0] is not None
return output, [() for _ in range(len(self.basic_ops))] return output, [() for _ in range(len(self.basic_ops))]
......
...@@ -76,6 +76,10 @@ class ForwardLinearBiasAdd(FusedOperation): ...@@ -76,6 +76,10 @@ class ForwardLinearBiasAdd(FusedOperation):
if basic_op_kwargs[idx]: if basic_op_kwargs[idx]:
raise ValueError("Bias operation forward does not expect keyword arguments") raise ValueError("Bias operation forward does not expect keyword arguments")
# Check which grads are required
input_requires_grad = linear_op_ctx.requires_grad and input_.requires_grad
weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad
# FP8 metadata # FP8 metadata
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
input_quantizer = None input_quantizer = None
...@@ -98,7 +102,7 @@ class ForwardLinearBiasAdd(FusedOperation): ...@@ -98,7 +102,7 @@ class ForwardLinearBiasAdd(FusedOperation):
# Linear forward # Linear forward
output = basic_op_extra_inputs[self._op_idxs["add"]][0] output = basic_op_extra_inputs[self._op_idxs["add"]][0]
output, x_local, _ = BasicLinear._functional_forward( output, x_local, w = BasicLinear._functional_forward(
input=input_, input=input_,
weight=linear_op.weight, weight=linear_op.weight,
bias=bias, bias=bias,
...@@ -111,18 +115,20 @@ class ForwardLinearBiasAdd(FusedOperation): ...@@ -111,18 +115,20 @@ class ForwardLinearBiasAdd(FusedOperation):
input_quantizer=input_quantizer, input_quantizer=input_quantizer,
weight_quantizer=weight_quantizer, weight_quantizer=weight_quantizer,
output_quantizer=output_quantizer, output_quantizer=output_quantizer,
input_requires_grad=input_requires_grad,
weight_requires_grad=weight_requires_grad,
) )
# Save state for backward pass # Save state for backward pass
linear_op_ctx.save_for_backward(x_local) linear_op_ctx.save_for_backward(x_local, w)
linear_op_ctx.with_quantized_compute = with_quantized_compute linear_op_ctx.with_quantized_compute = with_quantized_compute
linear_op_ctx.input_quantizer = input_quantizer linear_op_ctx.input_quantizer = input_quantizer
linear_op_ctx.weight_quantizer = weight_quantizer linear_op_ctx.weight_quantizer = weight_quantizer
linear_op_ctx.grad_output_quantizer = grad_output_quantizer linear_op_ctx.grad_output_quantizer = grad_output_quantizer
linear_op_ctx.grad_input_quantizer = grad_input_quantizer linear_op_ctx.grad_input_quantizer = grad_input_quantizer
linear_op_ctx.dtype = dtype linear_op_ctx.dtype = dtype
linear_op_ctx.input_requires_grad = input_.requires_grad linear_op_ctx.input_requires_grad = input_requires_grad
linear_op_ctx.weight_requires_grad = linear_op.weight.requires_grad linear_op_ctx.weight_requires_grad = weight_requires_grad
linear_op_ctx.has_prev_op = basic_op_prev_ops[0] is not None linear_op_ctx.has_prev_op = basic_op_prev_ops[0] is not None
return output, [() for _ in range(len(self.basic_ops))] return output, [() for _ in range(len(self.basic_ops))]
......
...@@ -500,7 +500,7 @@ class UserbuffersBackwardLinear(FusedOperation): ...@@ -500,7 +500,7 @@ class UserbuffersBackwardLinear(FusedOperation):
bias_op = self.basic_ops[idx] bias_op = self.basic_ops[idx]
# Saved tensors from forward pass # Saved tensors from forward pass
(x_local,) = linear_op_ctx.saved_tensors (x_local, w) = linear_op_ctx.saved_tensors
# wgrad fusion # wgrad fusion
accumulate_into_main_grad = linear_op._accumulate_into_main_grad accumulate_into_main_grad = linear_op._accumulate_into_main_grad
...@@ -520,7 +520,7 @@ class UserbuffersBackwardLinear(FusedOperation): ...@@ -520,7 +520,7 @@ class UserbuffersBackwardLinear(FusedOperation):
retval = UserbuffersBackwardLinear._functional_backward( retval = UserbuffersBackwardLinear._functional_backward(
grad_output=grad_output, grad_output=grad_output,
input=x_local, input=x_local,
weight=linear_op.weight, weight=w,
weight_requires_grad=linear_op_ctx.weight_requires_grad, weight_requires_grad=linear_op_ctx.weight_requires_grad,
bias_requires_grad=(bias_op is not None), bias_requires_grad=(bias_op is not None),
dtype=linear_op_ctx.dtype, dtype=linear_op_ctx.dtype,
......
...@@ -98,6 +98,8 @@ class UserbuffersForwardLinear(FusedOperation): ...@@ -98,6 +98,8 @@ class UserbuffersForwardLinear(FusedOperation):
input_quantizer: Optional[Quantizer] = None, input_quantizer: Optional[Quantizer] = None,
weight_quantizer: Optional[Quantizer] = None, weight_quantizer: Optional[Quantizer] = None,
output_quantizer: Optional[Quantizer] = None, output_quantizer: Optional[Quantizer] = None,
input_requires_grad: bool = True,
weight_requires_grad: bool = True,
ub_comm_name: str, ub_comm_name: str,
) -> tuple[torch.Tensor, dict]: ) -> tuple[torch.Tensor, dict]:
"""Functional API for forward pass """Functional API for forward pass
...@@ -131,6 +133,12 @@ class UserbuffersForwardLinear(FusedOperation): ...@@ -131,6 +133,12 @@ class UserbuffersForwardLinear(FusedOperation):
Builder class for quantized weight tensor. Builder class for quantized weight tensor.
output_quantizer: Quantizer, optional output_quantizer: Quantizer, optional
Builder class for quantized output tensor. Builder class for quantized output tensor.
input_requires_grad: bool, default = `True`
Whether the loss gradient w.r.t. the input tensor is
required in the backward pass.
weight_requires_grad: bool, default = `True`
Whether the loss gradient w.r.t. the weight tensor is
required in the backward pass.
ub_comm_name: str ub_comm_name: str
Layer type (e.g. "qkv", "proj", "fc1", "fc2"). This is Layer type (e.g. "qkv", "proj", "fc1", "fc2"). This is
used to access the corresponding Userbuffers communicators used to access the corresponding Userbuffers communicators
...@@ -141,8 +149,9 @@ class UserbuffersForwardLinear(FusedOperation): ...@@ -141,8 +149,9 @@ class UserbuffersForwardLinear(FusedOperation):
torch.Tensor torch.Tensor
Output tensor Output tensor
dict dict
Extra output tensors. "input" is the input tensor, Extra output tensors. "input" is the input tensor and
possibly cast and reshaped from the provided input tensor. "weight" is the weight tensor, both ready for use in the
backward pass.
""" """
...@@ -198,7 +207,7 @@ class UserbuffersForwardLinear(FusedOperation): ...@@ -198,7 +207,7 @@ class UserbuffersForwardLinear(FusedOperation):
if with_ub_all_gather: if with_ub_all_gather:
if input_quantizer is not None: if input_quantizer is not None:
if not isinstance(x_local, QuantizedTensorBase): if not isinstance(x_local, QuantizedTensorBase):
input_quantizer.set_usage(rowwise=True, columnwise=True) input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad)
if isinstance(input_quantizer, Float8Quantizer): if isinstance(input_quantizer, Float8Quantizer):
input_quantizer.set_usage(columnwise=False) input_quantizer.set_usage(columnwise=False)
x_local = input_quantizer(x_local) x_local = input_quantizer(x_local)
...@@ -212,7 +221,7 @@ class UserbuffersForwardLinear(FusedOperation): ...@@ -212,7 +221,7 @@ class UserbuffersForwardLinear(FusedOperation):
else: else:
if with_quantized_compute: if with_quantized_compute:
if not isinstance(x_local, QuantizedTensorBase): if not isinstance(x_local, QuantizedTensorBase):
input_quantizer.set_usage(rowwise=True, columnwise=True) input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad)
x_local = input_quantizer(x_local) x_local = input_quantizer(x_local)
else: else:
if isinstance(x_local, QuantizedTensorBase): if isinstance(x_local, QuantizedTensorBase):
...@@ -225,7 +234,7 @@ class UserbuffersForwardLinear(FusedOperation): ...@@ -225,7 +234,7 @@ class UserbuffersForwardLinear(FusedOperation):
w = weight w = weight
w_is_quantized = isinstance(w, QuantizedTensorBase) w_is_quantized = isinstance(w, QuantizedTensorBase)
if with_quantized_compute and not w_is_quantized: if with_quantized_compute and not w_is_quantized:
weight_quantizer.set_usage(rowwise=True) weight_quantizer.set_usage(rowwise=True, columnwise=input_requires_grad)
w = weight_quantizer(w) w = weight_quantizer(w)
elif not with_quantized_compute and w_is_quantized: elif not with_quantized_compute and w_is_quantized:
w = w.dequantize() w = w.dequantize()
...@@ -258,17 +267,25 @@ class UserbuffersForwardLinear(FusedOperation): ...@@ -258,17 +267,25 @@ class UserbuffersForwardLinear(FusedOperation):
else: else:
y_local = gemm_output y_local = gemm_output
# Detach input tensor if needed # Prepare weight tensor for backward pass
# Note: PyTorch autograd produces esoteric errors if we save if input_requires_grad:
# input tensor as context for backward pass. if w is not weight and with_quantized_compute and isinstance(w, QuantizedTensorBase):
w.update_usage(rowwise_usage=False, columnwise_usage=True)
else:
w = None
# Prepare input tensor for backward pass
if weight_requires_grad:
if x_local is input: if x_local is input:
# PyTorch autograd produces esoteric errors if we
# cache input tensor directly.
x_local = x_local.detach() x_local = x_local.detach()
# Configure input tensor for backward pass
if with_quantized_compute and isinstance(x_local, QuantizedTensorBase): if with_quantized_compute and isinstance(x_local, QuantizedTensorBase):
if not (isinstance(x_local, Float8TensorBase) and with_ub_all_gather): if not (isinstance(x_local, Float8TensorBase) and with_ub_all_gather):
# FP8 does not support all-gather of transpose data # FP8 does not support all-gather of transpose data
x_local.update_usage(rowwise_usage=False, columnwise_usage=True) x_local.update_usage(rowwise_usage=False, columnwise_usage=True)
else:
x_local = None
# Return cast tensors # Return cast tensors
extra_outputs = {"input": x_local, "weight": w} extra_outputs = {"input": x_local, "weight": w}
...@@ -298,6 +315,10 @@ class UserbuffersForwardLinear(FusedOperation): ...@@ -298,6 +315,10 @@ class UserbuffersForwardLinear(FusedOperation):
if basic_op_kwargs[idx]: if basic_op_kwargs[idx]:
raise ValueError("Bias operation forward does not expect keyword arguments") raise ValueError("Bias operation forward does not expect keyword arguments")
# Check which grads are required
input_requires_grad = linear_op_ctx.requires_grad and input_.requires_grad
weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad
# Quantization metadata # Quantization metadata
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
input_quantizer = None input_quantizer = None
...@@ -338,12 +359,15 @@ class UserbuffersForwardLinear(FusedOperation): ...@@ -338,12 +359,15 @@ class UserbuffersForwardLinear(FusedOperation):
input_quantizer=input_quantizer, input_quantizer=input_quantizer,
weight_quantizer=weight_quantizer, weight_quantizer=weight_quantizer,
output_quantizer=None, # Not supported output_quantizer=None, # Not supported
input_requires_grad=input_requires_grad,
weight_requires_grad=weight_requires_grad,
ub_comm_name=linear_op._userbuffers_options["comm_name"], ub_comm_name=linear_op._userbuffers_options["comm_name"],
) )
x_local = extra_outputs["input"] x_local = extra_outputs["input"]
w = extra_outputs["weight"]
# Save state for backward pass # Save state for backward pass
linear_op_ctx.save_for_backward(x_local) linear_op_ctx.save_for_backward(x_local, w)
linear_op_ctx.with_quantized_compute = with_quantized_compute linear_op_ctx.with_quantized_compute = with_quantized_compute
linear_op_ctx.input_quantizer = input_quantizer linear_op_ctx.input_quantizer = input_quantizer
linear_op_ctx.weight_quantizer = weight_quantizer linear_op_ctx.weight_quantizer = weight_quantizer
...@@ -351,8 +375,8 @@ class UserbuffersForwardLinear(FusedOperation): ...@@ -351,8 +375,8 @@ class UserbuffersForwardLinear(FusedOperation):
linear_op_ctx.grad_input_quantizer = grad_input_quantizer linear_op_ctx.grad_input_quantizer = grad_input_quantizer
linear_op_ctx.dtype = dtype linear_op_ctx.dtype = dtype
linear_op_ctx.input_dims = input_.size() linear_op_ctx.input_dims = input_.size()
linear_op_ctx.input_requires_grad = input_.requires_grad linear_op_ctx.input_requires_grad = input_requires_grad
linear_op_ctx.weight_requires_grad = linear_op.weight.requires_grad linear_op_ctx.weight_requires_grad = weight_requires_grad
linear_op_ctx.has_prev_op = basic_op_prev_ops[0] is not None linear_op_ctx.has_prev_op = basic_op_prev_ops[0] is not None
return output, [() for _ in range(len(self.basic_ops))] return output, [() for _ in range(len(self.basic_ops))]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment