Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
bb3afd68
Commit
bb3afd68
authored
Jan 21, 2026
by
zhuwenwen
Browse files
Merge branch 'v0.9.2-dev' of
http://10.16.6.30/dcutoolkit/deeplearing/vllm
into v0.9.2-dev
parents
0d5dd2da
beb3aff7
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
616 additions
and
438 deletions
+616
-438
vllm/model_executor/layers/fused_moe/configs/E=128,N=192,device_name=gfx938_64cu.json
...used_moe/configs/E=128,N=192,device_name=gfx938_64cu.json
+146
-0
vllm/model_executor/layers/fused_moe/configs/E=128,N=192,device_name=gfx938_64cu_nn.json
...d_moe/configs/E=128,N=192,device_name=gfx938_64cu_nn.json
+164
-0
vllm/model_executor/layers/fused_moe/configs/E=160,N=320,device_name=gfx938_64cu.json
...used_moe/configs/E=160,N=320,device_name=gfx938_64cu.json
+146
-0
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+5
-5
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+6
-2
vllm/model_executor/layers/quantization/utils/fp8_utils.py
vllm/model_executor/layers/quantization/utils/fp8_utils.py
+39
-404
vllm/model_executor/models/qwen3.py
vllm/model_executor/models/qwen3.py
+95
-15
vllm/platforms/rocm.py
vllm/platforms/rocm.py
+1
-5
vllm/utils/__init__.py
vllm/utils/__init__.py
+14
-7
No files found.
vllm/model_executor/layers/fused_moe/configs/E=128,N=192,device_name=gfx938_64cu.json
0 → 100644
View file @
bb3afd68
{
"1"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
2
},
"2"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"4"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"8"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"16"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
2
},
"24"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
2
},
"32"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
2
},
"48"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"64"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
2
},
"96"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
2
},
"128"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
2
},
"256"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
2
},
"512"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
32
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"1024"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"1536"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"2048"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"3072"
:
{
"BLOCK_SIZE_M"
:
256
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"4096"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
}
}
vllm/model_executor/layers/fused_moe/configs/E=128,N=192,device_name=gfx938_64cu_nn.json
0 → 100644
View file @
bb3afd68
{
"1"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
1
},
"2"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
1
},
"4"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
1
},
"8"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
1
},
"16"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
1
},
"24"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
1
},
"32"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
1
},
"48"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
1
},
"64"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
2
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
1
},
"96"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
1
},
"128"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
1
},
"256"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
1
},
"512"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
1
},
"1024"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
1
},
"1536"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
1
},
"2048"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
1
},
"3072"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
1
},
"4096"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
1
}
}
vllm/model_executor/layers/fused_moe/configs/E=160,N=320,device_name=gfx938_64cu.json
0 → 100644
View file @
bb3afd68
{
"1"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"2"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
2
},
"4"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"8"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"16"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"24"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"32"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"48"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"64"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"96"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"128"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"256"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"512"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"1024"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"1536"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"2048"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"3072"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"4096"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
}
}
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
bb3afd68
...
@@ -1225,14 +1225,14 @@ def vllm_topk_softmax(topk_weights: torch.Tensor, topk_indices: torch.Tensor,
...
@@ -1225,14 +1225,14 @@ def vllm_topk_softmax(topk_weights: torch.Tensor, topk_indices: torch.Tensor,
token_expert_indices
:
torch
.
Tensor
,
token_expert_indices
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
renormalize
:
bool
)
->
tuple
[
torch
.
Tensor
,
...]:
renormalize
:
bool
)
->
tuple
[
torch
.
Tensor
,
...]:
if
envs
.
VLLM_USE_TOPK_RENORM
:
if
envs
.
VLLM_USE_TOPK_RENORM
and
renormalize
is
True
:
from
lightop
import
op
as
op
from
lightop
import
op
as
op
op
.
topk_softmax
(
op
.
topk_softmax
(
topk_weights
,
topk_weights
,
topk_indices
,
topk_indices
,
token_expert_indices
,
token_expert_indices
,
gating_output
,
gating_output
,
Tru
e
,
renormaliz
e
,
)
)
else
:
else
:
ops
.
topk_softmax
(
ops
.
topk_softmax
(
...
@@ -1791,7 +1791,7 @@ def fused_experts_impl(
...
@@ -1791,7 +1791,7 @@ def fused_experts_impl(
device
=
hidden_states
.
device
,
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
)
dtype
=
hidden_states
.
dtype
)
if
use_int8_w8a8
is
True
:
if
use_int8_w8a8
or
use_fp8_w8a8
:
return
fused_experts_impl_int8
(
hidden_states
=
hidden_states
,
return
fused_experts_impl_int8
(
hidden_states
=
hidden_states
,
w1
=
w1
,
w1
=
w1
,
w2
=
w2
,
w2
=
w2
,
...
@@ -1801,8 +1801,8 @@ def fused_experts_impl(
...
@@ -1801,8 +1801,8 @@ def fused_experts_impl(
inplace
=
inplace
,
inplace
=
inplace
,
activation
=
activation
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
use_fp8_w8a8
=
False
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a8
=
True
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
False
,
use_int8_w8a16
=
False
,
use_int4_w4a16
=
False
,
use_int4_w4a16
=
False
,
per_channel_quant
=
per_channel_quant
,
per_channel_quant
=
per_channel_quant
,
...
...
vllm/model_executor/layers/quantization/fp8.py
View file @
bb3afd68
...
@@ -331,8 +331,12 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -331,8 +331,12 @@ class Fp8LinearMethod(LinearMethodBase):
weight
=
layer
.
weight
.
data
weight
=
layer
.
weight
.
data
weight_scale_inv
=
layer
.
weight_scale_inv
.
data
weight_scale_inv
=
layer
.
weight_scale_inv
.
data
weight
=
self
.
_maybe_pad_weight
(
weight
)
if
envs
.
VLLM_W8A8_BACKEND
==
3
:
weight
=
weight
.
T
.
contiguous
()
weight_scale_inv
=
weight_scale_inv
.
T
.
contiguous
()
else
:
weight
=
self
.
_maybe_pad_weight
(
weight
)
# Torch.compile cannot use Parameter subclasses.
# Torch.compile cannot use Parameter subclasses.
layer
.
weight
=
Parameter
(
weight
,
requires_grad
=
False
)
layer
.
weight
=
Parameter
(
weight
,
requires_grad
=
False
)
layer
.
weight_scale_inv
=
Parameter
(
weight_scale_inv
,
layer
.
weight_scale_inv
=
Parameter
(
weight_scale_inv
,
...
...
vllm/model_executor/layers/quantization/utils/fp8_utils.py
View file @
bb3afd68
...
@@ -6,6 +6,8 @@ import functools
...
@@ -6,6 +6,8 @@ import functools
import
json
import
json
import
os
import
os
from
typing
import
Any
,
Callable
,
Optional
,
Union
,
List
from
typing
import
Any
,
Callable
,
Optional
,
Union
,
List
from
lmslim
import
quant_ops
from
lmslim.quantize.quant_ops
import
BlockSize
import
torch
import
torch
...
@@ -19,6 +21,10 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
...
@@ -19,6 +21,10 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.triton_utils
import
tl
,
triton
from
vllm.triton_utils
import
tl
,
triton
from
vllm.utils
import
cdiv
,
direct_register_custom_op
,
has_deep_gemm
from
vllm.utils
import
cdiv
,
direct_register_custom_op
,
has_deep_gemm
try
:
from
lmslim.layers.gemm.fp8_utils
import
per_token_group_quant_fp8
,
w8a8_block_fp8_matmul
except
Exception
:
print
(
"INFO: Please updata lmslim if you want to use fp8_utils.
\n
"
)
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -83,7 +89,7 @@ if current_platform.is_rocm():
...
@@ -83,7 +89,7 @@ if current_platform.is_rocm():
def
dispatch_w8a8_blockscale_func
(
def
dispatch_w8a8_blockscale_func
(
use_cutlass
:
bool
,
use_aiter_and_is_supported
:
bool
use_cutlass
:
bool
,
use_aiter_and_is_supported
:
bool
,
use_blaslt
:
bool
)
->
Callable
[[
)
->
Callable
[[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
...
@@ -96,6 +102,9 @@ def dispatch_w8a8_blockscale_func(
...
@@ -96,6 +102,9 @@ def dispatch_w8a8_blockscale_func(
return
cutlass_scaled_mm
return
cutlass_scaled_mm
if
(
use_aiter_and_is_supported
):
if
(
use_aiter_and_is_supported
):
return
torch
.
ops
.
vllm
.
rocm_aiter_gemm_w8a8_blockscale
return
torch
.
ops
.
vllm
.
rocm_aiter_gemm_w8a8_blockscale
if
use_blaslt
:
return
hipblaslt_w8a8_block_fp8_matmul
return
w8a8_block_fp8_matmul
return
w8a8_block_fp8_matmul
...
@@ -127,7 +136,11 @@ def apply_w8a8_block_fp8_linear(
...
@@ -127,7 +136,11 @@ def apply_w8a8_block_fp8_linear(
assert
input_scale
is
None
assert
input_scale
is
None
# View input as 2D matrix for fp8 methods
# View input as 2D matrix for fp8 methods
input_2d
=
input
.
view
(
-
1
,
input
.
shape
[
-
1
])
input_2d
=
input
.
view
(
-
1
,
input
.
shape
[
-
1
])
output_shape
=
[
*
input
.
shape
[:
-
1
],
weight
.
shape
[
0
]]
output_shape
=
[]
if
envs
.
VLLM_W8A8_BACKEND
==
3
:
output_shape
=
[
*
input
.
shape
[:
-
1
],
weight
.
shape
[
-
1
]]
else
:
output_shape
=
[
*
input
.
shape
[:
-
1
],
weight
.
shape
[
0
]]
output_dtype
=
input
.
dtype
output_dtype
=
input
.
dtype
if
should_use_deepgemm
(
output_dtype
,
weight
):
if
should_use_deepgemm
(
output_dtype
,
weight
):
...
@@ -166,9 +179,12 @@ def apply_w8a8_block_fp8_linear(
...
@@ -166,9 +179,12 @@ def apply_w8a8_block_fp8_linear(
weight
.
shape
[
0
]
%
128
==
0
and
weight
.
shape
[
1
]
%
128
==
0
)
weight
.
shape
[
0
]
%
128
==
0
and
weight
.
shape
[
1
]
%
128
==
0
)
else
:
else
:
use_cutlass
=
False
use_cutlass
=
False
use_blaslt
=
False
if
envs
.
VLLM_W8A8_BACKEND
==
3
:
use_blaslt
=
True
w8a8_blockscale_func
=
dispatch_w8a8_blockscale_func
(
w8a8_blockscale_func
=
dispatch_w8a8_blockscale_func
(
use_cutlass
,
use_aiter_and_is_supported
)
use_cutlass
,
use_aiter_and_is_supported
,
use_blaslt
)
if
use_cutlass
:
if
use_cutlass
:
q_input
,
x_scale
=
per_token_group_quant_fp8
(
q_input
,
x_scale
=
per_token_group_quant_fp8
(
input_2d
,
block_size
[
1
],
column_major_scales
=
use_cutlass
)
input_2d
,
block_size
[
1
],
column_major_scales
=
use_cutlass
)
...
@@ -197,7 +213,11 @@ def apply_w8a8_block_fp8_linear_fake(
...
@@ -197,7 +213,11 @@ def apply_w8a8_block_fp8_linear_fake(
cutlass_block_fp8_supported
:
bool
=
CUTLASS_BLOCK_FP8_SUPPORTED
,
cutlass_block_fp8_supported
:
bool
=
CUTLASS_BLOCK_FP8_SUPPORTED
,
use_aiter_and_is_supported
:
bool
=
False
,
use_aiter_and_is_supported
:
bool
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
output_shape
=
[
*
input
.
shape
[:
-
1
],
weight
.
shape
[
0
]]
output_shape
=
[]
if
envs
.
VLLM_W8A8_BACKEND
==
3
:
output_shape
=
[
*
input
.
shape
[:
-
1
],
weight
.
shape
[
-
1
]]
else
:
output_shape
=
[
*
input
.
shape
[:
-
1
],
weight
.
shape
[
0
]]
return
torch
.
empty
(
output_shape
,
dtype
=
input
.
dtype
,
device
=
input
.
device
)
return
torch
.
empty
(
output_shape
,
dtype
=
input
.
dtype
,
device
=
input
.
device
)
...
@@ -240,333 +260,9 @@ def block_quant_to_tensor_quant(
...
@@ -240,333 +260,9 @@ def block_quant_to_tensor_quant(
return
x_q_tensor
,
scale
return
x_q_tensor
,
scale
@
triton
.
jit
def
_per_token_group_quant_fp8
(
# Pointers to inputs and output
y_ptr
,
y_q_ptr
,
y_s_ptr
,
group_size
,
# Num columns of y
y_num_columns
,
y_row_stride
,
# Avoid to divide zero
eps
,
# Information for float8
fp8_min
,
fp8_max
,
# Meta-parameters
BLOCK
:
tl
.
constexpr
,
):
"""A Triton-accelerated function to perform per-token-group
quantization on a tensor.
This function converts the tensor values into float8 values.
"""
groups_per_row
=
y_num_columns
//
group_size
# Map the program id to the row of X and Y it should compute.
g_id
=
tl
.
program_id
(
0
)
row
=
g_id
//
groups_per_row
row_g_id
=
g_id
%
groups_per_row
# Ensure offset calculations use int64 to prevent overflow
y_ptr_offset
=
(
row
.
to
(
tl
.
int64
)
*
y_row_stride
)
+
(
row_g_id
.
to
(
tl
.
int64
)
*
group_size
)
y_ptr
+=
y_ptr_offset
y_q_ptr_offset
=
g_id
.
to
(
tl
.
int64
)
*
group_size
y_q_ptr
+=
y_q_ptr_offset
y_s_ptr
+=
g_id
cols
=
tl
.
arange
(
0
,
BLOCK
)
# N <= BLOCK
mask
=
cols
<
group_size
y
=
tl
.
load
(
y_ptr
+
cols
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
# Quant
_absmax
=
tl
.
maximum
(
tl
.
max
(
tl
.
abs
(
y
)),
eps
)
y_s
=
_absmax
/
fp8_max
y_q
=
tl
.
clamp
(
y
/
y_s
,
fp8_min
,
fp8_max
).
to
(
y_q_ptr
.
dtype
.
element_ty
)
tl
.
store
(
y_q_ptr
+
cols
,
y_q
,
mask
=
mask
)
tl
.
store
(
y_s_ptr
,
y_s
)
@
triton
.
jit
def
_per_token_group_quant_fp8_colmajor
(
# Pointers to inputs and output
y_ptr
,
y_q_ptr
,
y_s_ptr
,
group_size
,
# Num columns of y
y_num_columns
,
y_row_stride
,
# Stride from one column to the next of y_s
y_s_col_stride
,
# Avoid to divide zero
eps
,
# Information for float8
fp8_min
,
fp8_max
,
# Meta-parameters
BLOCK
:
tl
.
constexpr
,
):
"""A Triton-accelerated function to perform per-token-group
quantization on a tensor.
This function converts the tensor values into float8 values.
"""
groups_per_row
=
y_num_columns
//
group_size
# Map the program id to the row of X and Y it should compute.
g_id
=
tl
.
program_id
(
0
)
row
=
g_id
//
groups_per_row
row_g_id
=
g_id
%
groups_per_row
# Ensure offset calculations use int64 to prevent overflow
y_ptr_offset
=
(
row
.
to
(
tl
.
int64
)
*
y_row_stride
)
+
(
row_g_id
.
to
(
tl
.
int64
)
*
group_size
)
y_ptr
+=
y_ptr_offset
y_q_ptr_offset
=
g_id
.
to
(
tl
.
int64
)
*
group_size
y_q_ptr
+=
y_q_ptr_offset
# Convert g_id the flattened block coordinate to 2D so we can index
# into the output y_scales matrix
blocks_per_row
=
y_num_columns
//
group_size
scale_col
=
g_id
%
blocks_per_row
scale_row
=
g_id
//
blocks_per_row
# Ensure offset calculation uses int64 for y_s_ptr
y_s_ptr_offset
=
(
scale_col
.
to
(
tl
.
int64
)
*
y_s_col_stride
)
+
scale_row
.
to
(
tl
.
int64
)
y_s_ptr
+=
y_s_ptr_offset
cols
=
tl
.
arange
(
0
,
BLOCK
)
# group_size <= BLOCK
mask
=
cols
<
group_size
y
=
tl
.
load
(
y_ptr
+
cols
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
# Quant
_absmax
=
tl
.
maximum
(
tl
.
max
(
tl
.
abs
(
y
)),
eps
)
y_s
=
_absmax
/
fp8_max
y_q
=
tl
.
clamp
(
y
/
y_s
,
fp8_min
,
fp8_max
).
to
(
y_q_ptr
.
dtype
.
element_ty
)
tl
.
store
(
y_q_ptr
+
cols
,
y_q
,
mask
=
mask
)
tl
.
store
(
y_s_ptr
,
y_s
)
def
per_token_group_quant_fp8
(
x
:
torch
.
Tensor
,
group_size
:
int
,
eps
:
float
=
1e-10
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
column_major_scales
:
bool
=
False
,
out_q
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Function to perform per-token-group quantization on an input tensor `x`.
It converts the tensor values into signed float8 values and returns the
quantized tensor along with the scaling factor used for quantization.
Args:
x: The input tensor with ndim >= 2.
group_size: The group size used for quantization.
eps: The minimum to avoid dividing zero.
dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn`
is supported for now.
column_major_scales: Outputs scales in column major.
out_q: Optional output tensor. If not provided, function will create.
Returns:
tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
scaling factor for quantization.
"""
dtype
=
current_platform
.
fp8_dtype
()
if
dtype
is
None
else
dtype
assert
(
x
.
shape
[
-
1
]
%
group_size
==
0
),
(
f
"the last dimension of `x`
{
x
.
shape
[
-
1
]
}
must be divisible "
f
"by `group_size`
{
group_size
}
"
)
assert
x
.
stride
(
-
1
)
==
1
,
"`x` groups must be contiguous"
finfo
=
torch
.
finfo
(
dtype
)
fp8_min
=
finfo
.
min
fp8_max
=
finfo
.
max
assert
out_q
is
None
or
out_q
.
shape
==
x
.
shape
x_q
=
out_q
if
x_q
is
None
:
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
dtype
)
M
=
x
.
numel
()
//
group_size
N
=
group_size
if
column_major_scales
:
shape
=
(
x
.
shape
[
-
1
]
//
group_size
,
)
+
x
.
shape
[:
-
1
]
x_s
=
torch
.
empty
(
shape
,
device
=
x
.
device
,
dtype
=
torch
.
float32
).
permute
(
-
1
,
-
2
)
else
:
shape
=
x
.
shape
[:
-
1
]
+
(
x
.
shape
[
-
1
]
//
group_size
,
)
x_s
=
torch
.
empty
(
shape
,
device
=
x
.
device
,
dtype
=
torch
.
float32
)
BLOCK
=
triton
.
next_power_of_2
(
N
)
# heuristics for number of warps
num_warps
=
min
(
max
(
BLOCK
//
256
,
1
),
8
)
num_stages
=
1
if
column_major_scales
:
_per_token_group_quant_fp8_colmajor
[(
M
,
)](
x
,
x_q
,
x_s
,
group_size
,
x
.
shape
[
1
],
x
.
stride
(
0
),
x_s
.
stride
(
1
),
eps
,
fp8_min
=
fp8_min
,
fp8_max
=
fp8_max
,
BLOCK
=
BLOCK
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
)
else
:
_per_token_group_quant_fp8
[(
M
,
)](
x
,
x_q
,
x_s
,
group_size
,
x
.
shape
[
1
],
x
.
stride
(
0
),
eps
,
fp8_min
=
fp8_min
,
fp8_max
=
fp8_max
,
BLOCK
=
BLOCK
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
)
return
x_q
,
x_s
def
hipblaslt_w8a8_block_fp8_matmul
(
@
triton
.
jit
def
_w8a8_block_fp8_matmul
(
# Pointers to inputs and output
A
,
B
,
C
,
As
,
Bs
,
# Shape for matmul
M
,
N
,
K
,
# Block size for block-wise quantization
group_n
,
group_k
,
# Stride for inputs and output
stride_am
,
stride_ak
,
stride_bk
,
stride_bn
,
stride_cm
,
stride_cn
,
stride_As_m
,
stride_As_k
,
stride_Bs_k
,
stride_Bs_n
,
# Meta-parameters
BLOCK_SIZE_M
:
tl
.
constexpr
,
BLOCK_SIZE_N
:
tl
.
constexpr
,
BLOCK_SIZE_K
:
tl
.
constexpr
,
GROUP_SIZE_M
:
tl
.
constexpr
,
):
"""Triton-accelerated function used to perform linear operations (dot
product) on input tensors `A` and `B` with block-wise quantization, and
store the result in output tensor `C`.
"""
pid
=
tl
.
program_id
(
axis
=
0
)
num_pid_m
=
tl
.
cdiv
(
M
,
BLOCK_SIZE_M
)
num_pid_n
=
tl
.
cdiv
(
N
,
BLOCK_SIZE_N
)
num_pid_in_group
=
GROUP_SIZE_M
*
num_pid_n
group_id
=
pid
//
num_pid_in_group
first_pid_m
=
group_id
*
GROUP_SIZE_M
group_size_m
=
min
(
num_pid_m
-
first_pid_m
,
GROUP_SIZE_M
)
pid_m
=
first_pid_m
+
(
pid
%
group_size_m
)
pid_n
=
(
pid
%
num_pid_in_group
)
//
group_size_m
offs_am
=
(
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
))
%
M
offs_bn
=
(
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
))
%
N
offs_k
=
tl
.
arange
(
0
,
BLOCK_SIZE_K
)
a_ptrs
=
A
+
(
offs_am
[:,
None
]
*
stride_am
+
offs_k
[
None
,
:]
*
stride_ak
)
b_ptrs
=
B
+
(
offs_k
[:,
None
]
*
stride_bk
+
offs_bn
[
None
,
:]
*
stride_bn
)
As_ptrs
=
As
+
offs_am
*
stride_As_m
offs_bsn
=
offs_bn
//
group_n
Bs_ptrs
=
Bs
+
offs_bsn
*
stride_Bs_n
accumulator
=
tl
.
zeros
((
BLOCK_SIZE_M
,
BLOCK_SIZE_N
),
dtype
=
tl
.
float32
)
for
k
in
range
(
0
,
tl
.
cdiv
(
K
,
BLOCK_SIZE_K
)):
a
=
tl
.
load
(
a_ptrs
,
mask
=
offs_k
[
None
,
:]
<
K
-
k
*
BLOCK_SIZE_K
,
other
=
0.0
)
b
=
tl
.
load
(
b_ptrs
,
mask
=
offs_k
[:,
None
]
<
K
-
k
*
BLOCK_SIZE_K
,
other
=
0.0
)
k_start
=
k
*
BLOCK_SIZE_K
offs_ks
=
k_start
//
group_k
a_s
=
tl
.
load
(
As_ptrs
+
offs_ks
*
stride_As_k
)
b_s
=
tl
.
load
(
Bs_ptrs
+
offs_ks
*
stride_Bs_k
)
accumulator
+=
tl
.
dot
(
a
,
b
)
*
a_s
[:,
None
]
*
b_s
[
None
,
:]
a_ptrs
+=
BLOCK_SIZE_K
*
stride_ak
b_ptrs
+=
BLOCK_SIZE_K
*
stride_bk
if
C
.
dtype
.
element_ty
==
tl
.
bfloat16
:
c
=
accumulator
.
to
(
tl
.
bfloat16
)
elif
C
.
dtype
.
element_ty
==
tl
.
float16
:
c
=
accumulator
.
to
(
tl
.
float16
)
else
:
c
=
accumulator
.
to
(
tl
.
float32
)
offs_cm
=
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
)
offs_cn
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
c_ptrs
=
C
+
stride_cm
*
offs_cm
[:,
None
]
+
stride_cn
*
offs_cn
[
None
,
:]
c_mask
=
(
offs_cm
[:,
None
]
<
M
)
&
(
offs_cn
[
None
,
:]
<
N
)
tl
.
store
(
c_ptrs
,
c
,
mask
=
c_mask
)
@
functools
.
lru_cache
def
get_w8a8_block_fp8_configs
(
N
:
int
,
K
:
int
,
block_n
:
int
,
block_k
:
int
)
->
Optional
[
dict
[
int
,
Any
]]:
"""
Return optimized configurations for the w8a8 block fp8 kernel.
The return value will be a dictionary that maps an irregular grid of
batch sizes to configurations of the w8a8 block fp8 kernel. To evaluate the
kernel on a given batch size bs, the closest batch size in the grid should
be picked and the associated configuration chosen to invoke the kernel.
"""
# First look up if an optimized configuration is available in the configs
# directory
device_name
=
current_platform
.
get_device_name
().
replace
(
" "
,
"_"
)
json_file_name
=
f
"N=
{
N
}
,K=
{
K
}
,device_name=
{
device_name
}
,dtype=fp8_w8a8,block_shape=[
{
block_n
}
,
{
block_k
}
].json"
# noqa: E501
config_file_path
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)),
"configs"
,
json_file_name
)
if
os
.
path
.
exists
(
config_file_path
):
with
open
(
config_file_path
)
as
f
:
logger
.
info
(
"Using configuration from %s for W8A8 Block FP8 kernel."
,
config_file_path
,
)
# If a configuration has been found, return it
return
{
int
(
key
):
val
for
key
,
val
in
json
.
load
(
f
).
items
()}
# If no optimized configuration is available, we will use the default
# configuration
logger
.
warning
(
"Using default W8A8 Block FP8 kernel config. Performance might "
"be sub-optimal! Config file not found at %s"
,
config_file_path
,
)
return
None
def
w8a8_block_fp8_matmul
(
A
:
torch
.
Tensor
,
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
As
:
torch
.
Tensor
,
As
:
torch
.
Tensor
,
...
@@ -574,80 +270,19 @@ def w8a8_block_fp8_matmul(
...
@@ -574,80 +270,19 @@ def w8a8_block_fp8_matmul(
block_size
:
list
[
int
],
block_size
:
list
[
int
],
output_dtype
:
torch
.
dtype
=
torch
.
float16
,
output_dtype
:
torch
.
dtype
=
torch
.
float16
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""This function performs matrix multiplication with block-wise
m
,
k
=
A
.
shape
quantization.
_
,
n
=
B
.
shape
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
enum_block_size
=
BlockSize
.
block_128x128
The output is returned in the specified `output_dtype`.
if
block_size
[
0
]
==
64
:
Args:
enum_block_size
=
BlockSize
.
block_64x64
A: The input tensor, e.g., activation.
elif
block_size
[
0
]
==
128
:
B: The input tensor, e.g., weight.
enum_block_size
=
BlockSize
.
block_128x128
As: The per-token-group quantization scale for `A`.
Bs: The per-block quantization scale for `B`.
block_size: The block size for per-block quantization. It should
be 2-dim, e.g., [128, 128].
output_dytpe: The dtype of the returned tensor.
Returns:
torch.Tensor: The result of matmul.
"""
assert
len
(
block_size
)
==
2
block_n
,
block_k
=
block_size
[
0
],
block_size
[
1
]
assert
A
.
shape
[
-
1
]
==
B
.
shape
[
-
1
]
assert
A
.
shape
[:
-
1
]
==
As
.
shape
[:
-
1
]
and
A
.
is_contiguous
()
assert
triton
.
cdiv
(
A
.
shape
[
-
1
],
block_k
)
==
As
.
shape
[
-
1
]
M
=
A
.
numel
()
//
A
.
shape
[
-
1
]
assert
B
.
ndim
==
2
and
Bs
.
ndim
==
2
N
,
K
=
B
.
shape
assert
triton
.
cdiv
(
N
,
block_n
)
==
Bs
.
shape
[
0
]
assert
triton
.
cdiv
(
K
,
block_k
)
==
Bs
.
shape
[
1
]
C_shape
=
A
.
shape
[:
-
1
]
+
(
N
,
)
C
=
A
.
new_empty
(
C_shape
,
dtype
=
output_dtype
)
configs
=
get_w8a8_block_fp8_configs
(
N
,
K
,
block_size
[
0
],
block_size
[
1
])
if
configs
:
# Get the optimal config if there is one
config
=
configs
[
min
(
configs
.
keys
(),
key
=
lambda
x
:
abs
(
x
-
M
))]
else
:
else
:
# Default config
print
(
f
"[WARN] Unsupported block_size:
{
block_size
}
. Falling back to BlockSize.block_128x128"
)
# Block-wise quant: BLOCK_SIZE_N must be divisible by block_size[0]
# BLOCK_SIZE_K must be divisible by block_size[1]
_
,
d
=
quant_ops
.
hipblaslt_w8a8_blockwise_gemm
(
A
,
B
,
As
,
Bs
,
config
=
{
m
,
n
,
k
,
'NN'
,
output_dtype
,
"BLOCK_SIZE_M"
:
64
,
enum_block_size
,
None
)
"BLOCK_SIZE_N"
:
block_size
[
0
],
return
d
"BLOCK_SIZE_K"
:
block_size
[
1
],
"GROUP_SIZE_M"
:
32
,
"num_warps"
:
4
,
"num_stages"
:
2
,
}
def
grid
(
META
):
return
(
triton
.
cdiv
(
M
,
META
[
"BLOCK_SIZE_M"
])
*
triton
.
cdiv
(
N
,
META
[
"BLOCK_SIZE_N"
]),
)
_w8a8_block_fp8_matmul
[
grid
](
A
,
B
,
C
,
As
,
Bs
,
M
,
N
,
K
,
block_n
,
block_k
,
A
.
stride
(
-
2
),
A
.
stride
(
-
1
),
B
.
stride
(
1
),
B
.
stride
(
0
),
C
.
stride
(
-
2
),
C
.
stride
(
-
1
),
As
.
stride
(
-
2
),
As
.
stride
(
-
1
),
Bs
.
stride
(
1
),
Bs
.
stride
(
0
),
**
config
,
)
return
C
vllm/model_executor/models/qwen3.py
View file @
bb3afd68
...
@@ -52,6 +52,7 @@ from .qwen2 import Qwen2MLP as Qwen3MLP
...
@@ -52,6 +52,7 @@ from .qwen2 import Qwen2MLP as Qwen3MLP
from
.qwen2
import
Qwen2Model
from
.qwen2
import
Qwen2Model
from
.utils
import
AutoWeightsLoader
,
PPMissingLayer
,
maybe_prefix
from
.utils
import
AutoWeightsLoader
,
PPMissingLayer
,
maybe_prefix
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.utils
import
direct_register_custom_op
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -129,6 +130,58 @@ class Qwen3Attention(nn.Module):
...
@@ -129,6 +130,58 @@ class Qwen3Attention(nn.Module):
self
.
q_norm
=
RMSNorm
(
self
.
head_dim
,
eps
=
rms_norm_eps
)
self
.
q_norm
=
RMSNorm
(
self
.
head_dim
,
eps
=
rms_norm_eps
)
self
.
k_norm
=
RMSNorm
(
self
.
head_dim
,
eps
=
rms_norm_eps
)
self
.
k_norm
=
RMSNorm
(
self
.
head_dim
,
eps
=
rms_norm_eps
)
def
rms_rotary_embedding_fuse
(
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
],
head_size
:
int
,
cos_sin_cache
:
torch
.
Tensor
,
is_neox_style
:
bool
,
q_weight
:
torch
.
Tensor
,
k_weight
:
torch
.
Tensor
,
q_bias
:
Optional
[
torch
.
Tensor
],
k_bias
:
Optional
[
torch
.
Tensor
],
epsilon
:
float
,
)
->
None
:
from
lightop
import
rms_rotary_embedding_fuse
as
fused_kernel
fused_kernel
(
positions
,
query
,
key
,
head_size
,
cos_sin_cache
,
is_neox_style
,
q_weight
,
k_weight
,
q_bias
,
k_bias
,
epsilon
,
)
def
rms_rotary_embedding_fuse_fake
(
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
],
head_size
:
int
,
cos_sin_cache
:
torch
.
Tensor
,
is_neox_style
:
bool
,
q_weight
:
torch
.
Tensor
,
k_weight
:
torch
.
Tensor
,
q_bias
:
Optional
[
torch
.
Tensor
],
k_bias
:
Optional
[
torch
.
Tensor
],
epsilon
:
float
,
)
->
None
:
# Fake impl intentionally left as no-op for graph tracing modes.
pass
if
not
hasattr
(
torch
.
ops
.
vllm
,
"rms_rotary_embedding_fuse"
):
direct_register_custom_op
(
op_name
=
"rms_rotary_embedding_fuse"
,
op_func
=
rms_rotary_embedding_fuse
,
mutates_args
=
[
"query"
,
"key"
],
fake_impl
=
rms_rotary_embedding_fuse_fake
,
)
def
forward
(
def
forward
(
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
...
@@ -136,22 +189,49 @@ class Qwen3Attention(nn.Module):
...
@@ -136,22 +189,49 @@ class Qwen3Attention(nn.Module):
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
# Add qk-norm
if
envs
.
VLLM_USE_FUSED_RMS_ROPE
:
q_by_head
=
q
.
view
(
*
q
.
shape
[:
-
1
],
q
.
shape
[
-
1
]
//
self
.
head_dim
,
# Fused RMSNorm + RoPE path through custom op.
self
.
head_dim
)
cos_sin_cache
=
self
.
rotary_emb
.
cos_sin_cache
if
envs
.
VLLM_USE_APEX_RN
:
if
(
cos_sin_cache
.
device
!=
q
.
device
q_by_head
=
self
.
q_norm
.
forward_apex
(
q_by_head
)
or
cos_sin_cache
.
dtype
!=
q
.
dtype
):
else
:
cos_sin_cache
=
cos_sin_cache
.
to
(
q
.
device
,
q_by_head
=
self
.
q_norm
.
forward_cuda
(
q_by_head
)
dtype
=
q
.
dtype
,
q
=
q_by_head
.
view
(
q
.
shape
)
non_blocking
=
True
)
k_by_head
=
k
.
view
(
*
k
.
shape
[:
-
1
],
k
.
shape
[
-
1
]
//
self
.
head_dim
,
# Persist the converted cache so we don't re-copy/re-allocate
self
.
head_dim
)
# on every forward when the original buffer starts on CPU.
if
envs
.
VLLM_USE_APEX_RN
:
self
.
rotary_emb
.
cos_sin_cache
=
cos_sin_cache
k_by_head
=
self
.
k_norm
.
forward_apex
(
k_by_head
)
q
=
q
.
contiguous
()
k
=
k
.
contiguous
()
torch
.
ops
.
vllm
.
rms_rotary_embedding_fuse
(
positions
,
q
,
k
,
self
.
head_dim
,
cos_sin_cache
,
self
.
rotary_emb
.
is_neox_style
,
self
.
q_norm
.
weight
,
self
.
k_norm
.
weight
,
None
,
None
,
self
.
q_norm
.
variance_epsilon
,
)
else
:
else
:
k_by_head
=
self
.
k_norm
.
forward_cuda
(
k_by_head
)
# Add qk-norm then RoPE (original path).
k
=
k_by_head
.
view
(
k
.
shape
)
q_by_head
=
q
.
view
(
*
q
.
shape
[:
-
1
],
q
.
shape
[
-
1
]
//
self
.
head_dim
,
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
self
.
head_dim
)
if
envs
.
VLLM_USE_APEX_RN
:
q_by_head
=
self
.
q_norm
.
forward_apex
(
q_by_head
)
else
:
q_by_head
=
self
.
q_norm
.
forward_cuda
(
q_by_head
)
q
=
q_by_head
.
view
(
q
.
shape
)
k_by_head
=
k
.
view
(
*
k
.
shape
[:
-
1
],
k
.
shape
[
-
1
]
//
self
.
head_dim
,
self
.
head_dim
)
if
envs
.
VLLM_USE_APEX_RN
:
k_by_head
=
self
.
k_norm
.
forward_apex
(
k_by_head
)
else
:
k_by_head
=
self
.
k_norm
.
forward_cuda
(
k_by_head
)
k
=
k_by_head
.
view
(
k
.
shape
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
o_proj
(
attn_output
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
return
output
...
...
vllm/platforms/rocm.py
View file @
bb3afd68
...
@@ -16,11 +16,7 @@ from vllm.utils import cuda_device_count_stateless
...
@@ -16,11 +16,7 @@ from vllm.utils import cuda_device_count_stateless
from
.interface
import
DeviceCapability
,
Platform
,
PlatformEnum
,
_Backend
from
.interface
import
DeviceCapability
,
Platform
,
PlatformEnum
,
_Backend
from
vllm.utils
import
SUPPORT_TC
,
SUPPORT_MOE_MARLIN_W16A16
from
vllm.utils
import
SUPPORT_TC
if
SUPPORT_MOE_MARLIN_W16A16
:
os
.
environ
[
'VLLM_USE_MARLIN_W16A16_MOE'
]
=
'1'
os
.
environ
[
'MOE_NN'
]
=
'0'
if
not
SUPPORT_TC
:
if
not
SUPPORT_TC
:
os
.
environ
[
'VLLM_USE_V1'
]
=
'0'
os
.
environ
[
'VLLM_USE_V1'
]
=
'0'
...
...
vllm/utils/__init__.py
View file @
bb3afd68
...
@@ -87,7 +87,6 @@ MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120
...
@@ -87,7 +87,6 @@ MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120
GPU_ARCH
=
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
GPU_ARCH
=
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
SUPPORT_TC
=
any
(
arch
in
GPU_ARCH
for
arch
in
[
"gfx928"
,
"gfx936"
,
"gfx938"
])
SUPPORT_TC
=
any
(
arch
in
GPU_ARCH
for
arch
in
[
"gfx928"
,
"gfx936"
,
"gfx938"
])
SUPPORT_MOE_MARLIN_W16A16
=
any
(
arch
in
GPU_ARCH
for
arch
in
[
"gfx936"
])
def
_generate_random_int8
(
def
_generate_random_int8
(
tensor
:
torch
.
Tensor
,
tensor
:
torch
.
Tensor
,
...
@@ -1959,7 +1958,7 @@ class W8a8GetCacheJSON:
...
@@ -1959,7 +1958,7 @@ class W8a8GetCacheJSON:
self
.
moe_weight_shapes
=
[]
self
.
moe_weight_shapes
=
[]
arch_name
=
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
arch_name
=
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
arch_cu
=
torch
.
cuda
.
get_device_properties
(
torch
.
cuda
.
current_device
()).
multi_processor_count
arch_cu
=
torch
.
cuda
.
get_device_properties
(
torch
.
cuda
.
current_device
()).
multi_processor_count
self
.
cache_json_data
=
{}
device_name
=
arch_name
+
'_'
+
str
(
arch_cu
)
+
'cu'
device_name
=
arch_name
+
'_'
+
str
(
arch_cu
)
+
'cu'
self
.
device_name
=
device_name
self
.
device_name
=
device_name
self
.
topk
=
1
self
.
topk
=
1
...
@@ -2061,19 +2060,27 @@ class W8a8GetCacheJSON:
...
@@ -2061,19 +2060,27 @@ class W8a8GetCacheJSON:
return
self
.
triton_json_dir
+
f
"/linear_
{
n
}
_
{
k
}
_block[
{
block_n
}
,
{
block_k
}
]_
{
self
.
device_name
}
.json"
return
self
.
triton_json_dir
+
f
"/linear_
{
n
}
_
{
k
}
_block[
{
block_n
}
,
{
block_k
}
]_
{
self
.
device_name
}
.json"
def
get_moeint8json_name
(
self
,
E
,
N1
,
N2
,
K
,
TOPK
,
def
get_moeint8json_name
(
self
,
E
,
N1
,
N2
,
K
,
TOPK
,
block_size
:
Optional
[
list
]
=
None
,
use_int4_w4a8
:
Optional
[
bool
]
=
False
):
block_size
:
Optional
[
list
]
=
None
,
use_int4_w4a8
:
Optional
[
bool
]
=
False
,
use_int8_w8a8
:
Optional
[
bool
]
=
False
):
if
use_int4_w4a8
:
if
use_int4_w4a8
:
if
block_size
is
not
None
:
if
block_size
is
not
None
:
return
self
.
triton_json_dir
+
f
"/MOE_W4A8INT8[
{
block_size
[
0
]
}
,
{
block_size
[
1
]
}
]_E=
{
E
}
_N1=
{
N1
}
_N2=
{
N2
}
_K=
{
K
}
_TOPK
{
TOPK
}
_
{
self
.
device_name
}
.json"
return
self
.
triton_json_dir
+
f
"/MOE_W4A8INT8[
{
block_size
[
0
]
}
,
{
block_size
[
1
]
}
]_E=
{
E
}
_N1=
{
N1
}
_N2=
{
N2
}
_K=
{
K
}
_TOPK
{
TOPK
}
_
{
self
.
device_name
}
.json"
else
:
else
:
return
self
.
triton_json_dir
+
f
"/MOE_W4A8INT8_E=
{
E
}
_N1=
{
N1
}
_N2=
{
N2
}
_K=
{
K
}
_TOPK
{
TOPK
}
_
{
self
.
device_name
}
.json"
return
self
.
triton_json_dir
+
f
"/MOE_W4A8INT8_E=
{
E
}
_N1=
{
N1
}
_N2=
{
N2
}
_K=
{
K
}
_TOPK
{
TOPK
}
_
{
self
.
device_name
}
.json"
elif
use_int8_w8a8
:
if
block_size
is
not
None
:
return
self
.
triton_json_dir
+
f
"/MOE_BLOCKINT8[
{
block_size
[
0
]
}
,
{
block_size
[
1
]
}
]_E=
{
E
}
_N1=
{
N1
}
_N2=
{
N2
}
_K=
{
K
}
_TOPK
{
TOPK
}
_
{
self
.
device_name
}
.json"
else
:
return
self
.
triton_json_dir
+
f
"/MOE_W8A8INT8_E=
{
E
}
_N1=
{
N1
}
_N2=
{
N2
}
_K=
{
K
}
_TOPK
{
TOPK
}
_
{
self
.
device_name
}
.json"
else
:
else
:
if
block_size
is
not
None
:
if
block_size
is
not
None
:
return
self
.
triton_json_dir
+
f
"/MOE_BLOCK
INT
8[
{
block_size
[
0
]
}
,
{
block_size
[
1
]
}
]_E=
{
E
}
_N1=
{
N1
}
_N2=
{
N2
}
_K=
{
K
}
_TOPK
{
TOPK
}
_
{
self
.
device_name
}
.json"
return
self
.
triton_json_dir
+
f
"/MOE_BLOCK
FP
8[
{
block_size
[
0
]
}
,
{
block_size
[
1
]
}
]_E=
{
E
}
_N1=
{
N1
}
_N2=
{
N2
}
_K=
{
K
}
_TOPK
{
TOPK
}
_
{
self
.
device_name
}
.json"
else
:
else
:
return
self
.
triton_json_dir
+
f
"/MOE_W8A8
INT
8_E=
{
E
}
_N1=
{
N1
}
_N2=
{
N2
}
_K=
{
K
}
_TOPK
{
TOPK
}
_
{
self
.
device_name
}
.json"
return
self
.
triton_json_dir
+
f
"/MOE_W8A8
FP
8_E=
{
E
}
_N1=
{
N1
}
_N2=
{
N2
}
_K=
{
K
}
_TOPK
{
TOPK
}
_
{
self
.
device_name
}
.json"
def
get_moeint8_triton_cache
(
self
,
file_path
,
E
,
N1
,
N2
,
K
,
TOPK
):
def
get_moeint8_triton_cache
(
self
,
file_path
,
E
,
N1
,
N2
,
K
,
TOPK
):
if
file_path
in
self
.
cache_json_data
:
# 直接返回缓存数据,避免重复读取
return
self
.
cache_json_data
[
file_path
]
cache_json_file
=
file_path
cache_json_file
=
file_path
if
os
.
path
.
exists
(
file_path
):
if
os
.
path
.
exists
(
file_path
):
...
@@ -2089,7 +2096,7 @@ class W8a8GetCacheJSON:
...
@@ -2089,7 +2096,7 @@ class W8a8GetCacheJSON:
for
sub_key
,
sub_value
in
value
.
items
():
for
sub_key
,
sub_value
in
value
.
items
():
configs_key
=
f
"
{
sub_key
}
_
{
key
}
"
configs_key
=
f
"
{
sub_key
}
_
{
key
}
"
configs_dict
[
configs_key
]
=
sub_value
configs_dict
[
configs_key
]
=
sub_value
self
.
cache_json_data
[
file_path
]
=
configs_dict
return
configs_dict
return
configs_dict
# Adapted from: https://stackoverflow.com/a/47212782/5082708
# Adapted from: https://stackoverflow.com/a/47212782/5082708
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment