Unverified Commit 164302c7 authored by Christian Bahls's avatar Christian Bahls Committed by GitHub
Browse files

Implement BGE-M3 Sparse Embeddings in SGLang (#10869)


Co-authored-by: default avatarChristian Bahls <christian.bahls@planet-ai.de>
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
parent 5dccf697
......@@ -237,6 +237,9 @@ class Envs:
SGLANG_KT_AMX_METHOD = EnvStr(None)
SGLANG_KT_MOE_CHUNKED_PREFILL_SIZE = EnvInt(None)
# Sparse Embeddings
SGLANG_EMBEDDINGS_SPARSE_HEAD = EnvStr(None)
# fmt: on
......
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig
from sglang.srt.model_executor.model_runner import ForwardBatch
@dataclass
class SparseEmbeddingOutput:
embeddings: torch.Tensor # [batch_size, vocab_size]
class SparsePooler(nn.Module):
"""A layer that pools hidden states into sparse vocabulary-space embeddings.
This layer does the following:
1. Applies a linear transformation + ReLU to get token-level weights
2. Maps these weights to vocabulary positions using token IDs
3. Aggregates weights for repeated tokens using max pooling
4. Returns sparse embeddings in vocabulary space
Attributes:
config: Model configuration containing vocab_size and hidden_size
sparse_linear: Linear layer for computing token weights
vocab_size: Size of vocabulary for output embeddings
"""
def __init__(self, config: PretrainedConfig):
super().__init__()
# Validate required attributes
if not hasattr(config, "vocab_size"):
raise AttributeError(
f"Config {type(config)} missing required 'vocab_size' attribute"
)
if not hasattr(config, "hidden_size"):
raise AttributeError(
f"Config {type(config)} missing required 'hidden_size' attribute"
)
self.vocab_size = config.vocab_size
self.sparse_linear = nn.Linear(config.hidden_size, 1)
self._weights_loaded = False
def forward(
self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
) -> SparseEmbeddingOutput:
"""
Forward pass for sparse pooling.
Args:
hidden_states: Packed sequence hidden states [total_tokens, hidden_size]
forward_batch: Batch information with sequence lengths and input_ids
Returns:
SparseEmbeddingOutput with embeddings of shape [batch_size, vocab_size]
"""
if not self._weights_loaded:
raise ValueError(
"Sparse pooling weights not loaded. Call load_weights() first"
)
# Apply sparse linear + ReLU to get token weights
token_weights = F.relu(self.sparse_linear(hidden_states)).squeeze(
-1
) # [total_tokens]
# Create batch indices for packed sequences
batch_indices = torch.repeat_interleave(
torch.arange(
len(forward_batch.extend_seq_lens), device=hidden_states.device
),
forward_batch.extend_seq_lens,
)
# Initialize sparse embedding output
sparse_embedding = torch.zeros(
len(forward_batch.extend_seq_lens),
self.vocab_size,
dtype=token_weights.dtype,
device=token_weights.device,
)
# Map to vocabulary space using scatter_reduce with amax
flat_indices = batch_indices * self.vocab_size + forward_batch.input_ids
sparse_embedding.view(-1).scatter_reduce_(
0, flat_indices, token_weights, reduce="amax"
)
return SparseEmbeddingOutput(embeddings=sparse_embedding)
def load_weights(self, state_dict: dict):
"""Load weights from state dict (called by the model)."""
self.sparse_linear.load_state_dict(state_dict)
self._weights_loaded = True
......@@ -961,7 +961,7 @@ class BatchEmbeddingOutput(BaseBatchReq):
# The finish reason
finished_reasons: List[BaseFinishReason]
# The output embedding
embeddings: List[List[float]]
embeddings: Union[List[List[float]], List[Dict[int, float]]]
# Token counts
prompt_tokens: List[int]
cached_tokens: List[int]
......
......@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import torch
from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.environ import envs
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.io_struct import (
AbortReq,
......@@ -175,7 +176,21 @@ class SchedulerOutputProcessorMixin:
logprob_pt += num_input_logprobs
else: # embedding or reward model
embeddings = result.embeddings.tolist()
is_sparse = envs.SGLANG_EMBEDDINGS_SPARSE_HEAD.is_set()
embeddings = result.embeddings
if is_sparse:
batch_ids, token_ids = embeddings.indices()
values = embeddings.values()
embeddings = [{} for _ in range(embeddings.size(0))]
for i in range(batch_ids.shape[0]):
embeddings[batch_ids[i].item()][token_ids[i].item()] = values[
i
].item()
else:
embeddings = embeddings.tolist()
# Check finish conditions
for i, req in enumerate(batch.reqs):
......
......@@ -77,6 +77,7 @@ from sglang.srt.model_loader.utils import (
DEFAULT_GPU_MEMORY_FRACTION_FOR_CALIBRATION = (
0.8 # Reserve 20% GPU memory headroom for ModelOpt calibration
)
from sglang.srt.environ import envs
from sglang.srt.model_loader.weight_utils import (
download_safetensors_index_file_from_hf,
download_weights_from_hf,
......@@ -244,10 +245,19 @@ def _initialize_model(
quant_config = _get_quantization_config(
model_config, load_config, packed_modules_mapping
)
return model_class(
config=model_config.hf_config,
quant_config=quant_config,
)
# Build kwargs conditionally
kwargs = {
"config": model_config.hf_config,
"quant_config": quant_config,
}
# Only add sparse head kwargs if envs.SGLANG_EMBEDDINGS_SPARSE_HEAD.is_set()
if envs.SGLANG_EMBEDDINGS_SPARSE_HEAD.is_set():
kwargs["sparse_head"] = envs.SGLANG_EMBEDDINGS_SPARSE_HEAD.value
kwargs["model_path"] = model_config.model_path
return model_class(**kwargs)
class BaseModelLoader(ABC):
......
# SPDX-License-Identifier: Apache-2.0
import os
from typing import Iterable, Optional, Tuple
import torch
......@@ -7,10 +8,12 @@ from torch import nn
from sglang.srt.layers.pooler import CrossEncodingPooler, Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.sparse_pooler import SparsePooler
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.bert import BertEncoder
from sglang.srt.utils.hf_transformers_utils import download_from_hf
RobertaConfig = None
......@@ -205,11 +208,28 @@ class XLMRobertaModel(nn.Module):
config: RobertaConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
sparse_head: Optional[str] = None,
model_path: Optional[str] = None,
):
super().__init__()
self.roberta = XLMRobertaBaseModel(
config=config, quant_config=quant_config, prefix=prefix
)
if sparse_head is not None:
self._is_sparse = True
self._model_path = model_path
self._sparse_head = sparse_head
self.pooler = SparsePooler(config=config)
# Zero out special tokens
self._special_tokens = [
config.bos_token_id,
config.eos_token_id,
config.pad_token_id,
# self.config.unk_token_id # not available in the XLMRobertaConfig
]
self._special_tokens = [t for t in self._special_tokens if t is not None]
else:
self._is_sparse = False
self.pooler = Pooler(pooling_type=PoolingType.CLS, normalize=True)
def forward(
......@@ -223,11 +243,44 @@ class XLMRobertaModel(nn.Module):
hidden_states = self.roberta(
input_ids, positions, forward_batch, input_embeds, get_embedding
)
return self.pooler(hidden_states, forward_batch)
embeddings = self.pooler(hidden_states, forward_batch)
if self._is_sparse:
for token_id in self._special_tokens:
embeddings.embeddings[:, token_id] = 0.0
embeddings.embeddings = embeddings.embeddings.to_sparse()
return embeddings
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
self.roberta.load_weights(weights)
if self._is_sparse:
sparse_dict = XLMRobertaModel._load_sparse_linear(
self._model_path, self._sparse_head
)
self.pooler.load_weights(sparse_dict)
@staticmethod
def _load_sparse_linear(model_path_or_dir: str, sparse_head: str) -> dict:
"""
Load sparse_head from local dir or HF Hub.
Returns a state_dict suitable for nn.Linear.load_state_dict().
"""
if os.path.isdir(model_path_or_dir):
path = os.path.join(model_path_or_dir, sparse_head)
if not os.path.exists(path):
raise FileNotFoundError(
f"'{sparse_head}' not found in {model_path_or_dir}"
)
else:
# remote → use SGLang HF utility
local_dir = download_from_hf(model_path_or_dir, allow_patterns=sparse_head)
path = os.path.join(local_dir, sparse_head)
state_dict = torch.load(path)
return state_dict
class XLMRobertaForSequenceClassification(nn.Module):
def __init__(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment