Unverified Commit 3bface15 authored by woodx's avatar woodx Committed by GitHub
Browse files

Feat/support encoder model (like bert) (#4887)

parent 6fb29ffd
......@@ -6,6 +6,7 @@ import torch
from torch.nn.functional import scaled_dot_product_attention
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.radix_attention import AttentionType
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
if TYPE_CHECKING:
......@@ -202,6 +203,10 @@ class TorchNativeAttnBackend(AttentionBackend):
q_ = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
o_ = o.view(-1, layer.tp_q_head_num, layer.v_head_dim)
causal = True
if layer.is_cross_attention or layer.attn_type == AttentionType.ENCODER_ONLY:
causal = False
self._run_sdpa_forward_extend(
q_,
o_,
......@@ -214,7 +219,7 @@ class TorchNativeAttnBackend(AttentionBackend):
forward_batch.extend_seq_lens,
scaling=layer.scaling,
enable_gqa=use_gqa,
causal=not layer.is_cross_attention,
causal=causal,
)
return o
......
......@@ -10,6 +10,7 @@ import triton.language as tl
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.layers.radix_attention import AttentionType
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.utils import get_bool_env_var, get_device_core_count
......@@ -528,6 +529,10 @@ class TritonAttnBackend(AttentionBackend):
layer, forward_batch.out_cache_loc, k, v
)
causal = True
if layer.attn_type == AttentionType.ENCODER_ONLY:
causal = False
self.extend_attention_fwd(
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
k.contiguous(),
......@@ -539,6 +544,7 @@ class TritonAttnBackend(AttentionBackend):
self.forward_metadata.kv_indptr,
self.forward_metadata.kv_indices,
self.forward_metadata.custom_mask,
causal,
self.forward_metadata.mask_indptr,
self.forward_metadata.max_extend_len,
layer.scaling,
......
......@@ -74,6 +74,7 @@ def _fwd_kernel(
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
USE_CUSTOM_MASK: tl.constexpr,
IS_CAUSAL: tl.constexpr,
SKIP_PREFIX_CUSTOM_MASK: tl.constexpr,
STORE_TRANSPOSE: tl.constexpr,
):
......@@ -129,6 +130,7 @@ def _fwd_kernel(
for start_n in range(0, cur_seq_len_prefix, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
mask_n = (start_n + offs_n) < cur_seq_len_prefix
offs_kv_loc = tl.load(
kv_indices + cur_seq_kv_start_idx + start_n + offs_n, mask=mask_n, other=0
)
......@@ -196,7 +198,11 @@ def _fwd_kernel(
# stage 2: compute the triangle part
cur_block_m_end = tl.minimum(cur_seq_len_extend, (cur_block_m + 1) * BLOCK_M)
cur_block_m_end = (
cur_seq_len_extend
if not IS_CAUSAL
else tl.minimum(cur_seq_len_extend, (cur_block_m + 1) * BLOCK_M)
)
for start_n in range(0, cur_block_m_end, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
mask_n = (start_n + offs_n) < cur_block_m_end
......@@ -243,12 +249,15 @@ def _fwd_kernel(
)
custom_mask &= mask_m[:, None] & mask_n[None, :]
qk = tl.where(custom_mask, qk, float("-inf"))
else:
elif IS_CAUSAL:
mask_causual = (cur_block_m * BLOCK_M + offs_m[:, None]) >= (
start_n + offs_n[None, :]
)
mask_causual &= mask_m[:, None] & mask_n[None, :]
qk = tl.where(mask_causual, qk, float("-inf"))
else:
mask_non_causal = mask_m[:, None] & mask_n[None, :]
qk = tl.where(mask_non_causal, qk, float("-inf"))
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
re_scale = tl.exp(e_max - n_e_max)
......@@ -299,6 +308,7 @@ def extend_attention_fwd(
kv_indptr,
kv_indices,
custom_mask,
is_causal,
mask_indptr,
max_len_extend,
sm_scale=None,
......@@ -411,6 +421,7 @@ def extend_attention_fwd(
Lq=Lq,
Lv=Lv,
USE_CUSTOM_MASK=USE_CUSTOM_MASK,
IS_CAUSAL=is_causal,
SKIP_PREFIX_CUSTOM_MASK=SKIP_PREFIX_CUSTOM_MASK,
STORE_TRANSPOSE=_is_hip,
num_warps=num_warps,
......
......@@ -13,6 +13,7 @@
# ==============================================================================
"""Radix attention."""
from enum import Enum
from typing import Optional
from torch import nn
......@@ -22,6 +23,18 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
class AttentionType(Enum):
"""
Attention type.
Use string to be compatible with `torch.compile`.
"""
# Decoder attention between previous layer Q/K/V
DECODER = "decoder"
# Encoder attention between previous layer Q/K/V
ENCODER_ONLY = "encoder_only"
class RadixAttention(nn.Module):
"""
The attention layer implementation.
......@@ -39,6 +52,7 @@ class RadixAttention(nn.Module):
sliding_window_size: int = -1,
is_cross_attention: bool = False,
quant_config: Optional[QuantizationConfig] = None,
attn_type=AttentionType.DECODER,
prefix: str = "",
use_irope: bool = False,
):
......@@ -64,6 +78,7 @@ class RadixAttention(nn.Module):
self.quant_method = quant_config.get_quant_method(self, prefix=prefix)
if self.quant_method is not None:
self.quant_method.create_weights(self)
self.attn_type = attn_type
def forward(
self,
......
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, Iterable, Optional, Set, Tuple
import torch
from torch import nn
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import get_act_fn
from sglang.srt.layers.linear import (
ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import AttentionType, RadixAttention
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
BertConfig = None
class BertEmbedding(nn.Module):
def __init__(self, config: BertConfig):
super().__init__()
self.size = config.hidden_size
self.word_embeddings = VocabParallelEmbedding(
config.vocab_size, config.hidden_size
)
self.position_embeddings = VocabParallelEmbedding(
config.max_position_embeddings, config.hidden_size
)
self.token_type_embeddings = VocabParallelEmbedding(
config.type_vocab_size, config.hidden_size
)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.position_ids = nn.Parameter(
torch.empty((1, config.max_position_embeddings)),
)
self.position_embedding_type = config.position_embedding_type
if self.position_embedding_type != "absolute":
raise ValueError(
"Only 'absolute' position_embedding_type" + " is supported"
)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
) -> torch.Tensor:
input_shape = input_ids.size()
# Input embeddings.
inputs_embeds = self.word_embeddings(input_ids)
# Position embeddings.
position_embeddings = self.position_embeddings(position_ids)
token_type_ids = torch.zeros(
input_shape, dtype=torch.long, device=inputs_embeds.device
)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = inputs_embeds + token_type_embeddings + position_embeddings
embeddings = self.LayerNorm(embeddings)
return embeddings
class BertEncoder(nn.Module):
def __init__(
self,
config: BertConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
self.quant_config = quant_config
self.layer = nn.ModuleList(
[
BertLayer(
config=config,
layer_id=layer_idx,
quant_config=quant_config,
prefix=f"{prefix}.layer.{layer_idx}",
)
for layer_idx in range(config.num_hidden_layers)
]
)
def forward(
self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
) -> torch.Tensor:
for layer in self.layer:
hidden_states = layer(hidden_states, forward_batch)
return hidden_states
class BertLayer(nn.Module):
def __init__(
self,
config: BertConfig,
layer_id: int = 0,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.attention = BertAttention(
hidden_size=config.hidden_size,
num_attention_heads=config.num_attention_heads,
layer_id=layer_id,
layer_norm_eps=config.layer_norm_eps,
quant_config=quant_config,
prefix=f"{prefix}.attention",
)
self.intermediate = BertIntermediate(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.intermediate",
)
self.output = BertOutput(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
layer_norm_eps=config.layer_norm_eps,
quant_config=quant_config,
prefix=f"{prefix}.output",
)
def forward(self, hidden_states: torch.Tensor, forward_batch: ForwardBatch):
attn_output = self.attention(hidden_states, forward_batch)
intermediate_output = self.intermediate(attn_output)
output = self.output(intermediate_output, attn_output)
return output
class BertAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_attention_heads: int,
layer_norm_eps: float,
layer_id: int = 0,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.self_attn = BertSelfAttention(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
layer_id=layer_id,
quant_config=quant_config,
prefix=f"{prefix}.output",
)
self.output = BertSelfOutput(
hidden_size=hidden_size,
layer_norm_eps=layer_norm_eps,
quant_config=quant_config,
prefix=f"{prefix}.output",
)
def forward(
self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
) -> torch.Tensor:
self_output = self.self_attn(hidden_states, forward_batch)
return self.output(self_output, hidden_states)
class BertSelfAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_attention_heads: int,
layer_id: int = 0,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_attention_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = self.total_num_heads
self.head_dim = self.hidden_size // self.total_num_heads
assert self.head_dim * self.total_num_heads == self.hidden_size
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.qkv_proj = QKVParallelLinear(
hidden_size=self.hidden_size,
head_size=self.head_dim,
total_num_heads=self.total_num_heads,
total_num_kv_heads=self.total_num_kv_heads,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.attn = RadixAttention(
num_heads=self.num_heads,
head_dim=self.head_dim,
scaling=self.scaling,
num_kv_heads=self.num_kv_heads,
layer_id=layer_id,
prefix=f"{prefix}.attn",
attn_type=AttentionType.ENCODER_ONLY,
)
def forward(
self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
output = self.attn(q, k, v, forward_batch)
return output
class BertSelfOutput(nn.Module):
def __init__(
self,
hidden_size: int,
layer_norm_eps: float,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.dense = RowParallelLinear(
input_size=hidden_size,
output_size=hidden_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.dense",
)
self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
def forward(
self, hidden_states: torch.Tensor, input_tensor: torch.Tensor
) -> torch.Tensor:
hidden_states, _ = self.dense(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class BertIntermediate(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.dense = ColumnParallelLinear(
input_size=hidden_size,
output_size=intermediate_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.dense",
)
self.intermediate_act_fn = get_act_fn(hidden_act)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class BertOutput(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
layer_norm_eps: float,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.dense = RowParallelLinear(
input_size=intermediate_size,
output_size=hidden_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.dense",
)
self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
def forward(
self, hidden_states: torch.Tensor, input_tensor: torch.Tensor
) -> torch.Tensor:
hidden_states, _ = self.dense(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class BertModel(nn.Module):
def __init__(
self,
*,
config: BertConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
self.embeddings = BertEmbedding(config)
self.encoder = BertEncoder(
config=config, quant_config=quant_config, prefix=f"encoder"
)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
# self.pooler = BertPooler(config)
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
get_embedding: bool = False,
) -> torch.Tensor:
assert get_embedding == True
# Your tokenized IDs
hidden_states = self.embeddings(
input_ids=input_ids,
position_ids=positions,
)
hidden_states = self.encoder(hidden_states, forward_batch=forward_batch)
return self.pooler(hidden_states, forward_batch)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "query", "q"),
("qkv_proj", "key", "k"),
("qkv_proj", "value", "v"),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
name = name.replace("self", "self_attn")
if "pooler" in name:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
class Contriever(BertModel):
pass
EntryClass = [BertModel, Contriever]
......@@ -51,6 +51,8 @@ NUM_TOP_LOGPROBS = 5
def get_dtype_str(torch_dtype):
if torch_dtype is torch.float16:
return "float16"
if torch_dtype is torch.float32:
return "float32"
else:
raise NotImplementedError()
......@@ -447,6 +449,7 @@ class SRTRunner:
port: int = DEFAULT_PORT_FOR_SRT_TEST_RUNNER,
lora_paths: List[str] = None,
max_loras_per_batch: int = 4,
attention_backend: Optional[str] = None,
lora_backend: str = "triton",
disable_cuda_graph: bool = False,
disable_radix_cache: bool = False,
......@@ -487,6 +490,7 @@ class SRTRunner:
lora_paths=lora_paths,
max_loras_per_batch=max_loras_per_batch,
lora_backend=lora_backend,
attention_backend=attention_backend,
disable_cuda_graph=disable_cuda_graph,
disable_radix_cache=disable_radix_cache,
chunked_prefill_size=chunked_prefill_size,
......
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# python -m unittest test_encoder_embedding_models.TestEncoderEmbeddingModels.test_prefill_logits
import multiprocessing as mp
import random
import time
import unittest
import torch
from transformers import AutoConfig, AutoTokenizer
from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner
from sglang.test.test_utils import CustomTestCase, get_similarities, is_in_ci
MODELS = [("BAAI/bge-small-en", 1, 1e-5)]
ATTENTION_BACKEND = ["torch_native", "triton"]
BATCH_SIZE = [30]
TORCH_DTYPES = [torch.float32]
sgl_to_st_ratio = []
class TestEncoderEmbeddingModels(CustomTestCase):
@classmethod
def setUpClass(cls):
mp.set_start_method("spawn", force=True)
def _truncate_prompts(self, prompts, model_path):
config = AutoConfig.from_pretrained(model_path)
max_length = getattr(config, "max_position_embeddings", 512) - 20
tokenizer = AutoTokenizer.from_pretrained(model_path)
truncated_prompts = []
for prompt in prompts:
tokens = tokenizer(prompt, return_tensors="pt", truncation=False)
if len(tokens.input_ids[0]) > max_length:
truncated_text = tokenizer.decode(
tokens.input_ids[0][: max_length - 1], skip_special_tokens=True
)
truncated_prompts.append(truncated_text)
else:
truncated_prompts.append(prompt)
return truncated_prompts
def assert_close_prefill_logits(
self,
prompts,
model_path,
tp_size,
torch_dtype,
prefill_tolerance,
attention_backend,
batch_size,
) -> None:
truncated_prompts = self._truncate_prompts(prompts, model_path)
truncated_prompts = truncated_prompts * batch_size
with HFRunner(
model_path,
torch_dtype=torch_dtype,
model_type="embedding",
) as hf_runner:
# warm up
hf_outputs = hf_runner.forward(truncated_prompts)
st_start_time = time.time()
hf_outputs = hf_runner.forward(truncated_prompts)
st_end_time = time.time()
with SRTRunner(
model_path,
tp_size=tp_size,
torch_dtype=torch_dtype,
model_type="embedding",
attention_backend=attention_backend,
chunked_prefill_size=-1,
disable_radix_cache=True,
) as srt_runner:
# warm up
srt_outputs = srt_runner.forward(truncated_prompts)
sgl_start_time = time.time()
srt_outputs = srt_runner.forward(truncated_prompts)
sgl_end_time = time.time()
transformer_time = st_end_time - st_start_time
sgl_time = sgl_end_time - sgl_start_time
sgl_to_st_ratio.append(sgl_time / transformer_time)
for i in range(len(truncated_prompts)):
hf_logits = torch.Tensor(hf_outputs.embed_logits[i])
srt_logits = torch.Tensor(srt_outputs.embed_logits[i])
similarity = torch.tensor(get_similarities(hf_logits, srt_logits))
# If something is wrong, uncomment this to observe similarity.
# print("similarity diff", abs(similarity - 1))
if len(truncated_prompts[i]) <= 1000:
assert torch.all(
abs(similarity - 1) < prefill_tolerance
), "embeddings are not all close"
def test_prefill_logits(self):
models_to_test = MODELS
if is_in_ci():
models_to_test = [random.choice(MODELS)]
for model, tp_size, prefill_tolerance in models_to_test:
for attention_backend in ATTENTION_BACKEND:
for batch_size in BATCH_SIZE:
for torch_dtype in TORCH_DTYPES:
self.assert_close_prefill_logits(
DEFAULT_PROMPTS,
model,
tp_size,
torch_dtype,
prefill_tolerance,
attention_backend,
batch_size,
)
for i in range(len(BATCH_SIZE)):
print(
"bacth size: ",
BATCH_SIZE[i] * 5,
"sgl_time/st_time",
round(sgl_to_st_ratio[i], 3),
)
if __name__ == "__main__":
unittest.main()
......@@ -116,6 +116,7 @@ class TestTritonAttention(CustomTestCase):
kv_indptr,
kv_indices,
custom_mask,
True,
mask_indptr,
max_len_extend,
)
......@@ -150,6 +151,7 @@ class TestTritonAttention(CustomTestCase):
kv_indptr,
kv_indices,
custom_mask,
True,
mask_indptr,
max_len_extend,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment