Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7f74da5a
Commit
7f74da5a
authored
Apr 15, 2026
by
lixh6
Browse files
[FEATURE] 接入Aiter MoE W8A8 量化模型支持 && MQA_logits 修改 (Ref:wanghl)
parent
3842b316
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
322 additions
and
138 deletions
+322
-138
vllm/envs.py
vllm/envs.py
+4
-0
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+72
-30
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe_marlin.py
...ation/compressed_tensors/compressed_tensors_moe_marlin.py
+112
-44
vllm/model_executor/layers/sparse_attn_indexer.py
vllm/model_executor/layers/sparse_attn_indexer.py
+134
-64
No files found.
vllm/envs.py
View file @
7f74da5a
...
@@ -167,6 +167,7 @@ if TYPE_CHECKING:
...
@@ -167,6 +167,7 @@ if TYPE_CHECKING:
VLLM_MOE_USE_DEEP_GEMM
:
bool
=
True
VLLM_MOE_USE_DEEP_GEMM
:
bool
=
True
VLLM_USE_DEEP_GEMM_E8M0
:
bool
=
True
VLLM_USE_DEEP_GEMM_E8M0
:
bool
=
True
VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES
:
bool
=
True
VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES
:
bool
=
True
VLLM_USE_AITER_MOE_W8A8
:
bool
=
True
VLLM_DEEP_GEMM_WARMUP
:
Literal
[
VLLM_DEEP_GEMM_WARMUP
:
Literal
[
"skip"
,
"skip"
,
"full"
,
"full"
,
...
@@ -1287,6 +1288,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
...
@@ -1287,6 +1288,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES"
:
lambda
:
bool
(
"VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES"
,
"1"
))
int
(
os
.
getenv
(
"VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES"
,
"1"
))
),
),
"VLLM_USE_AITER_MOE_W8A8"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_USE_AITER_MOE_W8A8"
,
"1"
))
),
# DeepGemm JITs the kernels on-demand. The warmup attempts to make DeepGemm
# DeepGemm JITs the kernels on-demand. The warmup attempts to make DeepGemm
# JIT all the required kernels before model execution so there is no
# JIT all the required kernels before model execution so there is no
# JIT'ing in the hot-path. However, this warmup increases the engine
# JIT'ing in the hot-path. However, this warmup increases the engine
...
...
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
7f74da5a
...
@@ -6,7 +6,11 @@ import functools
...
@@ -6,7 +6,11 @@ import functools
import
json
import
json
import
os
import
os
import
math
import
math
import
sys
import
aiter
from
vllm._aiter_ops
import
rocm_aiter_ops
from
aiter.moe
import
get_aiter_moe_config
,
aiter_moe
,
MoeQuantType
,
MoeSolutionType
from
aiter.ops.shuffle
import
moe_layout_shuffle_gemm1
,
moe_layout_shuffle_gemm2
from
collections.abc
import
Callable
from
collections.abc
import
Callable
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
...
@@ -1858,35 +1862,73 @@ def fused_experts_impl(
...
@@ -1858,35 +1862,73 @@ def fused_experts_impl(
cache13
=
torch
.
empty
(
M
*
top_k_num
*
max
(
N
,
K
if
not
use_nn_moe
else
w2
.
shape
[
2
]),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
)
cache13
=
torch
.
empty
(
M
*
top_k_num
*
max
(
N
,
K
if
not
use_nn_moe
else
w2
.
shape
[
2
]),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
)
if
use_int8_w8a8
or
use_fp8_w8a8
:
if
use_int8_w8a8
or
use_fp8_w8a8
:
return
fused_experts_impl_int8
(
hidden_states
=
hidden_states
,
if
envs
.
VLLM_USE_AITER_MOE_W8A8
==
True
:
w1
=
w1
,
K_input
=
hidden_states
.
size
(
1
)
w2
=
w2
,
actual_N2
=
N
//
2
topk_weights
=
topk_weights
,
quant_type
=
MoeQuantType
.
W8A8
topk_ids
=
topk_ids
,
status
,
moe_config
=
get_aiter_moe_config
(
cache13
=
cache13
,
M
=
num_tokens
,
inplace
=
inplace
,
E
=
global_num_experts
,
activation
=
activation
,
N1
=
N
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
N2
=
actual_N2
,
use_fp8_w8a8
=
use_fp8_w8a8
,
K
=
K_input
,
use_int8_w8a8
=
use_int8_w8a8
,
top_k
=
top_k_num
,
use_int8_w8a16
=
False
,
block_size
=
0
,
use_int4_w4a16
=
False
,
dtype
=
hidden_states
.
dtype
,
per_channel_quant
=
per_channel_quant
,
quant_type
=
quant_type
,
global_num_experts
=
global_num_experts
,
)
expert_map
=
expert_map
,
w1_scale
=
w1_scale
,
output
=
aiter_moe
(
w2_scale
=
w2_scale
,
hidden_states
=
hidden_states
,
w1_zp
=
w1_zp
,
w1
=
w1
,
w2_zp
=
w2_zp
,
w2
=
w2
,
a1_scale
=
a1_scale
,
topk_weights
=
topk_weights
,
a2_scale
=
a2_scale
,
topk_ids
=
topk_ids
,
block_shape
=
block_shape
,
moe_config
=
moe_config
,
use_nn_moe
=
False
,
inplace
=
inplace
,
routed_scaling_factor
=
routed_scaling_factor
,
activation
=
activation
,
shared_output
=
shared_output
,
w1_scale
=
w1_scale
,
i_q
=
i_q
,
w2_scale
=
w2_scale
,
i_s
=
i_s
w1_zp
=
w1_zp
,
)
w2_zp
=
w2_zp
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
block_shape
=
None
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
routed_scaling_factor
=
routed_scaling_factor
,
)
return
output
else
:
return
fused_experts_impl_int8
(
hidden_states
=
hidden_states
,
w1
=
w1
,
w2
=
w2
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
cache13
=
cache13
,
inplace
=
inplace
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
False
,
use_int4_w4a16
=
False
,
per_channel_quant
=
per_channel_quant
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
w1_zp
=
w1_zp
,
w2_zp
=
w2_zp
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
block_shape
=
block_shape
,
use_nn_moe
=
False
,
routed_scaling_factor
=
routed_scaling_factor
,
shared_output
=
shared_output
,
i_q
=
i_q
,
i_s
=
i_s
)
elif
use_int4_w4a8
is
True
:
elif
use_int4_w4a8
is
True
:
return
fused_experts_impl_w4a8
(
hidden_states
=
hidden_states
,
return
fused_experts_impl_w4a8
(
hidden_states
=
hidden_states
,
w1
=
w1
,
w1
=
w1
,
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe_marlin.py
View file @
7f74da5a
...
@@ -26,6 +26,14 @@ from vllm.model_executor.layers.fused_moe import (
...
@@ -26,6 +26,14 @@ from vllm.model_executor.layers.fused_moe import (
FusedMoEPrepareAndFinalize
,
FusedMoEPrepareAndFinalize
,
FusedMoeWeightScaleSupported
,
FusedMoeWeightScaleSupported
,
)
)
import
aiter
from
aiter.test_common
import
checkAllclose
,
perftest
from
aiter.ops.shuffle
import
moe_layout_shuffle_gemm1
,
moe_layout_shuffle_gemm2
from
aiter.fused_moe
import
fused_topk
,
torch_moe
from
aiter
import
dtypes
,
ActivationType
from
aiter.moe
import
get_aiter_moe_config
,
aiter_moe
,
MoeSolutionType
,
MoeQuantType
try
:
try
:
from
lmslim.layers.fused_moe.fuse_moe_int8_marlin
import
fused_experts_impl_int8_marlin
from
lmslim.layers.fused_moe.fuse_moe_int8_marlin
import
fused_experts_impl_int8_marlin
from
lmslim.layers.fused_moe.fuse_moe_fp8_marlin
import
fused_experts_impl_fp8_marlin
from
lmslim.layers.fused_moe.fuse_moe_fp8_marlin
import
fused_experts_impl_fp8_marlin
...
@@ -369,28 +377,48 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
...
@@ -369,28 +377,48 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
layer
.
w13_input_scale
=
None
layer
.
w13_input_scale
=
None
layer
.
w2_input_scale
=
None
layer
.
w2_input_scale
=
None
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
def
shuffle_w8a8_gemm1
(
self
,
weight_data
):
w1_marlin_list
=
[]
w_i8
=
weight_data
.
to
(
torch
.
int8
)
for
ii
in
range
(
layer
.
w13_weight
.
shape
[
0
]):
return
moe_layout_shuffle_gemm1
(
w_i8
)
if
not
self
.
use_deepep
:
w1_marlin_in
=
get_w8a8_int8_marlin_weights
(
layer
.
w13_weight
[
ii
])
else
:
w1_marlin_in
=
w8a8_nt_kpack2_marlin_weight
(
layer
.
w13_weight
[
ii
])
w1_marlin_list
.
append
(
w1_marlin_in
)
w1_marlin
=
torch
.
stack
(
w1_marlin_list
,
dim
=
0
)
del
w1_marlin_list
def
shuffle_w8a8_gemm2
(
self
,
weight_data
):
w2_marlin_list
=
[]
w_i8
=
weight_data
.
to
(
torch
.
int8
)
for
ii
in
range
(
layer
.
w2_weight
.
shape
[
0
]):
return
moe_layout_shuffle_gemm2
(
w_i8
)
if
not
self
.
use_deepep
:
w2_marlin_in
=
get_w8a8_int8_marlin_weights
(
layer
.
w2_weight
[
ii
])
else
:
w2_marlin_in
=
w8a8_nt_kpack2_marlin_weight
(
layer
.
w2_weight
[
ii
])
w2_marlin_list
.
append
(
w2_marlin_in
)
w2_marlin
=
torch
.
stack
(
w2_marlin_list
,
dim
=
0
)
layer
.
w13_weight
=
Parameter
(
w1_marlin
,
requires_grad
=
False
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
layer
.
w2_weight
=
Parameter
(
w2_marlin
,
requires_grad
=
False
)
if
envs
.
VLLM_USE_AITER_MOE_W8A8
==
True
:
E
,
N13
,
K
=
layer
.
w13_weight
.
shape
_
,
K_w2
,
N2
=
layer
.
w2_weight
.
shape
layer
.
w13_weight_scale
=
Parameter
(
layer
.
w13_weight_scale
.
data
,
requires_grad
=
False
)
layer
.
w2_weight_scale
=
Parameter
(
layer
.
w2_weight_scale
.
data
,
requires_grad
=
False
)
shuffled_w13
=
self
.
shuffle_w8a8_gemm1
(
layer
.
w13_weight
)
layer
.
w13_weight
=
Parameter
(
shuffled_w13
.
view
(
*
layer
.
w13_weight
.
shape
),
requires_grad
=
False
)
shuffled_w2
=
self
.
shuffle_w8a8_gemm2
(
layer
.
w2_weight
)
layer
.
w2_weight
=
Parameter
(
shuffled_w2
.
view
(
*
layer
.
w2_weight
.
shape
),
requires_grad
=
False
)
else
:
w1_marlin_list
=
[]
for
ii
in
range
(
layer
.
w13_weight
.
shape
[
0
]):
if
not
self
.
use_deepep
:
w1_marlin_in
=
get_w8a8_int8_marlin_weights
(
layer
.
w13_weight
[
ii
])
else
:
w1_marlin_in
=
w8a8_nt_kpack2_marlin_weight
(
layer
.
w13_weight
[
ii
])
w1_marlin_list
.
append
(
w1_marlin_in
)
w1_marlin
=
torch
.
stack
(
w1_marlin_list
,
dim
=
0
)
del
w1_marlin_list
w2_marlin_list
=
[]
for
ii
in
range
(
layer
.
w2_weight
.
shape
[
0
]):
if
not
self
.
use_deepep
:
w2_marlin_in
=
get_w8a8_int8_marlin_weights
(
layer
.
w2_weight
[
ii
])
else
:
w2_marlin_in
=
w8a8_nt_kpack2_marlin_weight
(
layer
.
w2_weight
[
ii
])
w2_marlin_list
.
append
(
w2_marlin_in
)
w2_marlin
=
torch
.
stack
(
w2_marlin_list
,
dim
=
0
)
layer
.
w13_weight
=
Parameter
(
w1_marlin
,
requires_grad
=
False
)
layer
.
w2_weight
=
Parameter
(
w2_marlin
,
requires_grad
=
False
)
def
apply
(
def
apply
(
self
,
self
,
...
@@ -405,31 +433,71 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
...
@@ -405,31 +433,71 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
routed_scaling_factor
:
Optional
[
float
]
=
1.0
,
routed_scaling_factor
:
Optional
[
float
]
=
1.0
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
from
vllm.model_executor.layers.fused_moe
import
fused_experts
if
envs
.
VLLM_USE_AITER_MOE_W8A8
==
True
:
m_flat
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
M
=
m_flat
.
shape
[
0
]
E
=
layer
.
w13_weight
.
size
(
0
)
K
=
x
.
size
(
-
1
)
N1
=
layer
.
w13_weight
.
size
(
1
)
topk
=
topk_ids
.
size
(
1
)
w1_input
=
layer
.
w13_weight
.
view
(
E
,
N1
,
K
)
w2_input
=
layer
.
w2_weight
.
view
(
E
,
K
,
N1
//
2
)
_
,
moe_cfg
=
get_aiter_moe_config
(
M
=
M
,
E
=
E
,
N1
=
N1
,
N2
=
N1
//
2
,
K
=
K
,
top_k
=
topk
,
block_size
=
0
,
dtype
=
x
.
dtype
,
quant_type
=
MoeQuantType
.
W8A8
,
)
return
fused_experts_impl_int8_marlin
(
output
=
aiter_moe
(
hidden_states
=
x
,
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w1
=
w1_input
,
w2
=
layer
.
w2_weight
,
w2
=
w2_input
,
topk_weights
=
topk_weights
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
topk_ids
=
topk_ids
,
inplace
=
True
,
moe_config
=
moe_cfg
,
activation
=
layer
.
activation
,
inplace
=
False
,
apply_router_weight_on_input
=
layer
.
apply_router_weight_on_input
,
activation
=
getattr
(
layer
,
"activation"
,
"silu"
),
use_int8_w8a8
=
True
,
w1_scale
=
layer
.
w13_weight_scale
,
per_channel_quant
=
True
,
w2_scale
=
layer
.
w2_weight_scale
,
global_num_experts
=
layer
.
global_num_experts
,
a1_scale
=
getattr
(
layer
,
"w13_input_scale"
,
None
),
expert_map
=
layer
.
expert_map
,
a2_scale
=
getattr
(
layer
,
"w2_input_scale"
,
None
),
quant_config
=
self
.
moe_quant_config
,
global_num_experts
=
E
,
w1_scale
=
layer
.
w13_weight_scale
,
expert_map
=
getattr
(
layer
,
"expert_map"
,
None
),
w2_scale
=
layer
.
w2_weight_scale
,
routed_scaling_factor
=
routed_scaling_factor
,
a1_scale
=
layer
.
w13_input_scale
,
)
a2_scale
=
layer
.
w2_input_scale
,
return
output
use_nn_moe
=
False
,
else
:
i_q
=
i_q
,
return
fused_experts_impl_int8_marlin
(
i_s
=
i_s
,
hidden_states
=
x
,
shared_output
=
shared_output
,
w1
=
layer
.
w13_weight
,
routed_scaling_factor
=
routed_scaling_factor
,
w2
=
layer
.
w2_weight
,
)
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
True
,
activation
=
layer
.
activation
,
apply_router_weight_on_input
=
layer
.
apply_router_weight_on_input
,
use_int8_w8a8
=
True
,
per_channel_quant
=
True
,
global_num_experts
=
layer
.
global_num_experts
,
expert_map
=
layer
.
expert_map
,
quant_config
=
self
.
moe_quant_config
,
w1_scale
=
layer
.
w13_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
use_nn_moe
=
False
,
i_q
=
i_q
,
i_s
=
i_s
,
shared_output
=
shared_output
,
routed_scaling_factor
=
routed_scaling_factor
,
)
def
select_gemm_impl
(
def
select_gemm_impl
(
self
,
self
,
...
...
vllm/model_executor/layers/sparse_attn_indexer.py
View file @
7f74da5a
...
@@ -30,6 +30,7 @@ elif current_platform.is_xpu():
...
@@ -30,6 +30,7 @@ elif current_platform.is_xpu():
from
vllm._ipex_ops
import
ipex_ops
as
ops
from
vllm._ipex_ops
import
ipex_ops
as
ops
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
_GLOBAL_LOGITS_BUFFERS
=
{}
@
maybe_transfer_kv_layer
@
maybe_transfer_kv_layer
def
sparse_attn_indexer
(
def
sparse_attn_indexer
(
...
@@ -50,7 +51,21 @@ def sparse_attn_indexer(
...
@@ -50,7 +51,21 @@ def sparse_attn_indexer(
# careful! this will be None in dummy run
# careful! this will be None in dummy run
attn_metadata
=
get_forward_context
().
attn_metadata
attn_metadata
=
get_forward_context
().
attn_metadata
fp8_dtype
=
current_platform
.
fp8_dtype
()
fp8_dtype
=
current_platform
.
fp8_dtype
()
if
q_fp8
.
dtype
==
fp8_dtype
:
MAX_ELEMENTS
=
65536
*
65536
elif
q_fp8
.
dtype
in
(
torch
.
bfloat16
,
torch
.
float16
):
MAX_ELEMENTS
=
16384
*
32768
else
:
MAX_ELEMENTS
=
16384
*
32768
device
=
q_fp8
.
device
if
device
not
in
_GLOBAL_LOGITS_BUFFERS
or
_GLOBAL_LOGITS_BUFFERS
[
device
].
numel
()
<
MAX_ELEMENTS
:
_GLOBAL_LOGITS_BUFFERS
[
device
]
=
torch
.
empty
(
MAX_ELEMENTS
,
dtype
=
torch
.
float32
,
device
=
device
)
logits_buffer
=
_GLOBAL_LOGITS_BUFFERS
[
device
]
# assert isinstance(attn_metadata, dict)
# assert isinstance(attn_metadata, dict)
if
not
isinstance
(
attn_metadata
,
dict
):
if
not
isinstance
(
attn_metadata
,
dict
):
# Reserve workspace for indexer during profiling run
# Reserve workspace for indexer during profiling run
...
@@ -75,7 +90,14 @@ def sparse_attn_indexer(
...
@@ -75,7 +90,14 @@ def sparse_attn_indexer(
)
)
attn_metadata
=
attn_metadata
[
layer_name
]
attn_metadata
=
attn_metadata
[
layer_name
]
assert
isinstance
(
attn_metadata
,
DeepseekV32IndexerMetadata
)
assert
isinstance
(
attn_metadata
,
DeepseekV32IndexerMetadata
)
slot_mapping
=
attn_metadata
.
slot_mapping
# slot_mapping = attn_metadata.slot_mapping[:attn_metadata.num_kv_actual_tokens]
if
hasattr
(
attn_metadata
,
'num_kv_actual_tokens'
):
num_kv_tokens
=
attn_metadata
.
num_kv_actual_tokens
elif
hasattr
(
attn_metadata
,
'num_prefills'
)
and
attn_metadata
.
num_prefills
>
0
:
num_kv_tokens
=
getattr
(
attn_metadata
,
'num_prefill_tokens'
,
attn_metadata
.
slot_mapping
.
shape
[
0
])
else
:
num_kv_tokens
=
attn_metadata
.
slot_mapping
.
shape
[
0
]
slot_mapping
=
attn_metadata
.
slot_mapping
[:
num_kv_tokens
]
has_decode
=
attn_metadata
.
num_decodes
>
0
has_decode
=
attn_metadata
.
num_decodes
>
0
has_prefill
=
attn_metadata
.
num_prefills
>
0
has_prefill
=
attn_metadata
.
num_prefills
>
0
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
...
@@ -116,14 +138,6 @@ def sparse_attn_indexer(
...
@@ -116,14 +138,6 @@ def sparse_attn_indexer(
chunk
.
block_table
,
chunk
.
block_table
,
chunk
.
cu_seq_lens
,
chunk
.
cu_seq_lens
,
)
)
logits
=
fp8_mqa_logits
(
q_fp8
[
chunk
.
token_start
:
chunk
.
token_end
],
(
k_fp8
,
k_scale
.
view
(
torch
.
float32
).
flatten
()),
weights
[
chunk
.
token_start
:
chunk
.
token_end
],
chunk
.
cu_seqlen_ks
,
chunk
.
cu_seqlen_ke
,
)
elif
get_gcn_arch_name
()
==
"gfx938"
:
elif
get_gcn_arch_name
()
==
"gfx938"
:
k_fp8
=
k_fp8_full
[:
chunk
.
total_seq_lens
]
k_fp8
=
k_fp8_full
[:
chunk
.
total_seq_lens
]
k_scale
=
k_scale_full
[:
chunk
.
total_seq_lens
]
k_scale
=
k_scale_full
[:
chunk
.
total_seq_lens
]
...
@@ -134,19 +148,6 @@ def sparse_attn_indexer(
...
@@ -134,19 +148,6 @@ def sparse_attn_indexer(
chunk
.
block_table
,
chunk
.
block_table
,
chunk
.
cu_seq_lens
,
chunk
.
cu_seq_lens
,
)
)
logits
=
op
.
mqa_logits
(
q_fp8
[
chunk
.
token_start
:
chunk
.
token_end
],
k_fp8
,
weights
[
chunk
.
token_start
:
chunk
.
token_end
],
chunk
.
cu_seqlen_ks
,
chunk
.
cu_seqlen_ke
,
q_fp8
[
chunk
.
token_start
:
chunk
.
token_end
].
shape
[
0
],
k_fp8
.
shape
[
0
],
q_fp8
.
shape
[
1
],
q_fp8
.
shape
[
2
],
k_scale
.
view
(
torch
.
float32
).
flatten
(),
True
)
else
:
else
:
k_fp8
=
k_fp8_full
[:
chunk
.
total_seq_lens
]
k_fp8
=
k_fp8_full
[:
chunk
.
total_seq_lens
]
k_scale
=
k_scale_full
[:
chunk
.
total_seq_lens
]
k_scale
=
k_scale_full
[:
chunk
.
total_seq_lens
]
...
@@ -156,46 +157,117 @@ def sparse_attn_indexer(
...
@@ -156,46 +157,117 @@ def sparse_attn_indexer(
chunk
.
block_table
,
chunk
.
block_table
,
chunk
.
cu_seq_lens
,
chunk
.
cu_seq_lens
,
)
)
logits
=
op
.
mqa_logits
(
q_fp8
[
chunk
.
token_start
:
chunk
.
token_end
],
q_all
=
q_fp8
[
chunk
.
token_start
:
chunk
.
token_end
]
k_fp8
,
weights_all
=
weights
[
chunk
.
token_start
:
chunk
.
token_end
]
weights
[
chunk
.
token_start
:
chunk
.
token_end
].
to
(
torch
.
float32
),
ks_all
=
chunk
.
cu_seqlen_ks
chunk
.
cu_seqlen_ks
,
ke_all
=
chunk
.
cu_seqlen_ke
chunk
.
cu_seqlen_ke
,
q_fp8
[
chunk
.
token_start
:
chunk
.
token_end
].
shape
[
0
],
num_q
=
q_all
.
shape
[
0
]
k_fp8
.
shape
[
0
],
num_k
=
k_fp8
.
shape
[
0
]
q_fp8
.
shape
[
1
],
q_fp8
.
shape
[
2
],
is_q_fp16_bf16
=
q_all
.
dtype
in
(
torch
.
float16
,
torch
.
bfloat16
)
None
,
align_size
=
128
if
is_q_fp16_bf16
else
1
True
kv_seq_len_aligned
=
(
num_k
+
align_size
-
1
)
//
align_size
*
align_size
current_capacity
=
logits_buffer
.
numel
()
MAX_Q_CHUNK
=
current_capacity
//
max
(
1
,
kv_seq_len_aligned
)
if
align_size
>
1
:
MAX_Q_CHUNK
=
(
MAX_Q_CHUNK
//
align_size
)
*
align_size
MAX_Q_CHUNK
=
max
(
1
,
MAX_Q_CHUNK
)
slices
=
[]
for
start_idx
in
range
(
0
,
num_q
,
MAX_Q_CHUNK
):
end_idx
=
min
(
start_idx
+
MAX_Q_CHUNK
,
num_q
)
slices
.
append
((
start_idx
,
end_idx
))
for
q_start
,
q_end
in
slices
:
if
q_end
<=
q_start
:
continue
q_slice
=
q_all
[
q_start
:
q_end
]
weights_slice
=
weights_all
[
q_start
:
q_end
]
ks_slice
=
ks_all
[
q_start
:
q_end
]
ke_slice
=
ke_all
[
q_start
:
q_end
]
q_len
=
q_end
-
q_start
q_seq_len_aligned
=
(
q_len
+
align_size
-
1
)
//
align_size
*
align_size
required_size
=
q_seq_len_aligned
*
kv_seq_len_aligned
logits_slice_view
=
logits_buffer
[:
required_size
].
view
(
q_seq_len_aligned
,
kv_seq_len_aligned
)
if
not
current_platform
.
is_rocm
():
logits_slice
=
fp8_mqa_logits
(
q_slice
,
(
k_fp8
,
k_scale
.
view
(
torch
.
float32
).
flatten
()),
weights_slice
,
ks_slice
,
ke_slice
,
)
elif
get_gcn_arch_name
()
==
"gfx938"
:
op
.
mqa_logits
(
q_slice
,
k_fp8
,
weights_slice
,
ks_slice
,
ke_slice
,
q_slice
.
shape
[
0
],
k_fp8
.
shape
[
0
],
q_slice
.
shape
[
1
],
q_slice
.
shape
[
2
],
k_scale
.
view
(
torch
.
float32
).
flatten
(),
True
,
logits_slice_view
)
logits_slice
=
logits_slice_view
[:
q_len
,
:
num_k
]
else
:
op
.
mqa_logits
(
q_slice
,
k_fp8
,
weights_slice
.
to
(
torch
.
float32
),
ks_slice
,
ke_slice
,
q_slice
.
shape
[
0
],
k_fp8
.
shape
[
0
],
q_slice
.
shape
[
1
],
q_slice
.
shape
[
2
],
None
,
True
,
logits_slice_view
)
logits_slice
=
logits_slice_view
[:
q_len
,
:
num_k
]
num_rows_slice
=
logits_slice
.
shape
[
0
]
topk_indices_slice
=
topk_indices_buffer
[
chunk
.
token_start
+
q_start
:
chunk
.
token_start
+
q_end
,
:
topk_tokens
]
if
not
envs
.
USE_LIGHTOP_TOPK
:
torch
.
ops
.
_C
.
top_k_per_row_prefill
(
logits_slice
,
ks_slice
,
ke_slice
,
topk_indices_slice
,
num_rows_slice
,
logits_slice
.
stride
(
0
),
logits_slice
.
stride
(
1
),
topk_tokens
,
)
else
:
op
.
top_k_per_row_prefill
(
logits_slice
,
ks_slice
,
ke_slice
,
topk_indices_slice
,
num_rows_slice
,
logits_slice
.
stride
(
0
),
logits_slice
.
stride
(
1
),
topk_tokens
,
)
)
num_rows
=
logits
.
shape
[
0
]
topk_indices
=
topk_indices_buffer
[
chunk
.
token_start
:
chunk
.
token_end
,
:
topk_tokens
]
if
not
envs
.
USE_LIGHTOP_TOPK
:
torch
.
ops
.
_C
.
top_k_per_row_prefill
(
logits
,
chunk
.
cu_seqlen_ks
,
chunk
.
cu_seqlen_ke
,
topk_indices
,
num_rows
,
logits
.
stride
(
0
),
logits
.
stride
(
1
),
topk_tokens
,
)
else
:
op
.
top_k_per_row_prefill
(
logits
,
chunk
.
cu_seqlen_ks
,
chunk
.
cu_seqlen_ke
,
topk_indices
,
num_rows
,
logits
.
stride
(
0
),
logits
.
stride
(
1
),
topk_tokens
,
)
if
has_decode
:
if
has_decode
:
decode_metadata
=
attn_metadata
.
decode
decode_metadata
=
attn_metadata
.
decode
...
@@ -423,6 +495,4 @@ class SparseAttnIndexer(CustomOp):
...
@@ -423,6 +495,4 @@ class SparseAttnIndexer(CustomOp):
self
.
max_model_len
,
self
.
max_model_len
,
self
.
max_total_seq_len
,
self
.
max_total_seq_len
,
self
.
topk_indices_buffer
,
self
.
topk_indices_buffer
,
)
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment