Unverified Commit ad8818bb authored by Lucas Kabela's avatar Lucas Kabela Committed by GitHub
Browse files

[Misc][BE] Type coverage for vllm/compilation [3/3] (#31748)


Signed-off-by: default avatarLucas Kabela <lucaskabela@meta.com>
parent 08e8e99c
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any
import torch import torch
from torch._higher_order_ops.auto_functionalize import auto_functionalized from torch._higher_order_ops.auto_functionalize import auto_functionalized
...@@ -52,7 +53,7 @@ class ActivationQuantPattern(ABC): ...@@ -52,7 +53,7 @@ class ActivationQuantPattern(ABC):
def __init__( def __init__(
self, self,
quant_key: QuantKey, quant_key: QuantKey,
): ) -> None:
self.quant_key = quant_key self.quant_key = quant_key
self.quant_dtype = quant_key.dtype self.quant_dtype = quant_key.dtype
...@@ -68,12 +69,12 @@ class ActivationQuantPattern(ABC): ...@@ -68,12 +69,12 @@ class ActivationQuantPattern(ABC):
self.silu_and_mul_matcher = MatcherSiluAndMul() self.silu_and_mul_matcher = MatcherSiluAndMul()
def empty_quant(self, *args, **kwargs): def empty_quant(self, *args: Any, **kwargs: Any) -> torch.Tensor:
kwargs = {"dtype": self.quant_dtype, "device": "cuda", **kwargs} kwargs = {"dtype": self.quant_dtype, "device": "cuda", **kwargs}
return torch.empty(*args, **kwargs) return torch.empty(*args, **kwargs)
@abstractmethod @abstractmethod
def register(self, pm_pass: PatternMatcherPass): def register(self, pm_pass: PatternMatcherPass) -> None:
raise NotImplementedError raise NotImplementedError
...@@ -82,15 +83,22 @@ class SiluMulFp8StaticQuantPattern(ActivationQuantPattern): ...@@ -82,15 +83,22 @@ class SiluMulFp8StaticQuantPattern(ActivationQuantPattern):
Fusion for SiluMul+Fp8StaticQuant Pattern Fusion for SiluMul+Fp8StaticQuant Pattern
""" """
def __init__(self): def __init__(self) -> None:
super().__init__(kFp8StaticTensorSym) super().__init__(kFp8StaticTensorSym)
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym) self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
def register(self, pm_pass: PatternMatcherPass): def get_inputs(self) -> list[torch.Tensor]:
scale = self.quant_matcher.inputs()[1]
return [
*self.silu_and_mul_matcher.inputs(), # input
scale,
]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern( def pattern(
input: torch.Tensor, input: torch.Tensor,
scale: torch.Tensor, scale: torch.Tensor,
): ) -> torch.Tensor:
result_silu_mul = self.silu_and_mul_matcher(input) result_silu_mul = self.silu_and_mul_matcher(input)
result_quant = self.quant_matcher(result_silu_mul, scale) result_quant = self.quant_matcher(result_silu_mul, scale)
return result_quant[0] return result_quant[0]
...@@ -98,7 +106,7 @@ class SiluMulFp8StaticQuantPattern(ActivationQuantPattern): ...@@ -98,7 +106,7 @@ class SiluMulFp8StaticQuantPattern(ActivationQuantPattern):
def replacement( def replacement(
input: torch.Tensor, input: torch.Tensor,
scale: torch.Tensor, scale: torch.Tensor,
): ) -> torch.Tensor:
d = input.shape[-1] // 2 d = input.shape[-1] // 2
output_shape = input.shape[:-1] + (d,) output_shape = input.shape[:-1] + (d,)
result = torch.empty( result = torch.empty(
...@@ -109,13 +117,10 @@ class SiluMulFp8StaticQuantPattern(ActivationQuantPattern): ...@@ -109,13 +117,10 @@ class SiluMulFp8StaticQuantPattern(ActivationQuantPattern):
) )
return at[1] return at[1]
inputs = [ inps = self.get_inputs()
*self.silu_and_mul_matcher.inputs(), # input pattern(*inps)
self.quant_matcher.inputs()[1], # scale
]
pattern(*inputs)
register_replacement(pattern, replacement, inputs, fwd_only, pm_pass) register_replacement(pattern, replacement, inps, fwd_only, pm_pass)
class SiluMulNvfp4QuantPattern(ActivationQuantPattern): class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
...@@ -123,16 +128,23 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern): ...@@ -123,16 +128,23 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
Fusion for SiluMul+Nvfp4Quant Pattern Fusion for SiluMul+Nvfp4Quant Pattern
""" """
def __init__(self): def __init__(self) -> None:
super().__init__(kNvfp4Quant) super().__init__(kNvfp4Quant)
def register(self, pm_pass: PatternMatcherPass): def get_inputs(self) -> list[torch.Tensor]:
result = self.empty_quant(5, 32)
output_scale = empty_i32(128, 4)
input_ = empty_bf16(5, 64)
scale = empty_fp32(1, 1)
return [result, output_scale, input_, scale]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern( def pattern(
result: torch.Tensor, result: torch.Tensor,
output_scale: torch.Tensor, output_scale: torch.Tensor,
input: torch.Tensor, input: torch.Tensor,
scale: torch.Tensor, scale: torch.Tensor,
): ) -> tuple[torch.Tensor, torch.Tensor]:
result_silu_mul = self.silu_and_mul_matcher(input) result_silu_mul = self.silu_and_mul_matcher(input)
at = auto_functionalized( at = auto_functionalized(
self.QUANT_OP, self.QUANT_OP,
...@@ -148,7 +160,7 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern): ...@@ -148,7 +160,7 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
output_scale: torch.Tensor, output_scale: torch.Tensor,
input: torch.Tensor, input: torch.Tensor,
scale: torch.Tensor, scale: torch.Tensor,
): ) -> tuple[torch.Tensor, torch.Tensor]:
at = auto_functionalized( at = auto_functionalized(
self.FUSED_OP, self.FUSED_OP,
result=result, result=result,
...@@ -158,14 +170,7 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern): ...@@ -158,14 +170,7 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
) )
return at[1], at[2] return at[1], at[2]
inputs = [ register_replacement(pattern, replacement, self.get_inputs(), fwd_only, pm_pass)
self.empty_quant(5, 32), # result
empty_i32(128, 4), # output_scale
empty_bf16(5, 64), # input
empty_fp32(1, 1), # scale
]
register_replacement(pattern, replacement, inputs, fwd_only, pm_pass)
class ActivationQuantFusionPass(VllmPatternMatcherPass): class ActivationQuantFusionPass(VllmPatternMatcherPass):
...@@ -179,7 +184,7 @@ class ActivationQuantFusionPass(VllmPatternMatcherPass): ...@@ -179,7 +184,7 @@ class ActivationQuantFusionPass(VllmPatternMatcherPass):
""" """
@enable_fake_mode @enable_fake_mode
def __init__(self, config: VllmConfig): def __init__(self, config: VllmConfig) -> None:
super().__init__(config) super().__init__(config)
self.patterns: PatternMatcherPass = PatternMatcherPass( self.patterns: PatternMatcherPass = PatternMatcherPass(
...@@ -196,11 +201,11 @@ class ActivationQuantFusionPass(VllmPatternMatcherPass): ...@@ -196,11 +201,11 @@ class ActivationQuantFusionPass(VllmPatternMatcherPass):
self.dump_patterns(config, self.patterns) self.dump_patterns(config, self.patterns)
@VllmInductorPass.time_and_log @VllmInductorPass.time_and_log
def __call__(self, graph: torch.fx.Graph): def __call__(self, graph: torch.fx.Graph) -> None:
self.matched_count = self.patterns.apply(graph) self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count) logger.debug("Replaced %s patterns", self.matched_count)
def uuid(self): def uuid(self) -> str:
return VllmInductorPass.hash_source( return VllmInductorPass.hash_source(
self, self,
ActivationQuantPattern, ActivationQuantPattern,
......
This diff is collapsed.
...@@ -38,19 +38,19 @@ FP8_DTYPE = current_platform.fp8_dtype() ...@@ -38,19 +38,19 @@ FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8 FP4_DTYPE = torch.uint8
def empty_bf16(*args, **kwargs): def empty_bf16(*args: Any, **kwargs: Any) -> torch.Tensor:
return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda") return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda")
def empty_fp32(*args, **kwargs): def empty_fp32(*args: Any, **kwargs: Any) -> torch.Tensor:
return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda") return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda")
def empty_i32(*args, **kwargs): def empty_i32(*args: Any, **kwargs: Any) -> torch.Tensor:
return torch.empty(*args, **kwargs, dtype=torch.int32, device="cuda") return torch.empty(*args, **kwargs, dtype=torch.int32, device="cuda")
def empty_i64(*args, **kwargs): def empty_i64(*args: Any, **kwargs: Any) -> torch.Tensor:
return torch.empty(*args, **kwargs, dtype=torch.int64, device="cuda") return torch.empty(*args, **kwargs, dtype=torch.int64, device="cuda")
...@@ -79,7 +79,7 @@ class FusedRMSQuantKey(NamedTuple): ...@@ -79,7 +79,7 @@ class FusedRMSQuantKey(NamedTuple):
quant: QuantKey quant: QuantKey
fused_add: bool fused_add: bool
def __str__(self): def __str__(self) -> str:
return ( return (
f"FusedQuantKey({self.quant}, with" f"FusedQuantKey({self.quant}, with"
f"{'' if self.fused_add else 'out'} residual)" f"{'' if self.fused_add else 'out'} residual)"
...@@ -121,7 +121,7 @@ class RMSNormQuantPattern: ...@@ -121,7 +121,7 @@ class RMSNormQuantPattern:
key: FusedRMSQuantKey, key: FusedRMSQuantKey,
has_col_major_scales: bool = False, has_col_major_scales: bool = False,
is_e8m0: bool = False, is_e8m0: bool = False,
): ) -> None:
self.epsilon = epsilon self.epsilon = epsilon
self.quant_dtype = key.quant.dtype self.quant_dtype = key.quant.dtype
config = get_current_vllm_config() config = get_current_vllm_config()
...@@ -141,7 +141,9 @@ class RMSNormQuantPattern: ...@@ -141,7 +141,9 @@ class RMSNormQuantPattern:
class RMSNormStaticQuantPattern(RMSNormQuantPattern): class RMSNormStaticQuantPattern(RMSNormQuantPattern):
def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True): def __init__(
self, epsilon: float, quant_dtype: torch.dtype, symmetric: bool = True
) -> None:
fused_key = FusedRMSQuantKey( fused_key = FusedRMSQuantKey(
fused_add=False, fused_add=False,
quant=QuantKey( quant=QuantKey(
...@@ -150,13 +152,17 @@ class RMSNormStaticQuantPattern(RMSNormQuantPattern): ...@@ -150,13 +152,17 @@ class RMSNormStaticQuantPattern(RMSNormQuantPattern):
) )
super().__init__(epsilon, fused_key) super().__init__(epsilon, fused_key)
def register(self, pm_pass: PatternMatcherPass): def register(self, pm_pass: PatternMatcherPass) -> None:
# Cannot use methods, as the self argument affects tracing # Cannot use methods, as the self argument affects tracing
def pattern(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): def pattern(
input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
) -> torch.Tensor:
result_rms = self.rmsnorm_matcher(input, weight) result_rms = self.rmsnorm_matcher(input, weight)
return self.quant_matcher(result_rms, scale)[0] return self.quant_matcher(result_rms, scale)[0]
def replacement(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): def replacement(
input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
) -> torch.Tensor:
# In case we're matching native rms-norm, conversions might be # In case we're matching native rms-norm, conversions might be
# optimized out. We convert here just to be safe. # optimized out. We convert here just to be safe.
input = input.to(dtype=self.model_dtype) input = input.to(dtype=self.model_dtype)
...@@ -187,7 +193,9 @@ class RMSNormStaticQuantPattern(RMSNormQuantPattern): ...@@ -187,7 +193,9 @@ class RMSNormStaticQuantPattern(RMSNormQuantPattern):
class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern): class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True): def __init__(
self, epsilon: float, quant_dtype: torch.dtype, symmetric: bool = True
) -> None:
key = FusedRMSQuantKey( key = FusedRMSQuantKey(
fused_add=True, fused_add=True,
quant=QuantKey( quant=QuantKey(
...@@ -196,13 +204,13 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern): ...@@ -196,13 +204,13 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
) )
super().__init__(epsilon, key) super().__init__(epsilon, key)
def register(self, pm_pass: PatternMatcherPass): def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern( def pattern(
input: torch.Tensor, input: torch.Tensor,
weight: torch.Tensor, weight: torch.Tensor,
residual: torch.Tensor, residual: torch.Tensor,
scale: torch.Tensor, scale: torch.Tensor,
): ) -> tuple[torch.Tensor, torch.Tensor]:
result_rms, residual = self.rmsnorm_matcher(input, weight, residual) result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
result, _ = self.quant_matcher(result_rms, scale) result, _ = self.quant_matcher(result_rms, scale)
...@@ -213,7 +221,7 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern): ...@@ -213,7 +221,7 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
weight: torch.Tensor, weight: torch.Tensor,
residual: torch.Tensor, residual: torch.Tensor,
scale: torch.Tensor, scale: torch.Tensor,
): ) -> tuple[torch.Tensor, torch.Tensor]:
# In case we're matching native rms-norm, conversions might be # In case we're matching native rms-norm, conversions might be
# optimized out. We convert here just to be safe. # optimized out. We convert here just to be safe.
input = input.to(dtype=self.model_dtype) input = input.to(dtype=self.model_dtype)
...@@ -253,10 +261,10 @@ class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern): ...@@ -253,10 +261,10 @@ class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern):
epsilon: float, epsilon: float,
quant_dtype: torch.dtype, quant_dtype: torch.dtype,
group_shape: GroupShape, group_shape: GroupShape,
symmetric=True, symmetric: bool = True,
has_col_major_scales: bool = False, has_col_major_scales: bool = False,
is_e8m0: bool = False, is_e8m0: bool = False,
): ) -> None:
scale = ScaleDesc(torch.float32, False, group_shape) scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey( key = FusedRMSQuantKey(
fused_add=True, fused_add=True,
...@@ -269,15 +277,17 @@ class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern): ...@@ -269,15 +277,17 @@ class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern):
epsilon, key, has_col_major_scales=has_col_major_scales, is_e8m0=is_e8m0 epsilon, key, has_col_major_scales=has_col_major_scales, is_e8m0=is_e8m0
) )
def register(self, pm_pass: PatternMatcherPass): def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor): def pattern(
input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
result_rms, residual = self.rmsnorm_matcher(input, weight, residual) result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
result, scale = self.quant_matcher(result_rms) result, scale = self.quant_matcher(result_rms)
return result, residual, scale return result, residual, scale
def replacement( def replacement(
input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
): ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# In case we're matching native rms-norm, conversions might be # In case we're matching native rms-norm, conversions might be
# optimized out. We convert here just to be safe. # optimized out. We convert here just to be safe.
input = input.to(dtype=self.model_dtype) input = input.to(dtype=self.model_dtype)
...@@ -315,10 +325,10 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern): ...@@ -315,10 +325,10 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern):
epsilon: float, epsilon: float,
quant_dtype: torch.dtype, quant_dtype: torch.dtype,
group_shape: GroupShape, group_shape: GroupShape,
symmetric=True, symmetric: bool = True,
has_col_major_scales: bool = False, has_col_major_scales: bool = False,
is_e8m0: bool = False, is_e8m0: bool = False,
): ) -> None:
scale = ScaleDesc(torch.float32, False, group_shape) scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey( key = FusedRMSQuantKey(
fused_add=False, fused_add=False,
...@@ -329,13 +339,17 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern): ...@@ -329,13 +339,17 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern):
epsilon, key, has_col_major_scales=has_col_major_scales, is_e8m0=is_e8m0 epsilon, key, has_col_major_scales=has_col_major_scales, is_e8m0=is_e8m0
) )
def register(self, pm_pass: PatternMatcherPass): def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(input: torch.Tensor, weight: torch.Tensor): def pattern(
input: torch.Tensor, weight: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
result_rms = self.rmsnorm_matcher(input, weight) result_rms = self.rmsnorm_matcher(input, weight)
result, scale = self.quant_matcher(result_rms) result, scale = self.quant_matcher(result_rms)
return result, scale return result, scale
def replacement(input: torch.Tensor, weight: torch.Tensor): def replacement(
input: torch.Tensor, weight: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
# In case we're matching native rms-norm, conversions might be # In case we're matching native rms-norm, conversions might be
# optimized out. We convert here just to be safe. # optimized out. We convert here just to be safe.
input = input.to(dtype=self.model_dtype) input = input.to(dtype=self.model_dtype)
...@@ -375,8 +389,8 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern): ...@@ -375,8 +389,8 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
epsilon: float, epsilon: float,
quant_dtype: torch.dtype, quant_dtype: torch.dtype,
group_shape: GroupShape = GroupShape.PER_TOKEN, group_shape: GroupShape = GroupShape.PER_TOKEN,
symmetric=True, symmetric: bool = True,
): ) -> None:
scale = ScaleDesc(torch.float32, False, group_shape) scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey( key = FusedRMSQuantKey(
fused_add=False, fused_add=False,
...@@ -384,13 +398,17 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern): ...@@ -384,13 +398,17 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
) )
super().__init__(epsilon, key) super().__init__(epsilon, key)
def register(self, pm_pass: PatternMatcherPass): def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(input: torch.Tensor, weight: torch.Tensor): def pattern(
input: torch.Tensor, weight: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
result_rms = self.rmsnorm_matcher(input, weight) result_rms = self.rmsnorm_matcher(input, weight)
# result, scale # result, scale
return self.quant_matcher(result_rms) return self.quant_matcher(result_rms) # type: ignore[no-any-return]
def replacement(input: torch.Tensor, weight: torch.Tensor): def replacement(
input: torch.Tensor, weight: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
# In case we're matching native rms-norm, conversions might be # In case we're matching native rms-norm, conversions might be
# optimized out. We convert here just to be safe. # optimized out. We convert here just to be safe.
input = input.to(dtype=self.model_dtype) input = input.to(dtype=self.model_dtype)
...@@ -426,8 +444,8 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern): ...@@ -426,8 +444,8 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
epsilon: float, epsilon: float,
quant_dtype: torch.dtype, quant_dtype: torch.dtype,
group_shape: GroupShape = GroupShape.PER_TOKEN, group_shape: GroupShape = GroupShape.PER_TOKEN,
symmetric=True, symmetric: bool = True,
): ) -> None:
scale = ScaleDesc(torch.float32, False, group_shape) scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey( key = FusedRMSQuantKey(
fused_add=True, fused_add=True,
...@@ -435,8 +453,10 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern): ...@@ -435,8 +453,10 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
) )
super().__init__(epsilon, key) super().__init__(epsilon, key)
def register(self, pm_pass: PatternMatcherPass): def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor): def pattern(
input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
result_rms, residual = self.rmsnorm_matcher(input, weight, residual) result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
result, scale = self.quant_matcher(result_rms) result, scale = self.quant_matcher(result_rms)
...@@ -444,7 +464,7 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern): ...@@ -444,7 +464,7 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
def replacement( def replacement(
input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
): ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# In case we're matching native rms-norm, conversions might be # In case we're matching native rms-norm, conversions might be
# optimized out. We convert here just to be safe. # optimized out. We convert here just to be safe.
input = input.to(dtype=self.model_dtype) input = input.to(dtype=self.model_dtype)
...@@ -481,7 +501,7 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass): ...@@ -481,7 +501,7 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass):
""" """
@enable_fake_mode @enable_fake_mode
def __init__(self, config: VllmConfig): def __init__(self, config: VllmConfig) -> None:
super().__init__(config) super().__init__(config)
self.patterns: PatternMatcherPass = PatternMatcherPass( self.patterns: PatternMatcherPass = PatternMatcherPass(
...@@ -533,11 +553,11 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass): ...@@ -533,11 +553,11 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass):
self.dump_patterns(config, self.patterns) self.dump_patterns(config, self.patterns)
@VllmInductorPass.time_and_log @VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph): def __call__(self, graph: fx.Graph) -> None:
self.matched_count = self.patterns.apply(graph) self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count) logger.debug("Replaced %s patterns", self.matched_count)
def uuid(self) -> Any: def uuid(self) -> str:
return self.hash_source( return self.hash_source(
self, self,
RMSNormGroupQuantPattern, RMSNormGroupQuantPattern,
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Callable from collections.abc import Callable
from typing import Any, ParamSpec
import torch import torch
import torch._inductor.pattern_matcher as pm import torch._inductor.pattern_matcher as pm
...@@ -28,7 +29,7 @@ from .matcher_utils import MatcherQuantFP8 ...@@ -28,7 +29,7 @@ from .matcher_utils import MatcherQuantFP8
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
logger = init_logger(__name__) logger = init_logger(__name__)
P = ParamSpec("P")
FP8_DTYPE = current_platform.fp8_dtype() FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8 FP4_DTYPE = torch.uint8
...@@ -47,7 +48,7 @@ class AttentionQuantPattern(ABC): ...@@ -47,7 +48,7 @@ class AttentionQuantPattern(ABC):
layer: Attention, layer: Attention,
quant_key: QuantKey, quant_key: QuantKey,
dtype: torch.dtype, dtype: torch.dtype,
): ) -> None:
self.layer = layer self.layer = layer
self.layer_name = layer.layer_name self.layer_name = layer.layer_name
self.num_heads = layer.num_heads self.num_heads = layer.num_heads
...@@ -61,17 +62,20 @@ class AttentionQuantPattern(ABC): ...@@ -61,17 +62,20 @@ class AttentionQuantPattern(ABC):
) )
self.QUANT_OP = QUANT_OPS[self.quant_key] self.QUANT_OP = QUANT_OPS[self.quant_key]
def empty(self, *args, **kwargs): def empty(self, *args: Any, **kwargs: Any) -> torch.Tensor:
kwargs = {"dtype": self.dtype, "device": "cuda", **kwargs} kwargs = {"dtype": self.dtype, "device": "cuda", **kwargs}
return torch.empty(*args, **kwargs) return torch.empty(*args, **kwargs)
def empty_quant(self, *args, **kwargs): def empty_quant(self, *args: Any, **kwargs: Any) -> torch.Tensor:
kwargs = {"dtype": self.quant_dtype, "device": "cuda", **kwargs} kwargs = {"dtype": self.quant_dtype, "device": "cuda", **kwargs}
return torch.empty(*args, **kwargs) return torch.empty(*args, **kwargs)
@staticmethod @staticmethod
def wrap_trace_fn(trace_fn, *process_fx_fns: Callable[[fx.GraphModule], None]): def wrap_trace_fn(
def wrapped(*args, **kwargs): trace_fn: Callable[P, fx.GraphModule],
*process_fx_fns: Callable[[fx.GraphModule], None],
) -> Callable[P, fx.GraphModule]:
def wrapped(*args: P.args, **kwargs: P.kwargs) -> fx.GraphModule:
gm = trace_fn(*args, **kwargs) gm = trace_fn(*args, **kwargs)
for process_fx in process_fx_fns: for process_fx in process_fx_fns:
process_fx(gm) process_fx(gm)
...@@ -81,13 +85,13 @@ class AttentionQuantPattern(ABC): ...@@ -81,13 +85,13 @@ class AttentionQuantPattern(ABC):
return wrapped return wrapped
@staticmethod @staticmethod
def fx_view_to_reshape(gm: torch.fx.GraphModule): def fx_view_to_reshape(gm: torch.fx.GraphModule) -> None:
from torch._inductor.fx_passes.post_grad import view_to_reshape from torch._inductor.fx_passes.post_grad import view_to_reshape
view_to_reshape(gm) view_to_reshape(gm)
@staticmethod @staticmethod
def remove_noop_permutes(gm: torch.fx.GraphModule): def remove_noop_permutes(gm: torch.fx.GraphModule) -> None:
for node in gm.graph.nodes: for node in gm.graph.nodes:
if not is_func(node, torch.ops.aten.permute.default): if not is_func(node, torch.ops.aten.permute.default):
continue continue
...@@ -100,12 +104,12 @@ class AttentionQuantPattern(ABC): ...@@ -100,12 +104,12 @@ class AttentionQuantPattern(ABC):
node.replace_all_uses_with(node.args[0]) node.replace_all_uses_with(node.args[0])
gm.graph.erase_node(node) gm.graph.erase_node(node)
def register_if_supported(self, pm_pass: PatternMatcherPass): def register_if_supported(self, pm_pass: PatternMatcherPass) -> None:
if self.layer.impl.fused_output_quant_supported(self.quant_key): if self.layer.impl.fused_output_quant_supported(self.quant_key):
self._register(pm_pass) self._register(pm_pass)
@abstractmethod @abstractmethod
def _register(self, pm_pass: PatternMatcherPass): def _register(self, pm_pass: PatternMatcherPass) -> None:
raise NotImplementedError raise NotImplementedError
...@@ -124,21 +128,21 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern): ...@@ -124,21 +128,21 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
layer: Attention, layer: Attention,
dtype: torch.dtype, dtype: torch.dtype,
symmetric: bool = True, symmetric: bool = True,
): ) -> None:
quant_key = QuantKey( quant_key = QuantKey(
dtype=FP8_DTYPE, scale=kStaticTensorScale, symmetric=symmetric dtype=FP8_DTYPE, scale=kStaticTensorScale, symmetric=symmetric
) )
super().__init__(layer, quant_key, dtype) super().__init__(layer, quant_key, dtype)
self.quant_matcher = MatcherQuantFP8(quant_key) self.quant_matcher = MatcherQuantFP8(quant_key)
def _register(self, pm_pass: PatternMatcherPass): def _register(self, pm_pass: PatternMatcherPass) -> None:
def pattern( def pattern(
q: torch.Tensor, q: torch.Tensor,
k: torch.Tensor, k: torch.Tensor,
v: torch.Tensor, v: torch.Tensor,
output_attn: torch.Tensor, output_attn: torch.Tensor,
scale: torch.Tensor, scale: torch.Tensor,
): ) -> torch.Tensor:
at1 = auto_functionalized( at1 = auto_functionalized(
ATTN_OP, ATTN_OP,
query=q, query=q,
...@@ -161,7 +165,7 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern): ...@@ -161,7 +165,7 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
v: torch.Tensor, v: torch.Tensor,
output_attn: torch.Tensor, output_attn: torch.Tensor,
scale: torch.Tensor, scale: torch.Tensor,
): ) -> torch.Tensor:
# attn output in quant_dtype # attn output in quant_dtype
output_attn = torch.ops.aten.full.default( output_attn = torch.ops.aten.full.default(
[q.shape[0], self.num_heads, self.head_size], [q.shape[0], self.num_heads, self.head_size],
...@@ -212,10 +216,10 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern): ...@@ -212,10 +216,10 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
will be passed into Attention op as the `output_scale` argument. will be passed into Attention op as the `output_scale` argument.
""" """
def __init__(self, layer: Attention, dtype: torch.dtype): def __init__(self, layer: Attention, dtype: torch.dtype) -> None:
super().__init__(layer, kNvfp4Quant, dtype) super().__init__(layer, kNvfp4Quant, dtype)
def _register(self, pm_pass: PatternMatcherPass): def _register(self, pm_pass: PatternMatcherPass) -> None:
def pattern( def pattern(
q: torch.Tensor, q: torch.Tensor,
k: torch.Tensor, k: torch.Tensor,
...@@ -224,7 +228,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern): ...@@ -224,7 +228,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
output_quant: torch.Tensor, output_quant: torch.Tensor,
output_scale: torch.Tensor, output_scale: torch.Tensor,
input_scale: torch.Tensor, input_scale: torch.Tensor,
): ) -> tuple[torch.Tensor, torch.Tensor]:
at1 = auto_functionalized( at1 = auto_functionalized(
ATTN_OP, ATTN_OP,
query=q, query=q,
...@@ -256,7 +260,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern): ...@@ -256,7 +260,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
output_quant: torch.Tensor, output_quant: torch.Tensor,
output_scale: torch.Tensor, output_scale: torch.Tensor,
input_scale: torch.Tensor, input_scale: torch.Tensor,
): ) -> tuple[torch.Tensor, torch.Tensor]:
# attention output in quant_dtype # attention output in quant_dtype
output_attn = torch.ops.aten.full.default( output_attn = torch.ops.aten.full.default(
[q.shape[0], self.num_heads, self.head_size // 2], [q.shape[0], self.num_heads, self.head_size // 2],
...@@ -318,7 +322,7 @@ class AttnFusionPass(VllmPatternMatcherPass): ...@@ -318,7 +322,7 @@ class AttnFusionPass(VllmPatternMatcherPass):
""" """
@enable_fake_mode @enable_fake_mode
def __init__(self, config: VllmConfig): def __init__(self, config: VllmConfig) -> None:
super().__init__(config) super().__init__(config)
self.patterns = PatternMatcherPass(pass_name="attn_fusion_pass") self.patterns = PatternMatcherPass(pass_name="attn_fusion_pass")
...@@ -350,7 +354,7 @@ class AttnFusionPass(VllmPatternMatcherPass): ...@@ -350,7 +354,7 @@ class AttnFusionPass(VllmPatternMatcherPass):
self.matched_count = self.patterns.apply(graph) self.matched_count = self.patterns.apply(graph)
logger.debug("Fused quant onto %s attention nodes", self.matched_count) logger.debug("Fused quant onto %s attention nodes", self.matched_count)
def uuid(self): def uuid(self) -> str:
return VllmInductorPass.hash_source( return VllmInductorPass.hash_source(
self, self,
AttentionQuantPattern, AttentionQuantPattern,
......
...@@ -68,7 +68,7 @@ class InductorPass(CustomGraphPass): # type: ignore[misc] ...@@ -68,7 +68,7 @@ class InductorPass(CustomGraphPass): # type: ignore[misc]
This is defined as a convenience and should work in most cases. This is defined as a convenience and should work in most cases.
""" """
def uuid(self) -> Any: def uuid(self) -> str:
""" """
Provide a unique identifier for the pass, used in Inductor code cache. Provide a unique identifier for the pass, used in Inductor code cache.
This should depend on the pass implementation, so that changes to the This should depend on the pass implementation, so that changes to the
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any
import torch import torch
from torch._higher_order_ops import auto_functionalized from torch._higher_order_ops import auto_functionalized
...@@ -47,7 +48,7 @@ SILU_MUL_OP = torch.ops._C.silu_and_mul.default ...@@ -47,7 +48,7 @@ SILU_MUL_OP = torch.ops._C.silu_and_mul.default
class MatcherCustomOp(ABC): class MatcherCustomOp(ABC):
def __init__(self, enabled: bool): def __init__(self, enabled: bool) -> None:
config = get_current_vllm_config() config = get_current_vllm_config()
self.model_dtype = config.model_config.dtype if config.model_config else None self.model_dtype = config.model_config.dtype if config.model_config else None
self.device = config.device_config.device if config.device_config else None self.device = config.device_config.device if config.device_config else None
...@@ -56,24 +57,24 @@ class MatcherCustomOp(ABC): ...@@ -56,24 +57,24 @@ class MatcherCustomOp(ABC):
self.forward = self.forward_custom if enabled else self.forward_native self.forward = self.forward_custom if enabled else self.forward_native
@abstractmethod @abstractmethod
def forward_custom(self, *args, **kws): def forward_custom(self, *args: Any, **kwargs: Any) -> Any:
pass pass
@abstractmethod @abstractmethod
def forward_native(self, *args, **kws): def forward_native(self, *args: Any, **kwargs: Any) -> Any:
pass pass
def __call__(self, *args, **kws): def __call__(self, *args: Any, **kwargs: Any) -> Any:
return self.forward(*args, **kws) return self.forward(*args, **kwargs)
def empty(self, *args, **kws): def empty(self, *args: Any, **kwargs: Any) -> torch.Tensor:
return torch.empty(*args, dtype=self.model_dtype, device=self.device, **kws) return torch.empty(*args, dtype=self.model_dtype, device=self.device, **kwargs)
def empty_int64(self, *args, **kws): def empty_int64(self, *args: Any, **kwargs: Any) -> torch.Tensor:
return torch.empty(*args, dtype=torch.int64, device=self.device, **kws) return torch.empty(*args, dtype=torch.int64, device=self.device, **kwargs)
def empty_f32(self, *args, **kws): def empty_f32(self, *args: Any, **kwargs: Any) -> torch.Tensor:
return torch.empty(*args, dtype=torch.float32, device=self.device, **kws) return torch.empty(*args, dtype=torch.float32, device=self.device, **kwargs)
def inputs(self) -> list[torch.Tensor]: def inputs(self) -> list[torch.Tensor]:
"""Utility for inputs to the pattern""" """Utility for inputs to the pattern"""
...@@ -157,7 +158,7 @@ class MatcherRMSNorm(MatcherCustomOp): ...@@ -157,7 +158,7 @@ class MatcherRMSNorm(MatcherCustomOp):
epsilon: float, epsilon: float,
enabled: bool | None = None, enabled: bool | None = None,
match_rocm_aiter: bool = False, match_rocm_aiter: bool = False,
): ) -> None:
if enabled is None: if enabled is None:
enabled = RMSNorm.enabled() enabled = RMSNorm.enabled()
...@@ -169,7 +170,7 @@ class MatcherRMSNorm(MatcherCustomOp): ...@@ -169,7 +170,7 @@ class MatcherRMSNorm(MatcherCustomOp):
if match_rocm_aiter: if match_rocm_aiter:
self._rmsnorm_op = rocm_aiter_ops.get_rmsnorm_op() self._rmsnorm_op = rocm_aiter_ops.get_rmsnorm_op()
def inputs(self): def inputs(self) -> list[torch.Tensor]:
input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16) input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16)
weight = self.empty(16) weight = self.empty(16)
return [input, weight] return [input, weight]
...@@ -220,7 +221,7 @@ class MatcherFusedAddRMSNorm(MatcherCustomOp): ...@@ -220,7 +221,7 @@ class MatcherFusedAddRMSNorm(MatcherCustomOp):
epsilon: float, epsilon: float,
enabled: bool | None = None, enabled: bool | None = None,
match_rocm_aiter: bool = False, match_rocm_aiter: bool = False,
): ) -> None:
if enabled is None: if enabled is None:
enabled = RMSNorm.enabled() enabled = RMSNorm.enabled()
...@@ -233,7 +234,7 @@ class MatcherFusedAddRMSNorm(MatcherCustomOp): ...@@ -233,7 +234,7 @@ class MatcherFusedAddRMSNorm(MatcherCustomOp):
if match_rocm_aiter: if match_rocm_aiter:
self._rmsnorm_op = rocm_aiter_ops.get_rmsnorm_fused_add_op() self._rmsnorm_op = rocm_aiter_ops.get_rmsnorm_fused_add_op()
def inputs(self): def inputs(self) -> list[torch.Tensor]:
input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16) input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16)
weight = self.empty(16) weight = self.empty(16)
residual = self.empty(5, 16) residual = self.empty(5, 16)
...@@ -245,7 +246,7 @@ class MatcherFusedAddRMSNorm(MatcherCustomOp): ...@@ -245,7 +246,7 @@ class MatcherFusedAddRMSNorm(MatcherCustomOp):
weight: torch.Tensor, weight: torch.Tensor,
residual: torch.Tensor, residual: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
return self._rmsnorm_op( return self._rmsnorm_op( # type: ignore[no-any-return]
x=input, residual=residual, weight=weight, variance_epsilon=self.epsilon x=input, residual=residual, weight=weight, variance_epsilon=self.epsilon
) )
...@@ -287,7 +288,7 @@ class MatcherQuantFP8(MatcherCustomOp): ...@@ -287,7 +288,7 @@ class MatcherQuantFP8(MatcherCustomOp):
has_col_major_scales: bool = False, has_col_major_scales: bool = False,
is_e8m0: bool = False, is_e8m0: bool = False,
match_rocm_aiter: bool = False, match_rocm_aiter: bool = False,
): ) -> None:
if enabled is None: if enabled is None:
enabled = QuantFP8.enabled() enabled = QuantFP8.enabled()
...@@ -340,13 +341,13 @@ class MatcherQuantFP8(MatcherCustomOp): ...@@ -340,13 +341,13 @@ class MatcherQuantFP8(MatcherCustomOp):
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
quant_key_group_shape = self.quant_key.scale.group_shape quant_key_group_shape = self.quant_key.scale.group_shape
if quant_key_group_shape == GroupShape.PER_TOKEN: if quant_key_group_shape == GroupShape.PER_TOKEN:
return self.QUANT_OP( return self.QUANT_OP( # type: ignore[no-any-return]
x=input, x=input,
quant_dtype=self.quant_key.dtype, quant_dtype=self.quant_key.dtype,
scale=scale, scale=scale,
) )
else: else:
return self.QUANT_OP(input, quant_key_group_shape.col) return self.QUANT_OP(input, quant_key_group_shape.col) # type: ignore[no-any-return]
def forward_custom( def forward_custom(
self, self,
...@@ -400,9 +401,9 @@ class MatcherQuantFP8(MatcherCustomOp): ...@@ -400,9 +401,9 @@ class MatcherQuantFP8(MatcherCustomOp):
input: torch.Tensor, input: torch.Tensor,
scale: torch.Tensor | None = None, scale: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
return self.quant_fp8(input, scale) return self.quant_fp8(input, scale) # type: ignore[no-any-return]
def make_scale(self, input: torch.Tensor, transposed: bool = False): def make_scale(self, input: torch.Tensor, transposed: bool = False) -> torch.Tensor:
normalized_group_shape = _normalize_quant_group_shape( normalized_group_shape = _normalize_quant_group_shape(
input, self.quant_key.scale.group_shape input, self.quant_key.scale.group_shape
) )
...@@ -427,7 +428,7 @@ class MatcherQuantFP8(MatcherCustomOp): ...@@ -427,7 +428,7 @@ class MatcherQuantFP8(MatcherCustomOp):
class MatcherSiluAndMul(MatcherCustomOp): class MatcherSiluAndMul(MatcherCustomOp):
def __init__(self, enabled: bool | None = None): def __init__(self, enabled: bool | None = None) -> None:
if enabled is None: if enabled is None:
enabled = SiluAndMul.enabled() enabled = SiluAndMul.enabled()
super().__init__(enabled) super().__init__(enabled)
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable from collections.abc import Callable
from typing import ParamSpec
import torch import torch
import torch._inductor.pattern_matcher as pm import torch._inductor.pattern_matcher as pm
...@@ -23,6 +24,8 @@ logger = init_logger(__name__) ...@@ -23,6 +24,8 @@ logger = init_logger(__name__)
FUSED_QK_ROPE_OP = torch.ops._C.fused_qk_norm_rope.default FUSED_QK_ROPE_OP = torch.ops._C.fused_qk_norm_rope.default
P = ParamSpec("P")
class QkNormRopePattern: class QkNormRopePattern:
""" """
...@@ -72,7 +75,7 @@ class QkNormRopePattern: ...@@ -72,7 +75,7 @@ class QkNormRopePattern:
use_flashinfer=self.rope_flashinfer, use_flashinfer=self.rope_flashinfer,
) )
def get_inputs(self): def get_inputs(self) -> list[torch.Tensor]:
# Sample inputs to help pattern tracing # Sample inputs to help pattern tracing
T = 5 T = 5
qkv = empty_bf16(T, self.q_size + 2 * self.kv_size) qkv = empty_bf16(T, self.q_size + 2 * self.kv_size)
...@@ -92,8 +95,11 @@ class QkNormRopePattern: ...@@ -92,8 +95,11 @@ class QkNormRopePattern:
] ]
@staticmethod @staticmethod
def wrap_trace_fn(trace_fn, *process_fx_fns: Callable[[fx.GraphModule], None]): def wrap_trace_fn(
def wrapped(*args, **kwargs): trace_fn: Callable[P, fx.GraphModule],
*process_fx_fns: Callable[[fx.GraphModule], None],
) -> Callable[P, fx.GraphModule]:
def wrapped(*args: P.args, **kwargs: P.kwargs) -> fx.GraphModule:
gm = trace_fn(*args, **kwargs) gm = trace_fn(*args, **kwargs)
for process_fx in process_fx_fns: for process_fx in process_fx_fns:
process_fx(gm) process_fx(gm)
...@@ -103,19 +109,19 @@ class QkNormRopePattern: ...@@ -103,19 +109,19 @@ class QkNormRopePattern:
return wrapped return wrapped
@staticmethod @staticmethod
def fx_view_to_reshape(gm: torch.fx.GraphModule): def fx_view_to_reshape(gm: torch.fx.GraphModule) -> None:
from torch._inductor.fx_passes.post_grad import view_to_reshape from torch._inductor.fx_passes.post_grad import view_to_reshape
view_to_reshape(gm) view_to_reshape(gm)
def register(self, pm_pass: PatternMatcherPass): def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern( def pattern(
qkv: torch.Tensor, qkv: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
q_weight: torch.Tensor, q_weight: torch.Tensor,
k_weight: torch.Tensor, k_weight: torch.Tensor,
cos_sin_cache: torch.Tensor, cos_sin_cache: torch.Tensor,
): ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# split qkv -> q,k,v # split qkv -> q,k,v
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
...@@ -143,7 +149,7 @@ class QkNormRopePattern: ...@@ -143,7 +149,7 @@ class QkNormRopePattern:
q_weight: torch.Tensor, q_weight: torch.Tensor,
k_weight: torch.Tensor, k_weight: torch.Tensor,
cos_sin_cache: torch.Tensor, cos_sin_cache: torch.Tensor,
): ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# Run fused qk_norm_rope op # Run fused qk_norm_rope op
result = auto_functionalized( result = auto_functionalized(
FUSED_QK_ROPE_OP, FUSED_QK_ROPE_OP,
...@@ -162,7 +168,7 @@ class QkNormRopePattern: ...@@ -162,7 +168,7 @@ class QkNormRopePattern:
result_qkv = result[1] result_qkv = result[1]
# Split back to q,k,v and return # Split back to q,k,v and return
return result_qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) return result_qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) # type: ignore[no-any-return]
# NOTE: use fx_view_to_reshape to unify view/reshape to simplify # NOTE: use fx_view_to_reshape to unify view/reshape to simplify
# pattern and increase matching opportunities # pattern and increase matching opportunities
...@@ -182,7 +188,7 @@ class QKNormRoPEFusionPass(VllmPatternMatcherPass): ...@@ -182,7 +188,7 @@ class QKNormRoPEFusionPass(VllmPatternMatcherPass):
"""Fuse Q/K RMSNorm + RoPE into fused_qk_norm_rope when the custom op exists.""" """Fuse Q/K RMSNorm + RoPE into fused_qk_norm_rope when the custom op exists."""
@enable_fake_mode @enable_fake_mode
def __init__(self, config: VllmConfig): def __init__(self, config: VllmConfig) -> None:
super().__init__(config) super().__init__(config)
self.patterns: PatternMatcherPass = PatternMatcherPass( self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="qk_norm_rope_fusion_pass" pass_name="qk_norm_rope_fusion_pass"
...@@ -234,5 +240,5 @@ class QKNormRoPEFusionPass(VllmPatternMatcherPass): ...@@ -234,5 +240,5 @@ class QKNormRoPEFusionPass(VllmPatternMatcherPass):
self.matched_count = self.patterns.apply(graph) self.matched_count = self.patterns.apply(graph)
logger.debug("Fused QK Norm+RoPE on %s sites", self.matched_count) logger.debug("Fused QK Norm+RoPE on %s sites", self.matched_count)
def uuid(self): def uuid(self) -> str:
return VllmInductorPass.hash_source(self, QkNormRopePattern) return VllmInductorPass.hash_source(self, QkNormRopePattern)
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
import torch import torch
import torch._inductor.pattern_matcher as pm import torch._inductor.pattern_matcher as pm
...@@ -65,8 +64,8 @@ class AiterRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern): ...@@ -65,8 +64,8 @@ class AiterRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern):
quant_dtype: torch.dtype, quant_dtype: torch.dtype,
match_aiter_quant: bool = True, match_aiter_quant: bool = True,
group_shape: GroupShape = GroupShape.PER_TOKEN, group_shape: GroupShape = GroupShape.PER_TOKEN,
symmetric=True, symmetric: bool = True,
): ) -> None:
scale = ScaleDesc(torch.float32, False, group_shape) scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey( key = FusedRMSQuantKey(
fused_add=False, fused_add=False,
...@@ -75,11 +74,11 @@ class AiterRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern): ...@@ -75,11 +74,11 @@ class AiterRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern):
super().__init__(epsilon, key, match_aiter_quant) super().__init__(epsilon, key, match_aiter_quant)
def register(self, pm_pass): def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern( def pattern(
input: torch.Tensor, input: torch.Tensor,
weight: torch.Tensor, weight: torch.Tensor,
): ) -> tuple[torch.Tensor, torch.Tensor]:
result_rms = self.rmsnorm_matcher(input, weight) result_rms = self.rmsnorm_matcher(input, weight)
result, scale = self.quant_matcher(result_rms) result, scale = self.quant_matcher(result_rms)
return result, scale return result, scale
...@@ -87,7 +86,7 @@ class AiterRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern): ...@@ -87,7 +86,7 @@ class AiterRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern):
def replacement( def replacement(
input: torch.Tensor, input: torch.Tensor,
weight: torch.Tensor, weight: torch.Tensor,
): ) -> tuple[torch.Tensor, torch.Tensor]:
result = self.FUSED_OP( result = self.FUSED_OP(
x=input, x=input,
weight=weight, weight=weight,
...@@ -117,8 +116,8 @@ class AiterFusedAddRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern): ...@@ -117,8 +116,8 @@ class AiterFusedAddRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern):
quant_dtype: torch.dtype, quant_dtype: torch.dtype,
match_aiter_quant: bool = True, match_aiter_quant: bool = True,
group_shape: GroupShape = GroupShape.PER_TOKEN, group_shape: GroupShape = GroupShape.PER_TOKEN,
symmetric=True, symmetric: bool = True,
): ) -> None:
scale = ScaleDesc(torch.float32, False, group_shape) scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey( key = FusedRMSQuantKey(
fused_add=True, fused_add=True,
...@@ -127,12 +126,12 @@ class AiterFusedAddRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern): ...@@ -127,12 +126,12 @@ class AiterFusedAddRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern):
super().__init__(epsilon, key, match_aiter_quant) super().__init__(epsilon, key, match_aiter_quant)
def register(self, pm_pass): def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern( def pattern(
input: torch.Tensor, input: torch.Tensor,
weight: torch.Tensor, weight: torch.Tensor,
residual: torch.Tensor, residual: torch.Tensor,
): ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
result_rms, residual_out = self.rmsnorm_matcher(input, weight, residual) result_rms, residual_out = self.rmsnorm_matcher(input, weight, residual)
result, scale = self.quant_matcher(result_rms) result, scale = self.quant_matcher(result_rms)
...@@ -140,7 +139,7 @@ class AiterFusedAddRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern): ...@@ -140,7 +139,7 @@ class AiterFusedAddRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern):
def replacement( def replacement(
input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
): ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
result = self.FUSED_OP( result = self.FUSED_OP(
x=input, x=input,
residual=residual, residual=residual,
...@@ -174,8 +173,8 @@ class AiterRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern): ...@@ -174,8 +173,8 @@ class AiterRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern):
quant_dtype: torch.dtype, quant_dtype: torch.dtype,
group_shape: GroupShape, group_shape: GroupShape,
match_aiter_quant: bool = True, match_aiter_quant: bool = True,
symmetric=True, symmetric: bool = True,
): ) -> None:
scale = ScaleDesc(torch.float32, False, group_shape) scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey( key = FusedRMSQuantKey(
fused_add=False, fused_add=False,
...@@ -184,11 +183,11 @@ class AiterRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern): ...@@ -184,11 +183,11 @@ class AiterRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern):
super().__init__(epsilon, key, match_aiter_quant) super().__init__(epsilon, key, match_aiter_quant)
def register(self, pm_pass: PatternMatcherPass): def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern( def pattern(
input: torch.Tensor, input: torch.Tensor,
weight: torch.Tensor, weight: torch.Tensor,
): ) -> tuple[torch.Tensor, torch.Tensor]:
result_rms = self.rmsnorm_matcher(input, weight) result_rms = self.rmsnorm_matcher(input, weight)
result, scale = self.quant_matcher(result_rms) result, scale = self.quant_matcher(result_rms)
return result, scale return result, scale
...@@ -196,7 +195,7 @@ class AiterRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern): ...@@ -196,7 +195,7 @@ class AiterRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern):
def replacement( def replacement(
input: torch.Tensor, input: torch.Tensor,
weight: torch.Tensor, weight: torch.Tensor,
): ) -> tuple[torch.Tensor, torch.Tensor]:
at = self.FUSED_OP( at = self.FUSED_OP(
x=input, x=input,
weight=weight, weight=weight,
...@@ -225,8 +224,8 @@ class AiterFusedAddRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern): ...@@ -225,8 +224,8 @@ class AiterFusedAddRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern):
quant_dtype: torch.dtype, quant_dtype: torch.dtype,
group_shape: GroupShape, group_shape: GroupShape,
match_aiter_quant: bool = True, match_aiter_quant: bool = True,
symmetric=True, symmetric: bool = True,
): ) -> None:
scale = ScaleDesc(torch.float32, False, group_shape) scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey( key = FusedRMSQuantKey(
fused_add=True, fused_add=True,
...@@ -235,12 +234,12 @@ class AiterFusedAddRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern): ...@@ -235,12 +234,12 @@ class AiterFusedAddRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern):
super().__init__(epsilon, key, match_aiter_quant) super().__init__(epsilon, key, match_aiter_quant)
def register(self, pm_pass: PatternMatcherPass): def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern( def pattern(
input: torch.Tensor, input: torch.Tensor,
weight: torch.Tensor, weight: torch.Tensor,
residual: torch.Tensor, residual: torch.Tensor,
): ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
result_rms, residual_out = self.rmsnorm_matcher(input, weight, residual) result_rms, residual_out = self.rmsnorm_matcher(input, weight, residual)
result, scale = self.quant_matcher(result_rms) result, scale = self.quant_matcher(result_rms)
...@@ -250,7 +249,7 @@ class AiterFusedAddRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern): ...@@ -250,7 +249,7 @@ class AiterFusedAddRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern):
input: torch.Tensor, input: torch.Tensor,
weight: torch.Tensor, weight: torch.Tensor,
residual: torch.Tensor, residual: torch.Tensor,
): ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
at = self.FUSED_OP( at = self.FUSED_OP(
x=input, x=input,
residual=residual, residual=residual,
...@@ -275,7 +274,7 @@ class RocmAiterRMSNormFusionPass(VllmPatternMatcherPass): ...@@ -275,7 +274,7 @@ class RocmAiterRMSNormFusionPass(VllmPatternMatcherPass):
""" """
@enable_fake_mode @enable_fake_mode
def __init__(self, config: VllmConfig): def __init__(self, config: VllmConfig) -> None:
super().__init__(config) super().__init__(config)
self.patterns: PatternMatcherPass = PatternMatcherPass( self.patterns: PatternMatcherPass = PatternMatcherPass(
...@@ -311,11 +310,11 @@ class RocmAiterRMSNormFusionPass(VllmPatternMatcherPass): ...@@ -311,11 +310,11 @@ class RocmAiterRMSNormFusionPass(VllmPatternMatcherPass):
self.dump_patterns(config, self.patterns) self.dump_patterns(config, self.patterns)
@VllmInductorPass.time_and_log @VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph): def __call__(self, graph: fx.Graph) -> None:
self.matched_count = self.patterns.apply(graph) self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count) logger.debug("Replaced %s patterns", self.matched_count)
def uuid(self) -> Any: def uuid(self) -> str:
fusion_patterns = [ fusion_patterns = [
AiterRMSNormDynamicQuantPattern, AiterRMSNormDynamicQuantPattern,
AiterFusedAddRMSNormDynamicQuantPattern, AiterFusedAddRMSNormDynamicQuantPattern,
...@@ -333,29 +332,32 @@ class AiterSiluMulFp8GroupQuantPattern(ActivationQuantPattern): ...@@ -333,29 +332,32 @@ class AiterSiluMulFp8GroupQuantPattern(ActivationQuantPattern):
FUSED_SILU_MUL_QUANT_OP = rocm_aiter_ops.get_act_mul_fused_fp8_group_quant_op() FUSED_SILU_MUL_QUANT_OP = rocm_aiter_ops.get_act_mul_fused_fp8_group_quant_op()
def __init__(self, quant_op: OpOverload): def __init__(self, quant_op: OpOverload) -> None:
self.silu_and_mul_matcher = MatcherSiluAndMul() self.silu_and_mul_matcher = MatcherSiluAndMul()
self.quant_op = quant_op self.quant_op = quant_op
def register(self, pm_pass: PatternMatcherPass): def get_inputs(self) -> list[torch.Tensor]:
return [
self.silu_and_mul_matcher.inputs()[0],
]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern( def pattern(
input: torch.Tensor, input: torch.Tensor,
): ) -> tuple[torch.Tensor, torch.Tensor]:
at1 = self.silu_and_mul_matcher(input) at1 = self.silu_and_mul_matcher(input)
at2 = self.quant_op(at1, 128) at2 = self.quant_op(at1, 128)
return at2[0], at2[1] return at2[0], at2[1]
def replacement( def replacement(
input: torch.Tensor, input: torch.Tensor,
): ) -> tuple[torch.Tensor, torch.Tensor]:
at = self.FUSED_SILU_MUL_QUANT_OP(x=input, group_size=128) at = self.FUSED_SILU_MUL_QUANT_OP(x=input, group_size=128)
return at[0], at[1] return at[0], at[1]
inputs = [ pm.register_replacement(
self.silu_and_mul_matcher.inputs()[0], pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
] )
pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)
class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass): class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass):
...@@ -374,7 +376,7 @@ class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass): ...@@ -374,7 +376,7 @@ class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass):
QUANT_OPS = [AITER_GROUP_FP8_QUANT_OP, TRITON_GROUP_FP8_QUANT_OP] QUANT_OPS = [AITER_GROUP_FP8_QUANT_OP, TRITON_GROUP_FP8_QUANT_OP]
@enable_fake_mode @enable_fake_mode
def __init__(self, config: VllmConfig): def __init__(self, config: VllmConfig) -> None:
super().__init__(config) super().__init__(config)
self.patterns: PatternMatcherPass = PatternMatcherPass( self.patterns: PatternMatcherPass = PatternMatcherPass(
...@@ -387,11 +389,11 @@ class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass): ...@@ -387,11 +389,11 @@ class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass):
self.dump_patterns(config, self.patterns) self.dump_patterns(config, self.patterns)
@VllmInductorPass.time_and_log @VllmInductorPass.time_and_log
def __call__(self, graph: torch.fx.Graph): def __call__(self, graph: torch.fx.Graph) -> None:
self.matched_count = self.patterns.apply(graph) self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count) logger.debug("Replaced %s patterns", self.matched_count)
def uuid(self): def uuid(self) -> str:
fusion_patterns = [ fusion_patterns = [
ActivationQuantPattern, ActivationQuantPattern,
AiterSiluMulFp8GroupQuantPattern, AiterSiluMulFp8GroupQuantPattern,
......
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools import functools
from collections.abc import Callable, Sequence
from typing import Any
import torch import torch
import torch._inductor.pattern_matcher as pm import torch._inductor.pattern_matcher as pm
...@@ -26,9 +28,11 @@ from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass ...@@ -26,9 +28,11 @@ from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
logger = init_logger(__name__) logger = init_logger(__name__)
def get_first_out_wrapper(fn): def get_first_out_wrapper(
fn: Callable[..., Sequence[torch.Tensor]],
) -> Callable[..., torch.Tensor]:
@functools.wraps(fn) @functools.wraps(fn)
def wrapper(*args): def wrapper(*args: Any) -> torch.Tensor:
return fn(*args)[0] return fn(*args)[0]
return wrapper return wrapper
...@@ -68,7 +72,7 @@ class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper): ...@@ -68,7 +72,7 @@ class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
super().__init__(epsilon, dtype, device) super().__init__(epsilon, dtype, device)
self.rmsnorm_matcher = MatcherRMSNorm(epsilon) self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
def get_inputs(self): def get_inputs(self) -> list[torch.Tensor]:
input = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype) input = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
arg3_1 = torch.empty([4], device=self.device, dtype=self.dtype) arg3_1 = torch.empty([4], device=self.device, dtype=self.dtype)
...@@ -78,7 +82,7 @@ class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper): ...@@ -78,7 +82,7 @@ class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
def pattern( def pattern(
input: torch.Tensor, input: torch.Tensor,
arg3_1: torch.Tensor, arg3_1: torch.Tensor,
): ) -> tuple[torch.Tensor, torch.Tensor]:
all_reduce = self._all_reduce(input) all_reduce = self._all_reduce(input)
rmsnorm = self.rmsnorm_matcher(all_reduce, arg3_1) rmsnorm = self.rmsnorm_matcher(all_reduce, arg3_1)
...@@ -87,7 +91,7 @@ class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper): ...@@ -87,7 +91,7 @@ class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
def replacement( def replacement(
input: torch.Tensor, input: torch.Tensor,
arg3_1: torch.Tensor, arg3_1: torch.Tensor,
): ) -> tuple[torch.Tensor, torch.Tensor]:
reduce_scatter = self._reduce_scatter(input) reduce_scatter = self._reduce_scatter(input)
rmsnorm = self.rmsnorm_matcher(reduce_scatter, arg3_1) rmsnorm = self.rmsnorm_matcher(reduce_scatter, arg3_1)
...@@ -100,11 +104,11 @@ class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper): ...@@ -100,11 +104,11 @@ class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper): class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
def __init__(self, epsilon: float, dtype: torch.dtype, device: str | None): def __init__(self, epsilon: float, dtype: torch.dtype, device: str | None) -> None:
super().__init__(epsilon, dtype, device) super().__init__(epsilon, dtype, device)
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon) self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
def get_inputs(self): def get_inputs(self) -> list[torch.Tensor]:
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype) mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype) residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
...@@ -116,7 +120,7 @@ class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper): ...@@ -116,7 +120,7 @@ class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
rms_norm_weights, rms_norm_weights,
] ]
def register(self, pm_pass: PatternMatcherPass): def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern( def pattern(
residual: torch.Tensor, residual: torch.Tensor,
mm_1: torch.Tensor, mm_1: torch.Tensor,
...@@ -163,23 +167,23 @@ class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper): ...@@ -163,23 +167,23 @@ class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
epsilon: float, epsilon: float,
dtype: torch.dtype, dtype: torch.dtype,
device: str | None, device: str | None,
): ) -> None:
super().__init__(epsilon, dtype, device) super().__init__(epsilon, dtype, device)
self.rmsnorm_matcher = MatcherRMSNorm(epsilon) self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym) self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
def get_inputs(self): def get_inputs(self) -> list[torch.Tensor]:
input = torch.zeros([1, 8, 4], device=self.device, dtype=self.dtype) input = torch.zeros([1, 8, 4], device=self.device, dtype=self.dtype)
weight = torch.empty([4], device=self.device, dtype=self.dtype) weight = torch.empty([4], device=self.device, dtype=self.dtype)
scale = torch.tensor(1.0, device=self.device, dtype=torch.float32) scale = torch.tensor(1.0, device=self.device, dtype=torch.float32)
return [input, weight, scale] return [input, weight, scale]
def register(self, pm_pass: PatternMatcherPass): def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern( def pattern(
input: torch.Tensor, input: torch.Tensor,
weight: torch.Tensor, weight: torch.Tensor,
scale: torch.Tensor, scale: torch.Tensor,
): ) -> tuple[torch.Tensor, torch.Tensor]:
all_reduce = self._all_reduce(input) all_reduce = self._all_reduce(input)
rms = self.rmsnorm_matcher(all_reduce, weight) rms = self.rmsnorm_matcher(all_reduce, weight)
quant, _ = self.quant_matcher(rms, scale) quant, _ = self.quant_matcher(rms, scale)
...@@ -189,7 +193,7 @@ class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper): ...@@ -189,7 +193,7 @@ class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
input: torch.Tensor, input: torch.Tensor,
weight: torch.Tensor, weight: torch.Tensor,
scale: torch.Tensor, scale: torch.Tensor,
): ) -> tuple[torch.Tensor, torch.Tensor]:
reduce_scatter = self._reduce_scatter(input) reduce_scatter = self._reduce_scatter(input)
rms = self.rmsnorm_matcher(reduce_scatter, weight) rms = self.rmsnorm_matcher(reduce_scatter, weight)
quant, _ = self.quant_matcher(rms, scale) quant, _ = self.quant_matcher(rms, scale)
...@@ -203,12 +207,12 @@ class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper): ...@@ -203,12 +207,12 @@ class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
class MiddleAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper): class MiddleAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
def __init__(self, epsilon: float, dtype: torch.dtype, device: str | None): def __init__(self, epsilon: float, dtype: torch.dtype, device: str | None) -> None:
super().__init__(epsilon, dtype, device) super().__init__(epsilon, dtype, device)
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon) self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym) self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
def get_inputs(self): def get_inputs(self) -> list[torch.Tensor]:
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype) mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype) residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype) rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype)
...@@ -216,7 +220,7 @@ class MiddleAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper): ...@@ -216,7 +220,7 @@ class MiddleAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
return [residual, mm_1, rms_norm_weights, scale] return [residual, mm_1, rms_norm_weights, scale]
def register(self, pm_pass: PatternMatcherPass): def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern( def pattern(
residual: torch.Tensor, residual: torch.Tensor,
mm_1: torch.Tensor, mm_1: torch.Tensor,
...@@ -302,7 +306,7 @@ class SequenceParallelismPass(VllmPatternMatcherPass): ...@@ -302,7 +306,7 @@ class SequenceParallelismPass(VllmPatternMatcherPass):
""" """
@enable_fake_mode @enable_fake_mode
def __init__(self, config: VllmConfig): def __init__(self, config: VllmConfig) -> None:
super().__init__(config) super().__init__(config)
# Used to clean up redundant views created temporarily # Used to clean up redundant views created temporarily
...@@ -357,7 +361,7 @@ class SequenceParallelismPass(VllmPatternMatcherPass): ...@@ -357,7 +361,7 @@ class SequenceParallelismPass(VllmPatternMatcherPass):
return (compile_range.is_single_size()) and (compile_range.end % tp_size == 0) return (compile_range.is_single_size()) and (compile_range.end % tp_size == 0)
@VllmInductorPass.time_and_log @VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph): def __call__(self, graph: fx.Graph) -> None:
self.matched_count = self.patterns.apply(graph) self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count) logger.debug("Replaced %s patterns", self.matched_count)
# Clean up reshape nodes # Clean up reshape nodes
......
...@@ -1529,22 +1529,22 @@ def patch_tensor_parallel_group(tp_group: GroupCoordinator): ...@@ -1529,22 +1529,22 @@ def patch_tensor_parallel_group(tp_group: GroupCoordinator):
_TP = old_tp_group _TP = old_tp_group
def get_tensor_model_parallel_world_size(): def get_tensor_model_parallel_world_size() -> int:
"""Return world size for the tensor model parallel group.""" """Return world size for the tensor model parallel group."""
return get_tp_group().world_size return get_tp_group().world_size
def get_tensor_model_parallel_rank(): def get_tensor_model_parallel_rank() -> int:
"""Return my rank for the tensor model parallel group.""" """Return my rank for the tensor model parallel group."""
return get_tp_group().rank_in_group return get_tp_group().rank_in_group
def get_decode_context_model_parallel_world_size(): def get_decode_context_model_parallel_world_size() -> int:
"""Return world size for the decode context model parallel group.""" """Return world size for the decode context model parallel group."""
return get_dcp_group().world_size return get_dcp_group().world_size
def get_decode_context_model_parallel_rank(): def get_decode_context_model_parallel_rank() -> int:
"""Return my rank for the decode context model parallel group.""" """Return my rank for the decode context model parallel group."""
return get_dcp_group().rank_in_group return get_dcp_group().rank_in_group
......
...@@ -20,7 +20,9 @@ from .phi3_long_rope_scaled_rope import Phi3LongRoPEScaledRotaryEmbedding ...@@ -20,7 +20,9 @@ from .phi3_long_rope_scaled_rope import Phi3LongRoPEScaledRotaryEmbedding
from .xdrope import XDRotaryEmbedding from .xdrope import XDRotaryEmbedding
from .yarn_scaling_rope import YaRNScalingRotaryEmbedding from .yarn_scaling_rope import YaRNScalingRotaryEmbedding
_ROPE_DICT: dict[tuple, RotaryEmbedding] = {} _ROPE_DICT: dict[tuple[Any, ...], RotaryEmbedding] = {}
__all__ = ["RotaryEmbedding"]
def get_rope( def get_rope(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment