Unverified Commit 20e95ba3 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Explicitly specify quantized tensor usages needed for linear op backward (#1646)



Explicitly specify quantized tensor usages needed for linear op backward
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
parent 962d9c53
......@@ -1420,15 +1420,17 @@ class TestBasicOps:
test_device=device,
test_is_fp8=quantized_compute,
)
if quantized_compute:
with torch.no_grad():
x_test = x_test.dequantize().requires_grad_()
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
test_dtype=dtype,
test_device=device,
test_is_fp8=quantized_compute,
requires_grad=False,
)
if quantized_compute:
with torch.no_grad():
x_test = x_test.dequantize().requires_grad_()
dy_test = dy_test.dequantize()
# Plain PyTorch implementation
y_ref: torch.Tensor
......@@ -1459,6 +1461,7 @@ class TestBasicOps:
swiglu=te_ops.SwiGLU,
)[activation]
forward = te_ops.Sequential(
te_ops.Quantize(forward=False, backward=quantized_compute),
make_op(),
te_ops.Quantize(forward=quantized_compute, backward=False),
)
......
......@@ -523,7 +523,7 @@ class BasicLinear(BasicOperation):
# Configure input tensor for backward pass
if own_quantized_x_local:
x_local.update_usage(rowwise_usage=False)
x_local.update_usage(rowwise_usage=False, columnwise_usage=True)
# Detach input tensor if needed
# Note: PyTorch autograd produces esoteric errors if we save
......@@ -679,7 +679,9 @@ class BasicLinear(BasicOperation):
quantizer=input_quantizer,
)
else:
if not isinstance(x_local, QuantizedTensor):
if isinstance(x_local, QuantizedTensor):
x_local.update_usage(columnwise_usage=True)
else:
x_local = input_quantizer(x_local)
x = x_local
else:
......@@ -706,15 +708,19 @@ class BasicLinear(BasicOperation):
raise ValueError("Weight tensor is required to compute input grad")
w = weight
w_is_quantized = isinstance(w, QuantizedTensor)
if with_quantized_compute and not w_is_quantized:
if weight_quantizer is None:
raise ValueError("Missing quantizer for weight tensor")
weight_quantizer.set_usage(columnwise=True)
w = weight_quantizer(w)
elif not with_quantized_compute and w_is_quantized:
w = w.dequantize()
if not with_quantized_compute and w.dtype != dtype:
w = w.to(dtype=dtype)
if with_quantized_compute:
if w_is_quantized:
w.update_usage(columnwise_usage=True)
else:
if weight_quantizer is None:
raise ValueError("Missing quantizer for weight tensor")
weight_quantizer.set_usage(columnwise=True)
w = weight_quantizer(w)
else:
if w_is_quantized:
w = w.dequantize(dtype=dtype)
elif w.dtype != dtype:
w = w.to(dtype=dtype)
# Synchronize tensor-parallel communication
_wait_async(dy_async)
......@@ -867,8 +873,8 @@ class BasicLinear(BasicOperation):
# Configure quantizers
# Note: We cache the quantized input for backward pass,
# but discard the quantized weights.
input_quantizer.set_usage(columnwise=weight_requires_grad)
weight_quantizer.set_usage(columnwise=False)
input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad)
weight_quantizer.set_usage(rowwise=True, columnwise=False)
# Get autocast dtype if needed
dtype = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment