Unverified Commit e1edaaec authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[PyTorch] Reduce CPU overheads (#2377)



Initial changes to remove pytorch overheads
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 42d22740
......@@ -18,7 +18,6 @@ from ...quantization import FP8GlobalStateManager
from ...module.base import (
fill_userbuffers_buffer_for_all_gather,
get_ub,
get_workspace,
_2X_ACC_FPROP,
)
from ...quantized_tensor import Quantizer
......@@ -243,7 +242,6 @@ class UserbuffersForwardLinear(FusedOperation):
gemm_output, *_, reduce_scatter_output = general_gemm(
w,
x,
get_workspace(),
out_dtype=dtype,
quantization_params=output_quantizer,
bias=bias,
......
......@@ -7,7 +7,6 @@
from __future__ import annotations
from typing import Optional, Tuple, Iterable, Any, Dict, Union
import abc
import copy
import warnings
import math
......@@ -297,10 +296,6 @@ class Quantizer(abc.ABC):
if columnwise is not None:
self.columnwise_usage = columnwise
def copy(self) -> Quantizer:
"""Create shallow copy"""
return copy.copy(self)
def onnx_quantize(self, tensor: torch.Tensor) -> QuantizedTensor:
"""Symbolic function for ONNX export"""
raise NotImplementedError(
......
......@@ -57,6 +57,22 @@ class Float8BlockQuantizer(Quantizer):
self.block_scaling_dim = block_scaling_dim
self.all_gather_usage = all_gather_usage
def copy(self) -> Float8BlockQuantizer:
"""Create shallow copy"""
quantizer = Float8BlockQuantizer(
fp8_dtype=self.dtype,
rowwise=self.rowwise_usage,
columnwise=self.columnwise_usage,
block_scaling_dim=self.block_scaling_dim,
all_gather_usage=self.all_gather_usage,
amax_epsilon=self.amax_epsilon,
force_pow_2_scales=self.force_pow_2_scales,
)
quantizer.internal = self.internal
return quantizer
def update_quantized(
self,
src: torch.Tensor,
......
......@@ -66,6 +66,20 @@ class Float8Quantizer(Quantizer):
self.amax = amax
self.dtype = fp8_dtype
def copy(self) -> Float8Quantizer:
"""Create shallow copy"""
quantizer = Float8Quantizer(
scale=self.scale,
amax=self.amax,
fp8_dtype=self.dtype,
rowwise=self.rowwise_usage,
columnwise=self.columnwise_usage,
)
quantizer.internal = self.internal
return quantizer
def update_quantized(
self,
src: torch.Tensor,
......@@ -245,10 +259,16 @@ class Float8CurrentScalingQuantizer(Quantizer):
amax_reduction_group: Optional[dist_group_type] = None,
force_pow_2_scales: bool = False,
amax_epsilon: float = 0.0,
scale: Optional[torch.Tensor] = None,
amax: Optional[torch.Tensor] = None,
) -> None:
super().__init__(rowwise=rowwise, columnwise=columnwise)
self.scale = torch.empty(1, dtype=torch.float32, device=device)
self.amax = torch.empty(1, dtype=torch.float32, device=device)
if scale is None:
scale = torch.empty(1, dtype=torch.float32, device=device)
if amax is None:
amax = torch.empty(1, dtype=torch.float32, device=device)
self.scale = scale
self.amax = amax
self.dtype = fp8_dtype
self.use_existing_amax = use_existing_amax
self.with_amax_reduction = with_amax_reduction
......@@ -256,6 +276,26 @@ class Float8CurrentScalingQuantizer(Quantizer):
self.force_pow_2_scales = force_pow_2_scales
self.amax_epsilon = amax_epsilon
def copy(self) -> Float8CurrentScalingQuantizer:
"""Create shallow copy"""
quantizer = Float8CurrentScalingQuantizer(
fp8_dtype=self.dtype,
device=0,
rowwise=self.rowwise_usage,
columnwise=self.columnwise_usage,
with_amax_reduction=self.with_amax_reduction,
amax_reduction_group=self.amax_reduction_group,
use_existing_amax=self.use_existing_amax,
force_pow_2_scales=self.force_pow_2_scales,
amax_epsilon=self.amax_epsilon,
scale=self.scale,
amax=self.amax,
)
quantizer.internal = self.internal
return quantizer
def update_quantized(
self,
src: torch.Tensor,
......
......@@ -45,6 +45,18 @@ class MXFP8Quantizer(Quantizer):
super().__init__(rowwise=rowwise, columnwise=columnwise)
self.dtype = fp8_dtype
def copy(self) -> MXFP8Quantizer:
"""Create shallow copy"""
quantizer = MXFP8Quantizer(
fp8_dtype=self.dtype,
rowwise=self.rowwise_usage,
columnwise=self.columnwise_usage,
)
quantizer.internal = self.internal
return quantizer
def update_quantized(
self,
src: torch.Tensor,
......
......@@ -176,6 +176,26 @@ class NVFP4Quantizer(Quantizer):
return dst
def copy(self) -> NVFP4Quantizer:
"""Create shallow copy"""
quantizer = NVFP4Quantizer(
fp4_dtype=self.dtype,
rowwise=self.rowwise_usage,
columnwise=self.columnwise_usage,
with_amax_reduction=self.with_amax_reduction,
amax_reduction_group=self.amax_reduction_group,
with_rht=self.with_rht,
with_post_rht_amax=self.with_post_rht_amax,
with_2d_quantization=self.with_2d_quantization,
stochastic_rounding=self.stochastic_rounding,
)
quantizer.internal = self.internal
quantizer.rht_matrix = self.rht_matrix
quantizer.rht_matrix_random_sign_mask_t = self.rht_matrix_random_sign_mask_t
return quantizer
def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor:
"""Quantize tensor implementation"""
return tex.quantize(tensor, self)
......
......@@ -8,6 +8,7 @@ import functools
import math
import os
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union
from contextlib import nullcontext
import numpy as np
import torch
......@@ -592,6 +593,24 @@ def _nvtx_enabled() -> bool:
_nvtx_range_messages: list[str] = []
def get_nvtx_range_context(msg: str):
"""Get NVTX context manager to tag module forward and backward passes.
Set `NVTE_NVTX_ENABLED=1` in the environment to enable NVTX
context manager for module level profiling tags.
Parameters
----------
msg: str
Message to associate with profiling context.
"""
if _nvtx_enabled():
return torch.cuda.nvtx.range(msg)
return nullcontext()
def nvtx_range_push(msg: str) -> None:
"""Push NVTX range onto stack, if NVTX range profiling is enabled
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment