Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ad8818bb
Unverified
Commit
ad8818bb
authored
Jan 12, 2026
by
Lucas Kabela
Committed by
GitHub
Jan 12, 2026
Browse files
[Misc][BE] Type coverage for vllm/compilation [3/3] (#31748)
Signed-off-by:
Lucas Kabela
<
lucaskabela@meta.com
>
parent
08e8e99c
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
333 additions
and
280 deletions
+333
-280
vllm/compilation/activation_quant_fusion.py
vllm/compilation/activation_quant_fusion.py
+33
-28
vllm/compilation/collective_fusion.py
vllm/compilation/collective_fusion.py
+111
-102
vllm/compilation/fusion.py
vllm/compilation/fusion.py
+58
-38
vllm/compilation/fusion_attn.py
vllm/compilation/fusion_attn.py
+24
-20
vllm/compilation/inductor_pass.py
vllm/compilation/inductor_pass.py
+1
-1
vllm/compilation/matcher_utils.py
vllm/compilation/matcher_utils.py
+23
-22
vllm/compilation/qk_norm_rope_fusion.py
vllm/compilation/qk_norm_rope_fusion.py
+16
-10
vllm/compilation/rocm_aiter_fusion.py
vllm/compilation/rocm_aiter_fusion.py
+38
-36
vllm/compilation/sequence_parallelism.py
vllm/compilation/sequence_parallelism.py
+22
-18
vllm/distributed/parallel_state.py
vllm/distributed/parallel_state.py
+4
-4
vllm/model_executor/layers/rotary_embedding/__init__.py
vllm/model_executor/layers/rotary_embedding/__init__.py
+3
-1
No files found.
vllm/compilation/activation_quant_fusion.py
View file @
ad8818bb
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
typing
import
Any
import
torch
import
torch
from
torch._higher_order_ops.auto_functionalize
import
auto_functionalized
from
torch._higher_order_ops.auto_functionalize
import
auto_functionalized
...
@@ -52,7 +53,7 @@ class ActivationQuantPattern(ABC):
...
@@ -52,7 +53,7 @@ class ActivationQuantPattern(ABC):
def
__init__
(
def
__init__
(
self
,
self
,
quant_key
:
QuantKey
,
quant_key
:
QuantKey
,
):
)
->
None
:
self
.
quant_key
=
quant_key
self
.
quant_key
=
quant_key
self
.
quant_dtype
=
quant_key
.
dtype
self
.
quant_dtype
=
quant_key
.
dtype
...
@@ -68,12 +69,12 @@ class ActivationQuantPattern(ABC):
...
@@ -68,12 +69,12 @@ class ActivationQuantPattern(ABC):
self
.
silu_and_mul_matcher
=
MatcherSiluAndMul
()
self
.
silu_and_mul_matcher
=
MatcherSiluAndMul
()
def
empty_quant
(
self
,
*
args
,
**
kwargs
)
:
def
empty_quant
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
torch
.
Tensor
:
kwargs
=
{
"dtype"
:
self
.
quant_dtype
,
"device"
:
"cuda"
,
**
kwargs
}
kwargs
=
{
"dtype"
:
self
.
quant_dtype
,
"device"
:
"cuda"
,
**
kwargs
}
return
torch
.
empty
(
*
args
,
**
kwargs
)
return
torch
.
empty
(
*
args
,
**
kwargs
)
@
abstractmethod
@
abstractmethod
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
raise
NotImplementedError
raise
NotImplementedError
...
@@ -82,15 +83,22 @@ class SiluMulFp8StaticQuantPattern(ActivationQuantPattern):
...
@@ -82,15 +83,22 @@ class SiluMulFp8StaticQuantPattern(ActivationQuantPattern):
Fusion for SiluMul+Fp8StaticQuant Pattern
Fusion for SiluMul+Fp8StaticQuant Pattern
"""
"""
def
__init__
(
self
):
def
__init__
(
self
)
->
None
:
super
().
__init__
(
kFp8StaticTensorSym
)
super
().
__init__
(
kFp8StaticTensorSym
)
self
.
quant_matcher
=
MatcherQuantFP8
(
kFp8StaticTensorSym
)
self
.
quant_matcher
=
MatcherQuantFP8
(
kFp8StaticTensorSym
)
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
get_inputs
(
self
)
->
list
[
torch
.
Tensor
]:
scale
=
self
.
quant_matcher
.
inputs
()[
1
]
return
[
*
self
.
silu_and_mul_matcher
.
inputs
(),
# input
scale
,
]
def
register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
def
pattern
(
def
pattern
(
input
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
):
)
->
torch
.
Tensor
:
result_silu_mul
=
self
.
silu_and_mul_matcher
(
input
)
result_silu_mul
=
self
.
silu_and_mul_matcher
(
input
)
result_quant
=
self
.
quant_matcher
(
result_silu_mul
,
scale
)
result_quant
=
self
.
quant_matcher
(
result_silu_mul
,
scale
)
return
result_quant
[
0
]
return
result_quant
[
0
]
...
@@ -98,7 +106,7 @@ class SiluMulFp8StaticQuantPattern(ActivationQuantPattern):
...
@@ -98,7 +106,7 @@ class SiluMulFp8StaticQuantPattern(ActivationQuantPattern):
def
replacement
(
def
replacement
(
input
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
):
)
->
torch
.
Tensor
:
d
=
input
.
shape
[
-
1
]
//
2
d
=
input
.
shape
[
-
1
]
//
2
output_shape
=
input
.
shape
[:
-
1
]
+
(
d
,)
output_shape
=
input
.
shape
[:
-
1
]
+
(
d
,)
result
=
torch
.
empty
(
result
=
torch
.
empty
(
...
@@ -109,13 +117,10 @@ class SiluMulFp8StaticQuantPattern(ActivationQuantPattern):
...
@@ -109,13 +117,10 @@ class SiluMulFp8StaticQuantPattern(ActivationQuantPattern):
)
)
return
at
[
1
]
return
at
[
1
]
inputs
=
[
inps
=
self
.
get_inputs
()
*
self
.
silu_and_mul_matcher
.
inputs
(),
# input
pattern
(
*
inps
)
self
.
quant_matcher
.
inputs
()[
1
],
# scale
]
pattern
(
*
inputs
)
register_replacement
(
pattern
,
replacement
,
inp
ut
s
,
fwd_only
,
pm_pass
)
register_replacement
(
pattern
,
replacement
,
inps
,
fwd_only
,
pm_pass
)
class
SiluMulNvfp4QuantPattern
(
ActivationQuantPattern
):
class
SiluMulNvfp4QuantPattern
(
ActivationQuantPattern
):
...
@@ -123,16 +128,23 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
...
@@ -123,16 +128,23 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
Fusion for SiluMul+Nvfp4Quant Pattern
Fusion for SiluMul+Nvfp4Quant Pattern
"""
"""
def
__init__
(
self
):
def
__init__
(
self
)
->
None
:
super
().
__init__
(
kNvfp4Quant
)
super
().
__init__
(
kNvfp4Quant
)
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
get_inputs
(
self
)
->
list
[
torch
.
Tensor
]:
result
=
self
.
empty_quant
(
5
,
32
)
output_scale
=
empty_i32
(
128
,
4
)
input_
=
empty_bf16
(
5
,
64
)
scale
=
empty_fp32
(
1
,
1
)
return
[
result
,
output_scale
,
input_
,
scale
]
def
register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
def
pattern
(
def
pattern
(
result
:
torch
.
Tensor
,
result
:
torch
.
Tensor
,
output_scale
:
torch
.
Tensor
,
output_scale
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
result_silu_mul
=
self
.
silu_and_mul_matcher
(
input
)
result_silu_mul
=
self
.
silu_and_mul_matcher
(
input
)
at
=
auto_functionalized
(
at
=
auto_functionalized
(
self
.
QUANT_OP
,
self
.
QUANT_OP
,
...
@@ -148,7 +160,7 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
...
@@ -148,7 +160,7 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
output_scale
:
torch
.
Tensor
,
output_scale
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
at
=
auto_functionalized
(
at
=
auto_functionalized
(
self
.
FUSED_OP
,
self
.
FUSED_OP
,
result
=
result
,
result
=
result
,
...
@@ -158,14 +170,7 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
...
@@ -158,14 +170,7 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
)
)
return
at
[
1
],
at
[
2
]
return
at
[
1
],
at
[
2
]
inputs
=
[
register_replacement
(
pattern
,
replacement
,
self
.
get_inputs
(),
fwd_only
,
pm_pass
)
self
.
empty_quant
(
5
,
32
),
# result
empty_i32
(
128
,
4
),
# output_scale
empty_bf16
(
5
,
64
),
# input
empty_fp32
(
1
,
1
),
# scale
]
register_replacement
(
pattern
,
replacement
,
inputs
,
fwd_only
,
pm_pass
)
class
ActivationQuantFusionPass
(
VllmPatternMatcherPass
):
class
ActivationQuantFusionPass
(
VllmPatternMatcherPass
):
...
@@ -179,7 +184,7 @@ class ActivationQuantFusionPass(VllmPatternMatcherPass):
...
@@ -179,7 +184,7 @@ class ActivationQuantFusionPass(VllmPatternMatcherPass):
"""
"""
@
enable_fake_mode
@
enable_fake_mode
def
__init__
(
self
,
config
:
VllmConfig
):
def
__init__
(
self
,
config
:
VllmConfig
)
->
None
:
super
().
__init__
(
config
)
super
().
__init__
(
config
)
self
.
patterns
:
PatternMatcherPass
=
PatternMatcherPass
(
self
.
patterns
:
PatternMatcherPass
=
PatternMatcherPass
(
...
@@ -196,11 +201,11 @@ class ActivationQuantFusionPass(VllmPatternMatcherPass):
...
@@ -196,11 +201,11 @@ class ActivationQuantFusionPass(VllmPatternMatcherPass):
self
.
dump_patterns
(
config
,
self
.
patterns
)
self
.
dump_patterns
(
config
,
self
.
patterns
)
@
VllmInductorPass
.
time_and_log
@
VllmInductorPass
.
time_and_log
def
__call__
(
self
,
graph
:
torch
.
fx
.
Graph
):
def
__call__
(
self
,
graph
:
torch
.
fx
.
Graph
)
->
None
:
self
.
matched_count
=
self
.
patterns
.
apply
(
graph
)
self
.
matched_count
=
self
.
patterns
.
apply
(
graph
)
logger
.
debug
(
"Replaced %s patterns"
,
self
.
matched_count
)
logger
.
debug
(
"Replaced %s patterns"
,
self
.
matched_count
)
def
uuid
(
self
):
def
uuid
(
self
)
->
str
:
return
VllmInductorPass
.
hash_source
(
return
VllmInductorPass
.
hash_source
(
self
,
self
,
ActivationQuantPattern
,
ActivationQuantPattern
,
...
...
vllm/compilation/collective_fusion.py
View file @
ad8818bb
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
importlib.util
import
find_spec
from
importlib.util
import
find_spec
from
types
import
ModuleType
import
torch
import
torch
import
torch._inductor.pattern_matcher
as
pm
import
torch._inductor.pattern_matcher
as
pm
...
@@ -33,15 +34,15 @@ if find_spec("flashinfer"):
...
@@ -33,15 +34,15 @@ if find_spec("flashinfer"):
try
:
try
:
import
flashinfer.comm
as
flashinfer_comm
import
flashinfer.comm
as
flashinfer_comm
flashinfer_comm
=
(
flashinfer_comm
:
ModuleType
|
None
=
(
# type: ignore[no-redef]
flashinfer_comm
flashinfer_comm
if
hasattr
(
flashinfer_comm
,
"trtllm_allreduce_fusion"
)
if
hasattr
(
flashinfer_comm
,
"trtllm_allreduce_fusion"
)
else
None
else
None
)
)
except
ImportError
:
except
ImportError
:
flashinfer_comm
=
None
flashinfer_comm
=
None
# type: ignore[assignment]
else
:
else
:
flashinfer_comm
=
None
flashinfer_comm
=
None
# type: ignore[assignment]
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -58,13 +59,13 @@ class BasePattern:
...
@@ -58,13 +59,13 @@ class BasePattern:
class
GEMMReduceScatterPattern
(
BasePattern
):
class
GEMMReduceScatterPattern
(
BasePattern
):
def
get_inputs
(
self
):
def
get_inputs
(
self
)
->
list
[
torch
.
Tensor
]
:
mul
=
torch
.
empty
([
16
,
4
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
mul
=
torch
.
empty
([
16
,
4
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
mm_weight
=
torch
.
empty
([
4
,
4
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
mm_weight
=
torch
.
empty
([
4
,
4
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
return
[
mul
,
mm_weight
]
return
[
mul
,
mm_weight
]
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
def
pattern
(
mul
:
torch
.
Tensor
,
mm_weight
:
torch
.
Tensor
):
def
pattern
(
mul
:
torch
.
Tensor
,
mm_weight
:
torch
.
Tensor
)
->
torch
.
Tensor
:
mm
=
torch
.
ops
.
aten
.
mm
.
default
(
mul
,
mm_weight
)
mm
=
torch
.
ops
.
aten
.
mm
.
default
(
mul
,
mm_weight
)
reduce_scatter
=
torch
.
ops
.
vllm
.
reduce_scatter
.
default
(
reduce_scatter
=
torch
.
ops
.
vllm
.
reduce_scatter
.
default
(
mm
,
mm
,
...
@@ -74,7 +75,7 @@ class GEMMReduceScatterPattern(BasePattern):
...
@@ -74,7 +75,7 @@ class GEMMReduceScatterPattern(BasePattern):
)
)
return
reduce_scatter
return
reduce_scatter
def
replacement
(
mul
:
torch
.
Tensor
,
mm_weight
:
torch
.
Tensor
):
def
replacement
(
mul
:
torch
.
Tensor
,
mm_weight
:
torch
.
Tensor
)
->
torch
.
Tensor
:
gemm_rs
=
torch
.
ops
.
symm_mem
.
fused_matmul_reduce_scatter
(
gemm_rs
=
torch
.
ops
.
symm_mem
.
fused_matmul_reduce_scatter
(
mul
,
mul
,
mm_weight
,
mm_weight
,
...
@@ -91,17 +92,17 @@ class GEMMReduceScatterPattern(BasePattern):
...
@@ -91,17 +92,17 @@ class GEMMReduceScatterPattern(BasePattern):
class
AllGatherGEMMPattern
(
BasePattern
):
class
AllGatherGEMMPattern
(
BasePattern
):
def
get_inputs
(
self
):
def
get_inputs
(
self
)
->
list
[
torch
.
Tensor
]
:
x
=
torch
.
empty
([
4
,
4
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
x
=
torch
.
empty
([
4
,
4
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
weight
=
torch
.
empty
([
4
,
4
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
weight
=
torch
.
empty
([
4
,
4
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
return
[
x
,
weight
]
return
[
x
,
weight
]
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
def
pattern
(
def
pattern
(
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
)
->
torch
.
Tensor
:
all_gather
=
torch
.
ops
.
vllm
.
all_gather
.
default
(
all_gather
=
torch
.
ops
.
vllm
.
all_gather
.
default
(
x
,
x
,
dim
=
0
,
dim
=
0
,
...
@@ -111,9 +112,7 @@ class AllGatherGEMMPattern(BasePattern):
...
@@ -111,9 +112,7 @@ class AllGatherGEMMPattern(BasePattern):
return
torch
.
ops
.
aten
.
mm
.
default
(
all_gather
,
weight
)
return
torch
.
ops
.
aten
.
mm
.
default
(
all_gather
,
weight
)
def
replacement
(
def
replacement
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
ag_output
,
mm_outputs
=
torch
.
ops
.
symm_mem
.
fused_all_gather_matmul
(
ag_output
,
mm_outputs
=
torch
.
ops
.
symm_mem
.
fused_all_gather_matmul
(
x
,
x
,
[
weight
],
[
weight
],
...
@@ -128,7 +127,7 @@ class AllGatherGEMMPattern(BasePattern):
...
@@ -128,7 +127,7 @@ class AllGatherGEMMPattern(BasePattern):
class
ScaledMMReduceScatterPattern
(
BasePattern
):
class
ScaledMMReduceScatterPattern
(
BasePattern
):
def
get_inputs
(
self
):
def
get_inputs
(
self
)
->
list
[
torch
.
Tensor
]
:
input
=
torch
.
empty
([
16
,
16
],
device
=
self
.
device
,
dtype
=
FP8_DTYPE
)
input
=
torch
.
empty
([
16
,
16
],
device
=
self
.
device
,
dtype
=
FP8_DTYPE
)
mm_weight
=
(
mm_weight
=
(
torch
.
empty
([
16
,
16
],
device
=
self
.
device
,
dtype
=
FP8_DTYPE
)
torch
.
empty
([
16
,
16
],
device
=
self
.
device
,
dtype
=
FP8_DTYPE
)
...
@@ -139,7 +138,7 @@ class ScaledMMReduceScatterPattern(BasePattern):
...
@@ -139,7 +138,7 @@ class ScaledMMReduceScatterPattern(BasePattern):
scale_b
=
torch
.
empty
([
1
,
16
],
device
=
self
.
device
,
dtype
=
torch
.
float32
)
scale_b
=
torch
.
empty
([
1
,
16
],
device
=
self
.
device
,
dtype
=
torch
.
float32
)
return
[
input
,
mm_weight
,
scale_a
,
scale_b
]
return
[
input
,
mm_weight
,
scale_a
,
scale_b
]
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
def
pattern
(
def
pattern
(
input
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
mat2
:
torch
.
Tensor
,
mat2
:
torch
.
Tensor
,
...
@@ -196,7 +195,7 @@ class ScaledMMReduceScatterPattern(BasePattern):
...
@@ -196,7 +195,7 @@ class ScaledMMReduceScatterPattern(BasePattern):
class
AllGatherScaledMMPattern
(
BasePattern
):
class
AllGatherScaledMMPattern
(
BasePattern
):
def
get_inputs
(
self
):
def
get_inputs
(
self
)
->
list
[
torch
.
Tensor
]
:
x
=
torch
.
empty
([
8
,
16
],
device
=
self
.
device
,
dtype
=
FP8_DTYPE
)
x
=
torch
.
empty
([
8
,
16
],
device
=
self
.
device
,
dtype
=
FP8_DTYPE
)
weight
=
(
weight
=
(
torch
.
empty
([
16
,
16
],
device
=
self
.
device
,
dtype
=
FP8_DTYPE
)
torch
.
empty
([
16
,
16
],
device
=
self
.
device
,
dtype
=
FP8_DTYPE
)
...
@@ -211,7 +210,7 @@ class AllGatherScaledMMPattern(BasePattern):
...
@@ -211,7 +210,7 @@ class AllGatherScaledMMPattern(BasePattern):
return
[
x
,
weight
,
scale_a
,
scale_b
]
return
[
x
,
weight
,
scale_a
,
scale_b
]
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
def
pattern
(
def
pattern
(
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
...
@@ -258,7 +257,7 @@ class AllGatherScaledMMPattern(BasePattern):
...
@@ -258,7 +257,7 @@ class AllGatherScaledMMPattern(BasePattern):
class
CutlassScaledMMReduceScatterPattern
(
BasePattern
):
class
CutlassScaledMMReduceScatterPattern
(
BasePattern
):
def
get_inputs
(
self
):
def
get_inputs
(
self
)
->
list
[
torch
.
Tensor
]
:
input
=
torch
.
empty
([
16
,
16
],
device
=
self
.
device
,
dtype
=
FP8_DTYPE
)
input
=
torch
.
empty
([
16
,
16
],
device
=
self
.
device
,
dtype
=
FP8_DTYPE
)
mm_weight
=
(
mm_weight
=
(
torch
.
empty
([
16
,
16
],
device
=
self
.
device
,
dtype
=
FP8_DTYPE
)
torch
.
empty
([
16
,
16
],
device
=
self
.
device
,
dtype
=
FP8_DTYPE
)
...
@@ -271,7 +270,7 @@ class CutlassScaledMMReduceScatterPattern(BasePattern):
...
@@ -271,7 +270,7 @@ class CutlassScaledMMReduceScatterPattern(BasePattern):
cutlass_mm_output
=
torch
.
empty
([
16
,
16
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
cutlass_mm_output
=
torch
.
empty
([
16
,
16
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
return
[
input
,
mm_weight
,
scale_a
,
scale_b
,
cutlass_mm_output
]
return
[
input
,
mm_weight
,
scale_a
,
scale_b
,
cutlass_mm_output
]
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
def
pattern
(
def
pattern
(
input
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
...
@@ -331,7 +330,7 @@ class CutlassScaledMMReduceScatterPattern(BasePattern):
...
@@ -331,7 +330,7 @@ class CutlassScaledMMReduceScatterPattern(BasePattern):
class
AllGatherCutlassScaledMMPattern
(
BasePattern
):
class
AllGatherCutlassScaledMMPattern
(
BasePattern
):
def
get_inputs
(
self
):
def
get_inputs
(
self
)
->
list
[
torch
.
Tensor
]
:
x
=
torch
.
empty
([
8
,
16
],
device
=
self
.
device
,
dtype
=
FP8_DTYPE
)
x
=
torch
.
empty
([
8
,
16
],
device
=
self
.
device
,
dtype
=
FP8_DTYPE
)
weight
=
(
weight
=
(
torch
.
empty
([
16
,
16
],
device
=
self
.
device
,
dtype
=
FP8_DTYPE
)
torch
.
empty
([
16
,
16
],
device
=
self
.
device
,
dtype
=
FP8_DTYPE
)
...
@@ -349,7 +348,7 @@ class AllGatherCutlassScaledMMPattern(BasePattern):
...
@@ -349,7 +348,7 @@ class AllGatherCutlassScaledMMPattern(BasePattern):
return
[
x
,
weight
,
scale_a
,
scale_b
,
output
]
return
[
x
,
weight
,
scale_a
,
scale_b
,
output
]
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
def
pattern
(
def
pattern
(
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
...
@@ -400,7 +399,7 @@ class AllGatherCutlassScaledMMPattern(BasePattern):
...
@@ -400,7 +399,7 @@ class AllGatherCutlassScaledMMPattern(BasePattern):
class
AsyncTPPass
(
VllmPatternMatcherPass
):
class
AsyncTPPass
(
VllmPatternMatcherPass
):
@
enable_fake_mode
@
enable_fake_mode
def
__init__
(
self
,
config
:
VllmConfig
):
def
__init__
(
self
,
config
:
VllmConfig
)
->
None
:
super
().
__init__
(
config
)
super
().
__init__
(
config
)
# Enable symmetric memory for the TP process group
# Enable symmetric memory for the TP process group
...
@@ -445,7 +444,7 @@ class AsyncTPPass(VllmPatternMatcherPass):
...
@@ -445,7 +444,7 @@ class AsyncTPPass(VllmPatternMatcherPass):
return
compile_range
.
is_single_size
()
and
compile_range
.
end
%
tp_size
==
0
return
compile_range
.
is_single_size
()
and
compile_range
.
end
%
tp_size
==
0
@
VllmInductorPass
.
time_and_log
@
VllmInductorPass
.
time_and_log
def
__call__
(
self
,
graph
:
fx
.
Graph
):
def
__call__
(
self
,
graph
:
fx
.
Graph
)
->
None
:
self
.
matched_count
=
self
.
patterns
.
apply
(
graph
)
self
.
matched_count
=
self
.
patterns
.
apply
(
graph
)
logger
.
debug
(
"Replaced %s patterns"
,
self
.
matched_count
)
logger
.
debug
(
"Replaced %s patterns"
,
self
.
matched_count
)
...
@@ -512,11 +511,13 @@ if flashinfer_comm is not None:
...
@@ -512,11 +511,13 @@ if flashinfer_comm is not None:
f
"max token num
{
max_token_num
}
* hidden size
{
hidden_size
}
* "
f
"max token num
{
max_token_num
}
* hidden size
{
hidden_size
}
* "
f
"element size
{
element_size
}
"
f
"element size
{
element_size
}
"
)
)
device_capability
=
current_platform
.
get_device_capability
().
to_int
()
curr_device
=
current_platform
.
get_device_capability
()
device_capability
=
curr_device
.
to_int
()
if
curr_device
is
not
None
else
None
# Get one shot input size limit for the current world size
# Get one shot input size limit for the current world size
# for the current device capability
# for the current device capability
max_one_shot_size
=
_FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB
.
get
(
max_one_shot_size
=
_FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB
.
get
(
device_capability
,
{}
device_capability
,
# type: ignore[arg-type]
{},
).
get
(
world_size
,
None
)
).
get
(
world_size
,
None
)
# Use one shot if no max size is specified
# Use one shot if no max size is specified
use_oneshot
=
(
use_oneshot
=
(
...
@@ -606,7 +607,7 @@ class FlashInferFusedAllReduceParams:
...
@@ -606,7 +607,7 @@ class FlashInferFusedAllReduceParams:
world_size
:
int
,
world_size
:
int
,
use_fp32_lamport
:
bool
=
False
,
use_fp32_lamport
:
bool
=
False
,
max_token_num
:
int
=
1024
,
max_token_num
:
int
=
1024
,
):
)
->
None
:
self
.
rank
=
rank
self
.
rank
=
rank
self
.
world_size
=
world_size
self
.
world_size
=
world_size
self
.
use_fp32_lamport
=
use_fp32_lamport
self
.
use_fp32_lamport
=
use_fp32_lamport
...
@@ -615,7 +616,7 @@ class FlashInferFusedAllReduceParams:
...
@@ -615,7 +616,7 @@ class FlashInferFusedAllReduceParams:
self
.
fp32_acc
=
True
self
.
fp32_acc
=
True
self
.
max_token_num
=
max_token_num
self
.
max_token_num
=
max_token_num
def
get_trtllm_fused_allreduce_kwargs
(
self
):
def
get_trtllm_fused_allreduce_kwargs
(
self
)
->
dict
[
str
,
bool
|
int
]
:
return
{
return
{
"world_rank"
:
self
.
rank
,
"world_rank"
:
self
.
rank
,
"world_size"
:
self
.
world_size
,
"world_size"
:
self
.
world_size
,
...
@@ -639,26 +640,30 @@ class AllReduceRMSNormPattern(BasePattern):
...
@@ -639,26 +640,30 @@ class AllReduceRMSNormPattern(BasePattern):
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
device
:
str
|
None
,
device
:
str
|
None
,
allreduce_params
:
FlashInferFusedAllReduceParams
,
allreduce_params
:
FlashInferFusedAllReduceParams
,
):
)
->
None
:
super
().
__init__
(
dtype
,
device
)
super
().
__init__
(
dtype
,
device
)
self
.
epsilon
=
epsilon
self
.
epsilon
=
epsilon
self
.
allreduce_params
=
allreduce_params
self
.
allreduce_params
=
allreduce_params
self
.
rmsnorm_matcher
=
MatcherRMSNorm
(
epsilon
)
self
.
rmsnorm_matcher
=
MatcherRMSNorm
(
epsilon
)
def
get_inputs
(
self
):
def
get_inputs
(
self
)
->
list
[
torch
.
Tensor
]
:
input
,
weight
=
self
.
rmsnorm_matcher
.
inputs
()
input
,
weight
=
self
.
rmsnorm_matcher
.
inputs
()
# input goes through allreduce first, always 16-bit
# input goes through allreduce first, always 16-bit
return
[
input
.
to
(
self
.
dtype
),
weight
]
return
[
input
.
to
(
self
.
dtype
),
weight
]
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
def
pattern
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
):
def
pattern
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
allreduce_output
=
tensor_model_parallel_all_reduce
(
input
)
allreduce_output
=
tensor_model_parallel_all_reduce
(
input
)
rms
=
self
.
rmsnorm_matcher
(
allreduce_output
,
weight
)
rms
=
self
.
rmsnorm_matcher
(
allreduce_output
,
weight
)
return
rms
,
allreduce_output
return
rms
,
allreduce_output
def
replacement
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
):
def
replacement
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
residual
=
torch
.
zeros_like
(
input
)
residual
=
torch
.
zeros_like
(
input
)
rms_result
=
torch
.
empty_like
(
input
)
rms_result
=
torch
.
empty_like
(
input
)
allreduce
=
auto_functionalized
(
allreduce
=
auto_functionalized
(
...
@@ -694,27 +699,29 @@ class AllReduceFusedAddRMSNormPattern(BasePattern):
...
@@ -694,27 +699,29 @@ class AllReduceFusedAddRMSNormPattern(BasePattern):
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
device
:
str
|
None
,
device
:
str
|
None
,
allreduce_params
:
FlashInferFusedAllReduceParams
,
allreduce_params
:
FlashInferFusedAllReduceParams
,
):
)
->
None
:
super
().
__init__
(
dtype
,
device
)
super
().
__init__
(
dtype
,
device
)
self
.
epsilon
=
epsilon
self
.
epsilon
=
epsilon
self
.
allreduce_params
=
allreduce_params
self
.
allreduce_params
=
allreduce_params
self
.
rmsnorm_matcher
=
MatcherFusedAddRMSNorm
(
epsilon
)
self
.
rmsnorm_matcher
=
MatcherFusedAddRMSNorm
(
epsilon
)
def
get_inputs
(
self
):
def
get_inputs
(
self
)
->
list
[
torch
.
Tensor
]
:
input
,
residual
,
weight
=
self
.
rmsnorm_matcher
.
inputs
()
input
,
residual
,
weight
=
self
.
rmsnorm_matcher
.
inputs
()
# input goes through allreduce first, always 16-bit
# input goes through allreduce first, always 16-bit
return
[
residual
,
input
.
to
(
self
.
dtype
),
weight
]
return
[
residual
,
input
.
to
(
self
.
dtype
),
weight
]
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
def
pattern
(
residual
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
):
def
pattern
(
residual
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
allreduce_output
=
tensor_model_parallel_all_reduce
(
input
)
allreduce_output
=
tensor_model_parallel_all_reduce
(
input
)
rms
,
residual
=
self
.
rmsnorm_matcher
(
allreduce_output
,
weight
,
residual
)
rms
,
residual
=
self
.
rmsnorm_matcher
(
allreduce_output
,
weight
,
residual
)
return
rms
,
residual
return
rms
,
residual
def
replacement
(
def
replacement
(
residual
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
residual
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
allreduce
=
auto_functionalized
(
allreduce
=
auto_functionalized
(
flashinfer_trtllm_fused_allreduce_norm
,
flashinfer_trtllm_fused_allreduce_norm
,
allreduce_in
=
input
,
allreduce_in
=
input
,
...
@@ -739,8 +746,8 @@ class AllReduceFusedAddRMSNormPattern(BasePattern):
...
@@ -739,8 +746,8 @@ class AllReduceFusedAddRMSNormPattern(BasePattern):
first_return_only
=
lambda
fn
:
lambda
a
,
b
,
c
:
fn
(
a
,
b
,
c
)[
0
]
first_return_only
=
lambda
fn
:
lambda
a
,
b
,
c
:
fn
(
a
,
b
,
c
)[
0
]
pm
.
register_replacement
(
pm
.
register_replacement
(
first_return_only
(
pattern
),
first_return_only
(
pattern
),
# type: ignore[no-untyped-call]
first_return_only
(
replacement
),
first_return_only
(
replacement
),
# type: ignore[no-untyped-call]
self
.
get_inputs
(),
self
.
get_inputs
(),
pm
.
fwd_only
,
pm
.
fwd_only
,
pm_pass
,
pm_pass
,
...
@@ -761,7 +768,7 @@ class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern):
...
@@ -761,7 +768,7 @@ class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern):
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
device
:
str
|
None
,
device
:
str
|
None
,
allreduce_params
:
FlashInferFusedAllReduceParams
,
allreduce_params
:
FlashInferFusedAllReduceParams
,
):
)
->
None
:
super
().
__init__
(
dtype
,
device
)
super
().
__init__
(
dtype
,
device
)
self
.
epsilon
=
epsilon
self
.
epsilon
=
epsilon
self
.
allreduce_params
=
allreduce_params
self
.
allreduce_params
=
allreduce_params
...
@@ -769,25 +776,27 @@ class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern):
...
@@ -769,25 +776,27 @@ class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern):
self
.
rmsnorm_matcher
=
MatcherRMSNorm
(
epsilon
)
self
.
rmsnorm_matcher
=
MatcherRMSNorm
(
epsilon
)
self
.
quant_matcher
=
MatcherQuantFP8
(
kFp8StaticTensorSym
)
self
.
quant_matcher
=
MatcherQuantFP8
(
kFp8StaticTensorSym
)
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
get_inputs
(
self
)
->
list
[
torch
.
Tensor
]:
def
get_inputs
():
input
,
weight
=
self
.
rmsnorm_matcher
.
inputs
()
input
,
weight
=
self
.
rmsnorm_matcher
.
inputs
()
_
,
scale
=
self
.
quant_matcher
.
inputs
()
_
,
scale
=
self
.
quant_matcher
.
inputs
()
# input goes through allreduce first, always 16-bit
# input goes through allreduce first, always 16-bit
return
[
input
.
to
(
self
.
dtype
),
weight
,
scale
]
return
[
input
.
to
(
self
.
dtype
),
weight
,
scale
]
def
register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
def
pattern
(
def
pattern
(
input
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
all_reduce
=
tensor_model_parallel_all_reduce
(
input
)
all_reduce
=
tensor_model_parallel_all_reduce
(
input
)
rms
=
self
.
rmsnorm_matcher
(
all_reduce
,
weight
)
rms
=
self
.
rmsnorm_matcher
(
all_reduce
,
weight
)
quant
,
_
=
self
.
quant_matcher
(
rms
,
scale
)
quant
,
_
=
self
.
quant_matcher
(
rms
,
scale
)
return
quant
,
all_reduce
return
quant
,
all_reduce
def
replacement
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
):
def
replacement
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
residual
=
torch
.
zeros_like
(
input
)
residual
=
torch
.
zeros_like
(
input
)
result_rms
=
torch
.
empty_like
(
input
)
result_rms
=
torch
.
empty_like
(
input
)
result_quant
=
torch
.
empty_like
(
input
,
dtype
=
self
.
quant_dtype
)
result_quant
=
torch
.
empty_like
(
input
,
dtype
=
self
.
quant_dtype
)
...
@@ -812,7 +821,7 @@ class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern):
...
@@ -812,7 +821,7 @@ class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern):
return
allreduce
[
4
],
allreduce
[
1
]
return
allreduce
[
4
],
allreduce
[
1
]
pm
.
register_replacement
(
pm
.
register_replacement
(
pattern
,
replacement
,
get_inputs
(),
pm
.
fwd_only
,
pm_pass
pattern
,
replacement
,
self
.
get_inputs
(),
pm
.
fwd_only
,
pm_pass
)
)
...
@@ -830,7 +839,7 @@ class AllReduceFusedAddRMSNormStaticQuantFP8Pattern(BasePattern):
...
@@ -830,7 +839,7 @@ class AllReduceFusedAddRMSNormStaticQuantFP8Pattern(BasePattern):
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
device
:
str
|
None
,
device
:
str
|
None
,
allreduce_params
:
FlashInferFusedAllReduceParams
,
allreduce_params
:
FlashInferFusedAllReduceParams
,
):
)
->
None
:
super
().
__init__
(
dtype
,
device
)
super
().
__init__
(
dtype
,
device
)
self
.
epsilon
=
epsilon
self
.
epsilon
=
epsilon
self
.
allreduce_params
=
allreduce_params
self
.
allreduce_params
=
allreduce_params
...
@@ -839,20 +848,20 @@ class AllReduceFusedAddRMSNormStaticQuantFP8Pattern(BasePattern):
...
@@ -839,20 +848,20 @@ class AllReduceFusedAddRMSNormStaticQuantFP8Pattern(BasePattern):
self
.
rmsnorm_matcher
=
MatcherFusedAddRMSNorm
(
epsilon
)
self
.
rmsnorm_matcher
=
MatcherFusedAddRMSNorm
(
epsilon
)
self
.
quant_matcher
=
MatcherQuantFP8
(
kFp8StaticTensorSym
)
self
.
quant_matcher
=
MatcherQuantFP8
(
kFp8StaticTensorSym
)
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
get_inputs
(
self
)
->
list
[
torch
.
Tensor
]:
def
get_inputs
():
input
,
residual
,
weight
=
self
.
rmsnorm_matcher
.
inputs
()
input
,
residual
,
weight
=
self
.
rmsnorm_matcher
.
inputs
()
_
,
scale
=
self
.
quant_matcher
.
inputs
()
_
,
scale
=
self
.
quant_matcher
.
inputs
()
# input goes through allreduce first, always 16-bit
# input goes through allreduce first, always 16-bit
return
[
residual
,
input
.
to
(
self
.
dtype
),
weight
,
scale
]
return
[
residual
,
input
.
to
(
self
.
dtype
),
weight
,
scale
]
def
register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
def
pattern
(
def
pattern
(
residual
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
allreduce_output
=
tensor_model_parallel_all_reduce
(
input
)
allreduce_output
=
tensor_model_parallel_all_reduce
(
input
)
rms
,
res
=
self
.
rmsnorm_matcher
(
allreduce_output
,
weight
,
residual
)
rms
,
res
=
self
.
rmsnorm_matcher
(
allreduce_output
,
weight
,
residual
)
quant
,
_
=
self
.
quant_matcher
(
rms
,
scale
)
quant
,
_
=
self
.
quant_matcher
(
rms
,
scale
)
...
@@ -864,7 +873,7 @@ class AllReduceFusedAddRMSNormStaticQuantFP8Pattern(BasePattern):
...
@@ -864,7 +873,7 @@ class AllReduceFusedAddRMSNormStaticQuantFP8Pattern(BasePattern):
input
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
result_quant
=
torch
.
empty_like
(
input
,
dtype
=
self
.
quant_dtype
)
result_quant
=
torch
.
empty_like
(
input
,
dtype
=
self
.
quant_dtype
)
allreduce
=
auto_functionalized
(
allreduce
=
auto_functionalized
(
flashinfer_trtllm_fused_allreduce_norm
,
flashinfer_trtllm_fused_allreduce_norm
,
...
@@ -886,7 +895,7 @@ class AllReduceFusedAddRMSNormStaticQuantFP8Pattern(BasePattern):
...
@@ -886,7 +895,7 @@ class AllReduceFusedAddRMSNormStaticQuantFP8Pattern(BasePattern):
return
allreduce
[
4
],
allreduce
[
2
]
return
allreduce
[
4
],
allreduce
[
2
]
pm
.
register_replacement
(
pm
.
register_replacement
(
pattern
,
replacement
,
get_inputs
(),
pm
.
fwd_only
,
pm_pass
pattern
,
replacement
,
self
.
get_inputs
(),
pm
.
fwd_only
,
pm_pass
)
)
...
@@ -904,31 +913,31 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
...
@@ -904,31 +913,31 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
device
:
str
|
None
,
device
:
str
|
None
,
allreduce_params
:
FlashInferFusedAllReduceParams
,
allreduce_params
:
FlashInferFusedAllReduceParams
,
):
)
->
None
:
super
().
__init__
(
dtype
,
device
)
super
().
__init__
(
dtype
,
device
)
self
.
epsilon
=
epsilon
self
.
epsilon
=
epsilon
self
.
allreduce_params
=
allreduce_params
self
.
allreduce_params
=
allreduce_params
self
.
rmsnorm_matcher
=
MatcherRMSNorm
(
epsilon
)
self
.
rmsnorm_matcher
=
MatcherRMSNorm
(
epsilon
)
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
get_inputs
(
self
)
->
list
[
torch
.
Tensor
]:
def
get_inputs
():
input
=
torch
.
empty
([
1
,
16
,
16
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
input
=
torch
.
empty
([
1
,
16
,
16
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
quant_result
=
torch
.
empty
((
16
,
8
),
device
=
self
.
device
,
dtype
=
torch
.
uint8
)
quant_result
=
torch
.
empty
((
16
,
8
),
device
=
self
.
device
,
dtype
=
torch
.
uint8
)
input_global_scale
=
torch
.
empty
(
input_global_scale
=
torch
.
empty
(
[
1
,
1
],
device
=
self
.
device
,
dtype
=
torch
.
float32
[
1
,
1
],
device
=
self
.
device
,
dtype
=
torch
.
float32
)
)
weight
=
torch
.
empty
([
16
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
weight
=
torch
.
empty
([
16
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
output_scale
=
torch
.
empty
([
128
,
4
],
device
=
self
.
device
,
dtype
=
torch
.
int32
)
output_scale
=
torch
.
empty
([
128
,
4
],
device
=
self
.
device
,
dtype
=
torch
.
int32
)
return
[
input
,
quant_result
,
weight
,
input_global_scale
,
output_scale
]
return
[
input
,
quant_result
,
weight
,
input_global_scale
,
output_scale
]
def
register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
def
pattern
(
def
pattern
(
input
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
quant_result
:
torch
.
Tensor
,
quant_result
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
input_global_scale
:
torch
.
Tensor
,
input_global_scale
:
torch
.
Tensor
,
output_scale
:
torch
.
Tensor
,
output_scale
:
torch
.
Tensor
,
):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
:
all_reduce
=
tensor_model_parallel_all_reduce
(
input
)
all_reduce
=
tensor_model_parallel_all_reduce
(
input
)
rms
=
self
.
rmsnorm_matcher
(
all_reduce
,
weight
)
rms
=
self
.
rmsnorm_matcher
(
all_reduce
,
weight
)
quant_out_tuple
=
auto_functionalized
(
quant_out_tuple
=
auto_functionalized
(
...
@@ -948,7 +957,7 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
...
@@ -948,7 +957,7 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
weight
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
input_global_scale
:
torch
.
Tensor
,
input_global_scale
:
torch
.
Tensor
,
output_scale
:
torch
.
Tensor
,
output_scale
:
torch
.
Tensor
,
):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
:
residual
=
torch
.
zeros_like
(
input
)
residual
=
torch
.
zeros_like
(
input
)
result_rms
=
torch
.
empty_like
(
input
)
result_rms
=
torch
.
empty_like
(
input
)
allreduce
=
auto_functionalized
(
allreduce
=
auto_functionalized
(
...
@@ -972,7 +981,7 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
...
@@ -972,7 +981,7 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
return
allreduce
[
4
],
allreduce
[
1
],
allreduce
[
5
]
return
allreduce
[
4
],
allreduce
[
1
],
allreduce
[
5
]
pm
.
register_replacement
(
pm
.
register_replacement
(
pattern
,
replacement
,
get_inputs
(),
pm
.
fwd_only
,
pm_pass
pattern
,
replacement
,
self
.
get_inputs
(),
pm
.
fwd_only
,
pm_pass
)
)
...
@@ -990,33 +999,33 @@ class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
...
@@ -990,33 +999,33 @@ class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
device
:
str
|
None
,
device
:
str
|
None
,
allreduce_params
:
FlashInferFusedAllReduceParams
,
allreduce_params
:
FlashInferFusedAllReduceParams
,
):
)
->
None
:
super
().
__init__
(
dtype
,
device
)
super
().
__init__
(
dtype
,
device
)
self
.
epsilon
=
epsilon
self
.
epsilon
=
epsilon
self
.
allreduce_params
=
allreduce_params
self
.
allreduce_params
=
allreduce_params
self
.
rmsnorm_matcher
=
MatcherFusedAddRMSNorm
(
epsilon
)
self
.
rmsnorm_matcher
=
MatcherFusedAddRMSNorm
(
epsilon
)
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
get_inputs
(
self
)
->
list
[
torch
.
Tensor
]:
def
get_inputs
():
input
=
torch
.
empty
([
16
,
16
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
input
=
torch
.
empty
([
16
,
16
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
residual
=
torch
.
empty
([
16
,
16
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
weight
=
torch
.
empty
([
16
,
16
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
quant_result
=
torch
.
empty
((
16
,
8
),
device
=
self
.
device
,
dtype
=
torch
.
uint8
)
input_global_scale
=
torch
.
empty
(
[
1
,
1
],
device
=
self
.
device
,
dtype
=
torch
.
float32
)
output_scale
=
torch
.
empty
([
128
,
4
],
device
=
self
.
device
,
dtype
=
torch
.
int32
)
return
[
quant_result
,
residual
,
input
,
output_scale
,
weight
,
input_global_scale
,
]
residual
=
torch
.
empty
([
16
,
16
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
weight
=
torch
.
empty
([
16
,
16
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
quant_result
=
torch
.
empty
((
16
,
8
),
device
=
self
.
device
,
dtype
=
torch
.
uint8
)
input_global_scale
=
torch
.
empty
(
[
1
,
1
],
device
=
self
.
device
,
dtype
=
torch
.
float32
)
output_scale
=
torch
.
empty
([
128
,
4
],
device
=
self
.
device
,
dtype
=
torch
.
int32
)
return
[
quant_result
,
residual
,
input
,
output_scale
,
weight
,
input_global_scale
,
]
def
register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
def
pattern
(
def
pattern
(
quant_result
:
torch
.
Tensor
,
quant_result
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
...
@@ -1024,7 +1033,7 @@ class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
...
@@ -1024,7 +1033,7 @@ class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
output_scale
:
torch
.
Tensor
,
output_scale
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
input_global_scale
:
torch
.
Tensor
,
input_global_scale
:
torch
.
Tensor
,
):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
:
allreduce_output
=
tensor_model_parallel_all_reduce
(
input
)
allreduce_output
=
tensor_model_parallel_all_reduce
(
input
)
rms
,
residual
=
self
.
rmsnorm_matcher
(
allreduce_output
,
weight
,
residual
)
rms
,
residual
=
self
.
rmsnorm_matcher
(
allreduce_output
,
weight
,
residual
)
quant_out_tuple
=
auto_functionalized
(
quant_out_tuple
=
auto_functionalized
(
...
@@ -1045,7 +1054,7 @@ class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
...
@@ -1045,7 +1054,7 @@ class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
output_scale
:
torch
.
Tensor
,
output_scale
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
input_global_scale
:
torch
.
Tensor
,
input_global_scale
:
torch
.
Tensor
,
):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
:
allreduce
=
auto_functionalized
(
allreduce
=
auto_functionalized
(
flashinfer_trtllm_fused_allreduce_norm
,
flashinfer_trtllm_fused_allreduce_norm
,
allreduce_in
=
input
,
allreduce_in
=
input
,
...
@@ -1066,12 +1075,12 @@ class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
...
@@ -1066,12 +1075,12 @@ class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
return
allreduce
[
4
],
allreduce
[
2
],
allreduce
[
5
]
return
allreduce
[
4
],
allreduce
[
2
],
allreduce
[
5
]
pm
.
register_replacement
(
pm
.
register_replacement
(
pattern
,
replacement
,
get_inputs
(),
pm
.
fwd_only
,
pm_pass
pattern
,
replacement
,
self
.
get_inputs
(),
pm
.
fwd_only
,
pm_pass
)
)
class
AllReduceFusionPass
(
VllmPatternMatcherPass
):
class
AllReduceFusionPass
(
VllmPatternMatcherPass
):
def
__init__
(
self
,
config
:
VllmConfig
):
def
__init__
(
self
,
config
:
VllmConfig
)
->
None
:
super
().
__init__
(
config
)
super
().
__init__
(
config
)
self
.
disabled
=
True
self
.
disabled
=
True
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
...
@@ -1122,7 +1131,7 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
...
@@ -1122,7 +1131,7 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
)
)
self
.
ipc_handles
,
workspace_tensor
=
(
self
.
ipc_handles
,
workspace_tensor
=
(
flashinfer_comm
.
trtllm_create_ipc_workspace_for_all_reduce_fusion
(
flashinfer_comm
.
trtllm_create_ipc_workspace_for_all_reduce_fusion
(
# type: ignore[misc]
tp_rank
=
rank
,
tp_rank
=
rank
,
tp_size
=
self
.
tp_size
,
tp_size
=
self
.
tp_size
,
max_token_num
=
self
.
max_token_num
,
max_token_num
=
self
.
max_token_num
,
...
@@ -1145,7 +1154,7 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
...
@@ -1145,7 +1154,7 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
self
.
dump_patterns
(
config
,
self
.
patterns
)
self
.
dump_patterns
(
config
,
self
.
patterns
)
@
enable_fake_mode
@
enable_fake_mode
def
register_patterns
(
self
):
def
register_patterns
(
self
)
->
None
:
for
epsilon
in
[
1e-5
,
1e-6
]:
for
epsilon
in
[
1e-5
,
1e-6
]:
AllReduceFusedRMSNormStaticQuantFP8Pattern
(
AllReduceFusedRMSNormStaticQuantFP8Pattern
(
epsilon
,
epsilon
,
...
@@ -1198,7 +1207,7 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
...
@@ -1198,7 +1207,7 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
return
compile_range
.
end
<=
self
.
max_token_num
return
compile_range
.
end
<=
self
.
max_token_num
@
VllmInductorPass
.
time_and_log
@
VllmInductorPass
.
time_and_log
def
__call__
(
self
,
graph
:
fx
.
Graph
):
def
__call__
(
self
,
graph
:
fx
.
Graph
)
->
None
:
if
self
.
disabled
:
if
self
.
disabled
:
logger
.
debug
(
"AllReduceFusionPass disabled"
)
logger
.
debug
(
"AllReduceFusionPass disabled"
)
return
return
...
@@ -1206,7 +1215,7 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
...
@@ -1206,7 +1215,7 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
self
.
matched_count
=
self
.
patterns
.
apply
(
graph
)
self
.
matched_count
=
self
.
patterns
.
apply
(
graph
)
logger
.
debug
(
"Replaced %s patterns"
,
self
.
matched_count
)
logger
.
debug
(
"Replaced %s patterns"
,
self
.
matched_count
)
def
__del__
(
self
):
def
__del__
(
self
)
->
None
:
if
getattr
(
self
,
"disabled"
,
True
):
if
getattr
(
self
,
"disabled"
,
True
):
return
return
if
flashinfer_comm
is
not
None
:
if
flashinfer_comm
is
not
None
:
...
...
vllm/compilation/fusion.py
View file @
ad8818bb
...
@@ -38,19 +38,19 @@ FP8_DTYPE = current_platform.fp8_dtype()
...
@@ -38,19 +38,19 @@ FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE
=
torch
.
uint8
FP4_DTYPE
=
torch
.
uint8
def
empty_bf16
(
*
args
,
**
kwargs
)
:
def
empty_bf16
(
*
args
:
Any
,
**
kwargs
:
Any
)
->
torch
.
Tensor
:
return
torch
.
empty
(
*
args
,
**
kwargs
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
return
torch
.
empty
(
*
args
,
**
kwargs
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
def
empty_fp32
(
*
args
,
**
kwargs
)
:
def
empty_fp32
(
*
args
:
Any
,
**
kwargs
:
Any
)
->
torch
.
Tensor
:
return
torch
.
empty
(
*
args
,
**
kwargs
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
return
torch
.
empty
(
*
args
,
**
kwargs
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
def
empty_i32
(
*
args
,
**
kwargs
)
:
def
empty_i32
(
*
args
:
Any
,
**
kwargs
:
Any
)
->
torch
.
Tensor
:
return
torch
.
empty
(
*
args
,
**
kwargs
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
return
torch
.
empty
(
*
args
,
**
kwargs
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
def
empty_i64
(
*
args
,
**
kwargs
)
:
def
empty_i64
(
*
args
:
Any
,
**
kwargs
:
Any
)
->
torch
.
Tensor
:
return
torch
.
empty
(
*
args
,
**
kwargs
,
dtype
=
torch
.
int64
,
device
=
"cuda"
)
return
torch
.
empty
(
*
args
,
**
kwargs
,
dtype
=
torch
.
int64
,
device
=
"cuda"
)
...
@@ -79,7 +79,7 @@ class FusedRMSQuantKey(NamedTuple):
...
@@ -79,7 +79,7 @@ class FusedRMSQuantKey(NamedTuple):
quant
:
QuantKey
quant
:
QuantKey
fused_add
:
bool
fused_add
:
bool
def
__str__
(
self
):
def
__str__
(
self
)
->
str
:
return
(
return
(
f
"FusedQuantKey(
{
self
.
quant
}
, with"
f
"FusedQuantKey(
{
self
.
quant
}
, with"
f
"
{
''
if
self
.
fused_add
else
'out'
}
residual)"
f
"
{
''
if
self
.
fused_add
else
'out'
}
residual)"
...
@@ -121,7 +121,7 @@ class RMSNormQuantPattern:
...
@@ -121,7 +121,7 @@ class RMSNormQuantPattern:
key
:
FusedRMSQuantKey
,
key
:
FusedRMSQuantKey
,
has_col_major_scales
:
bool
=
False
,
has_col_major_scales
:
bool
=
False
,
is_e8m0
:
bool
=
False
,
is_e8m0
:
bool
=
False
,
):
)
->
None
:
self
.
epsilon
=
epsilon
self
.
epsilon
=
epsilon
self
.
quant_dtype
=
key
.
quant
.
dtype
self
.
quant_dtype
=
key
.
quant
.
dtype
config
=
get_current_vllm_config
()
config
=
get_current_vllm_config
()
...
@@ -141,7 +141,9 @@ class RMSNormQuantPattern:
...
@@ -141,7 +141,9 @@ class RMSNormQuantPattern:
class
RMSNormStaticQuantPattern
(
RMSNormQuantPattern
):
class
RMSNormStaticQuantPattern
(
RMSNormQuantPattern
):
def
__init__
(
self
,
epsilon
:
float
,
quant_dtype
:
torch
.
dtype
,
symmetric
=
True
):
def
__init__
(
self
,
epsilon
:
float
,
quant_dtype
:
torch
.
dtype
,
symmetric
:
bool
=
True
)
->
None
:
fused_key
=
FusedRMSQuantKey
(
fused_key
=
FusedRMSQuantKey
(
fused_add
=
False
,
fused_add
=
False
,
quant
=
QuantKey
(
quant
=
QuantKey
(
...
@@ -150,13 +152,17 @@ class RMSNormStaticQuantPattern(RMSNormQuantPattern):
...
@@ -150,13 +152,17 @@ class RMSNormStaticQuantPattern(RMSNormQuantPattern):
)
)
super
().
__init__
(
epsilon
,
fused_key
)
super
().
__init__
(
epsilon
,
fused_key
)
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
# Cannot use methods, as the self argument affects tracing
# Cannot use methods, as the self argument affects tracing
def
pattern
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
):
def
pattern
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
)
->
torch
.
Tensor
:
result_rms
=
self
.
rmsnorm_matcher
(
input
,
weight
)
result_rms
=
self
.
rmsnorm_matcher
(
input
,
weight
)
return
self
.
quant_matcher
(
result_rms
,
scale
)[
0
]
return
self
.
quant_matcher
(
result_rms
,
scale
)[
0
]
def
replacement
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
):
def
replacement
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# In case we're matching native rms-norm, conversions might be
# In case we're matching native rms-norm, conversions might be
# optimized out. We convert here just to be safe.
# optimized out. We convert here just to be safe.
input
=
input
.
to
(
dtype
=
self
.
model_dtype
)
input
=
input
.
to
(
dtype
=
self
.
model_dtype
)
...
@@ -187,7 +193,9 @@ class RMSNormStaticQuantPattern(RMSNormQuantPattern):
...
@@ -187,7 +193,9 @@ class RMSNormStaticQuantPattern(RMSNormQuantPattern):
class
FusedAddRMSNormStaticQuantPattern
(
RMSNormQuantPattern
):
class
FusedAddRMSNormStaticQuantPattern
(
RMSNormQuantPattern
):
def
__init__
(
self
,
epsilon
:
float
,
quant_dtype
:
torch
.
dtype
,
symmetric
=
True
):
def
__init__
(
self
,
epsilon
:
float
,
quant_dtype
:
torch
.
dtype
,
symmetric
:
bool
=
True
)
->
None
:
key
=
FusedRMSQuantKey
(
key
=
FusedRMSQuantKey
(
fused_add
=
True
,
fused_add
=
True
,
quant
=
QuantKey
(
quant
=
QuantKey
(
...
@@ -196,13 +204,13 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
...
@@ -196,13 +204,13 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
)
)
super
().
__init__
(
epsilon
,
key
)
super
().
__init__
(
epsilon
,
key
)
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
def
pattern
(
def
pattern
(
input
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
result_rms
,
residual
=
self
.
rmsnorm_matcher
(
input
,
weight
,
residual
)
result_rms
,
residual
=
self
.
rmsnorm_matcher
(
input
,
weight
,
residual
)
result
,
_
=
self
.
quant_matcher
(
result_rms
,
scale
)
result
,
_
=
self
.
quant_matcher
(
result_rms
,
scale
)
...
@@ -213,7 +221,7 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
...
@@ -213,7 +221,7 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
weight
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
# In case we're matching native rms-norm, conversions might be
# In case we're matching native rms-norm, conversions might be
# optimized out. We convert here just to be safe.
# optimized out. We convert here just to be safe.
input
=
input
.
to
(
dtype
=
self
.
model_dtype
)
input
=
input
.
to
(
dtype
=
self
.
model_dtype
)
...
@@ -253,10 +261,10 @@ class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern):
...
@@ -253,10 +261,10 @@ class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern):
epsilon
:
float
,
epsilon
:
float
,
quant_dtype
:
torch
.
dtype
,
quant_dtype
:
torch
.
dtype
,
group_shape
:
GroupShape
,
group_shape
:
GroupShape
,
symmetric
=
True
,
symmetric
:
bool
=
True
,
has_col_major_scales
:
bool
=
False
,
has_col_major_scales
:
bool
=
False
,
is_e8m0
:
bool
=
False
,
is_e8m0
:
bool
=
False
,
):
)
->
None
:
scale
=
ScaleDesc
(
torch
.
float32
,
False
,
group_shape
)
scale
=
ScaleDesc
(
torch
.
float32
,
False
,
group_shape
)
key
=
FusedRMSQuantKey
(
key
=
FusedRMSQuantKey
(
fused_add
=
True
,
fused_add
=
True
,
...
@@ -269,15 +277,17 @@ class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern):
...
@@ -269,15 +277,17 @@ class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern):
epsilon
,
key
,
has_col_major_scales
=
has_col_major_scales
,
is_e8m0
=
is_e8m0
epsilon
,
key
,
has_col_major_scales
=
has_col_major_scales
,
is_e8m0
=
is_e8m0
)
)
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
def
pattern
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
):
def
pattern
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
result_rms
,
residual
=
self
.
rmsnorm_matcher
(
input
,
weight
,
residual
)
result_rms
,
residual
=
self
.
rmsnorm_matcher
(
input
,
weight
,
residual
)
result
,
scale
=
self
.
quant_matcher
(
result_rms
)
result
,
scale
=
self
.
quant_matcher
(
result_rms
)
return
result
,
residual
,
scale
return
result
,
residual
,
scale
def
replacement
(
def
replacement
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
:
# In case we're matching native rms-norm, conversions might be
# In case we're matching native rms-norm, conversions might be
# optimized out. We convert here just to be safe.
# optimized out. We convert here just to be safe.
input
=
input
.
to
(
dtype
=
self
.
model_dtype
)
input
=
input
.
to
(
dtype
=
self
.
model_dtype
)
...
@@ -315,10 +325,10 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern):
...
@@ -315,10 +325,10 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern):
epsilon
:
float
,
epsilon
:
float
,
quant_dtype
:
torch
.
dtype
,
quant_dtype
:
torch
.
dtype
,
group_shape
:
GroupShape
,
group_shape
:
GroupShape
,
symmetric
=
True
,
symmetric
:
bool
=
True
,
has_col_major_scales
:
bool
=
False
,
has_col_major_scales
:
bool
=
False
,
is_e8m0
:
bool
=
False
,
is_e8m0
:
bool
=
False
,
):
)
->
None
:
scale
=
ScaleDesc
(
torch
.
float32
,
False
,
group_shape
)
scale
=
ScaleDesc
(
torch
.
float32
,
False
,
group_shape
)
key
=
FusedRMSQuantKey
(
key
=
FusedRMSQuantKey
(
fused_add
=
False
,
fused_add
=
False
,
...
@@ -329,13 +339,17 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern):
...
@@ -329,13 +339,17 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern):
epsilon
,
key
,
has_col_major_scales
=
has_col_major_scales
,
is_e8m0
=
is_e8m0
epsilon
,
key
,
has_col_major_scales
=
has_col_major_scales
,
is_e8m0
=
is_e8m0
)
)
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
def
pattern
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
):
def
pattern
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
result_rms
=
self
.
rmsnorm_matcher
(
input
,
weight
)
result_rms
=
self
.
rmsnorm_matcher
(
input
,
weight
)
result
,
scale
=
self
.
quant_matcher
(
result_rms
)
result
,
scale
=
self
.
quant_matcher
(
result_rms
)
return
result
,
scale
return
result
,
scale
def
replacement
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
):
def
replacement
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# In case we're matching native rms-norm, conversions might be
# In case we're matching native rms-norm, conversions might be
# optimized out. We convert here just to be safe.
# optimized out. We convert here just to be safe.
input
=
input
.
to
(
dtype
=
self
.
model_dtype
)
input
=
input
.
to
(
dtype
=
self
.
model_dtype
)
...
@@ -375,8 +389,8 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
...
@@ -375,8 +389,8 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
epsilon
:
float
,
epsilon
:
float
,
quant_dtype
:
torch
.
dtype
,
quant_dtype
:
torch
.
dtype
,
group_shape
:
GroupShape
=
GroupShape
.
PER_TOKEN
,
group_shape
:
GroupShape
=
GroupShape
.
PER_TOKEN
,
symmetric
=
True
,
symmetric
:
bool
=
True
,
):
)
->
None
:
scale
=
ScaleDesc
(
torch
.
float32
,
False
,
group_shape
)
scale
=
ScaleDesc
(
torch
.
float32
,
False
,
group_shape
)
key
=
FusedRMSQuantKey
(
key
=
FusedRMSQuantKey
(
fused_add
=
False
,
fused_add
=
False
,
...
@@ -384,13 +398,17 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
...
@@ -384,13 +398,17 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
)
)
super
().
__init__
(
epsilon
,
key
)
super
().
__init__
(
epsilon
,
key
)
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
def
pattern
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
):
def
pattern
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
result_rms
=
self
.
rmsnorm_matcher
(
input
,
weight
)
result_rms
=
self
.
rmsnorm_matcher
(
input
,
weight
)
# result, scale
# result, scale
return
self
.
quant_matcher
(
result_rms
)
return
self
.
quant_matcher
(
result_rms
)
# type: ignore[no-any-return]
def
replacement
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
):
def
replacement
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# In case we're matching native rms-norm, conversions might be
# In case we're matching native rms-norm, conversions might be
# optimized out. We convert here just to be safe.
# optimized out. We convert here just to be safe.
input
=
input
.
to
(
dtype
=
self
.
model_dtype
)
input
=
input
.
to
(
dtype
=
self
.
model_dtype
)
...
@@ -426,8 +444,8 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
...
@@ -426,8 +444,8 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
epsilon
:
float
,
epsilon
:
float
,
quant_dtype
:
torch
.
dtype
,
quant_dtype
:
torch
.
dtype
,
group_shape
:
GroupShape
=
GroupShape
.
PER_TOKEN
,
group_shape
:
GroupShape
=
GroupShape
.
PER_TOKEN
,
symmetric
=
True
,
symmetric
:
bool
=
True
,
):
)
->
None
:
scale
=
ScaleDesc
(
torch
.
float32
,
False
,
group_shape
)
scale
=
ScaleDesc
(
torch
.
float32
,
False
,
group_shape
)
key
=
FusedRMSQuantKey
(
key
=
FusedRMSQuantKey
(
fused_add
=
True
,
fused_add
=
True
,
...
@@ -435,8 +453,10 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
...
@@ -435,8 +453,10 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
)
)
super
().
__init__
(
epsilon
,
key
)
super
().
__init__
(
epsilon
,
key
)
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
def
pattern
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
):
def
pattern
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
result_rms
,
residual
=
self
.
rmsnorm_matcher
(
input
,
weight
,
residual
)
result_rms
,
residual
=
self
.
rmsnorm_matcher
(
input
,
weight
,
residual
)
result
,
scale
=
self
.
quant_matcher
(
result_rms
)
result
,
scale
=
self
.
quant_matcher
(
result_rms
)
...
@@ -444,7 +464,7 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
...
@@ -444,7 +464,7 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
def
replacement
(
def
replacement
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
:
# In case we're matching native rms-norm, conversions might be
# In case we're matching native rms-norm, conversions might be
# optimized out. We convert here just to be safe.
# optimized out. We convert here just to be safe.
input
=
input
.
to
(
dtype
=
self
.
model_dtype
)
input
=
input
.
to
(
dtype
=
self
.
model_dtype
)
...
@@ -481,7 +501,7 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass):
...
@@ -481,7 +501,7 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass):
"""
"""
@
enable_fake_mode
@
enable_fake_mode
def
__init__
(
self
,
config
:
VllmConfig
):
def
__init__
(
self
,
config
:
VllmConfig
)
->
None
:
super
().
__init__
(
config
)
super
().
__init__
(
config
)
self
.
patterns
:
PatternMatcherPass
=
PatternMatcherPass
(
self
.
patterns
:
PatternMatcherPass
=
PatternMatcherPass
(
...
@@ -533,11 +553,11 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass):
...
@@ -533,11 +553,11 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass):
self
.
dump_patterns
(
config
,
self
.
patterns
)
self
.
dump_patterns
(
config
,
self
.
patterns
)
@
VllmInductorPass
.
time_and_log
@
VllmInductorPass
.
time_and_log
def
__call__
(
self
,
graph
:
fx
.
Graph
):
def
__call__
(
self
,
graph
:
fx
.
Graph
)
->
None
:
self
.
matched_count
=
self
.
patterns
.
apply
(
graph
)
self
.
matched_count
=
self
.
patterns
.
apply
(
graph
)
logger
.
debug
(
"Replaced %s patterns"
,
self
.
matched_count
)
logger
.
debug
(
"Replaced %s patterns"
,
self
.
matched_count
)
def
uuid
(
self
)
->
Any
:
def
uuid
(
self
)
->
str
:
return
self
.
hash_source
(
return
self
.
hash_source
(
self
,
self
,
RMSNormGroupQuantPattern
,
RMSNormGroupQuantPattern
,
...
...
vllm/compilation/fusion_attn.py
View file @
ad8818bb
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
collections.abc
import
Callable
from
collections.abc
import
Callable
from
typing
import
Any
,
ParamSpec
import
torch
import
torch
import
torch._inductor.pattern_matcher
as
pm
import
torch._inductor.pattern_matcher
as
pm
...
@@ -28,7 +29,7 @@ from .matcher_utils import MatcherQuantFP8
...
@@ -28,7 +29,7 @@ from .matcher_utils import MatcherQuantFP8
from
.vllm_inductor_pass
import
VllmInductorPass
,
VllmPatternMatcherPass
from
.vllm_inductor_pass
import
VllmInductorPass
,
VllmPatternMatcherPass
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
P
=
ParamSpec
(
"P"
)
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
FP4_DTYPE
=
torch
.
uint8
FP4_DTYPE
=
torch
.
uint8
...
@@ -47,7 +48,7 @@ class AttentionQuantPattern(ABC):
...
@@ -47,7 +48,7 @@ class AttentionQuantPattern(ABC):
layer
:
Attention
,
layer
:
Attention
,
quant_key
:
QuantKey
,
quant_key
:
QuantKey
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
):
)
->
None
:
self
.
layer
=
layer
self
.
layer
=
layer
self
.
layer_name
=
layer
.
layer_name
self
.
layer_name
=
layer
.
layer_name
self
.
num_heads
=
layer
.
num_heads
self
.
num_heads
=
layer
.
num_heads
...
@@ -61,17 +62,20 @@ class AttentionQuantPattern(ABC):
...
@@ -61,17 +62,20 @@ class AttentionQuantPattern(ABC):
)
)
self
.
QUANT_OP
=
QUANT_OPS
[
self
.
quant_key
]
self
.
QUANT_OP
=
QUANT_OPS
[
self
.
quant_key
]
def
empty
(
self
,
*
args
,
**
kwargs
)
:
def
empty
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
torch
.
Tensor
:
kwargs
=
{
"dtype"
:
self
.
dtype
,
"device"
:
"cuda"
,
**
kwargs
}
kwargs
=
{
"dtype"
:
self
.
dtype
,
"device"
:
"cuda"
,
**
kwargs
}
return
torch
.
empty
(
*
args
,
**
kwargs
)
return
torch
.
empty
(
*
args
,
**
kwargs
)
def
empty_quant
(
self
,
*
args
,
**
kwargs
)
:
def
empty_quant
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
torch
.
Tensor
:
kwargs
=
{
"dtype"
:
self
.
quant_dtype
,
"device"
:
"cuda"
,
**
kwargs
}
kwargs
=
{
"dtype"
:
self
.
quant_dtype
,
"device"
:
"cuda"
,
**
kwargs
}
return
torch
.
empty
(
*
args
,
**
kwargs
)
return
torch
.
empty
(
*
args
,
**
kwargs
)
@
staticmethod
@
staticmethod
def
wrap_trace_fn
(
trace_fn
,
*
process_fx_fns
:
Callable
[[
fx
.
GraphModule
],
None
]):
def
wrap_trace_fn
(
def
wrapped
(
*
args
,
**
kwargs
):
trace_fn
:
Callable
[
P
,
fx
.
GraphModule
],
*
process_fx_fns
:
Callable
[[
fx
.
GraphModule
],
None
],
)
->
Callable
[
P
,
fx
.
GraphModule
]:
def
wrapped
(
*
args
:
P
.
args
,
**
kwargs
:
P
.
kwargs
)
->
fx
.
GraphModule
:
gm
=
trace_fn
(
*
args
,
**
kwargs
)
gm
=
trace_fn
(
*
args
,
**
kwargs
)
for
process_fx
in
process_fx_fns
:
for
process_fx
in
process_fx_fns
:
process_fx
(
gm
)
process_fx
(
gm
)
...
@@ -81,13 +85,13 @@ class AttentionQuantPattern(ABC):
...
@@ -81,13 +85,13 @@ class AttentionQuantPattern(ABC):
return
wrapped
return
wrapped
@
staticmethod
@
staticmethod
def
fx_view_to_reshape
(
gm
:
torch
.
fx
.
GraphModule
):
def
fx_view_to_reshape
(
gm
:
torch
.
fx
.
GraphModule
)
->
None
:
from
torch._inductor.fx_passes.post_grad
import
view_to_reshape
from
torch._inductor.fx_passes.post_grad
import
view_to_reshape
view_to_reshape
(
gm
)
view_to_reshape
(
gm
)
@
staticmethod
@
staticmethod
def
remove_noop_permutes
(
gm
:
torch
.
fx
.
GraphModule
):
def
remove_noop_permutes
(
gm
:
torch
.
fx
.
GraphModule
)
->
None
:
for
node
in
gm
.
graph
.
nodes
:
for
node
in
gm
.
graph
.
nodes
:
if
not
is_func
(
node
,
torch
.
ops
.
aten
.
permute
.
default
):
if
not
is_func
(
node
,
torch
.
ops
.
aten
.
permute
.
default
):
continue
continue
...
@@ -100,12 +104,12 @@ class AttentionQuantPattern(ABC):
...
@@ -100,12 +104,12 @@ class AttentionQuantPattern(ABC):
node
.
replace_all_uses_with
(
node
.
args
[
0
])
node
.
replace_all_uses_with
(
node
.
args
[
0
])
gm
.
graph
.
erase_node
(
node
)
gm
.
graph
.
erase_node
(
node
)
def
register_if_supported
(
self
,
pm_pass
:
PatternMatcherPass
):
def
register_if_supported
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
if
self
.
layer
.
impl
.
fused_output_quant_supported
(
self
.
quant_key
):
if
self
.
layer
.
impl
.
fused_output_quant_supported
(
self
.
quant_key
):
self
.
_register
(
pm_pass
)
self
.
_register
(
pm_pass
)
@
abstractmethod
@
abstractmethod
def
_register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
_register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
raise
NotImplementedError
raise
NotImplementedError
...
@@ -124,21 +128,21 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
...
@@ -124,21 +128,21 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
layer
:
Attention
,
layer
:
Attention
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
symmetric
:
bool
=
True
,
symmetric
:
bool
=
True
,
):
)
->
None
:
quant_key
=
QuantKey
(
quant_key
=
QuantKey
(
dtype
=
FP8_DTYPE
,
scale
=
kStaticTensorScale
,
symmetric
=
symmetric
dtype
=
FP8_DTYPE
,
scale
=
kStaticTensorScale
,
symmetric
=
symmetric
)
)
super
().
__init__
(
layer
,
quant_key
,
dtype
)
super
().
__init__
(
layer
,
quant_key
,
dtype
)
self
.
quant_matcher
=
MatcherQuantFP8
(
quant_key
)
self
.
quant_matcher
=
MatcherQuantFP8
(
quant_key
)
def
_register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
_register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
def
pattern
(
def
pattern
(
q
:
torch
.
Tensor
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
output_attn
:
torch
.
Tensor
,
output_attn
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
):
)
->
torch
.
Tensor
:
at1
=
auto_functionalized
(
at1
=
auto_functionalized
(
ATTN_OP
,
ATTN_OP
,
query
=
q
,
query
=
q
,
...
@@ -161,7 +165,7 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
...
@@ -161,7 +165,7 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
v
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
output_attn
:
torch
.
Tensor
,
output_attn
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
):
)
->
torch
.
Tensor
:
# attn output in quant_dtype
# attn output in quant_dtype
output_attn
=
torch
.
ops
.
aten
.
full
.
default
(
output_attn
=
torch
.
ops
.
aten
.
full
.
default
(
[
q
.
shape
[
0
],
self
.
num_heads
,
self
.
head_size
],
[
q
.
shape
[
0
],
self
.
num_heads
,
self
.
head_size
],
...
@@ -212,10 +216,10 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
...
@@ -212,10 +216,10 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
will be passed into Attention op as the `output_scale` argument.
will be passed into Attention op as the `output_scale` argument.
"""
"""
def
__init__
(
self
,
layer
:
Attention
,
dtype
:
torch
.
dtype
):
def
__init__
(
self
,
layer
:
Attention
,
dtype
:
torch
.
dtype
)
->
None
:
super
().
__init__
(
layer
,
kNvfp4Quant
,
dtype
)
super
().
__init__
(
layer
,
kNvfp4Quant
,
dtype
)
def
_register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
_register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
def
pattern
(
def
pattern
(
q
:
torch
.
Tensor
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
...
@@ -224,7 +228,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
...
@@ -224,7 +228,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
output_quant
:
torch
.
Tensor
,
output_quant
:
torch
.
Tensor
,
output_scale
:
torch
.
Tensor
,
output_scale
:
torch
.
Tensor
,
input_scale
:
torch
.
Tensor
,
input_scale
:
torch
.
Tensor
,
):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
at1
=
auto_functionalized
(
at1
=
auto_functionalized
(
ATTN_OP
,
ATTN_OP
,
query
=
q
,
query
=
q
,
...
@@ -256,7 +260,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
...
@@ -256,7 +260,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
output_quant
:
torch
.
Tensor
,
output_quant
:
torch
.
Tensor
,
output_scale
:
torch
.
Tensor
,
output_scale
:
torch
.
Tensor
,
input_scale
:
torch
.
Tensor
,
input_scale
:
torch
.
Tensor
,
):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
# attention output in quant_dtype
# attention output in quant_dtype
output_attn
=
torch
.
ops
.
aten
.
full
.
default
(
output_attn
=
torch
.
ops
.
aten
.
full
.
default
(
[
q
.
shape
[
0
],
self
.
num_heads
,
self
.
head_size
//
2
],
[
q
.
shape
[
0
],
self
.
num_heads
,
self
.
head_size
//
2
],
...
@@ -318,7 +322,7 @@ class AttnFusionPass(VllmPatternMatcherPass):
...
@@ -318,7 +322,7 @@ class AttnFusionPass(VllmPatternMatcherPass):
"""
"""
@
enable_fake_mode
@
enable_fake_mode
def
__init__
(
self
,
config
:
VllmConfig
):
def
__init__
(
self
,
config
:
VllmConfig
)
->
None
:
super
().
__init__
(
config
)
super
().
__init__
(
config
)
self
.
patterns
=
PatternMatcherPass
(
pass_name
=
"attn_fusion_pass"
)
self
.
patterns
=
PatternMatcherPass
(
pass_name
=
"attn_fusion_pass"
)
...
@@ -350,7 +354,7 @@ class AttnFusionPass(VllmPatternMatcherPass):
...
@@ -350,7 +354,7 @@ class AttnFusionPass(VllmPatternMatcherPass):
self
.
matched_count
=
self
.
patterns
.
apply
(
graph
)
self
.
matched_count
=
self
.
patterns
.
apply
(
graph
)
logger
.
debug
(
"Fused quant onto %s attention nodes"
,
self
.
matched_count
)
logger
.
debug
(
"Fused quant onto %s attention nodes"
,
self
.
matched_count
)
def
uuid
(
self
):
def
uuid
(
self
)
->
str
:
return
VllmInductorPass
.
hash_source
(
return
VllmInductorPass
.
hash_source
(
self
,
self
,
AttentionQuantPattern
,
AttentionQuantPattern
,
...
...
vllm/compilation/inductor_pass.py
View file @
ad8818bb
...
@@ -68,7 +68,7 @@ class InductorPass(CustomGraphPass): # type: ignore[misc]
...
@@ -68,7 +68,7 @@ class InductorPass(CustomGraphPass): # type: ignore[misc]
This is defined as a convenience and should work in most cases.
This is defined as a convenience and should work in most cases.
"""
"""
def
uuid
(
self
)
->
Any
:
def
uuid
(
self
)
->
str
:
"""
"""
Provide a unique identifier for the pass, used in Inductor code cache.
Provide a unique identifier for the pass, used in Inductor code cache.
This should depend on the pass implementation, so that changes to the
This should depend on the pass implementation, so that changes to the
...
...
vllm/compilation/matcher_utils.py
View file @
ad8818bb
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
typing
import
Any
import
torch
import
torch
from
torch._higher_order_ops
import
auto_functionalized
from
torch._higher_order_ops
import
auto_functionalized
...
@@ -47,7 +48,7 @@ SILU_MUL_OP = torch.ops._C.silu_and_mul.default
...
@@ -47,7 +48,7 @@ SILU_MUL_OP = torch.ops._C.silu_and_mul.default
class
MatcherCustomOp
(
ABC
):
class
MatcherCustomOp
(
ABC
):
def
__init__
(
self
,
enabled
:
bool
):
def
__init__
(
self
,
enabled
:
bool
)
->
None
:
config
=
get_current_vllm_config
()
config
=
get_current_vllm_config
()
self
.
model_dtype
=
config
.
model_config
.
dtype
if
config
.
model_config
else
None
self
.
model_dtype
=
config
.
model_config
.
dtype
if
config
.
model_config
else
None
self
.
device
=
config
.
device_config
.
device
if
config
.
device_config
else
None
self
.
device
=
config
.
device_config
.
device
if
config
.
device_config
else
None
...
@@ -56,24 +57,24 @@ class MatcherCustomOp(ABC):
...
@@ -56,24 +57,24 @@ class MatcherCustomOp(ABC):
self
.
forward
=
self
.
forward_custom
if
enabled
else
self
.
forward_native
self
.
forward
=
self
.
forward_custom
if
enabled
else
self
.
forward_native
@
abstractmethod
@
abstractmethod
def
forward_custom
(
self
,
*
args
,
**
kw
s
)
:
def
forward_custom
(
self
,
*
args
:
Any
,
**
kw
args
:
Any
)
->
Any
:
pass
pass
@
abstractmethod
@
abstractmethod
def
forward_native
(
self
,
*
args
,
**
kw
s
)
:
def
forward_native
(
self
,
*
args
:
Any
,
**
kw
args
:
Any
)
->
Any
:
pass
pass
def
__call__
(
self
,
*
args
,
**
kw
s
)
:
def
__call__
(
self
,
*
args
:
Any
,
**
kw
args
:
Any
)
->
Any
:
return
self
.
forward
(
*
args
,
**
kws
)
return
self
.
forward
(
*
args
,
**
kw
arg
s
)
def
empty
(
self
,
*
args
,
**
kw
s
)
:
def
empty
(
self
,
*
args
:
Any
,
**
kw
args
:
Any
)
->
torch
.
Tensor
:
return
torch
.
empty
(
*
args
,
dtype
=
self
.
model_dtype
,
device
=
self
.
device
,
**
kws
)
return
torch
.
empty
(
*
args
,
dtype
=
self
.
model_dtype
,
device
=
self
.
device
,
**
kw
arg
s
)
def
empty_int64
(
self
,
*
args
,
**
kw
s
)
:
def
empty_int64
(
self
,
*
args
:
Any
,
**
kw
args
:
Any
)
->
torch
.
Tensor
:
return
torch
.
empty
(
*
args
,
dtype
=
torch
.
int64
,
device
=
self
.
device
,
**
kws
)
return
torch
.
empty
(
*
args
,
dtype
=
torch
.
int64
,
device
=
self
.
device
,
**
kw
arg
s
)
def
empty_f32
(
self
,
*
args
,
**
kw
s
)
:
def
empty_f32
(
self
,
*
args
:
Any
,
**
kw
args
:
Any
)
->
torch
.
Tensor
:
return
torch
.
empty
(
*
args
,
dtype
=
torch
.
float32
,
device
=
self
.
device
,
**
kws
)
return
torch
.
empty
(
*
args
,
dtype
=
torch
.
float32
,
device
=
self
.
device
,
**
kw
arg
s
)
def
inputs
(
self
)
->
list
[
torch
.
Tensor
]:
def
inputs
(
self
)
->
list
[
torch
.
Tensor
]:
"""Utility for inputs to the pattern"""
"""Utility for inputs to the pattern"""
...
@@ -157,7 +158,7 @@ class MatcherRMSNorm(MatcherCustomOp):
...
@@ -157,7 +158,7 @@ class MatcherRMSNorm(MatcherCustomOp):
epsilon
:
float
,
epsilon
:
float
,
enabled
:
bool
|
None
=
None
,
enabled
:
bool
|
None
=
None
,
match_rocm_aiter
:
bool
=
False
,
match_rocm_aiter
:
bool
=
False
,
):
)
->
None
:
if
enabled
is
None
:
if
enabled
is
None
:
enabled
=
RMSNorm
.
enabled
()
enabled
=
RMSNorm
.
enabled
()
...
@@ -169,7 +170,7 @@ class MatcherRMSNorm(MatcherCustomOp):
...
@@ -169,7 +170,7 @@ class MatcherRMSNorm(MatcherCustomOp):
if
match_rocm_aiter
:
if
match_rocm_aiter
:
self
.
_rmsnorm_op
=
rocm_aiter_ops
.
get_rmsnorm_op
()
self
.
_rmsnorm_op
=
rocm_aiter_ops
.
get_rmsnorm_op
()
def
inputs
(
self
):
def
inputs
(
self
)
->
list
[
torch
.
Tensor
]
:
input
=
self
.
empty
(
5
,
16
)
if
self
.
enabled
else
self
.
empty_f32
(
5
,
16
)
input
=
self
.
empty
(
5
,
16
)
if
self
.
enabled
else
self
.
empty_f32
(
5
,
16
)
weight
=
self
.
empty
(
16
)
weight
=
self
.
empty
(
16
)
return
[
input
,
weight
]
return
[
input
,
weight
]
...
@@ -220,7 +221,7 @@ class MatcherFusedAddRMSNorm(MatcherCustomOp):
...
@@ -220,7 +221,7 @@ class MatcherFusedAddRMSNorm(MatcherCustomOp):
epsilon
:
float
,
epsilon
:
float
,
enabled
:
bool
|
None
=
None
,
enabled
:
bool
|
None
=
None
,
match_rocm_aiter
:
bool
=
False
,
match_rocm_aiter
:
bool
=
False
,
):
)
->
None
:
if
enabled
is
None
:
if
enabled
is
None
:
enabled
=
RMSNorm
.
enabled
()
enabled
=
RMSNorm
.
enabled
()
...
@@ -233,7 +234,7 @@ class MatcherFusedAddRMSNorm(MatcherCustomOp):
...
@@ -233,7 +234,7 @@ class MatcherFusedAddRMSNorm(MatcherCustomOp):
if
match_rocm_aiter
:
if
match_rocm_aiter
:
self
.
_rmsnorm_op
=
rocm_aiter_ops
.
get_rmsnorm_fused_add_op
()
self
.
_rmsnorm_op
=
rocm_aiter_ops
.
get_rmsnorm_fused_add_op
()
def
inputs
(
self
):
def
inputs
(
self
)
->
list
[
torch
.
Tensor
]
:
input
=
self
.
empty
(
5
,
16
)
if
self
.
enabled
else
self
.
empty_f32
(
5
,
16
)
input
=
self
.
empty
(
5
,
16
)
if
self
.
enabled
else
self
.
empty_f32
(
5
,
16
)
weight
=
self
.
empty
(
16
)
weight
=
self
.
empty
(
16
)
residual
=
self
.
empty
(
5
,
16
)
residual
=
self
.
empty
(
5
,
16
)
...
@@ -245,7 +246,7 @@ class MatcherFusedAddRMSNorm(MatcherCustomOp):
...
@@ -245,7 +246,7 @@ class MatcherFusedAddRMSNorm(MatcherCustomOp):
weight
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
return
self
.
_rmsnorm_op
(
return
self
.
_rmsnorm_op
(
# type: ignore[no-any-return]
x
=
input
,
residual
=
residual
,
weight
=
weight
,
variance_epsilon
=
self
.
epsilon
x
=
input
,
residual
=
residual
,
weight
=
weight
,
variance_epsilon
=
self
.
epsilon
)
)
...
@@ -287,7 +288,7 @@ class MatcherQuantFP8(MatcherCustomOp):
...
@@ -287,7 +288,7 @@ class MatcherQuantFP8(MatcherCustomOp):
has_col_major_scales
:
bool
=
False
,
has_col_major_scales
:
bool
=
False
,
is_e8m0
:
bool
=
False
,
is_e8m0
:
bool
=
False
,
match_rocm_aiter
:
bool
=
False
,
match_rocm_aiter
:
bool
=
False
,
):
)
->
None
:
if
enabled
is
None
:
if
enabled
is
None
:
enabled
=
QuantFP8
.
enabled
()
enabled
=
QuantFP8
.
enabled
()
...
@@ -340,13 +341,13 @@ class MatcherQuantFP8(MatcherCustomOp):
...
@@ -340,13 +341,13 @@ class MatcherQuantFP8(MatcherCustomOp):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
quant_key_group_shape
=
self
.
quant_key
.
scale
.
group_shape
quant_key_group_shape
=
self
.
quant_key
.
scale
.
group_shape
if
quant_key_group_shape
==
GroupShape
.
PER_TOKEN
:
if
quant_key_group_shape
==
GroupShape
.
PER_TOKEN
:
return
self
.
QUANT_OP
(
return
self
.
QUANT_OP
(
# type: ignore[no-any-return]
x
=
input
,
x
=
input
,
quant_dtype
=
self
.
quant_key
.
dtype
,
quant_dtype
=
self
.
quant_key
.
dtype
,
scale
=
scale
,
scale
=
scale
,
)
)
else
:
else
:
return
self
.
QUANT_OP
(
input
,
quant_key_group_shape
.
col
)
return
self
.
QUANT_OP
(
input
,
quant_key_group_shape
.
col
)
# type: ignore[no-any-return]
def
forward_custom
(
def
forward_custom
(
self
,
self
,
...
@@ -400,9 +401,9 @@ class MatcherQuantFP8(MatcherCustomOp):
...
@@ -400,9 +401,9 @@ class MatcherQuantFP8(MatcherCustomOp):
input
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
|
None
=
None
,
scale
:
torch
.
Tensor
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
return
self
.
quant_fp8
(
input
,
scale
)
return
self
.
quant_fp8
(
input
,
scale
)
# type: ignore[no-any-return]
def
make_scale
(
self
,
input
:
torch
.
Tensor
,
transposed
:
bool
=
False
):
def
make_scale
(
self
,
input
:
torch
.
Tensor
,
transposed
:
bool
=
False
)
->
torch
.
Tensor
:
normalized_group_shape
=
_normalize_quant_group_shape
(
normalized_group_shape
=
_normalize_quant_group_shape
(
input
,
self
.
quant_key
.
scale
.
group_shape
input
,
self
.
quant_key
.
scale
.
group_shape
)
)
...
@@ -427,7 +428,7 @@ class MatcherQuantFP8(MatcherCustomOp):
...
@@ -427,7 +428,7 @@ class MatcherQuantFP8(MatcherCustomOp):
class
MatcherSiluAndMul
(
MatcherCustomOp
):
class
MatcherSiluAndMul
(
MatcherCustomOp
):
def
__init__
(
self
,
enabled
:
bool
|
None
=
None
):
def
__init__
(
self
,
enabled
:
bool
|
None
=
None
)
->
None
:
if
enabled
is
None
:
if
enabled
is
None
:
enabled
=
SiluAndMul
.
enabled
()
enabled
=
SiluAndMul
.
enabled
()
super
().
__init__
(
enabled
)
super
().
__init__
(
enabled
)
...
...
vllm/compilation/qk_norm_rope_fusion.py
View file @
ad8818bb
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Callable
from
collections.abc
import
Callable
from
typing
import
ParamSpec
import
torch
import
torch
import
torch._inductor.pattern_matcher
as
pm
import
torch._inductor.pattern_matcher
as
pm
...
@@ -23,6 +24,8 @@ logger = init_logger(__name__)
...
@@ -23,6 +24,8 @@ logger = init_logger(__name__)
FUSED_QK_ROPE_OP
=
torch
.
ops
.
_C
.
fused_qk_norm_rope
.
default
FUSED_QK_ROPE_OP
=
torch
.
ops
.
_C
.
fused_qk_norm_rope
.
default
P
=
ParamSpec
(
"P"
)
class
QkNormRopePattern
:
class
QkNormRopePattern
:
"""
"""
...
@@ -72,7 +75,7 @@ class QkNormRopePattern:
...
@@ -72,7 +75,7 @@ class QkNormRopePattern:
use_flashinfer
=
self
.
rope_flashinfer
,
use_flashinfer
=
self
.
rope_flashinfer
,
)
)
def
get_inputs
(
self
):
def
get_inputs
(
self
)
->
list
[
torch
.
Tensor
]
:
# Sample inputs to help pattern tracing
# Sample inputs to help pattern tracing
T
=
5
T
=
5
qkv
=
empty_bf16
(
T
,
self
.
q_size
+
2
*
self
.
kv_size
)
qkv
=
empty_bf16
(
T
,
self
.
q_size
+
2
*
self
.
kv_size
)
...
@@ -92,8 +95,11 @@ class QkNormRopePattern:
...
@@ -92,8 +95,11 @@ class QkNormRopePattern:
]
]
@
staticmethod
@
staticmethod
def
wrap_trace_fn
(
trace_fn
,
*
process_fx_fns
:
Callable
[[
fx
.
GraphModule
],
None
]):
def
wrap_trace_fn
(
def
wrapped
(
*
args
,
**
kwargs
):
trace_fn
:
Callable
[
P
,
fx
.
GraphModule
],
*
process_fx_fns
:
Callable
[[
fx
.
GraphModule
],
None
],
)
->
Callable
[
P
,
fx
.
GraphModule
]:
def
wrapped
(
*
args
:
P
.
args
,
**
kwargs
:
P
.
kwargs
)
->
fx
.
GraphModule
:
gm
=
trace_fn
(
*
args
,
**
kwargs
)
gm
=
trace_fn
(
*
args
,
**
kwargs
)
for
process_fx
in
process_fx_fns
:
for
process_fx
in
process_fx_fns
:
process_fx
(
gm
)
process_fx
(
gm
)
...
@@ -103,19 +109,19 @@ class QkNormRopePattern:
...
@@ -103,19 +109,19 @@ class QkNormRopePattern:
return
wrapped
return
wrapped
@
staticmethod
@
staticmethod
def
fx_view_to_reshape
(
gm
:
torch
.
fx
.
GraphModule
):
def
fx_view_to_reshape
(
gm
:
torch
.
fx
.
GraphModule
)
->
None
:
from
torch._inductor.fx_passes.post_grad
import
view_to_reshape
from
torch._inductor.fx_passes.post_grad
import
view_to_reshape
view_to_reshape
(
gm
)
view_to_reshape
(
gm
)
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
def
pattern
(
def
pattern
(
qkv
:
torch
.
Tensor
,
qkv
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
q_weight
:
torch
.
Tensor
,
q_weight
:
torch
.
Tensor
,
k_weight
:
torch
.
Tensor
,
k_weight
:
torch
.
Tensor
,
cos_sin_cache
:
torch
.
Tensor
,
cos_sin_cache
:
torch
.
Tensor
,
):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
:
# split qkv -> q,k,v
# split qkv -> q,k,v
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
...
@@ -143,7 +149,7 @@ class QkNormRopePattern:
...
@@ -143,7 +149,7 @@ class QkNormRopePattern:
q_weight
:
torch
.
Tensor
,
q_weight
:
torch
.
Tensor
,
k_weight
:
torch
.
Tensor
,
k_weight
:
torch
.
Tensor
,
cos_sin_cache
:
torch
.
Tensor
,
cos_sin_cache
:
torch
.
Tensor
,
):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
:
# Run fused qk_norm_rope op
# Run fused qk_norm_rope op
result
=
auto_functionalized
(
result
=
auto_functionalized
(
FUSED_QK_ROPE_OP
,
FUSED_QK_ROPE_OP
,
...
@@ -162,7 +168,7 @@ class QkNormRopePattern:
...
@@ -162,7 +168,7 @@ class QkNormRopePattern:
result_qkv
=
result
[
1
]
result_qkv
=
result
[
1
]
# Split back to q,k,v and return
# Split back to q,k,v and return
return
result_qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
return
result_qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
# type: ignore[no-any-return]
# NOTE: use fx_view_to_reshape to unify view/reshape to simplify
# NOTE: use fx_view_to_reshape to unify view/reshape to simplify
# pattern and increase matching opportunities
# pattern and increase matching opportunities
...
@@ -182,7 +188,7 @@ class QKNormRoPEFusionPass(VllmPatternMatcherPass):
...
@@ -182,7 +188,7 @@ class QKNormRoPEFusionPass(VllmPatternMatcherPass):
"""Fuse Q/K RMSNorm + RoPE into fused_qk_norm_rope when the custom op exists."""
"""Fuse Q/K RMSNorm + RoPE into fused_qk_norm_rope when the custom op exists."""
@
enable_fake_mode
@
enable_fake_mode
def
__init__
(
self
,
config
:
VllmConfig
):
def
__init__
(
self
,
config
:
VllmConfig
)
->
None
:
super
().
__init__
(
config
)
super
().
__init__
(
config
)
self
.
patterns
:
PatternMatcherPass
=
PatternMatcherPass
(
self
.
patterns
:
PatternMatcherPass
=
PatternMatcherPass
(
pass_name
=
"qk_norm_rope_fusion_pass"
pass_name
=
"qk_norm_rope_fusion_pass"
...
@@ -234,5 +240,5 @@ class QKNormRoPEFusionPass(VllmPatternMatcherPass):
...
@@ -234,5 +240,5 @@ class QKNormRoPEFusionPass(VllmPatternMatcherPass):
self
.
matched_count
=
self
.
patterns
.
apply
(
graph
)
self
.
matched_count
=
self
.
patterns
.
apply
(
graph
)
logger
.
debug
(
"Fused QK Norm+RoPE on %s sites"
,
self
.
matched_count
)
logger
.
debug
(
"Fused QK Norm+RoPE on %s sites"
,
self
.
matched_count
)
def
uuid
(
self
):
def
uuid
(
self
)
->
str
:
return
VllmInductorPass
.
hash_source
(
self
,
QkNormRopePattern
)
return
VllmInductorPass
.
hash_source
(
self
,
QkNormRopePattern
)
vllm/compilation/rocm_aiter_fusion.py
View file @
ad8818bb
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Any
import
torch
import
torch
import
torch._inductor.pattern_matcher
as
pm
import
torch._inductor.pattern_matcher
as
pm
...
@@ -65,8 +64,8 @@ class AiterRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern):
...
@@ -65,8 +64,8 @@ class AiterRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern):
quant_dtype
:
torch
.
dtype
,
quant_dtype
:
torch
.
dtype
,
match_aiter_quant
:
bool
=
True
,
match_aiter_quant
:
bool
=
True
,
group_shape
:
GroupShape
=
GroupShape
.
PER_TOKEN
,
group_shape
:
GroupShape
=
GroupShape
.
PER_TOKEN
,
symmetric
=
True
,
symmetric
:
bool
=
True
,
):
)
->
None
:
scale
=
ScaleDesc
(
torch
.
float32
,
False
,
group_shape
)
scale
=
ScaleDesc
(
torch
.
float32
,
False
,
group_shape
)
key
=
FusedRMSQuantKey
(
key
=
FusedRMSQuantKey
(
fused_add
=
False
,
fused_add
=
False
,
...
@@ -75,11 +74,11 @@ class AiterRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern):
...
@@ -75,11 +74,11 @@ class AiterRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern):
super
().
__init__
(
epsilon
,
key
,
match_aiter_quant
)
super
().
__init__
(
epsilon
,
key
,
match_aiter_quant
)
def
register
(
self
,
pm_pass
)
:
def
register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
def
pattern
(
def
pattern
(
input
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
result_rms
=
self
.
rmsnorm_matcher
(
input
,
weight
)
result_rms
=
self
.
rmsnorm_matcher
(
input
,
weight
)
result
,
scale
=
self
.
quant_matcher
(
result_rms
)
result
,
scale
=
self
.
quant_matcher
(
result_rms
)
return
result
,
scale
return
result
,
scale
...
@@ -87,7 +86,7 @@ class AiterRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern):
...
@@ -87,7 +86,7 @@ class AiterRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern):
def
replacement
(
def
replacement
(
input
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
result
=
self
.
FUSED_OP
(
result
=
self
.
FUSED_OP
(
x
=
input
,
x
=
input
,
weight
=
weight
,
weight
=
weight
,
...
@@ -117,8 +116,8 @@ class AiterFusedAddRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern):
...
@@ -117,8 +116,8 @@ class AiterFusedAddRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern):
quant_dtype
:
torch
.
dtype
,
quant_dtype
:
torch
.
dtype
,
match_aiter_quant
:
bool
=
True
,
match_aiter_quant
:
bool
=
True
,
group_shape
:
GroupShape
=
GroupShape
.
PER_TOKEN
,
group_shape
:
GroupShape
=
GroupShape
.
PER_TOKEN
,
symmetric
=
True
,
symmetric
:
bool
=
True
,
):
)
->
None
:
scale
=
ScaleDesc
(
torch
.
float32
,
False
,
group_shape
)
scale
=
ScaleDesc
(
torch
.
float32
,
False
,
group_shape
)
key
=
FusedRMSQuantKey
(
key
=
FusedRMSQuantKey
(
fused_add
=
True
,
fused_add
=
True
,
...
@@ -127,12 +126,12 @@ class AiterFusedAddRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern):
...
@@ -127,12 +126,12 @@ class AiterFusedAddRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern):
super
().
__init__
(
epsilon
,
key
,
match_aiter_quant
)
super
().
__init__
(
epsilon
,
key
,
match_aiter_quant
)
def
register
(
self
,
pm_pass
)
:
def
register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
def
pattern
(
def
pattern
(
input
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
:
result_rms
,
residual_out
=
self
.
rmsnorm_matcher
(
input
,
weight
,
residual
)
result_rms
,
residual_out
=
self
.
rmsnorm_matcher
(
input
,
weight
,
residual
)
result
,
scale
=
self
.
quant_matcher
(
result_rms
)
result
,
scale
=
self
.
quant_matcher
(
result_rms
)
...
@@ -140,7 +139,7 @@ class AiterFusedAddRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern):
...
@@ -140,7 +139,7 @@ class AiterFusedAddRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern):
def
replacement
(
def
replacement
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
:
result
=
self
.
FUSED_OP
(
result
=
self
.
FUSED_OP
(
x
=
input
,
x
=
input
,
residual
=
residual
,
residual
=
residual
,
...
@@ -174,8 +173,8 @@ class AiterRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern):
...
@@ -174,8 +173,8 @@ class AiterRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern):
quant_dtype
:
torch
.
dtype
,
quant_dtype
:
torch
.
dtype
,
group_shape
:
GroupShape
,
group_shape
:
GroupShape
,
match_aiter_quant
:
bool
=
True
,
match_aiter_quant
:
bool
=
True
,
symmetric
=
True
,
symmetric
:
bool
=
True
,
):
)
->
None
:
scale
=
ScaleDesc
(
torch
.
float32
,
False
,
group_shape
)
scale
=
ScaleDesc
(
torch
.
float32
,
False
,
group_shape
)
key
=
FusedRMSQuantKey
(
key
=
FusedRMSQuantKey
(
fused_add
=
False
,
fused_add
=
False
,
...
@@ -184,11 +183,11 @@ class AiterRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern):
...
@@ -184,11 +183,11 @@ class AiterRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern):
super
().
__init__
(
epsilon
,
key
,
match_aiter_quant
)
super
().
__init__
(
epsilon
,
key
,
match_aiter_quant
)
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
def
pattern
(
def
pattern
(
input
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
result_rms
=
self
.
rmsnorm_matcher
(
input
,
weight
)
result_rms
=
self
.
rmsnorm_matcher
(
input
,
weight
)
result
,
scale
=
self
.
quant_matcher
(
result_rms
)
result
,
scale
=
self
.
quant_matcher
(
result_rms
)
return
result
,
scale
return
result
,
scale
...
@@ -196,7 +195,7 @@ class AiterRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern):
...
@@ -196,7 +195,7 @@ class AiterRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern):
def
replacement
(
def
replacement
(
input
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
at
=
self
.
FUSED_OP
(
at
=
self
.
FUSED_OP
(
x
=
input
,
x
=
input
,
weight
=
weight
,
weight
=
weight
,
...
@@ -225,8 +224,8 @@ class AiterFusedAddRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern):
...
@@ -225,8 +224,8 @@ class AiterFusedAddRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern):
quant_dtype
:
torch
.
dtype
,
quant_dtype
:
torch
.
dtype
,
group_shape
:
GroupShape
,
group_shape
:
GroupShape
,
match_aiter_quant
:
bool
=
True
,
match_aiter_quant
:
bool
=
True
,
symmetric
=
True
,
symmetric
:
bool
=
True
,
):
)
->
None
:
scale
=
ScaleDesc
(
torch
.
float32
,
False
,
group_shape
)
scale
=
ScaleDesc
(
torch
.
float32
,
False
,
group_shape
)
key
=
FusedRMSQuantKey
(
key
=
FusedRMSQuantKey
(
fused_add
=
True
,
fused_add
=
True
,
...
@@ -235,12 +234,12 @@ class AiterFusedAddRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern):
...
@@ -235,12 +234,12 @@ class AiterFusedAddRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern):
super
().
__init__
(
epsilon
,
key
,
match_aiter_quant
)
super
().
__init__
(
epsilon
,
key
,
match_aiter_quant
)
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
def
pattern
(
def
pattern
(
input
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
:
result_rms
,
residual_out
=
self
.
rmsnorm_matcher
(
input
,
weight
,
residual
)
result_rms
,
residual_out
=
self
.
rmsnorm_matcher
(
input
,
weight
,
residual
)
result
,
scale
=
self
.
quant_matcher
(
result_rms
)
result
,
scale
=
self
.
quant_matcher
(
result_rms
)
...
@@ -250,7 +249,7 @@ class AiterFusedAddRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern):
...
@@ -250,7 +249,7 @@ class AiterFusedAddRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern):
input
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
:
at
=
self
.
FUSED_OP
(
at
=
self
.
FUSED_OP
(
x
=
input
,
x
=
input
,
residual
=
residual
,
residual
=
residual
,
...
@@ -275,7 +274,7 @@ class RocmAiterRMSNormFusionPass(VllmPatternMatcherPass):
...
@@ -275,7 +274,7 @@ class RocmAiterRMSNormFusionPass(VllmPatternMatcherPass):
"""
"""
@
enable_fake_mode
@
enable_fake_mode
def
__init__
(
self
,
config
:
VllmConfig
):
def
__init__
(
self
,
config
:
VllmConfig
)
->
None
:
super
().
__init__
(
config
)
super
().
__init__
(
config
)
self
.
patterns
:
PatternMatcherPass
=
PatternMatcherPass
(
self
.
patterns
:
PatternMatcherPass
=
PatternMatcherPass
(
...
@@ -311,11 +310,11 @@ class RocmAiterRMSNormFusionPass(VllmPatternMatcherPass):
...
@@ -311,11 +310,11 @@ class RocmAiterRMSNormFusionPass(VllmPatternMatcherPass):
self
.
dump_patterns
(
config
,
self
.
patterns
)
self
.
dump_patterns
(
config
,
self
.
patterns
)
@
VllmInductorPass
.
time_and_log
@
VllmInductorPass
.
time_and_log
def
__call__
(
self
,
graph
:
fx
.
Graph
):
def
__call__
(
self
,
graph
:
fx
.
Graph
)
->
None
:
self
.
matched_count
=
self
.
patterns
.
apply
(
graph
)
self
.
matched_count
=
self
.
patterns
.
apply
(
graph
)
logger
.
debug
(
"Replaced %s patterns"
,
self
.
matched_count
)
logger
.
debug
(
"Replaced %s patterns"
,
self
.
matched_count
)
def
uuid
(
self
)
->
Any
:
def
uuid
(
self
)
->
str
:
fusion_patterns
=
[
fusion_patterns
=
[
AiterRMSNormDynamicQuantPattern
,
AiterRMSNormDynamicQuantPattern
,
AiterFusedAddRMSNormDynamicQuantPattern
,
AiterFusedAddRMSNormDynamicQuantPattern
,
...
@@ -333,29 +332,32 @@ class AiterSiluMulFp8GroupQuantPattern(ActivationQuantPattern):
...
@@ -333,29 +332,32 @@ class AiterSiluMulFp8GroupQuantPattern(ActivationQuantPattern):
FUSED_SILU_MUL_QUANT_OP
=
rocm_aiter_ops
.
get_act_mul_fused_fp8_group_quant_op
()
FUSED_SILU_MUL_QUANT_OP
=
rocm_aiter_ops
.
get_act_mul_fused_fp8_group_quant_op
()
def
__init__
(
self
,
quant_op
:
OpOverload
):
def
__init__
(
self
,
quant_op
:
OpOverload
)
->
None
:
self
.
silu_and_mul_matcher
=
MatcherSiluAndMul
()
self
.
silu_and_mul_matcher
=
MatcherSiluAndMul
()
self
.
quant_op
=
quant_op
self
.
quant_op
=
quant_op
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
get_inputs
(
self
)
->
list
[
torch
.
Tensor
]:
return
[
self
.
silu_and_mul_matcher
.
inputs
()[
0
],
]
def
register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
def
pattern
(
def
pattern
(
input
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
at1
=
self
.
silu_and_mul_matcher
(
input
)
at1
=
self
.
silu_and_mul_matcher
(
input
)
at2
=
self
.
quant_op
(
at1
,
128
)
at2
=
self
.
quant_op
(
at1
,
128
)
return
at2
[
0
],
at2
[
1
]
return
at2
[
0
],
at2
[
1
]
def
replacement
(
def
replacement
(
input
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
at
=
self
.
FUSED_SILU_MUL_QUANT_OP
(
x
=
input
,
group_size
=
128
)
at
=
self
.
FUSED_SILU_MUL_QUANT_OP
(
x
=
input
,
group_size
=
128
)
return
at
[
0
],
at
[
1
]
return
at
[
0
],
at
[
1
]
inputs
=
[
pm
.
register_replacement
(
self
.
silu_and_mul_matcher
.
inputs
()[
0
],
pattern
,
replacement
,
self
.
get_inputs
(),
pm
.
fwd_only
,
pm_pass
]
)
pm
.
register_replacement
(
pattern
,
replacement
,
inputs
,
pm
.
fwd_only
,
pm_pass
)
class
RocmAiterSiluMulFp8GroupQuantFusionPass
(
VllmPatternMatcherPass
):
class
RocmAiterSiluMulFp8GroupQuantFusionPass
(
VllmPatternMatcherPass
):
...
@@ -374,7 +376,7 @@ class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass):
...
@@ -374,7 +376,7 @@ class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass):
QUANT_OPS
=
[
AITER_GROUP_FP8_QUANT_OP
,
TRITON_GROUP_FP8_QUANT_OP
]
QUANT_OPS
=
[
AITER_GROUP_FP8_QUANT_OP
,
TRITON_GROUP_FP8_QUANT_OP
]
@
enable_fake_mode
@
enable_fake_mode
def
__init__
(
self
,
config
:
VllmConfig
):
def
__init__
(
self
,
config
:
VllmConfig
)
->
None
:
super
().
__init__
(
config
)
super
().
__init__
(
config
)
self
.
patterns
:
PatternMatcherPass
=
PatternMatcherPass
(
self
.
patterns
:
PatternMatcherPass
=
PatternMatcherPass
(
...
@@ -387,11 +389,11 @@ class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass):
...
@@ -387,11 +389,11 @@ class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass):
self
.
dump_patterns
(
config
,
self
.
patterns
)
self
.
dump_patterns
(
config
,
self
.
patterns
)
@
VllmInductorPass
.
time_and_log
@
VllmInductorPass
.
time_and_log
def
__call__
(
self
,
graph
:
torch
.
fx
.
Graph
):
def
__call__
(
self
,
graph
:
torch
.
fx
.
Graph
)
->
None
:
self
.
matched_count
=
self
.
patterns
.
apply
(
graph
)
self
.
matched_count
=
self
.
patterns
.
apply
(
graph
)
logger
.
debug
(
"Replaced %s patterns"
,
self
.
matched_count
)
logger
.
debug
(
"Replaced %s patterns"
,
self
.
matched_count
)
def
uuid
(
self
):
def
uuid
(
self
)
->
str
:
fusion_patterns
=
[
fusion_patterns
=
[
ActivationQuantPattern
,
ActivationQuantPattern
,
AiterSiluMulFp8GroupQuantPattern
,
AiterSiluMulFp8GroupQuantPattern
,
...
...
vllm/compilation/sequence_parallelism.py
View file @
ad8818bb
...
@@ -2,6 +2,8 @@
...
@@ -2,6 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
functools
import
functools
from
collections.abc
import
Callable
,
Sequence
from
typing
import
Any
import
torch
import
torch
import
torch._inductor.pattern_matcher
as
pm
import
torch._inductor.pattern_matcher
as
pm
...
@@ -26,9 +28,11 @@ from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
...
@@ -26,9 +28,11 @@ from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
def
get_first_out_wrapper
(
fn
):
def
get_first_out_wrapper
(
fn
:
Callable
[...,
Sequence
[
torch
.
Tensor
]],
)
->
Callable
[...,
torch
.
Tensor
]:
@
functools
.
wraps
(
fn
)
@
functools
.
wraps
(
fn
)
def
wrapper
(
*
args
)
:
def
wrapper
(
*
args
:
Any
)
->
torch
.
Tensor
:
return
fn
(
*
args
)[
0
]
return
fn
(
*
args
)[
0
]
return
wrapper
return
wrapper
...
@@ -68,7 +72,7 @@ class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
...
@@ -68,7 +72,7 @@ class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
super
().
__init__
(
epsilon
,
dtype
,
device
)
super
().
__init__
(
epsilon
,
dtype
,
device
)
self
.
rmsnorm_matcher
=
MatcherRMSNorm
(
epsilon
)
self
.
rmsnorm_matcher
=
MatcherRMSNorm
(
epsilon
)
def
get_inputs
(
self
):
def
get_inputs
(
self
)
->
list
[
torch
.
Tensor
]
:
input
=
torch
.
empty
([
1
,
8
,
4
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
input
=
torch
.
empty
([
1
,
8
,
4
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
arg3_1
=
torch
.
empty
([
4
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
arg3_1
=
torch
.
empty
([
4
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
...
@@ -78,7 +82,7 @@ class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
...
@@ -78,7 +82,7 @@ class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
def
pattern
(
def
pattern
(
input
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
arg3_1
:
torch
.
Tensor
,
arg3_1
:
torch
.
Tensor
,
):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
all_reduce
=
self
.
_all_reduce
(
input
)
all_reduce
=
self
.
_all_reduce
(
input
)
rmsnorm
=
self
.
rmsnorm_matcher
(
all_reduce
,
arg3_1
)
rmsnorm
=
self
.
rmsnorm_matcher
(
all_reduce
,
arg3_1
)
...
@@ -87,7 +91,7 @@ class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
...
@@ -87,7 +91,7 @@ class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
def
replacement
(
def
replacement
(
input
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
arg3_1
:
torch
.
Tensor
,
arg3_1
:
torch
.
Tensor
,
):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
reduce_scatter
=
self
.
_reduce_scatter
(
input
)
reduce_scatter
=
self
.
_reduce_scatter
(
input
)
rmsnorm
=
self
.
rmsnorm_matcher
(
reduce_scatter
,
arg3_1
)
rmsnorm
=
self
.
rmsnorm_matcher
(
reduce_scatter
,
arg3_1
)
...
@@ -100,11 +104,11 @@ class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
...
@@ -100,11 +104,11 @@ class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
class
MiddleAllReduceRMSNormPattern
(
_SequenceParallelPatternHelper
):
class
MiddleAllReduceRMSNormPattern
(
_SequenceParallelPatternHelper
):
def
__init__
(
self
,
epsilon
:
float
,
dtype
:
torch
.
dtype
,
device
:
str
|
None
):
def
__init__
(
self
,
epsilon
:
float
,
dtype
:
torch
.
dtype
,
device
:
str
|
None
)
->
None
:
super
().
__init__
(
epsilon
,
dtype
,
device
)
super
().
__init__
(
epsilon
,
dtype
,
device
)
self
.
rmsnorm_matcher
=
MatcherFusedAddRMSNorm
(
epsilon
)
self
.
rmsnorm_matcher
=
MatcherFusedAddRMSNorm
(
epsilon
)
def
get_inputs
(
self
):
def
get_inputs
(
self
)
->
list
[
torch
.
Tensor
]
:
mm_1
=
torch
.
empty
([
4
,
4
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
mm_1
=
torch
.
empty
([
4
,
4
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
residual
=
torch
.
empty
([
4
,
4
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
residual
=
torch
.
empty
([
4
,
4
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
...
@@ -116,7 +120,7 @@ class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
...
@@ -116,7 +120,7 @@ class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
rms_norm_weights
,
rms_norm_weights
,
]
]
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
def
pattern
(
def
pattern
(
residual
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
mm_1
:
torch
.
Tensor
,
mm_1
:
torch
.
Tensor
,
...
@@ -163,23 +167,23 @@ class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
...
@@ -163,23 +167,23 @@ class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
epsilon
:
float
,
epsilon
:
float
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
device
:
str
|
None
,
device
:
str
|
None
,
):
)
->
None
:
super
().
__init__
(
epsilon
,
dtype
,
device
)
super
().
__init__
(
epsilon
,
dtype
,
device
)
self
.
rmsnorm_matcher
=
MatcherRMSNorm
(
epsilon
)
self
.
rmsnorm_matcher
=
MatcherRMSNorm
(
epsilon
)
self
.
quant_matcher
=
MatcherQuantFP8
(
kFp8StaticTensorSym
)
self
.
quant_matcher
=
MatcherQuantFP8
(
kFp8StaticTensorSym
)
def
get_inputs
(
self
):
def
get_inputs
(
self
)
->
list
[
torch
.
Tensor
]
:
input
=
torch
.
zeros
([
1
,
8
,
4
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
input
=
torch
.
zeros
([
1
,
8
,
4
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
weight
=
torch
.
empty
([
4
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
weight
=
torch
.
empty
([
4
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
scale
=
torch
.
tensor
(
1.0
,
device
=
self
.
device
,
dtype
=
torch
.
float32
)
scale
=
torch
.
tensor
(
1.0
,
device
=
self
.
device
,
dtype
=
torch
.
float32
)
return
[
input
,
weight
,
scale
]
return
[
input
,
weight
,
scale
]
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
def
pattern
(
def
pattern
(
input
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
all_reduce
=
self
.
_all_reduce
(
input
)
all_reduce
=
self
.
_all_reduce
(
input
)
rms
=
self
.
rmsnorm_matcher
(
all_reduce
,
weight
)
rms
=
self
.
rmsnorm_matcher
(
all_reduce
,
weight
)
quant
,
_
=
self
.
quant_matcher
(
rms
,
scale
)
quant
,
_
=
self
.
quant_matcher
(
rms
,
scale
)
...
@@ -189,7 +193,7 @@ class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
...
@@ -189,7 +193,7 @@ class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
input
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
reduce_scatter
=
self
.
_reduce_scatter
(
input
)
reduce_scatter
=
self
.
_reduce_scatter
(
input
)
rms
=
self
.
rmsnorm_matcher
(
reduce_scatter
,
weight
)
rms
=
self
.
rmsnorm_matcher
(
reduce_scatter
,
weight
)
quant
,
_
=
self
.
quant_matcher
(
rms
,
scale
)
quant
,
_
=
self
.
quant_matcher
(
rms
,
scale
)
...
@@ -203,12 +207,12 @@ class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
...
@@ -203,12 +207,12 @@ class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
class
MiddleAllReduceRMSNormStaticFP8Pattern
(
_SequenceParallelPatternHelper
):
class
MiddleAllReduceRMSNormStaticFP8Pattern
(
_SequenceParallelPatternHelper
):
def
__init__
(
self
,
epsilon
:
float
,
dtype
:
torch
.
dtype
,
device
:
str
|
None
):
def
__init__
(
self
,
epsilon
:
float
,
dtype
:
torch
.
dtype
,
device
:
str
|
None
)
->
None
:
super
().
__init__
(
epsilon
,
dtype
,
device
)
super
().
__init__
(
epsilon
,
dtype
,
device
)
self
.
rmsnorm_matcher
=
MatcherFusedAddRMSNorm
(
epsilon
)
self
.
rmsnorm_matcher
=
MatcherFusedAddRMSNorm
(
epsilon
)
self
.
quant_matcher
=
MatcherQuantFP8
(
kFp8StaticTensorSym
)
self
.
quant_matcher
=
MatcherQuantFP8
(
kFp8StaticTensorSym
)
def
get_inputs
(
self
):
def
get_inputs
(
self
)
->
list
[
torch
.
Tensor
]
:
mm_1
=
torch
.
empty
([
4
,
4
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
mm_1
=
torch
.
empty
([
4
,
4
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
residual
=
torch
.
empty
([
4
,
4
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
residual
=
torch
.
empty
([
4
,
4
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
rms_norm_weights
=
torch
.
empty
([
4
,
4
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
rms_norm_weights
=
torch
.
empty
([
4
,
4
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
...
@@ -216,7 +220,7 @@ class MiddleAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
...
@@ -216,7 +220,7 @@ class MiddleAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
return
[
residual
,
mm_1
,
rms_norm_weights
,
scale
]
return
[
residual
,
mm_1
,
rms_norm_weights
,
scale
]
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
def
pattern
(
def
pattern
(
residual
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
mm_1
:
torch
.
Tensor
,
mm_1
:
torch
.
Tensor
,
...
@@ -302,7 +306,7 @@ class SequenceParallelismPass(VllmPatternMatcherPass):
...
@@ -302,7 +306,7 @@ class SequenceParallelismPass(VllmPatternMatcherPass):
"""
"""
@
enable_fake_mode
@
enable_fake_mode
def
__init__
(
self
,
config
:
VllmConfig
):
def
__init__
(
self
,
config
:
VllmConfig
)
->
None
:
super
().
__init__
(
config
)
super
().
__init__
(
config
)
# Used to clean up redundant views created temporarily
# Used to clean up redundant views created temporarily
...
@@ -357,7 +361,7 @@ class SequenceParallelismPass(VllmPatternMatcherPass):
...
@@ -357,7 +361,7 @@ class SequenceParallelismPass(VllmPatternMatcherPass):
return
(
compile_range
.
is_single_size
())
and
(
compile_range
.
end
%
tp_size
==
0
)
return
(
compile_range
.
is_single_size
())
and
(
compile_range
.
end
%
tp_size
==
0
)
@
VllmInductorPass
.
time_and_log
@
VllmInductorPass
.
time_and_log
def
__call__
(
self
,
graph
:
fx
.
Graph
):
def
__call__
(
self
,
graph
:
fx
.
Graph
)
->
None
:
self
.
matched_count
=
self
.
patterns
.
apply
(
graph
)
self
.
matched_count
=
self
.
patterns
.
apply
(
graph
)
logger
.
debug
(
"Replaced %s patterns"
,
self
.
matched_count
)
logger
.
debug
(
"Replaced %s patterns"
,
self
.
matched_count
)
# Clean up reshape nodes
# Clean up reshape nodes
...
...
vllm/distributed/parallel_state.py
View file @
ad8818bb
...
@@ -1529,22 +1529,22 @@ def patch_tensor_parallel_group(tp_group: GroupCoordinator):
...
@@ -1529,22 +1529,22 @@ def patch_tensor_parallel_group(tp_group: GroupCoordinator):
_TP
=
old_tp_group
_TP
=
old_tp_group
def
get_tensor_model_parallel_world_size
():
def
get_tensor_model_parallel_world_size
()
->
int
:
"""Return world size for the tensor model parallel group."""
"""Return world size for the tensor model parallel group."""
return
get_tp_group
().
world_size
return
get_tp_group
().
world_size
def
get_tensor_model_parallel_rank
():
def
get_tensor_model_parallel_rank
()
->
int
:
"""Return my rank for the tensor model parallel group."""
"""Return my rank for the tensor model parallel group."""
return
get_tp_group
().
rank_in_group
return
get_tp_group
().
rank_in_group
def
get_decode_context_model_parallel_world_size
():
def
get_decode_context_model_parallel_world_size
()
->
int
:
"""Return world size for the decode context model parallel group."""
"""Return world size for the decode context model parallel group."""
return
get_dcp_group
().
world_size
return
get_dcp_group
().
world_size
def
get_decode_context_model_parallel_rank
():
def
get_decode_context_model_parallel_rank
()
->
int
:
"""Return my rank for the decode context model parallel group."""
"""Return my rank for the decode context model parallel group."""
return
get_dcp_group
().
rank_in_group
return
get_dcp_group
().
rank_in_group
...
...
vllm/model_executor/layers/rotary_embedding/__init__.py
View file @
ad8818bb
...
@@ -20,7 +20,9 @@ from .phi3_long_rope_scaled_rope import Phi3LongRoPEScaledRotaryEmbedding
...
@@ -20,7 +20,9 @@ from .phi3_long_rope_scaled_rope import Phi3LongRoPEScaledRotaryEmbedding
from
.xdrope
import
XDRotaryEmbedding
from
.xdrope
import
XDRotaryEmbedding
from
.yarn_scaling_rope
import
YaRNScalingRotaryEmbedding
from
.yarn_scaling_rope
import
YaRNScalingRotaryEmbedding
_ROPE_DICT
:
dict
[
tuple
,
RotaryEmbedding
]
=
{}
_ROPE_DICT
:
dict
[
tuple
[
Any
,
...],
RotaryEmbedding
]
=
{}
__all__
=
[
"RotaryEmbedding"
]
def
get_rope
(
def
get_rope
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment