Commit 624eab7c authored by laibao's avatar laibao
Browse files

[BUGFIX] 修复 Qwen3.5 在新版 transformers 下的配置兼容问题并统一 ROCm unified attention 路由

目的:
  修复 Qwen3.5 / Qwen3.5-MoE 在升级 transformers 后的配置解析兼容问题,并优化 ROCm 下 unified attention 的路由策略,避免prefill 和 decode落到不同实现路径上,降低后续排查和行为不一致的成本
parent e220b38b
......@@ -5,6 +5,7 @@ import os
from collections.abc import Callable
from dataclasses import asdict
from functools import cache, partial
from importlib.util import find_spec
from importlib.metadata import version
from pathlib import Path
from typing import Any, Literal, TypeAlias
......@@ -44,10 +45,14 @@ try:
# Transformers v5
from transformers.configuration_utils import ALLOWED_ATTENTION_LAYER_TYPES
except ImportError:
# Transformers v4
from transformers.configuration_utils import (
ALLOWED_LAYER_TYPES as ALLOWED_ATTENTION_LAYER_TYPES,
)
try:
# Transformers v4.52+
from transformers.configuration_utils import (
ALLOWED_LAYER_TYPES as ALLOWED_ATTENTION_LAYER_TYPES,
)
except ImportError:
# Transformers v4.51 and earlier: neither symbol exists, use empty set
ALLOWED_ATTENTION_LAYER_TYPES: set = set()
if envs.VLLM_USE_MODELSCOPE:
......@@ -60,6 +65,18 @@ MISTRAL_CONFIG_NAME = "params.json"
logger = init_logger(__name__)
def _hf_has_native_qwen3_5_config(model_type: str | None) -> bool:
if model_type == "qwen3_5":
return find_spec(
"transformers.models.qwen3_5.configuration_qwen3_5"
) is not None
if model_type == "qwen3_5_moe":
return find_spec(
"transformers.models.qwen3_5_moe.configuration_qwen3_5_moe"
) is not None
return False
class LazyConfigDict(dict):
def __getitem__(self, key):
if isinstance(value := super().__getitem__(key), type):
......@@ -99,6 +116,8 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict(
step3p5="Step3p5Config",
qwen3_asr="Qwen3ASRConfig",
qwen3_next="Qwen3NextConfig",
qwen3_5="Qwen3_5Config",
qwen3_5_moe="Qwen3_5MoeConfig",
lfm2_moe="Lfm2MoeConfig",
tarsier2="Tarsier2Config",
)
......@@ -151,7 +170,9 @@ class HFConfigParser(ConfigParserBase):
if (hf_overrides := kwargs.pop("hf_overrides", None)) is not None:
model_type = hf_overrides.get("model_type", model_type)
if model_type in _CONFIG_REGISTRY:
if model_type in _CONFIG_REGISTRY and not _hf_has_native_qwen3_5_config(
model_type
):
config_class = _CONFIG_REGISTRY[model_type]
config = config_class.from_pretrained(
model,
......
......@@ -68,10 +68,10 @@ class Qwen3_5TextConfig(PretrainedConfig):
eos_token_id=None,
**kwargs,
):
kwargs["ignore_keys_at_rope_validation"] = [
kwargs["ignore_keys_at_rope_validation"] = {
"mrope_section",
"mrope_interleaved",
]
}
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
......
......@@ -75,10 +75,10 @@ class Qwen3_5MoeTextConfig(PretrainedConfig):
eos_token_id=None,
**kwargs,
):
kwargs["ignore_keys_at_rope_validation"] = [
kwargs["ignore_keys_at_rope_validation"] = {
"mrope_section",
"mrope_interleaved",
]
}
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
......
......@@ -973,11 +973,24 @@ def unified_attention(
is_prefill=False,
)
# On ROCm, prefer the unified FA kernel whenever the cache layout and
# headdim match its supported path. This keeps both prefill and decode on
# the same implementation instead of routing decode-only batches to the 3D
# Triton kernel.
use_fa_unified_2d = (
current_platform.is_rocm()
and varlen_fwd_unified is not None
and block_size % 64 == 0
and head_size == 256
)
# Launch the 2D kernel if
# 1. No intermediate tiled softmax buffers for the 3D kernel have been allocated, or
# 2. The batch includes at least one prefill request, or
# 3. The number of sequences exceeds the configured threshold
if (
use_fa_unified_2d
or
seq_threshold_3D is None
or num_par_softmax_segments is None
or softmax_segm_output is None
......@@ -987,12 +1000,6 @@ def unified_attention(
or num_seqs > seq_threshold_3D
):
# print(f"[2D Triton] k shape: {k.shape}, v shape: {v.shape}")
use_fa_unified_2d = (
current_platform.is_rocm()
and varlen_fwd_unified is not None
and block_size % 64 == 0
and head_size == 256
)
if not use_fa_unified_2d:
# print("Running Triton kernel")
kernel_unified_attention_2d[
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment