Commit 7f74da5a authored by lixh6's avatar lixh6
Browse files

[FEATURE] 接入Aiter MoE W8A8 量化模型支持 && MQA_logits 修改 (Ref:wanghl)

parent 3842b316
...@@ -167,6 +167,7 @@ if TYPE_CHECKING: ...@@ -167,6 +167,7 @@ if TYPE_CHECKING:
VLLM_MOE_USE_DEEP_GEMM: bool = True VLLM_MOE_USE_DEEP_GEMM: bool = True
VLLM_USE_DEEP_GEMM_E8M0: bool = True VLLM_USE_DEEP_GEMM_E8M0: bool = True
VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES: bool = True VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES: bool = True
VLLM_USE_AITER_MOE_W8A8: bool = True
VLLM_DEEP_GEMM_WARMUP: Literal[ VLLM_DEEP_GEMM_WARMUP: Literal[
"skip", "skip",
"full", "full",
...@@ -1287,6 +1288,9 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1287,6 +1288,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES": lambda: bool( "VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES": lambda: bool(
int(os.getenv("VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES", "1")) int(os.getenv("VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES", "1"))
), ),
"VLLM_USE_AITER_MOE_W8A8": lambda: bool(
int(os.getenv("VLLM_USE_AITER_MOE_W8A8", "1"))
),
# DeepGemm JITs the kernels on-demand. The warmup attempts to make DeepGemm # DeepGemm JITs the kernels on-demand. The warmup attempts to make DeepGemm
# JIT all the required kernels before model execution so there is no # JIT all the required kernels before model execution so there is no
# JIT'ing in the hot-path. However, this warmup increases the engine # JIT'ing in the hot-path. However, this warmup increases the engine
......
...@@ -6,7 +6,11 @@ import functools ...@@ -6,7 +6,11 @@ import functools
import json import json
import os import os
import math import math
import sys
import aiter
from vllm._aiter_ops import rocm_aiter_ops
from aiter.moe import get_aiter_moe_config, aiter_moe, MoeQuantType, MoeSolutionType
from aiter.ops.shuffle import moe_layout_shuffle_gemm1, moe_layout_shuffle_gemm2
from collections.abc import Callable from collections.abc import Callable
from typing import Any, Callable, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional
...@@ -1858,35 +1862,73 @@ def fused_experts_impl( ...@@ -1858,35 +1862,73 @@ def fused_experts_impl(
cache13 = torch.empty(M * top_k_num * max(N, K if not use_nn_moe else w2.shape[2]), device=hidden_states.device, dtype=hidden_states.dtype) cache13 = torch.empty(M * top_k_num * max(N, K if not use_nn_moe else w2.shape[2]), device=hidden_states.device, dtype=hidden_states.dtype)
if use_int8_w8a8 or use_fp8_w8a8: if use_int8_w8a8 or use_fp8_w8a8:
return fused_experts_impl_int8(hidden_states=hidden_states, if envs.VLLM_USE_AITER_MOE_W8A8==True:
w1=w1, K_input = hidden_states.size(1)
w2=w2, actual_N2 = N // 2
topk_weights=topk_weights, quant_type = MoeQuantType.W8A8
topk_ids=topk_ids, status, moe_config = get_aiter_moe_config(
cache13=cache13, M=num_tokens,
inplace=inplace, E=global_num_experts,
activation=activation, N1=N,
apply_router_weight_on_input=apply_router_weight_on_input, N2=actual_N2,
use_fp8_w8a8=use_fp8_w8a8, K=K_input,
use_int8_w8a8=use_int8_w8a8, top_k=top_k_num,
use_int8_w8a16=False, block_size=0,
use_int4_w4a16=False, dtype=hidden_states.dtype,
per_channel_quant=per_channel_quant, quant_type=quant_type,
global_num_experts=global_num_experts, )
expert_map=expert_map,
w1_scale=w1_scale, output = aiter_moe(
w2_scale=w2_scale, hidden_states=hidden_states,
w1_zp=w1_zp, w1=w1,
w2_zp=w2_zp, w2=w2,
a1_scale=a1_scale, topk_weights=topk_weights,
a2_scale=a2_scale, topk_ids=topk_ids,
block_shape=block_shape, moe_config=moe_config,
use_nn_moe=False, inplace=inplace,
routed_scaling_factor=routed_scaling_factor, activation=activation,
shared_output=shared_output, w1_scale=w1_scale,
i_q=i_q, w2_scale=w2_scale,
i_s=i_s w1_zp=w1_zp,
) w2_zp=w2_zp,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=None,
global_num_experts=global_num_experts,
expert_map=expert_map,
routed_scaling_factor=routed_scaling_factor,
)
return output
else:
return fused_experts_impl_int8(hidden_states=hidden_states,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
cache13=cache13,
inplace=inplace,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=False,
use_int4_w4a16=False,
per_channel_quant=per_channel_quant,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape,
use_nn_moe=False,
routed_scaling_factor=routed_scaling_factor,
shared_output=shared_output,
i_q=i_q,
i_s=i_s
)
elif use_int4_w4a8 is True: elif use_int4_w4a8 is True:
return fused_experts_impl_w4a8(hidden_states=hidden_states, return fused_experts_impl_w4a8(hidden_states=hidden_states,
w1=w1, w1=w1,
......
...@@ -26,6 +26,14 @@ from vllm.model_executor.layers.fused_moe import ( ...@@ -26,6 +26,14 @@ from vllm.model_executor.layers.fused_moe import (
FusedMoEPrepareAndFinalize, FusedMoEPrepareAndFinalize,
FusedMoeWeightScaleSupported, FusedMoeWeightScaleSupported,
) )
import aiter
from aiter.test_common import checkAllclose, perftest
from aiter.ops.shuffle import moe_layout_shuffle_gemm1, moe_layout_shuffle_gemm2
from aiter.fused_moe import fused_topk, torch_moe
from aiter import dtypes, ActivationType
from aiter.moe import get_aiter_moe_config, aiter_moe, MoeSolutionType, MoeQuantType
try: try:
from lmslim.layers.fused_moe.fuse_moe_int8_marlin import fused_experts_impl_int8_marlin from lmslim.layers.fused_moe.fuse_moe_int8_marlin import fused_experts_impl_int8_marlin
from lmslim.layers.fused_moe.fuse_moe_fp8_marlin import fused_experts_impl_fp8_marlin from lmslim.layers.fused_moe.fuse_moe_fp8_marlin import fused_experts_impl_fp8_marlin
...@@ -369,28 +377,48 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod) ...@@ -369,28 +377,48 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
layer.w13_input_scale = None layer.w13_input_scale = None
layer.w2_input_scale = None layer.w2_input_scale = None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def shuffle_w8a8_gemm1(self, weight_data):
w1_marlin_list = [] w_i8 = weight_data.to(torch.int8)
for ii in range(layer.w13_weight.shape[0]): return moe_layout_shuffle_gemm1(w_i8)
if not self.use_deepep:
w1_marlin_in = get_w8a8_int8_marlin_weights(layer.w13_weight[ii])
else:
w1_marlin_in = w8a8_nt_kpack2_marlin_weight(layer.w13_weight[ii])
w1_marlin_list.append(w1_marlin_in)
w1_marlin = torch.stack(w1_marlin_list, dim=0)
del w1_marlin_list def shuffle_w8a8_gemm2(self, weight_data):
w2_marlin_list = [] w_i8 = weight_data.to(torch.int8)
for ii in range(layer.w2_weight.shape[0]): return moe_layout_shuffle_gemm2(w_i8)
if not self.use_deepep:
w2_marlin_in = get_w8a8_int8_marlin_weights(layer.w2_weight[ii])
else:
w2_marlin_in = w8a8_nt_kpack2_marlin_weight(layer.w2_weight[ii])
w2_marlin_list.append(w2_marlin_in)
w2_marlin = torch.stack(w2_marlin_list, dim=0)
layer.w13_weight = Parameter(w1_marlin, requires_grad=False) def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.w2_weight = Parameter(w2_marlin, requires_grad=False) if envs.VLLM_USE_AITER_MOE_W8A8==True:
E, N13, K = layer.w13_weight.shape
_, K_w2, N2 = layer.w2_weight.shape
layer.w13_weight_scale = Parameter(layer.w13_weight_scale.data, requires_grad=False)
layer.w2_weight_scale = Parameter(layer.w2_weight_scale.data, requires_grad=False)
shuffled_w13 = self.shuffle_w8a8_gemm1(layer.w13_weight)
layer.w13_weight = Parameter(shuffled_w13.view(*layer.w13_weight.shape), requires_grad=False)
shuffled_w2 = self.shuffle_w8a8_gemm2(layer.w2_weight)
layer.w2_weight = Parameter(shuffled_w2.view(*layer.w2_weight.shape), requires_grad=False)
else:
w1_marlin_list = []
for ii in range(layer.w13_weight.shape[0]):
if not self.use_deepep:
w1_marlin_in = get_w8a8_int8_marlin_weights(layer.w13_weight[ii])
else:
w1_marlin_in = w8a8_nt_kpack2_marlin_weight(layer.w13_weight[ii])
w1_marlin_list.append(w1_marlin_in)
w1_marlin = torch.stack(w1_marlin_list, dim=0)
del w1_marlin_list
w2_marlin_list = []
for ii in range(layer.w2_weight.shape[0]):
if not self.use_deepep:
w2_marlin_in = get_w8a8_int8_marlin_weights(layer.w2_weight[ii])
else:
w2_marlin_in = w8a8_nt_kpack2_marlin_weight(layer.w2_weight[ii])
w2_marlin_list.append(w2_marlin_in)
w2_marlin = torch.stack(w2_marlin_list, dim=0)
layer.w13_weight = Parameter(w1_marlin, requires_grad=False)
layer.w2_weight = Parameter(w2_marlin, requires_grad=False)
def apply( def apply(
self, self,
...@@ -405,31 +433,71 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod) ...@@ -405,31 +433,71 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
routed_scaling_factor: Optional[float] = 1.0, routed_scaling_factor: Optional[float] = 1.0,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
if envs.VLLM_USE_AITER_MOE_W8A8==True:
m_flat = x.view(-1, x.shape[-1])
M = m_flat.shape[0]
E = layer.w13_weight.size(0)
K = x.size(-1)
N1 = layer.w13_weight.size(1)
topk = topk_ids.size(1)
w1_input = layer.w13_weight.view(E, N1, K)
w2_input = layer.w2_weight.view(E, K, N1 // 2)
_, moe_cfg = get_aiter_moe_config(
M=M,
E=E,
N1=N1,
N2=N1 // 2,
K=K,
top_k=topk,
block_size=0,
dtype=x.dtype,
quant_type=MoeQuantType.W8A8,
)
return fused_experts_impl_int8_marlin( output = aiter_moe(
hidden_states=x, hidden_states=x,
w1=layer.w13_weight, w1=w1_input,
w2=layer.w2_weight, w2=w2_input,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=True, moe_config=moe_cfg,
activation=layer.activation, inplace=False,
apply_router_weight_on_input=layer.apply_router_weight_on_input, activation=getattr(layer, "activation", "silu"),
use_int8_w8a8=True, w1_scale=layer.w13_weight_scale,
per_channel_quant=True, w2_scale=layer.w2_weight_scale,
global_num_experts=layer.global_num_experts, a1_scale=getattr(layer, "w13_input_scale", None),
expert_map=layer.expert_map, a2_scale=getattr(layer, "w2_input_scale", None),
quant_config=self.moe_quant_config, global_num_experts=E,
w1_scale=layer.w13_weight_scale, expert_map=getattr(layer, "expert_map", None),
w2_scale=layer.w2_weight_scale, routed_scaling_factor=routed_scaling_factor,
a1_scale=layer.w13_input_scale, )
a2_scale=layer.w2_input_scale, return output
use_nn_moe=False, else:
i_q=i_q, return fused_experts_impl_int8_marlin(
i_s=i_s, hidden_states=x,
shared_output=shared_output, w1=layer.w13_weight,
routed_scaling_factor=routed_scaling_factor, w2=layer.w2_weight,
) topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=layer.activation,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
use_int8_w8a8=True,
per_channel_quant=True,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
quant_config=self.moe_quant_config,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
use_nn_moe=False,
i_q=i_q,
i_s=i_s,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
)
def select_gemm_impl( def select_gemm_impl(
self, self,
......
...@@ -30,6 +30,7 @@ elif current_platform.is_xpu(): ...@@ -30,6 +30,7 @@ elif current_platform.is_xpu():
from vllm._ipex_ops import ipex_ops as ops from vllm._ipex_ops import ipex_ops as ops
logger = init_logger(__name__) logger = init_logger(__name__)
_GLOBAL_LOGITS_BUFFERS = {}
@maybe_transfer_kv_layer @maybe_transfer_kv_layer
def sparse_attn_indexer( def sparse_attn_indexer(
...@@ -50,7 +51,21 @@ def sparse_attn_indexer( ...@@ -50,7 +51,21 @@ def sparse_attn_indexer(
# careful! this will be None in dummy run # careful! this will be None in dummy run
attn_metadata = get_forward_context().attn_metadata attn_metadata = get_forward_context().attn_metadata
fp8_dtype = current_platform.fp8_dtype() fp8_dtype = current_platform.fp8_dtype()
if q_fp8.dtype == fp8_dtype:
MAX_ELEMENTS = 65536 * 65536
elif q_fp8.dtype in (torch.bfloat16, torch.float16):
MAX_ELEMENTS = 16384 * 32768
else:
MAX_ELEMENTS = 16384 * 32768
device = q_fp8.device
if device not in _GLOBAL_LOGITS_BUFFERS or _GLOBAL_LOGITS_BUFFERS[device].numel() < MAX_ELEMENTS:
_GLOBAL_LOGITS_BUFFERS[device] = torch.empty(
MAX_ELEMENTS,
dtype=torch.float32,
device=device
)
logits_buffer = _GLOBAL_LOGITS_BUFFERS[device]
# assert isinstance(attn_metadata, dict) # assert isinstance(attn_metadata, dict)
if not isinstance(attn_metadata, dict): if not isinstance(attn_metadata, dict):
# Reserve workspace for indexer during profiling run # Reserve workspace for indexer during profiling run
...@@ -75,7 +90,14 @@ def sparse_attn_indexer( ...@@ -75,7 +90,14 @@ def sparse_attn_indexer(
) )
attn_metadata = attn_metadata[layer_name] attn_metadata = attn_metadata[layer_name]
assert isinstance(attn_metadata, DeepseekV32IndexerMetadata) assert isinstance(attn_metadata, DeepseekV32IndexerMetadata)
slot_mapping = attn_metadata.slot_mapping # slot_mapping = attn_metadata.slot_mapping[:attn_metadata.num_kv_actual_tokens]
if hasattr(attn_metadata, 'num_kv_actual_tokens'):
num_kv_tokens = attn_metadata.num_kv_actual_tokens
elif hasattr(attn_metadata, 'num_prefills') and attn_metadata.num_prefills > 0:
num_kv_tokens = getattr(attn_metadata, 'num_prefill_tokens', attn_metadata.slot_mapping.shape[0])
else:
num_kv_tokens = attn_metadata.slot_mapping.shape[0]
slot_mapping = attn_metadata.slot_mapping[:num_kv_tokens]
has_decode = attn_metadata.num_decodes > 0 has_decode = attn_metadata.num_decodes > 0
has_prefill = attn_metadata.num_prefills > 0 has_prefill = attn_metadata.num_prefills > 0
num_decode_tokens = attn_metadata.num_decode_tokens num_decode_tokens = attn_metadata.num_decode_tokens
...@@ -116,14 +138,6 @@ def sparse_attn_indexer( ...@@ -116,14 +138,6 @@ def sparse_attn_indexer(
chunk.block_table, chunk.block_table,
chunk.cu_seq_lens, chunk.cu_seq_lens,
) )
logits = fp8_mqa_logits(
q_fp8[chunk.token_start : chunk.token_end],
(k_fp8, k_scale.view(torch.float32).flatten()),
weights[chunk.token_start : chunk.token_end],
chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke,
)
elif get_gcn_arch_name() == "gfx938": elif get_gcn_arch_name() == "gfx938":
k_fp8 = k_fp8_full[: chunk.total_seq_lens] k_fp8 = k_fp8_full[: chunk.total_seq_lens]
k_scale = k_scale_full[: chunk.total_seq_lens] k_scale = k_scale_full[: chunk.total_seq_lens]
...@@ -134,19 +148,6 @@ def sparse_attn_indexer( ...@@ -134,19 +148,6 @@ def sparse_attn_indexer(
chunk.block_table, chunk.block_table,
chunk.cu_seq_lens, chunk.cu_seq_lens,
) )
logits = op.mqa_logits(
q_fp8[chunk.token_start:chunk.token_end],
k_fp8,
weights[chunk.token_start:chunk.token_end],
chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke,
q_fp8[chunk.token_start:chunk.token_end].shape[0],
k_fp8.shape[0],
q_fp8.shape[1],
q_fp8.shape[2],
k_scale.view(torch.float32).flatten(),
True
)
else: else:
k_fp8 = k_fp8_full[: chunk.total_seq_lens] k_fp8 = k_fp8_full[: chunk.total_seq_lens]
k_scale = k_scale_full[: chunk.total_seq_lens] k_scale = k_scale_full[: chunk.total_seq_lens]
...@@ -156,46 +157,117 @@ def sparse_attn_indexer( ...@@ -156,46 +157,117 @@ def sparse_attn_indexer(
chunk.block_table, chunk.block_table,
chunk.cu_seq_lens, chunk.cu_seq_lens,
) )
logits = op.mqa_logits(
q_fp8[chunk.token_start:chunk.token_end], q_all = q_fp8[chunk.token_start:chunk.token_end]
k_fp8, weights_all = weights[chunk.token_start:chunk.token_end]
weights[chunk.token_start:chunk.token_end].to(torch.float32), ks_all = chunk.cu_seqlen_ks
chunk.cu_seqlen_ks, ke_all = chunk.cu_seqlen_ke
chunk.cu_seqlen_ke,
q_fp8[chunk.token_start:chunk.token_end].shape[0], num_q = q_all.shape[0]
k_fp8.shape[0], num_k = k_fp8.shape[0]
q_fp8.shape[1],
q_fp8.shape[2], is_q_fp16_bf16 = q_all.dtype in (torch.float16, torch.bfloat16)
None, align_size = 128 if is_q_fp16_bf16 else 1
True
kv_seq_len_aligned = (num_k + align_size - 1) // align_size * align_size
current_capacity = logits_buffer.numel()
MAX_Q_CHUNK = current_capacity // max(1, kv_seq_len_aligned)
if align_size > 1:
MAX_Q_CHUNK = (MAX_Q_CHUNK // align_size) * align_size
MAX_Q_CHUNK = max(1, MAX_Q_CHUNK)
slices = []
for start_idx in range(0, num_q, MAX_Q_CHUNK):
end_idx = min(start_idx + MAX_Q_CHUNK, num_q)
slices.append((start_idx, end_idx))
for q_start, q_end in slices:
if q_end <= q_start:
continue
q_slice = q_all[q_start:q_end]
weights_slice = weights_all[q_start:q_end]
ks_slice = ks_all[q_start:q_end]
ke_slice = ke_all[q_start:q_end]
q_len = q_end - q_start
q_seq_len_aligned = (q_len + align_size - 1) // align_size * align_size
required_size = q_seq_len_aligned * kv_seq_len_aligned
logits_slice_view = logits_buffer[:required_size].view(q_seq_len_aligned, kv_seq_len_aligned)
if not current_platform.is_rocm():
logits_slice = fp8_mqa_logits(
q_slice,
(k_fp8, k_scale.view(torch.float32).flatten()),
weights_slice,
ks_slice,
ke_slice,
)
elif get_gcn_arch_name() == "gfx938":
op.mqa_logits(
q_slice,
k_fp8,
weights_slice,
ks_slice,
ke_slice,
q_slice.shape[0],
k_fp8.shape[0],
q_slice.shape[1],
q_slice.shape[2],
k_scale.view(torch.float32).flatten(),
True,
logits_slice_view
)
logits_slice = logits_slice_view[:q_len, :num_k]
else:
op.mqa_logits(
q_slice,
k_fp8,
weights_slice.to(torch.float32),
ks_slice,
ke_slice,
q_slice.shape[0],
k_fp8.shape[0],
q_slice.shape[1],
q_slice.shape[2],
None,
True,
logits_slice_view
)
logits_slice = logits_slice_view[:q_len, :num_k]
num_rows_slice = logits_slice.shape[0]
topk_indices_slice = topk_indices_buffer[
chunk.token_start + q_start : chunk.token_start + q_end, :topk_tokens
]
if not envs.USE_LIGHTOP_TOPK:
torch.ops._C.top_k_per_row_prefill(
logits_slice,
ks_slice,
ke_slice,
topk_indices_slice,
num_rows_slice,
logits_slice.stride(0),
logits_slice.stride(1),
topk_tokens,
)
else:
op.top_k_per_row_prefill(
logits_slice,
ks_slice,
ke_slice,
topk_indices_slice,
num_rows_slice,
logits_slice.stride(0),
logits_slice.stride(1),
topk_tokens,
) )
num_rows = logits.shape[0]
topk_indices = topk_indices_buffer[
chunk.token_start : chunk.token_end, :topk_tokens
]
if not envs.USE_LIGHTOP_TOPK:
torch.ops._C.top_k_per_row_prefill(
logits,
chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke,
topk_indices,
num_rows,
logits.stride(0),
logits.stride(1),
topk_tokens,
)
else:
op.top_k_per_row_prefill(
logits,
chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke,
topk_indices,
num_rows,
logits.stride(0),
logits.stride(1),
topk_tokens,
)
if has_decode: if has_decode:
decode_metadata = attn_metadata.decode decode_metadata = attn_metadata.decode
...@@ -423,6 +495,4 @@ class SparseAttnIndexer(CustomOp): ...@@ -423,6 +495,4 @@ class SparseAttnIndexer(CustomOp):
self.max_model_len, self.max_model_len,
self.max_total_seq_len, self.max_total_seq_len,
self.topk_indices_buffer, self.topk_indices_buffer,
) )
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment