Unverified Commit 48f3ca90 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Avoid unnecessary tensor usages when caching for linear op backward (#1676)



* Avoid unnecessary tensor usages when caching for linear op backward
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug test failure
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
parent 4c9626e7
......@@ -413,7 +413,6 @@ class BasicLinear(BasicOperation):
x = None
x_async = None
with_x_all_gather = tensor_parallel_mode == "column" and sequence_parallel
own_quantized_x_local = False
if with_quantized_compute:
if input_quantizer is None:
raise ValueError("Missing quantizer for input tensor")
......@@ -429,7 +428,6 @@ class BasicLinear(BasicOperation):
else:
if not isinstance(x_local, QuantizedTensor):
x_local = input_quantizer(x_local)
own_quantized_x_local = True
x = x_local
else:
if isinstance(x_local, QuantizedTensor):
......@@ -528,16 +526,16 @@ class BasicLinear(BasicOperation):
else:
torch.distributed.all_reduce(y, group=tensor_parallel_group)
# Configure input tensor for backward pass
if own_quantized_x_local:
x_local.update_usage(rowwise_usage=False, columnwise_usage=True)
# Detach input tensor if needed
# Note: PyTorch autograd produces esoteric errors if we save
# input tensor as context for backward pass.
if x_local is input:
x_local = x_local.detach()
# Configure input tensor for backward pass
if with_quantized_compute and isinstance(x_local, QuantizedTensor):
x_local.update_usage(rowwise_usage=False, columnwise_usage=True)
return y, x_local, w
@staticmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment