Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c637d1aa
Commit
c637d1aa
authored
Apr 15, 2026
by
王敏
Browse files
Merge remote-tracking branch 'origin/v0.15.1-dev-pcp' into v0.15.1-dev-pcp
parents
f3d1f95b
263d6216
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
290 additions
and
32 deletions
+290
-32
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+2
-1
vllm/model_executor/layers/quantization/slimquant_w4a8.py
vllm/model_executor/layers/quantization/slimquant_w4a8.py
+229
-1
vllm/model_executor/layers/quantization/slimquant_w4a8_marlin.py
...del_executor/layers/quantization/slimquant_w4a8_marlin.py
+5
-2
vllm/model_executor/layers/sparse_attn_indexer.py
vllm/model_executor/layers/sparse_attn_indexer.py
+54
-28
No files found.
vllm/model_executor/layers/fused_moe/layer.py
View file @
c637d1aa
...
@@ -733,7 +733,8 @@ class FusedMoE(CustomOp):
...
@@ -733,7 +733,8 @@ class FusedMoE(CustomOp):
if
(
self
.
quant_method
.
__class__
.
__name__
in
(
"BlockInt8MoEMethod"
,
if
(
self
.
quant_method
.
__class__
.
__name__
in
(
"BlockInt8MoEMethod"
,
"SlimQuantW4A8Int8MoEMethod"
,
"SlimQuantW4A8Int8MoEMethod"
,
"SlimQuantW4A8Int8MarlinMoEMethod"
)):
"SlimQuantW4A8Int8MarlinMoEMethod"
,
"SlimQuantW4A8Int8AiterMoEMethod"
)):
moe_quant_params
[
"intermediate_size"
]
=
self
.
intermediate_size_per_partition
moe_quant_params
[
"intermediate_size"
]
=
self
.
intermediate_size_per_partition
...
...
vllm/model_executor/layers/quantization/slimquant_w4a8.py
View file @
c637d1aa
...
@@ -25,6 +25,21 @@ import os
...
@@ -25,6 +25,21 @@ import os
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm
import
envs
from
vllm
import
envs
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
try
:
from
aiter.ops.shuffle
import
w4a8_moe_layout_shuffle_gemm1
,
w4a8_moe_layout_shuffle_gemm2
from
aiter.moe
import
(
get_aiter_moe_config
,
aiter_moe
,
MoeSolutionType
,
MoeQuantType
,
)
from
aiter
import
dtypes
,
ActivationType
except
ImportError
as
e
:
print
(
"Import error msg: import aiter"
)
W8A8_TRITONJSON
=
W8a8GetCacheJSON
()
W8A8_TRITONJSON
=
W8a8GetCacheJSON
()
def
baseline_scaled_mm
(
a
:
torch
.
Tensor
,
def
baseline_scaled_mm
(
a
:
torch
.
Tensor
,
...
@@ -82,7 +97,10 @@ class SlimQuantW4A8Int8Config(QuantizationConfig):
...
@@ -82,7 +97,10 @@ class SlimQuantW4A8Int8Config(QuantizationConfig):
if
isinstance
(
layer
,
LinearBase
):
if
isinstance
(
layer
,
LinearBase
):
return
SlimQuantW4A8Int8LinearMethod
(
self
)
return
SlimQuantW4A8Int8LinearMethod
(
self
)
elif
isinstance
(
layer
,
FusedMoE
):
elif
isinstance
(
layer
,
FusedMoE
):
return
SlimQuantW4A8Int8MoEMethod
(
self
,
layer
.
moe_config
)
if
envs
.
VLLM_ROCM_USE_AITER_MOE
:
return
SlimQuantW4A8Int8AiterMoEMethod
(
self
,
layer
.
moe_config
)
else
:
return
SlimQuantW4A8Int8MoEMethod
(
self
,
layer
.
moe_config
)
return
None
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
...
@@ -328,4 +346,214 @@ class SlimQuantW4A8Int8MoEMethod:
...
@@ -328,4 +346,214 @@ class SlimQuantW4A8Int8MoEMethod:
use_nn_moe
=
use_nn_moe
,
use_nn_moe
=
use_nn_moe
,
shared_output
=
shared_output
,
shared_output
=
shared_output
,
routed_scaling_factor
=
routed_scaling_factor
,
routed_scaling_factor
=
routed_scaling_factor
,
)
class
SlimQuantW4A8Int8AiterMoEMethod
:
"""MoE method for W4A8INT8.
Supports loading INT8 checkpoints with static weight scale and
dynamic/static activation scale.
Also supports loading quantized FP16/BF16 model checkpoints with dynamic
activation scaling. The weight scaling factor will be initialized after
the model weights are loaded.
Args:
quant_config: The quantization config.
"""
def
__new__
(
cls
,
*
args
,
**
kwargs
):
if
not
hasattr
(
cls
,
"_initialized"
):
original_init
=
cls
.
__init__
new_cls
=
type
(
cls
.
__name__
,
(
FusedMoEMethodBase
,),
{
"__init__"
:
original_init
,
**
{
k
:
v
for
k
,
v
in
cls
.
__dict__
.
items
()
if
k
!=
"__dict__"
},
},
)
obj
=
super
(
new_cls
,
new_cls
).
__new__
(
new_cls
)
obj
.
__init__
(
*
args
,
**
kwargs
)
return
obj
return
super
().
__new__
(
cls
)
def
__init__
(
self
,
quant_config
,
moe
):
self
.
moe
=
moe
self
.
quant_config
=
quant_config
self
.
tritonsingleton
=
W8a8GetCacheJSON
()
self
.
moe_quant_config
:
Optional
[
FusedMoEQuantConfig
]
=
None
self
.
moe_mk
:
Optional
[
FusedMoEModularKernel
]
=
None
def
get_fused_moe_quant_config
(
self
,
layer
:
torch
.
nn
.
Module
)
->
Optional
[
FusedMoEQuantConfig
]:
self
.
moe_quant_config
=
FusedMoEQuantConfig
.
make
(
torch
.
int8
,
w1_scale
=
layer
.
w13_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
per_act_token_quant
=
True
,
per_out_ch_quant
=
False
,
block_shape
=
None
,
weight_dtype
=
'int4'
)
)
return
self
.
moe_quant_config
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
tp_size
=
get_tensor_model_parallel_world_size
()
# WEIGHTS
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
2
*
intermediate_size
,
hidden_size
//
2
,
dtype
=
torch
.
int8
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size
//
2
,
dtype
=
torch
.
int8
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
w13_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
2
*
intermediate_size
,
1
,
dtype
=
torch
.
float32
),
requires_grad
=
False
,
)
w2_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
hidden_size
,
1
,
dtype
=
torch
.
float32
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_weight_scale"
,
w13_weight_scale
)
layer
.
register_parameter
(
"w2_weight_scale"
,
w2_weight_scale
)
extra_weight_attrs
.
update
(
{
"quant_method"
:
FusedMoeWeightScaleSupported
.
CHANNEL
.
value
}
)
set_weight_attrs
(
w13_weight_scale
,
extra_weight_attrs
)
set_weight_attrs
(
w2_weight_scale
,
extra_weight_attrs
)
w13_input_scale
=
None
layer
.
register_parameter
(
"w13_input_scale"
,
w13_input_scale
)
w2_input_scale
=
None
layer
.
register_parameter
(
"w2_input_scale"
,
w2_input_scale
)
def
repack_and_shuffle_w4a8
(
self
,
weight_data
,
E
):
"""
逐 expert 处理 [n, k_half]
处理完直接写回 weight_data[i]
"""
# 原始 shape: [E, n, k_half]
for
i
in
range
(
E
):
# 1. 取当前 expert [n, k_half]
expert
=
weight_data
[
i
]
n
,
k_half
=
expert
.
shape
# 2. repack 逻辑(连续 → blocked)
w_u8
=
expert
.
to
(
torch
.
uint8
)
# 解包 1byte → 2个4bit
w_unpacked
=
torch
.
stack
([
(
w_u8
>>
4
)
&
0x0F
,
w_u8
&
0x0F
],
dim
=-
1
).
view
(
n
,
-
1
)
# 8个4bit分块重排
blocks
=
w_unpacked
.
view
(
n
,
-
1
,
8
)
w_low
=
blocks
[...,
:
4
]
w_high
=
blocks
[...,
4
:]
packed
=
(
w_low
<<
4
)
|
w_high
packed
=
packed
.
view
(
n
,
k_half
)
# 3. shuffle
w_marlin_in
=
w4a8_moe_layout_shuffle_gemm2
(
packed
)
w_marlin_in
=
w_marlin_in
.
reshape
(
n
,
k_half
)
# 4. 直接写回
weight_data
[
i
]
=
w_marlin_in
return
weight_data
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
E
=
layer
.
w13_weight
.
shape
[
0
]
layer
.
w13_weight_scale
=
Parameter
(
layer
.
w13_weight_scale
.
data
,
requires_grad
=
False
)
layer
.
w2_weight_scale
=
Parameter
(
layer
.
w2_weight_scale
.
data
,
requires_grad
=
False
)
layer
.
w13_weight
=
Parameter
(
self
.
repack_and_shuffle_w4a8
(
layer
.
w13_weight
.
data
,
E
),
requires_grad
=
False
)
layer
.
w2_weight
=
Parameter
(
self
.
repack_and_shuffle_w4a8
(
layer
.
w2_weight
.
data
,
E
),
requires_grad
=
False
)
def
apply
(
self
,
layer
:
FusedMoE
,
x
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
use_nn_moe
:
bool
|
None
=
False
,
use_fused_gate
:
bool
|
None
=
False
,
i_q
:
torch
.
Tensor
|
None
=
None
,
i_s
:
torch
.
Tensor
|
None
=
None
,
shared_output
:
torch
.
Tensor
|
None
=
None
,
routed_scaling_factor
:
float
=
1.0
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
E
=
layer
.
w13_weight
.
size
(
0
)
K
=
x
.
size
(
-
1
)
N1
=
layer
.
w13_weight
.
size
(
1
)
if
x
.
dim
()
==
2
:
# Make sure we are using the correct a1 (pre-permute).
M
=
x
.
size
(
0
)
else
:
assert
x
.
dim
()
==
3
assert
x
.
size
(
0
)
==
E
,
f
"
{
x
.
size
(
0
)
}
==
{
E
}
"
M
=
x
.
size
(
1
)
topk
=
topk_ids
.
size
(
1
)
status
,
moe_cfg
=
get_aiter_moe_config
(
M
=
M
,
E
=
E
,
N1
=
N1
,
N2
=
N1
//
2
,
K
=
K
,
top_k
=
topk
,
block_size
=
None
,
dtype
=
dtypes
.
bf16
,
quant_type
=
MoeQuantType
.
W4A8
,
)
if
not
status
:
assert
moe_cfg
.
solution_type
is
None
assert
moe_cfg
.
config
is
None
logger
.
info
(
f
"[get_config_w4a8]
{
M
=
}
, no solution found"
)
return
aiter_moe
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
,
topk_ids
,
moe_cfg
,
layer
.
w13_weight_scale
,
layer
.
w2_weight_scale
,
w1_zp
=
None
,
w2_zp
=
None
,
a1_scale
=
None
,
a2_scale
=
None
,
block_shape
=
None
,
global_num_experts
=
E
,
expert_map
=
None
,
activation
=
"silu"
)
\ No newline at end of file
vllm/model_executor/layers/quantization/slimquant_w4a8_marlin.py
View file @
c637d1aa
...
@@ -16,7 +16,7 @@ from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
...
@@ -16,7 +16,7 @@ from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
from
vllm.model_executor.layers.fused_moe.config
import
FusedMoEQuantConfig
from
vllm.model_executor.layers.fused_moe.config
import
FusedMoEQuantConfig
from
vllm.model_executor.layers.fused_moe.modular_kernel
import
(
from
vllm.model_executor.layers.fused_moe.modular_kernel
import
(
FusedMoEModularKernel
)
FusedMoEModularKernel
)
from
vllm.model_executor.layers.quantization.slimquant_w4a8
import
SlimQuantW4A8Int8LinearMethod
from
vllm.model_executor.layers.quantization.slimquant_w4a8
import
SlimQuantW4A8Int8LinearMethod
,
SlimQuantW4A8Int8AiterMoEMethod
from
vllm.model_executor.layers.fused_moe.fused_moe
import
get_moe_cache
from
vllm.model_executor.layers.fused_moe.fused_moe
import
get_moe_cache
try
:
try
:
from
lmslim.layers.fused_moe.fuse_moe_w4a8_marlin
import
fused_experts_impl_w4a8_marlin
from
lmslim.layers.fused_moe.fuse_moe_w4a8_marlin
import
fused_experts_impl_w4a8_marlin
...
@@ -111,7 +111,10 @@ class SlimQuantW4A8Int8MarlinConfig(QuantizationConfig):
...
@@ -111,7 +111,10 @@ class SlimQuantW4A8Int8MarlinConfig(QuantizationConfig):
if
isinstance
(
layer
,
LinearBase
):
if
isinstance
(
layer
,
LinearBase
):
return
SlimQuantW4A8Int8LinearMethod
(
self
)
return
SlimQuantW4A8Int8LinearMethod
(
self
)
elif
isinstance
(
layer
,
FusedMoE
):
elif
isinstance
(
layer
,
FusedMoE
):
return
SlimQuantW4A8Int8MarlinMoEMethod
(
self
,
layer
.
moe_config
)
if
envs
.
VLLM_ROCM_USE_AITER_MOE
:
return
SlimQuantW4A8Int8AiterMoEMethod
(
self
,
layer
.
moe_config
)
else
:
return
SlimQuantW4A8Int8MarlinMoEMethod
(
self
,
layer
.
moe_config
)
return
None
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
...
...
vllm/model_executor/layers/sparse_attn_indexer.py
View file @
c637d1aa
...
@@ -30,6 +30,7 @@ elif current_platform.is_xpu():
...
@@ -30,6 +30,7 @@ elif current_platform.is_xpu():
from
vllm._ipex_ops
import
ipex_ops
as
ops
from
vllm._ipex_ops
import
ipex_ops
as
ops
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
_GLOBAL_LOGITS_BUFFERS
=
{}
@
maybe_transfer_kv_layer
@
maybe_transfer_kv_layer
def
sparse_attn_indexer
(
def
sparse_attn_indexer
(
...
@@ -50,7 +51,21 @@ def sparse_attn_indexer(
...
@@ -50,7 +51,21 @@ def sparse_attn_indexer(
# careful! this will be None in dummy run
# careful! this will be None in dummy run
attn_metadata
=
get_forward_context
().
attn_metadata
attn_metadata
=
get_forward_context
().
attn_metadata
fp8_dtype
=
current_platform
.
fp8_dtype
()
fp8_dtype
=
current_platform
.
fp8_dtype
()
if
q_fp8
.
dtype
==
fp8_dtype
:
MAX_ELEMENTS
=
65536
*
65536
elif
q_fp8
.
dtype
in
(
torch
.
bfloat16
,
torch
.
float16
):
MAX_ELEMENTS
=
16384
*
32768
else
:
MAX_ELEMENTS
=
16384
*
32768
device
=
q_fp8
.
device
if
device
not
in
_GLOBAL_LOGITS_BUFFERS
or
_GLOBAL_LOGITS_BUFFERS
[
device
].
numel
()
<
MAX_ELEMENTS
:
_GLOBAL_LOGITS_BUFFERS
[
device
]
=
torch
.
empty
(
MAX_ELEMENTS
,
dtype
=
torch
.
float32
,
device
=
device
)
logits_buffer
=
_GLOBAL_LOGITS_BUFFERS
[
device
]
# assert isinstance(attn_metadata, dict)
# assert isinstance(attn_metadata, dict)
if
not
isinstance
(
attn_metadata
,
dict
):
if
not
isinstance
(
attn_metadata
,
dict
):
# Reserve workspace for indexer during profiling run
# Reserve workspace for indexer during profiling run
...
@@ -140,18 +155,21 @@ def sparse_attn_indexer(
...
@@ -140,18 +155,21 @@ def sparse_attn_indexer(
weights_all
=
weights
[
chunk
.
token_start
:
chunk
.
token_end
]
weights_all
=
weights
[
chunk
.
token_start
:
chunk
.
token_end
]
ks_all
=
chunk
.
cu_seqlen_ks
ks_all
=
chunk
.
cu_seqlen_ks
ke_all
=
chunk
.
cu_seqlen_ke
ke_all
=
chunk
.
cu_seqlen_ke
num_q
=
q_all
.
shape
[
0
]
num_q
=
q_all
.
shape
[
0
]
num_k
=
k_fp8
.
shape
[
0
]
num_k
=
k_fp8
.
shape
[
0
]
MAX_ELEMENTS
=
1024
*
1024
*
1024
# 4GB
is_q_fp16_bf16
=
q_all
.
dtype
in
(
torch
.
float16
,
torch
.
bfloat16
)
if
(
num_q
<=
65536
and
num_k
<=
65536
):
# if num_q <= 65536 and num_k <= 65536 and (num_q * num_k <= MAX_ELEMENTS):
align_size
=
128
if
is_q_fp16_bf16
else
1
MAX_Q_CHUNK
=
max
(
1
,
num_q
)
else
:
kv_seq_len_aligned
=
(
num_k
+
align_size
-
1
)
//
align_size
*
align_size
MAX_Q_CHUNK
=
max
(
1024
,
MAX_ELEMENTS
//
max
(
1
,
num_k
))
MAX_Q_CHUNK
=
min
(
MAX_Q_CHUNK
,
max
(
1
,
num_q
))
current_capacity
=
logits_buffer
.
numel
()
MAX_Q_CHUNK
=
current_capacity
//
max
(
1
,
kv_seq_len_aligned
)
if
align_size
>
1
:
MAX_Q_CHUNK
=
(
MAX_Q_CHUNK
//
align_size
)
*
align_size
MAX_Q_CHUNK
=
max
(
1
,
MAX_Q_CHUNK
)
#存储q的起始和终止地址
slices
=
[]
slices
=
[]
for
start_idx
in
range
(
0
,
num_q
,
MAX_Q_CHUNK
):
for
start_idx
in
range
(
0
,
num_q
,
MAX_Q_CHUNK
):
...
@@ -161,13 +179,19 @@ def sparse_attn_indexer(
...
@@ -161,13 +179,19 @@ def sparse_attn_indexer(
for
q_start
,
q_end
in
slices
:
for
q_start
,
q_end
in
slices
:
if
q_end
<=
q_start
:
if
q_end
<=
q_start
:
continue
continue
q_slice
=
q_all
[
q_start
:
q_end
]
q_slice
=
q_all
[
q_start
:
q_end
]
weights_slice
=
weights_all
[
q_start
:
q_end
]
weights_slice
=
weights_all
[
q_start
:
q_end
]
ks_slice
=
ks_all
[
q_start
:
q_end
]
ks_slice
=
ks_all
[
q_start
:
q_end
]
ke_slice
=
ke_all
[
q_start
:
q_end
]
ke_slice
=
ke_all
[
q_start
:
q_end
]
q_len
=
q_end
-
q_start
q_seq_len_aligned
=
(
q_len
+
align_size
-
1
)
//
align_size
*
align_size
required_size
=
q_seq_len_aligned
*
kv_seq_len_aligned
logits_slice_view
=
logits_buffer
[:
required_size
].
view
(
q_seq_len_aligned
,
kv_seq_len_aligned
)
if
not
current_platform
.
is_rocm
():
if
not
current_platform
.
is_rocm
():
logits_slice
=
fp8_mqa_logits
(
logits_slice
=
fp8_mqa_logits
(
q_slice
,
q_slice
,
...
@@ -177,40 +201,44 @@ def sparse_attn_indexer(
...
@@ -177,40 +201,44 @@ def sparse_attn_indexer(
ke_slice
,
ke_slice
,
)
)
elif
get_gcn_arch_name
()
==
"gfx938"
:
elif
get_gcn_arch_name
()
==
"gfx938"
:
logits_slice
=
op
.
mqa_logits
(
op
.
mqa_logits
(
q_slice
,
q_slice
,
k_fp8
,
k_fp8
,
weights_slice
,
weights_slice
,
ks_slice
,
ks_slice
,
ke_slice
,
ke_slice
,
q_slice
.
shape
[
0
],
q_slice
.
shape
[
0
],
k_fp8
.
shape
[
0
],
k_fp8
.
shape
[
0
],
q_slice
.
shape
[
1
],
q_slice
.
shape
[
1
],
q_slice
.
shape
[
2
],
q_slice
.
shape
[
2
],
k_scale
.
view
(
torch
.
float32
).
flatten
(),
k_scale
.
view
(
torch
.
float32
).
flatten
(),
True
True
,
logits_slice_view
)
)
logits_slice
=
logits_slice_view
[:
q_len
,
:
num_k
]
else
:
else
:
logits_slice
=
op
.
mqa_logits
(
op
.
mqa_logits
(
q_slice
,
q_slice
,
k_fp8
,
k_fp8
,
weights_slice
.
to
(
torch
.
float32
),
weights_slice
.
to
(
torch
.
float32
),
ks_slice
,
ks_slice
,
ke_slice
,
ke_slice
,
q_slice
.
shape
[
0
],
q_slice
.
shape
[
0
],
k_fp8
.
shape
[
0
],
k_fp8
.
shape
[
0
],
q_slice
.
shape
[
1
],
q_slice
.
shape
[
1
],
q_slice
.
shape
[
2
],
q_slice
.
shape
[
2
],
None
,
None
,
True
True
,
logits_slice_view
)
)
logits_slice
=
logits_slice_view
[:
q_len
,
:
num_k
]
num_rows_slice
=
logits_slice
.
shape
[
0
]
num_rows_slice
=
logits_slice
.
shape
[
0
]
topk_indices_slice
=
topk_indices_buffer
[
topk_indices_slice
=
topk_indices_buffer
[
chunk
.
token_start
+
q_start
:
chunk
.
token_start
+
q_end
,
:
topk_tokens
chunk
.
token_start
+
q_start
:
chunk
.
token_start
+
q_end
,
:
topk_tokens
]
]
if
not
envs
.
USE_LIGHTOP_TOPK
:
if
not
envs
.
USE_LIGHTOP_TOPK
:
torch
.
ops
.
_C
.
top_k_per_row_prefill
(
torch
.
ops
.
_C
.
top_k_per_row_prefill
(
logits_slice
,
logits_slice
,
...
@@ -460,6 +488,4 @@ class SparseAttnIndexer(CustomOp):
...
@@ -460,6 +488,4 @@ class SparseAttnIndexer(CustomOp):
self
.
max_model_len
,
self
.
max_model_len
,
self
.
max_total_seq_len
,
self
.
max_total_seq_len
,
self
.
topk_indices_buffer
,
self
.
topk_indices_buffer
,
)
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment