Unverified Commit e1edaaec authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[PyTorch] Reduce CPU overheads (#2377)



Initial changes to remove pytorch overheads
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 42d22740
...@@ -18,7 +18,6 @@ from ...quantization import FP8GlobalStateManager ...@@ -18,7 +18,6 @@ from ...quantization import FP8GlobalStateManager
from ...module.base import ( from ...module.base import (
fill_userbuffers_buffer_for_all_gather, fill_userbuffers_buffer_for_all_gather,
get_ub, get_ub,
get_workspace,
_2X_ACC_FPROP, _2X_ACC_FPROP,
) )
from ...quantized_tensor import Quantizer from ...quantized_tensor import Quantizer
...@@ -243,7 +242,6 @@ class UserbuffersForwardLinear(FusedOperation): ...@@ -243,7 +242,6 @@ class UserbuffersForwardLinear(FusedOperation):
gemm_output, *_, reduce_scatter_output = general_gemm( gemm_output, *_, reduce_scatter_output = general_gemm(
w, w,
x, x,
get_workspace(),
out_dtype=dtype, out_dtype=dtype,
quantization_params=output_quantizer, quantization_params=output_quantizer,
bias=bias, bias=bias,
......
...@@ -7,7 +7,6 @@ ...@@ -7,7 +7,6 @@
from __future__ import annotations from __future__ import annotations
from typing import Optional, Tuple, Iterable, Any, Dict, Union from typing import Optional, Tuple, Iterable, Any, Dict, Union
import abc import abc
import copy
import warnings import warnings
import math import math
...@@ -297,10 +296,6 @@ class Quantizer(abc.ABC): ...@@ -297,10 +296,6 @@ class Quantizer(abc.ABC):
if columnwise is not None: if columnwise is not None:
self.columnwise_usage = columnwise self.columnwise_usage = columnwise
def copy(self) -> Quantizer:
"""Create shallow copy"""
return copy.copy(self)
def onnx_quantize(self, tensor: torch.Tensor) -> QuantizedTensor: def onnx_quantize(self, tensor: torch.Tensor) -> QuantizedTensor:
"""Symbolic function for ONNX export""" """Symbolic function for ONNX export"""
raise NotImplementedError( raise NotImplementedError(
......
...@@ -57,6 +57,22 @@ class Float8BlockQuantizer(Quantizer): ...@@ -57,6 +57,22 @@ class Float8BlockQuantizer(Quantizer):
self.block_scaling_dim = block_scaling_dim self.block_scaling_dim = block_scaling_dim
self.all_gather_usage = all_gather_usage self.all_gather_usage = all_gather_usage
def copy(self) -> Float8BlockQuantizer:
"""Create shallow copy"""
quantizer = Float8BlockQuantizer(
fp8_dtype=self.dtype,
rowwise=self.rowwise_usage,
columnwise=self.columnwise_usage,
block_scaling_dim=self.block_scaling_dim,
all_gather_usage=self.all_gather_usage,
amax_epsilon=self.amax_epsilon,
force_pow_2_scales=self.force_pow_2_scales,
)
quantizer.internal = self.internal
return quantizer
def update_quantized( def update_quantized(
self, self,
src: torch.Tensor, src: torch.Tensor,
......
...@@ -66,6 +66,20 @@ class Float8Quantizer(Quantizer): ...@@ -66,6 +66,20 @@ class Float8Quantizer(Quantizer):
self.amax = amax self.amax = amax
self.dtype = fp8_dtype self.dtype = fp8_dtype
def copy(self) -> Float8Quantizer:
"""Create shallow copy"""
quantizer = Float8Quantizer(
scale=self.scale,
amax=self.amax,
fp8_dtype=self.dtype,
rowwise=self.rowwise_usage,
columnwise=self.columnwise_usage,
)
quantizer.internal = self.internal
return quantizer
def update_quantized( def update_quantized(
self, self,
src: torch.Tensor, src: torch.Tensor,
...@@ -245,10 +259,16 @@ class Float8CurrentScalingQuantizer(Quantizer): ...@@ -245,10 +259,16 @@ class Float8CurrentScalingQuantizer(Quantizer):
amax_reduction_group: Optional[dist_group_type] = None, amax_reduction_group: Optional[dist_group_type] = None,
force_pow_2_scales: bool = False, force_pow_2_scales: bool = False,
amax_epsilon: float = 0.0, amax_epsilon: float = 0.0,
scale: Optional[torch.Tensor] = None,
amax: Optional[torch.Tensor] = None,
) -> None: ) -> None:
super().__init__(rowwise=rowwise, columnwise=columnwise) super().__init__(rowwise=rowwise, columnwise=columnwise)
self.scale = torch.empty(1, dtype=torch.float32, device=device) if scale is None:
self.amax = torch.empty(1, dtype=torch.float32, device=device) scale = torch.empty(1, dtype=torch.float32, device=device)
if amax is None:
amax = torch.empty(1, dtype=torch.float32, device=device)
self.scale = scale
self.amax = amax
self.dtype = fp8_dtype self.dtype = fp8_dtype
self.use_existing_amax = use_existing_amax self.use_existing_amax = use_existing_amax
self.with_amax_reduction = with_amax_reduction self.with_amax_reduction = with_amax_reduction
...@@ -256,6 +276,26 @@ class Float8CurrentScalingQuantizer(Quantizer): ...@@ -256,6 +276,26 @@ class Float8CurrentScalingQuantizer(Quantizer):
self.force_pow_2_scales = force_pow_2_scales self.force_pow_2_scales = force_pow_2_scales
self.amax_epsilon = amax_epsilon self.amax_epsilon = amax_epsilon
def copy(self) -> Float8CurrentScalingQuantizer:
"""Create shallow copy"""
quantizer = Float8CurrentScalingQuantizer(
fp8_dtype=self.dtype,
device=0,
rowwise=self.rowwise_usage,
columnwise=self.columnwise_usage,
with_amax_reduction=self.with_amax_reduction,
amax_reduction_group=self.amax_reduction_group,
use_existing_amax=self.use_existing_amax,
force_pow_2_scales=self.force_pow_2_scales,
amax_epsilon=self.amax_epsilon,
scale=self.scale,
amax=self.amax,
)
quantizer.internal = self.internal
return quantizer
def update_quantized( def update_quantized(
self, self,
src: torch.Tensor, src: torch.Tensor,
......
...@@ -45,6 +45,18 @@ class MXFP8Quantizer(Quantizer): ...@@ -45,6 +45,18 @@ class MXFP8Quantizer(Quantizer):
super().__init__(rowwise=rowwise, columnwise=columnwise) super().__init__(rowwise=rowwise, columnwise=columnwise)
self.dtype = fp8_dtype self.dtype = fp8_dtype
def copy(self) -> MXFP8Quantizer:
"""Create shallow copy"""
quantizer = MXFP8Quantizer(
fp8_dtype=self.dtype,
rowwise=self.rowwise_usage,
columnwise=self.columnwise_usage,
)
quantizer.internal = self.internal
return quantizer
def update_quantized( def update_quantized(
self, self,
src: torch.Tensor, src: torch.Tensor,
......
...@@ -176,6 +176,26 @@ class NVFP4Quantizer(Quantizer): ...@@ -176,6 +176,26 @@ class NVFP4Quantizer(Quantizer):
return dst return dst
def copy(self) -> NVFP4Quantizer:
"""Create shallow copy"""
quantizer = NVFP4Quantizer(
fp4_dtype=self.dtype,
rowwise=self.rowwise_usage,
columnwise=self.columnwise_usage,
with_amax_reduction=self.with_amax_reduction,
amax_reduction_group=self.amax_reduction_group,
with_rht=self.with_rht,
with_post_rht_amax=self.with_post_rht_amax,
with_2d_quantization=self.with_2d_quantization,
stochastic_rounding=self.stochastic_rounding,
)
quantizer.internal = self.internal
quantizer.rht_matrix = self.rht_matrix
quantizer.rht_matrix_random_sign_mask_t = self.rht_matrix_random_sign_mask_t
return quantizer
def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor: def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor:
"""Quantize tensor implementation""" """Quantize tensor implementation"""
return tex.quantize(tensor, self) return tex.quantize(tensor, self)
......
...@@ -8,6 +8,7 @@ import functools ...@@ -8,6 +8,7 @@ import functools
import math import math
import os import os
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union from typing import Any, Callable, List, Optional, Sequence, Tuple, Union
from contextlib import nullcontext
import numpy as np import numpy as np
import torch import torch
...@@ -592,6 +593,24 @@ def _nvtx_enabled() -> bool: ...@@ -592,6 +593,24 @@ def _nvtx_enabled() -> bool:
_nvtx_range_messages: list[str] = [] _nvtx_range_messages: list[str] = []
def get_nvtx_range_context(msg: str):
"""Get NVTX context manager to tag module forward and backward passes.
Set `NVTE_NVTX_ENABLED=1` in the environment to enable NVTX
context manager for module level profiling tags.
Parameters
----------
msg: str
Message to associate with profiling context.
"""
if _nvtx_enabled():
return torch.cuda.nvtx.range(msg)
return nullcontext()
def nvtx_range_push(msg: str) -> None: def nvtx_range_push(msg: str) -> None:
"""Push NVTX range onto stack, if NVTX range profiling is enabled """Push NVTX range onto stack, if NVTX range profiling is enabled
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment