Commit 624eab7c authored by laibao's avatar laibao
Browse files

[BUGFIX] 修复 Qwen3.5 在新版 transformers 下的配置兼容问题并统一 ROCm unified attention 路由

目的:
  修复 Qwen3.5 / Qwen3.5-MoE 在升级 transformers 后的配置解析兼容问题,并优化 ROCm 下 unified attention 的路由策略,避免prefill 和 decode落到不同实现路径上,降低后续排查和行为不一致的成本
parent e220b38b
...@@ -5,6 +5,7 @@ import os ...@@ -5,6 +5,7 @@ import os
from collections.abc import Callable from collections.abc import Callable
from dataclasses import asdict from dataclasses import asdict
from functools import cache, partial from functools import cache, partial
from importlib.util import find_spec
from importlib.metadata import version from importlib.metadata import version
from pathlib import Path from pathlib import Path
from typing import Any, Literal, TypeAlias from typing import Any, Literal, TypeAlias
...@@ -44,10 +45,14 @@ try: ...@@ -44,10 +45,14 @@ try:
# Transformers v5 # Transformers v5
from transformers.configuration_utils import ALLOWED_ATTENTION_LAYER_TYPES from transformers.configuration_utils import ALLOWED_ATTENTION_LAYER_TYPES
except ImportError: except ImportError:
# Transformers v4 try:
# Transformers v4.52+
from transformers.configuration_utils import ( from transformers.configuration_utils import (
ALLOWED_LAYER_TYPES as ALLOWED_ATTENTION_LAYER_TYPES, ALLOWED_LAYER_TYPES as ALLOWED_ATTENTION_LAYER_TYPES,
) )
except ImportError:
# Transformers v4.51 and earlier: neither symbol exists, use empty set
ALLOWED_ATTENTION_LAYER_TYPES: set = set()
if envs.VLLM_USE_MODELSCOPE: if envs.VLLM_USE_MODELSCOPE:
...@@ -60,6 +65,18 @@ MISTRAL_CONFIG_NAME = "params.json" ...@@ -60,6 +65,18 @@ MISTRAL_CONFIG_NAME = "params.json"
logger = init_logger(__name__) logger = init_logger(__name__)
def _hf_has_native_qwen3_5_config(model_type: str | None) -> bool:
if model_type == "qwen3_5":
return find_spec(
"transformers.models.qwen3_5.configuration_qwen3_5"
) is not None
if model_type == "qwen3_5_moe":
return find_spec(
"transformers.models.qwen3_5_moe.configuration_qwen3_5_moe"
) is not None
return False
class LazyConfigDict(dict): class LazyConfigDict(dict):
def __getitem__(self, key): def __getitem__(self, key):
if isinstance(value := super().__getitem__(key), type): if isinstance(value := super().__getitem__(key), type):
...@@ -99,6 +116,8 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict( ...@@ -99,6 +116,8 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict(
step3p5="Step3p5Config", step3p5="Step3p5Config",
qwen3_asr="Qwen3ASRConfig", qwen3_asr="Qwen3ASRConfig",
qwen3_next="Qwen3NextConfig", qwen3_next="Qwen3NextConfig",
qwen3_5="Qwen3_5Config",
qwen3_5_moe="Qwen3_5MoeConfig",
lfm2_moe="Lfm2MoeConfig", lfm2_moe="Lfm2MoeConfig",
tarsier2="Tarsier2Config", tarsier2="Tarsier2Config",
) )
...@@ -151,7 +170,9 @@ class HFConfigParser(ConfigParserBase): ...@@ -151,7 +170,9 @@ class HFConfigParser(ConfigParserBase):
if (hf_overrides := kwargs.pop("hf_overrides", None)) is not None: if (hf_overrides := kwargs.pop("hf_overrides", None)) is not None:
model_type = hf_overrides.get("model_type", model_type) model_type = hf_overrides.get("model_type", model_type)
if model_type in _CONFIG_REGISTRY: if model_type in _CONFIG_REGISTRY and not _hf_has_native_qwen3_5_config(
model_type
):
config_class = _CONFIG_REGISTRY[model_type] config_class = _CONFIG_REGISTRY[model_type]
config = config_class.from_pretrained( config = config_class.from_pretrained(
model, model,
......
...@@ -68,10 +68,10 @@ class Qwen3_5TextConfig(PretrainedConfig): ...@@ -68,10 +68,10 @@ class Qwen3_5TextConfig(PretrainedConfig):
eos_token_id=None, eos_token_id=None,
**kwargs, **kwargs,
): ):
kwargs["ignore_keys_at_rope_validation"] = [ kwargs["ignore_keys_at_rope_validation"] = {
"mrope_section", "mrope_section",
"mrope_interleaved", "mrope_interleaved",
] }
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size self.hidden_size = hidden_size
......
...@@ -75,10 +75,10 @@ class Qwen3_5MoeTextConfig(PretrainedConfig): ...@@ -75,10 +75,10 @@ class Qwen3_5MoeTextConfig(PretrainedConfig):
eos_token_id=None, eos_token_id=None,
**kwargs, **kwargs,
): ):
kwargs["ignore_keys_at_rope_validation"] = [ kwargs["ignore_keys_at_rope_validation"] = {
"mrope_section", "mrope_section",
"mrope_interleaved", "mrope_interleaved",
] }
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size self.hidden_size = hidden_size
......
...@@ -973,11 +973,24 @@ def unified_attention( ...@@ -973,11 +973,24 @@ def unified_attention(
is_prefill=False, is_prefill=False,
) )
# On ROCm, prefer the unified FA kernel whenever the cache layout and
# headdim match its supported path. This keeps both prefill and decode on
# the same implementation instead of routing decode-only batches to the 3D
# Triton kernel.
use_fa_unified_2d = (
current_platform.is_rocm()
and varlen_fwd_unified is not None
and block_size % 64 == 0
and head_size == 256
)
# Launch the 2D kernel if # Launch the 2D kernel if
# 1. No intermediate tiled softmax buffers for the 3D kernel have been allocated, or # 1. No intermediate tiled softmax buffers for the 3D kernel have been allocated, or
# 2. The batch includes at least one prefill request, or # 2. The batch includes at least one prefill request, or
# 3. The number of sequences exceeds the configured threshold # 3. The number of sequences exceeds the configured threshold
if ( if (
use_fa_unified_2d
or
seq_threshold_3D is None seq_threshold_3D is None
or num_par_softmax_segments is None or num_par_softmax_segments is None
or softmax_segm_output is None or softmax_segm_output is None
...@@ -987,12 +1000,6 @@ def unified_attention( ...@@ -987,12 +1000,6 @@ def unified_attention(
or num_seqs > seq_threshold_3D or num_seqs > seq_threshold_3D
): ):
# print(f"[2D Triton] k shape: {k.shape}, v shape: {v.shape}") # print(f"[2D Triton] k shape: {k.shape}, v shape: {v.shape}")
use_fa_unified_2d = (
current_platform.is_rocm()
and varlen_fwd_unified is not None
and block_size % 64 == 0
and head_size == 256
)
if not use_fa_unified_2d: if not use_fa_unified_2d:
# print("Running Triton kernel") # print("Running Triton kernel")
kernel_unified_attention_2d[ kernel_unified_attention_2d[
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment