"tools/vscode:/vscode.git/clone" did not exist on "2b9838697f3dacad7cd8a6c4fcd7489d53711729"
Unverified Commit e040a245 authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

Add e5-mistral embedding model - step 3/3 (#988)

parent 9f662501
......@@ -35,6 +35,7 @@ jobs:
pip install -e "python[all]"
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall
pip install accelerate
pip install sentence_transformers
- name: Test Frontend Language
run: |
......
......@@ -25,7 +25,11 @@ import zmq
import zmq.asyncio
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut
from sglang.srt.managers.io_struct import (
BatchEmbeddingOut,
BatchStrOut,
BatchTokenIDOut,
)
from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.utils import find_printable_text, get_exception_traceback, graceful_registry
......@@ -66,6 +70,18 @@ class DetokenizerManager:
async def handle_loop(self):
while True:
recv_obj: BatchTokenIDOut = await self.recv_from_router.recv_pyobj()
if isinstance(recv_obj, BatchEmbeddingOut):
self.send_to_tokenizer.send_pyobj(
BatchEmbeddingOut(
rids=recv_obj.rids,
embeddings=recv_obj.embeddings,
meta_info=recv_obj.meta_info,
finished_reason=recv_obj.finished_reason,
)
)
continue
assert isinstance(recv_obj, BatchTokenIDOut)
bs = len(recv_obj.rids)
......
......@@ -143,6 +143,7 @@ class Req:
# Logprobs
self.return_logprob = False
self.embedding = None
self.logprob_start_len = 0
self.top_logprobs_num = 0
self.normalized_prompt_logprob = None
......
......@@ -21,7 +21,7 @@ import dataclasses
import logging
import multiprocessing as mp
import os
from typing import Dict, List, Tuple
from typing import Dict, List, Tuple, Union
import numpy as np
import transformers
......@@ -38,16 +38,19 @@ from sglang.srt.hf_transformers_utils import (
)
from sglang.srt.managers.io_struct import (
AbortReq,
BatchEmbeddingOut,
BatchStrOut,
BatchTokenIDOut,
EmbeddingReqInput,
FlushCacheReq,
GenerateReqInput,
TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput,
)
from sglang.srt.mm_utils import expand2square, process_anyres_image
from sglang.srt.sampling_params import SamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import is_multimodal_model, load_image
from sglang.srt.utils import is_generation_model, is_multimodal_model, load_image
from sglang.utils import get_exception_traceback
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
......@@ -85,6 +88,7 @@ class TokenizerManager:
trust_remote_code=server_args.trust_remote_code,
model_overide_args=model_overide_args,
)
self.is_generation = is_generation_model(self.hf_config.architectures)
if server_args.context_length is not None:
self.context_len = server_args.context_length
......@@ -133,7 +137,9 @@ class TokenizerManager:
image_data, aspect_ratio, grid_pinpoints, self.processor
)
async def generate_request(self, obj: GenerateReqInput, request=None):
async def generate_request(
self, obj: Union[GenerateReqInput, EmbeddingReqInput], request=None
):
if self.to_create_loop:
self.create_handle_loop()
......@@ -144,6 +150,8 @@ class TokenizerManager:
async for response in self._handle_single_request(obj, request):
yield response
else:
if isinstance(obj, EmbeddingReqInput):
raise NotImplementedError("Please send only one prompt in each request")
if obj.stream:
raise ValueError("Do not support stream for batch mode.")
......@@ -151,39 +159,47 @@ class TokenizerManager:
yield response
async def _handle_single_request(
self, obj, request, index=None, is_cache_for_prefill=False
self,
obj: Union[GenerateReqInput, EmbeddingReqInput],
request,
index=None,
is_cache_for_prefill=False,
):
if not is_cache_for_prefill: # The normal case with a single prompt
not_use_index = index is None
rid = obj.rid if not_use_index else obj.rid[index]
input_text = obj.text if not_use_index else obj.text[index]
input_ids = (
self.tokenizer.encode(input_text)
if obj.input_ids is None
else obj.input_ids
)
if not not_use_index and obj.input_ids:
input_ids = obj.input_ids[index]
if obj.input_ids is None:
input_ids = self.tokenizer.encode(input_text)
else:
input_ids = obj.input_ids if not_use_index else obj.input_ids[index]
self._validate_input_length(input_ids)
sampling_params = self._get_sampling_params(
obj.sampling_params if not_use_index else obj.sampling_params[index]
)
pixel_values, image_hash, image_size = await self._get_pixel_values(
obj.image_data if not_use_index else obj.image_data[index]
)
return_logprob = (
obj.return_logprob if not_use_index else obj.return_logprob[index]
)
logprob_start_len = (
obj.logprob_start_len if not_use_index else obj.logprob_start_len[index]
)
top_logprobs_num = (
obj.top_logprobs_num if not_use_index else obj.top_logprobs_num[index]
)
if self.is_generation:
pixel_values, image_hash, image_size = await self._get_pixel_values(
obj.image_data if not_use_index else obj.image_data[index]
)
return_logprob = (
obj.return_logprob if not_use_index else obj.return_logprob[index]
)
logprob_start_len = (
obj.logprob_start_len
if not_use_index
else obj.logprob_start_len[index]
)
top_logprobs_num = (
obj.top_logprobs_num
if not_use_index
else obj.top_logprobs_num[index]
)
else: # A prefill request to cache the common prompt for parallel sampling
assert self.is_generation
if obj.text is not None:
if isinstance(obj.text, list):
input_text = obj.text[index]
......@@ -213,19 +229,28 @@ class TokenizerManager:
logprob_start_len = obj.logprob_start_len[0]
top_logprobs_num = obj.top_logprobs_num[0]
tokenized_obj = TokenizedGenerateReqInput(
rid,
input_text,
input_ids,
pixel_values,
image_hash,
image_size,
sampling_params,
return_logprob,
logprob_start_len,
top_logprobs_num,
obj.stream,
)
if self.is_generation:
tokenized_obj = TokenizedGenerateReqInput(
rid,
input_text,
input_ids,
pixel_values,
image_hash,
image_size,
sampling_params,
return_logprob,
logprob_start_len,
top_logprobs_num,
obj.stream,
)
else: # is embedding
tokenized_obj = TokenizedEmbeddingReqInput(
rid,
input_text,
input_ids,
sampling_params,
)
self.send_to_router.send_pyobj(tokenized_obj)
event = asyncio.Event()
......@@ -368,7 +393,7 @@ class TokenizerManager:
self,
event: asyncio.Event,
state: ReqState,
obj: GenerateReqInput,
obj: Union[GenerateReqInput, EmbeddingReqInput],
rid: str,
request,
):
......@@ -381,12 +406,15 @@ class TokenizerManager:
raise ValueError(f"Abort request {rid}")
continue
out = self.convert_logprob_style(
state.out_list[-1],
obj.return_logprob,
obj.top_logprobs_num,
obj.return_text_in_logprobs,
)
if self.is_generation:
out = self.convert_logprob_style(
state.out_list[-1],
obj.return_logprob,
obj.top_logprobs_num,
obj.return_text_in_logprobs,
)
else: # isinstance(obj, EmbeddingReqInput)
out = state.out_list[-1]
# Log requests
if self.server_args.log_requests and state.finished:
......@@ -459,8 +487,10 @@ class TokenizerManager:
async def handle_loop(self):
while True:
recv_obj: BatchStrOut = await self.recv_from_detokenizer.recv_pyobj()
assert isinstance(recv_obj, BatchStrOut)
recv_obj: Union[BatchStrOut, BatchEmbeddingOut] = (
await self.recv_from_detokenizer.recv_pyobj()
)
assert isinstance(recv_obj, (BatchStrOut, BatchEmbeddingOut))
for i, rid in enumerate(recv_obj.rids):
state = self.rid_to_state.get(rid, None)
......@@ -468,10 +498,17 @@ class TokenizerManager:
continue
recv_obj.meta_info[i]["id"] = rid
out_dict = {
"text": recv_obj.output_strs[i],
"meta_info": recv_obj.meta_info[i],
}
if isinstance(recv_obj, BatchStrOut):
out_dict = {
"text": recv_obj.output_strs[i],
"meta_info": recv_obj.meta_info[i],
}
else:
assert isinstance(recv_obj, BatchEmbeddingOut)
out_dict = {
"embedding": recv_obj.embeddings[i],
"meta_info": recv_obj.meta_info[i],
}
state.out_list.append(out_dict)
state.finished = recv_obj.finished_reason[i] is not None
state.event.set()
......
......@@ -20,7 +20,7 @@ import multiprocessing
import pickle
import time
import warnings
from typing import List, Optional
from typing import List, Optional, Union
import torch
import torch.distributed as dist
......@@ -31,8 +31,10 @@ from sglang.srt.constrained.jump_forward import JumpForwardCache
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.managers.io_struct import (
AbortReq,
BatchEmbeddingOut,
BatchTokenIDOut,
FlushCacheReq,
TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput,
)
from sglang.srt.managers.policy_scheduler import PolicyScheduler, PrefillAdder
......@@ -205,7 +207,9 @@ class ModelTpServer:
try:
# Recv requests
for recv_req in recv_reqs:
if isinstance(recv_req, TokenizedGenerateReqInput):
if isinstance(
recv_req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
):
self.handle_generate_request(recv_req)
elif isinstance(recv_req, FlushCacheReq):
self.flush_cache()
......@@ -297,41 +301,42 @@ class ModelTpServer:
def handle_generate_request(
self,
recv_req: TokenizedGenerateReqInput,
recv_req: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput],
):
req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
req.pixel_values = recv_req.pixel_values
if req.pixel_values is not None:
req.pad_value = [
(recv_req.image_hash) % self.model_config.vocab_size,
(recv_req.image_hash >> 16) % self.model_config.vocab_size,
(recv_req.image_hash >> 32) % self.model_config.vocab_size,
(recv_req.image_hash >> 64) % self.model_config.vocab_size,
]
req.image_size = recv_req.image_size
(
req.origin_input_ids,
req.image_offset,
) = self.model_runner.model.pad_input_ids(
req.origin_input_ids_unpadded,
req.pad_value,
req.pixel_values.shape,
req.image_size,
)
req.sampling_params = recv_req.sampling_params
req.return_logprob = recv_req.return_logprob
req.logprob_start_len = recv_req.logprob_start_len
req.top_logprobs_num = recv_req.top_logprobs_num
req.stream = recv_req.stream
req.tokenizer = self.tokenizer
# Init regex fsm
if req.sampling_params.regex is not None:
req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex)
if not self.disable_regex_jump_forward:
req.jump_forward_map = self.jump_forward_cache.query(
req.sampling_params.regex
req.sampling_params = recv_req.sampling_params
if self.model_runner.is_generation:
req.pixel_values = recv_req.pixel_values
if req.pixel_values is not None:
req.pad_value = [
(recv_req.image_hash) % self.model_config.vocab_size,
(recv_req.image_hash >> 16) % self.model_config.vocab_size,
(recv_req.image_hash >> 32) % self.model_config.vocab_size,
(recv_req.image_hash >> 64) % self.model_config.vocab_size,
]
req.image_size = recv_req.image_size
(
req.origin_input_ids,
req.image_offset,
) = self.model_runner.model.pad_input_ids(
req.origin_input_ids_unpadded,
req.pad_value,
req.pixel_values.shape,
req.image_size,
)
req.return_logprob = recv_req.return_logprob
req.logprob_start_len = recv_req.logprob_start_len
req.top_logprobs_num = recv_req.top_logprobs_num
req.stream = recv_req.stream
# Init regex fsm
if req.sampling_params.regex is not None:
req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex)
if not self.disable_regex_jump_forward:
req.jump_forward_map = self.jump_forward_cache.query(
req.sampling_params.regex
)
# Truncate prompts that are too long
if len(req.origin_input_ids) >= self.max_req_input_len:
......@@ -340,14 +345,17 @@ class ModelTpServer:
"the max context length. Truncated!!!"
)
req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
req.sampling_params.max_new_tokens = min(
(
req.sampling_params.max_new_tokens
if req.sampling_params.max_new_tokens is not None
else 1 << 30
),
self.max_req_input_len - 1 - len(req.origin_input_ids),
)
if self.model_runner.is_generation:
req.sampling_params.max_new_tokens = min(
(
req.sampling_params.max_new_tokens
if req.sampling_params.max_new_tokens is not None
else 1 << 30
),
self.max_req_input_len - 1 - len(req.origin_input_ids),
)
self.waiting_queue.append(req)
def get_new_prefill_batch(self) -> Optional[ScheduleBatch]:
......@@ -439,47 +447,68 @@ class ModelTpServer:
self.model_config.vocab_size, self.int_token_logit_bias
)
# Forward and sample the next tokens
if batch.extend_num_tokens != 0:
output = self.model_runner.forward(batch, ForwardMode.EXTEND)
next_token_ids = batch.sample(output.next_token_logits)
# Move logprobs to cpu
if output.next_token_logprobs is not None:
output.next_token_logprobs = output.next_token_logprobs[
torch.arange(len(next_token_ids), device=next_token_ids.device),
next_token_ids,
].tolist()
output.input_token_logprobs = output.input_token_logprobs.tolist()
output.normalized_prompt_logprobs = (
output.normalized_prompt_logprobs.tolist()
)
if self.model_runner.is_generation:
# Forward and sample the next tokens
if batch.extend_num_tokens != 0:
output = self.model_runner.forward(batch, ForwardMode.EXTEND)
next_token_ids = batch.sample(output.next_token_logits)
# Move logprobs to cpu
if output.next_token_logprobs is not None:
output.next_token_logprobs = output.next_token_logprobs[
torch.arange(len(next_token_ids), device=next_token_ids.device),
next_token_ids,
].tolist()
output.input_token_logprobs = output.input_token_logprobs.tolist()
output.normalized_prompt_logprobs = (
output.normalized_prompt_logprobs.tolist()
)
next_token_ids = next_token_ids.tolist()
else:
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
next_token_ids = next_token_ids.tolist()
else:
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
# Check finish conditions
pt = 0
for i, req in enumerate(batch.reqs):
if req is not self.current_inflight_req:
# Inflight reqs' prefill is not finished
req.completion_tokens_wo_jump_forward += 1
req.output_ids.append(next_token_ids[i])
req.check_finished()
if req.finished():
self.tree_cache.cache_finished_req(req)
else:
self.tree_cache.cache_unfinished_req(req)
# Check finish conditions
pt = 0
for i, req in enumerate(batch.reqs):
if req is not self.current_inflight_req:
# Inflight reqs' prefill is not finished
req.completion_tokens_wo_jump_forward += 1
req.output_ids.append(next_token_ids[i])
req.check_finished()
if req is self.current_inflight_req:
# Inflight request would get a new req idx
self.req_to_token_pool.free(req.req_pool_idx)
if req.finished():
self.tree_cache.cache_finished_req(req)
else:
self.tree_cache.cache_unfinished_req(req)
if req.return_logprob:
self.add_logprob_return_values(i, req, pt, next_token_ids, output)
pt += req.extend_input_len
else:
assert batch.extend_num_tokens != 0
output = self.model_runner.forward(batch, ForwardMode.EXTEND)
embeddings = output.embeddings.tolist()
if req is self.current_inflight_req:
# Inflight request would get a new req idx
self.req_to_token_pool.free(req.req_pool_idx)
# Check finish conditions
for i, req in enumerate(batch.reqs):
req.embedding = embeddings[i]
if req is not self.current_inflight_req:
# Inflight reqs' prefill is not finished
req.check_finished()
if req.return_logprob:
self.add_logprob_return_values(i, req, pt, next_token_ids, output)
pt += req.extend_input_len
if req.finished():
self.tree_cache.cache_finished_req(req)
else:
self.tree_cache.cache_unfinished_req(req)
if req is self.current_inflight_req:
# Inflight request would get a new req idx
self.req_to_token_pool.free(req.req_pool_idx)
self.handle_finished_requests(batch)
......@@ -596,15 +625,19 @@ class ModelTpServer:
def handle_finished_requests(self, batch: ScheduleBatch):
output_rids = []
output_vids = []
decoded_texts = []
output_read_ids = []
output_read_offsets = []
output_skip_special_tokens = []
output_spaces_between_special_tokens = []
output_meta_info = []
output_finished_reason: List[BaseFinishReason] = []
if self.model_runner.is_generation:
output_vids = []
decoded_texts = []
output_read_ids = []
output_read_offsets = []
output_skip_special_tokens = []
output_spaces_between_special_tokens = []
else: # for embedding model
output_embeddings = []
unfinished_indices = []
for i, req in enumerate(batch.reqs):
if not req.finished() and req is not self.current_inflight_req:
unfinished_indices.append(i)
......@@ -619,56 +652,73 @@ class ModelTpServer:
)
):
output_rids.append(req.rid)
output_vids.append(req.vid)
decoded_texts.append(req.decoded_text)
read_ids, read_offset = req.init_incremental_detokenize()
output_read_ids.append(read_ids)
output_read_offsets.append(read_offset)
output_skip_special_tokens.append(
req.sampling_params.skip_special_tokens
)
output_spaces_between_special_tokens.append(
req.sampling_params.spaces_between_special_tokens
)
meta_info = {
"prompt_tokens": len(req.origin_input_ids),
"completion_tokens": len(req.output_ids),
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
"finish_reason": str(req.finished_reason),
}
if req.return_logprob:
(
meta_info["input_token_logprobs"],
meta_info["output_token_logprobs"],
meta_info["input_top_logprobs"],
meta_info["output_top_logprobs"],
meta_info["normalized_prompt_logprob"],
) = (
req.input_token_logprobs,
req.output_token_logprobs,
req.input_top_logprobs,
req.output_top_logprobs,
req.normalized_prompt_logprob,
)
output_meta_info.append(meta_info)
output_finished_reason.append(req.finished_reason)
if self.model_runner.is_generation:
output_vids.append(req.vid)
decoded_texts.append(req.decoded_text)
read_ids, read_offset = req.init_incremental_detokenize()
output_read_ids.append(read_ids)
output_read_offsets.append(read_offset)
output_skip_special_tokens.append(
req.sampling_params.skip_special_tokens
)
output_spaces_between_special_tokens.append(
req.sampling_params.spaces_between_special_tokens
)
meta_info = {
"prompt_tokens": len(req.origin_input_ids),
"completion_tokens": len(req.output_ids),
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
"finish_reason": str(req.finished_reason),
}
if req.return_logprob:
(
meta_info["input_token_logprobs"],
meta_info["output_token_logprobs"],
meta_info["input_top_logprobs"],
meta_info["output_top_logprobs"],
meta_info["normalized_prompt_logprob"],
) = (
req.input_token_logprobs,
req.output_token_logprobs,
req.input_top_logprobs,
req.output_top_logprobs,
req.normalized_prompt_logprob,
)
output_meta_info.append(meta_info)
else: # for embedding model
output_embeddings.append(req.embedding)
meta_info = {
"prompt_tokens": len(req.origin_input_ids),
}
output_meta_info.append(meta_info)
# Send to detokenizer
if output_rids:
self.out_pyobjs.append(
BatchTokenIDOut(
output_rids,
output_vids,
decoded_texts,
output_read_ids,
output_read_offsets,
output_skip_special_tokens,
output_spaces_between_special_tokens,
output_meta_info,
output_finished_reason,
if self.model_runner.is_generation:
self.out_pyobjs.append(
BatchTokenIDOut(
output_rids,
output_vids,
decoded_texts,
output_read_ids,
output_read_offsets,
output_skip_special_tokens,
output_spaces_between_special_tokens,
output_meta_info,
output_finished_reason,
)
)
else: # for embedding model
self.out_pyobjs.append(
BatchEmbeddingOut(
output_rids,
output_embeddings,
output_meta_info,
output_finished_reason,
)
)
)
# Remove finished reqs: update batch tensors
batch.filter_batch(unfinished_indices)
......
......@@ -52,6 +52,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetad
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import (
get_available_gpu_memory,
is_generation_model,
is_llama3_405b_fp8,
is_multimodal_model,
monkey_patch_vllm_dummy_weight_loader,
......@@ -132,8 +133,10 @@ class ModelRunner:
self.init_cublas()
self.init_flashinfer()
# Capture cuda graphs
self.init_cuda_graphs()
if self.is_generation:
# FIXME Currently, cuda graph only capture decode steps, which only exists in causal models
# Capture cuda graphs
self.init_cuda_graphs()
def load_model(self):
logger.info(
......@@ -184,6 +187,10 @@ class ModelRunner:
scheduler_config=None,
cache_config=None,
)
self.is_generation = is_generation_model(
self.model_config.hf_config.architectures
)
logger.info(
f"[gpu={self.gpu_id}] Load weight end. "
f"type={type(self.model).__name__}, "
......@@ -406,8 +413,10 @@ def import_model_classes():
entry, list
): # To support multiple model classes in one module
for tmp in entry:
assert tmp.__name__ not in model_arch_name_to_cls
model_arch_name_to_cls[tmp.__name__] = tmp
else:
assert entry.__name__ not in model_arch_name_to_cls
model_arch_name_to_cls[entry.__name__] = entry
# compat: some models such as chatglm has incorrect class set in config.json
......@@ -417,6 +426,7 @@ def import_model_classes():
):
for remap in module.EntryClassRemapping:
if isinstance(remap, tuple) and len(remap) == 2:
assert remap[0] not in model_arch_name_to_cls
model_arch_name_to_cls[remap[0]] = remap[1]
return model_arch_name_to_cls
......
......@@ -84,3 +84,5 @@ class LlamaEmbeddingModel(nn.Module):
EntryClass = LlamaEmbeddingModel
# compat: e5-mistral model.config class == MistralModel
EntryClassRemapping = [("MistralModel", LlamaEmbeddingModel)]
......@@ -52,7 +52,7 @@ from sglang.srt.managers.controller_single import (
start_controller_process as start_controller_process_single,
)
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput
from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.openai_api.adapter import (
load_chat_template_for_openai_api,
......@@ -97,6 +97,7 @@ async def health() -> Response:
async def get_model_info():
result = {
"model_path": tokenizer_manager.model_path,
"is_generation": tokenizer_manager.is_generation,
}
return result
......@@ -148,6 +149,21 @@ app.post("/generate")(generate_request)
app.put("/generate")(generate_request)
async def encode_request(obj: EmbeddingReqInput, request: Request):
"""Handle an embedding request."""
try:
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
return ret
except ValueError as e:
return JSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
)
app.post("/encode")(encode_request)
app.put("/encode")(encode_request)
@app.post("/v1/completions")
async def openai_v1_completions(raw_request: Request):
return await v1_completions(tokenizer_manager, raw_request)
......@@ -380,6 +396,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
except (AssertionError, requests.exceptions.RequestException) as e:
last_traceback = get_exception_traceback()
pass
model_info = res.json()
if not success:
if pipe_finish_writer is not None:
......@@ -388,15 +405,17 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
sys.exit(1)
# Send a warmup request
request_name = "/generate" if model_info["is_generation"] else "/encode"
max_new_tokens = 8 if model_info["is_generation"] else 0
try:
for _ in range(server_args.dp_size):
res = requests.post(
url + "/generate",
url + request_name,
json={
"text": "The capital city of France is",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 8,
"max_new_tokens": max_new_tokens,
},
},
headers=headers,
......@@ -529,5 +548,18 @@ class Runtime:
)
return json.dumps(response.json())
def encode(
self,
prompt: str,
):
json_data = {
"text": prompt,
}
response = requests.post(
self.url + "/encode",
json=json_data,
)
return json.dumps(response.json())
def __del__(self):
self.shutdown()
......@@ -223,6 +223,15 @@ def is_multimodal_model(model):
raise ValueError("unrecognized type")
def is_generation_model(model_architectures):
if (
"LlamaEmbeddingModel" in model_architectures
or "MistralModel" in model_architectures
):
return False
return True
def decode_video_base64(video_base64):
from PIL import Image
......
......@@ -23,6 +23,7 @@ import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
from sglang.srt.server import Runtime
from sglang.srt.utils import is_generation_model
DEFAULT_PROMPTS = [
"The capital of France is",
......@@ -33,13 +34,6 @@ DEFAULT_PROMPTS = [
NUM_TOP_LOGPROBS = 5
def is_embedding_model(model_path):
# FIXME incomplete list
if "e5-mistral-7b-instruct" in model_path.lower():
return True
return False
def get_dtype_str(torch_dtype):
if torch_dtype is torch.float16:
return "float16"
......@@ -60,7 +54,7 @@ class HFRunner:
self,
model_path,
torch_dtype=torch.float16,
is_embedding_model=None,
is_generation_model=None,
):
self.in_queue = multiprocessing.Queue()
self.out_queue = multiprocessing.Queue()
......@@ -72,13 +66,13 @@ class HFRunner:
self.out_queue,
model_path,
torch_dtype,
is_embedding_model,
is_generation_model,
),
)
self.model_proc.start()
def start_model_process(
self, in_queue, out_queue, model_path, torch_dtype, is_embedding_model
self, in_queue, out_queue, model_path, torch_dtype, is_generation_model
):
self.tokenizer = AutoTokenizer.from_pretrained(
model_path,
......@@ -86,12 +80,12 @@ class HFRunner:
trust_remote_code=True,
)
self.is_embedding_model = (
is_embedding_model(model_path)
if is_embedding_model is None
else is_embedding_model
self.is_generation_model = (
is_generation_model(model_path)
if is_generation_model is None
else is_generation_model
)
if not self.is_embedding_model:
if self.is_generation_model:
self.model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch_dtype,
......@@ -103,13 +97,13 @@ class HFRunner:
self.model = SentenceTransformer(
model_path,
device="cpu",
).to(dtype=torch_dtype)
model_kwargs={"torch_dtype": torch_dtype},
)
while True:
prompts, max_new_tokens = in_queue.get()
if prompts is not None:
if not self.is_embedding_model:
if self.is_generation_model:
output_strs = []
prefill_logprobs = []
for p in prompts:
......@@ -144,7 +138,6 @@ class HFRunner:
)
else:
assert isinstance(prompts, List[str])
logits = self.model.encode(prompts).tolist()
out_queue.put(ModelOutput(embed_logits=logits))
......@@ -175,16 +168,13 @@ class SRTRunner:
model_path,
tp_size=1,
torch_dtype=torch.float16,
is_embedding_model=None,
is_generation_model=None,
):
self.is_embedding_model = (
is_embedding_model(model_path)
if is_embedding_model is None
else is_embedding_model
self.is_generation_model = (
is_generation_model(model_path)
if is_generation_model is None
else is_generation_model
)
if self.is_embedding_model:
raise NotImplementedError()
self.runtime = Runtime(
model_path=model_path,
tp_size=tp_size,
......@@ -196,38 +186,45 @@ class SRTRunner:
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
max_new_tokens=64,
):
# the return value contains logprobs from prefill
output_strs = []
top_input_logprobs = []
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
for prompt in prompts:
response = self.runtime.generate(
prompt,
sampling_params=sampling_params,
return_logprob=True,
top_logprobs_num=NUM_TOP_LOGPROBS,
)
response = json.loads(response)
output_strs.append(response["text"])
top_input_logprobs.append(
[
[tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
for x in response["meta_info"]["input_top_logprobs"][1:]
]
+ [
if self.is_generation_model:
# the return value contains logprobs from prefill
output_strs = []
top_input_logprobs = []
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
for prompt in prompts:
response = self.runtime.generate(
prompt,
sampling_params=sampling_params,
return_logprob=True,
top_logprobs_num=NUM_TOP_LOGPROBS,
)
response = json.loads(response)
output_strs.append(response["text"])
top_input_logprobs.append(
[
tup[0]
for tup in response["meta_info"]["output_top_logprobs"][0][
:NUM_TOP_LOGPROBS
[tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
for x in response["meta_info"]["input_top_logprobs"][1:]
]
+ [
[
tup[0]
for tup in response["meta_info"]["output_top_logprobs"][0][
:NUM_TOP_LOGPROBS
]
]
]
]
)
# print(response["meta_info"]["output_top_logprobs"][0])
)
return ModelOutput(
output_strs=output_strs, top_input_logprobs=top_input_logprobs
)
return ModelOutput(
output_strs=output_strs, top_input_logprobs=top_input_logprobs
)
else:
logits = []
for prompt in prompts:
response = self.runtime.encode(prompt)
response = json.loads(response)
logits.append(response["embedding"])
return ModelOutput(embed_logits=logits)
def __enter__(self):
return self
......
......@@ -12,6 +12,8 @@ from typing import Callable, List, Optional
import numpy as np
import requests
import torch
import torch.nn.functional as F
from sglang.global_config import global_config
from sglang.lang.backend.openai import OpenAI
......@@ -492,3 +494,7 @@ def run_unittest_files(files: List[str], timeout_per_file: float):
print(f"Fail. Time elapsed: {time.time() - tic:.2f}s")
return 0 if success else -1
def get_similarities(vec1, vec2):
return F.cosine_similarity(torch.tensor(vec1), torch.tensor(vec2), dim=0)
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
import torch
from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner
from sglang.test.test_utils import get_similarities
MODELS = [("intfloat/e5-mistral-7b-instruct", 1)]
TORCH_DTYPES = [torch.float16]
class TestEmbeddingModels(unittest.TestCase):
def assert_close_prefill_logits(
self,
prompts,
model_path,
tp_size,
torch_dtype,
) -> None:
with HFRunner(
model_path, torch_dtype=torch_dtype, is_generation_model=False
) as hf_runner:
hf_outputs = hf_runner.forward(prompts)
with SRTRunner(
model_path,
tp_size=tp_size,
torch_dtype=torch_dtype,
is_generation_model=False,
) as srt_runner:
srt_outputs = srt_runner.forward(prompts)
for i in range(len(prompts)):
hf_logits = torch.Tensor(hf_outputs.embed_logits[i])
srt_logits = torch.Tensor(srt_outputs.embed_logits[i])
similarities = torch.tensor(get_similarities(hf_logits, srt_logits))
tolerance = 1e-2
assert torch.all(
abs(similarities - 1) < tolerance
), f"embeddings not all close"
def test_prefill_logits(self):
for model, tp_size in MODELS:
for torch_dtype in TORCH_DTYPES:
self.assert_close_prefill_logits(
DEFAULT_PROMPTS, model, tp_size, torch_dtype
)
if __name__ == "__main__":
unittest.main(warnings="ignore")
......@@ -3,7 +3,9 @@ Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
......@@ -33,7 +35,7 @@ class TestCausalModels(unittest.TestCase):
torch_dtype,
) -> None:
with HFRunner(
model_path, torch_dtype=torch_dtype, is_embedding_model=False
model_path, torch_dtype=torch_dtype, is_generation_model=True
) as hf_runner:
hf_outputs = hf_runner.forward(prompts)
......@@ -41,7 +43,7 @@ class TestCausalModels(unittest.TestCase):
model_path,
tp_size=tp_size,
torch_dtype=torch_dtype,
is_embedding_model=False,
is_generation_model=True,
) as srt_runner:
srt_outputs = srt_runner.forward(prompts)
......
......@@ -10,7 +10,8 @@ suites = {
"test_vision_openai_server.py",
"test_chunked_prefill.py",
"test_torch_compile.py",
"models/test_causal_models.py",
"models/test_generation_models.py",
"models/test_embedding_models.py",
"sampling/penaltylib",
],
"sampling/penaltylib": glob.glob(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment