Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ad8818bb
Unverified
Commit
ad8818bb
authored
Jan 12, 2026
by
Lucas Kabela
Committed by
GitHub
Jan 12, 2026
Browse files
[Misc][BE] Type coverage for vllm/compilation [3/3] (#31748)
Signed-off-by:
Lucas Kabela
<
lucaskabela@meta.com
>
parent
08e8e99c
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
333 additions
and
280 deletions
+333
-280
vllm/compilation/activation_quant_fusion.py
vllm/compilation/activation_quant_fusion.py
+33
-28
vllm/compilation/collective_fusion.py
vllm/compilation/collective_fusion.py
+111
-102
vllm/compilation/fusion.py
vllm/compilation/fusion.py
+58
-38
vllm/compilation/fusion_attn.py
vllm/compilation/fusion_attn.py
+24
-20
vllm/compilation/inductor_pass.py
vllm/compilation/inductor_pass.py
+1
-1
vllm/compilation/matcher_utils.py
vllm/compilation/matcher_utils.py
+23
-22
vllm/compilation/qk_norm_rope_fusion.py
vllm/compilation/qk_norm_rope_fusion.py
+16
-10
vllm/compilation/rocm_aiter_fusion.py
vllm/compilation/rocm_aiter_fusion.py
+38
-36
vllm/compilation/sequence_parallelism.py
vllm/compilation/sequence_parallelism.py
+22
-18
vllm/distributed/parallel_state.py
vllm/distributed/parallel_state.py
+4
-4
vllm/model_executor/layers/rotary_embedding/__init__.py
vllm/model_executor/layers/rotary_embedding/__init__.py
+3
-1
No files found.
vllm/compilation/activation_quant_fusion.py
View file @
ad8818bb
...
...
@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
abc
import
ABC
,
abstractmethod
from
typing
import
Any
import
torch
from
torch._higher_order_ops.auto_functionalize
import
auto_functionalized
...
...
@@ -52,7 +53,7 @@ class ActivationQuantPattern(ABC):
def
__init__
(
self
,
quant_key
:
QuantKey
,
):
)
->
None
:
self
.
quant_key
=
quant_key
self
.
quant_dtype
=
quant_key
.
dtype
...
...
@@ -68,12 +69,12 @@ class ActivationQuantPattern(ABC):
self
.
silu_and_mul_matcher
=
MatcherSiluAndMul
()
def
empty_quant
(
self
,
*
args
,
**
kwargs
)
:
def
empty_quant
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
torch
.
Tensor
:
kwargs
=
{
"dtype"
:
self
.
quant_dtype
,
"device"
:
"cuda"
,
**
kwargs
}
return
torch
.
empty
(
*
args
,
**
kwargs
)
@
abstractmethod
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
raise
NotImplementedError
...
...
@@ -82,15 +83,22 @@ class SiluMulFp8StaticQuantPattern(ActivationQuantPattern):
Fusion for SiluMul+Fp8StaticQuant Pattern
"""
def
__init__
(
self
):
def
__init__
(
self
)
->
None
:
super
().
__init__
(
kFp8StaticTensorSym
)
self
.
quant_matcher
=
MatcherQuantFP8
(
kFp8StaticTensorSym
)
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
get_inputs
(
self
)
->
list
[
torch
.
Tensor
]:
scale
=
self
.
quant_matcher
.
inputs
()[
1
]
return
[
*
self
.
silu_and_mul_matcher
.
inputs
(),
# input
scale
,
]
def
register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
def
pattern
(
input
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
):
)
->
torch
.
Tensor
:
result_silu_mul
=
self
.
silu_and_mul_matcher
(
input
)
result_quant
=
self
.
quant_matcher
(
result_silu_mul
,
scale
)
return
result_quant
[
0
]
...
...
@@ -98,7 +106,7 @@ class SiluMulFp8StaticQuantPattern(ActivationQuantPattern):
def
replacement
(
input
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
):
)
->
torch
.
Tensor
:
d
=
input
.
shape
[
-
1
]
//
2
output_shape
=
input
.
shape
[:
-
1
]
+
(
d
,)
result
=
torch
.
empty
(
...
...
@@ -109,13 +117,10 @@ class SiluMulFp8StaticQuantPattern(ActivationQuantPattern):
)
return
at
[
1
]
inputs
=
[
*
self
.
silu_and_mul_matcher
.
inputs
(),
# input
self
.
quant_matcher
.
inputs
()[
1
],
# scale
]
pattern
(
*
inputs
)
inps
=
self
.
get_inputs
()
pattern
(
*
inps
)
register_replacement
(
pattern
,
replacement
,
inp
ut
s
,
fwd_only
,
pm_pass
)
register_replacement
(
pattern
,
replacement
,
inps
,
fwd_only
,
pm_pass
)
class
SiluMulNvfp4QuantPattern
(
ActivationQuantPattern
):
...
...
@@ -123,16 +128,23 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
Fusion for SiluMul+Nvfp4Quant Pattern
"""
def
__init__
(
self
):
def
__init__
(
self
)
->
None
:
super
().
__init__
(
kNvfp4Quant
)
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
get_inputs
(
self
)
->
list
[
torch
.
Tensor
]:
result
=
self
.
empty_quant
(
5
,
32
)
output_scale
=
empty_i32
(
128
,
4
)
input_
=
empty_bf16
(
5
,
64
)
scale
=
empty_fp32
(
1
,
1
)
return
[
result
,
output_scale
,
input_
,
scale
]
def
register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
def
pattern
(
result
:
torch
.
Tensor
,
output_scale
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
result_silu_mul
=
self
.
silu_and_mul_matcher
(
input
)
at
=
auto_functionalized
(
self
.
QUANT_OP
,
...
...
@@ -148,7 +160,7 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
output_scale
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
at
=
auto_functionalized
(
self
.
FUSED_OP
,
result
=
result
,
...
...
@@ -158,14 +170,7 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
)
return
at
[
1
],
at
[
2
]
inputs
=
[
self
.
empty_quant
(
5
,
32
),
# result
empty_i32
(
128
,
4
),
# output_scale
empty_bf16
(
5
,
64
),
# input
empty_fp32
(
1
,
1
),
# scale
]
register_replacement
(
pattern
,
replacement
,
inputs
,
fwd_only
,
pm_pass
)
register_replacement
(
pattern
,
replacement
,
self
.
get_inputs
(),
fwd_only
,
pm_pass
)
class
ActivationQuantFusionPass
(
VllmPatternMatcherPass
):
...
...
@@ -179,7 +184,7 @@ class ActivationQuantFusionPass(VllmPatternMatcherPass):
"""
@
enable_fake_mode
def
__init__
(
self
,
config
:
VllmConfig
):
def
__init__
(
self
,
config
:
VllmConfig
)
->
None
:
super
().
__init__
(
config
)
self
.
patterns
:
PatternMatcherPass
=
PatternMatcherPass
(
...
...
@@ -196,11 +201,11 @@ class ActivationQuantFusionPass(VllmPatternMatcherPass):
self
.
dump_patterns
(
config
,
self
.
patterns
)
@
VllmInductorPass
.
time_and_log
def
__call__
(
self
,
graph
:
torch
.
fx
.
Graph
):
def
__call__
(
self
,
graph
:
torch
.
fx
.
Graph
)
->
None
:
self
.
matched_count
=
self
.
patterns
.
apply
(
graph
)
logger
.
debug
(
"Replaced %s patterns"
,
self
.
matched_count
)
def
uuid
(
self
):
def
uuid
(
self
)
->
str
:
return
VllmInductorPass
.
hash_source
(
self
,
ActivationQuantPattern
,
...
...
vllm/compilation/collective_fusion.py
View file @
ad8818bb
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
importlib.util
import
find_spec
from
types
import
ModuleType
import
torch
import
torch._inductor.pattern_matcher
as
pm
...
...
@@ -33,15 +34,15 @@ if find_spec("flashinfer"):
try
:
import
flashinfer.comm
as
flashinfer_comm
flashinfer_comm
=
(
flashinfer_comm
:
ModuleType
|
None
=
(
# type: ignore[no-redef]
flashinfer_comm
if
hasattr
(
flashinfer_comm
,
"trtllm_allreduce_fusion"
)
else
None
)
except
ImportError
:
flashinfer_comm
=
None
flashinfer_comm
=
None
# type: ignore[assignment]
else
:
flashinfer_comm
=
None
flashinfer_comm
=
None
# type: ignore[assignment]
logger
=
init_logger
(
__name__
)
...
...
@@ -58,13 +59,13 @@ class BasePattern:
class
GEMMReduceScatterPattern
(
BasePattern
):
def
get_inputs
(
self
):
def
get_inputs
(
self
)
->
list
[
torch
.
Tensor
]
:
mul
=
torch
.
empty
([
16
,
4
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
mm_weight
=
torch
.
empty
([
4
,
4
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
return
[
mul
,
mm_weight
]
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
pattern
(
mul
:
torch
.
Tensor
,
mm_weight
:
torch
.
Tensor
):
def
register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
def
pattern
(
mul
:
torch
.
Tensor
,
mm_weight
:
torch
.
Tensor
)
->
torch
.
Tensor
:
mm
=
torch
.
ops
.
aten
.
mm
.
default
(
mul
,
mm_weight
)
reduce_scatter
=
torch
.
ops
.
vllm
.
reduce_scatter
.
default
(
mm
,
...
...
@@ -74,7 +75,7 @@ class GEMMReduceScatterPattern(BasePattern):
)
return
reduce_scatter
def
replacement
(
mul
:
torch
.
Tensor
,
mm_weight
:
torch
.
Tensor
):
def
replacement
(
mul
:
torch
.
Tensor
,
mm_weight
:
torch
.
Tensor
)
->
torch
.
Tensor
:
gemm_rs
=
torch
.
ops
.
symm_mem
.
fused_matmul_reduce_scatter
(
mul
,
mm_weight
,
...
...
@@ -91,17 +92,17 @@ class GEMMReduceScatterPattern(BasePattern):
class
AllGatherGEMMPattern
(
BasePattern
):
def
get_inputs
(
self
):
def
get_inputs
(
self
)
->
list
[
torch
.
Tensor
]
:
x
=
torch
.
empty
([
4
,
4
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
weight
=
torch
.
empty
([
4
,
4
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
return
[
x
,
weight
]
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
def
pattern
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
)
->
torch
.
Tensor
:
all_gather
=
torch
.
ops
.
vllm
.
all_gather
.
default
(
x
,
dim
=
0
,
...
...
@@ -111,9 +112,7 @@ class AllGatherGEMMPattern(BasePattern):
return
torch
.
ops
.
aten
.
mm
.
default
(
all_gather
,
weight
)
def
replacement
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
def
replacement
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
)
->
torch
.
Tensor
:
ag_output
,
mm_outputs
=
torch
.
ops
.
symm_mem
.
fused_all_gather_matmul
(
x
,
[
weight
],
...
...
@@ -128,7 +127,7 @@ class AllGatherGEMMPattern(BasePattern):
class
ScaledMMReduceScatterPattern
(
BasePattern
):
def
get_inputs
(
self
):
def
get_inputs
(
self
)
->
list
[
torch
.
Tensor
]
:
input
=
torch
.
empty
([
16
,
16
],
device
=
self
.
device
,
dtype
=
FP8_DTYPE
)
mm_weight
=
(
torch
.
empty
([
16
,
16
],
device
=
self
.
device
,
dtype
=
FP8_DTYPE
)
...
...
@@ -139,7 +138,7 @@ class ScaledMMReduceScatterPattern(BasePattern):
scale_b
=
torch
.
empty
([
1
,
16
],
device
=
self
.
device
,
dtype
=
torch
.
float32
)
return
[
input
,
mm_weight
,
scale_a
,
scale_b
]
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
def
pattern
(
input
:
torch
.
Tensor
,
mat2
:
torch
.
Tensor
,
...
...
@@ -196,7 +195,7 @@ class ScaledMMReduceScatterPattern(BasePattern):
class
AllGatherScaledMMPattern
(
BasePattern
):
def
get_inputs
(
self
):
def
get_inputs
(
self
)
->
list
[
torch
.
Tensor
]
:
x
=
torch
.
empty
([
8
,
16
],
device
=
self
.
device
,
dtype
=
FP8_DTYPE
)
weight
=
(
torch
.
empty
([
16
,
16
],
device
=
self
.
device
,
dtype
=
FP8_DTYPE
)
...
...
@@ -211,7 +210,7 @@ class AllGatherScaledMMPattern(BasePattern):
return
[
x
,
weight
,
scale_a
,
scale_b
]
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
def
pattern
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
...
...
@@ -258,7 +257,7 @@ class AllGatherScaledMMPattern(BasePattern):
class
CutlassScaledMMReduceScatterPattern
(
BasePattern
):
def
get_inputs
(
self
):
def
get_inputs
(
self
)
->
list
[
torch
.
Tensor
]
:
input
=
torch
.
empty
([
16
,
16
],
device
=
self
.
device
,
dtype
=
FP8_DTYPE
)
mm_weight
=
(
torch
.
empty
([
16
,
16
],
device
=
self
.
device
,
dtype
=
FP8_DTYPE
)
...
...
@@ -271,7 +270,7 @@ class CutlassScaledMMReduceScatterPattern(BasePattern):
cutlass_mm_output
=
torch
.
empty
([
16
,
16
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
return
[
input
,
mm_weight
,
scale_a
,
scale_b
,
cutlass_mm_output
]
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
def
pattern
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
...
...
@@ -331,7 +330,7 @@ class CutlassScaledMMReduceScatterPattern(BasePattern):
class
AllGatherCutlassScaledMMPattern
(
BasePattern
):
def
get_inputs
(
self
):
def
get_inputs
(
self
)
->
list
[
torch
.
Tensor
]
:
x
=
torch
.
empty
([
8
,
16
],
device
=
self
.
device
,
dtype
=
FP8_DTYPE
)
weight
=
(
torch
.
empty
([
16
,
16
],
device
=
self
.
device
,
dtype
=
FP8_DTYPE
)
...
...
@@ -349,7 +348,7 @@ class AllGatherCutlassScaledMMPattern(BasePattern):
return
[
x
,
weight
,
scale_a
,
scale_b
,
output
]
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
def
pattern
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
...
...
@@ -400,7 +399,7 @@ class AllGatherCutlassScaledMMPattern(BasePattern):
class
AsyncTPPass
(
VllmPatternMatcherPass
):
@
enable_fake_mode
def
__init__
(
self
,
config
:
VllmConfig
):
def
__init__
(
self
,
config
:
VllmConfig
)
->
None
:
super
().
__init__
(
config
)
# Enable symmetric memory for the TP process group
...
...
@@ -445,7 +444,7 @@ class AsyncTPPass(VllmPatternMatcherPass):
return
compile_range
.
is_single_size
()
and
compile_range
.
end
%
tp_size
==
0
@
VllmInductorPass
.
time_and_log
def
__call__
(
self
,
graph
:
fx
.
Graph
):
def
__call__
(
self
,
graph
:
fx
.
Graph
)
->
None
:
self
.
matched_count
=
self
.
patterns
.
apply
(
graph
)
logger
.
debug
(
"Replaced %s patterns"
,
self
.
matched_count
)
...
...
@@ -512,11 +511,13 @@ if flashinfer_comm is not None:
f
"max token num
{
max_token_num
}
* hidden size
{
hidden_size
}
* "
f
"element size
{
element_size
}
"
)
device_capability
=
current_platform
.
get_device_capability
().
to_int
()
curr_device
=
current_platform
.
get_device_capability
()
device_capability
=
curr_device
.
to_int
()
if
curr_device
is
not
None
else
None
# Get one shot input size limit for the current world size
# for the current device capability
max_one_shot_size
=
_FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB
.
get
(
device_capability
,
{}
device_capability
,
# type: ignore[arg-type]
{},
).
get
(
world_size
,
None
)
# Use one shot if no max size is specified
use_oneshot
=
(
...
...
@@ -606,7 +607,7 @@ class FlashInferFusedAllReduceParams:
world_size
:
int
,
use_fp32_lamport
:
bool
=
False
,
max_token_num
:
int
=
1024
,
):
)
->
None
:
self
.
rank
=
rank
self
.
world_size
=
world_size
self
.
use_fp32_lamport
=
use_fp32_lamport
...
...
@@ -615,7 +616,7 @@ class FlashInferFusedAllReduceParams:
self
.
fp32_acc
=
True
self
.
max_token_num
=
max_token_num
def
get_trtllm_fused_allreduce_kwargs
(
self
):
def
get_trtllm_fused_allreduce_kwargs
(
self
)
->
dict
[
str
,
bool
|
int
]
:
return
{
"world_rank"
:
self
.
rank
,
"world_size"
:
self
.
world_size
,
...
...
@@ -639,26 +640,30 @@ class AllReduceRMSNormPattern(BasePattern):
dtype
:
torch
.
dtype
,
device
:
str
|
None
,
allreduce_params
:
FlashInferFusedAllReduceParams
,
):
)
->
None
:
super
().
__init__
(
dtype
,
device
)
self
.
epsilon
=
epsilon
self
.
allreduce_params
=
allreduce_params
self
.
rmsnorm_matcher
=
MatcherRMSNorm
(
epsilon
)
def
get_inputs
(
self
):
def
get_inputs
(
self
)
->
list
[
torch
.
Tensor
]
:
input
,
weight
=
self
.
rmsnorm_matcher
.
inputs
()
# input goes through allreduce first, always 16-bit
return
[
input
.
to
(
self
.
dtype
),
weight
]
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
pattern
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
):
def
register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
def
pattern
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
allreduce_output
=
tensor_model_parallel_all_reduce
(
input
)
rms
=
self
.
rmsnorm_matcher
(
allreduce_output
,
weight
)
return
rms
,
allreduce_output
def
replacement
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
):
def
replacement
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
residual
=
torch
.
zeros_like
(
input
)
rms_result
=
torch
.
empty_like
(
input
)
allreduce
=
auto_functionalized
(
...
...
@@ -694,27 +699,29 @@ class AllReduceFusedAddRMSNormPattern(BasePattern):
dtype
:
torch
.
dtype
,
device
:
str
|
None
,
allreduce_params
:
FlashInferFusedAllReduceParams
,
):
)
->
None
:
super
().
__init__
(
dtype
,
device
)
self
.
epsilon
=
epsilon
self
.
allreduce_params
=
allreduce_params
self
.
rmsnorm_matcher
=
MatcherFusedAddRMSNorm
(
epsilon
)
def
get_inputs
(
self
):
def
get_inputs
(
self
)
->
list
[
torch
.
Tensor
]
:
input
,
residual
,
weight
=
self
.
rmsnorm_matcher
.
inputs
()
# input goes through allreduce first, always 16-bit
return
[
residual
,
input
.
to
(
self
.
dtype
),
weight
]
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
pattern
(
residual
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
):
def
register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
def
pattern
(
residual
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
allreduce_output
=
tensor_model_parallel_all_reduce
(
input
)
rms
,
residual
=
self
.
rmsnorm_matcher
(
allreduce_output
,
weight
,
residual
)
return
rms
,
residual
def
replacement
(
residual
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
allreduce
=
auto_functionalized
(
flashinfer_trtllm_fused_allreduce_norm
,
allreduce_in
=
input
,
...
...
@@ -739,8 +746,8 @@ class AllReduceFusedAddRMSNormPattern(BasePattern):
first_return_only
=
lambda
fn
:
lambda
a
,
b
,
c
:
fn
(
a
,
b
,
c
)[
0
]
pm
.
register_replacement
(
first_return_only
(
pattern
),
first_return_only
(
replacement
),
first_return_only
(
pattern
),
# type: ignore[no-untyped-call]
first_return_only
(
replacement
),
# type: ignore[no-untyped-call]
self
.
get_inputs
(),
pm
.
fwd_only
,
pm_pass
,
...
...
@@ -761,7 +768,7 @@ class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern):
dtype
:
torch
.
dtype
,
device
:
str
|
None
,
allreduce_params
:
FlashInferFusedAllReduceParams
,
):
)
->
None
:
super
().
__init__
(
dtype
,
device
)
self
.
epsilon
=
epsilon
self
.
allreduce_params
=
allreduce_params
...
...
@@ -769,25 +776,27 @@ class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern):
self
.
rmsnorm_matcher
=
MatcherRMSNorm
(
epsilon
)
self
.
quant_matcher
=
MatcherQuantFP8
(
kFp8StaticTensorSym
)
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
get_inputs
():
input
,
weight
=
self
.
rmsnorm_matcher
.
inputs
()
_
,
scale
=
self
.
quant_matcher
.
inputs
()
def
get_inputs
(
self
)
->
list
[
torch
.
Tensor
]:
input
,
weight
=
self
.
rmsnorm_matcher
.
inputs
()
_
,
scale
=
self
.
quant_matcher
.
inputs
()
# input goes through allreduce first, always 16-bit
return
[
input
.
to
(
self
.
dtype
),
weight
,
scale
]
# input goes through allreduce first, always 16-bit
return
[
input
.
to
(
self
.
dtype
),
weight
,
scale
]
def
register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
def
pattern
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
all_reduce
=
tensor_model_parallel_all_reduce
(
input
)
rms
=
self
.
rmsnorm_matcher
(
all_reduce
,
weight
)
quant
,
_
=
self
.
quant_matcher
(
rms
,
scale
)
return
quant
,
all_reduce
def
replacement
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
):
def
replacement
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
residual
=
torch
.
zeros_like
(
input
)
result_rms
=
torch
.
empty_like
(
input
)
result_quant
=
torch
.
empty_like
(
input
,
dtype
=
self
.
quant_dtype
)
...
...
@@ -812,7 +821,7 @@ class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern):
return
allreduce
[
4
],
allreduce
[
1
]
pm
.
register_replacement
(
pattern
,
replacement
,
get_inputs
(),
pm
.
fwd_only
,
pm_pass
pattern
,
replacement
,
self
.
get_inputs
(),
pm
.
fwd_only
,
pm_pass
)
...
...
@@ -830,7 +839,7 @@ class AllReduceFusedAddRMSNormStaticQuantFP8Pattern(BasePattern):
dtype
:
torch
.
dtype
,
device
:
str
|
None
,
allreduce_params
:
FlashInferFusedAllReduceParams
,
):
)
->
None
:
super
().
__init__
(
dtype
,
device
)
self
.
epsilon
=
epsilon
self
.
allreduce_params
=
allreduce_params
...
...
@@ -839,20 +848,20 @@ class AllReduceFusedAddRMSNormStaticQuantFP8Pattern(BasePattern):
self
.
rmsnorm_matcher
=
MatcherFusedAddRMSNorm
(
epsilon
)
self
.
quant_matcher
=
MatcherQuantFP8
(
kFp8StaticTensorSym
)
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
get_inputs
():
input
,
residual
,
weight
=
self
.
rmsnorm_matcher
.
inputs
()
_
,
scale
=
self
.
quant_matcher
.
inputs
()
def
get_inputs
(
self
)
->
list
[
torch
.
Tensor
]:
input
,
residual
,
weight
=
self
.
rmsnorm_matcher
.
inputs
()
_
,
scale
=
self
.
quant_matcher
.
inputs
()
# input goes through allreduce first, always 16-bit
return
[
residual
,
input
.
to
(
self
.
dtype
),
weight
,
scale
]
# input goes through allreduce first, always 16-bit
return
[
residual
,
input
.
to
(
self
.
dtype
),
weight
,
scale
]
def
register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
def
pattern
(
residual
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
allreduce_output
=
tensor_model_parallel_all_reduce
(
input
)
rms
,
res
=
self
.
rmsnorm_matcher
(
allreduce_output
,
weight
,
residual
)
quant
,
_
=
self
.
quant_matcher
(
rms
,
scale
)
...
...
@@ -864,7 +873,7 @@ class AllReduceFusedAddRMSNormStaticQuantFP8Pattern(BasePattern):
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
result_quant
=
torch
.
empty_like
(
input
,
dtype
=
self
.
quant_dtype
)
allreduce
=
auto_functionalized
(
flashinfer_trtllm_fused_allreduce_norm
,
...
...
@@ -886,7 +895,7 @@ class AllReduceFusedAddRMSNormStaticQuantFP8Pattern(BasePattern):
return
allreduce
[
4
],
allreduce
[
2
]
pm
.
register_replacement
(
pattern
,
replacement
,
get_inputs
(),
pm
.
fwd_only
,
pm_pass
pattern
,
replacement
,
self
.
get_inputs
(),
pm
.
fwd_only
,
pm_pass
)
...
...
@@ -904,31 +913,31 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
dtype
:
torch
.
dtype
,
device
:
str
|
None
,
allreduce_params
:
FlashInferFusedAllReduceParams
,
):
)
->
None
:
super
().
__init__
(
dtype
,
device
)
self
.
epsilon
=
epsilon
self
.
allreduce_params
=
allreduce_params
self
.
rmsnorm_matcher
=
MatcherRMSNorm
(
epsilon
)
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
get_inputs
():
input
=
torch
.
empty
([
1
,
16
,
16
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
quant_result
=
torch
.
empty
((
16
,
8
),
device
=
self
.
device
,
dtype
=
torch
.
uint8
)
input_global_scale
=
torch
.
empty
(
[
1
,
1
],
device
=
self
.
device
,
dtype
=
torch
.
float32
)
weight
=
torch
.
empty
([
16
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
output_scale
=
torch
.
empty
([
128
,
4
],
device
=
self
.
device
,
dtype
=
torch
.
int32
)
def
get_inputs
(
self
)
->
list
[
torch
.
Tensor
]:
input
=
torch
.
empty
([
1
,
16
,
16
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
quant_result
=
torch
.
empty
((
16
,
8
),
device
=
self
.
device
,
dtype
=
torch
.
uint8
)
input_global_scale
=
torch
.
empty
(
[
1
,
1
],
device
=
self
.
device
,
dtype
=
torch
.
float32
)
weight
=
torch
.
empty
([
16
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
output_scale
=
torch
.
empty
([
128
,
4
],
device
=
self
.
device
,
dtype
=
torch
.
int32
)
return
[
input
,
quant_result
,
weight
,
input_global_scale
,
output_scale
]
return
[
input
,
quant_result
,
weight
,
input_global_scale
,
output_scale
]
def
register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
def
pattern
(
input
:
torch
.
Tensor
,
quant_result
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
input_global_scale
:
torch
.
Tensor
,
output_scale
:
torch
.
Tensor
,
):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
:
all_reduce
=
tensor_model_parallel_all_reduce
(
input
)
rms
=
self
.
rmsnorm_matcher
(
all_reduce
,
weight
)
quant_out_tuple
=
auto_functionalized
(
...
...
@@ -948,7 +957,7 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
weight
:
torch
.
Tensor
,
input_global_scale
:
torch
.
Tensor
,
output_scale
:
torch
.
Tensor
,
):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
:
residual
=
torch
.
zeros_like
(
input
)
result_rms
=
torch
.
empty_like
(
input
)
allreduce
=
auto_functionalized
(
...
...
@@ -972,7 +981,7 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
return
allreduce
[
4
],
allreduce
[
1
],
allreduce
[
5
]
pm
.
register_replacement
(
pattern
,
replacement
,
get_inputs
(),
pm
.
fwd_only
,
pm_pass
pattern
,
replacement
,
self
.
get_inputs
(),
pm
.
fwd_only
,
pm_pass
)
...
...
@@ -990,33 +999,33 @@ class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
dtype
:
torch
.
dtype
,
device
:
str
|
None
,
allreduce_params
:
FlashInferFusedAllReduceParams
,
):
)
->
None
:
super
().
__init__
(
dtype
,
device
)
self
.
epsilon
=
epsilon
self
.
allreduce_params
=
allreduce_params
self
.
rmsnorm_matcher
=
MatcherFusedAddRMSNorm
(
epsilon
)
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
get_inputs
():
input
=
torch
.
empty
([
16
,
16
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
residual
=
torch
.
empty
([
16
,
16
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
weight
=
torch
.
empty
([
16
,
16
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
quant_result
=
torch
.
empty
((
16
,
8
),
device
=
self
.
device
,
dtype
=
torch
.
uint8
)
input_global_scale
=
torch
.
empty
(
[
1
,
1
],
device
=
self
.
device
,
dtype
=
torch
.
float32
)
output_scale
=
torch
.
empty
([
128
,
4
],
device
=
self
.
device
,
dtype
=
torch
.
int32
)
return
[
quant_result
,
residual
,
input
,
output_scale
,
weight
,
input_global_scale
,
]
def
get_inputs
(
self
)
->
list
[
torch
.
Tensor
]:
input
=
torch
.
empty
([
16
,
16
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
residual
=
torch
.
empty
([
16
,
16
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
weight
=
torch
.
empty
([
16
,
16
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
quant_result
=
torch
.
empty
((
16
,
8
),
device
=
self
.
device
,
dtype
=
torch
.
uint8
)
input_global_scale
=
torch
.
empty
(
[
1
,
1
],
device
=
self
.
device
,
dtype
=
torch
.
float32
)
output_scale
=
torch
.
empty
([
128
,
4
],
device
=
self
.
device
,
dtype
=
torch
.
int32
)
return
[
quant_result
,
residual
,
input
,
output_scale
,
weight
,
input_global_scale
,
]
def
register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
def
pattern
(
quant_result
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
...
...
@@ -1024,7 +1033,7 @@ class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
output_scale
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
input_global_scale
:
torch
.
Tensor
,
):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
:
allreduce_output
=
tensor_model_parallel_all_reduce
(
input
)
rms
,
residual
=
self
.
rmsnorm_matcher
(
allreduce_output
,
weight
,
residual
)
quant_out_tuple
=
auto_functionalized
(
...
...
@@ -1045,7 +1054,7 @@ class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
output_scale
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
input_global_scale
:
torch
.
Tensor
,
):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
:
allreduce
=
auto_functionalized
(
flashinfer_trtllm_fused_allreduce_norm
,
allreduce_in
=
input
,
...
...
@@ -1066,12 +1075,12 @@ class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
return
allreduce
[
4
],
allreduce
[
2
],
allreduce
[
5
]
pm
.
register_replacement
(
pattern
,
replacement
,
get_inputs
(),
pm
.
fwd_only
,
pm_pass
pattern
,
replacement
,
self
.
get_inputs
(),
pm
.
fwd_only
,
pm_pass
)
class
AllReduceFusionPass
(
VllmPatternMatcherPass
):
def
__init__
(
self
,
config
:
VllmConfig
):
def
__init__
(
self
,
config
:
VllmConfig
)
->
None
:
super
().
__init__
(
config
)
self
.
disabled
=
True
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
...
...
@@ -1122,7 +1131,7 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
)
self
.
ipc_handles
,
workspace_tensor
=
(
flashinfer_comm
.
trtllm_create_ipc_workspace_for_all_reduce_fusion
(
flashinfer_comm
.
trtllm_create_ipc_workspace_for_all_reduce_fusion
(
# type: ignore[misc]
tp_rank
=
rank
,
tp_size
=
self
.
tp_size
,
max_token_num
=
self
.
max_token_num
,
...
...
@@ -1145,7 +1154,7 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
self
.
dump_patterns
(
config
,
self
.
patterns
)
@
enable_fake_mode
def
register_patterns
(
self
):
def
register_patterns
(
self
)
->
None
:
for
epsilon
in
[
1e-5
,
1e-6
]:
AllReduceFusedRMSNormStaticQuantFP8Pattern
(
epsilon
,
...
...
@@ -1198,7 +1207,7 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
return
compile_range
.
end
<=
self
.
max_token_num
@
VllmInductorPass
.
time_and_log
def
__call__
(
self
,
graph
:
fx
.
Graph
):
def
__call__
(
self
,
graph
:
fx
.
Graph
)
->
None
:
if
self
.
disabled
:
logger
.
debug
(
"AllReduceFusionPass disabled"
)
return
...
...
@@ -1206,7 +1215,7 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
self
.
matched_count
=
self
.
patterns
.
apply
(
graph
)
logger
.
debug
(
"Replaced %s patterns"
,
self
.
matched_count
)
def
__del__
(
self
):
def
__del__
(
self
)
->
None
:
if
getattr
(
self
,
"disabled"
,
True
):
return
if
flashinfer_comm
is
not
None
:
...
...
vllm/compilation/fusion.py
View file @
ad8818bb
...
...
@@ -38,19 +38,19 @@ FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE
=
torch
.
uint8
def
empty_bf16
(
*
args
,
**
kwargs
)
:
def
empty_bf16
(
*
args
:
Any
,
**
kwargs
:
Any
)
->
torch
.
Tensor
:
return
torch
.
empty
(
*
args
,
**
kwargs
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
def
empty_fp32
(
*
args
,
**
kwargs
)
:
def
empty_fp32
(
*
args
:
Any
,
**
kwargs
:
Any
)
->
torch
.
Tensor
:
return
torch
.
empty
(
*
args
,
**
kwargs
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
def
empty_i32
(
*
args
,
**
kwargs
)
:
def
empty_i32
(
*
args
:
Any
,
**
kwargs
:
Any
)
->
torch
.
Tensor
:
return
torch
.
empty
(
*
args
,
**
kwargs
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
def
empty_i64
(
*
args
,
**
kwargs
)
:
def
empty_i64
(
*
args
:
Any
,
**
kwargs
:
Any
)
->
torch
.
Tensor
:
return
torch
.
empty
(
*
args
,
**
kwargs
,
dtype
=
torch
.
int64
,
device
=
"cuda"
)
...
...
@@ -79,7 +79,7 @@ class FusedRMSQuantKey(NamedTuple):
quant
:
QuantKey
fused_add
:
bool
def
__str__
(
self
):
def
__str__
(
self
)
->
str
:
return
(
f
"FusedQuantKey(
{
self
.
quant
}
, with"
f
"
{
''
if
self
.
fused_add
else
'out'
}
residual)"
...
...
@@ -121,7 +121,7 @@ class RMSNormQuantPattern:
key
:
FusedRMSQuantKey
,
has_col_major_scales
:
bool
=
False
,
is_e8m0
:
bool
=
False
,
):
)
->
None
:
self
.
epsilon
=
epsilon
self
.
quant_dtype
=
key
.
quant
.
dtype
config
=
get_current_vllm_config
()
...
...
@@ -141,7 +141,9 @@ class RMSNormQuantPattern:
class
RMSNormStaticQuantPattern
(
RMSNormQuantPattern
):
def
__init__
(
self
,
epsilon
:
float
,
quant_dtype
:
torch
.
dtype
,
symmetric
=
True
):
def
__init__
(
self
,
epsilon
:
float
,
quant_dtype
:
torch
.
dtype
,
symmetric
:
bool
=
True
)
->
None
:
fused_key
=
FusedRMSQuantKey
(
fused_add
=
False
,
quant
=
QuantKey
(
...
...
@@ -150,13 +152,17 @@ class RMSNormStaticQuantPattern(RMSNormQuantPattern):
)
super
().
__init__
(
epsilon
,
fused_key
)
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
# Cannot use methods, as the self argument affects tracing
def
pattern
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
):
def
pattern
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
)
->
torch
.
Tensor
:
result_rms
=
self
.
rmsnorm_matcher
(
input
,
weight
)
return
self
.
quant_matcher
(
result_rms
,
scale
)[
0
]
def
replacement
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
):
def
replacement
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# In case we're matching native rms-norm, conversions might be
# optimized out. We convert here just to be safe.
input
=
input
.
to
(
dtype
=
self
.
model_dtype
)
...
...
@@ -187,7 +193,9 @@ class RMSNormStaticQuantPattern(RMSNormQuantPattern):
class
FusedAddRMSNormStaticQuantPattern
(
RMSNormQuantPattern
):
def
__init__
(
self
,
epsilon
:
float
,
quant_dtype
:
torch
.
dtype
,
symmetric
=
True
):
def
__init__
(
self
,
epsilon
:
float
,
quant_dtype
:
torch
.
dtype
,
symmetric
:
bool
=
True
)
->
None
:
key
=
FusedRMSQuantKey
(
fused_add
=
True
,
quant
=
QuantKey
(
...
...
@@ -196,13 +204,13 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
)
super
().
__init__
(
epsilon
,
key
)
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
def
pattern
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
result_rms
,
residual
=
self
.
rmsnorm_matcher
(
input
,
weight
,
residual
)
result
,
_
=
self
.
quant_matcher
(
result_rms
,
scale
)
...
...
@@ -213,7 +221,7 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
weight
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
# In case we're matching native rms-norm, conversions might be
# optimized out. We convert here just to be safe.
input
=
input
.
to
(
dtype
=
self
.
model_dtype
)
...
...
@@ -253,10 +261,10 @@ class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern):
epsilon
:
float
,
quant_dtype
:
torch
.
dtype
,
group_shape
:
GroupShape
,
symmetric
=
True
,
symmetric
:
bool
=
True
,
has_col_major_scales
:
bool
=
False
,
is_e8m0
:
bool
=
False
,
):
)
->
None
:
scale
=
ScaleDesc
(
torch
.
float32
,
False
,
group_shape
)
key
=
FusedRMSQuantKey
(
fused_add
=
True
,
...
...
@@ -269,15 +277,17 @@ class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern):
epsilon
,
key
,
has_col_major_scales
=
has_col_major_scales
,
is_e8m0
=
is_e8m0
)
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
pattern
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
):
def
register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
def
pattern
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
result_rms
,
residual
=
self
.
rmsnorm_matcher
(
input
,
weight
,
residual
)
result
,
scale
=
self
.
quant_matcher
(
result_rms
)
return
result
,
residual
,
scale
def
replacement
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
:
# In case we're matching native rms-norm, conversions might be
# optimized out. We convert here just to be safe.
input
=
input
.
to
(
dtype
=
self
.
model_dtype
)
...
...
@@ -315,10 +325,10 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern):
epsilon
:
float
,
quant_dtype
:
torch
.
dtype
,
group_shape
:
GroupShape
,
symmetric
=
True
,
symmetric
:
bool
=
True
,
has_col_major_scales
:
bool
=
False
,
is_e8m0
:
bool
=
False
,
):
)
->
None
:
scale
=
ScaleDesc
(
torch
.
float32
,
False
,
group_shape
)
key
=
FusedRMSQuantKey
(
fused_add
=
False
,
...
...
@@ -329,13 +339,17 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern):
epsilon
,
key
,
has_col_major_scales
=
has_col_major_scales
,
is_e8m0
=
is_e8m0
)
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
pattern
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
):
def
register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
def
pattern
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
result_rms
=
self
.
rmsnorm_matcher
(
input
,
weight
)
result
,
scale
=
self
.
quant_matcher
(
result_rms
)
return
result
,
scale
def
replacement
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
):
def
replacement
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# In case we're matching native rms-norm, conversions might be
# optimized out. We convert here just to be safe.
input
=
input
.
to
(
dtype
=
self
.
model_dtype
)
...
...
@@ -375,8 +389,8 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
epsilon
:
float
,
quant_dtype
:
torch
.
dtype
,
group_shape
:
GroupShape
=
GroupShape
.
PER_TOKEN
,
symmetric
=
True
,
):
symmetric
:
bool
=
True
,
)
->
None
:
scale
=
ScaleDesc
(
torch
.
float32
,
False
,
group_shape
)
key
=
FusedRMSQuantKey
(
fused_add
=
False
,
...
...
@@ -384,13 +398,17 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
)
super
().
__init__
(
epsilon
,
key
)
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
pattern
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
):
def
register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
def
pattern
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
result_rms
=
self
.
rmsnorm_matcher
(
input
,
weight
)
# result, scale
return
self
.
quant_matcher
(
result_rms
)
return
self
.
quant_matcher
(
result_rms
)
# type: ignore[no-any-return]
def
replacement
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
):
def
replacement
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# In case we're matching native rms-norm, conversions might be
# optimized out. We convert here just to be safe.
input
=
input
.
to
(
dtype
=
self
.
model_dtype
)
...
...
@@ -426,8 +444,8 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
epsilon
:
float
,
quant_dtype
:
torch
.
dtype
,
group_shape
:
GroupShape
=
GroupShape
.
PER_TOKEN
,
symmetric
=
True
,
):
symmetric
:
bool
=
True
,
)
->
None
:
scale
=
ScaleDesc
(
torch
.
float32
,
False
,
group_shape
)
key
=
FusedRMSQuantKey
(
fused_add
=
True
,
...
...
@@ -435,8 +453,10 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
)
super
().
__init__
(
epsilon
,
key
)
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
pattern
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
):
def
register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
def
pattern
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
result_rms
,
residual
=
self
.
rmsnorm_matcher
(
input
,
weight
,
residual
)
result
,
scale
=
self
.
quant_matcher
(
result_rms
)
...
...
@@ -444,7 +464,7 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
def
replacement
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
:
# In case we're matching native rms-norm, conversions might be
# optimized out. We convert here just to be safe.
input
=
input
.
to
(
dtype
=
self
.
model_dtype
)
...
...
@@ -481,7 +501,7 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass):
"""
@
enable_fake_mode
def
__init__
(
self
,
config
:
VllmConfig
):
def
__init__
(
self
,
config
:
VllmConfig
)
->
None
:
super
().
__init__
(
config
)
self
.
patterns
:
PatternMatcherPass
=
PatternMatcherPass
(
...
...
@@ -533,11 +553,11 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass):
self
.
dump_patterns
(
config
,
self
.
patterns
)
@
VllmInductorPass
.
time_and_log
def
__call__
(
self
,
graph
:
fx
.
Graph
):
def
__call__
(
self
,
graph
:
fx
.
Graph
)
->
None
:
self
.
matched_count
=
self
.
patterns
.
apply
(
graph
)
logger
.
debug
(
"Replaced %s patterns"
,
self
.
matched_count
)
def
uuid
(
self
)
->
Any
:
def
uuid
(
self
)
->
str
:
return
self
.
hash_source
(
self
,
RMSNormGroupQuantPattern
,
...
...
vllm/compilation/fusion_attn.py
View file @
ad8818bb
...
...
@@ -3,6 +3,7 @@
from
abc
import
ABC
,
abstractmethod
from
collections.abc
import
Callable
from
typing
import
Any
,
ParamSpec
import
torch
import
torch._inductor.pattern_matcher
as
pm
...
...
@@ -28,7 +29,7 @@ from .matcher_utils import MatcherQuantFP8
from
.vllm_inductor_pass
import
VllmInductorPass
,
VllmPatternMatcherPass
logger
=
init_logger
(
__name__
)
P
=
ParamSpec
(
"P"
)
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
FP4_DTYPE
=
torch
.
uint8
...
...
@@ -47,7 +48,7 @@ class AttentionQuantPattern(ABC):
layer
:
Attention
,
quant_key
:
QuantKey
,
dtype
:
torch
.
dtype
,
):
)
->
None
:
self
.
layer
=
layer
self
.
layer_name
=
layer
.
layer_name
self
.
num_heads
=
layer
.
num_heads
...
...
@@ -61,17 +62,20 @@ class AttentionQuantPattern(ABC):
)
self
.
QUANT_OP
=
QUANT_OPS
[
self
.
quant_key
]
def
empty
(
self
,
*
args
,
**
kwargs
)
:
def
empty
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
torch
.
Tensor
:
kwargs
=
{
"dtype"
:
self
.
dtype
,
"device"
:
"cuda"
,
**
kwargs
}
return
torch
.
empty
(
*
args
,
**
kwargs
)
def
empty_quant
(
self
,
*
args
,
**
kwargs
)
:
def
empty_quant
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
torch
.
Tensor
:
kwargs
=
{
"dtype"
:
self
.
quant_dtype
,
"device"
:
"cuda"
,
**
kwargs
}
return
torch
.
empty
(
*
args
,
**
kwargs
)
@
staticmethod
def
wrap_trace_fn
(
trace_fn
,
*
process_fx_fns
:
Callable
[[
fx
.
GraphModule
],
None
]):
def
wrapped
(
*
args
,
**
kwargs
):
def
wrap_trace_fn
(
trace_fn
:
Callable
[
P
,
fx
.
GraphModule
],
*
process_fx_fns
:
Callable
[[
fx
.
GraphModule
],
None
],
)
->
Callable
[
P
,
fx
.
GraphModule
]:
def
wrapped
(
*
args
:
P
.
args
,
**
kwargs
:
P
.
kwargs
)
->
fx
.
GraphModule
:
gm
=
trace_fn
(
*
args
,
**
kwargs
)
for
process_fx
in
process_fx_fns
:
process_fx
(
gm
)
...
...
@@ -81,13 +85,13 @@ class AttentionQuantPattern(ABC):
return
wrapped
@
staticmethod
def
fx_view_to_reshape
(
gm
:
torch
.
fx
.
GraphModule
):
def
fx_view_to_reshape
(
gm
:
torch
.
fx
.
GraphModule
)
->
None
:
from
torch._inductor.fx_passes.post_grad
import
view_to_reshape
view_to_reshape
(
gm
)
@
staticmethod
def
remove_noop_permutes
(
gm
:
torch
.
fx
.
GraphModule
):
def
remove_noop_permutes
(
gm
:
torch
.
fx
.
GraphModule
)
->
None
:
for
node
in
gm
.
graph
.
nodes
:
if
not
is_func
(
node
,
torch
.
ops
.
aten
.
permute
.
default
):
continue
...
...
@@ -100,12 +104,12 @@ class AttentionQuantPattern(ABC):
node
.
replace_all_uses_with
(
node
.
args
[
0
])
gm
.
graph
.
erase_node
(
node
)
def
register_if_supported
(
self
,
pm_pass
:
PatternMatcherPass
):
def
register_if_supported
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
if
self
.
layer
.
impl
.
fused_output_quant_supported
(
self
.
quant_key
):
self
.
_register
(
pm_pass
)
@
abstractmethod
def
_register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
_register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
raise
NotImplementedError
...
...
@@ -124,21 +128,21 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
layer
:
Attention
,
dtype
:
torch
.
dtype
,
symmetric
:
bool
=
True
,
):
)
->
None
:
quant_key
=
QuantKey
(
dtype
=
FP8_DTYPE
,
scale
=
kStaticTensorScale
,
symmetric
=
symmetric
)
super
().
__init__
(
layer
,
quant_key
,
dtype
)
self
.
quant_matcher
=
MatcherQuantFP8
(
quant_key
)
def
_register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
_register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
def
pattern
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
output_attn
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
):
)
->
torch
.
Tensor
:
at1
=
auto_functionalized
(
ATTN_OP
,
query
=
q
,
...
...
@@ -161,7 +165,7 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
v
:
torch
.
Tensor
,
output_attn
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
):
)
->
torch
.
Tensor
:
# attn output in quant_dtype
output_attn
=
torch
.
ops
.
aten
.
full
.
default
(
[
q
.
shape
[
0
],
self
.
num_heads
,
self
.
head_size
],
...
...
@@ -212,10 +216,10 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
will be passed into Attention op as the `output_scale` argument.
"""
def
__init__
(
self
,
layer
:
Attention
,
dtype
:
torch
.
dtype
):
def
__init__
(
self
,
layer
:
Attention
,
dtype
:
torch
.
dtype
)
->
None
:
super
().
__init__
(
layer
,
kNvfp4Quant
,
dtype
)
def
_register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
_register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
def
pattern
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
...
...
@@ -224,7 +228,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
output_quant
:
torch
.
Tensor
,
output_scale
:
torch
.
Tensor
,
input_scale
:
torch
.
Tensor
,
):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
at1
=
auto_functionalized
(
ATTN_OP
,
query
=
q
,
...
...
@@ -256,7 +260,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
output_quant
:
torch
.
Tensor
,
output_scale
:
torch
.
Tensor
,
input_scale
:
torch
.
Tensor
,
):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
# attention output in quant_dtype
output_attn
=
torch
.
ops
.
aten
.
full
.
default
(
[
q
.
shape
[
0
],
self
.
num_heads
,
self
.
head_size
//
2
],
...
...
@@ -318,7 +322,7 @@ class AttnFusionPass(VllmPatternMatcherPass):
"""
@
enable_fake_mode
def
__init__
(
self
,
config
:
VllmConfig
):
def
__init__
(
self
,
config
:
VllmConfig
)
->
None
:
super
().
__init__
(
config
)
self
.
patterns
=
PatternMatcherPass
(
pass_name
=
"attn_fusion_pass"
)
...
...
@@ -350,7 +354,7 @@ class AttnFusionPass(VllmPatternMatcherPass):
self
.
matched_count
=
self
.
patterns
.
apply
(
graph
)
logger
.
debug
(
"Fused quant onto %s attention nodes"
,
self
.
matched_count
)
def
uuid
(
self
):
def
uuid
(
self
)
->
str
:
return
VllmInductorPass
.
hash_source
(
self
,
AttentionQuantPattern
,
...
...
vllm/compilation/inductor_pass.py
View file @
ad8818bb
...
...
@@ -68,7 +68,7 @@ class InductorPass(CustomGraphPass): # type: ignore[misc]
This is defined as a convenience and should work in most cases.
"""
def
uuid
(
self
)
->
Any
:
def
uuid
(
self
)
->
str
:
"""
Provide a unique identifier for the pass, used in Inductor code cache.
This should depend on the pass implementation, so that changes to the
...
...
vllm/compilation/matcher_utils.py
View file @
ad8818bb
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
abc
import
ABC
,
abstractmethod
from
typing
import
Any
import
torch
from
torch._higher_order_ops
import
auto_functionalized
...
...
@@ -47,7 +48,7 @@ SILU_MUL_OP = torch.ops._C.silu_and_mul.default
class
MatcherCustomOp
(
ABC
):
def
__init__
(
self
,
enabled
:
bool
):
def
__init__
(
self
,
enabled
:
bool
)
->
None
:
config
=
get_current_vllm_config
()
self
.
model_dtype
=
config
.
model_config
.
dtype
if
config
.
model_config
else
None
self
.
device
=
config
.
device_config
.
device
if
config
.
device_config
else
None
...
...
@@ -56,24 +57,24 @@ class MatcherCustomOp(ABC):
self
.
forward
=
self
.
forward_custom
if
enabled
else
self
.
forward_native
@
abstractmethod
def
forward_custom
(
self
,
*
args
,
**
kw
s
)
:
def
forward_custom
(
self
,
*
args
:
Any
,
**
kw
args
:
Any
)
->
Any
:
pass
@
abstractmethod
def
forward_native
(
self
,
*
args
,
**
kw
s
)
:
def
forward_native
(
self
,
*
args
:
Any
,
**
kw
args
:
Any
)
->
Any
:
pass
def
__call__
(
self
,
*
args
,
**
kw
s
)
:
return
self
.
forward
(
*
args
,
**
kws
)
def
__call__
(
self
,
*
args
:
Any
,
**
kw
args
:
Any
)
->
Any
:
return
self
.
forward
(
*
args
,
**
kw
arg
s
)
def
empty
(
self
,
*
args
,
**
kw
s
)
:
return
torch
.
empty
(
*
args
,
dtype
=
self
.
model_dtype
,
device
=
self
.
device
,
**
kws
)
def
empty
(
self
,
*
args
:
Any
,
**
kw
args
:
Any
)
->
torch
.
Tensor
:
return
torch
.
empty
(
*
args
,
dtype
=
self
.
model_dtype
,
device
=
self
.
device
,
**
kw
arg
s
)
def
empty_int64
(
self
,
*
args
,
**
kw
s
)
:
return
torch
.
empty
(
*
args
,
dtype
=
torch
.
int64
,
device
=
self
.
device
,
**
kws
)
def
empty_int64
(
self
,
*
args
:
Any
,
**
kw
args
:
Any
)
->
torch
.
Tensor
:
return
torch
.
empty
(
*
args
,
dtype
=
torch
.
int64
,
device
=
self
.
device
,
**
kw
arg
s
)
def
empty_f32
(
self
,
*
args
,
**
kw
s
)
:
return
torch
.
empty
(
*
args
,
dtype
=
torch
.
float32
,
device
=
self
.
device
,
**
kws
)
def
empty_f32
(
self
,
*
args
:
Any
,
**
kw
args
:
Any
)
->
torch
.
Tensor
:
return
torch
.
empty
(
*
args
,
dtype
=
torch
.
float32
,
device
=
self
.
device
,
**
kw
arg
s
)
def
inputs
(
self
)
->
list
[
torch
.
Tensor
]:
"""Utility for inputs to the pattern"""
...
...
@@ -157,7 +158,7 @@ class MatcherRMSNorm(MatcherCustomOp):
epsilon
:
float
,
enabled
:
bool
|
None
=
None
,
match_rocm_aiter
:
bool
=
False
,
):
)
->
None
:
if
enabled
is
None
:
enabled
=
RMSNorm
.
enabled
()
...
...
@@ -169,7 +170,7 @@ class MatcherRMSNorm(MatcherCustomOp):
if
match_rocm_aiter
:
self
.
_rmsnorm_op
=
rocm_aiter_ops
.
get_rmsnorm_op
()
def
inputs
(
self
):
def
inputs
(
self
)
->
list
[
torch
.
Tensor
]
:
input
=
self
.
empty
(
5
,
16
)
if
self
.
enabled
else
self
.
empty_f32
(
5
,
16
)
weight
=
self
.
empty
(
16
)
return
[
input
,
weight
]
...
...
@@ -220,7 +221,7 @@ class MatcherFusedAddRMSNorm(MatcherCustomOp):
epsilon
:
float
,
enabled
:
bool
|
None
=
None
,
match_rocm_aiter
:
bool
=
False
,
):
)
->
None
:
if
enabled
is
None
:
enabled
=
RMSNorm
.
enabled
()
...
...
@@ -233,7 +234,7 @@ class MatcherFusedAddRMSNorm(MatcherCustomOp):
if
match_rocm_aiter
:
self
.
_rmsnorm_op
=
rocm_aiter_ops
.
get_rmsnorm_fused_add_op
()
def
inputs
(
self
):
def
inputs
(
self
)
->
list
[
torch
.
Tensor
]
:
input
=
self
.
empty
(
5
,
16
)
if
self
.
enabled
else
self
.
empty_f32
(
5
,
16
)
weight
=
self
.
empty
(
16
)
residual
=
self
.
empty
(
5
,
16
)
...
...
@@ -245,7 +246,7 @@ class MatcherFusedAddRMSNorm(MatcherCustomOp):
weight
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
return
self
.
_rmsnorm_op
(
return
self
.
_rmsnorm_op
(
# type: ignore[no-any-return]
x
=
input
,
residual
=
residual
,
weight
=
weight
,
variance_epsilon
=
self
.
epsilon
)
...
...
@@ -287,7 +288,7 @@ class MatcherQuantFP8(MatcherCustomOp):
has_col_major_scales
:
bool
=
False
,
is_e8m0
:
bool
=
False
,
match_rocm_aiter
:
bool
=
False
,
):
)
->
None
:
if
enabled
is
None
:
enabled
=
QuantFP8
.
enabled
()
...
...
@@ -340,13 +341,13 @@ class MatcherQuantFP8(MatcherCustomOp):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
quant_key_group_shape
=
self
.
quant_key
.
scale
.
group_shape
if
quant_key_group_shape
==
GroupShape
.
PER_TOKEN
:
return
self
.
QUANT_OP
(
return
self
.
QUANT_OP
(
# type: ignore[no-any-return]
x
=
input
,
quant_dtype
=
self
.
quant_key
.
dtype
,
scale
=
scale
,
)
else
:
return
self
.
QUANT_OP
(
input
,
quant_key_group_shape
.
col
)
return
self
.
QUANT_OP
(
input
,
quant_key_group_shape
.
col
)
# type: ignore[no-any-return]
def
forward_custom
(
self
,
...
...
@@ -400,9 +401,9 @@ class MatcherQuantFP8(MatcherCustomOp):
input
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
return
self
.
quant_fp8
(
input
,
scale
)
return
self
.
quant_fp8
(
input
,
scale
)
# type: ignore[no-any-return]
def
make_scale
(
self
,
input
:
torch
.
Tensor
,
transposed
:
bool
=
False
):
def
make_scale
(
self
,
input
:
torch
.
Tensor
,
transposed
:
bool
=
False
)
->
torch
.
Tensor
:
normalized_group_shape
=
_normalize_quant_group_shape
(
input
,
self
.
quant_key
.
scale
.
group_shape
)
...
...
@@ -427,7 +428,7 @@ class MatcherQuantFP8(MatcherCustomOp):
class
MatcherSiluAndMul
(
MatcherCustomOp
):
def
__init__
(
self
,
enabled
:
bool
|
None
=
None
):
def
__init__
(
self
,
enabled
:
bool
|
None
=
None
)
->
None
:
if
enabled
is
None
:
enabled
=
SiluAndMul
.
enabled
()
super
().
__init__
(
enabled
)
...
...
vllm/compilation/qk_norm_rope_fusion.py
View file @
ad8818bb
...
...
@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Callable
from
typing
import
ParamSpec
import
torch
import
torch._inductor.pattern_matcher
as
pm
...
...
@@ -23,6 +24,8 @@ logger = init_logger(__name__)
FUSED_QK_ROPE_OP
=
torch
.
ops
.
_C
.
fused_qk_norm_rope
.
default
P
=
ParamSpec
(
"P"
)
class
QkNormRopePattern
:
"""
...
...
@@ -72,7 +75,7 @@ class QkNormRopePattern:
use_flashinfer
=
self
.
rope_flashinfer
,
)
def
get_inputs
(
self
):
def
get_inputs
(
self
)
->
list
[
torch
.
Tensor
]
:
# Sample inputs to help pattern tracing
T
=
5
qkv
=
empty_bf16
(
T
,
self
.
q_size
+
2
*
self
.
kv_size
)
...
...
@@ -92,8 +95,11 @@ class QkNormRopePattern:
]
@
staticmethod
def
wrap_trace_fn
(
trace_fn
,
*
process_fx_fns
:
Callable
[[
fx
.
GraphModule
],
None
]):
def
wrapped
(
*
args
,
**
kwargs
):
def
wrap_trace_fn
(
trace_fn
:
Callable
[
P
,
fx
.
GraphModule
],
*
process_fx_fns
:
Callable
[[
fx
.
GraphModule
],
None
],
)
->
Callable
[
P
,
fx
.
GraphModule
]:
def
wrapped
(
*
args
:
P
.
args
,
**
kwargs
:
P
.
kwargs
)
->
fx
.
GraphModule
:
gm
=
trace_fn
(
*
args
,
**
kwargs
)
for
process_fx
in
process_fx_fns
:
process_fx
(
gm
)
...
...
@@ -103,19 +109,19 @@ class QkNormRopePattern:
return
wrapped
@
staticmethod
def
fx_view_to_reshape
(
gm
:
torch
.
fx
.
GraphModule
):
def
fx_view_to_reshape
(
gm
:
torch
.
fx
.
GraphModule
)
->
None
:
from
torch._inductor.fx_passes.post_grad
import
view_to_reshape
view_to_reshape
(
gm
)
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
def
pattern
(
qkv
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
q_weight
:
torch
.
Tensor
,
k_weight
:
torch
.
Tensor
,
cos_sin_cache
:
torch
.
Tensor
,
):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
:
# split qkv -> q,k,v
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
...
...
@@ -143,7 +149,7 @@ class QkNormRopePattern:
q_weight
:
torch
.
Tensor
,
k_weight
:
torch
.
Tensor
,
cos_sin_cache
:
torch
.
Tensor
,
):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
:
# Run fused qk_norm_rope op
result
=
auto_functionalized
(
FUSED_QK_ROPE_OP
,
...
...
@@ -162,7 +168,7 @@ class QkNormRopePattern:
result_qkv
=
result
[
1
]
# Split back to q,k,v and return
return
result_qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
return
result_qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
# type: ignore[no-any-return]
# NOTE: use fx_view_to_reshape to unify view/reshape to simplify
# pattern and increase matching opportunities
...
...
@@ -182,7 +188,7 @@ class QKNormRoPEFusionPass(VllmPatternMatcherPass):
"""Fuse Q/K RMSNorm + RoPE into fused_qk_norm_rope when the custom op exists."""
@
enable_fake_mode
def
__init__
(
self
,
config
:
VllmConfig
):
def
__init__
(
self
,
config
:
VllmConfig
)
->
None
:
super
().
__init__
(
config
)
self
.
patterns
:
PatternMatcherPass
=
PatternMatcherPass
(
pass_name
=
"qk_norm_rope_fusion_pass"
...
...
@@ -234,5 +240,5 @@ class QKNormRoPEFusionPass(VllmPatternMatcherPass):
self
.
matched_count
=
self
.
patterns
.
apply
(
graph
)
logger
.
debug
(
"Fused QK Norm+RoPE on %s sites"
,
self
.
matched_count
)
def
uuid
(
self
):
def
uuid
(
self
)
->
str
:
return
VllmInductorPass
.
hash_source
(
self
,
QkNormRopePattern
)
vllm/compilation/rocm_aiter_fusion.py
View file @
ad8818bb
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Any
import
torch
import
torch._inductor.pattern_matcher
as
pm
...
...
@@ -65,8 +64,8 @@ class AiterRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern):
quant_dtype
:
torch
.
dtype
,
match_aiter_quant
:
bool
=
True
,
group_shape
:
GroupShape
=
GroupShape
.
PER_TOKEN
,
symmetric
=
True
,
):
symmetric
:
bool
=
True
,
)
->
None
:
scale
=
ScaleDesc
(
torch
.
float32
,
False
,
group_shape
)
key
=
FusedRMSQuantKey
(
fused_add
=
False
,
...
...
@@ -75,11 +74,11 @@ class AiterRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern):
super
().
__init__
(
epsilon
,
key
,
match_aiter_quant
)
def
register
(
self
,
pm_pass
)
:
def
register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
def
pattern
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
result_rms
=
self
.
rmsnorm_matcher
(
input
,
weight
)
result
,
scale
=
self
.
quant_matcher
(
result_rms
)
return
result
,
scale
...
...
@@ -87,7 +86,7 @@ class AiterRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern):
def
replacement
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
result
=
self
.
FUSED_OP
(
x
=
input
,
weight
=
weight
,
...
...
@@ -117,8 +116,8 @@ class AiterFusedAddRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern):
quant_dtype
:
torch
.
dtype
,
match_aiter_quant
:
bool
=
True
,
group_shape
:
GroupShape
=
GroupShape
.
PER_TOKEN
,
symmetric
=
True
,
):
symmetric
:
bool
=
True
,
)
->
None
:
scale
=
ScaleDesc
(
torch
.
float32
,
False
,
group_shape
)
key
=
FusedRMSQuantKey
(
fused_add
=
True
,
...
...
@@ -127,12 +126,12 @@ class AiterFusedAddRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern):
super
().
__init__
(
epsilon
,
key
,
match_aiter_quant
)
def
register
(
self
,
pm_pass
)
:
def
register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
def
pattern
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
:
result_rms
,
residual_out
=
self
.
rmsnorm_matcher
(
input
,
weight
,
residual
)
result
,
scale
=
self
.
quant_matcher
(
result_rms
)
...
...
@@ -140,7 +139,7 @@ class AiterFusedAddRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern):
def
replacement
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
:
result
=
self
.
FUSED_OP
(
x
=
input
,
residual
=
residual
,
...
...
@@ -174,8 +173,8 @@ class AiterRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern):
quant_dtype
:
torch
.
dtype
,
group_shape
:
GroupShape
,
match_aiter_quant
:
bool
=
True
,
symmetric
=
True
,
):
symmetric
:
bool
=
True
,
)
->
None
:
scale
=
ScaleDesc
(
torch
.
float32
,
False
,
group_shape
)
key
=
FusedRMSQuantKey
(
fused_add
=
False
,
...
...
@@ -184,11 +183,11 @@ class AiterRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern):
super
().
__init__
(
epsilon
,
key
,
match_aiter_quant
)
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
def
pattern
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
result_rms
=
self
.
rmsnorm_matcher
(
input
,
weight
)
result
,
scale
=
self
.
quant_matcher
(
result_rms
)
return
result
,
scale
...
...
@@ -196,7 +195,7 @@ class AiterRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern):
def
replacement
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
at
=
self
.
FUSED_OP
(
x
=
input
,
weight
=
weight
,
...
...
@@ -225,8 +224,8 @@ class AiterFusedAddRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern):
quant_dtype
:
torch
.
dtype
,
group_shape
:
GroupShape
,
match_aiter_quant
:
bool
=
True
,
symmetric
=
True
,
):
symmetric
:
bool
=
True
,
)
->
None
:
scale
=
ScaleDesc
(
torch
.
float32
,
False
,
group_shape
)
key
=
FusedRMSQuantKey
(
fused_add
=
True
,
...
...
@@ -235,12 +234,12 @@ class AiterFusedAddRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern):
super
().
__init__
(
epsilon
,
key
,
match_aiter_quant
)
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
def
pattern
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
:
result_rms
,
residual_out
=
self
.
rmsnorm_matcher
(
input
,
weight
,
residual
)
result
,
scale
=
self
.
quant_matcher
(
result_rms
)
...
...
@@ -250,7 +249,7 @@ class AiterFusedAddRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern):
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
:
at
=
self
.
FUSED_OP
(
x
=
input
,
residual
=
residual
,
...
...
@@ -275,7 +274,7 @@ class RocmAiterRMSNormFusionPass(VllmPatternMatcherPass):
"""
@
enable_fake_mode
def
__init__
(
self
,
config
:
VllmConfig
):
def
__init__
(
self
,
config
:
VllmConfig
)
->
None
:
super
().
__init__
(
config
)
self
.
patterns
:
PatternMatcherPass
=
PatternMatcherPass
(
...
...
@@ -311,11 +310,11 @@ class RocmAiterRMSNormFusionPass(VllmPatternMatcherPass):
self
.
dump_patterns
(
config
,
self
.
patterns
)
@
VllmInductorPass
.
time_and_log
def
__call__
(
self
,
graph
:
fx
.
Graph
):
def
__call__
(
self
,
graph
:
fx
.
Graph
)
->
None
:
self
.
matched_count
=
self
.
patterns
.
apply
(
graph
)
logger
.
debug
(
"Replaced %s patterns"
,
self
.
matched_count
)
def
uuid
(
self
)
->
Any
:
def
uuid
(
self
)
->
str
:
fusion_patterns
=
[
AiterRMSNormDynamicQuantPattern
,
AiterFusedAddRMSNormDynamicQuantPattern
,
...
...
@@ -333,29 +332,32 @@ class AiterSiluMulFp8GroupQuantPattern(ActivationQuantPattern):
FUSED_SILU_MUL_QUANT_OP
=
rocm_aiter_ops
.
get_act_mul_fused_fp8_group_quant_op
()
def
__init__
(
self
,
quant_op
:
OpOverload
):
def
__init__
(
self
,
quant_op
:
OpOverload
)
->
None
:
self
.
silu_and_mul_matcher
=
MatcherSiluAndMul
()
self
.
quant_op
=
quant_op
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
get_inputs
(
self
)
->
list
[
torch
.
Tensor
]:
return
[
self
.
silu_and_mul_matcher
.
inputs
()[
0
],
]
def
register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
def
pattern
(
input
:
torch
.
Tensor
,
):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
at1
=
self
.
silu_and_mul_matcher
(
input
)
at2
=
self
.
quant_op
(
at1
,
128
)
return
at2
[
0
],
at2
[
1
]
def
replacement
(
input
:
torch
.
Tensor
,
):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
at
=
self
.
FUSED_SILU_MUL_QUANT_OP
(
x
=
input
,
group_size
=
128
)
return
at
[
0
],
at
[
1
]
inputs
=
[
self
.
silu_and_mul_matcher
.
inputs
()[
0
],
]
pm
.
register_replacement
(
pattern
,
replacement
,
inputs
,
pm
.
fwd_only
,
pm_pass
)
pm
.
register_replacement
(
pattern
,
replacement
,
self
.
get_inputs
(),
pm
.
fwd_only
,
pm_pass
)
class
RocmAiterSiluMulFp8GroupQuantFusionPass
(
VllmPatternMatcherPass
):
...
...
@@ -374,7 +376,7 @@ class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass):
QUANT_OPS
=
[
AITER_GROUP_FP8_QUANT_OP
,
TRITON_GROUP_FP8_QUANT_OP
]
@
enable_fake_mode
def
__init__
(
self
,
config
:
VllmConfig
):
def
__init__
(
self
,
config
:
VllmConfig
)
->
None
:
super
().
__init__
(
config
)
self
.
patterns
:
PatternMatcherPass
=
PatternMatcherPass
(
...
...
@@ -387,11 +389,11 @@ class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass):
self
.
dump_patterns
(
config
,
self
.
patterns
)
@
VllmInductorPass
.
time_and_log
def
__call__
(
self
,
graph
:
torch
.
fx
.
Graph
):
def
__call__
(
self
,
graph
:
torch
.
fx
.
Graph
)
->
None
:
self
.
matched_count
=
self
.
patterns
.
apply
(
graph
)
logger
.
debug
(
"Replaced %s patterns"
,
self
.
matched_count
)
def
uuid
(
self
):
def
uuid
(
self
)
->
str
:
fusion_patterns
=
[
ActivationQuantPattern
,
AiterSiluMulFp8GroupQuantPattern
,
...
...
vllm/compilation/sequence_parallelism.py
View file @
ad8818bb
...
...
@@ -2,6 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
functools
from
collections.abc
import
Callable
,
Sequence
from
typing
import
Any
import
torch
import
torch._inductor.pattern_matcher
as
pm
...
...
@@ -26,9 +28,11 @@ from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
logger
=
init_logger
(
__name__
)
def
get_first_out_wrapper
(
fn
):
def
get_first_out_wrapper
(
fn
:
Callable
[...,
Sequence
[
torch
.
Tensor
]],
)
->
Callable
[...,
torch
.
Tensor
]:
@
functools
.
wraps
(
fn
)
def
wrapper
(
*
args
)
:
def
wrapper
(
*
args
:
Any
)
->
torch
.
Tensor
:
return
fn
(
*
args
)[
0
]
return
wrapper
...
...
@@ -68,7 +72,7 @@ class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
super
().
__init__
(
epsilon
,
dtype
,
device
)
self
.
rmsnorm_matcher
=
MatcherRMSNorm
(
epsilon
)
def
get_inputs
(
self
):
def
get_inputs
(
self
)
->
list
[
torch
.
Tensor
]
:
input
=
torch
.
empty
([
1
,
8
,
4
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
arg3_1
=
torch
.
empty
([
4
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
...
...
@@ -78,7 +82,7 @@ class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
def
pattern
(
input
:
torch
.
Tensor
,
arg3_1
:
torch
.
Tensor
,
):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
all_reduce
=
self
.
_all_reduce
(
input
)
rmsnorm
=
self
.
rmsnorm_matcher
(
all_reduce
,
arg3_1
)
...
...
@@ -87,7 +91,7 @@ class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
def
replacement
(
input
:
torch
.
Tensor
,
arg3_1
:
torch
.
Tensor
,
):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
reduce_scatter
=
self
.
_reduce_scatter
(
input
)
rmsnorm
=
self
.
rmsnorm_matcher
(
reduce_scatter
,
arg3_1
)
...
...
@@ -100,11 +104,11 @@ class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
class
MiddleAllReduceRMSNormPattern
(
_SequenceParallelPatternHelper
):
def
__init__
(
self
,
epsilon
:
float
,
dtype
:
torch
.
dtype
,
device
:
str
|
None
):
def
__init__
(
self
,
epsilon
:
float
,
dtype
:
torch
.
dtype
,
device
:
str
|
None
)
->
None
:
super
().
__init__
(
epsilon
,
dtype
,
device
)
self
.
rmsnorm_matcher
=
MatcherFusedAddRMSNorm
(
epsilon
)
def
get_inputs
(
self
):
def
get_inputs
(
self
)
->
list
[
torch
.
Tensor
]
:
mm_1
=
torch
.
empty
([
4
,
4
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
residual
=
torch
.
empty
([
4
,
4
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
...
...
@@ -116,7 +120,7 @@ class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
rms_norm_weights
,
]
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
def
pattern
(
residual
:
torch
.
Tensor
,
mm_1
:
torch
.
Tensor
,
...
...
@@ -163,23 +167,23 @@ class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
epsilon
:
float
,
dtype
:
torch
.
dtype
,
device
:
str
|
None
,
):
)
->
None
:
super
().
__init__
(
epsilon
,
dtype
,
device
)
self
.
rmsnorm_matcher
=
MatcherRMSNorm
(
epsilon
)
self
.
quant_matcher
=
MatcherQuantFP8
(
kFp8StaticTensorSym
)
def
get_inputs
(
self
):
def
get_inputs
(
self
)
->
list
[
torch
.
Tensor
]
:
input
=
torch
.
zeros
([
1
,
8
,
4
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
weight
=
torch
.
empty
([
4
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
scale
=
torch
.
tensor
(
1.0
,
device
=
self
.
device
,
dtype
=
torch
.
float32
)
return
[
input
,
weight
,
scale
]
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
def
pattern
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
all_reduce
=
self
.
_all_reduce
(
input
)
rms
=
self
.
rmsnorm_matcher
(
all_reduce
,
weight
)
quant
,
_
=
self
.
quant_matcher
(
rms
,
scale
)
...
...
@@ -189,7 +193,7 @@ class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
reduce_scatter
=
self
.
_reduce_scatter
(
input
)
rms
=
self
.
rmsnorm_matcher
(
reduce_scatter
,
weight
)
quant
,
_
=
self
.
quant_matcher
(
rms
,
scale
)
...
...
@@ -203,12 +207,12 @@ class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
class
MiddleAllReduceRMSNormStaticFP8Pattern
(
_SequenceParallelPatternHelper
):
def
__init__
(
self
,
epsilon
:
float
,
dtype
:
torch
.
dtype
,
device
:
str
|
None
):
def
__init__
(
self
,
epsilon
:
float
,
dtype
:
torch
.
dtype
,
device
:
str
|
None
)
->
None
:
super
().
__init__
(
epsilon
,
dtype
,
device
)
self
.
rmsnorm_matcher
=
MatcherFusedAddRMSNorm
(
epsilon
)
self
.
quant_matcher
=
MatcherQuantFP8
(
kFp8StaticTensorSym
)
def
get_inputs
(
self
):
def
get_inputs
(
self
)
->
list
[
torch
.
Tensor
]
:
mm_1
=
torch
.
empty
([
4
,
4
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
residual
=
torch
.
empty
([
4
,
4
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
rms_norm_weights
=
torch
.
empty
([
4
,
4
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
...
...
@@ -216,7 +220,7 @@ class MiddleAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
return
[
residual
,
mm_1
,
rms_norm_weights
,
scale
]
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
def
pattern
(
residual
:
torch
.
Tensor
,
mm_1
:
torch
.
Tensor
,
...
...
@@ -302,7 +306,7 @@ class SequenceParallelismPass(VllmPatternMatcherPass):
"""
@
enable_fake_mode
def
__init__
(
self
,
config
:
VllmConfig
):
def
__init__
(
self
,
config
:
VllmConfig
)
->
None
:
super
().
__init__
(
config
)
# Used to clean up redundant views created temporarily
...
...
@@ -357,7 +361,7 @@ class SequenceParallelismPass(VllmPatternMatcherPass):
return
(
compile_range
.
is_single_size
())
and
(
compile_range
.
end
%
tp_size
==
0
)
@
VllmInductorPass
.
time_and_log
def
__call__
(
self
,
graph
:
fx
.
Graph
):
def
__call__
(
self
,
graph
:
fx
.
Graph
)
->
None
:
self
.
matched_count
=
self
.
patterns
.
apply
(
graph
)
logger
.
debug
(
"Replaced %s patterns"
,
self
.
matched_count
)
# Clean up reshape nodes
...
...
vllm/distributed/parallel_state.py
View file @
ad8818bb
...
...
@@ -1529,22 +1529,22 @@ def patch_tensor_parallel_group(tp_group: GroupCoordinator):
_TP
=
old_tp_group
def
get_tensor_model_parallel_world_size
():
def
get_tensor_model_parallel_world_size
()
->
int
:
"""Return world size for the tensor model parallel group."""
return
get_tp_group
().
world_size
def
get_tensor_model_parallel_rank
():
def
get_tensor_model_parallel_rank
()
->
int
:
"""Return my rank for the tensor model parallel group."""
return
get_tp_group
().
rank_in_group
def
get_decode_context_model_parallel_world_size
():
def
get_decode_context_model_parallel_world_size
()
->
int
:
"""Return world size for the decode context model parallel group."""
return
get_dcp_group
().
world_size
def
get_decode_context_model_parallel_rank
():
def
get_decode_context_model_parallel_rank
()
->
int
:
"""Return my rank for the decode context model parallel group."""
return
get_dcp_group
().
rank_in_group
...
...
vllm/model_executor/layers/rotary_embedding/__init__.py
View file @
ad8818bb
...
...
@@ -20,7 +20,9 @@ from .phi3_long_rope_scaled_rope import Phi3LongRoPEScaledRotaryEmbedding
from
.xdrope
import
XDRotaryEmbedding
from
.yarn_scaling_rope
import
YaRNScalingRotaryEmbedding
_ROPE_DICT
:
dict
[
tuple
,
RotaryEmbedding
]
=
{}
_ROPE_DICT
:
dict
[
tuple
[
Any
,
...],
RotaryEmbedding
]
=
{}
__all__
=
[
"RotaryEmbedding"
]
def
get_rope
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment