Unverified Commit 66f9b3cb authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[PyTorch] Unblock fused bgrad quantization path for nvfp4 (#2246)



Unblock path for fusing NVFP4 quantize and bgrad
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent ac5e868f
...@@ -40,7 +40,6 @@ from ..distributed import ( ...@@ -40,7 +40,6 @@ from ..distributed import (
from ..constants import dist_group_type from ..constants import dist_group_type
from ..tensor.quantized_tensor import QuantizedTensor, QuantizedTensorStorage, Quantizer from ..tensor.quantized_tensor import QuantizedTensor, QuantizedTensorStorage, Quantizer
from ..tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer from ..tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer
from ..tensor.nvfp4_tensor import NVFP4Quantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ..tensor.storage.float8_tensor_storage import Float8TensorStorage from ..tensor.storage.float8_tensor_storage import Float8TensorStorage
...@@ -1229,8 +1228,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1229,8 +1228,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
): ):
grad_bias = grad_output.dequantize().view(-1, grad_output.shape[-1]).sum(dim=0) grad_bias = grad_output.dequantize().view(-1, grad_output.shape[-1]).sum(dim=0)
else: else:
# TODO(ksivaman): Re-add fusion once kernel is available. if isinstance(quantizer, Float8BlockQuantizer):
if isinstance(quantizer, (Float8BlockQuantizer, NVFP4Quantizer)):
# unfuse bgrad for now until cast_transpose + dgrad calculation is ready for Float8BlockQuantizer. # unfuse bgrad for now until cast_transpose + dgrad calculation is ready for Float8BlockQuantizer.
grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0) grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0)
else: else:
......
...@@ -1037,11 +1037,8 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1037,11 +1037,8 @@ class _LayerNormMLP(torch.autograd.Function):
if ctx.fp8: if ctx.fp8:
# TODO float8 blockwise current scaling (as well as custom quantizers) has no bgrad fusion for now # TODO float8 blockwise current scaling (as well as custom quantizers) has no bgrad fusion for now
# TODO(ksivaman): Re-add fusion once kernel is available.
if ( if (
isinstance( isinstance(ctx.fc1_grad_output_quantizer, Float8BlockQuantizer)
ctx.fc1_grad_output_quantizer, (Float8BlockQuantizer, NVFP4Quantizer)
)
or ctx.fp8_recipe.custom() or ctx.fp8_recipe.custom()
): ):
fc1_bias_grad = dact.view(-1, dact.shape[-1]).sum(dim=0) fc1_bias_grad = dact.view(-1, dact.shape[-1]).sum(dim=0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment