Unverified Commit 20e95ba3 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Explicitly specify quantized tensor usages needed for linear op backward (#1646)



Explicitly specify quantized tensor usages needed for linear op backward
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
parent 962d9c53
...@@ -1420,15 +1420,17 @@ class TestBasicOps: ...@@ -1420,15 +1420,17 @@ class TestBasicOps:
test_device=device, test_device=device,
test_is_fp8=quantized_compute, test_is_fp8=quantized_compute,
) )
if quantized_compute:
with torch.no_grad():
x_test = x_test.dequantize().requires_grad_()
dy_ref, dy_test = make_reference_and_test_tensors( dy_ref, dy_test = make_reference_and_test_tensors(
out_shape, out_shape,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
test_is_fp8=quantized_compute,
requires_grad=False, requires_grad=False,
) )
if quantized_compute:
with torch.no_grad():
x_test = x_test.dequantize().requires_grad_()
dy_test = dy_test.dequantize()
# Plain PyTorch implementation # Plain PyTorch implementation
y_ref: torch.Tensor y_ref: torch.Tensor
...@@ -1459,6 +1461,7 @@ class TestBasicOps: ...@@ -1459,6 +1461,7 @@ class TestBasicOps:
swiglu=te_ops.SwiGLU, swiglu=te_ops.SwiGLU,
)[activation] )[activation]
forward = te_ops.Sequential( forward = te_ops.Sequential(
te_ops.Quantize(forward=False, backward=quantized_compute),
make_op(), make_op(),
te_ops.Quantize(forward=quantized_compute, backward=False), te_ops.Quantize(forward=quantized_compute, backward=False),
) )
......
...@@ -523,7 +523,7 @@ class BasicLinear(BasicOperation): ...@@ -523,7 +523,7 @@ class BasicLinear(BasicOperation):
# Configure input tensor for backward pass # Configure input tensor for backward pass
if own_quantized_x_local: if own_quantized_x_local:
x_local.update_usage(rowwise_usage=False) x_local.update_usage(rowwise_usage=False, columnwise_usage=True)
# Detach input tensor if needed # Detach input tensor if needed
# Note: PyTorch autograd produces esoteric errors if we save # Note: PyTorch autograd produces esoteric errors if we save
...@@ -679,7 +679,9 @@ class BasicLinear(BasicOperation): ...@@ -679,7 +679,9 @@ class BasicLinear(BasicOperation):
quantizer=input_quantizer, quantizer=input_quantizer,
) )
else: else:
if not isinstance(x_local, QuantizedTensor): if isinstance(x_local, QuantizedTensor):
x_local.update_usage(columnwise_usage=True)
else:
x_local = input_quantizer(x_local) x_local = input_quantizer(x_local)
x = x_local x = x_local
else: else:
...@@ -706,14 +708,18 @@ class BasicLinear(BasicOperation): ...@@ -706,14 +708,18 @@ class BasicLinear(BasicOperation):
raise ValueError("Weight tensor is required to compute input grad") raise ValueError("Weight tensor is required to compute input grad")
w = weight w = weight
w_is_quantized = isinstance(w, QuantizedTensor) w_is_quantized = isinstance(w, QuantizedTensor)
if with_quantized_compute and not w_is_quantized: if with_quantized_compute:
if w_is_quantized:
w.update_usage(columnwise_usage=True)
else:
if weight_quantizer is None: if weight_quantizer is None:
raise ValueError("Missing quantizer for weight tensor") raise ValueError("Missing quantizer for weight tensor")
weight_quantizer.set_usage(columnwise=True) weight_quantizer.set_usage(columnwise=True)
w = weight_quantizer(w) w = weight_quantizer(w)
elif not with_quantized_compute and w_is_quantized: else:
w = w.dequantize() if w_is_quantized:
if not with_quantized_compute and w.dtype != dtype: w = w.dequantize(dtype=dtype)
elif w.dtype != dtype:
w = w.to(dtype=dtype) w = w.to(dtype=dtype)
# Synchronize tensor-parallel communication # Synchronize tensor-parallel communication
...@@ -867,8 +873,8 @@ class BasicLinear(BasicOperation): ...@@ -867,8 +873,8 @@ class BasicLinear(BasicOperation):
# Configure quantizers # Configure quantizers
# Note: We cache the quantized input for backward pass, # Note: We cache the quantized input for backward pass,
# but discard the quantized weights. # but discard the quantized weights.
input_quantizer.set_usage(columnwise=weight_requires_grad) input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad)
weight_quantizer.set_usage(columnwise=False) weight_quantizer.set_usage(rowwise=True, columnwise=False)
# Get autocast dtype if needed # Get autocast dtype if needed
dtype = None dtype = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment