Unverified Commit 4c34b2f6 authored by Yuwen Zhou's avatar Yuwen Zhou Committed by GitHub
Browse files

[XPU] Enable torch.compile for XPU GDN attention (#39466)


Signed-off-by: default avataryuwenzho <yuwen.zhou@intel.com>
Signed-off-by: default avatarYuwen Zhou <yuwen.zhou@intel.com>
Co-authored-by: default avatarKunshang Ji <kunshang.ji@intel.com>
parent cf8a613a
...@@ -92,6 +92,72 @@ if hasattr(torch.ops._xpu_C, "int4_gemm_w4a16"): ...@@ -92,6 +92,72 @@ if hasattr(torch.ops._xpu_C, "int4_gemm_w4a16"):
return torch.empty((M, N), dtype=input.dtype, device=input.device) return torch.empty((M, N), dtype=input.dtype, device=input.device)
def _gdn_attention_core_xpu_impl(
core_attn_out: torch.Tensor,
z: torch.Tensor,
projected_states_qkvz: torch.Tensor,
projected_states_ba: torch.Tensor,
layer_name: str,
) -> None:
"""Custom op wrapping the XPU SYCL GDN kernel for torch.compile."""
from vllm.forward_context import get_forward_context
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata
forward_context = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
attn_metadata_raw = forward_context.attn_metadata
if attn_metadata_raw is None:
return
assert isinstance(attn_metadata_raw, dict)
attn_metadata = attn_metadata_raw[self.prefix]
assert isinstance(attn_metadata, GDNAttentionMetadata)
# TODO: xpu does not support speculative decoding yet
assert attn_metadata.spec_sequence_masks is None # type: ignore[attr-defined]
conv_weights = self.conv1d.weight.view(
self.conv1d.weight.size(0), self.conv1d.weight.size(2)
)
torch.ops._xpu_C.gdn_attention(
core_attn_out,
z,
projected_states_qkvz,
projected_states_ba,
self.num_k_heads,
self.num_v_heads,
self.head_k_dim,
self.head_v_dim,
conv_state=self.kv_cache[0],
ssm_state=self.kv_cache[1],
conv_weights=conv_weights,
conv_bias=self.conv1d.bias,
activation=self.activation,
A_log=self.A_log,
dt_bias=self.dt_bias,
num_prefills=attn_metadata.num_prefills, # type: ignore[attr-defined]
num_decodes=attn_metadata.num_decodes, # type: ignore[attr-defined]
has_initial_state=attn_metadata.has_initial_state, # type: ignore[attr-defined]
non_spec_query_start_loc=attn_metadata.non_spec_query_start_loc, # type: ignore[attr-defined]
non_spec_state_indices_tensor=attn_metadata.non_spec_state_indices_tensor, # type: ignore[attr-defined]
num_actual_tokens=attn_metadata.num_actual_tokens, # type: ignore[attr-defined]
tp_size=self.tp_size,
reorder_input=not self.gqa_interleaved_layout,
)
def _gdn_attention_core_xpu_fake(
core_attn_out: torch.Tensor,
z: torch.Tensor,
projected_states_qkvz: torch.Tensor,
projected_states_ba: torch.Tensor,
layer_name: str,
) -> None:
return
def _xpu_ops_deepseek_scaling_rope_impl( def _xpu_ops_deepseek_scaling_rope_impl(
positions: torch.Tensor, positions: torch.Tensor,
query: torch.Tensor, query: torch.Tensor,
...@@ -618,6 +684,13 @@ class xpu_ops: ...@@ -618,6 +684,13 @@ class xpu_ops:
fake_impl=_xpu_mxfp4_quantize_fake, fake_impl=_xpu_mxfp4_quantize_fake,
) )
direct_register_custom_op(
op_name="gdn_attention_core_xpu",
op_func=_gdn_attention_core_xpu_impl,
mutates_args=["core_attn_out", "z"],
fake_impl=_gdn_attention_core_xpu_fake,
)
_OPS_REGISTERED = True _OPS_REGISTERED = True
......
...@@ -744,6 +744,7 @@ class CompilationConfig: ...@@ -744,6 +744,7 @@ class CompilationConfig:
"vllm::linear_attention", "vllm::linear_attention",
"vllm::plamo2_mamba_mixer", "vllm::plamo2_mamba_mixer",
"vllm::gdn_attention_core", "vllm::gdn_attention_core",
"vllm::gdn_attention_core_xpu",
"vllm::olmo_hybrid_gdn_full_forward", "vllm::olmo_hybrid_gdn_full_forward",
"vllm::kda_attention", "vllm::kda_attention",
"vllm::sparse_attn_indexer", "vllm::sparse_attn_indexer",
......
...@@ -618,53 +618,19 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase): ...@@ -618,53 +618,19 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
# ============================================================ # ============================================================
# Part 2: Core Attention # Part 2: Core Attention
# ============================================================ # ============================================================
forward_context = get_forward_context()
attn_metadata_raw = forward_context.attn_metadata
core_attn_out = torch.zeros( core_attn_out = torch.zeros(
(num_tokens, self.num_v_heads // self.tp_size, self.head_v_dim), (num_tokens, self.num_v_heads // self.tp_size, self.head_v_dim),
dtype=hidden_states.dtype, dtype=hidden_states.dtype,
device=hidden_states.device, device=hidden_states.device,
) )
z = torch.empty_like(core_attn_out) z = torch.empty_like(core_attn_out)
if attn_metadata_raw is not None:
assert isinstance(attn_metadata_raw, dict)
attn_metadata = attn_metadata_raw[self.prefix]
# TODO: xpu does not support this param yet
spec_sequence_masks = attn_metadata.spec_sequence_masks # type: ignore[attr-defined]
assert spec_sequence_masks is None
conv_weights = self.conv1d.weight.view(
self.conv1d.weight.size(0), self.conv1d.weight.size(2)
)
conv_state = self.kv_cache[0]
ssm_state = self.kv_cache[1]
torch.ops._xpu_C.gdn_attention( torch.ops.vllm.gdn_attention_core_xpu(
core_attn_out, core_attn_out,
z, z,
projected_states_qkvz, projected_states_qkvz,
projected_states_ba, projected_states_ba,
self.num_k_heads, self.prefix,
self.num_v_heads,
self.head_k_dim,
self.head_v_dim,
conv_state=conv_state,
ssm_state=ssm_state,
conv_weights=conv_weights,
conv_bias=self.conv1d.bias,
activation=self.activation,
A_log=self.A_log,
dt_bias=self.dt_bias,
num_prefills=attn_metadata.num_prefills, # type: ignore[attr-defined]
num_decodes=attn_metadata.num_decodes, # type: ignore[attr-defined]
has_initial_state=attn_metadata.has_initial_state, # type: ignore[attr-defined]
non_spec_query_start_loc=attn_metadata.non_spec_query_start_loc, # type: ignore[attr-defined]
non_spec_state_indices_tensor=attn_metadata.non_spec_state_indices_tensor, # type: ignore[attr-defined]
num_actual_tokens=attn_metadata.num_actual_tokens, # type: ignore[attr-defined]
tp_size=self.tp_size,
reorder_input=not self.gqa_interleaved_layout,
) )
# ============================================================ # ============================================================
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment