Unverified Commit 3bd8335b authored by Wentao Ye's avatar Wentao Ye Committed by GitHub
Browse files

[Refactor] Refactor for `DeepGemmQuantScaleFMT` using cache (#30898)


Signed-off-by: default avataryewentao256 <zhyanwentao@126.com>
parent 1ab52135
...@@ -31,6 +31,7 @@ from vllm.model_executor.utils import replace_parameter ...@@ -31,6 +31,7 @@ from vllm.model_executor.utils import replace_parameter
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils.deep_gemm import ( from vllm.utils.deep_gemm import (
DeepGemmQuantScaleFMT,
fp8_gemm_nt, fp8_gemm_nt,
is_deep_gemm_e8m0_used, is_deep_gemm_e8m0_used,
is_deep_gemm_supported, is_deep_gemm_supported,
...@@ -247,7 +248,6 @@ class W8A8BlockFp8LinearOp: ...@@ -247,7 +248,6 @@ class W8A8BlockFp8LinearOp:
self.act_quant_group_shape = act_quant_group_shape self.act_quant_group_shape = act_quant_group_shape
self.is_deep_gemm_supported = is_deep_gemm_supported() self.is_deep_gemm_supported = is_deep_gemm_supported()
self.is_hopper = current_platform.is_device_capability(90) self.is_hopper = current_platform.is_device_capability(90)
self.is_blackwell = current_platform.is_device_capability_family(100)
self.use_deep_gemm_e8m0 = is_deep_gemm_e8m0_used() self.use_deep_gemm_e8m0 = is_deep_gemm_e8m0_used()
# Get the correct blockscale mul and input quant operations. # Get the correct blockscale mul and input quant operations.
...@@ -303,7 +303,7 @@ class W8A8BlockFp8LinearOp: ...@@ -303,7 +303,7 @@ class W8A8BlockFp8LinearOp:
weight: torch.Tensor, weight: torch.Tensor,
weight_scale: torch.Tensor, weight_scale: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
if self.use_deep_gemm_e8m0 and self.is_blackwell: if DeepGemmQuantScaleFMT.from_oracle() == DeepGemmQuantScaleFMT.UE8M0:
q_input, input_scale = per_token_group_quant_fp8_packed_for_deepgemm( q_input, input_scale = per_token_group_quant_fp8_packed_for_deepgemm(
input_2d, input_2d,
group_size=self.act_quant_group_shape.col, group_size=self.act_quant_group_shape.col,
......
...@@ -32,16 +32,35 @@ class DeepGemmQuantScaleFMT(Enum): ...@@ -32,16 +32,35 @@ class DeepGemmQuantScaleFMT(Enum):
# element contains 4 scale values. # element contains 4 scale values.
UE8M0 = 2 UE8M0 = 2
@staticmethod @classmethod
def from_oracle() -> "DeepGemmQuantScaleFMT": def init_oracle_cache(cls) -> None:
if not is_deep_gemm_e8m0_used(): """Initialize the oracle decision and store it in the class cache"""
return DeepGemmQuantScaleFMT.FLOAT32 cached = getattr(cls, "_oracle_cache", None)
return ( if cached is not None:
DeepGemmQuantScaleFMT.UE8M0 return
use_e8m0 = (
envs.VLLM_USE_DEEP_GEMM_E8M0
and is_deep_gemm_supported()
and (_fp8_gemm_nt_impl is not None)
)
if not use_e8m0:
cls._oracle_cache = cls.FLOAT32 # type: ignore
return
cls._oracle_cache = ( # type: ignore
cls.UE8M0
if current_platform.is_device_capability_family(100) if current_platform.is_device_capability_family(100)
else DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0 else cls.FLOAT32_CEIL_UE8M0
) )
@classmethod
def from_oracle(cls) -> "DeepGemmQuantScaleFMT":
"""Return the pre-initialized oracle decision"""
cached = getattr(cls, "_oracle_cache", None)
assert cached is not None, "DeepGemmQuantScaleFMT oracle cache not initialized"
return cached
@functools.cache @functools.cache
def is_deep_gemm_supported() -> bool: def is_deep_gemm_supported() -> bool:
...@@ -149,6 +168,7 @@ def _lazy_init() -> None: ...@@ -149,6 +168,7 @@ def _lazy_init() -> None:
_transform_sf_into_required_layout_impl = getattr( _transform_sf_into_required_layout_impl = getattr(
_dg, "transform_sf_into_required_layout", None _dg, "transform_sf_into_required_layout", None
) )
DeepGemmQuantScaleFMT.init_oracle_cache()
def get_num_sms() -> int: def get_num_sms() -> int:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment