Unverified Commit b7cd7430 authored by PGFLMG's avatar PGFLMG Committed by GitHub
Browse files

[Feat] QWen-1M context support[2/2]: Update block sparse attention backend (#5949)

parent a69b6370
"""
Usage:
python3 offline_batch_inference.py
"""
from urllib.request import urlopen
import sglang as sgl
def load_prompt() -> str:
# Test cases with various lengths can be found at:
#
# https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/64k.txt
# https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/200k.txt
# https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/600k.txt
# https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/1m.txt
with urlopen(
"https://qianwen-res.oss-cn-beijing.aliyuncs.com"
"/Qwen2.5-1M/test-data/64k.txt",
timeout=5,
) as response:
prompt = response.read().decode("utf-8")
return prompt
# Processing the prompt.
def process_requests(llm: sgl.Engine, prompts: list[str]) -> None:
# Create a sampling params object.
sampling_params = {
"temperature": 0.7,
"top_p": 0.8,
"top_k": 20,
"repetition_penalty": 1.05,
"max_new_tokens": 256,
}
# Generate texts from the prompts.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt_token_ids = output["meta_info"]["prompt_tokens"]
generated_text = output["text"]
print(
f"Prompt length: {prompt_token_ids}, " f"Generated text: {generated_text!r}"
)
# Create an LLM.
def initialize_engine() -> sgl.Engine:
llm = sgl.Engine(
model_path="Qwen/Qwen2.5-7B-Instruct-1M",
context_length=1048576,
page_size=256,
attention_backend="dual_chunk_flash_attn",
tp_size=4,
disable_radix_cache=True,
enable_mixed_chunk=False,
enable_torch_compile=False,
chunked_prefill_size=131072,
mem_fraction_static=0.6,
log_level="DEBUG",
)
return llm
def main():
llm = initialize_engine()
prompt = load_prompt()
process_requests(llm, [prompt])
if __name__ == "__main__":
main()
...@@ -27,6 +27,7 @@ from sglang.srt.hf_transformers_utils import ( ...@@ -27,6 +27,7 @@ from sglang.srt.hf_transformers_utils import (
get_context_length, get_context_length,
get_generation_config, get_generation_config,
get_hf_text_config, get_hf_text_config,
get_sparse_attention_config,
) )
from sglang.srt.layers.quantization import QUANTIZATION_METHODS from sglang.srt.layers.quantization import QUANTIZATION_METHODS
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
...@@ -270,6 +271,9 @@ class ModelConfig: ...@@ -270,6 +271,9 @@ class ModelConfig:
# Verify quantization # Verify quantization
self._verify_quantization() self._verify_quantization()
# Verify dual-chunk attention config
self._verify_dual_chunk_attention_config()
# Cache attributes # Cache attributes
self.hf_eos_token_id = self.get_hf_eos_token_id() self.hf_eos_token_id = self.get_hf_eos_token_id()
...@@ -297,6 +301,13 @@ class ModelConfig: ...@@ -297,6 +301,13 @@ class ModelConfig:
**kwargs, **kwargs,
) )
def get_total_num_attention_heads(self) -> int:
return self.num_attention_heads
def get_num_attention_heads(self, tensor_parallel_size) -> int:
total_num_attention_heads = self.num_attention_heads
return max(1, total_num_attention_heads // tensor_parallel_size)
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289 # adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
def get_total_num_kv_heads(self) -> int: def get_total_num_kv_heads(self) -> int:
"""Returns the total number of KV heads.""" """Returns the total number of KV heads."""
...@@ -484,6 +495,23 @@ class ModelConfig: ...@@ -484,6 +495,23 @@ class ModelConfig:
self.quantization, self.quantization,
) )
def _verify_dual_chunk_attention_config(self) -> None:
if hasattr(self.hf_config, "dual_chunk_attention_config"):
# Try loading the sparse attention config
sparse_attn_config = get_sparse_attention_config(self.model_path)
if not sparse_attn_config:
return
self.hf_config.dual_chunk_attention_config["sparse_attention_config"] = (
sparse_attn_config
)
if (
"sparse_attention_enabled"
not in self.hf_config.dual_chunk_attention_config
):
self.hf_config.dual_chunk_attention_config[
"sparse_attention_enabled"
] = True
def get_hf_eos_token_id(self) -> Optional[Set[int]]: def get_hf_eos_token_id(self) -> Optional[Set[int]]:
eos_ids = getattr(self.hf_config, "eos_token_id", None) eos_ids = getattr(self.hf_config, "eos_token_id", None)
if eos_ids is not None: if eos_ids is not None:
......
...@@ -76,6 +76,9 @@ class ScheduleBatchDisaggregationDecodeMixin: ...@@ -76,6 +76,9 @@ class ScheduleBatchDisaggregationDecodeMixin:
req_pool_indices, dtype=torch.int64, device=self.device req_pool_indices, dtype=torch.int64, device=self.device
) )
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64, device=self.device) self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64, device=self.device)
self.orig_seq_lens = torch.tensor(
seq_lens, dtype=torch.int32, device=self.device
)
self.out_cache_loc = out_cache_loc self.out_cache_loc = out_cache_loc
self.seq_lens_sum = sum(seq_lens) self.seq_lens_sum = sum(seq_lens)
......
...@@ -14,10 +14,11 @@ ...@@ -14,10 +14,11 @@
"""Utilities for Huggingface Transformers.""" """Utilities for Huggingface Transformers."""
import contextlib import contextlib
import json
import os import os
import warnings import warnings
from pathlib import Path from pathlib import Path
from typing import Dict, Optional, Type, Union from typing import Any, Dict, Optional, Type, Union
import torch import torch
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
...@@ -62,11 +63,17 @@ for name, cls in _CONFIG_REGISTRY.items(): ...@@ -62,11 +63,17 @@ for name, cls in _CONFIG_REGISTRY.items():
AutoConfig.register(name, cls) AutoConfig.register(name, cls)
def download_from_hf(model_path: str): def download_from_hf(
model_path: str,
allow_patterns: Optional[Union[str, list]] = None,
):
if os.path.exists(model_path): if os.path.exists(model_path):
return model_path return model_path
return snapshot_download(model_path, allow_patterns=["*.json", "*.bin", "*.model"]) if not allow_patterns:
allow_patterns = ["*.json", "*.bin", "*.model"]
return snapshot_download(model_path, allow_patterns=allow_patterns)
def get_hf_text_config(config: PretrainedConfig): def get_hf_text_config(config: PretrainedConfig):
...@@ -171,6 +178,26 @@ def get_generation_config( ...@@ -171,6 +178,26 @@ def get_generation_config(
return None return None
# Qwen-1M related
def get_sparse_attention_config(
model: str,
sparse_attention_config_filename: str = "sparse_attention_config.json",
) -> Dict[str, Any]:
is_local = os.path.isdir(model)
if not is_local:
# Download the config files.
model = download_from_hf(model, allow_patterns=["*.json"])
config_file = os.path.join(model, sparse_attention_config_filename)
if not os.path.exists(config_file):
return {}
# Load the sparse attention config.
with open(config_file) as f:
config = json.load(f)
return config
# Models don't use the same configuration key for determining the maximum # Models don't use the same configuration key for determining the maximum
# context length. Store them here so we can sanely check them. # context length. Store them here so we can sanely check them.
# NOTE: The ordering here is important. Some models have two of these and we # NOTE: The ordering here is important. Some models have two of these and we
......
...@@ -1172,6 +1172,202 @@ class MRotaryEmbedding(RotaryEmbedding): ...@@ -1172,6 +1172,202 @@ class MRotaryEmbedding(RotaryEmbedding):
) )
class DualChunkRotaryEmbedding(CustomOp):
"""Rotary positional embedding for Dual Chunk Attention."""
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
dtype: torch.dtype,
chunk_size: int,
local_size: int,
) -> None:
super().__init__()
self.head_size = head_size
self.rotary_dim = rotary_dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.is_neox_style = is_neox_style
self.chunk_size = chunk_size
self.local_size = local_size
self.dtype = dtype
self.device = torch.device(f"cuda:{torch.cuda.current_device()}")
(q_cache, qc_cache, k_cache, qc_no_clamp_cache, q_inter_cache) = (
self._compute_cos_sin_cache()
)
self.register_buffer("cos_sin_q_cache", q_cache, persistent=False)
self.register_buffer("cos_sin_qc_cache", qc_cache, persistent=False)
self.register_buffer("cos_sin_k_cache", k_cache, persistent=False)
self.register_buffer(
"cos_sin_qc_no_clamp_cache", qc_no_clamp_cache, persistent=False
)
self.register_buffer("cos_sin_q_inter_cache", q_inter_cache, persistent=False)
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
"""Compute the inverse frequency."""
# NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`.
# However, we use `torch.arange(..., dtype=torch.float)` instead to
# avoid numerical issues with large base values (e.g., 10000000).
# This may cause a slight numerical difference between the HF
# implementation and ours.
# NOTE(woosuk): To exactly match the HF implementation, we need to
# use CPU to compute the cache and then move it to GPU. However, we
# create the cache on GPU for faster initialization. This may cause
# a slight numerical difference between the HF implementation and ours.
inv_freq = 1.0 / (
base
** (
torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim
)
)
return inv_freq
def _compute_cos_sin_cache(self) -> torch.Tensor:
"""Compute the cos and sin cache."""
inv_freq = self._compute_inv_freq(self.base)
chunk_len = self.chunk_size - self.local_size
q_t = torch.arange(chunk_len, dtype=torch.float)
qc_t = (torch.arange(chunk_len, dtype=torch.float) + chunk_len).clamp(
max=self.chunk_size
)
k_t = torch.arange(self.max_position_embeddings, dtype=torch.float) % chunk_len
# count from chunk_len, no clamp(self.chunk_size) restriction
qc_no_clamp_t = torch.arange(chunk_len, dtype=torch.float) + chunk_len
# count from self.chunk_size for q_inter's rope
q_inter_t = torch.arange(chunk_len, dtype=torch.float) + self.chunk_size
q_freqs = torch.outer(q_t, inv_freq)
qc_freqs = torch.outer(qc_t, inv_freq)
k_freqs = torch.outer(k_t, inv_freq)
qc_no_clamp_freqs = torch.outer(qc_no_clamp_t, inv_freq)
q_inter_freqs = torch.outer(q_inter_t, inv_freq)
q_cos = q_freqs.cos()
q_sin = q_freqs.sin()
qc_cos = qc_freqs.cos()
qc_sin = qc_freqs.sin()
k_cos = k_freqs.cos()
k_sin = k_freqs.sin()
qc_no_clamp_cos = qc_no_clamp_freqs.cos()
qc_no_clamp_sin = qc_no_clamp_freqs.sin()
q_inter_cos = q_inter_freqs.cos()
q_inter_sin = q_inter_freqs.sin()
q_cache = torch.cat((q_cos, q_sin), dim=-1).to(
dtype=self.dtype, device=self.device
)
qc_cache = torch.cat((qc_cos, qc_sin), dim=-1).to(
dtype=self.dtype, device=self.device
)
k_cache = torch.cat((k_cos, k_sin), dim=-1).to(
dtype=self.dtype, device=self.device
)
qc_no_clamp_cache = torch.cat((qc_no_clamp_cos, qc_no_clamp_sin), dim=-1).to(
dtype=self.dtype, device=self.device
)
q_inter_cache = torch.cat((q_inter_cos, q_inter_sin), dim=-1).to(
dtype=self.dtype, device=self.device
)
return q_cache, qc_cache, k_cache, qc_no_clamp_cache, q_inter_cache
def forward(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
query = query.view(*query.shape[:-1], -1, self.head_size)
key = key.view(*key.shape[:-1], -1, self.head_size)
query_rot = query[..., : self.rotary_dim]
key_rot = key[..., : self.rotary_dim]
if self.rotary_dim < self.head_size:
query_pass = query[..., self.rotary_dim :]
key_pass = key[..., self.rotary_dim :]
else:
query_pass = None
key_pass = None
positions_with_offsets = (
torch.add(positions, offsets) if offsets is not None else positions
)
key = self._apply_rotary_embedding(
self.cos_sin_k_cache[positions_with_offsets], key_rot, key_pass
)
chunk_len = self.chunk_size - self.local_size
query = self._apply_rotary_embedding(
self.cos_sin_q_cache[positions_with_offsets % chunk_len],
query_rot,
query_pass,
)
query_succ = self._apply_rotary_embedding(
self.cos_sin_qc_cache[positions_with_offsets % chunk_len],
query_rot,
query_pass,
)
query_inter = self._apply_rotary_embedding(
self.cos_sin_qc_cache[chunk_len - 1].repeat(positions.shape[0], 1),
query_rot,
query_pass,
)
query_succ_critical = self._apply_rotary_embedding(
self.cos_sin_qc_no_clamp_cache[positions_with_offsets % chunk_len],
query_rot,
query_pass,
)
query_inter_critical = self._apply_rotary_embedding(
self.cos_sin_q_inter_cache[positions_with_offsets % chunk_len],
query_rot,
query_pass,
)
# merge query into one tensor to simplify the interfaces
query = torch.cat(
(
query,
query_succ,
query_inter,
query_succ_critical,
query_inter_critical,
),
dim=-1,
)
return query, key
def _apply_rotary_embedding(self, cos_sin, hidden_rot, hidden_pass):
cos, sin = cos_sin.chunk(2, dim=-1)
if self.is_neox_style:
# NOTE(woosuk): Here we assume that the positions tensor has the
# shape [batch_size, seq_len].
cos = cos.repeat(1, 1, 2).unsqueeze(-2)
sin = sin.repeat(1, 1, 2).unsqueeze(-2)
else:
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj
hidden_rot = hidden_rot * cos + rotate_fn(hidden_rot) * sin
if self.rotary_dim < self.head_size:
hidden = torch.cat((hidden_rot, hidden_pass), dim=-1)
else:
hidden = hidden_rot
return hidden.flatten(-2).squeeze(0)
def extra_repr(self) -> str:
s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
s += f", max_position_embeddings={self.max_position_embeddings}"
s += f", base={self.base}, is_neox_style={self.is_neox_style}"
s += f", chunk_size={self.chunk_size}, local_size={self.local_size}"
return s
_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {} _ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}
...@@ -1184,6 +1380,7 @@ def get_rope( ...@@ -1184,6 +1380,7 @@ def get_rope(
rope_scaling: Optional[Dict[str, Any]] = None, rope_scaling: Optional[Dict[str, Any]] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
partial_rotary_factor: float = 1.0, partial_rotary_factor: float = 1.0,
dual_chunk_attention_config: Optional[Dict[str, Any]] = None,
) -> RotaryEmbedding: ) -> RotaryEmbedding:
if dtype is None: if dtype is None:
dtype = torch.get_default_dtype() dtype = torch.get_default_dtype()
...@@ -1195,6 +1392,17 @@ def get_rope( ...@@ -1195,6 +1392,17 @@ def get_rope(
rope_scaling_args = tuple(rope_scaling_tuple.items()) rope_scaling_args = tuple(rope_scaling_tuple.items())
else: else:
rope_scaling_args = None rope_scaling_args = None
if dual_chunk_attention_config is not None:
dual_chunk_attention_tuple = {
k: tuple(v) if isinstance(v, list) else v
for k, v in dual_chunk_attention_config.items()
if k != "sparse_attention_config"
}
dual_chunk_attention_args = tuple(dual_chunk_attention_tuple.items())
else:
dual_chunk_attention_args = None
if partial_rotary_factor < 1.0: if partial_rotary_factor < 1.0:
rotary_dim = int(rotary_dim * partial_rotary_factor) rotary_dim = int(rotary_dim * partial_rotary_factor)
key = ( key = (
...@@ -1204,12 +1412,28 @@ def get_rope( ...@@ -1204,12 +1412,28 @@ def get_rope(
base, base,
is_neox_style, is_neox_style,
rope_scaling_args, rope_scaling_args,
dual_chunk_attention_args,
dtype, dtype,
) )
if key in _ROPE_DICT: if key in _ROPE_DICT:
return _ROPE_DICT[key] return _ROPE_DICT[key]
if rope_scaling is None: if dual_chunk_attention_config is not None:
extra_kwargs = {
k: v
for k, v in dual_chunk_attention_config.items()
if k in ("chunk_size", "local_size")
}
rotary_emb = DualChunkRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
dtype,
**extra_kwargs,
)
elif rope_scaling is None:
rotary_emb = RotaryEmbedding( rotary_emb = RotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style, dtype head_size, rotary_dim, max_position, base, is_neox_style, dtype
) )
......
...@@ -846,6 +846,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -846,6 +846,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# The sum of all sequence lengths # The sum of all sequence lengths
seq_lens_sum: int = None seq_lens_sum: int = None
# The original sequence lengths, Qwen-1M related
orig_seq_lens: torch.Tensor = None # shape: [b], int32
# For DP attention # For DP attention
global_num_tokens: Optional[List[int]] = None global_num_tokens: Optional[List[int]] = None
...@@ -1131,6 +1133,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1131,6 +1133,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs] input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
extend_num_tokens = sum(len(ids) for ids in input_ids) extend_num_tokens = sum(len(ids) for ids in input_ids)
seq_lens = [len(r.fill_ids) for r in reqs] seq_lens = [len(r.fill_ids) for r in reqs]
orig_seq_lens = [max(len(r.fill_ids), len(r.origin_input_ids)) for r in reqs]
prefix_lens = [len(r.prefix_indices) for r in reqs] prefix_lens = [len(r.prefix_indices) for r in reqs]
extend_lens = [r.extend_input_len for r in reqs] extend_lens = [r.extend_input_len for r in reqs]
...@@ -1147,6 +1150,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1147,6 +1150,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to( seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to(
self.device, non_blocking=True self.device, non_blocking=True
) )
orig_seq_lens_tensor = torch.tensor(orig_seq_lens, dtype=torch.int32).to(
self.device, non_blocking=True
)
prefix_lens_tensor = torch.tensor( prefix_lens_tensor = torch.tensor(
prefix_lens, dtype=torch.int64, device=self.device prefix_lens, dtype=torch.int64, device=self.device
) )
...@@ -1260,6 +1266,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1260,6 +1266,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.input_ids = input_ids_tensor self.input_ids = input_ids_tensor
self.req_pool_indices = req_pool_indices_tensor self.req_pool_indices = req_pool_indices_tensor
self.seq_lens = seq_lens_tensor self.seq_lens = seq_lens_tensor
self.orig_seq_lens = orig_seq_lens_tensor
self.out_cache_loc = out_cache_loc self.out_cache_loc = out_cache_loc
self.input_embeds = ( self.input_embeds = (
torch.tensor(input_embeds).to(self.device, non_blocking=True) torch.tensor(input_embeds).to(self.device, non_blocking=True)
...@@ -1507,6 +1514,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1507,6 +1514,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.forward_mode = ForwardMode.IDLE self.forward_mode = ForwardMode.IDLE
self.input_ids = torch.empty(0, dtype=torch.int64, device=self.device) self.input_ids = torch.empty(0, dtype=torch.int64, device=self.device)
self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device) self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device)
self.orig_seq_lens = torch.empty(0, dtype=torch.int32, device=self.device)
self.out_cache_loc = torch.empty(0, dtype=torch.int64, device=self.device) self.out_cache_loc = torch.empty(0, dtype=torch.int64, device=self.device)
self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device) self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
self.seq_lens_sum = 0 self.seq_lens_sum = 0
...@@ -1561,9 +1569,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1561,9 +1569,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
if self.enable_overlap: if self.enable_overlap:
# Do not use in-place operations in the overlap mode # Do not use in-place operations in the overlap mode
self.seq_lens = self.seq_lens + 1 self.seq_lens = self.seq_lens + 1
self.orig_seq_lens = self.orig_seq_lens + 1
else: else:
# A faster in-place version # A faster in-place version
self.seq_lens.add_(1) self.seq_lens.add_(1)
self.orig_seq_lens.add_(1)
self.seq_lens_sum += bs self.seq_lens_sum += bs
# free memory # free memory
...@@ -1627,6 +1637,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1627,6 +1637,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.multimodal_inputs = [self.multimodal_inputs[i] for i in keep_indices] self.multimodal_inputs = [self.multimodal_inputs[i] for i in keep_indices]
self.req_pool_indices = self.req_pool_indices[keep_indices_device] self.req_pool_indices = self.req_pool_indices[keep_indices_device]
self.seq_lens = self.seq_lens[keep_indices_device] self.seq_lens = self.seq_lens[keep_indices_device]
self.orig_seq_lens = self.orig_seq_lens[keep_indices_device]
self.out_cache_loc = None self.out_cache_loc = None
self.seq_lens_sum = self.seq_lens.sum().item() self.seq_lens_sum = self.seq_lens.sum().item()
self.output_ids = self.output_ids[keep_indices_device] self.output_ids = self.output_ids[keep_indices_device]
...@@ -1659,6 +1670,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1659,6 +1670,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
[self.req_pool_indices, other.req_pool_indices] [self.req_pool_indices, other.req_pool_indices]
) )
self.seq_lens = torch.cat([self.seq_lens, other.seq_lens]) self.seq_lens = torch.cat([self.seq_lens, other.seq_lens])
self.orig_seq_lens = torch.cat([self.orig_seq_lens, other.orig_seq_lens])
self.out_cache_loc = None self.out_cache_loc = None
self.seq_lens_sum += other.seq_lens_sum self.seq_lens_sum += other.seq_lens_sum
if self.output_ids is not None: if self.output_ids is not None:
...@@ -1733,6 +1745,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1733,6 +1745,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
input_ids=self.input_ids, input_ids=self.input_ids,
req_pool_indices=self.req_pool_indices, req_pool_indices=self.req_pool_indices,
seq_lens=self.seq_lens, seq_lens=self.seq_lens,
orig_seq_lens=self.orig_seq_lens,
out_cache_loc=self.out_cache_loc, out_cache_loc=self.out_cache_loc,
seq_lens_cpu=seq_lens_cpu, seq_lens_cpu=seq_lens_cpu,
seq_lens_sum=self.seq_lens_sum, seq_lens_sum=self.seq_lens_sum,
...@@ -1900,6 +1913,9 @@ class ModelWorkerBatch: ...@@ -1900,6 +1913,9 @@ class ModelWorkerBatch:
# Sampling info # Sampling info
sampling_info: SamplingBatchInfo sampling_info: SamplingBatchInfo
# The original sequence lengths, Qwen-1M related
orig_seq_lens: Optional[torch.Tensor] = None
# The input Embeds # The input Embeds
input_embeds: Optional[torch.Tensor] = None input_embeds: Optional[torch.Tensor] = None
......
...@@ -589,6 +589,7 @@ class CudaGraphRunner: ...@@ -589,6 +589,7 @@ class CudaGraphRunner:
req_pool_indices=req_pool_indices, req_pool_indices=req_pool_indices,
seq_lens=seq_lens, seq_lens=seq_lens,
next_token_logits_buffer=next_token_logits_buffer, next_token_logits_buffer=next_token_logits_buffer,
orig_seq_lens=seq_lens,
req_to_token_pool=self.model_runner.req_to_token_pool, req_to_token_pool=self.model_runner.req_to_token_pool,
token_to_kv_pool=self.model_runner.token_to_kv_pool, token_to_kv_pool=self.model_runner.token_to_kv_pool,
attn_backend=self.model_runner.attn_backend, attn_backend=self.model_runner.attn_backend,
......
...@@ -180,6 +180,9 @@ class ForwardBatch: ...@@ -180,6 +180,9 @@ class ForwardBatch:
# The sum of all sequence lengths # The sum of all sequence lengths
seq_lens_sum: int seq_lens_sum: int
# The original sequence length without being chunked. Qwen-1M related.
orig_seq_lens: Optional[torch.Tensor] = None
# Optional seq_lens on cpu # Optional seq_lens on cpu
seq_lens_cpu: Optional[torch.Tensor] = None seq_lens_cpu: Optional[torch.Tensor] = None
...@@ -321,6 +324,7 @@ class ForwardBatch: ...@@ -321,6 +324,7 @@ class ForwardBatch:
encoder_out_cache_loc=batch.encoder_out_cache_loc, encoder_out_cache_loc=batch.encoder_out_cache_loc,
seq_lens_sum=batch.seq_lens_sum, seq_lens_sum=batch.seq_lens_sum,
seq_lens_cpu=batch.seq_lens_cpu, seq_lens_cpu=batch.seq_lens_cpu,
orig_seq_lens=batch.orig_seq_lens,
return_logprob=batch.return_logprob, return_logprob=batch.return_logprob,
top_logprobs_nums=batch.top_logprobs_nums, top_logprobs_nums=batch.top_logprobs_nums,
token_ids_logprobs=batch.token_ids_logprobs, token_ids_logprobs=batch.token_ids_logprobs,
......
...@@ -1467,6 +1467,12 @@ class ModelRunner: ...@@ -1467,6 +1467,12 @@ class ModelRunner:
logger.info(f"Intel AMX attention backend is enabled.") logger.info(f"Intel AMX attention backend is enabled.")
return IntelAMXAttnBackend(self) return IntelAMXAttnBackend(self)
elif self.server_args.attention_backend == "dual_chunk_flash_attn":
from sglang.srt.layers.attention.dual_chunk_flashattention_backend import (
DualChunkFlashAttentionBackend,
)
return DualChunkFlashAttentionBackend(self)
else: else:
raise ValueError(f"Invalid attention backend: {backend_str}") raise ValueError(f"Invalid attention backend: {backend_str}")
......
...@@ -107,6 +107,7 @@ class Qwen2Attention(nn.Module): ...@@ -107,6 +107,7 @@ class Qwen2Attention(nn.Module):
rope_scaling: Optional[Dict[str, Any]] = None, rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 32768, max_position_embeddings: int = 32768,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
dual_chunk_attention_config: Optional[dict[str, Any]] = None,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -158,6 +159,7 @@ class Qwen2Attention(nn.Module): ...@@ -158,6 +159,7 @@ class Qwen2Attention(nn.Module):
max_position=max_position_embeddings, max_position=max_position_embeddings,
base=rope_theta, base=rope_theta,
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
dual_chunk_attention_config=dual_chunk_attention_config,
) )
self.attn = RadixAttention( self.attn = RadixAttention(
self.num_heads, self.num_heads,
...@@ -198,6 +200,9 @@ class Qwen2DecoderLayer(nn.Module): ...@@ -198,6 +200,9 @@ class Qwen2DecoderLayer(nn.Module):
rope_scaling = getattr(config, "rope_scaling", None) rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings", 32768) max_position_embeddings = getattr(config, "max_position_embeddings", 32768)
head_dim = getattr(config, "head_dim", None) head_dim = getattr(config, "head_dim", None)
dual_chunk_attention_config = getattr(
config, "dual_chunk_attention_config", None
)
self.self_attn = Qwen2Attention( self.self_attn = Qwen2Attention(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
num_heads=config.num_attention_heads, num_heads=config.num_attention_heads,
...@@ -208,6 +213,7 @@ class Qwen2DecoderLayer(nn.Module): ...@@ -208,6 +213,7 @@ class Qwen2DecoderLayer(nn.Module):
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
quant_config=quant_config, quant_config=quant_config,
dual_chunk_attention_config=dual_chunk_attention_config,
prefix=add_prefix("self_attn", prefix), prefix=add_prefix("self_attn", prefix),
) )
self.mlp = Qwen2MLP( self.mlp = Qwen2MLP(
......
...@@ -210,6 +210,7 @@ class Qwen2MoeAttention(nn.Module): ...@@ -210,6 +210,7 @@ class Qwen2MoeAttention(nn.Module):
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
qkv_bias: int = True, qkv_bias: int = True,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
dual_chunk_attention_config: Optional[dict[str, Any]] = None,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -267,6 +268,7 @@ class Qwen2MoeAttention(nn.Module): ...@@ -267,6 +268,7 @@ class Qwen2MoeAttention(nn.Module):
max_position=max_position_embeddings, max_position=max_position_embeddings,
base=rope_theta, base=rope_theta,
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
dual_chunk_attention_config=dual_chunk_attention_config,
) )
self.attn = RadixAttention( self.attn = RadixAttention(
self.num_heads, self.num_heads,
...@@ -308,6 +310,9 @@ class Qwen2MoeDecoderLayer(nn.Module): ...@@ -308,6 +310,9 @@ class Qwen2MoeDecoderLayer(nn.Module):
rope_scaling = getattr(config, "rope_scaling", None) rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings", 8192) max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
qkv_bias = getattr(config, "qkv_bias", True) qkv_bias = getattr(config, "qkv_bias", True)
dual_chunk_attention_config = getattr(
config, "dual_chunk_attention_config", None
)
self.self_attn = Qwen2MoeAttention( self.self_attn = Qwen2MoeAttention(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
num_heads=config.num_attention_heads, num_heads=config.num_attention_heads,
...@@ -317,6 +322,7 @@ class Qwen2MoeDecoderLayer(nn.Module): ...@@ -317,6 +322,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
quant_config=quant_config, quant_config=quant_config,
dual_chunk_attention_config=dual_chunk_attention_config,
qkv_bias=qkv_bias, qkv_bias=qkv_bias,
prefix=add_prefix("self_attn", prefix), prefix=add_prefix("self_attn", prefix),
) )
......
...@@ -295,6 +295,7 @@ class Qwen3MoeAttention(nn.Module): ...@@ -295,6 +295,7 @@ class Qwen3MoeAttention(nn.Module):
attention_bias: bool = False, attention_bias: bool = False,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
dual_chunk_attention_config: Optional[dict[str, Any]] = None,
alt_stream: Optional[torch.cuda.Stream] = None, alt_stream: Optional[torch.cuda.Stream] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -353,6 +354,7 @@ class Qwen3MoeAttention(nn.Module): ...@@ -353,6 +354,7 @@ class Qwen3MoeAttention(nn.Module):
max_position=max_position_embeddings, max_position=max_position_embeddings,
base=rope_theta, base=rope_theta,
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
dual_chunk_attention_config=dual_chunk_attention_config,
) )
self.attn = RadixAttention( self.attn = RadixAttention(
self.num_heads, self.num_heads,
...@@ -458,6 +460,9 @@ class Qwen3MoeDecoderLayer(nn.Module): ...@@ -458,6 +460,9 @@ class Qwen3MoeDecoderLayer(nn.Module):
) )
rms_norm_eps = config.rms_norm_eps rms_norm_eps = config.rms_norm_eps
attention_bias = config.attention_bias attention_bias = config.attention_bias
dual_chunk_attention_config = getattr(
config, "dual_chunk_attention_config", None
)
self.self_attn = Qwen3MoeAttention( self.self_attn = Qwen3MoeAttention(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
num_heads=config.num_attention_heads, num_heads=config.num_attention_heads,
...@@ -471,6 +476,7 @@ class Qwen3MoeDecoderLayer(nn.Module): ...@@ -471,6 +476,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
attention_bias=attention_bias, attention_bias=attention_bias,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("self_attn", prefix), prefix=add_prefix("self_attn", prefix),
dual_chunk_attention_config=dual_chunk_attention_config,
alt_stream=alt_stream, alt_stream=alt_stream,
) )
......
...@@ -502,6 +502,20 @@ class ServerArgs: ...@@ -502,6 +502,20 @@ class ServerArgs:
# use bf16 for mxfp4 triton kernels # use bf16 for mxfp4 triton kernels
self.dtype = "bfloat16" self.dtype = "bfloat16"
if self.attention_backend == "dual_chunk_flash_attn":
logger.warning(
"Mixed chunk is disabled because of using dual chunk flash attention backend"
)
logger.warning(
"Radix cache is disabled because of using dual chunk flash attention backend"
)
logger.warning(
"Cuda graph is disabled because of using dual chunk flash attention backend"
)
self.enable_mixed_chunk = False
self.disable_cuda_graph = True
self.disable_radix_cache = True
# Set page size # Set page size
if self.page_size is None: if self.page_size is None:
self.page_size = 1 self.page_size = 1
...@@ -1337,6 +1351,7 @@ class ServerArgs: ...@@ -1337,6 +1351,7 @@ class ServerArgs:
"triton", "triton",
"trtllm_mla", "trtllm_mla",
"trtllm_mha", "trtllm_mha",
"dual_chunk_flash_attn",
], ],
default=ServerArgs.attention_backend, default=ServerArgs.attention_backend,
help="Choose the kernels for attention layers.", help="Choose the kernels for attention layers.",
......
...@@ -661,6 +661,7 @@ class TboForwardBatchPreparer: ...@@ -661,6 +661,7 @@ class TboForwardBatchPreparer:
"padded_static_len", "padded_static_len",
"mrope_positions", # only used by qwen2-vl, thus not care "mrope_positions", # only used by qwen2-vl, thus not care
"split_index", # for split prefill "split_index", # for split prefill
"orig_seq_lens", # only used by qwen-1m, thus not care
]: ]:
output_dict[key] = getattr(batch, key) output_dict[key] = getattr(batch, key)
if not batch.forward_mode.is_target_verify(): if not batch.forward_mode.is_target_verify():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment