Unverified Commit ad8818bb authored by Lucas Kabela's avatar Lucas Kabela Committed by GitHub
Browse files

[Misc][BE] Type coverage for vllm/compilation [3/3] (#31748)


Signed-off-by: default avatarLucas Kabela <lucaskabela@meta.com>
parent 08e8e99c
......@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from typing import Any
import torch
from torch._higher_order_ops.auto_functionalize import auto_functionalized
......@@ -52,7 +53,7 @@ class ActivationQuantPattern(ABC):
def __init__(
self,
quant_key: QuantKey,
):
) -> None:
self.quant_key = quant_key
self.quant_dtype = quant_key.dtype
......@@ -68,12 +69,12 @@ class ActivationQuantPattern(ABC):
self.silu_and_mul_matcher = MatcherSiluAndMul()
def empty_quant(self, *args, **kwargs):
def empty_quant(self, *args: Any, **kwargs: Any) -> torch.Tensor:
kwargs = {"dtype": self.quant_dtype, "device": "cuda", **kwargs}
return torch.empty(*args, **kwargs)
@abstractmethod
def register(self, pm_pass: PatternMatcherPass):
def register(self, pm_pass: PatternMatcherPass) -> None:
raise NotImplementedError
......@@ -82,15 +83,22 @@ class SiluMulFp8StaticQuantPattern(ActivationQuantPattern):
Fusion for SiluMul+Fp8StaticQuant Pattern
"""
def __init__(self):
def __init__(self) -> None:
super().__init__(kFp8StaticTensorSym)
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
def register(self, pm_pass: PatternMatcherPass):
def get_inputs(self) -> list[torch.Tensor]:
scale = self.quant_matcher.inputs()[1]
return [
*self.silu_and_mul_matcher.inputs(), # input
scale,
]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor,
scale: torch.Tensor,
):
) -> torch.Tensor:
result_silu_mul = self.silu_and_mul_matcher(input)
result_quant = self.quant_matcher(result_silu_mul, scale)
return result_quant[0]
......@@ -98,7 +106,7 @@ class SiluMulFp8StaticQuantPattern(ActivationQuantPattern):
def replacement(
input: torch.Tensor,
scale: torch.Tensor,
):
) -> torch.Tensor:
d = input.shape[-1] // 2
output_shape = input.shape[:-1] + (d,)
result = torch.empty(
......@@ -109,13 +117,10 @@ class SiluMulFp8StaticQuantPattern(ActivationQuantPattern):
)
return at[1]
inputs = [
*self.silu_and_mul_matcher.inputs(), # input
self.quant_matcher.inputs()[1], # scale
]
pattern(*inputs)
inps = self.get_inputs()
pattern(*inps)
register_replacement(pattern, replacement, inputs, fwd_only, pm_pass)
register_replacement(pattern, replacement, inps, fwd_only, pm_pass)
class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
......@@ -123,16 +128,23 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
Fusion for SiluMul+Nvfp4Quant Pattern
"""
def __init__(self):
def __init__(self) -> None:
super().__init__(kNvfp4Quant)
def register(self, pm_pass: PatternMatcherPass):
def get_inputs(self) -> list[torch.Tensor]:
result = self.empty_quant(5, 32)
output_scale = empty_i32(128, 4)
input_ = empty_bf16(5, 64)
scale = empty_fp32(1, 1)
return [result, output_scale, input_, scale]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
result: torch.Tensor,
output_scale: torch.Tensor,
input: torch.Tensor,
scale: torch.Tensor,
):
) -> tuple[torch.Tensor, torch.Tensor]:
result_silu_mul = self.silu_and_mul_matcher(input)
at = auto_functionalized(
self.QUANT_OP,
......@@ -148,7 +160,7 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
output_scale: torch.Tensor,
input: torch.Tensor,
scale: torch.Tensor,
):
) -> tuple[torch.Tensor, torch.Tensor]:
at = auto_functionalized(
self.FUSED_OP,
result=result,
......@@ -158,14 +170,7 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
)
return at[1], at[2]
inputs = [
self.empty_quant(5, 32), # result
empty_i32(128, 4), # output_scale
empty_bf16(5, 64), # input
empty_fp32(1, 1), # scale
]
register_replacement(pattern, replacement, inputs, fwd_only, pm_pass)
register_replacement(pattern, replacement, self.get_inputs(), fwd_only, pm_pass)
class ActivationQuantFusionPass(VllmPatternMatcherPass):
......@@ -179,7 +184,7 @@ class ActivationQuantFusionPass(VllmPatternMatcherPass):
"""
@enable_fake_mode
def __init__(self, config: VllmConfig):
def __init__(self, config: VllmConfig) -> None:
super().__init__(config)
self.patterns: PatternMatcherPass = PatternMatcherPass(
......@@ -196,11 +201,11 @@ class ActivationQuantFusionPass(VllmPatternMatcherPass):
self.dump_patterns(config, self.patterns)
@VllmInductorPass.time_and_log
def __call__(self, graph: torch.fx.Graph):
def __call__(self, graph: torch.fx.Graph) -> None:
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)
def uuid(self):
def uuid(self) -> str:
return VllmInductorPass.hash_source(
self,
ActivationQuantPattern,
......
This diff is collapsed.
......@@ -38,19 +38,19 @@ FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8
def empty_bf16(*args, **kwargs):
def empty_bf16(*args: Any, **kwargs: Any) -> torch.Tensor:
return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda")
def empty_fp32(*args, **kwargs):
def empty_fp32(*args: Any, **kwargs: Any) -> torch.Tensor:
return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda")
def empty_i32(*args, **kwargs):
def empty_i32(*args: Any, **kwargs: Any) -> torch.Tensor:
return torch.empty(*args, **kwargs, dtype=torch.int32, device="cuda")
def empty_i64(*args, **kwargs):
def empty_i64(*args: Any, **kwargs: Any) -> torch.Tensor:
return torch.empty(*args, **kwargs, dtype=torch.int64, device="cuda")
......@@ -79,7 +79,7 @@ class FusedRMSQuantKey(NamedTuple):
quant: QuantKey
fused_add: bool
def __str__(self):
def __str__(self) -> str:
return (
f"FusedQuantKey({self.quant}, with"
f"{'' if self.fused_add else 'out'} residual)"
......@@ -121,7 +121,7 @@ class RMSNormQuantPattern:
key: FusedRMSQuantKey,
has_col_major_scales: bool = False,
is_e8m0: bool = False,
):
) -> None:
self.epsilon = epsilon
self.quant_dtype = key.quant.dtype
config = get_current_vllm_config()
......@@ -141,7 +141,9 @@ class RMSNormQuantPattern:
class RMSNormStaticQuantPattern(RMSNormQuantPattern):
def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True):
def __init__(
self, epsilon: float, quant_dtype: torch.dtype, symmetric: bool = True
) -> None:
fused_key = FusedRMSQuantKey(
fused_add=False,
quant=QuantKey(
......@@ -150,13 +152,17 @@ class RMSNormStaticQuantPattern(RMSNormQuantPattern):
)
super().__init__(epsilon, fused_key)
def register(self, pm_pass: PatternMatcherPass):
def register(self, pm_pass: PatternMatcherPass) -> None:
# Cannot use methods, as the self argument affects tracing
def pattern(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor):
def pattern(
input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
) -> torch.Tensor:
result_rms = self.rmsnorm_matcher(input, weight)
return self.quant_matcher(result_rms, scale)[0]
def replacement(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor):
def replacement(
input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
) -> torch.Tensor:
# In case we're matching native rms-norm, conversions might be
# optimized out. We convert here just to be safe.
input = input.to(dtype=self.model_dtype)
......@@ -187,7 +193,9 @@ class RMSNormStaticQuantPattern(RMSNormQuantPattern):
class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True):
def __init__(
self, epsilon: float, quant_dtype: torch.dtype, symmetric: bool = True
) -> None:
key = FusedRMSQuantKey(
fused_add=True,
quant=QuantKey(
......@@ -196,13 +204,13 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
)
super().__init__(epsilon, key)
def register(self, pm_pass: PatternMatcherPass):
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
scale: torch.Tensor,
):
) -> tuple[torch.Tensor, torch.Tensor]:
result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
result, _ = self.quant_matcher(result_rms, scale)
......@@ -213,7 +221,7 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
weight: torch.Tensor,
residual: torch.Tensor,
scale: torch.Tensor,
):
) -> tuple[torch.Tensor, torch.Tensor]:
# In case we're matching native rms-norm, conversions might be
# optimized out. We convert here just to be safe.
input = input.to(dtype=self.model_dtype)
......@@ -253,10 +261,10 @@ class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern):
epsilon: float,
quant_dtype: torch.dtype,
group_shape: GroupShape,
symmetric=True,
symmetric: bool = True,
has_col_major_scales: bool = False,
is_e8m0: bool = False,
):
) -> None:
scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey(
fused_add=True,
......@@ -269,15 +277,17 @@ class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern):
epsilon, key, has_col_major_scales=has_col_major_scales, is_e8m0=is_e8m0
)
def register(self, pm_pass: PatternMatcherPass):
def pattern(input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor):
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
result, scale = self.quant_matcher(result_rms)
return result, residual, scale
def replacement(
input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
):
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# In case we're matching native rms-norm, conversions might be
# optimized out. We convert here just to be safe.
input = input.to(dtype=self.model_dtype)
......@@ -315,10 +325,10 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern):
epsilon: float,
quant_dtype: torch.dtype,
group_shape: GroupShape,
symmetric=True,
symmetric: bool = True,
has_col_major_scales: bool = False,
is_e8m0: bool = False,
):
) -> None:
scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey(
fused_add=False,
......@@ -329,13 +339,17 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern):
epsilon, key, has_col_major_scales=has_col_major_scales, is_e8m0=is_e8m0
)
def register(self, pm_pass: PatternMatcherPass):
def pattern(input: torch.Tensor, weight: torch.Tensor):
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor, weight: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
result_rms = self.rmsnorm_matcher(input, weight)
result, scale = self.quant_matcher(result_rms)
return result, scale
def replacement(input: torch.Tensor, weight: torch.Tensor):
def replacement(
input: torch.Tensor, weight: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
# In case we're matching native rms-norm, conversions might be
# optimized out. We convert here just to be safe.
input = input.to(dtype=self.model_dtype)
......@@ -375,8 +389,8 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
epsilon: float,
quant_dtype: torch.dtype,
group_shape: GroupShape = GroupShape.PER_TOKEN,
symmetric=True,
):
symmetric: bool = True,
) -> None:
scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey(
fused_add=False,
......@@ -384,13 +398,17 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
)
super().__init__(epsilon, key)
def register(self, pm_pass: PatternMatcherPass):
def pattern(input: torch.Tensor, weight: torch.Tensor):
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor, weight: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
result_rms = self.rmsnorm_matcher(input, weight)
# result, scale
return self.quant_matcher(result_rms)
return self.quant_matcher(result_rms) # type: ignore[no-any-return]
def replacement(input: torch.Tensor, weight: torch.Tensor):
def replacement(
input: torch.Tensor, weight: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
# In case we're matching native rms-norm, conversions might be
# optimized out. We convert here just to be safe.
input = input.to(dtype=self.model_dtype)
......@@ -426,8 +444,8 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
epsilon: float,
quant_dtype: torch.dtype,
group_shape: GroupShape = GroupShape.PER_TOKEN,
symmetric=True,
):
symmetric: bool = True,
) -> None:
scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey(
fused_add=True,
......@@ -435,8 +453,10 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
)
super().__init__(epsilon, key)
def register(self, pm_pass: PatternMatcherPass):
def pattern(input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor):
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
result, scale = self.quant_matcher(result_rms)
......@@ -444,7 +464,7 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
def replacement(
input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
):
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# In case we're matching native rms-norm, conversions might be
# optimized out. We convert here just to be safe.
input = input.to(dtype=self.model_dtype)
......@@ -481,7 +501,7 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass):
"""
@enable_fake_mode
def __init__(self, config: VllmConfig):
def __init__(self, config: VllmConfig) -> None:
super().__init__(config)
self.patterns: PatternMatcherPass = PatternMatcherPass(
......@@ -533,11 +553,11 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass):
self.dump_patterns(config, self.patterns)
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph):
def __call__(self, graph: fx.Graph) -> None:
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)
def uuid(self) -> Any:
def uuid(self) -> str:
return self.hash_source(
self,
RMSNormGroupQuantPattern,
......
......@@ -3,6 +3,7 @@
from abc import ABC, abstractmethod
from collections.abc import Callable
from typing import Any, ParamSpec
import torch
import torch._inductor.pattern_matcher as pm
......@@ -28,7 +29,7 @@ from .matcher_utils import MatcherQuantFP8
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
logger = init_logger(__name__)
P = ParamSpec("P")
FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8
......@@ -47,7 +48,7 @@ class AttentionQuantPattern(ABC):
layer: Attention,
quant_key: QuantKey,
dtype: torch.dtype,
):
) -> None:
self.layer = layer
self.layer_name = layer.layer_name
self.num_heads = layer.num_heads
......@@ -61,17 +62,20 @@ class AttentionQuantPattern(ABC):
)
self.QUANT_OP = QUANT_OPS[self.quant_key]
def empty(self, *args, **kwargs):
def empty(self, *args: Any, **kwargs: Any) -> torch.Tensor:
kwargs = {"dtype": self.dtype, "device": "cuda", **kwargs}
return torch.empty(*args, **kwargs)
def empty_quant(self, *args, **kwargs):
def empty_quant(self, *args: Any, **kwargs: Any) -> torch.Tensor:
kwargs = {"dtype": self.quant_dtype, "device": "cuda", **kwargs}
return torch.empty(*args, **kwargs)
@staticmethod
def wrap_trace_fn(trace_fn, *process_fx_fns: Callable[[fx.GraphModule], None]):
def wrapped(*args, **kwargs):
def wrap_trace_fn(
trace_fn: Callable[P, fx.GraphModule],
*process_fx_fns: Callable[[fx.GraphModule], None],
) -> Callable[P, fx.GraphModule]:
def wrapped(*args: P.args, **kwargs: P.kwargs) -> fx.GraphModule:
gm = trace_fn(*args, **kwargs)
for process_fx in process_fx_fns:
process_fx(gm)
......@@ -81,13 +85,13 @@ class AttentionQuantPattern(ABC):
return wrapped
@staticmethod
def fx_view_to_reshape(gm: torch.fx.GraphModule):
def fx_view_to_reshape(gm: torch.fx.GraphModule) -> None:
from torch._inductor.fx_passes.post_grad import view_to_reshape
view_to_reshape(gm)
@staticmethod
def remove_noop_permutes(gm: torch.fx.GraphModule):
def remove_noop_permutes(gm: torch.fx.GraphModule) -> None:
for node in gm.graph.nodes:
if not is_func(node, torch.ops.aten.permute.default):
continue
......@@ -100,12 +104,12 @@ class AttentionQuantPattern(ABC):
node.replace_all_uses_with(node.args[0])
gm.graph.erase_node(node)
def register_if_supported(self, pm_pass: PatternMatcherPass):
def register_if_supported(self, pm_pass: PatternMatcherPass) -> None:
if self.layer.impl.fused_output_quant_supported(self.quant_key):
self._register(pm_pass)
@abstractmethod
def _register(self, pm_pass: PatternMatcherPass):
def _register(self, pm_pass: PatternMatcherPass) -> None:
raise NotImplementedError
......@@ -124,21 +128,21 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
layer: Attention,
dtype: torch.dtype,
symmetric: bool = True,
):
) -> None:
quant_key = QuantKey(
dtype=FP8_DTYPE, scale=kStaticTensorScale, symmetric=symmetric
)
super().__init__(layer, quant_key, dtype)
self.quant_matcher = MatcherQuantFP8(quant_key)
def _register(self, pm_pass: PatternMatcherPass):
def _register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
output_attn: torch.Tensor,
scale: torch.Tensor,
):
) -> torch.Tensor:
at1 = auto_functionalized(
ATTN_OP,
query=q,
......@@ -161,7 +165,7 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
v: torch.Tensor,
output_attn: torch.Tensor,
scale: torch.Tensor,
):
) -> torch.Tensor:
# attn output in quant_dtype
output_attn = torch.ops.aten.full.default(
[q.shape[0], self.num_heads, self.head_size],
......@@ -212,10 +216,10 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
will be passed into Attention op as the `output_scale` argument.
"""
def __init__(self, layer: Attention, dtype: torch.dtype):
def __init__(self, layer: Attention, dtype: torch.dtype) -> None:
super().__init__(layer, kNvfp4Quant, dtype)
def _register(self, pm_pass: PatternMatcherPass):
def _register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
q: torch.Tensor,
k: torch.Tensor,
......@@ -224,7 +228,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
output_quant: torch.Tensor,
output_scale: torch.Tensor,
input_scale: torch.Tensor,
):
) -> tuple[torch.Tensor, torch.Tensor]:
at1 = auto_functionalized(
ATTN_OP,
query=q,
......@@ -256,7 +260,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
output_quant: torch.Tensor,
output_scale: torch.Tensor,
input_scale: torch.Tensor,
):
) -> tuple[torch.Tensor, torch.Tensor]:
# attention output in quant_dtype
output_attn = torch.ops.aten.full.default(
[q.shape[0], self.num_heads, self.head_size // 2],
......@@ -318,7 +322,7 @@ class AttnFusionPass(VllmPatternMatcherPass):
"""
@enable_fake_mode
def __init__(self, config: VllmConfig):
def __init__(self, config: VllmConfig) -> None:
super().__init__(config)
self.patterns = PatternMatcherPass(pass_name="attn_fusion_pass")
......@@ -350,7 +354,7 @@ class AttnFusionPass(VllmPatternMatcherPass):
self.matched_count = self.patterns.apply(graph)
logger.debug("Fused quant onto %s attention nodes", self.matched_count)
def uuid(self):
def uuid(self) -> str:
return VllmInductorPass.hash_source(
self,
AttentionQuantPattern,
......
......@@ -68,7 +68,7 @@ class InductorPass(CustomGraphPass): # type: ignore[misc]
This is defined as a convenience and should work in most cases.
"""
def uuid(self) -> Any:
def uuid(self) -> str:
"""
Provide a unique identifier for the pass, used in Inductor code cache.
This should depend on the pass implementation, so that changes to the
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from typing import Any
import torch
from torch._higher_order_ops import auto_functionalized
......@@ -47,7 +48,7 @@ SILU_MUL_OP = torch.ops._C.silu_and_mul.default
class MatcherCustomOp(ABC):
def __init__(self, enabled: bool):
def __init__(self, enabled: bool) -> None:
config = get_current_vllm_config()
self.model_dtype = config.model_config.dtype if config.model_config else None
self.device = config.device_config.device if config.device_config else None
......@@ -56,24 +57,24 @@ class MatcherCustomOp(ABC):
self.forward = self.forward_custom if enabled else self.forward_native
@abstractmethod
def forward_custom(self, *args, **kws):
def forward_custom(self, *args: Any, **kwargs: Any) -> Any:
pass
@abstractmethod
def forward_native(self, *args, **kws):
def forward_native(self, *args: Any, **kwargs: Any) -> Any:
pass
def __call__(self, *args, **kws):
return self.forward(*args, **kws)
def __call__(self, *args: Any, **kwargs: Any) -> Any:
return self.forward(*args, **kwargs)
def empty(self, *args, **kws):
return torch.empty(*args, dtype=self.model_dtype, device=self.device, **kws)
def empty(self, *args: Any, **kwargs: Any) -> torch.Tensor:
return torch.empty(*args, dtype=self.model_dtype, device=self.device, **kwargs)
def empty_int64(self, *args, **kws):
return torch.empty(*args, dtype=torch.int64, device=self.device, **kws)
def empty_int64(self, *args: Any, **kwargs: Any) -> torch.Tensor:
return torch.empty(*args, dtype=torch.int64, device=self.device, **kwargs)
def empty_f32(self, *args, **kws):
return torch.empty(*args, dtype=torch.float32, device=self.device, **kws)
def empty_f32(self, *args: Any, **kwargs: Any) -> torch.Tensor:
return torch.empty(*args, dtype=torch.float32, device=self.device, **kwargs)
def inputs(self) -> list[torch.Tensor]:
"""Utility for inputs to the pattern"""
......@@ -157,7 +158,7 @@ class MatcherRMSNorm(MatcherCustomOp):
epsilon: float,
enabled: bool | None = None,
match_rocm_aiter: bool = False,
):
) -> None:
if enabled is None:
enabled = RMSNorm.enabled()
......@@ -169,7 +170,7 @@ class MatcherRMSNorm(MatcherCustomOp):
if match_rocm_aiter:
self._rmsnorm_op = rocm_aiter_ops.get_rmsnorm_op()
def inputs(self):
def inputs(self) -> list[torch.Tensor]:
input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16)
weight = self.empty(16)
return [input, weight]
......@@ -220,7 +221,7 @@ class MatcherFusedAddRMSNorm(MatcherCustomOp):
epsilon: float,
enabled: bool | None = None,
match_rocm_aiter: bool = False,
):
) -> None:
if enabled is None:
enabled = RMSNorm.enabled()
......@@ -233,7 +234,7 @@ class MatcherFusedAddRMSNorm(MatcherCustomOp):
if match_rocm_aiter:
self._rmsnorm_op = rocm_aiter_ops.get_rmsnorm_fused_add_op()
def inputs(self):
def inputs(self) -> list[torch.Tensor]:
input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16)
weight = self.empty(16)
residual = self.empty(5, 16)
......@@ -245,7 +246,7 @@ class MatcherFusedAddRMSNorm(MatcherCustomOp):
weight: torch.Tensor,
residual: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
return self._rmsnorm_op(
return self._rmsnorm_op( # type: ignore[no-any-return]
x=input, residual=residual, weight=weight, variance_epsilon=self.epsilon
)
......@@ -287,7 +288,7 @@ class MatcherQuantFP8(MatcherCustomOp):
has_col_major_scales: bool = False,
is_e8m0: bool = False,
match_rocm_aiter: bool = False,
):
) -> None:
if enabled is None:
enabled = QuantFP8.enabled()
......@@ -340,13 +341,13 @@ class MatcherQuantFP8(MatcherCustomOp):
) -> tuple[torch.Tensor, torch.Tensor]:
quant_key_group_shape = self.quant_key.scale.group_shape
if quant_key_group_shape == GroupShape.PER_TOKEN:
return self.QUANT_OP(
return self.QUANT_OP( # type: ignore[no-any-return]
x=input,
quant_dtype=self.quant_key.dtype,
scale=scale,
)
else:
return self.QUANT_OP(input, quant_key_group_shape.col)
return self.QUANT_OP(input, quant_key_group_shape.col) # type: ignore[no-any-return]
def forward_custom(
self,
......@@ -400,9 +401,9 @@ class MatcherQuantFP8(MatcherCustomOp):
input: torch.Tensor,
scale: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
return self.quant_fp8(input, scale)
return self.quant_fp8(input, scale) # type: ignore[no-any-return]
def make_scale(self, input: torch.Tensor, transposed: bool = False):
def make_scale(self, input: torch.Tensor, transposed: bool = False) -> torch.Tensor:
normalized_group_shape = _normalize_quant_group_shape(
input, self.quant_key.scale.group_shape
)
......@@ -427,7 +428,7 @@ class MatcherQuantFP8(MatcherCustomOp):
class MatcherSiluAndMul(MatcherCustomOp):
def __init__(self, enabled: bool | None = None):
def __init__(self, enabled: bool | None = None) -> None:
if enabled is None:
enabled = SiluAndMul.enabled()
super().__init__(enabled)
......
......@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from typing import ParamSpec
import torch
import torch._inductor.pattern_matcher as pm
......@@ -23,6 +24,8 @@ logger = init_logger(__name__)
FUSED_QK_ROPE_OP = torch.ops._C.fused_qk_norm_rope.default
P = ParamSpec("P")
class QkNormRopePattern:
"""
......@@ -72,7 +75,7 @@ class QkNormRopePattern:
use_flashinfer=self.rope_flashinfer,
)
def get_inputs(self):
def get_inputs(self) -> list[torch.Tensor]:
# Sample inputs to help pattern tracing
T = 5
qkv = empty_bf16(T, self.q_size + 2 * self.kv_size)
......@@ -92,8 +95,11 @@ class QkNormRopePattern:
]
@staticmethod
def wrap_trace_fn(trace_fn, *process_fx_fns: Callable[[fx.GraphModule], None]):
def wrapped(*args, **kwargs):
def wrap_trace_fn(
trace_fn: Callable[P, fx.GraphModule],
*process_fx_fns: Callable[[fx.GraphModule], None],
) -> Callable[P, fx.GraphModule]:
def wrapped(*args: P.args, **kwargs: P.kwargs) -> fx.GraphModule:
gm = trace_fn(*args, **kwargs)
for process_fx in process_fx_fns:
process_fx(gm)
......@@ -103,19 +109,19 @@ class QkNormRopePattern:
return wrapped
@staticmethod
def fx_view_to_reshape(gm: torch.fx.GraphModule):
def fx_view_to_reshape(gm: torch.fx.GraphModule) -> None:
from torch._inductor.fx_passes.post_grad import view_to_reshape
view_to_reshape(gm)
def register(self, pm_pass: PatternMatcherPass):
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
qkv: torch.Tensor,
positions: torch.Tensor,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
cos_sin_cache: torch.Tensor,
):
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# split qkv -> q,k,v
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
......@@ -143,7 +149,7 @@ class QkNormRopePattern:
q_weight: torch.Tensor,
k_weight: torch.Tensor,
cos_sin_cache: torch.Tensor,
):
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# Run fused qk_norm_rope op
result = auto_functionalized(
FUSED_QK_ROPE_OP,
......@@ -162,7 +168,7 @@ class QkNormRopePattern:
result_qkv = result[1]
# Split back to q,k,v and return
return result_qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
return result_qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) # type: ignore[no-any-return]
# NOTE: use fx_view_to_reshape to unify view/reshape to simplify
# pattern and increase matching opportunities
......@@ -182,7 +188,7 @@ class QKNormRoPEFusionPass(VllmPatternMatcherPass):
"""Fuse Q/K RMSNorm + RoPE into fused_qk_norm_rope when the custom op exists."""
@enable_fake_mode
def __init__(self, config: VllmConfig):
def __init__(self, config: VllmConfig) -> None:
super().__init__(config)
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="qk_norm_rope_fusion_pass"
......@@ -234,5 +240,5 @@ class QKNormRoPEFusionPass(VllmPatternMatcherPass):
self.matched_count = self.patterns.apply(graph)
logger.debug("Fused QK Norm+RoPE on %s sites", self.matched_count)
def uuid(self):
def uuid(self) -> str:
return VllmInductorPass.hash_source(self, QkNormRopePattern)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
import torch
import torch._inductor.pattern_matcher as pm
......@@ -65,8 +64,8 @@ class AiterRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern):
quant_dtype: torch.dtype,
match_aiter_quant: bool = True,
group_shape: GroupShape = GroupShape.PER_TOKEN,
symmetric=True,
):
symmetric: bool = True,
) -> None:
scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey(
fused_add=False,
......@@ -75,11 +74,11 @@ class AiterRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern):
super().__init__(epsilon, key, match_aiter_quant)
def register(self, pm_pass):
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor,
weight: torch.Tensor,
):
) -> tuple[torch.Tensor, torch.Tensor]:
result_rms = self.rmsnorm_matcher(input, weight)
result, scale = self.quant_matcher(result_rms)
return result, scale
......@@ -87,7 +86,7 @@ class AiterRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern):
def replacement(
input: torch.Tensor,
weight: torch.Tensor,
):
) -> tuple[torch.Tensor, torch.Tensor]:
result = self.FUSED_OP(
x=input,
weight=weight,
......@@ -117,8 +116,8 @@ class AiterFusedAddRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern):
quant_dtype: torch.dtype,
match_aiter_quant: bool = True,
group_shape: GroupShape = GroupShape.PER_TOKEN,
symmetric=True,
):
symmetric: bool = True,
) -> None:
scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey(
fused_add=True,
......@@ -127,12 +126,12 @@ class AiterFusedAddRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern):
super().__init__(epsilon, key, match_aiter_quant)
def register(self, pm_pass):
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
):
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
result_rms, residual_out = self.rmsnorm_matcher(input, weight, residual)
result, scale = self.quant_matcher(result_rms)
......@@ -140,7 +139,7 @@ class AiterFusedAddRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern):
def replacement(
input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
):
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
result = self.FUSED_OP(
x=input,
residual=residual,
......@@ -174,8 +173,8 @@ class AiterRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern):
quant_dtype: torch.dtype,
group_shape: GroupShape,
match_aiter_quant: bool = True,
symmetric=True,
):
symmetric: bool = True,
) -> None:
scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey(
fused_add=False,
......@@ -184,11 +183,11 @@ class AiterRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern):
super().__init__(epsilon, key, match_aiter_quant)
def register(self, pm_pass: PatternMatcherPass):
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor,
weight: torch.Tensor,
):
) -> tuple[torch.Tensor, torch.Tensor]:
result_rms = self.rmsnorm_matcher(input, weight)
result, scale = self.quant_matcher(result_rms)
return result, scale
......@@ -196,7 +195,7 @@ class AiterRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern):
def replacement(
input: torch.Tensor,
weight: torch.Tensor,
):
) -> tuple[torch.Tensor, torch.Tensor]:
at = self.FUSED_OP(
x=input,
weight=weight,
......@@ -225,8 +224,8 @@ class AiterFusedAddRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern):
quant_dtype: torch.dtype,
group_shape: GroupShape,
match_aiter_quant: bool = True,
symmetric=True,
):
symmetric: bool = True,
) -> None:
scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey(
fused_add=True,
......@@ -235,12 +234,12 @@ class AiterFusedAddRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern):
super().__init__(epsilon, key, match_aiter_quant)
def register(self, pm_pass: PatternMatcherPass):
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
):
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
result_rms, residual_out = self.rmsnorm_matcher(input, weight, residual)
result, scale = self.quant_matcher(result_rms)
......@@ -250,7 +249,7 @@ class AiterFusedAddRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern):
input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
):
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
at = self.FUSED_OP(
x=input,
residual=residual,
......@@ -275,7 +274,7 @@ class RocmAiterRMSNormFusionPass(VllmPatternMatcherPass):
"""
@enable_fake_mode
def __init__(self, config: VllmConfig):
def __init__(self, config: VllmConfig) -> None:
super().__init__(config)
self.patterns: PatternMatcherPass = PatternMatcherPass(
......@@ -311,11 +310,11 @@ class RocmAiterRMSNormFusionPass(VllmPatternMatcherPass):
self.dump_patterns(config, self.patterns)
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph):
def __call__(self, graph: fx.Graph) -> None:
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)
def uuid(self) -> Any:
def uuid(self) -> str:
fusion_patterns = [
AiterRMSNormDynamicQuantPattern,
AiterFusedAddRMSNormDynamicQuantPattern,
......@@ -333,29 +332,32 @@ class AiterSiluMulFp8GroupQuantPattern(ActivationQuantPattern):
FUSED_SILU_MUL_QUANT_OP = rocm_aiter_ops.get_act_mul_fused_fp8_group_quant_op()
def __init__(self, quant_op: OpOverload):
def __init__(self, quant_op: OpOverload) -> None:
self.silu_and_mul_matcher = MatcherSiluAndMul()
self.quant_op = quant_op
def register(self, pm_pass: PatternMatcherPass):
def get_inputs(self) -> list[torch.Tensor]:
return [
self.silu_and_mul_matcher.inputs()[0],
]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor,
):
) -> tuple[torch.Tensor, torch.Tensor]:
at1 = self.silu_and_mul_matcher(input)
at2 = self.quant_op(at1, 128)
return at2[0], at2[1]
def replacement(
input: torch.Tensor,
):
) -> tuple[torch.Tensor, torch.Tensor]:
at = self.FUSED_SILU_MUL_QUANT_OP(x=input, group_size=128)
return at[0], at[1]
inputs = [
self.silu_and_mul_matcher.inputs()[0],
]
pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass):
......@@ -374,7 +376,7 @@ class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass):
QUANT_OPS = [AITER_GROUP_FP8_QUANT_OP, TRITON_GROUP_FP8_QUANT_OP]
@enable_fake_mode
def __init__(self, config: VllmConfig):
def __init__(self, config: VllmConfig) -> None:
super().__init__(config)
self.patterns: PatternMatcherPass = PatternMatcherPass(
......@@ -387,11 +389,11 @@ class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass):
self.dump_patterns(config, self.patterns)
@VllmInductorPass.time_and_log
def __call__(self, graph: torch.fx.Graph):
def __call__(self, graph: torch.fx.Graph) -> None:
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)
def uuid(self):
def uuid(self) -> str:
fusion_patterns = [
ActivationQuantPattern,
AiterSiluMulFp8GroupQuantPattern,
......
......@@ -2,6 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
from collections.abc import Callable, Sequence
from typing import Any
import torch
import torch._inductor.pattern_matcher as pm
......@@ -26,9 +28,11 @@ from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
logger = init_logger(__name__)
def get_first_out_wrapper(fn):
def get_first_out_wrapper(
fn: Callable[..., Sequence[torch.Tensor]],
) -> Callable[..., torch.Tensor]:
@functools.wraps(fn)
def wrapper(*args):
def wrapper(*args: Any) -> torch.Tensor:
return fn(*args)[0]
return wrapper
......@@ -68,7 +72,7 @@ class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
super().__init__(epsilon, dtype, device)
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
def get_inputs(self):
def get_inputs(self) -> list[torch.Tensor]:
input = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
arg3_1 = torch.empty([4], device=self.device, dtype=self.dtype)
......@@ -78,7 +82,7 @@ class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
def pattern(
input: torch.Tensor,
arg3_1: torch.Tensor,
):
) -> tuple[torch.Tensor, torch.Tensor]:
all_reduce = self._all_reduce(input)
rmsnorm = self.rmsnorm_matcher(all_reduce, arg3_1)
......@@ -87,7 +91,7 @@ class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
def replacement(
input: torch.Tensor,
arg3_1: torch.Tensor,
):
) -> tuple[torch.Tensor, torch.Tensor]:
reduce_scatter = self._reduce_scatter(input)
rmsnorm = self.rmsnorm_matcher(reduce_scatter, arg3_1)
......@@ -100,11 +104,11 @@ class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
def __init__(self, epsilon: float, dtype: torch.dtype, device: str | None):
def __init__(self, epsilon: float, dtype: torch.dtype, device: str | None) -> None:
super().__init__(epsilon, dtype, device)
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
def get_inputs(self):
def get_inputs(self) -> list[torch.Tensor]:
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
......@@ -116,7 +120,7 @@ class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
rms_norm_weights,
]
def register(self, pm_pass: PatternMatcherPass):
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
residual: torch.Tensor,
mm_1: torch.Tensor,
......@@ -163,23 +167,23 @@ class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
epsilon: float,
dtype: torch.dtype,
device: str | None,
):
) -> None:
super().__init__(epsilon, dtype, device)
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
def get_inputs(self):
def get_inputs(self) -> list[torch.Tensor]:
input = torch.zeros([1, 8, 4], device=self.device, dtype=self.dtype)
weight = torch.empty([4], device=self.device, dtype=self.dtype)
scale = torch.tensor(1.0, device=self.device, dtype=torch.float32)
return [input, weight, scale]
def register(self, pm_pass: PatternMatcherPass):
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
):
) -> tuple[torch.Tensor, torch.Tensor]:
all_reduce = self._all_reduce(input)
rms = self.rmsnorm_matcher(all_reduce, weight)
quant, _ = self.quant_matcher(rms, scale)
......@@ -189,7 +193,7 @@ class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
input: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
):
) -> tuple[torch.Tensor, torch.Tensor]:
reduce_scatter = self._reduce_scatter(input)
rms = self.rmsnorm_matcher(reduce_scatter, weight)
quant, _ = self.quant_matcher(rms, scale)
......@@ -203,12 +207,12 @@ class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
class MiddleAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
def __init__(self, epsilon: float, dtype: torch.dtype, device: str | None):
def __init__(self, epsilon: float, dtype: torch.dtype, device: str | None) -> None:
super().__init__(epsilon, dtype, device)
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
def get_inputs(self):
def get_inputs(self) -> list[torch.Tensor]:
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype)
......@@ -216,7 +220,7 @@ class MiddleAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
return [residual, mm_1, rms_norm_weights, scale]
def register(self, pm_pass: PatternMatcherPass):
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
residual: torch.Tensor,
mm_1: torch.Tensor,
......@@ -302,7 +306,7 @@ class SequenceParallelismPass(VllmPatternMatcherPass):
"""
@enable_fake_mode
def __init__(self, config: VllmConfig):
def __init__(self, config: VllmConfig) -> None:
super().__init__(config)
# Used to clean up redundant views created temporarily
......@@ -357,7 +361,7 @@ class SequenceParallelismPass(VllmPatternMatcherPass):
return (compile_range.is_single_size()) and (compile_range.end % tp_size == 0)
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph):
def __call__(self, graph: fx.Graph) -> None:
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)
# Clean up reshape nodes
......
......@@ -1529,22 +1529,22 @@ def patch_tensor_parallel_group(tp_group: GroupCoordinator):
_TP = old_tp_group
def get_tensor_model_parallel_world_size():
def get_tensor_model_parallel_world_size() -> int:
"""Return world size for the tensor model parallel group."""
return get_tp_group().world_size
def get_tensor_model_parallel_rank():
def get_tensor_model_parallel_rank() -> int:
"""Return my rank for the tensor model parallel group."""
return get_tp_group().rank_in_group
def get_decode_context_model_parallel_world_size():
def get_decode_context_model_parallel_world_size() -> int:
"""Return world size for the decode context model parallel group."""
return get_dcp_group().world_size
def get_decode_context_model_parallel_rank():
def get_decode_context_model_parallel_rank() -> int:
"""Return my rank for the decode context model parallel group."""
return get_dcp_group().rank_in_group
......
......@@ -20,7 +20,9 @@ from .phi3_long_rope_scaled_rope import Phi3LongRoPEScaledRotaryEmbedding
from .xdrope import XDRotaryEmbedding
from .yarn_scaling_rope import YaRNScalingRotaryEmbedding
_ROPE_DICT: dict[tuple, RotaryEmbedding] = {}
_ROPE_DICT: dict[tuple[Any, ...], RotaryEmbedding] = {}
__all__ = ["RotaryEmbedding"]
def get_rope(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment