"ssh:/git@developer.sourcefind.cn:2222/OpenDAS/dynamo.git" did not exist on "9bdc8b730345b43e58f30816036cc7462c5872af"
Unverified Commit ee3cf457 authored by Yan Ma's avatar Yan Ma Committed by GitHub
Browse files

[XPU] Initial support for GDN attention on Qwen3-next/Qwen3.5 (#33657)


Signed-off-by: default avatarYan Ma <yan.ma@intel.com>
Signed-off-by: default avatarChendi Xue <chendi.xue@intel.com>
Co-authored-by: default avatarChendi Xue <chendi.xue@intel.com>
Co-authored-by: default avatarKunshang Ji <kunshang.ji@intel.com>
parent 05e68e1f
...@@ -560,6 +560,11 @@ class RMSNormGated(CustomOp): ...@@ -560,6 +560,11 @@ class RMSNormGated(CustomOp):
activation=self.activation, activation=self.activation,
) )
def forward_xpu(
self, x: torch.Tensor, z: torch.Tensor | None = None
) -> torch.Tensor:
return self.forward_cuda(x, z)
class LayerNorm(nn.Module): class LayerNorm(nn.Module):
""" """
......
...@@ -262,6 +262,9 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase): ...@@ -262,6 +262,9 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
else 0 else 0
) )
self.gqa_interleaved_layout = gqa_interleaved_layout self.gqa_interleaved_layout = gqa_interleaved_layout
self._forward_method = (
self.forward_xpu if current_platform.is_xpu() else self.forward_cuda
)
# QKV # QKV
self.conv_dim = self.key_dim * 2 + self.value_dim self.conv_dim = self.key_dim * 2 + self.value_dim
...@@ -493,6 +496,13 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase): ...@@ -493,6 +496,13 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
output: torch.Tensor, output: torch.Tensor,
):
self._forward_method(hidden_states, output)
def forward_cuda(
self,
hidden_states: torch.Tensor,
output: torch.Tensor,
): ):
""" """
Forward pass with three parts: Forward pass with three parts:
...@@ -567,6 +577,90 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase): ...@@ -567,6 +577,90 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)") core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)")
output[:num_tokens], _ = self.out_proj(core_attn_out) output[:num_tokens], _ = self.out_proj(core_attn_out)
def forward_xpu(
self,
hidden_states: torch.Tensor,
output: torch.Tensor,
):
"""
Forward pass with three parts:
1. Input projection
2. Core attention (custom op)
3. Output projection
"""
num_tokens = hidden_states.size(0)
assert not hasattr(self, "in_proj_qkv"), "lora isn't supported on XPU."
# ============================================================
# Part 1: Input Projection
# ============================================================
projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states)
projected_states_ba, _ = self.in_proj_ba(hidden_states)
# ============================================================
# Part 2: Core Attention
# ============================================================
forward_context = get_forward_context()
attn_metadata: AttentionMetadata = forward_context.attn_metadata
core_attn_out = torch.zeros(
(num_tokens, self.num_v_heads // self.tp_size, self.head_v_dim),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
z = torch.empty_like(core_attn_out)
if attn_metadata is not None:
attn_metadata = attn_metadata[self.prefix]
# TODO: xpu does not support this param yet
spec_sequence_masks = attn_metadata.spec_sequence_masks
assert spec_sequence_masks is None
conv_weights = self.conv1d.weight.view(
self.conv1d.weight.size(0), self.conv1d.weight.size(2)
)
conv_state = self.kv_cache[0]
ssm_state = self.kv_cache[1]
torch.ops._xpu_C.gdn_attention(
core_attn_out,
z,
projected_states_qkvz,
projected_states_ba,
self.num_k_heads,
self.num_v_heads,
self.head_k_dim,
self.head_v_dim,
conv_state=conv_state,
ssm_state=ssm_state,
conv_weights=conv_weights,
conv_bias=self.conv1d.bias,
activation=self.activation,
A_log=self.A_log,
dt_bias=self.dt_bias,
num_prefills=attn_metadata.num_prefills,
num_decodes=attn_metadata.num_decodes,
has_initial_state=attn_metadata.has_initial_state,
non_spec_query_start_loc=attn_metadata.non_spec_query_start_loc,
non_spec_state_indices_tensor=attn_metadata.non_spec_state_indices_tensor,
num_actual_tokens=attn_metadata.num_actual_tokens,
tp_size=self.tp_size,
reorder_input=not self.gqa_interleaved_layout,
)
# ============================================================
# Part 3: Output Projection
# ============================================================
z_shape_og = z.shape
# Reshape input data into 2D tensor
core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1])
z = z.reshape(-1, z.shape[-1])
core_attn_out = self.norm(core_attn_out, z)
core_attn_out = core_attn_out.reshape(z_shape_og)
core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)")
output[:num_tokens], _ = self.out_proj(core_attn_out)
def _warmup_prefill_kernels(self, mixed_qkv: torch.Tensor) -> None: def _warmup_prefill_kernels(self, mixed_qkv: torch.Tensor) -> None:
"""Warm up GDN prefill kernels during V1 profiling. """Warm up GDN prefill kernels during V1 profiling.
......
...@@ -218,6 +218,57 @@ class XPUPlatform(Platform): ...@@ -218,6 +218,57 @@ class XPUPlatform(Platform):
# ref. https://openucx.readthedocs.io/en/master/faq.html # ref. https://openucx.readthedocs.io/en/master/faq.html
os.environ["UCX_MEMTYPE_CACHE"] = "n" os.environ["UCX_MEMTYPE_CACHE"] = "n"
@classmethod
def update_block_size_for_backend(cls, vllm_config: "VllmConfig") -> None:
super().update_block_size_for_backend(vllm_config)
from vllm.config.vllm import get_layers_from_vllm_config
from vllm.model_executor.layers.attention_layer_base import (
AttentionLayerBase,
)
from vllm.utils.math_utils import cdiv
cache_config = vllm_config.cache_config
# special fix for GDN since kernel only supports block size dividable by 64
attn_layers = get_layers_from_vllm_config(
vllm_config,
AttentionLayerBase, # type: ignore[type-abstract]
)
kernel_block_size = None
for layer in attn_layers.values():
b = layer.get_attn_backend()
if b.get_name() == "GDN_ATTN":
kernel_block_size = 64
break
if kernel_block_size is None:
return
new_block_size = (
cdiv(cache_config.block_size, kernel_block_size) * kernel_block_size
)
if new_block_size == cache_config.block_size:
return
if cache_config.mamba_cache_mode == "align":
cache_config.mamba_block_size = new_block_size
original_mamba_page_size_padded = cache_config.mamba_page_size_padded
if cache_config.mamba_page_size_padded is not None:
attn_page_size_1_token = (
cache_config.mamba_page_size_padded // cache_config.block_size
)
cache_config.mamba_page_size_padded = (
new_block_size * attn_page_size_1_token
)
cache_config.block_size = new_block_size
logger.info(
"[XPU]Setting attention block size to %d tokens to ensure multiple of %d, "
"set mamba_page_size_padded to %d bytes accordingly, before was %d bytes.",
new_block_size,
kernel_block_size,
cache_config.mamba_page_size_padded,
original_mamba_page_size_padded,
)
@classmethod @classmethod
def support_hybrid_kv_cache(cls) -> bool: def support_hybrid_kv_cache(cls) -> bool:
return True return True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment