Unverified Commit 888cb175 authored by YanbingJiang's avatar YanbingJiang Committed by GitHub
Browse files

Add intel_amx backend for Radix Attention for CPU (#6408)


Co-authored-by: default avatarChunyuan WU <chunyuan.wu@intel.com>
Co-authored-by: default avatarThien Tran <gau.nernst@yahoo.com.sg>
parent e39bca07
...@@ -109,3 +109,7 @@ class AttentionBackend(ABC): ...@@ -109,3 +109,7 @@ class AttentionBackend(ABC):
): ):
"""Run a forward for extend.""" """Run a forward for extend."""
raise NotImplementedError() raise NotImplementedError()
def support_triton(self):
"""Check if the current backend supports triton."""
return True
from __future__ import annotations
from typing import TYPE_CHECKING
import torch
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner
class IntelAMXAttnBackend(AttentionBackend):
def __init__(self, model_runner: ModelRunner):
import sgl_kernel
super().__init__()
self.forward_metadata = None
self.device = model_runner.device
self.num_head = (
model_runner.model_config.num_attention_heads // model_runner.tp_size
)
self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1]
self.decode_attention_fwd = torch.ops.sgl_kernel.decode_attention_cpu
self.extend_attention_fwd = torch.ops.sgl_kernel.extend_attention_cpu
def init_forward_metadata(self, forward_batch: ForwardBatch):
"""Init the metadata for a forward pass."""
bs = forward_batch.batch_size
attn_logits = torch.zeros(
(
bs,
self.num_head,
8, # self.num_kv_splits,
self.v_head_dim + 1,
),
dtype=torch.float32,
device=self.device,
)
if forward_batch.forward_mode.is_decode_or_idle():
max_extend_len = None
else:
max_extend_len = torch.max(forward_batch.extend_seq_lens).item()
self.forward_metadata = (attn_logits, max_extend_len)
def forward_extend(
self,
q,
k,
v,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
):
if layer.qk_head_dim != layer.v_head_dim:
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
else:
o = torch.empty_like(q)
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, forward_batch.out_cache_loc, k, v
)
_, max_extend_len = self.forward_metadata
self.extend_attention_fwd(
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
k,
v,
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
forward_batch.req_to_token_pool.req_to_token,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.extend_seq_lens,
forward_batch.extend_start_loc,
max_extend_len,
layer.scaling,
layer.logit_cap,
)
return o
def forward_decode(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
):
attn_logits, _ = self.forward_metadata
q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
if layer.qk_head_dim != layer.v_head_dim:
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
else:
o = torch.empty_like(q)
self.decode_attention_fwd(
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
k,
v,
forward_batch.out_cache_loc,
attn_logits,
forward_batch.req_to_token_pool.req_to_token,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
layer.scaling,
layer.logit_cap,
)
return o
def support_triton(self):
return False
...@@ -265,3 +265,6 @@ class TorchNativeAttnBackend(AttentionBackend): ...@@ -265,3 +265,6 @@ class TorchNativeAttnBackend(AttentionBackend):
) )
return o return o
def support_triton(self):
return False
...@@ -60,7 +60,7 @@ from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, Forw ...@@ -60,7 +60,7 @@ from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, Forw
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import flatten_nested_list, get_compiler_backend from sglang.srt.utils import flatten_nested_list, get_compiler_backend, support_triton
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
...@@ -1257,7 +1257,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1257,7 +1257,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.extend_input_logprob_token_ids = extend_input_logprob_token_ids self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
# Write to req_to_token_pool # Write to req_to_token_pool
if global_server_args_dict["attention_backend"] != "torch_native": if support_triton(global_server_args_dict.get("attention_backend")):
# TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start) # TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)
write_req_to_token_pool_triton[(bs,)]( write_req_to_token_pool_triton[(bs,)](
......
...@@ -39,7 +39,7 @@ import triton ...@@ -39,7 +39,7 @@ import triton
import triton.language as tl import triton.language as tl
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
from sglang.srt.utils import flatten_nested_list, get_compiler_backend from sglang.srt.utils import flatten_nested_list, get_compiler_backend, support_triton
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
...@@ -351,7 +351,7 @@ class ForwardBatch: ...@@ -351,7 +351,7 @@ class ForwardBatch:
ret.extend_prefix_lens = torch.tensor( ret.extend_prefix_lens = torch.tensor(
batch.extend_prefix_lens, dtype=torch.int32 batch.extend_prefix_lens, dtype=torch.int32
).to(device, non_blocking=True) ).to(device, non_blocking=True)
if model_runner.server_args.attention_backend != "torch_native": if support_triton(model_runner.server_args.attention_backend):
ret.extend_num_tokens = batch.extend_num_tokens ret.extend_num_tokens = batch.extend_num_tokens
positions, ret.extend_start_loc = compute_position_triton( positions, ret.extend_start_loc = compute_position_triton(
ret.extend_prefix_lens, ret.extend_prefix_lens,
......
...@@ -91,6 +91,7 @@ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm ...@@ -91,6 +91,7 @@ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.utils import ( from sglang.srt.utils import (
MultiprocessingSerializer, MultiprocessingSerializer,
cpu_has_amx_support,
enable_show_time_cost, enable_show_time_cost,
get_available_gpu_memory, get_available_gpu_memory,
get_bool_env_var, get_bool_env_var,
...@@ -317,6 +318,16 @@ class ModelRunner: ...@@ -317,6 +318,16 @@ class ModelRunner:
def model_specific_adjustment(self): def model_specific_adjustment(self):
server_args = self.server_args server_args = self.server_args
if (
server_args.attention_backend == "intel_amx"
and server_args.device == "cpu"
and not cpu_has_amx_support()
):
logger.info(
"The current platform does not support Intel AMX, will fallback to torch_native backend."
)
server_args.attention_backend = "torch_native"
if server_args.attention_backend is None: if server_args.attention_backend is None:
""" """
Auto select the fastest attention backend. Auto select the fastest attention backend.
...@@ -369,7 +380,10 @@ class ModelRunner: ...@@ -369,7 +380,10 @@ class ModelRunner:
f"Invalid attention backend for MLA: {server_args.attention_backend}" f"Invalid attention backend for MLA: {server_args.attention_backend}"
) )
else: else:
raise ValueError("MLA optimization not supported on CPU.") if server_args.attention_backend != "intel_amx":
raise ValueError(
"MLA optimization not supported on CPU except for intel_amx backend."
)
if ( if (
server_args.attention_backend == "fa3" server_args.attention_backend == "fa3"
...@@ -1067,6 +1081,13 @@ class ModelRunner: ...@@ -1067,6 +1081,13 @@ class ModelRunner:
) )
return CutlassMLABackend(self) return CutlassMLABackend(self)
elif self.server_args.attention_backend == "intel_amx":
from sglang.srt.layers.attention.intel_amx_backend import (
IntelAMXAttnBackend,
)
logger.info(f"Intel AMX attention backend is enabled.")
return IntelAMXAttnBackend(self)
else: else:
raise ValueError( raise ValueError(
f"Invalid attention backend: {self.server_args.attention_backend}" f"Invalid attention backend: {self.server_args.attention_backend}"
......
...@@ -323,6 +323,11 @@ class ServerArgs: ...@@ -323,6 +323,11 @@ class ServerArgs:
self.sampling_backend = "pytorch" self.sampling_backend = "pytorch"
# Set kernel backends # Set kernel backends
if self.device == "cpu":
if self.attention_backend is None:
self.attention_backend = "intel_amx"
self.sampling_backend = "pytorch"
if self.sampling_backend is None: if self.sampling_backend is None:
self.sampling_backend = ( self.sampling_backend = (
"flashinfer" if is_flashinfer_available() else "pytorch" "flashinfer" if is_flashinfer_available() else "pytorch"
...@@ -993,6 +998,7 @@ class ServerArgs: ...@@ -993,6 +998,7 @@ class ServerArgs:
"fa3", "fa3",
"flashmla", "flashmla",
"cutlass_mla", "cutlass_mla",
"intel_amx",
], ],
default=ServerArgs.attention_backend, default=ServerArgs.attention_backend,
help="Choose the kernels for attention layers.", help="Choose the kernels for attention layers.",
......
...@@ -2225,3 +2225,21 @@ def bind_or_assign(target, source): ...@@ -2225,3 +2225,21 @@ def bind_or_assign(target, source):
return target return target
else: else:
return source return source
def support_triton(backend: str) -> bool:
return backend not in ["torch_native", "intel_amx"]
try:
import sgl_kernel
is_intel_amx_backend_available = hasattr(
torch.ops.sgl_kernel, "convert_weight_packed"
)
except:
is_intel_amx_backend_available = False
def cpu_has_amx_support():
return torch._C._cpu._is_amx_tile_supported() and is_intel_amx_backend_available
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment