Commit 561b6cbb authored by 王敏's avatar 王敏
Browse files

merge dev主干代码

parents 0beafe40 ce47a56e
...@@ -7,7 +7,7 @@ requests >= 2.26.0 ...@@ -7,7 +7,7 @@ requests >= 2.26.0
tqdm tqdm
blake3 blake3
py-cpuinfo py-cpuinfo
transformers >= 4.56.0, < 5 transformers == 5.2.0
tokenizers >= 0.21.1 # Required for fast incremental detokenization. tokenizers >= 0.21.1 # Required for fast incremental detokenization.
protobuf >= 6.33.5 # Required by LlamaTokenizer, gRPC. CVE-2026-0994 protobuf >= 6.33.5 # Required by LlamaTokenizer, gRPC. CVE-2026-0994
fastapi[standard] >= 0.115.0 # Required by FastAPI's form models in the OpenAI API server's audio transcriptions endpoint. fastapi[standard] >= 0.115.0 # Required by FastAPI's form models in the OpenAI API server's audio transcriptions endpoint.
......
...@@ -26,7 +26,7 @@ fastrlock==0.8.3 ...@@ -26,7 +26,7 @@ fastrlock==0.8.3
torch == 2.9.0 torch == 2.9.0
triton == 3.3.0 triton == 3.3.0
flash_attn == 2.6.1 flash_attn == 2.8.3
flash_mla == 1.0.0 flash_mla == 1.0.0
lightop == 0.6.0 lightop == 0.6.0
lmslim == 0.3.1 lmslim == 0.3.1
...@@ -370,7 +370,8 @@ def rms_norm_opt_fake( ...@@ -370,7 +370,8 @@ def rms_norm_opt_fake(
def fused_add_rms_norm_opt(input: torch.Tensor, residual: torch.Tensor, def fused_add_rms_norm_opt(input: torch.Tensor, residual: torch.Tensor,
weight: torch.Tensor, epsilon: float, training: Optional[bool]=False, inplace: Optional[bool]=True) -> None: weight: torch.Tensor, epsilon: float, training: Optional[bool]=False, inplace: Optional[bool]=True) -> None:
op.rn_add_forward_autograd(input, residual, weight, epsilon, training, inplace) op.fused_add_rms_norm_opt(input, residual, weight, epsilon)
#op.rn_add_forward_autograd(input, residual, weight, epsilon, training, inplace)
def fused_add_rms_norm_opt_fake( def fused_add_rms_norm_opt_fake(
input: torch.Tensor, input: torch.Tensor,
...@@ -379,8 +380,8 @@ def fused_add_rms_norm_opt_fake( ...@@ -379,8 +380,8 @@ def fused_add_rms_norm_opt_fake(
epsilon: float, epsilon: float,
training: Optional[bool] = False, training: Optional[bool] = False,
inplace: Optional[bool] = False, inplace: Optional[bool] = False,
) -> torch.Tensor: ) -> None:
return torch.empty_like(input) return None
def fused_qk_norm_rope( def fused_qk_norm_rope(
qkv: torch.Tensor, qkv: torch.Tensor,
...@@ -3626,7 +3627,7 @@ direct_register_custom_op( ...@@ -3626,7 +3627,7 @@ direct_register_custom_op(
direct_register_custom_op( direct_register_custom_op(
op_name="fused_add_rms_norm_opt", op_name="fused_add_rms_norm_opt",
op_func=fused_add_rms_norm_opt, op_func=fused_add_rms_norm_opt,
mutates_args=[], mutates_args=["input", "residual"],
fake_impl=fused_add_rms_norm_opt_fake, fake_impl=fused_add_rms_norm_opt_fake,
) )
......
...@@ -245,7 +245,6 @@ class Attention(nn.Module, AttentionLayerBase): ...@@ -245,7 +245,6 @@ class Attention(nn.Module, AttentionLayerBase):
use_mla=False, use_mla=False,
has_sink=self.has_sink, has_sink=self.has_sink,
use_mm_prefix=self.use_mm_prefix, use_mm_prefix=self.use_mm_prefix,
use_alibi_sqrt=bool(use_alibi_sqrt),
attn_type=attn_type, attn_type=attn_type,
) )
else: else:
...@@ -1274,4 +1273,4 @@ direct_register_custom_op( ...@@ -1274,4 +1273,4 @@ direct_register_custom_op(
mutates_args=["qkv", "positions"], mutates_args=["qkv", "positions"],
fake_impl=fused_qkv_split_rmsnorm_rope_kv_store_fake, fake_impl=fused_qkv_split_rmsnorm_rope_kv_store_fake,
tags=(torch.Tag.needs_fixed_stride_order,), tags=(torch.Tag.needs_fixed_stride_order,),
) )
\ No newline at end of file
...@@ -323,10 +323,9 @@ if TYPE_CHECKING: ...@@ -323,10 +323,9 @@ if TYPE_CHECKING:
USE_LIGHTOP_PER_TOKEN_GROUP_QUANT_FP8: bool = False USE_LIGHTOP_PER_TOKEN_GROUP_QUANT_FP8: bool = False
USE_LIGHTOP_TOPK: bool = False USE_LIGHTOP_TOPK: bool = False
USE_LIGHTOP_CONVERT_REQ_INDEX_TO_GLOBAL_INDEX: bool = False USE_LIGHTOP_CONVERT_REQ_INDEX_TO_GLOBAL_INDEX: bool = False
VLLM_DISABLE_DSA: bool = False
VLLM_MLA_CP: bool = False VLLM_MLA_CP: bool = False
VLLM_MLA_CPLB: bool = False VLLM_MLA_CPLB: bool = False
def get_default_cache_root(): def get_default_cache_root():
return os.getenv( return os.getenv(
"XDG_CACHE_HOME", "XDG_CACHE_HOME",
...@@ -2005,16 +2004,18 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -2005,16 +2004,18 @@ environment_variables: dict[str, Callable[[], Any]] = {
("true", "1")), ("true", "1")),
"USE_LIGHTOP_CONVERT_REQ_INDEX_TO_GLOBAL_INDEX": "USE_LIGHTOP_CONVERT_REQ_INDEX_TO_GLOBAL_INDEX":
lambda: (os.environ.get("USE_LIGHTOP_CONVERT_REQ_INDEX_TO_GLOBAL_INDEX", "False").lower() in lambda: (os.environ.get("USE_LIGHTOP_CONVERT_REQ_INDEX_TO_GLOBAL_INDEX", "False").lower() in
("true", "1")), ("true", "1")),
#If set to 1/True, disenable the DSA.
# If set to 1/True, enable mla context parallel "VLLM_DISABLE_DSA":
lambda: (os.environ.get("VLLM_DISABLE_DSA", "False").lower() in
("true", "1")),
# If set to 1/True, enable mla context parallel
"VLLM_MLA_CP": "VLLM_MLA_CP":
lambda: (os.environ.get("VLLM_MLA_CP", "False").lower() in lambda: (os.environ.get("VLLM_MLA_CP", "False").lower() in
("true", "1")), ("true", "1")),
"VLLM_MLA_CPLB": "VLLM_MLA_CPLB":
lambda: (os.environ.get("VLLM_MLA_CPLB", "False").lower() in lambda: (os.environ.get("VLLM_MLA_CPLB", "False").lower() in
("true", "1")), ("true", "1")),
} }
# --8<-- [end:env-vars-definition] # --8<-- [end:env-vars-definition]
......
...@@ -98,6 +98,8 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): ...@@ -98,6 +98,8 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
use_nn_moe: bool | None = False, use_nn_moe: bool | None = False,
shared_output: torch.Tensor | None = None,
routed_scaling_factor: float = 1.0,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
return self.fused_experts( return self.fused_experts(
hidden_states=x, hidden_states=x,
...@@ -110,4 +112,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): ...@@ -110,4 +112,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
global_num_experts=layer.global_num_experts, global_num_experts=layer.global_num_experts,
apply_router_weight_on_input=layer.apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
expert_map=None if self.disable_expert_map else layer.expert_map, expert_map=None if self.disable_expert_map else layer.expert_map,
) use_nn_moe=use_nn_moe,
\ No newline at end of file shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
)
...@@ -735,6 +735,8 @@ class FusedMoEPermuteExpertsUnpermute(ABC): ...@@ -735,6 +735,8 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
workspace2: torch.Tensor, workspace2: torch.Tensor,
expert_tokens_meta: ExpertTokensMetadata | None, expert_tokens_meta: ExpertTokensMetadata | None,
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
shared_output: torch.Tensor | None = None,
routed_scaling_factor: float = 1.0,
) -> None: ) -> None:
""" """
This function computes the intermediate result of a Mixture of Experts This function computes the intermediate result of a Mixture of Experts
...@@ -1155,6 +1157,8 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -1155,6 +1157,8 @@ class FusedMoEModularKernel(torch.nn.Module):
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
expert_tokens_meta: ExpertTokensMetadata | None, expert_tokens_meta: ExpertTokensMetadata | None,
use_nn_moe: bool | None = False, use_nn_moe: bool | None = False,
shared_output: torch.Tensor | None = None,
routed_scaling_factor: float = 1.0,
) -> torch.Tensor: ) -> torch.Tensor:
_, M_full, N, K, top_k = self.fused_experts.moe_problem_size( _, M_full, N, K, top_k = self.fused_experts.moe_problem_size(
a1q, w1, w2, topk_ids a1q, w1, w2, topk_ids
...@@ -1216,7 +1220,13 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -1216,7 +1220,13 @@ class FusedMoEModularKernel(torch.nn.Module):
c_fused_out = self._slice_output_tensor( c_fused_out = self._slice_output_tensor(
fused_out, chunk_idx, num_chunks, CHUNK_SIZE, M_full fused_out, chunk_idx, num_chunks, CHUNK_SIZE, M_full
) )
c_shared_output = (
None
if shared_output is None
else self._slice_output_tensor(
shared_output, chunk_idx, num_chunks, CHUNK_SIZE, M_full
)
)
self.fused_experts.apply( self.fused_experts.apply(
output=c_fused_out, output=c_fused_out,
hidden_states=a1q[s:e], hidden_states=a1q[s:e],
...@@ -1234,6 +1244,8 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -1234,6 +1244,8 @@ class FusedMoEModularKernel(torch.nn.Module):
expert_tokens_meta=c_expert_tokens_meta, expert_tokens_meta=c_expert_tokens_meta,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
use_nn_moe=use_nn_moe, use_nn_moe=use_nn_moe,
shared_output=c_shared_output,
routed_scaling_factor=routed_scaling_factor,
) )
return fused_out return fused_out
...@@ -1246,13 +1258,12 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -1246,13 +1258,12 @@ class FusedMoEModularKernel(torch.nn.Module):
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
shared_output: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
""" """
The _finalize method is a wrapper around self.prepare_finalize.finalize The _finalize method is a wrapper around self.prepare_finalize.finalize
that handles DBO, async and shared expert overlap. that handles DBO, async and shared expert overlap.
""" """
shared_output: torch.Tensor | None = None
if not self.prepare_finalize.supports_async(): if not self.prepare_finalize.supports_async():
assert not dbo_enabled() assert not dbo_enabled()
...@@ -1264,11 +1275,11 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -1264,11 +1275,11 @@ class FusedMoEModularKernel(torch.nn.Module):
apply_router_weight_on_input, apply_router_weight_on_input,
self.fused_experts.finalize_weight_and_reduce_impl(), self.fused_experts.finalize_weight_and_reduce_impl(),
) )
if self.shared_experts is not None: if shared_output is None and self.shared_experts is not None:
shared_output = self.shared_experts(hidden_states) shared_output = self.shared_experts(hidden_states)
else: else:
self.alt_event.record() self.alt_event.record()
if self.shared_experts is not None: if shared_output is None and self.shared_experts is not None:
shared_output = self.shared_experts(hidden_states) shared_output = self.shared_experts(hidden_states)
current_stream = torch.cuda.current_stream() current_stream = torch.cuda.current_stream()
...@@ -1327,6 +1338,8 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -1327,6 +1338,8 @@ class FusedMoEModularKernel(torch.nn.Module):
expert_map: torch.Tensor | None = None, expert_map: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
use_nn_moe: bool | None = False, use_nn_moe: bool | None = False,
shared_output: torch.Tensor | None = None,
routed_scaling_factor: float = 1.0,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
""" """
This function computes a Mixture of Experts (MoE) layer using two sets This function computes a Mixture of Experts (MoE) layer using two sets
...@@ -1389,6 +1402,8 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -1389,6 +1402,8 @@ class FusedMoEModularKernel(torch.nn.Module):
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
expert_tokens_meta=expert_tokens_meta, expert_tokens_meta=expert_tokens_meta,
use_nn_moe=use_nn_moe, use_nn_moe=use_nn_moe,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
) )
return self._finalize( return self._finalize(
...@@ -1398,4 +1413,5 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -1398,4 +1413,5 @@ class FusedMoEModularKernel(torch.nn.Module):
topk_weights, topk_weights,
topk_ids, topk_ids,
apply_router_weight_on_input, apply_router_weight_on_input,
shared_output=shared_output,
) )
\ No newline at end of file
...@@ -57,9 +57,8 @@ def fused_add_rms_norm( ...@@ -57,9 +57,8 @@ def fused_add_rms_norm(
return rms_norm_batch_invariant( return rms_norm_batch_invariant(
x + residual, weight, variance_epsilon x + residual, weight, variance_epsilon
), x + residual ), x + residual
# if envs.VLLM_USE_OPT_OP: if envs.VLLM_USE_OPT_OP:
if False: torch.ops.vllm.fused_add_rms_norm_opt(
ops.fused_add_rms_norm_opt(
x, x,
residual, residual,
weight, weight,
......
...@@ -271,7 +271,7 @@ class UnquantizedLinearMethod(LinearMethodBase): ...@@ -271,7 +271,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
bias: torch.Tensor | None = None, bias: torch.Tensor | None = None, **_
) -> torch.Tensor: ) -> torch.Tensor:
if self.use_llama_nn: if self.use_llama_nn:
# if os.environ['GEMM_PAD'] == '1' and gemm_bank_conf(layer.weight.shape[1] - 32): # if os.environ['GEMM_PAD'] == '1' and gemm_bank_conf(layer.weight.shape[1] - 32):
...@@ -458,11 +458,15 @@ class ReplicatedLinear(LinearBase): ...@@ -458,11 +458,15 @@ class ReplicatedLinear(LinearBase):
def forward( def forward(
self, self,
x: torch.Tensor, x: torch.Tensor,
*,
iqis: tuple[torch.Tensor, torch.Tensor] | None = None
) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]: ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None assert self.quant_method is not None
if envs.USE_FUSED_RMS_QUANT and iqis is not None and iqis[0] is not None:
output = self.quant_method.apply(self, x, bias) output = self.quant_method.apply(self, x, bias, input_quant_args=iqis)
else:
output = self.quant_method.apply(self, x, bias)
if not self.return_bias: if not self.return_bias:
return output return output
......
...@@ -181,9 +181,10 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer): ...@@ -181,9 +181,10 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
) )
if self.indexer and self.is_sparse: if self.indexer and self.is_sparse:
_topk_indices = self.indexer( if envs.USE_FUSED_RMS_QUANT and iqis is not None:
hidden_states, q_c, positions, self.indexer_rope_emb _topk_indices = self.indexer(hidden_states, q_c, positions, self.indexer_rope_emb, iqis=iqis)
) else:
_topk_indices = self.indexer(hidden_states, q_c, positions, self.indexer_rope_emb)
if llama_4_scaling is not None: if llama_4_scaling is not None:
q *= llama_4_scaling q *= llama_4_scaling
......
...@@ -31,7 +31,6 @@ elif current_platform.is_xpu(): ...@@ -31,7 +31,6 @@ elif current_platform.is_xpu():
logger = init_logger(__name__) logger = init_logger(__name__)
@maybe_transfer_kv_layer @maybe_transfer_kv_layer
def sparse_attn_indexer( def sparse_attn_indexer(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -141,18 +140,18 @@ def sparse_attn_indexer( ...@@ -141,18 +140,18 @@ def sparse_attn_indexer(
weights_all = weights[chunk.token_start:chunk.token_end] weights_all = weights[chunk.token_start:chunk.token_end]
ks_all = chunk.cu_seqlen_ks ks_all = chunk.cu_seqlen_ks
ke_all = chunk.cu_seqlen_ke ke_all = chunk.cu_seqlen_ke
num_q = q_all.shape[0] num_q = q_all.shape[0]
num_k = k_fp8.shape[0] num_k = k_fp8.shape[0]
MAX_ELEMENTS = 1024 * 1024 * 1024 # 4GB MAX_ELEMENTS = 1024 * 1024 * 1024 # 4GB
if (num_q <= 65536 and num_k <= 65536): # if num_q <= 65536 and num_k <= 65536 and (num_q * num_k <= MAX_ELEMENTS): if (num_q <= 65536 and num_k <= 65536): # if num_q <= 65536 and num_k <= 65536 and (num_q * num_k <= MAX_ELEMENTS):
MAX_Q_CHUNK = max(1, num_q) MAX_Q_CHUNK = max(1, num_q)
else: else:
MAX_Q_CHUNK = max(1024, MAX_ELEMENTS // max(1, num_k)) MAX_Q_CHUNK = max(1024, MAX_ELEMENTS // max(1, num_k))
MAX_Q_CHUNK = min(MAX_Q_CHUNK, max(1, num_q)) MAX_Q_CHUNK = min(MAX_Q_CHUNK, max(1, num_q))
#存储q的起始和终止地址 #存储q的起始和终止地址
slices = [] slices = []
for start_idx in range(0, num_q, MAX_Q_CHUNK): for start_idx in range(0, num_q, MAX_Q_CHUNK):
...@@ -162,7 +161,7 @@ def sparse_attn_indexer( ...@@ -162,7 +161,7 @@ def sparse_attn_indexer(
for q_start, q_end in slices: for q_start, q_end in slices:
if q_end <= q_start: if q_end <= q_start:
continue continue
q_slice = q_all[q_start:q_end] q_slice = q_all[q_start:q_end]
weights_slice = weights_all[q_start:q_end] weights_slice = weights_all[q_start:q_end]
...@@ -179,10 +178,10 @@ def sparse_attn_indexer( ...@@ -179,10 +178,10 @@ def sparse_attn_indexer(
) )
elif get_gcn_arch_name() == "gfx938": elif get_gcn_arch_name() == "gfx938":
logits_slice = op.mqa_logits( logits_slice = op.mqa_logits(
q_slice, q_slice,
k_fp8, k_fp8,
weights_slice, weights_slice,
ks_slice, ks_slice,
ke_slice, ke_slice,
q_slice.shape[0], q_slice.shape[0],
k_fp8.shape[0], k_fp8.shape[0],
...@@ -193,10 +192,10 @@ def sparse_attn_indexer( ...@@ -193,10 +192,10 @@ def sparse_attn_indexer(
) )
else: else:
logits_slice = op.mqa_logits( logits_slice = op.mqa_logits(
q_slice, q_slice,
k_fp8, k_fp8,
weights_slice.to(torch.float32), weights_slice.to(torch.float32),
ks_slice, ks_slice,
ke_slice, ke_slice,
q_slice.shape[0], q_slice.shape[0],
k_fp8.shape[0], k_fp8.shape[0],
...@@ -207,11 +206,11 @@ def sparse_attn_indexer( ...@@ -207,11 +206,11 @@ def sparse_attn_indexer(
) )
num_rows_slice = logits_slice.shape[0] num_rows_slice = logits_slice.shape[0]
topk_indices_slice = topk_indices_buffer[ topk_indices_slice = topk_indices_buffer[
chunk.token_start + q_start : chunk.token_start + q_end, :topk_tokens chunk.token_start + q_start : chunk.token_start + q_end, :topk_tokens
] ]
if not envs.USE_LIGHTOP_TOPK: if not envs.USE_LIGHTOP_TOPK:
torch.ops._C.top_k_per_row_prefill( torch.ops._C.top_k_per_row_prefill(
logits_slice, logits_slice,
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
from copy import deepcopy from copy import deepcopy
from math import lcm from math import lcm
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from vllm import envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models import ModelRegistry
...@@ -554,7 +555,8 @@ class DeepseekV32ForCausalLM(VerifyAndUpdateConfig): ...@@ -554,7 +555,8 @@ class DeepseekV32ForCausalLM(VerifyAndUpdateConfig):
# For DeepSeekV3.2, a custom fp8 format is used when fp8 kv-cache is enabled. # For DeepSeekV3.2, a custom fp8 format is used when fp8 kv-cache is enabled.
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
if cache_config.cache_dtype.startswith("fp8"): force_disable_dsa = envs.VLLM_DISABLE_DSA
if cache_config.cache_dtype.startswith("fp8") and not force_disable_dsa:
cache_config.cache_dtype = "fp8_ds_mla" cache_config.cache_dtype = "fp8_ds_mla"
logger.info("Using custom fp8 kv-cache format for DeepSeekV3.2") logger.info("Using custom fp8 kv-cache format for DeepSeekV3.2")
if cache_config.cache_dtype == "bfloat16": if cache_config.cache_dtype == "bfloat16":
......
...@@ -84,7 +84,7 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module): ...@@ -84,7 +84,7 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
self.device = current_platform.device_type self.device = current_platform.device_type
#添加判断,默认开启DSA #添加判断,默认开启DSA
force_disable_dsa = os.environ.get("VLLM_DISABLE_DSA", "0") == "1" force_disable_dsa = envs.VLLM_DISABLE_DSA
self.is_v32 = hasattr(config, "index_topk") and not force_disable_dsa self.is_v32 = hasattr(config, "index_topk") and not force_disable_dsa
if self.is_v32: if self.is_v32:
topk_tokens = config.index_topk topk_tokens = config.index_topk
......
...@@ -815,15 +815,18 @@ class Indexer(nn.Module): ...@@ -815,15 +815,18 @@ class Indexer(nn.Module):
) )
def forward( def forward(
self, hidden_states: torch.Tensor, qr: torch.Tensor, positions, rotary_emb self, hidden_states: torch.Tensor, qr: torch.Tensor, positions, rotary_emb,
*, iqis: tuple[torch.Tensor, torch.Tensor] | None = None
) -> torch.Tensor: ) -> torch.Tensor:
q, _ = self.wq_b(qr) q, _ = self.wq_b(qr)
q = q.view(-1, self.n_head, self.head_dim) q = q.view(-1, self.n_head, self.head_dim)
q_pe, q_nope = torch.split( q_pe, q_nope = torch.split(
q, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1 q, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1
) )
if envs.USE_FUSED_RMS_QUANT and self.wk.weight.dtype == torch.int8 and iqis is not None:
k, _ = self.wk(hidden_states) k, _ = self.wk(hidden_states, iqis=iqis)
else:
k, _ = self.wk(hidden_states)
k = self.k_norm(k) k = self.k_norm(k)
k_pe, k_nope = torch.split( k_pe, k_nope = torch.split(
k, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1 k, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1
...@@ -861,7 +864,10 @@ class Indexer(nn.Module): ...@@ -861,7 +864,10 @@ class Indexer(nn.Module):
else: else:
q_fp8 = q q_fp8 = q
weights, _ = self.weights_proj(hidden_states) if envs.USE_FUSED_RMS_QUANT and self.weights_proj.weight.dtype == torch.int8 and iqis is not None:
weights, _ = self.weights_proj(hidden_states, iqis=iqis)
else:
weights, _ = self.weights_proj(hidden_states)
if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938": if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
weights = ( weights = (
weights.unsqueeze(-1) * q_scale * self.softmax_scale * self.n_head**-0.5 weights.unsqueeze(-1) * q_scale * self.softmax_scale * self.n_head**-0.5
...@@ -997,7 +1003,7 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -997,7 +1003,7 @@ class DeepseekV2MLAAttention(nn.Module):
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
self.scaling = self.scaling * mscale * mscale self.scaling = self.scaling * mscale * mscale
#添加判断,默认开启DSA #添加判断,默认开启DSA
force_disable_dsa = os.environ.get("VLLM_DISABLE_DSA", "0") == "1" force_disable_dsa = envs.VLLM_DISABLE_DSA
self.is_v32 = hasattr(config, "index_topk") and not force_disable_dsa self.is_v32 = hasattr(config, "index_topk") and not force_disable_dsa
if self.is_v32: if self.is_v32:
...@@ -1169,19 +1175,21 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1169,19 +1175,21 @@ class DeepseekV2DecoderLayer(nn.Module):
# Fix residual FP16 overflow # Fix residual FP16 overflow
residual_fix_overflow = False residual_fix_overflow = False
assert self.input_layernorm.has_weight is True assert self.input_layernorm.has_weight is True
# DSA should set update_input True
_dsa_flag = hasattr(self.self_attn, "indexer") and self.self_attn.indexer is not None
if residual is None: if residual is None:
residual = hidden_states.clone() residual = hidden_states.clone()
i_q, i_s, _ = self.input_layernorm(x=hidden_states, i_q, i_s, _ = self.input_layernorm(x=hidden_states,
residual=None, residual=None,
quant_dtype=torch.int8, quant_dtype=torch.int8,
update_input=False update_input=_dsa_flag
) )
residual_fix_overflow = True residual_fix_overflow = True
else: else:
i_q, i_s, residual = self.input_layernorm(x=hidden_states, i_q, i_s, residual = self.input_layernorm(x=hidden_states,
residual=residual, residual=residual,
quant_dtype=torch.int8, quant_dtype=torch.int8,
update_input=False update_input=_dsa_flag
) )
attn_kwargs = { attn_kwargs = {
"positions": positions, "positions": positions,
...@@ -1318,7 +1326,7 @@ class DeepseekV2Model(nn.Module): ...@@ -1318,7 +1326,7 @@ class DeepseekV2Model(nn.Module):
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
#添加判断,默认开启DSA #添加判断,默认开启DSA
force_disable_dsa = os.environ.get("VLLM_DISABLE_DSA", "0") == "1" force_disable_dsa = envs.VLLM_DISABLE_DSA
self.is_v32 = hasattr(config, "index_topk") and not force_disable_dsa self.is_v32 = hasattr(config, "index_topk") and not force_disable_dsa
if self.is_v32: if self.is_v32:
......
...@@ -262,7 +262,6 @@ class RocmPlatform(Platform): ...@@ -262,7 +262,6 @@ class RocmPlatform(Platform):
from vllm._aiter_ops import rocm_aiter_ops from vllm._aiter_ops import rocm_aiter_ops
block_size = attn_selector_config.block_size block_size = attn_selector_config.block_size
head_size = attn_selector_config.head_size
kv_cache_dtype = attn_selector_config.kv_cache_dtype kv_cache_dtype = attn_selector_config.kv_cache_dtype
if attn_selector_config.use_sparse: if attn_selector_config.use_sparse:
...@@ -305,36 +304,9 @@ class RocmPlatform(Platform): ...@@ -305,36 +304,9 @@ class RocmPlatform(Platform):
f"is not MLA type while requested for MLA backend." f"is not MLA type while requested for MLA backend."
) )
is_non64_block_multiple_64 = (
block_size != 64 if envs.VLLM_USE_FLASH_ATTN_PA and block_size == 64:
and block_size % 64 == 0 logger.info_once("Using Flash Attention backend on V1 engine. (only supports block size 64)")
)
use_unified_flash = (
is_non64_block_multiple_64
and head_size == 256
)
if (
envs.VLLM_USE_FLASH_ATTN_PA
and is_non64_block_multiple_64
and head_size != 256
):
logger.info_once(
"Skip unified varlen kernel on V1 engine: head size %d is "
"unsupported (requires 256).",
head_size,
)
if envs.VLLM_USE_FLASH_ATTN_PA and (block_size == 64 or use_unified_flash):
if use_unified_flash and block_size != 64:
logger.info_once(
"Using Flash Attention backend with unified varlen kernel on "
"V1 engine. (block size %d, requires block size divisible by 64)",
block_size,
)
else:
logger.info_once(
"Using Flash Attention backend on V1 engine. "
"(only supports block size 64)"
)
return AttentionBackendEnum.FLASH_ATTN.get_path() return AttentionBackendEnum.FLASH_ATTN.get_path()
else: else:
os.environ['VLLM_USE_FLASH_ATTN_PA'] = '0' os.environ['VLLM_USE_FLASH_ATTN_PA'] = '0'
......
...@@ -243,7 +243,10 @@ class Qwen3CoderToolParser(ToolParser): ...@@ -243,7 +243,10 @@ class Qwen3CoderToolParser(ToolParser):
self, function_call_str: str, tools: list[ChatCompletionToolsParam] | None self, function_call_str: str, tools: list[ChatCompletionToolsParam] | None
) -> ToolCall | None: ) -> ToolCall | None:
# Extract function name # Extract function name
end_index = function_call_str.index(">") end_index = function_call_str.find(">")
# If there's no ">" character, this is not a valid xml function call
if end_index == -1:
return None
function_name = function_call_str[:end_index] function_name = function_call_str[:end_index]
param_config = self._get_arguments_config(function_name, tools) param_config = self._get_arguments_config(function_name, tools)
parameters = function_call_str[end_index + 1 :] parameters = function_call_str[end_index + 1 :]
...@@ -327,10 +330,10 @@ class Qwen3CoderToolParser(ToolParser): ...@@ -327,10 +330,10 @@ class Qwen3CoderToolParser(ToolParser):
idx = model_output.find(self.tool_call_prefix) idx = model_output.find(self.tool_call_prefix)
content_index = content_index if content_index >= 0 else idx content_index = content_index if content_index >= 0 else idx
content = model_output[:content_index] # .rstrip() content = model_output[:content_index] # .rstrip()
valid_tool_calls = [tc for tc in tool_calls if tc is not None]
return ExtractedToolCallInformation( return ExtractedToolCallInformation(
tools_called=(len(tool_calls) > 0), tools_called=(len(valid_tool_calls) > 0),
tool_calls=tool_calls, tool_calls=valid_tool_calls,
content=content if content else None, content=content if content else None,
) )
......
...@@ -225,7 +225,6 @@ class AttentionBackend(ABC): ...@@ -225,7 +225,6 @@ class AttentionBackend(ABC):
has_sink: bool, has_sink: bool,
use_sparse: bool, use_sparse: bool,
use_mm_prefix: bool, use_mm_prefix: bool,
use_alibi_sqrt: bool,
device_capability: "DeviceCapability", device_capability: "DeviceCapability",
attn_type: str, attn_type: str,
) -> list[str]: ) -> list[str]:
...@@ -242,8 +241,6 @@ class AttentionBackend(ABC): ...@@ -242,8 +241,6 @@ class AttentionBackend(ABC):
invalid_reasons.append( invalid_reasons.append(
"partial multimodal token full attention not supported" "partial multimodal token full attention not supported"
) )
if use_alibi_sqrt and not cls.supports_alibi_sqrt():
invalid_reasons.append("use_alibi_sqrt not supported")
if use_mla != cls.is_mla(): if use_mla != cls.is_mla():
if use_mla: if use_mla:
invalid_reasons.append("MLA not supported") invalid_reasons.append("MLA not supported")
......
...@@ -33,13 +33,6 @@ if is_flash_attn_varlen_func_available(): ...@@ -33,13 +33,6 @@ if is_flash_attn_varlen_func_available():
vllm_flash_attn_varlen_func, vllm_flash_attn_varlen_func,
reshape_and_cache_cuda, reshape_and_cache_cuda,
) )
from vllm.v1.attention.ops.triton_reshape_and_cache_flash import (
triton_reshape_and_cache_flash,
)
try:
from flash_attn import varlen_fwd_unified
except Exception:
varlen_fwd_unified = None
else: else:
from vllm.v1.attention.backends.fa_utils import ( from vllm.v1.attention.backends.fa_utils import (
flash_attn_supports_sinks, flash_attn_supports_sinks,
...@@ -119,38 +112,6 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -119,38 +112,6 @@ class FlashAttentionBackend(AttentionBackend):
@staticmethod @staticmethod
def get_builder_cls() -> type["FlashAttentionMetadataBuilder"]: def get_builder_cls() -> type["FlashAttentionMetadataBuilder"]:
return FlashAttentionMetadataBuilder return FlashAttentionMetadataBuilder
@classmethod
def supports_alibi_sqrt(cls) -> bool:
return True
@classmethod
def supports_mm_prefix(cls) -> bool:
return True
@staticmethod
def _use_rocm_unified_kv_layout(
block_size: int | None = None,
key_cache: torch.Tensor | None = None,
value_cache: torch.Tensor | None = None,
) -> bool:
if not current_platform.is_rocm():
return False
if block_size is None:
if key_cache is not None and value_cache is not None:
if key_cache.ndim != 4 or value_cache.ndim != 4:
return False
if key_cache.shape != value_cache.shape:
return False
block_size = key_cache.shape[1]
else:
try:
block_size = get_current_vllm_config().cache_config.block_size
except Exception:
return False
return block_size is not None and block_size != 64 and block_size % 64 == 0
if current_platform.is_rocm(): if current_platform.is_rocm():
@staticmethod @staticmethod
...@@ -163,9 +124,6 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -163,9 +124,6 @@ class FlashAttentionBackend(AttentionBackend):
) -> tuple[tuple[int, ...], tuple[int, ...]]: ) -> tuple[tuple[int, ...], tuple[int, ...]]:
if block_size % 16 != 0: if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.") raise ValueError("Block size must be a multiple of 16.")
if FlashAttentionBackend._use_rocm_unified_kv_layout(block_size):
unified_shape = (num_blocks, block_size, num_kv_heads, head_size)
return (unified_shape, unified_shape)
return ( return (
(num_blocks, num_kv_heads, block_size, head_size), (num_blocks, num_kv_heads, block_size, head_size),
(num_blocks, num_kv_heads, head_size, block_size), (num_blocks, num_kv_heads, head_size, block_size),
...@@ -178,31 +136,20 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -178,31 +136,20 @@ class FlashAttentionBackend(AttentionBackend):
# `stride_order` indicates the permutation that gets # `stride_order` indicates the permutation that gets
# us from `get_kv_cache_shape` to the actual memory layout we want. # us from `get_kv_cache_shape` to the actual memory layout we want.
cache_layout = get_kv_cache_layout() cache_layout = get_kv_cache_layout()
if FlashAttentionBackend._use_rocm_unified_kv_layout(): if cache_layout == "NHD" and include_num_layers_dimension:
if cache_layout != "NHD": # (num_blocks, num_layers, block_size, num_kv_heads, head_size)
raise RuntimeError( return (1, 0, 3, 2, 5), (1, 0, 4, 2, 3)
"ROCm unified KV layout currently supports NHD only." elif cache_layout == "NHD":
)
if include_num_layers_dimension:
# (num_blocks, num_layers, block_size, num_kv_heads, head_size)
return (1, 0, 2, 3, 4), (1, 0, 2, 3, 4)
key_stride_order = (0, 1, 2, 3) key_stride_order = (0, 1, 2, 3)
value_stride_order = (0, 1, 2, 3) value_stride_order = (0, 1, 2, 3)
elif cache_layout == "HND" and include_num_layers_dimension:
# (num_blocks, num_kv_heads, num_layers, block_size, head_size)
return (1, 2, 0, 3, 4), (1, 2, 0, 4, 3)
elif cache_layout == "HND":
key_stride_order = (0, 1, 2, 3)
value_stride_order = (0, 1, 3, 2)
else: else:
if cache_layout == "NHD" and include_num_layers_dimension: raise ValueError(f"Unknown cache layout format {cache_layout}.")
# (num_blocks, num_layers, block_size, num_kv_heads, head_size)
return (1, 0, 3, 2, 5), (1, 0, 4, 2, 3)
elif cache_layout == "NHD":
key_stride_order = (0, 1, 2, 3)
value_stride_order = (0, 1, 2, 3)
elif cache_layout == "HND" and include_num_layers_dimension:
# (num_blocks, num_kv_heads, num_layers, block_size, head_size)
return (1, 2, 0, 3, 4), (1, 2, 0, 4, 3)
elif cache_layout == "HND":
key_stride_order = (0, 1, 2, 3)
value_stride_order = (0, 1, 3, 2)
else:
raise ValueError(f"Unknown cache layout format {cache_layout}.")
return key_stride_order, value_stride_order return key_stride_order, value_stride_order
else: else:
@staticmethod @staticmethod
...@@ -324,34 +271,8 @@ class FlashAttentionMetadata: ...@@ -324,34 +271,8 @@ class FlashAttentionMetadata:
prefix_scheduler_metadata: torch.Tensor | None = None prefix_scheduler_metadata: torch.Tensor | None = None
max_num_splits: int = 0 max_num_splits: int = 0
mm_prefix_range: dict[int, list[tuple[int, int]]] | None = None
qq_bias: torch.Tensor | None = None
causal: bool = True causal: bool = True
@property
def mm_prefix_range_tensor(self) -> torch.Tensor | None:
if self.mm_prefix_range is None:
return None
num_seqs = self.seq_lens.shape[0]
device = self.seq_lens.device
range_lists = [
self.mm_prefix_range.get(i, [(0, 0)]) or [(0, 0)]
for i in range(num_seqs)
]
if all(r == [(0, 0)] for r in range_lists):
return None
range_tensors = [
torch.tensor(r, dtype=torch.int32, device=device).view(-1, 2)
for r in range_lists
]
return torch.nested.nested_tensor(
range_tensors, layout=torch.jagged
).to_padded_tensor(0)
def _get_sliding_window_configs( def _get_sliding_window_configs(
vllm_config: VllmConfig, vllm_config: VllmConfig,
...@@ -676,7 +597,6 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -676,7 +597,6 @@ class FlashAttentionImpl(AttentionImpl):
attn_type: AttentionType = AttentionType.DECODER, attn_type: AttentionType = AttentionType.DECODER,
kv_sharing_target_layer_name: str | None = None, kv_sharing_target_layer_name: str | None = None,
sinks: torch.Tensor | None = None, sinks: torch.Tensor | None = None,
use_alibi_sqrt: bool = False,
) -> None: ) -> None:
self.num_heads = num_heads self.num_heads = num_heads
self.head_size = head_size self.head_size = head_size
...@@ -702,7 +622,6 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -702,7 +622,6 @@ class FlashAttentionImpl(AttentionImpl):
self.attn_type = attn_type self.attn_type = attn_type
self.vllm_flash_attn_version = get_flash_attn_version() self.vllm_flash_attn_version = get_flash_attn_version()
self.use_alibi_sqrt = use_alibi_sqrt
# Cache the batch invariant result for use in forward passes # Cache the batch invariant result for use in forward passes
self.batch_invariant_enabled = vllm_is_batch_invariant() self.batch_invariant_enabled = vllm_is_batch_invariant()
...@@ -729,14 +648,6 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -729,14 +648,6 @@ class FlashAttentionImpl(AttentionImpl):
else False else False
) )
def _get_unified_extras(
self,
attn_metadata: FlashAttentionMetadata,
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
mm_prefix_range_tensor = attn_metadata.mm_prefix_range_tensor
qq_bias = attn_metadata.qq_bias
return mm_prefix_range_tensor, qq_bias
def forward( def forward(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
...@@ -863,60 +774,30 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -863,60 +774,30 @@ class FlashAttentionImpl(AttentionImpl):
print(f"q.shape = {query[:num_actual_tokens].shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}") print(f"q.shape = {query[:num_actual_tokens].shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}")
print(f"cu_seqlens_q.shape = {cu_seqlens_q.shape}, max_seqlen_q = {max_seqlen_q}, seqused_k.shape = {seqused_k.shape}, max_seqlen_k = {max_seqlen_k}") print(f"cu_seqlens_q.shape = {cu_seqlens_q.shape}, max_seqlen_q = {max_seqlen_q}, seqused_k.shape = {seqused_k.shape}, max_seqlen_k = {max_seqlen_k}")
print(f"softmax_scale = {self.scale:.3f}, alibi_slopes = {self.alibi_slopes}, window_size = {self.sliding_window}, block_tables.shape = {block_table.shape}, softcap = {self.logits_soft_cap}, scheduler_metadata = {scheduler_metadata}") print(f"softmax_scale = {self.scale:.3f}, alibi_slopes = {self.alibi_slopes}, window_size = {self.sliding_window}, block_tables.shape = {block_table.shape}, softcap = {self.logits_soft_cap}, scheduler_metadata = {scheduler_metadata}")
use_unified_kv_layout = ( vllm_flash_attn_varlen_func(
FlashAttentionBackend._use_rocm_unified_kv_layout( q=query[:num_actual_tokens],
key_cache=key_cache, value_cache=value_cache) k=key_cache,
v=value_cache,
out=output[:num_actual_tokens],
cu_seqlens_q=cu_seqlens_q,
max_seqlen_q=max_seqlen_q,
seqused_k=seqused_k,
max_seqlen_k=max_seqlen_k,
softmax_scale=self.scale,
causal=attn_metadata.causal,
alibi_slopes=self.alibi_slopes,
window_size=sliding_window_size,
block_table=block_table,
softcap=self.logits_soft_cap,
scheduler_metadata=scheduler_metadata,
fa_version=self.vllm_flash_attn_version,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
# num_splits=attn_metadata.max_num_splits,
s_aux=self.sinks,
is_prefix_cache=True,
) )
if use_unified_kv_layout:
mm_prefix_range_tensor, qq_bias = self._get_unified_extras(
attn_metadata
)
varlen_fwd_unified(
q=query[:num_actual_tokens],
k=key_cache,
v=value_cache,
cu_seqlens_q=cu_seqlens_q,
seqused_k=seqused_k,
block_table=block_table,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
softmax_scale=self.scale,
causal=attn_metadata.causal,
softcap=self.logits_soft_cap,
window_size=tuple(self.sliding_window),
alibi_slopes=self.alibi_slopes,
use_alibi_sqrt=self.use_alibi_sqrt,
qq_bias=qq_bias,
s_aux=self.sinks,
mm_prefix_range=mm_prefix_range_tensor,
return_softmax_lse=False,
out=output[:num_actual_tokens],
)
else:
vllm_flash_attn_varlen_func(
q=query[:num_actual_tokens],
k=key_cache,
v=value_cache,
out=output[:num_actual_tokens],
cu_seqlens_q=cu_seqlens_q,
max_seqlen_q=max_seqlen_q,
seqused_k=seqused_k,
max_seqlen_k=max_seqlen_k,
softmax_scale=self.scale,
causal=attn_metadata.causal,
alibi_slopes=self.alibi_slopes,
window_size=sliding_window_size,
block_table=block_table,
softcap=self.logits_soft_cap,
scheduler_metadata=scheduler_metadata,
fa_version=self.vllm_flash_attn_version,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
# num_splits=attn_metadata.max_num_splits,
s_aux=self.sinks,
is_prefix_cache=True,
)
else: else:
flash_attn_varlen_func( flash_attn_varlen_func(
q=query[:num_actual_tokens], q=query[:num_actual_tokens],
...@@ -1008,11 +889,21 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -1008,11 +889,21 @@ class FlashAttentionImpl(AttentionImpl):
# op uses the slot_mapping's shape to determine the number of # op uses the slot_mapping's shape to determine the number of
# actual tokens. # actual tokens.
if current_platform.is_rocm(): if current_platform.is_rocm():
if FlashAttentionBackend._use_rocm_unified_kv_layout( if envs.VLLM_USE_OPT_RESHAPE_AND_CACHE:
key_cache=key_cache, from lightop import reshape_and_cache_cuda
value_cache=value_cache, reshape_and_cache_cuda(
): key,
triton_reshape_and_cache_flash( value,
key_cache,
value_cache,
slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale
)
else:
from vllm.v1.attention.backends.fa_utils import reshape_and_cache_cuda
reshape_and_cache_cuda(
key, key,
value, value,
key_cache, key_cache,
...@@ -1022,32 +913,6 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -1022,32 +913,6 @@ class FlashAttentionImpl(AttentionImpl):
layer._k_scale, layer._k_scale,
layer._v_scale, layer._v_scale,
) )
else:
if envs.VLLM_USE_OPT_RESHAPE_AND_CACHE:
from lightop import reshape_and_cache_cuda
reshape_and_cache_cuda(
key,
value,
key_cache,
value_cache,
slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale
)
else:
from vllm.v1.attention.backends.fa_utils import reshape_and_cache_cuda
reshape_and_cache_cuda(
key,
value,
key_cache,
value_cache,
slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
else: else:
reshape_and_cache_flash( reshape_and_cache_flash(
key, key,
......
...@@ -12,6 +12,10 @@ import torch ...@@ -12,6 +12,10 @@ import torch
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
try:
from flash_attn import varlen_fwd_unified
except Exception:
varlen_fwd_unified = None
logger = init_logger(__name__) logger = init_logger(__name__)
float8_info = torch.finfo(current_platform.fp8_dtype()) float8_info = torch.finfo(current_platform.fp8_dtype())
...@@ -983,61 +987,92 @@ def unified_attention( ...@@ -983,61 +987,92 @@ def unified_attention(
or num_seqs > seq_threshold_3D or num_seqs > seq_threshold_3D
): ):
# print(f"[2D Triton] k shape: {k.shape}, v shape: {v.shape}") # print(f"[2D Triton] k shape: {k.shape}, v shape: {v.shape}")
kernel_unified_attention_2d[ use_fa_unified_2d = (
( current_platform.is_rocm()
total_num_q_blocks, and varlen_fwd_unified is not None
num_kv_heads, and block_size % 64 == 0
) and head_size == 256
](
output_ptr=out,
query_ptr=q,
key_cache_ptr=k,
value_cache_ptr=v,
sink_ptr=sinks,
block_tables_ptr=block_table,
seq_lens_ptr=seqused_k,
alibi_slopes_ptr=alibi_slopes,
qq_bias_ptr=qq_bias,
scale=softmax_scale,
k_scale=k_descale,
v_scale=v_descale,
out_scale=1 / output_scale if output_scale is not None else 1.0,
softcap=softcap,
num_query_heads=num_query_heads,
num_queries_per_kv=num_queries_per_kv,
block_table_stride=block_table.stride(0),
query_stride_0=q.stride(0),
query_stride_1=q.stride(1),
output_stride_0=out.stride(0),
output_stride_1=out.stride(1),
qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0,
BLOCK_SIZE=block_size,
TILE_SIZE=TILE_SIZE_PREFILL,
HEAD_SIZE=head_size,
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
USE_ALIBI_SLOPES=use_alibi_slopes,
USE_ALIBI_SQRT=use_alibi_sqrt,
USE_QQ_BIAS=use_qq_bias,
USE_SOFTCAP=(softcap > 0),
USE_SINKS=(sinks is not None),
USE_MM_PREFIX=use_mm_prefix,
MAX_MM_RANGES=max_mm_ranges,
mm_prefix_range_ptr=mm_prefix_range,
SLIDING_WINDOW=(1 + window_size[0]),
stride_k_cache_0=k.stride(0),
stride_k_cache_1=k.stride(1),
stride_k_cache_2=k.stride(2),
stride_k_cache_3=k.stride(3),
stride_v_cache_0=v.stride(0),
stride_v_cache_1=v.stride(1),
stride_v_cache_2=v.stride(2),
stride_v_cache_3=v.stride(3),
query_start_len_ptr=cu_seqlens_q,
BLOCK_Q=BLOCK_Q,
num_seqs=num_seqs,
BLOCK_M=BLOCK_M,
USE_FP8=output_scale is not None,
) )
if not use_fa_unified_2d:
# print("Running Triton kernel")
kernel_unified_attention_2d[
(
total_num_q_blocks,
num_kv_heads,
)
](
output_ptr=out,
query_ptr=q,
key_cache_ptr=k,
value_cache_ptr=v,
sink_ptr=sinks,
block_tables_ptr=block_table,
seq_lens_ptr=seqused_k,
alibi_slopes_ptr=alibi_slopes,
qq_bias_ptr=qq_bias,
scale=softmax_scale,
k_scale=k_descale,
v_scale=v_descale,
out_scale=1 / output_scale if output_scale is not None else 1.0,
softcap=softcap,
num_query_heads=num_query_heads,
num_queries_per_kv=num_queries_per_kv,
block_table_stride=block_table.stride(0),
query_stride_0=q.stride(0),
query_stride_1=q.stride(1),
output_stride_0=out.stride(0),
output_stride_1=out.stride(1),
qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0,
BLOCK_SIZE=block_size,
TILE_SIZE=TILE_SIZE_PREFILL,
HEAD_SIZE=head_size,
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
USE_ALIBI_SLOPES=use_alibi_slopes,
USE_ALIBI_SQRT=use_alibi_sqrt,
USE_QQ_BIAS=use_qq_bias,
USE_SOFTCAP=(softcap > 0),
USE_SINKS=(sinks is not None),
USE_MM_PREFIX=use_mm_prefix,
MAX_MM_RANGES=max_mm_ranges,
mm_prefix_range_ptr=mm_prefix_range,
SLIDING_WINDOW=(1 + window_size[0]),
stride_k_cache_0=k.stride(0),
stride_k_cache_1=k.stride(1),
stride_k_cache_2=k.stride(2),
stride_k_cache_3=k.stride(3),
stride_v_cache_0=v.stride(0),
stride_v_cache_1=v.stride(1),
stride_v_cache_2=v.stride(2),
stride_v_cache_3=v.stride(3),
query_start_len_ptr=cu_seqlens_q,
BLOCK_Q=BLOCK_Q,
num_seqs=num_seqs,
BLOCK_M=BLOCK_M,
USE_FP8=output_scale is not None,
)
else:
# print("Running FA kernel")
varlen_fwd_unified(
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens_q,
seqused_k=seqused_k,
block_table=block_table,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
softmax_scale=softmax_scale,
causal=causal,
softcap=softcap,
window_size=window_size,
alibi_slopes=alibi_slopes,
use_alibi_sqrt=use_alibi_sqrt,
qq_bias=qq_bias,
s_aux=sinks,
mm_prefix_range=mm_prefix_range,
return_softmax_lse=False,
out=out,
)
else: else:
# print(f"[3D Triton] k shape: {k.shape}, v shape: {v.shape}") # print(f"[3D Triton] k shape: {k.shape}, v shape: {v.shape}")
kernel_unified_attention_3d[ kernel_unified_attention_3d[
......
...@@ -27,7 +27,6 @@ class AttentionSelectorConfig(NamedTuple): ...@@ -27,7 +27,6 @@ class AttentionSelectorConfig(NamedTuple):
has_sink: bool = False has_sink: bool = False
use_sparse: bool = False use_sparse: bool = False
use_mm_prefix: bool = False use_mm_prefix: bool = False
use_alibi_sqrt: bool = False
attn_type: str = AttentionType.DECODER attn_type: str = AttentionType.DECODER
def __repr__(self): def __repr__(self):
...@@ -40,7 +39,6 @@ class AttentionSelectorConfig(NamedTuple): ...@@ -40,7 +39,6 @@ class AttentionSelectorConfig(NamedTuple):
f"has_sink={self.has_sink}, " f"has_sink={self.has_sink}, "
f"use_sparse={self.use_sparse}, " f"use_sparse={self.use_sparse}, "
f"use_mm_prefix={self.use_mm_prefix}, " f"use_mm_prefix={self.use_mm_prefix}, "
f"use_alibi_sqrt={self.use_alibi_sqrt}, "
f"attn_type={self.attn_type})" f"attn_type={self.attn_type})"
) )
...@@ -54,7 +52,6 @@ def get_attn_backend( ...@@ -54,7 +52,6 @@ def get_attn_backend(
has_sink: bool = False, has_sink: bool = False,
use_sparse: bool = False, use_sparse: bool = False,
use_mm_prefix: bool = False, use_mm_prefix: bool = False,
use_alibi_sqrt: bool = False,
attn_type: str | None = None, attn_type: str | None = None,
) -> type[AttentionBackend]: ) -> type[AttentionBackend]:
"""Selects which attention backend to use and lazily imports it.""" """Selects which attention backend to use and lazily imports it."""
...@@ -80,7 +77,6 @@ def get_attn_backend( ...@@ -80,7 +77,6 @@ def get_attn_backend(
has_sink=has_sink, has_sink=has_sink,
use_sparse=use_sparse, use_sparse=use_sparse,
use_mm_prefix=use_mm_prefix, use_mm_prefix=use_mm_prefix,
use_alibi_sqrt=use_alibi_sqrt,
attn_type=attn_type or AttentionType.DECODER, attn_type=attn_type or AttentionType.DECODER,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment