Unverified Commit b113c72e authored by Meng, Hengyu's avatar Meng, Hengyu Committed by GitHub
Browse files

Init attention backend for Intel XPU (#10656)


Co-authored-by: default avatarguangyey <guangye.yu@intel.com>
Co-authored-by: default avatarDiweiSun <105627594+DiweiSun@users.noreply.github.com>
parent fb6cc7b0
...@@ -24,6 +24,8 @@ FILES_TO_UPDATE = docker/Dockerfile.rocm \ ...@@ -24,6 +24,8 @@ FILES_TO_UPDATE = docker/Dockerfile.rocm \
docs/get_started/install.md \ docs/get_started/install.md \
docs/platforms/amd_gpu.md \ docs/platforms/amd_gpu.md \
docs/platforms/ascend_npu.md \ docs/platforms/ascend_npu.md \
docs/platforms/cpu_server.md \
docs/platforms/xpu.md \
benchmark/deepseek_v3/README.md benchmark/deepseek_v3/README.md
update: ## Update version numbers across project files. Usage: make update <new_version> update: ## Update version numbers across project files. Usage: make update <new_version>
......
...@@ -48,7 +48,7 @@ RUN --mount=type=secret,id=github_token \ ...@@ -48,7 +48,7 @@ RUN --mount=type=secret,id=github_token \
cd /home/sdp && \ cd /home/sdp && \
. /home/sdp/miniforge3/bin/activate && \ . /home/sdp/miniforge3/bin/activate && \
conda activate py${PYTHON_VERSION} && \ conda activate py${PYTHON_VERSION} && \
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/xpu pip3 install torch==2.8.0+xpu torchao torchvision torchaudio pytorch-triton-xpu==3.4.0 --index-url https://download.pytorch.org/whl/xpu
RUN --mount=type=secret,id=github_token \ RUN --mount=type=secret,id=github_token \
cd /home/sdp && \ cd /home/sdp && \
...@@ -59,13 +59,8 @@ RUN --mount=type=secret,id=github_token \ ...@@ -59,13 +59,8 @@ RUN --mount=type=secret,id=github_token \
cd sglang && cd python && \ cd sglang && cd python && \
cp pyproject_xpu.toml pyproject.toml && \ cp pyproject_xpu.toml pyproject.toml && \
pip install . && \ pip install . && \
echo "Cloning ${SG_LANG_KERNEL_REPO} from ${SG_LANG_KERNEL_BRANCH}" && \ pip install xgrammar --no-deps && \
git clone --branch ${SG_LANG_KERNEL_BRANCH} --single-branch ${SG_LANG_KERNEL_REPO} && \
cd sgl-kernel-xpu && \
pip install -v . && \
pip install msgspec blake3 py-cpuinfo compressed_tensors gguf partial_json_parser einops --root-user-action=ignore && \ pip install msgspec blake3 py-cpuinfo compressed_tensors gguf partial_json_parser einops --root-user-action=ignore && \
pip uninstall pytorch-triton-xpu -y && \
pip install --pre pytorch-triton-xpu --index-url https://download.pytorch.org/whl/xpu && \
conda install libsqlite=3.48.0 -y && \ conda install libsqlite=3.48.0 -y && \
# Add environment setup commands to .bashrc again (in case it was overwritten) # Add environment setup commands to .bashrc again (in case it was overwritten)
echo ". /home/sdp/miniforge3/bin/activate; conda activate py${PYTHON_VERSION}; cd /home/sdp" >> /home/sdp/.bashrc echo ". /home/sdp/miniforge3/bin/activate; conda activate py${PYTHON_VERSION}; cd /home/sdp" >> /home/sdp/.bashrc
......
...@@ -26,6 +26,7 @@ The support matrix is split into two parts: MHA (standard attention) and MLA (mu ...@@ -26,6 +26,7 @@ The support matrix is split into two parts: MHA (standard attention) and MLA (mu
| **AITER (ROCm)** | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | | **AITER (ROCm)** | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ |
| **Wave (ROCm)** | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | | **Wave (ROCm)** | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
| **Ascend (NPU)** | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | | **Ascend (NPU)** | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
| **Intel XPU** | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
### MLA Backends ### MLA Backends
...@@ -190,6 +191,13 @@ python3 -m sglang.launch_server \ ...@@ -190,6 +191,13 @@ python3 -m sglang.launch_server \
--attention-backend ascend --attention-backend ascend
``` ```
- Intel XPU
```bash
python3 -m sglang.launch_server \
--model meta-llama/Meta-Llama-3.1-8B-Instruct \
--attention-backend intel_xpu
```
- Wave - Wave
```bash ```bash
python3 -m sglang.launch_server \ python3 -m sglang.launch_server \
......
...@@ -75,6 +75,7 @@ Its core features include: ...@@ -75,6 +75,7 @@ Its core features include:
platforms/tpu.md platforms/tpu.md
platforms/nvidia_jetson.md platforms/nvidia_jetson.md
platforms/ascend_npu.md platforms/ascend_npu.md
platforms/xpu.md
.. toctree:: .. toctree::
:maxdepth: 1 :maxdepth: 1
......
# XPU
The document addresses how to set up the [SGLang](https://github.com/sgl-project/sglang) environment and run LLM inference on Intel GPU, [see more context about Intel GPU support within PyTorch ecosystem](https://docs.pytorch.org/docs/stable/notes/get_start_xpu.html).
Specifically, SGLang is optimized for [Intel® Arc™ Pro B-Series Graphics](https://www.intel.com/content/www/us/en/ark/products/series/242616/intel-arc-pro-b-series-graphics.html) and [
Intel® Arc™ B-Series Graphics](https://www.intel.com/content/www/us/en/ark/products/series/240391/intel-arc-b-series-graphics.html).
## Optimized Model List
A list of LLMs have been optimized on Intel GPU, and more are on the way:
| Model Name | BF16 |
|:---:|:---:|
| Llama-3.2-3B | [meta-llama/Llama-3.2-3B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct) |
| Llama-3.1-8B | [meta-llama/Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct) |
| Qwen2.5-1.5B | [Qwen/Qwen2.5-1.5B](https://huggingface.co/Qwen/Qwen2.5-1.5B) |
**Note:** The model identifiers listed in the table above
have been verified on [Intel® Arc™ B580 Graphics](https://www.intel.com/content/www/us/en/products/sku/241598/intel-arc-b580-graphics/specifications.html).
## Installation
### Install From Source
Currently SGLang XPU only supports installation from source. Please refer to ["Getting Started on Intel GPU"](https://docs.pytorch.org/docs/stable/notes/get_start_xpu.html) to install XPU dependency.
```bash
# Create and activate a conda environment
conda create -n sgl-xpu python=3.12 -y
conda activate sgl-xpu
# Set PyTorch XPU as primary pip install channel to avoid installing the larger CUDA-enabled version and prevent potential runtime issues.
pip3 install torch==2.8.0+xpu torchao torchvision torchaudio pytorch-triton-xpu==3.4.0 --index-url https://download.pytorch.org/whl/xpu
pip3 install xgrammar --no-deps # xgrammar will introduce CUDA-enabled triton which might conflict with XPU
# Clone the SGLang code
git clone https://github.com/sgl-project/sglang.git
cd sglang
git checkout <YOUR-DESIRED-VERSION>
# Use dedicated toml file
cd python
cp pyproject_xpu.toml pyproject.toml
# Install SGLang dependent libs, and build SGLang main package
pip install --upgrade pip setuptools
pip install -v .
```
### Install Using Docker
The docker for XPU is under active development. Please stay tuned.
## Launch of the Serving Engine
Example command to launch SGLang serving:
```bash
python -m sglang.launch_server \
--model <MODEL_ID_OR_PATH> \
--trust-remote-code \
--disable-overlap-schedule \
--device xpu \
--host 0.0.0.0 \
--tp 2 \ # using multi GPUs
--attention-backend intel_xpu \ # using intel optimized XPU attention backend
--page-size \ # intel_xpu attention backend supports [32, 64, 128]
```
## Benchmarking with Requests
You can benchmark the performance via the `bench_serving` script.
Run the command in another terminal.
```bash
python -m sglang.bench_serving \
--dataset-name random \
--random-input-len 1024 \
--random-output-len 1024 \
--num-prompts 1 \
--request-rate inf \
--random-range-ratio 1.0
```
The detail explanations of the parameters can be looked up by the command:
```bash
python -m sglang.bench_serving -h
```
Additionally, the requests can be formed with
[OpenAI Completions API](https://docs.sglang.ai/basic_usage/openai_api_completions.html)
and sent via the command line (e.g. using `curl`) or via your own script.
# xpu is not enabled in public vllm and torch whl,
# need to follow https://docs.vllm.ai/en/latest/getting_started/xpu-installation.html install vllm
[build-system] [build-system]
requires = ["setuptools>=61.0", "wheel"] requires = ["setuptools>=61.0", "wheel"]
build-backend = "setuptools.build_meta" build-backend = "setuptools.build_meta"
...@@ -17,6 +15,10 @@ classifiers = [ ...@@ -17,6 +15,10 @@ classifiers = [
] ]
dependencies = [ dependencies = [
"torch==2.8.0",
"torchaudio==2.8.0",
"torchvision",
"sgl-kernel @ git+https://github.com/sgl-project/sgl-kernel-xpu.git",
"IPython", "IPython",
"aiohttp", "aiohttp",
"anthropic>=0.20.0", "anthropic>=0.20.0",
...@@ -61,7 +63,7 @@ dependencies = [ ...@@ -61,7 +63,7 @@ dependencies = [
"transformers==4.57.1", "transformers==4.57.1",
"uvicorn", "uvicorn",
"uvloop", "uvloop",
"xgrammar==0.1.25", # "xgrammar==0.1.24", , xgrammar depends on CUDA PyTorch and Triton only
"grpcio==1.75.1", # keep it align with compile_proto.py "grpcio==1.75.1", # keep it align with compile_proto.py
"grpcio-tools==1.75.1", # keep it align with compile_proto.py "grpcio-tools==1.75.1", # keep it align with compile_proto.py
"grpcio-reflection==1.75.1", # required by srt/entrypoints/grpc_server.py "grpcio-reflection==1.75.1", # required by srt/entrypoints/grpc_server.py
......
...@@ -272,7 +272,7 @@ def prepare_synthetic_inputs_for_latency_test( ...@@ -272,7 +272,7 @@ def prepare_synthetic_inputs_for_latency_test(
def extend(reqs, model_runner): def extend(reqs, model_runner):
# Create dummy tree_cache for benchmarks (no prefix caching, just allocation) # Create dummy tree_cache for benchmarks (no prefix caching, just allocation)
dummy_tree_cache = SimpleNamespace( dummy_tree_cache = SimpleNamespace(
page_size=1, page_size=model_runner.server_args.page_size,
device=model_runner.device, device=model_runner.device,
token_to_kv_pool_allocator=model_runner.token_to_kv_pool_allocator, token_to_kv_pool_allocator=model_runner.token_to_kv_pool_allocator,
) )
......
...@@ -50,11 +50,13 @@ from sglang.srt.utils import ( ...@@ -50,11 +50,13 @@ from sglang.srt.utils import (
is_hip, is_hip,
is_npu, is_npu,
is_shm_available, is_shm_available,
is_xpu,
supports_custom_op, supports_custom_op,
) )
_is_npu = is_npu() _is_npu = is_npu()
_is_cpu = is_cpu() _is_cpu = is_cpu()
_is_xpu = is_xpu()
_supports_custom_op = supports_custom_op() _supports_custom_op = supports_custom_op()
...@@ -694,7 +696,7 @@ class GroupCoordinator: ...@@ -694,7 +696,7 @@ class GroupCoordinator:
) )
def all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor): def all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor):
if _is_npu or not _supports_custom_op: if _is_npu or _is_xpu or not _supports_custom_op:
self._all_gather_into_tensor(output, input) self._all_gather_into_tensor(output, input)
else: else:
torch.ops.sglang.reg_all_gather_into_tensor( torch.ops.sglang.reg_all_gather_into_tensor(
...@@ -1298,7 +1300,7 @@ def init_model_parallel_group( ...@@ -1298,7 +1300,7 @@ def init_model_parallel_group(
group_ranks=group_ranks, group_ranks=group_ranks,
local_rank=local_rank, local_rank=local_rank,
torch_distributed_backend=backend, torch_distributed_backend=backend,
use_pynccl=not _is_npu, use_pynccl=not (_is_npu or _is_xpu),
use_pymscclpp=use_mscclpp_allreduce, use_pymscclpp=use_mscclpp_allreduce,
use_custom_allreduce=use_custom_allreduce, use_custom_allreduce=use_custom_allreduce,
use_torch_symm_mem=use_symm_mem_allreduce, use_torch_symm_mem=use_symm_mem_allreduce,
......
...@@ -217,3 +217,10 @@ def attn_backend_wrapper(runner: "ModelRunner", full_attn_backend: "AttentionBac ...@@ -217,3 +217,10 @@ def attn_backend_wrapper(runner: "ModelRunner", full_attn_backend: "AttentionBac
) )
return full_attn_backend return full_attn_backend
@register_attention_backend("intel_xpu")
def create_intel_xpu_backend(runner):
from sglang.srt.layers.attention.xpu_backend import XPUAttentionBackend
return XPUAttentionBackend(runner)
...@@ -12,6 +12,8 @@ import triton ...@@ -12,6 +12,8 @@ import triton
import triton.language as tl import triton.language as tl
from einops import rearrange from einops import rearrange
from sglang.srt.utils import device_context
def rms_norm_ref( def rms_norm_ref(
x, x,
...@@ -157,7 +159,7 @@ def _layer_norm_fwd( ...@@ -157,7 +159,7 @@ def _layer_norm_fwd(
# heuristics for number of warps # heuristics for number of warps
num_warps = min(max(BLOCK_N // 256, 1), 8) num_warps = min(max(BLOCK_N // 256, 1), 8)
grid = (M, ngroups) grid = (M, ngroups)
with torch.get_device_module(x.device).device(x.device.index): with device_context(x.device):
_layer_norm_fwd_1pass_kernel[grid]( _layer_norm_fwd_1pass_kernel[grid](
x, x,
out, out,
......
from __future__ import annotations
from typing import TYPE_CHECKING, Optional
import torch
from sglang.srt.configs.model_config import AttentionArch
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.flashattention_backend import (
FlashAttentionMetadata,
make_local_attention_virtual_batches,
merge_state_v2_wrapper,
prepare_swa_spec_page_table_triton,
)
from sglang.srt.managers.schedule_batch import get_global_server_args
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner
from sgl_kernel import merge_state_v2
from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
class XPUAttentionBackend(AttentionBackend):
"""XPU FlashAttention backend, currently based on FlashAttentionBackend, will be refactored later.
TODO:
- Prefill and Decode disaggregation, currently only chunked prefill is supported
- Speculative Decoding support
- XPU Graph support, see https://github.com/pytorch/pytorch/issues/162143
- MLA support
"""
def __init__(
self,
model_runner: ModelRunner,
skip_prefill: bool = False,
speculative_step_id=0,
topk=0,
speculative_num_steps=0,
):
super().__init__()
assert not (
model_runner.sliding_window_size is not None
and model_runner.model_config.is_encoder_decoder
), "Sliding window and cross attention are not supported together"
self.forward_metadata: FlashAttentionMetadata = None
# extra metadata for handling speculative decoding topk > 1, extended draft decode and verify
self.forward_metadata_spec_decode_expand: FlashAttentionMetadata = None
self.max_context_len = model_runner.model_config.context_len
self.device = model_runner.device
self.decode_cuda_graph_metadata = {}
self.target_verify_metadata = {}
self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.kv_cache_dtype = model_runner.kv_cache_dtype
self.kv_cache_dtype_str = model_runner.server_args.kv_cache_dtype
self.page_size = model_runner.page_size
self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA
assert (
self.use_mla is False
), "XPUAttentionBackend doesn't support MLA yet, please use --attention-backend triton instead."
self.skip_prefill = skip_prefill
self.is_hybrid = model_runner.is_hybrid
if self.is_hybrid:
self.full_to_swa_index_mapping = (
model_runner.token_to_kv_pool.full_to_swa_index_mapping
)
self.topk = model_runner.server_args.speculative_eagle_topk or 0
self.speculative_num_steps = speculative_num_steps
self.speculative_num_draft_tokens = (
model_runner.server_args.speculative_num_draft_tokens
)
self.speculative_step_id = speculative_step_id
# Local attention settings
self.attention_chunk_size = (
model_runner.attention_chunk_size
if hasattr(model_runner, "attention_chunk_size")
else None
)
# For each layer, the sliding_window_size can be different. This is only used for preparing SWA metadata.
# We use `layer.sliding_window_size` to decide whether to use SWA for each layer.
self.sliding_window_size = model_runner.sliding_window_size
self.has_swa = (
self.sliding_window_size is not None and self.sliding_window_size > -1
)
def init_forward_metadata(self, forward_batch: ForwardBatch):
"""Initialize forward metadata hence all layers in the forward pass can reuse it."""
metadata = FlashAttentionMetadata()
seqlens_in_batch = forward_batch.seq_lens
batch_size = forward_batch.batch_size
device = seqlens_in_batch.device
if forward_batch.forward_mode.is_decode_or_idle():
# Draft Decode
if forward_batch.spec_info is not None:
assert (
False
), "XPUAttentionBackend doesn't support speculative decoding yet, please use --attention-backend triton instead."
if self.topk <= 1:
metadata.cache_seqlens_int32 = (
seqlens_in_batch + (self.speculative_step_id + 1)
).to(torch.int32)
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + (
self.speculative_step_id + 1
)
metadata.cu_seqlens_q = torch.arange(
0, batch_size + 1, dtype=torch.int32, device=device
)
metadata.cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
),
(1, 0),
)
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.max_seq_len_k
]
else:
metadata.cache_seqlens_int32 = (seqlens_in_batch).to(torch.int32)
metadata.max_seq_len_q = self.topk
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
metadata.cu_seqlens_q = torch.arange(
0,
batch_size * self.topk + 1,
step=self.topk,
dtype=torch.int32,
device=device,
)
metadata.cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
),
(1, 0),
)
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.max_seq_len_k
]
metadata_expand = FlashAttentionMetadata()
decode_length = self.speculative_step_id + 1
metadata_expand.cache_seqlens_int32 = torch.full(
(seqlens_in_batch.numel() * self.topk,),
decode_length,
device=device,
dtype=torch.int32,
)
metadata_expand.max_seq_len_q = 1
metadata_expand.cu_seqlens_q = torch.arange(
0,
metadata_expand.cache_seqlens_int32.numel() + 1,
dtype=torch.int32,
device=device,
)
metadata_expand.cu_seqlens_k = torch.arange(
0,
metadata_expand.cache_seqlens_int32.numel() * decode_length + 1,
step=decode_length,
dtype=torch.int32,
device=device,
)
# shape: [bs, num_steps, topk] -> [bs x topk, num_steps]
cache_loc = forward_batch.out_cache_loc.view(
-1, self.speculative_num_steps
)
metadata_expand.page_table = (
cache_loc[:, :decode_length].contiguous().to(torch.int32)
)
self.forward_metadata_spec_decode_expand = metadata_expand
else:
# Normal Decode
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
metadata.cu_seqlens_q = torch.arange(
0, batch_size + 1, dtype=torch.int32, device=device
)
metadata.cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
)
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.max_seq_len_k
]
# TODO: we need to test this part for llama 4 eagle case
self._init_local_attn_metadata(forward_batch, metadata, device)
elif forward_batch.forward_mode.is_target_verify():
if self.topk <= 1:
metadata.cache_seqlens_int32 = (
forward_batch.seq_lens + self.speculative_num_draft_tokens
).to(torch.int32)
metadata.max_seq_len_q = self.speculative_num_draft_tokens
metadata.max_seq_len_k = (
forward_batch.seq_lens_cpu.max().item()
+ self.speculative_num_draft_tokens
)
metadata.cu_seqlens_q = torch.arange(
0,
batch_size * self.speculative_num_draft_tokens + 1,
self.speculative_num_draft_tokens,
dtype=torch.int32,
device=device,
)
metadata.cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
),
(1, 0),
)
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.max_seq_len_k
]
self._init_local_attn_metadata(forward_batch, metadata, device)
else:
metadata.cache_seqlens_int32 = forward_batch.seq_lens.to(torch.int32)
metadata.max_seq_len_q = self.speculative_num_draft_tokens
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
metadata.cu_seqlens_q = torch.arange(
0,
batch_size * self.speculative_num_draft_tokens + 1,
step=self.speculative_num_draft_tokens,
dtype=torch.int32,
device=device,
)
metadata.cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
),
(1, 0),
)
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.max_seq_len_k
]
metadata_expand = FlashAttentionMetadata()
metadata_expand.max_seq_len_q = 1
metadata_expand.cu_seqlens_q = torch.arange(
0,
forward_batch.seq_lens.numel() * self.speculative_num_draft_tokens
+ 1,
dtype=torch.int32,
device=device,
)
# create expand page table
offsets = torch.arange(
self.speculative_num_draft_tokens, device=device
).unsqueeze(
0
) # shape: (1, self.speculative_num_draft_tokens)
cols = offsets.expand(
forward_batch.seq_lens.numel(), -1
) + forward_batch.seq_lens.unsqueeze(1)
cum_len = torch.nn.functional.pad(
torch.cumsum(
(
forward_batch.seq_lens + self.speculative_num_draft_tokens
).repeat_interleave(self.speculative_num_draft_tokens),
dim=0,
),
(1, 0),
)[:-1]
mask_extraction_indices = (
cols.repeat_interleave(self.speculative_num_draft_tokens, dim=0)
+ cum_len[:, None]
).view(1, -1)
mask = forward_batch.spec_info.custom_mask[
mask_extraction_indices
].view(
-1, self.speculative_num_draft_tokens
) # (bsz * draft_num, draft_num)
# shift table indices to avoid padding
# non_masked_page_table [[8, 9, 10], mask (display with int format) [[1, 0, 0],
# [8, 9, 10], [1, 1, 0],
# [8, 9, 10]] [1, 0, 1]]
# if masked with padding [[8, 0, 0], our mask without padding [[8, 9, 10],
# [8, 9, 0], [8, 9, 10],
# [8, 0, 10]] [8, 10, 9]]
# note here cache_seqlens_int32 is [1, 2, 2] so extra page indices will be ignored in each row
col_indices = offsets.expand(
mask.shape[0], self.speculative_num_draft_tokens
)
# Build keys: if an entry is valid (mask==True), keep its original index;
# if not, add self.speculative_num_draft_tokens so that it sorts after all valid entries.
keys = torch.where(
mask, col_indices, col_indices + self.speculative_num_draft_tokens
)
_, sort_order = torch.sort(keys, dim=1)
non_masked_page_table = (
forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, :
]
.gather(1, cols)
.repeat_interleave(self.speculative_num_draft_tokens, dim=0)
) # (bsz, draft_num)
metadata_expand.page_table = non_masked_page_table.gather(1, sort_order)
metadata_expand.cache_seqlens_int32 = mask.sum(dim=1).to(torch.int32)
metadata_expand.cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(
metadata_expand.cache_seqlens_int32, dim=0, dtype=torch.int32
),
(1, 0),
)
self.forward_metadata_spec_decode_expand = metadata_expand
if self.has_swa:
self._init_sliding_window_attn_spec_metadata(
metadata, metadata_expand
)
elif forward_batch.forward_mode.is_extend_or_draft_extend_or_mixed():
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
metadata.cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
)
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.max_seq_len_k
]
if (
any(forward_batch.extend_prefix_lens_cpu)
or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND
):
extend_seq_lens = forward_batch.extend_seq_lens
metadata.max_seq_len_q = max(forward_batch.extend_seq_lens_cpu)
metadata.cu_seqlens_q = torch.nn.functional.pad(
torch.cumsum(extend_seq_lens, dim=0, dtype=torch.int32), (1, 0)
)
else:
metadata.max_seq_len_q = metadata.max_seq_len_k
metadata.cu_seqlens_q = metadata.cu_seqlens_k
# Setup local attention if enabled
if forward_batch.forward_mode == ForwardMode.EXTEND:
self._init_local_attn_metadata(forward_batch, metadata, device)
# Encoder metadata for cross attention
if forward_batch.encoder_lens is not None:
assert (
forward_batch.encoder_lens.numel() == 1
), "Only encoder size 1 is supported for now"
metadata.encoder_lens_int32 = forward_batch.encoder_lens.to(torch.int32)
metadata.encoder_cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(metadata.encoder_lens_int32, dim=0, dtype=torch.int32),
(1, 0),
)
metadata.encoder_max_seq_len_k = metadata.encoder_lens_int32.max().item()
metadata.encoder_page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.encoder_max_seq_len_k
]
# Currently only support forward_batch.encoder_lens.numel() == 1
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices,
metadata.encoder_max_seq_len_k : (
metadata.encoder_max_seq_len_k + metadata.max_seq_len_k
),
]
# Convert the page table to a strided format which is needed by FA3 API
if self.page_size > 1:
self.strided_indices = torch.arange(
0, metadata.page_table.shape[1], self.page_size, device=self.device
)
metadata.page_table = (
metadata.page_table[:, self.strided_indices] // self.page_size
)
self.forward_metadata = metadata
def forward_extend(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
# For multi-head latent attention
q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
sinks: Optional[torch.Tensor] = None,
):
if k is not None:
assert v is not None
if save_kv_cache:
cache_loc = (
forward_batch.out_cache_loc
if not layer.is_cross_attention
else forward_batch.encoder_out_cache_loc
)
if not self.use_mla:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)
else:
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
layer,
cache_loc,
k,
k_rope,
)
# Use precomputed metadata across all layers
metadata = self.forward_metadata
# Calculate window size (can be moved to metadata if layer properties don't change)
# we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
# here is two side inclusive
is_swa = (
layer.sliding_window_size is not None and layer.sliding_window_size > -1
)
window_size = (layer.sliding_window_size, 0) if is_swa else (-1, -1)
# currently no FP8 KV cache supported
k_descale, v_descale = None, None
# # only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
# # has corresponding quantization method so that layer.k_scale is not None,
# # 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case.
# if self.kv_cache_dtype_str != "auto" and layer.head_dim <= 256:
# if layer.k_scale is not None:
# descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
# k_descale = layer.k_scale.expand(descale_shape)
# v_descale = layer.v_scale.expand(descale_shape)
# q = q.to(self.kv_cache_dtype)
# q_rope = q_rope.to(self.kv_cache_dtype) if q_rope is not None else None
# k_rope = k_rope.to(self.kv_cache_dtype) if k_rope is not None else None
causal = not layer.is_cross_attention
# Check if we should use local attention
use_local_attn = (
self.attention_chunk_size is not None
and metadata.local_attn_metadata is not None
and (hasattr(layer, "use_irope") and layer.use_irope)
)
# We do cascade attention for Target Verify with topk > 1
# We don't use cascade attention for Sliding Window Attention:
# - Different window sizes should be passed in for each q in the first stage of cascade attention, but FA3 interface doesn't support pass in a list of window sizes.
# - The overhead of duplicated computation of the common prefix part is small for sliding window layers (seq_len <= window_size), so we can just expand it.
use_cascade_attn = (
forward_batch.forward_mode.is_target_verify()
and self.topk > 1
and not is_swa
)
# For fa3 interface version compatibility, we put new fields into conditional keyword args
kwargs = {}
if sinks is not None:
kwargs["sinks"] = sinks
# Get the appropriate page table based on whether we're using local attention
if use_local_attn:
local_metadata = metadata.local_attn_metadata
page_table = local_metadata.local_block_table
cu_seqlens_q = local_metadata.local_query_start_loc
cache_seqlens = local_metadata.local_seqused_k
max_seqlen_q = local_metadata.local_max_query_len
elif is_swa and metadata.swa_spec_metadata is not None:
swa_spec_metadata = metadata.swa_spec_metadata
page_table = swa_spec_metadata.page_table
cu_seqlens_q = swa_spec_metadata.cu_seqlens_q
cache_seqlens = swa_spec_metadata.cache_seqlens_int32
max_seqlen_q = swa_spec_metadata.max_seq_len_q
cu_seqlens_k = swa_spec_metadata.cu_seqlens_k
else:
page_table = metadata.page_table
cu_seqlens_q = metadata.cu_seqlens_q
cache_seqlens = metadata.cache_seqlens_int32
max_seqlen_q = metadata.max_seq_len_q
cu_seqlens_k = metadata.cu_seqlens_k
# Use Flash Attention for prefill
if not self.use_mla:
# Do multi-head attention
key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer(
layer.layer_id
)
key_cache = key_cache.view(
-1, self.page_size, layer.tp_k_head_num, layer.head_dim
)
value_cache = value_cache.view(
-1, self.page_size, layer.tp_v_head_num, layer.head_dim
)
if layer.is_cross_attention:
page_table = metadata.encoder_page_table
cache_seqlens = metadata.encoder_lens_int32
cu_seqlens_k = metadata.encoder_cu_seqlens_k
window_size = (-1, -1)
result = flash_attn_with_kvcache(
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
k_cache=key_cache,
v_cache=value_cache,
page_table=page_table,
cache_seqlens=cache_seqlens,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
max_seqlen_q=max_seqlen_q,
softmax_scale=layer.scaling,
causal=False if use_cascade_attn else causal,
window_size=window_size,
softcap=layer.logit_cap,
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=use_cascade_attn,
**kwargs,
)
if use_cascade_attn:
o, softmax_lse, *rest = result
o_expand, softmax_lse_expand, *rest_expand = flash_attn_with_kvcache(
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
k_cache=key_cache,
v_cache=value_cache,
page_table=self.forward_metadata_spec_decode_expand.page_table,
cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32,
cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q,
cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k,
max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q,
softmax_scale=layer.scaling,
causal=False,
window_size=window_size,
softcap=layer.logit_cap,
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=True,
**kwargs,
)
o, _ = merge_state_v2_wrapper(
o,
softmax_lse.T.contiguous(),
o_expand,
softmax_lse_expand.T.contiguous(),
)
else:
o = result
else:
if (
forward_batch.attn_attend_prefix_cache is not None
and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend()
):
# Do multi-head attention with chunked prefix cache
if forward_batch.attn_attend_prefix_cache:
assert not get_global_server_args().disable_chunked_prefix_cache
# MHA for chunked prefix kv cache when running model with MLA
assert forward_batch.prefix_chunk_idx is not None
assert forward_batch.prefix_chunk_cu_seq_lens is not None
assert forward_batch.prefix_chunk_max_seq_lens is not None
chunk_idx = forward_batch.prefix_chunk_idx
assert chunk_idx >= 0
assert forward_batch.mha_return_lse
output = flash_attn_varlen_func(
q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
k=k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype),
cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k=forward_batch.prefix_chunk_cu_seq_lens[chunk_idx],
max_seqlen_q=metadata.max_seq_len_q,
max_seqlen_k=forward_batch.prefix_chunk_max_seq_lens[chunk_idx],
softmax_scale=layer.scaling,
causal=False,
return_softmax_lse=True,
)
else:
# MHA for extend part of sequence without attending prefix kv cache
output = flash_attn_varlen_func(
q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
k=k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype),
cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k=metadata.cu_seqlens_q,
max_seqlen_q=metadata.max_seq_len_q,
max_seqlen_k=metadata.max_seq_len_q,
softmax_scale=layer.scaling,
causal=True,
return_softmax_lse=forward_batch.mha_return_lse,
)
if forward_batch.mha_return_lse:
output, lse, *rest = output
lse = torch.transpose(lse, 0, 1).contiguous()
return output, lse
return output
else:
# Do absorbed multi-latent attention
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(
layer.layer_id
).to(q.dtype)
k_rope = kv_cache[:, :, layer.v_head_dim :]
c_kv = kv_cache[:, :, : layer.v_head_dim]
k_rope_cache = k_rope.view(
-1,
self.page_size,
layer.tp_k_head_num,
layer.head_dim - layer.v_head_dim,
)
c_kv_cache = c_kv.view(
-1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
)
if q_rope is not None:
q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
q_rope = q_rope.view(
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
)
else:
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
q_nope = q_all[:, :, : layer.v_head_dim]
q_rope = q_all[:, :, layer.v_head_dim :]
result = flash_attn_with_kvcache(
q=q_rope,
k_cache=k_rope_cache,
v_cache=c_kv_cache,
qv=q_nope,
page_table=page_table,
cache_seqlens=cache_seqlens,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
max_seqlen_q=max_seqlen_q,
softmax_scale=layer.scaling,
causal=False if use_cascade_attn else causal,
softcap=layer.logit_cap,
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=use_cascade_attn,
)
if use_cascade_attn:
o, softmax_lse, *rest = result
o_expand, softmax_lse_expand, *rest_expand = (
flash_attn_with_kvcache(
q=q_rope,
k_cache=k_rope_cache,
v_cache=c_kv_cache,
qv=q_nope,
page_table=self.forward_metadata_spec_decode_expand.page_table,
cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32,
cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q,
cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k,
max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q,
softmax_scale=layer.scaling,
causal=False,
window_size=window_size,
softcap=layer.logit_cap,
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=True,
)
)
o, _ = merge_state_v2_wrapper(
o,
softmax_lse.T.contiguous(),
o_expand,
softmax_lse_expand.T.contiguous(),
)
else:
o = result
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
def forward_decode(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
# For multi-head latent attention
q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
sinks: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if k is not None:
assert v is not None
if save_kv_cache:
cache_loc = (
forward_batch.out_cache_loc
if not layer.is_cross_attention
else forward_batch.encoder_out_cache_loc
)
if not self.use_mla:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)
else:
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
layer,
cache_loc,
k,
k_rope,
)
# Use precomputed metadata across all layers
metadata = self.forward_metadata
local_attn_metadata = getattr(metadata, "local_attn_metadata", None)
use_local_attn = (
self.attention_chunk_size is not None
and local_attn_metadata is not None
and (hasattr(layer, "use_irope") and layer.use_irope)
)
# When Spec Decode enabled, forward_decode would be called with two mode:
# 1. DRAFT_DECODE: we enable cascade attention when top_k > 1
# 2. IDLE: we don’t need cascade attention, spec_info will be none in this case
use_cascade_attn = forward_batch.spec_info is not None and self.topk > 1
# Calculate window size (can be moved to metadata if layer properties don't change)
# we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
# here is two side inclusive
window_size = (
(layer.sliding_window_size, 0)
if layer.sliding_window_size is not None and layer.sliding_window_size > -1
else (-1, -1)
)
causal = not layer.is_cross_attention
# For fa3 interface version compatibility, we put new fields into conditional keyword args
kwargs = {}
if sinks is not None:
kwargs["sinks"] = sinks
k_descale, v_descale = None, None
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
# has corresponding quantization method so that layer.k_scale is not None,
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case.
if self.kv_cache_dtype_str != "auto" and layer.head_dim <= 256:
if layer.k_scale is not None:
descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
k_descale = layer.k_scale.expand(descale_shape)
v_descale = layer.v_scale.expand(descale_shape)
q = q.to(self.kv_cache_dtype)
q_rope = q_rope.to(self.kv_cache_dtype) if q_rope is not None else None
k_rope = k_rope.to(self.kv_cache_dtype) if k_rope is not None else None
if not self.use_mla:
# Do multi-head attention
key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer(
layer.layer_id
)
key_cache = key_cache.view(
-1, self.page_size, layer.tp_k_head_num, layer.head_dim
)
value_cache = value_cache.view(
-1, self.page_size, layer.tp_v_head_num, layer.head_dim
)
if layer.is_cross_attention:
# Always use non-chunked logic for cross-attention
o = flash_attn_with_kvcache(
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
k_cache=key_cache,
v_cache=value_cache,
page_table=metadata.encoder_page_table,
cache_seqlens=metadata.encoder_lens_int32,
cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k_new=metadata.encoder_cu_seqlens_k,
max_seqlen_q=1,
softmax_scale=layer.scaling,
causal=False,
window_size=(-1, -1),
softcap=layer.logit_cap,
k_descale=k_descale,
v_descale=v_descale,
**kwargs,
)
elif use_local_attn:
# Use chunked (local) attention batching for self-attention
o = flash_attn_with_kvcache(
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
k_cache=key_cache,
v_cache=value_cache,
page_table=local_attn_metadata.local_block_table,
cache_seqlens=local_attn_metadata.local_seqused_k,
cu_seqlens_q=local_attn_metadata.local_query_start_loc,
cu_seqlens_k_new=None,
max_seqlen_q=local_attn_metadata.local_max_query_len,
softmax_scale=layer.scaling,
causal=True,
window_size=(-1, -1),
softcap=layer.logit_cap,
k_descale=k_descale,
v_descale=v_descale,
**kwargs,
)
else:
page_table = metadata.page_table
cache_seqlens = metadata.cache_seqlens_int32
cu_seqlens_k = metadata.cu_seqlens_k
max_seqlen_q = metadata.max_seq_len_q
q_reshaped = q.contiguous().view(
-1, layer.tp_q_head_num, layer.head_dim
)
# Default: single-token self-attention
result = flash_attn_with_kvcache(
q=q_reshaped,
k_cache=key_cache,
v_cache=value_cache,
page_table=page_table,
cache_seqlens=cache_seqlens,
cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k_new=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
softmax_scale=layer.scaling,
causal=False if use_cascade_attn else causal,
window_size=window_size,
softcap=layer.logit_cap,
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=use_cascade_attn,
**kwargs,
)
if use_cascade_attn:
o, softmax_lse, *rest = result
o_expand, softmax_lse_expand, *rest_expand = (
flash_attn_with_kvcache(
q=q_reshaped,
k_cache=key_cache,
v_cache=value_cache,
page_table=self.forward_metadata_spec_decode_expand.page_table,
cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32,
cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q,
cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k,
max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q,
softmax_scale=layer.scaling,
causal=False,
window_size=window_size,
softcap=layer.logit_cap,
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=True,
**kwargs,
)
)
o, _ = merge_state_v2(
o,
softmax_lse.T.contiguous(),
o_expand,
softmax_lse_expand.T.contiguous(),
)
else:
o = result
else:
# Do absorbed multi-latent attention
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to(
q.dtype
)
k_rope = kv_cache[:, :, layer.v_head_dim :]
c_kv = kv_cache[:, :, : layer.v_head_dim]
k_rope_cache = k_rope.view(
-1,
self.page_size,
layer.tp_k_head_num,
layer.head_dim - layer.v_head_dim,
)
c_kv_cache = c_kv.view(
-1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
)
if q_rope is not None:
q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
q_rope = q_rope.view(
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
)
else:
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
q_nope = q_all[:, :, : layer.v_head_dim]
q_rope = q_all[:, :, layer.v_head_dim :]
max_seqlen_q = metadata.max_seq_len_q
result = flash_attn_with_kvcache(
q=q_rope,
k_cache=k_rope_cache,
v_cache=c_kv_cache,
qv=q_nope,
page_table=metadata.page_table,
cache_seqlens=metadata.cache_seqlens_int32,
cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k_new=metadata.cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
softmax_scale=layer.scaling,
causal=False if use_cascade_attn else causal,
softcap=layer.logit_cap,
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=use_cascade_attn, # softmax_lse is needed for merge states
)
if use_cascade_attn:
o, softmax_lse, *rest = result
o_expand, softmax_lse_expand, *rest_expand = flash_attn_with_kvcache(
q=q_rope,
k_cache=k_rope_cache,
v_cache=c_kv_cache,
qv=q_nope,
page_table=self.forward_metadata_spec_decode_expand.page_table,
cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32,
cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q,
cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k,
max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q,
softmax_scale=layer.scaling,
causal=False,
window_size=window_size,
softcap=layer.logit_cap,
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=True,
)
o, _ = merge_state_v2(
o,
softmax_lse.T.contiguous(),
o_expand,
softmax_lse_expand.T.contiguous(),
)
else:
o = result
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
def get_cuda_graph_seq_len_fill_value(self):
"""Get the fill value for sequence length in CUDA graph."""
return 1
def _init_local_attn_metadata(
self, forwardbatch: ForwardBatch, metadata: FlashAttentionMetadata, device
):
"""Centralized utility to initialize local_attn_metadata if chunked attention is enabled."""
if self.attention_chunk_size is None:
metadata.local_attn_metadata = None
return
cu_seqlens_q = metadata.cu_seqlens_q
cache_seqlens_int32 = metadata.cache_seqlens_int32
if self.is_hybrid:
page_table = self.full_to_swa_index_mapping[metadata.page_table].to(
torch.int32
)
else:
page_table = metadata.page_table
if cu_seqlens_q is None or cache_seqlens_int32 is None or page_table is None:
metadata.local_attn_metadata = None
return
cu_seqlens_q_np = cu_seqlens_q.cpu().numpy()
seq_lens_np = cache_seqlens_int32.cpu().numpy()
(
seqlens_q_local_np,
cu_seqlens_q_local_np,
seqlens_k_local_np,
block_table_local,
) = make_local_attention_virtual_batches(
self.attention_chunk_size,
cu_seqlens_q_np,
seq_lens_np,
page_table,
self.page_size,
)
local_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
local_query_start_loc=torch.from_numpy(cu_seqlens_q_local_np).to(device),
local_seqused_k=torch.from_numpy(seqlens_k_local_np).to(device),
local_block_table=block_table_local.to(device),
local_max_query_len=int(seqlens_q_local_np.max()),
local_max_seq_len=int(seqlens_k_local_np.max()),
)
metadata.local_attn_metadata = local_metadata
def _init_sliding_window_attn_spec_metadata(
self,
metadata: FlashAttentionMetadata,
metadata_expand: FlashAttentionMetadata,
metadata_swa: Optional[FlashAttentionMetadata] = None,
):
# TODO: support page_size > 1 for swa spec
assert (
self.page_size == 1
), "FlashAttention backend doesn't support topk > 1 speculative decoding with page size > 1 sliding window attention"
cache_seqlens_int32 = (
metadata.cache_seqlens_int32.repeat_interleave(
self.speculative_num_draft_tokens
)
+ metadata_expand.cache_seqlens_int32
)
cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(cache_seqlens_int32, dim=0, dtype=torch.int32), (1, 0)
)
bs = cache_seqlens_int32.shape[0]
page_table = (
metadata.page_table.new_zeros(
(bs, metadata.max_seq_len_k + metadata_expand.page_table.shape[1])
)
if metadata_swa is None
else metadata_swa.page_table
)
prepare_swa_spec_page_table_triton(
page_table,
metadata.page_table,
metadata_expand.page_table,
metadata.cache_seqlens_int32,
metadata_expand.cache_seqlens_int32,
self.speculative_num_draft_tokens,
)
if metadata_swa is None:
metadata_swa = FlashAttentionMetadata()
metadata_swa.max_seq_len_q = 1
metadata_swa.cu_seqlens_q = metadata_expand.cu_seqlens_q
metadata_swa.cache_seqlens_int32 = cache_seqlens_int32
metadata_swa.cu_seqlens_k = cu_seqlens_k
metadata_swa.page_table = page_table
else:
metadata_swa.cache_seqlens_int32.copy_(cache_seqlens_int32)
metadata_swa.cu_seqlens_k.copy_(cu_seqlens_k)
metadata.swa_spec_metadata = metadata_swa
...@@ -42,7 +42,7 @@ _is_cpu_amx_available = cpu_has_amx_support() ...@@ -42,7 +42,7 @@ _is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu() _is_cpu = is_cpu()
_is_xpu = is_xpu() _is_xpu = is_xpu()
if _is_cuda: if _is_cuda or _is_xpu:
# if _is_flashinfer_available: # if _is_flashinfer_available:
# from flashinfer.norm import fused_add_rmsnorm # from flashinfer.norm import fused_add_rmsnorm
# else: # else:
...@@ -52,13 +52,6 @@ if _is_cuda: ...@@ -52,13 +52,6 @@ if _is_cuda:
gemma_rmsnorm, gemma_rmsnorm,
rmsnorm, rmsnorm,
) )
elif _is_xpu:
from sgl_kernel import (
fused_add_rmsnorm,
gemma_fused_add_rmsnorm,
gemma_rmsnorm,
rmsnorm,
)
if _use_aiter: if _use_aiter:
from aiter import rmsnorm2d_fwd as rms_norm from aiter import rmsnorm2d_fwd as rms_norm
from aiter import rmsnorm2d_fwd_with_add as fused_add_rms_norm from aiter import rmsnorm2d_fwd_with_add as fused_add_rms_norm
......
...@@ -39,10 +39,11 @@ if TYPE_CHECKING: ...@@ -39,10 +39,11 @@ if TYPE_CHECKING:
CombineInput, CombineInput,
) )
from sglang.srt.utils import is_cuda, is_hip from sglang.srt.utils import is_cuda, is_hip, is_xpu
_is_cuda = is_cuda() _is_cuda = is_cuda()
_is_hip = is_hip() _is_hip = is_hip()
_is_xpu = is_xpu()
if _is_cuda: if _is_cuda:
from sgl_kernel import ( from sgl_kernel import (
awq_dequantize, awq_dequantize,
...@@ -58,8 +59,12 @@ elif _is_hip: ...@@ -58,8 +59,12 @@ elif _is_hip:
) )
warnings.warn(f"HIP does not support fused_marlin_moe currently.") warnings.warn(f"HIP does not support fused_marlin_moe currently.")
elif _is_xpu:
from sgl_kernel import awq_dequantize
warnings.warn(f"XPU does not support fused_marlin_moe currently.")
else: else:
warnings.warn(f"Only CUDA and HIP support AWQ currently.") warnings.warn(f"Only CUDA, HIP and XPU support AWQ currently.")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -115,7 +115,7 @@ class RotaryEmbedding(CustomOp): ...@@ -115,7 +115,7 @@ class RotaryEmbedding(CustomOp):
if dtype == torch.float32 or ( if dtype == torch.float32 or (
(not (_is_cuda or _is_npu) or self.head_size not in [64, 128, 256, 512]) (not (_is_cuda or _is_npu) or self.head_size not in [64, 128, 256, 512])
and not (_is_cpu and _is_cpu_amx_available) and not (_is_cpu and _is_cpu_amx_available)
and not _is_xpu and not (_is_xpu)
): ):
from vllm._custom_ops import rotary_embedding from vllm._custom_ops import rotary_embedding
...@@ -302,6 +302,7 @@ class RotaryEmbedding(CustomOp): ...@@ -302,6 +302,7 @@ class RotaryEmbedding(CustomOp):
offsets: Optional[torch.Tensor] = None, offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# TODO: make a wrapper, and XPU will implement this kernel later. # TODO: make a wrapper, and XPU will implement this kernel later.
self.cos_sin_cache = self.cos_sin_cache.to(query.device)
return self.forward_native(positions, query, key, offsets) return self.forward_native(positions, query, key, offsets)
......
...@@ -142,6 +142,7 @@ from sglang.srt.utils import ( ...@@ -142,6 +142,7 @@ from sglang.srt.utils import (
monkey_patch_vllm_gguf_config, monkey_patch_vllm_gguf_config,
set_cuda_arch, set_cuda_arch,
slow_rank_detector, slow_rank_detector,
xpu_has_xmx_support,
) )
from sglang.srt.utils.offloader import ( from sglang.srt.utils.offloader import (
create_offloader_from_server_args, create_offloader_from_server_args,
...@@ -195,6 +196,7 @@ def add_chunked_prefix_cache_attention_backend(backend_name): ...@@ -195,6 +196,7 @@ def add_chunked_prefix_cache_attention_backend(backend_name):
_is_hip = is_hip() _is_hip = is_hip()
_is_npu = is_npu() _is_npu = is_npu()
_is_cpu_amx_available = cpu_has_amx_support() _is_cpu_amx_available = cpu_has_amx_support()
_is_xpu_xmx_available = xpu_has_xmx_support()
# Use a small KV cache pool size for tests in CI # Use a small KV cache pool size for tests in CI
SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None) SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
...@@ -505,6 +507,16 @@ class ModelRunner: ...@@ -505,6 +507,16 @@ class ModelRunner:
) )
server_args.attention_backend = "torch_native" server_args.attention_backend = "torch_native"
if (
server_args.attention_backend == "intel_xpu"
and server_args.device == "xpu"
and not _is_xpu_xmx_available
):
logger.info(
"The current platform does not support Intel XMX, will fallback to triton backend."
)
server_args.attention_backend = "triton"
if server_args.prefill_attention_backend is not None and ( if server_args.prefill_attention_backend is not None and (
server_args.prefill_attention_backend server_args.prefill_attention_backend
== server_args.decode_attention_backend == server_args.decode_attention_backend
......
...@@ -114,6 +114,7 @@ ATTENTION_BACKEND_CHOICES = [ ...@@ -114,6 +114,7 @@ ATTENTION_BACKEND_CHOICES = [
# Other platforms # Other platforms
"intel_amx", "intel_amx",
"ascend", "ascend",
"intel_xpu",
] ]
LORA_BACKEND_CHOICES = ["triton", "csgmv"] LORA_BACKEND_CHOICES = ["triton", "csgmv"]
...@@ -1098,6 +1099,12 @@ class ServerArgs: ...@@ -1098,6 +1099,12 @@ class ServerArgs:
self.enable_mixed_chunk = False self.enable_mixed_chunk = False
self.disable_radix_cache = True self.disable_radix_cache = True
if self.attention_backend == "intel_xpu":
if self.page_size not in [32, 64, 128]:
logger.warning(
f"Intel XPU attention backend only supports page_size of 32, 64 or 128, changing page_size from {self.page_size} to 128."
)
self.page_size = 128
if self.attention_backend == "fa4" or self.decode_attention_backend == "fa4": if self.attention_backend == "fa4" or self.decode_attention_backend == "fa4":
raise ValueError( raise ValueError(
"FA4 backend is only supported for prefill. Please use `--prefill-attention-backend fa4` instead." "FA4 backend is only supported for prefill. Please use `--prefill-attention-backend fa4` instead."
......
...@@ -163,6 +163,20 @@ def _check(cc_major): ...@@ -163,6 +163,20 @@ def _check(cc_major):
) >= (12, 3) ) >= (12, 3)
@contextmanager
def device_context(device: torch.device):
if device.type == "cpu" and is_cpu():
with torch.device("cpu"):
yield
else:
module = torch.get_device_module(device)
if module is not None:
with module.device(device.index):
yield
else:
raise ValueError(f"Unknown device module: {device}")
is_ampere_with_cuda_12_3 = lambda: _check(8) is_ampere_with_cuda_12_3 = lambda: _check(8)
is_hopper_with_cuda_12_3 = lambda: _check(9) is_hopper_with_cuda_12_3 = lambda: _check(9)
...@@ -263,6 +277,14 @@ def use_intel_amx_backend(layer): ...@@ -263,6 +277,14 @@ def use_intel_amx_backend(layer):
return getattr(layer, "use_intel_amx_backend", False) return getattr(layer, "use_intel_amx_backend", False)
def xpu_has_xmx_support():
# TODO: update with XPU capalibity query
if is_xpu():
# currently only PVC/LNL/BMG supports F64, so we only support these now
return torch.xpu.get_device_properties().has_fp64
return False
def is_flashinfer_available(): def is_flashinfer_available():
""" """
Check whether flashinfer is available. Check whether flashinfer is available.
......
...@@ -8,6 +8,7 @@ import unittest ...@@ -8,6 +8,7 @@ import unittest
from functools import wraps from functools import wraps
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST_BASE,
DEFAULT_SMALL_MODEL_NAME_FOR_TEST_QWEN, DEFAULT_SMALL_MODEL_NAME_FOR_TEST_QWEN,
CustomTestCase, CustomTestCase,
is_in_ci, is_in_ci,
...@@ -55,6 +56,10 @@ class TestIntelXPUBackend(CustomTestCase): ...@@ -55,6 +56,10 @@ class TestIntelXPUBackend(CustomTestCase):
def test_latency_qwen_model(self): def test_latency_qwen_model(self):
return DEFAULT_SMALL_MODEL_NAME_FOR_TEST_QWEN return DEFAULT_SMALL_MODEL_NAME_FOR_TEST_QWEN
@intel_xpu_benchmark(["--attention-backend", "intel_xpu", "--page-size", "128"])
def test_attention_backend(self):
return DEFAULT_SMALL_MODEL_NAME_FOR_TEST_BASE
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment