Unverified Commit e4c051f1 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Activation ops support fusing backward pass with quantize (#1804)



Activation ops support fusing backward pass with quantize
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
parent 1669b3f4
......@@ -96,12 +96,15 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
if not x.is_contiguous():
x = x.contiguous()
# Check if FP8 is enabled
fp8_enabled = FP8GlobalStateManager.is_fp8_enabled()
if fp8_enabled and next_op is not None and next_op.num_quantizers("forward") > 0:
# Check if quantized compute is enabled
quantized_compute_enabled = FP8GlobalStateManager.is_fp8_enabled()
quantizer = None
if (
quantized_compute_enabled
and next_op is not None
and next_op.num_quantizers("forward") > 0
):
quantizer = next_op.get_quantizer("forward", 0)
else:
quantizer = None
# Launch kernel
y = self._activation_forward_impl(
......@@ -115,13 +118,13 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
# Quantize input to FP8 before caching if needed
if self.cache_quantized_input:
quantizer = Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, x.device)
quantizer.set_usage(rowwise=True, columnwise=False)
x = quantizer(x)
input_quantizer = Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, x.device)
input_quantizer.set_usage(rowwise=True, columnwise=False)
x = input_quantizer(x)
# Save state for backward pass
ctx.save_for_backward(x.detach())
ctx.fp8_enabled = fp8_enabled
ctx.quantized_compute_enabled = quantized_compute_enabled
ctx.dtype = dtype
ctx.prev_op = prev_op
......@@ -153,11 +156,20 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
if not dy.is_contiguous():
dy = dy.contiguous()
# Check if quantized compute is enabled
quantizer = None
if (
ctx.quantized_compute_enabled
and ctx.prev_op is not None
and ctx.prev_op.num_quantizers("backward") > 0
):
quantizer = ctx.prev_op.get_quantizer("backward", 0)
# Launch kernel
dx = self._activation_backward_impl(
reshape(dy, (-1, dy.size(-1))),
reshape(x, (-1, x.size(-1))),
None,
quantizer,
)
# Check grad input tensor
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment