Unverified Commit e040a245 authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

Add e5-mistral embedding model - step 3/3 (#988)

parent 9f662501
...@@ -35,6 +35,7 @@ jobs: ...@@ -35,6 +35,7 @@ jobs:
pip install -e "python[all]" pip install -e "python[all]"
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall
pip install accelerate pip install accelerate
pip install sentence_transformers
- name: Test Frontend Language - name: Test Frontend Language
run: | run: |
......
...@@ -25,7 +25,11 @@ import zmq ...@@ -25,7 +25,11 @@ import zmq
import zmq.asyncio import zmq.asyncio
from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut from sglang.srt.managers.io_struct import (
BatchEmbeddingOut,
BatchStrOut,
BatchTokenIDOut,
)
from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.utils import find_printable_text, get_exception_traceback, graceful_registry from sglang.utils import find_printable_text, get_exception_traceback, graceful_registry
...@@ -66,6 +70,18 @@ class DetokenizerManager: ...@@ -66,6 +70,18 @@ class DetokenizerManager:
async def handle_loop(self): async def handle_loop(self):
while True: while True:
recv_obj: BatchTokenIDOut = await self.recv_from_router.recv_pyobj() recv_obj: BatchTokenIDOut = await self.recv_from_router.recv_pyobj()
if isinstance(recv_obj, BatchEmbeddingOut):
self.send_to_tokenizer.send_pyobj(
BatchEmbeddingOut(
rids=recv_obj.rids,
embeddings=recv_obj.embeddings,
meta_info=recv_obj.meta_info,
finished_reason=recv_obj.finished_reason,
)
)
continue
assert isinstance(recv_obj, BatchTokenIDOut) assert isinstance(recv_obj, BatchTokenIDOut)
bs = len(recv_obj.rids) bs = len(recv_obj.rids)
......
...@@ -143,6 +143,7 @@ class Req: ...@@ -143,6 +143,7 @@ class Req:
# Logprobs # Logprobs
self.return_logprob = False self.return_logprob = False
self.embedding = None
self.logprob_start_len = 0 self.logprob_start_len = 0
self.top_logprobs_num = 0 self.top_logprobs_num = 0
self.normalized_prompt_logprob = None self.normalized_prompt_logprob = None
......
...@@ -21,7 +21,7 @@ import dataclasses ...@@ -21,7 +21,7 @@ import dataclasses
import logging import logging
import multiprocessing as mp import multiprocessing as mp
import os import os
from typing import Dict, List, Tuple from typing import Dict, List, Tuple, Union
import numpy as np import numpy as np
import transformers import transformers
...@@ -38,16 +38,19 @@ from sglang.srt.hf_transformers_utils import ( ...@@ -38,16 +38,19 @@ from sglang.srt.hf_transformers_utils import (
) )
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
AbortReq, AbortReq,
BatchEmbeddingOut,
BatchStrOut, BatchStrOut,
BatchTokenIDOut, BatchTokenIDOut,
EmbeddingReqInput,
FlushCacheReq, FlushCacheReq,
GenerateReqInput, GenerateReqInput,
TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput, TokenizedGenerateReqInput,
) )
from sglang.srt.mm_utils import expand2square, process_anyres_image from sglang.srt.mm_utils import expand2square, process_anyres_image
from sglang.srt.sampling_params import SamplingParams from sglang.srt.sampling_params import SamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import is_multimodal_model, load_image from sglang.srt.utils import is_generation_model, is_multimodal_model, load_image
from sglang.utils import get_exception_traceback from sglang.utils import get_exception_traceback
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
...@@ -85,6 +88,7 @@ class TokenizerManager: ...@@ -85,6 +88,7 @@ class TokenizerManager:
trust_remote_code=server_args.trust_remote_code, trust_remote_code=server_args.trust_remote_code,
model_overide_args=model_overide_args, model_overide_args=model_overide_args,
) )
self.is_generation = is_generation_model(self.hf_config.architectures)
if server_args.context_length is not None: if server_args.context_length is not None:
self.context_len = server_args.context_length self.context_len = server_args.context_length
...@@ -133,7 +137,9 @@ class TokenizerManager: ...@@ -133,7 +137,9 @@ class TokenizerManager:
image_data, aspect_ratio, grid_pinpoints, self.processor image_data, aspect_ratio, grid_pinpoints, self.processor
) )
async def generate_request(self, obj: GenerateReqInput, request=None): async def generate_request(
self, obj: Union[GenerateReqInput, EmbeddingReqInput], request=None
):
if self.to_create_loop: if self.to_create_loop:
self.create_handle_loop() self.create_handle_loop()
...@@ -144,6 +150,8 @@ class TokenizerManager: ...@@ -144,6 +150,8 @@ class TokenizerManager:
async for response in self._handle_single_request(obj, request): async for response in self._handle_single_request(obj, request):
yield response yield response
else: else:
if isinstance(obj, EmbeddingReqInput):
raise NotImplementedError("Please send only one prompt in each request")
if obj.stream: if obj.stream:
raise ValueError("Do not support stream for batch mode.") raise ValueError("Do not support stream for batch mode.")
...@@ -151,26 +159,29 @@ class TokenizerManager: ...@@ -151,26 +159,29 @@ class TokenizerManager:
yield response yield response
async def _handle_single_request( async def _handle_single_request(
self, obj, request, index=None, is_cache_for_prefill=False self,
obj: Union[GenerateReqInput, EmbeddingReqInput],
request,
index=None,
is_cache_for_prefill=False,
): ):
if not is_cache_for_prefill: # The normal case with a single prompt if not is_cache_for_prefill: # The normal case with a single prompt
not_use_index = index is None not_use_index = index is None
rid = obj.rid if not_use_index else obj.rid[index] rid = obj.rid if not_use_index else obj.rid[index]
input_text = obj.text if not_use_index else obj.text[index] input_text = obj.text if not_use_index else obj.text[index]
input_ids = ( if obj.input_ids is None:
self.tokenizer.encode(input_text) input_ids = self.tokenizer.encode(input_text)
if obj.input_ids is None else:
else obj.input_ids input_ids = obj.input_ids if not_use_index else obj.input_ids[index]
)
if not not_use_index and obj.input_ids:
input_ids = obj.input_ids[index]
self._validate_input_length(input_ids) self._validate_input_length(input_ids)
sampling_params = self._get_sampling_params( sampling_params = self._get_sampling_params(
obj.sampling_params if not_use_index else obj.sampling_params[index] obj.sampling_params if not_use_index else obj.sampling_params[index]
) )
if self.is_generation:
pixel_values, image_hash, image_size = await self._get_pixel_values( pixel_values, image_hash, image_size = await self._get_pixel_values(
obj.image_data if not_use_index else obj.image_data[index] obj.image_data if not_use_index else obj.image_data[index]
) )
...@@ -178,12 +189,17 @@ class TokenizerManager: ...@@ -178,12 +189,17 @@ class TokenizerManager:
obj.return_logprob if not_use_index else obj.return_logprob[index] obj.return_logprob if not_use_index else obj.return_logprob[index]
) )
logprob_start_len = ( logprob_start_len = (
obj.logprob_start_len if not_use_index else obj.logprob_start_len[index] obj.logprob_start_len
if not_use_index
else obj.logprob_start_len[index]
) )
top_logprobs_num = ( top_logprobs_num = (
obj.top_logprobs_num if not_use_index else obj.top_logprobs_num[index] obj.top_logprobs_num
if not_use_index
else obj.top_logprobs_num[index]
) )
else: # A prefill request to cache the common prompt for parallel sampling else: # A prefill request to cache the common prompt for parallel sampling
assert self.is_generation
if obj.text is not None: if obj.text is not None:
if isinstance(obj.text, list): if isinstance(obj.text, list):
input_text = obj.text[index] input_text = obj.text[index]
...@@ -213,6 +229,7 @@ class TokenizerManager: ...@@ -213,6 +229,7 @@ class TokenizerManager:
logprob_start_len = obj.logprob_start_len[0] logprob_start_len = obj.logprob_start_len[0]
top_logprobs_num = obj.top_logprobs_num[0] top_logprobs_num = obj.top_logprobs_num[0]
if self.is_generation:
tokenized_obj = TokenizedGenerateReqInput( tokenized_obj = TokenizedGenerateReqInput(
rid, rid,
input_text, input_text,
...@@ -226,6 +243,14 @@ class TokenizerManager: ...@@ -226,6 +243,14 @@ class TokenizerManager:
top_logprobs_num, top_logprobs_num,
obj.stream, obj.stream,
) )
else: # is embedding
tokenized_obj = TokenizedEmbeddingReqInput(
rid,
input_text,
input_ids,
sampling_params,
)
self.send_to_router.send_pyobj(tokenized_obj) self.send_to_router.send_pyobj(tokenized_obj)
event = asyncio.Event() event = asyncio.Event()
...@@ -368,7 +393,7 @@ class TokenizerManager: ...@@ -368,7 +393,7 @@ class TokenizerManager:
self, self,
event: asyncio.Event, event: asyncio.Event,
state: ReqState, state: ReqState,
obj: GenerateReqInput, obj: Union[GenerateReqInput, EmbeddingReqInput],
rid: str, rid: str,
request, request,
): ):
...@@ -381,12 +406,15 @@ class TokenizerManager: ...@@ -381,12 +406,15 @@ class TokenizerManager:
raise ValueError(f"Abort request {rid}") raise ValueError(f"Abort request {rid}")
continue continue
if self.is_generation:
out = self.convert_logprob_style( out = self.convert_logprob_style(
state.out_list[-1], state.out_list[-1],
obj.return_logprob, obj.return_logprob,
obj.top_logprobs_num, obj.top_logprobs_num,
obj.return_text_in_logprobs, obj.return_text_in_logprobs,
) )
else: # isinstance(obj, EmbeddingReqInput)
out = state.out_list[-1]
# Log requests # Log requests
if self.server_args.log_requests and state.finished: if self.server_args.log_requests and state.finished:
...@@ -459,8 +487,10 @@ class TokenizerManager: ...@@ -459,8 +487,10 @@ class TokenizerManager:
async def handle_loop(self): async def handle_loop(self):
while True: while True:
recv_obj: BatchStrOut = await self.recv_from_detokenizer.recv_pyobj() recv_obj: Union[BatchStrOut, BatchEmbeddingOut] = (
assert isinstance(recv_obj, BatchStrOut) await self.recv_from_detokenizer.recv_pyobj()
)
assert isinstance(recv_obj, (BatchStrOut, BatchEmbeddingOut))
for i, rid in enumerate(recv_obj.rids): for i, rid in enumerate(recv_obj.rids):
state = self.rid_to_state.get(rid, None) state = self.rid_to_state.get(rid, None)
...@@ -468,10 +498,17 @@ class TokenizerManager: ...@@ -468,10 +498,17 @@ class TokenizerManager:
continue continue
recv_obj.meta_info[i]["id"] = rid recv_obj.meta_info[i]["id"] = rid
if isinstance(recv_obj, BatchStrOut):
out_dict = { out_dict = {
"text": recv_obj.output_strs[i], "text": recv_obj.output_strs[i],
"meta_info": recv_obj.meta_info[i], "meta_info": recv_obj.meta_info[i],
} }
else:
assert isinstance(recv_obj, BatchEmbeddingOut)
out_dict = {
"embedding": recv_obj.embeddings[i],
"meta_info": recv_obj.meta_info[i],
}
state.out_list.append(out_dict) state.out_list.append(out_dict)
state.finished = recv_obj.finished_reason[i] is not None state.finished = recv_obj.finished_reason[i] is not None
state.event.set() state.event.set()
......
...@@ -20,7 +20,7 @@ import multiprocessing ...@@ -20,7 +20,7 @@ import multiprocessing
import pickle import pickle
import time import time
import warnings import warnings
from typing import List, Optional from typing import List, Optional, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -31,8 +31,10 @@ from sglang.srt.constrained.jump_forward import JumpForwardCache ...@@ -31,8 +31,10 @@ from sglang.srt.constrained.jump_forward import JumpForwardCache
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
AbortReq, AbortReq,
BatchEmbeddingOut,
BatchTokenIDOut, BatchTokenIDOut,
FlushCacheReq, FlushCacheReq,
TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput, TokenizedGenerateReqInput,
) )
from sglang.srt.managers.policy_scheduler import PolicyScheduler, PrefillAdder from sglang.srt.managers.policy_scheduler import PolicyScheduler, PrefillAdder
...@@ -205,7 +207,9 @@ class ModelTpServer: ...@@ -205,7 +207,9 @@ class ModelTpServer:
try: try:
# Recv requests # Recv requests
for recv_req in recv_reqs: for recv_req in recv_reqs:
if isinstance(recv_req, TokenizedGenerateReqInput): if isinstance(
recv_req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
):
self.handle_generate_request(recv_req) self.handle_generate_request(recv_req)
elif isinstance(recv_req, FlushCacheReq): elif isinstance(recv_req, FlushCacheReq):
self.flush_cache() self.flush_cache()
...@@ -297,9 +301,12 @@ class ModelTpServer: ...@@ -297,9 +301,12 @@ class ModelTpServer:
def handle_generate_request( def handle_generate_request(
self, self,
recv_req: TokenizedGenerateReqInput, recv_req: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput],
): ):
req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids) req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
req.tokenizer = self.tokenizer
req.sampling_params = recv_req.sampling_params
if self.model_runner.is_generation:
req.pixel_values = recv_req.pixel_values req.pixel_values = recv_req.pixel_values
if req.pixel_values is not None: if req.pixel_values is not None:
req.pad_value = [ req.pad_value = [
...@@ -318,12 +325,10 @@ class ModelTpServer: ...@@ -318,12 +325,10 @@ class ModelTpServer:
req.pixel_values.shape, req.pixel_values.shape,
req.image_size, req.image_size,
) )
req.sampling_params = recv_req.sampling_params
req.return_logprob = recv_req.return_logprob req.return_logprob = recv_req.return_logprob
req.logprob_start_len = recv_req.logprob_start_len req.logprob_start_len = recv_req.logprob_start_len
req.top_logprobs_num = recv_req.top_logprobs_num req.top_logprobs_num = recv_req.top_logprobs_num
req.stream = recv_req.stream req.stream = recv_req.stream
req.tokenizer = self.tokenizer
# Init regex fsm # Init regex fsm
if req.sampling_params.regex is not None: if req.sampling_params.regex is not None:
...@@ -340,6 +345,8 @@ class ModelTpServer: ...@@ -340,6 +345,8 @@ class ModelTpServer:
"the max context length. Truncated!!!" "the max context length. Truncated!!!"
) )
req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len] req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
if self.model_runner.is_generation:
req.sampling_params.max_new_tokens = min( req.sampling_params.max_new_tokens = min(
( (
req.sampling_params.max_new_tokens req.sampling_params.max_new_tokens
...@@ -348,6 +355,7 @@ class ModelTpServer: ...@@ -348,6 +355,7 @@ class ModelTpServer:
), ),
self.max_req_input_len - 1 - len(req.origin_input_ids), self.max_req_input_len - 1 - len(req.origin_input_ids),
) )
self.waiting_queue.append(req) self.waiting_queue.append(req)
def get_new_prefill_batch(self) -> Optional[ScheduleBatch]: def get_new_prefill_batch(self) -> Optional[ScheduleBatch]:
...@@ -439,6 +447,7 @@ class ModelTpServer: ...@@ -439,6 +447,7 @@ class ModelTpServer:
self.model_config.vocab_size, self.int_token_logit_bias self.model_config.vocab_size, self.int_token_logit_bias
) )
if self.model_runner.is_generation:
# Forward and sample the next tokens # Forward and sample the next tokens
if batch.extend_num_tokens != 0: if batch.extend_num_tokens != 0:
output = self.model_runner.forward(batch, ForwardMode.EXTEND) output = self.model_runner.forward(batch, ForwardMode.EXTEND)
...@@ -480,6 +489,26 @@ class ModelTpServer: ...@@ -480,6 +489,26 @@ class ModelTpServer:
if req.return_logprob: if req.return_logprob:
self.add_logprob_return_values(i, req, pt, next_token_ids, output) self.add_logprob_return_values(i, req, pt, next_token_ids, output)
pt += req.extend_input_len pt += req.extend_input_len
else:
assert batch.extend_num_tokens != 0
output = self.model_runner.forward(batch, ForwardMode.EXTEND)
embeddings = output.embeddings.tolist()
# Check finish conditions
for i, req in enumerate(batch.reqs):
req.embedding = embeddings[i]
if req is not self.current_inflight_req:
# Inflight reqs' prefill is not finished
req.check_finished()
if req.finished():
self.tree_cache.cache_finished_req(req)
else:
self.tree_cache.cache_unfinished_req(req)
if req is self.current_inflight_req:
# Inflight request would get a new req idx
self.req_to_token_pool.free(req.req_pool_idx)
self.handle_finished_requests(batch) self.handle_finished_requests(batch)
...@@ -596,15 +625,19 @@ class ModelTpServer: ...@@ -596,15 +625,19 @@ class ModelTpServer:
def handle_finished_requests(self, batch: ScheduleBatch): def handle_finished_requests(self, batch: ScheduleBatch):
output_rids = [] output_rids = []
output_meta_info = []
output_finished_reason: List[BaseFinishReason] = []
if self.model_runner.is_generation:
output_vids = [] output_vids = []
decoded_texts = [] decoded_texts = []
output_read_ids = [] output_read_ids = []
output_read_offsets = [] output_read_offsets = []
output_skip_special_tokens = [] output_skip_special_tokens = []
output_spaces_between_special_tokens = [] output_spaces_between_special_tokens = []
output_meta_info = [] else: # for embedding model
output_finished_reason: List[BaseFinishReason] = [] output_embeddings = []
unfinished_indices = [] unfinished_indices = []
for i, req in enumerate(batch.reqs): for i, req in enumerate(batch.reqs):
if not req.finished() and req is not self.current_inflight_req: if not req.finished() and req is not self.current_inflight_req:
unfinished_indices.append(i) unfinished_indices.append(i)
...@@ -619,6 +652,8 @@ class ModelTpServer: ...@@ -619,6 +652,8 @@ class ModelTpServer:
) )
): ):
output_rids.append(req.rid) output_rids.append(req.rid)
output_finished_reason.append(req.finished_reason)
if self.model_runner.is_generation:
output_vids.append(req.vid) output_vids.append(req.vid)
decoded_texts.append(req.decoded_text) decoded_texts.append(req.decoded_text)
read_ids, read_offset = req.init_incremental_detokenize() read_ids, read_offset = req.init_incremental_detokenize()
...@@ -652,10 +687,16 @@ class ModelTpServer: ...@@ -652,10 +687,16 @@ class ModelTpServer:
req.normalized_prompt_logprob, req.normalized_prompt_logprob,
) )
output_meta_info.append(meta_info) output_meta_info.append(meta_info)
output_finished_reason.append(req.finished_reason) else: # for embedding model
output_embeddings.append(req.embedding)
meta_info = {
"prompt_tokens": len(req.origin_input_ids),
}
output_meta_info.append(meta_info)
# Send to detokenizer # Send to detokenizer
if output_rids: if output_rids:
if self.model_runner.is_generation:
self.out_pyobjs.append( self.out_pyobjs.append(
BatchTokenIDOut( BatchTokenIDOut(
output_rids, output_rids,
...@@ -669,6 +710,15 @@ class ModelTpServer: ...@@ -669,6 +710,15 @@ class ModelTpServer:
output_finished_reason, output_finished_reason,
) )
) )
else: # for embedding model
self.out_pyobjs.append(
BatchEmbeddingOut(
output_rids,
output_embeddings,
output_meta_info,
output_finished_reason,
)
)
# Remove finished reqs: update batch tensors # Remove finished reqs: update batch tensors
batch.filter_batch(unfinished_indices) batch.filter_batch(unfinished_indices)
......
...@@ -52,6 +52,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetad ...@@ -52,6 +52,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetad
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
get_available_gpu_memory, get_available_gpu_memory,
is_generation_model,
is_llama3_405b_fp8, is_llama3_405b_fp8,
is_multimodal_model, is_multimodal_model,
monkey_patch_vllm_dummy_weight_loader, monkey_patch_vllm_dummy_weight_loader,
...@@ -132,6 +133,8 @@ class ModelRunner: ...@@ -132,6 +133,8 @@ class ModelRunner:
self.init_cublas() self.init_cublas()
self.init_flashinfer() self.init_flashinfer()
if self.is_generation:
# FIXME Currently, cuda graph only capture decode steps, which only exists in causal models
# Capture cuda graphs # Capture cuda graphs
self.init_cuda_graphs() self.init_cuda_graphs()
...@@ -184,6 +187,10 @@ class ModelRunner: ...@@ -184,6 +187,10 @@ class ModelRunner:
scheduler_config=None, scheduler_config=None,
cache_config=None, cache_config=None,
) )
self.is_generation = is_generation_model(
self.model_config.hf_config.architectures
)
logger.info( logger.info(
f"[gpu={self.gpu_id}] Load weight end. " f"[gpu={self.gpu_id}] Load weight end. "
f"type={type(self.model).__name__}, " f"type={type(self.model).__name__}, "
...@@ -406,8 +413,10 @@ def import_model_classes(): ...@@ -406,8 +413,10 @@ def import_model_classes():
entry, list entry, list
): # To support multiple model classes in one module ): # To support multiple model classes in one module
for tmp in entry: for tmp in entry:
assert tmp.__name__ not in model_arch_name_to_cls
model_arch_name_to_cls[tmp.__name__] = tmp model_arch_name_to_cls[tmp.__name__] = tmp
else: else:
assert entry.__name__ not in model_arch_name_to_cls
model_arch_name_to_cls[entry.__name__] = entry model_arch_name_to_cls[entry.__name__] = entry
# compat: some models such as chatglm has incorrect class set in config.json # compat: some models such as chatglm has incorrect class set in config.json
...@@ -417,6 +426,7 @@ def import_model_classes(): ...@@ -417,6 +426,7 @@ def import_model_classes():
): ):
for remap in module.EntryClassRemapping: for remap in module.EntryClassRemapping:
if isinstance(remap, tuple) and len(remap) == 2: if isinstance(remap, tuple) and len(remap) == 2:
assert remap[0] not in model_arch_name_to_cls
model_arch_name_to_cls[remap[0]] = remap[1] model_arch_name_to_cls[remap[0]] = remap[1]
return model_arch_name_to_cls return model_arch_name_to_cls
......
...@@ -84,3 +84,5 @@ class LlamaEmbeddingModel(nn.Module): ...@@ -84,3 +84,5 @@ class LlamaEmbeddingModel(nn.Module):
EntryClass = LlamaEmbeddingModel EntryClass = LlamaEmbeddingModel
# compat: e5-mistral model.config class == MistralModel
EntryClassRemapping = [("MistralModel", LlamaEmbeddingModel)]
...@@ -52,7 +52,7 @@ from sglang.srt.managers.controller_single import ( ...@@ -52,7 +52,7 @@ from sglang.srt.managers.controller_single import (
start_controller_process as start_controller_process_single, start_controller_process as start_controller_process_single,
) )
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
from sglang.srt.managers.io_struct import GenerateReqInput from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput
from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.openai_api.adapter import ( from sglang.srt.openai_api.adapter import (
load_chat_template_for_openai_api, load_chat_template_for_openai_api,
...@@ -97,6 +97,7 @@ async def health() -> Response: ...@@ -97,6 +97,7 @@ async def health() -> Response:
async def get_model_info(): async def get_model_info():
result = { result = {
"model_path": tokenizer_manager.model_path, "model_path": tokenizer_manager.model_path,
"is_generation": tokenizer_manager.is_generation,
} }
return result return result
...@@ -148,6 +149,21 @@ app.post("/generate")(generate_request) ...@@ -148,6 +149,21 @@ app.post("/generate")(generate_request)
app.put("/generate")(generate_request) app.put("/generate")(generate_request)
async def encode_request(obj: EmbeddingReqInput, request: Request):
"""Handle an embedding request."""
try:
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
return ret
except ValueError as e:
return JSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
)
app.post("/encode")(encode_request)
app.put("/encode")(encode_request)
@app.post("/v1/completions") @app.post("/v1/completions")
async def openai_v1_completions(raw_request: Request): async def openai_v1_completions(raw_request: Request):
return await v1_completions(tokenizer_manager, raw_request) return await v1_completions(tokenizer_manager, raw_request)
...@@ -380,6 +396,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer): ...@@ -380,6 +396,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
except (AssertionError, requests.exceptions.RequestException) as e: except (AssertionError, requests.exceptions.RequestException) as e:
last_traceback = get_exception_traceback() last_traceback = get_exception_traceback()
pass pass
model_info = res.json()
if not success: if not success:
if pipe_finish_writer is not None: if pipe_finish_writer is not None:
...@@ -388,15 +405,17 @@ def _wait_and_warmup(server_args, pipe_finish_writer): ...@@ -388,15 +405,17 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
sys.exit(1) sys.exit(1)
# Send a warmup request # Send a warmup request
request_name = "/generate" if model_info["is_generation"] else "/encode"
max_new_tokens = 8 if model_info["is_generation"] else 0
try: try:
for _ in range(server_args.dp_size): for _ in range(server_args.dp_size):
res = requests.post( res = requests.post(
url + "/generate", url + request_name,
json={ json={
"text": "The capital city of France is", "text": "The capital city of France is",
"sampling_params": { "sampling_params": {
"temperature": 0, "temperature": 0,
"max_new_tokens": 8, "max_new_tokens": max_new_tokens,
}, },
}, },
headers=headers, headers=headers,
...@@ -529,5 +548,18 @@ class Runtime: ...@@ -529,5 +548,18 @@ class Runtime:
) )
return json.dumps(response.json()) return json.dumps(response.json())
def encode(
self,
prompt: str,
):
json_data = {
"text": prompt,
}
response = requests.post(
self.url + "/encode",
json=json_data,
)
return json.dumps(response.json())
def __del__(self): def __del__(self):
self.shutdown() self.shutdown()
...@@ -223,6 +223,15 @@ def is_multimodal_model(model): ...@@ -223,6 +223,15 @@ def is_multimodal_model(model):
raise ValueError("unrecognized type") raise ValueError("unrecognized type")
def is_generation_model(model_architectures):
if (
"LlamaEmbeddingModel" in model_architectures
or "MistralModel" in model_architectures
):
return False
return True
def decode_video_base64(video_base64): def decode_video_base64(video_base64):
from PIL import Image from PIL import Image
......
...@@ -23,6 +23,7 @@ import torch.nn.functional as F ...@@ -23,6 +23,7 @@ import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
from sglang.srt.server import Runtime from sglang.srt.server import Runtime
from sglang.srt.utils import is_generation_model
DEFAULT_PROMPTS = [ DEFAULT_PROMPTS = [
"The capital of France is", "The capital of France is",
...@@ -33,13 +34,6 @@ DEFAULT_PROMPTS = [ ...@@ -33,13 +34,6 @@ DEFAULT_PROMPTS = [
NUM_TOP_LOGPROBS = 5 NUM_TOP_LOGPROBS = 5
def is_embedding_model(model_path):
# FIXME incomplete list
if "e5-mistral-7b-instruct" in model_path.lower():
return True
return False
def get_dtype_str(torch_dtype): def get_dtype_str(torch_dtype):
if torch_dtype is torch.float16: if torch_dtype is torch.float16:
return "float16" return "float16"
...@@ -60,7 +54,7 @@ class HFRunner: ...@@ -60,7 +54,7 @@ class HFRunner:
self, self,
model_path, model_path,
torch_dtype=torch.float16, torch_dtype=torch.float16,
is_embedding_model=None, is_generation_model=None,
): ):
self.in_queue = multiprocessing.Queue() self.in_queue = multiprocessing.Queue()
self.out_queue = multiprocessing.Queue() self.out_queue = multiprocessing.Queue()
...@@ -72,13 +66,13 @@ class HFRunner: ...@@ -72,13 +66,13 @@ class HFRunner:
self.out_queue, self.out_queue,
model_path, model_path,
torch_dtype, torch_dtype,
is_embedding_model, is_generation_model,
), ),
) )
self.model_proc.start() self.model_proc.start()
def start_model_process( def start_model_process(
self, in_queue, out_queue, model_path, torch_dtype, is_embedding_model self, in_queue, out_queue, model_path, torch_dtype, is_generation_model
): ):
self.tokenizer = AutoTokenizer.from_pretrained( self.tokenizer = AutoTokenizer.from_pretrained(
model_path, model_path,
...@@ -86,12 +80,12 @@ class HFRunner: ...@@ -86,12 +80,12 @@ class HFRunner:
trust_remote_code=True, trust_remote_code=True,
) )
self.is_embedding_model = ( self.is_generation_model = (
is_embedding_model(model_path) is_generation_model(model_path)
if is_embedding_model is None if is_generation_model is None
else is_embedding_model else is_generation_model
) )
if not self.is_embedding_model: if self.is_generation_model:
self.model = AutoModelForCausalLM.from_pretrained( self.model = AutoModelForCausalLM.from_pretrained(
model_path, model_path,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
...@@ -103,13 +97,13 @@ class HFRunner: ...@@ -103,13 +97,13 @@ class HFRunner:
self.model = SentenceTransformer( self.model = SentenceTransformer(
model_path, model_path,
device="cpu", model_kwargs={"torch_dtype": torch_dtype},
).to(dtype=torch_dtype) )
while True: while True:
prompts, max_new_tokens = in_queue.get() prompts, max_new_tokens = in_queue.get()
if prompts is not None: if prompts is not None:
if not self.is_embedding_model: if self.is_generation_model:
output_strs = [] output_strs = []
prefill_logprobs = [] prefill_logprobs = []
for p in prompts: for p in prompts:
...@@ -144,7 +138,6 @@ class HFRunner: ...@@ -144,7 +138,6 @@ class HFRunner:
) )
else: else:
assert isinstance(prompts, List[str])
logits = self.model.encode(prompts).tolist() logits = self.model.encode(prompts).tolist()
out_queue.put(ModelOutput(embed_logits=logits)) out_queue.put(ModelOutput(embed_logits=logits))
...@@ -175,16 +168,13 @@ class SRTRunner: ...@@ -175,16 +168,13 @@ class SRTRunner:
model_path, model_path,
tp_size=1, tp_size=1,
torch_dtype=torch.float16, torch_dtype=torch.float16,
is_embedding_model=None, is_generation_model=None,
): ):
self.is_embedding_model = ( self.is_generation_model = (
is_embedding_model(model_path) is_generation_model(model_path)
if is_embedding_model is None if is_generation_model is None
else is_embedding_model else is_generation_model
) )
if self.is_embedding_model:
raise NotImplementedError()
self.runtime = Runtime( self.runtime = Runtime(
model_path=model_path, model_path=model_path,
tp_size=tp_size, tp_size=tp_size,
...@@ -196,6 +186,7 @@ class SRTRunner: ...@@ -196,6 +186,7 @@ class SRTRunner:
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS, prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
max_new_tokens=64, max_new_tokens=64,
): ):
if self.is_generation_model:
# the return value contains logprobs from prefill # the return value contains logprobs from prefill
output_strs = [] output_strs = []
top_input_logprobs = [] top_input_logprobs = []
...@@ -223,11 +214,17 @@ class SRTRunner: ...@@ -223,11 +214,17 @@ class SRTRunner:
] ]
] ]
) )
# print(response["meta_info"]["output_top_logprobs"][0])
return ModelOutput( return ModelOutput(
output_strs=output_strs, top_input_logprobs=top_input_logprobs output_strs=output_strs, top_input_logprobs=top_input_logprobs
) )
else:
logits = []
for prompt in prompts:
response = self.runtime.encode(prompt)
response = json.loads(response)
logits.append(response["embedding"])
return ModelOutput(embed_logits=logits)
def __enter__(self): def __enter__(self):
return self return self
......
...@@ -12,6 +12,8 @@ from typing import Callable, List, Optional ...@@ -12,6 +12,8 @@ from typing import Callable, List, Optional
import numpy as np import numpy as np
import requests import requests
import torch
import torch.nn.functional as F
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.lang.backend.openai import OpenAI from sglang.lang.backend.openai import OpenAI
...@@ -492,3 +494,7 @@ def run_unittest_files(files: List[str], timeout_per_file: float): ...@@ -492,3 +494,7 @@ def run_unittest_files(files: List[str], timeout_per_file: float):
print(f"Fail. Time elapsed: {time.time() - tic:.2f}s") print(f"Fail. Time elapsed: {time.time() - tic:.2f}s")
return 0 if success else -1 return 0 if success else -1
def get_similarities(vec1, vec2):
return F.cosine_similarity(torch.tensor(vec1), torch.tensor(vec2), dim=0)
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
import torch
from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner
from sglang.test.test_utils import get_similarities
MODELS = [("intfloat/e5-mistral-7b-instruct", 1)]
TORCH_DTYPES = [torch.float16]
class TestEmbeddingModels(unittest.TestCase):
def assert_close_prefill_logits(
self,
prompts,
model_path,
tp_size,
torch_dtype,
) -> None:
with HFRunner(
model_path, torch_dtype=torch_dtype, is_generation_model=False
) as hf_runner:
hf_outputs = hf_runner.forward(prompts)
with SRTRunner(
model_path,
tp_size=tp_size,
torch_dtype=torch_dtype,
is_generation_model=False,
) as srt_runner:
srt_outputs = srt_runner.forward(prompts)
for i in range(len(prompts)):
hf_logits = torch.Tensor(hf_outputs.embed_logits[i])
srt_logits = torch.Tensor(srt_outputs.embed_logits[i])
similarities = torch.tensor(get_similarities(hf_logits, srt_logits))
tolerance = 1e-2
assert torch.all(
abs(similarities - 1) < tolerance
), f"embeddings not all close"
def test_prefill_logits(self):
for model, tp_size in MODELS:
for torch_dtype in TORCH_DTYPES:
self.assert_close_prefill_logits(
DEFAULT_PROMPTS, model, tp_size, torch_dtype
)
if __name__ == "__main__":
unittest.main(warnings="ignore")
...@@ -3,7 +3,9 @@ Copyright 2023-2024 SGLang Team ...@@ -3,7 +3,9 @@ Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -33,7 +35,7 @@ class TestCausalModels(unittest.TestCase): ...@@ -33,7 +35,7 @@ class TestCausalModels(unittest.TestCase):
torch_dtype, torch_dtype,
) -> None: ) -> None:
with HFRunner( with HFRunner(
model_path, torch_dtype=torch_dtype, is_embedding_model=False model_path, torch_dtype=torch_dtype, is_generation_model=True
) as hf_runner: ) as hf_runner:
hf_outputs = hf_runner.forward(prompts) hf_outputs = hf_runner.forward(prompts)
...@@ -41,7 +43,7 @@ class TestCausalModels(unittest.TestCase): ...@@ -41,7 +43,7 @@ class TestCausalModels(unittest.TestCase):
model_path, model_path,
tp_size=tp_size, tp_size=tp_size,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
is_embedding_model=False, is_generation_model=True,
) as srt_runner: ) as srt_runner:
srt_outputs = srt_runner.forward(prompts) srt_outputs = srt_runner.forward(prompts)
......
...@@ -10,7 +10,8 @@ suites = { ...@@ -10,7 +10,8 @@ suites = {
"test_vision_openai_server.py", "test_vision_openai_server.py",
"test_chunked_prefill.py", "test_chunked_prefill.py",
"test_torch_compile.py", "test_torch_compile.py",
"models/test_causal_models.py", "models/test_generation_models.py",
"models/test_embedding_models.py",
"sampling/penaltylib", "sampling/penaltylib",
], ],
"sampling/penaltylib": glob.glob( "sampling/penaltylib": glob.glob(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment