"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "74d902eb59f873b6156621220937f8e2521dfdc0"
Unverified Commit 9aa6553d authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

[Feature] Support reward model LxzGordon/URM-LLaMa-3.1-8B (#1525)

parent b1e330bc
...@@ -29,6 +29,7 @@ jobs: ...@@ -29,6 +29,7 @@ jobs:
run: | run: |
pip install --upgrade pip pip install --upgrade pip
pip install -e "python[dev]" pip install -e "python[dev]"
pip install transformers==4.44
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall
- name: Run test - name: Run test
...@@ -48,6 +49,7 @@ jobs: ...@@ -48,6 +49,7 @@ jobs:
run: | run: |
pip install --upgrade pip pip install --upgrade pip
pip install -e "python[dev]" pip install -e "python[dev]"
pip install transformers==4.44
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall
- name: Run test - name: Run test
...@@ -67,6 +69,7 @@ jobs: ...@@ -67,6 +69,7 @@ jobs:
run: | run: |
pip install --upgrade pip pip install --upgrade pip
pip install -e "python[dev]" pip install -e "python[dev]"
pip install transformers==4.44
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall
- name: Run test - name: Run test
...@@ -86,6 +89,7 @@ jobs: ...@@ -86,6 +89,7 @@ jobs:
run: | run: |
pip install --upgrade pip pip install --upgrade pip
pip install -e "python[dev]" pip install -e "python[dev]"
pip install transformers==4.44
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall
- name: Run test - name: Run test
...@@ -105,6 +109,7 @@ jobs: ...@@ -105,6 +109,7 @@ jobs:
run: | run: |
pip install --upgrade pip pip install --upgrade pip
pip install -e "python[all]" pip install -e "python[all]"
pip install transformers==4.44
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall
- name: Benchmark Single Latency - name: Benchmark Single Latency
...@@ -136,6 +141,7 @@ jobs: ...@@ -136,6 +141,7 @@ jobs:
run: | run: |
pip install --upgrade pip pip install --upgrade pip
pip install -e "python[all]" pip install -e "python[all]"
pip install transformers==4.44
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall
- name: Benchmark Offline Throughput (w/o RadixAttention) - name: Benchmark Offline Throughput (w/o RadixAttention)
...@@ -167,6 +173,7 @@ jobs: ...@@ -167,6 +173,7 @@ jobs:
run: | run: |
pip install --upgrade pip pip install --upgrade pip
pip install -e "python[all]" pip install -e "python[all]"
pip install transformers==4.44
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall
- name: Benchmark Offline Throughput (TP=2) - name: Benchmark Offline Throughput (TP=2)
...@@ -198,6 +205,7 @@ jobs: ...@@ -198,6 +205,7 @@ jobs:
run: | run: |
pip install --upgrade pip pip install --upgrade pip
pip install -e "python[all]" pip install -e "python[all]"
pip install transformers==4.44
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall
git clone https://github.com/merrymercy/human-eval.git git clone https://github.com/merrymercy/human-eval.git
...@@ -221,6 +229,7 @@ jobs: ...@@ -221,6 +229,7 @@ jobs:
run: | run: |
pip install --upgrade pip pip install --upgrade pip
pip install -e "python[all]" pip install -e "python[all]"
pip install transformers==4.44
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall
git clone https://github.com/merrymercy/human-eval.git git clone https://github.com/merrymercy/human-eval.git
......
# launch server
# python -m sglang.launch_server --model LxzGordon/URM-LLaMa-3.1-8B --is-embedding
import json
import requests
url = "http://127.0.0.1:30000"
PROMPT = (
"What is the range of the numeric output of a sigmoid node in a neural network?"
)
RESPONSE1 = "The output of a sigmoid node is bounded between -1 and 1."
RESPONSE2 = "The output of a sigmoid node is bounded between 0 and 1."
json_data = {
"conv": [
[
{"role": "user", "content": PROMPT},
{"role": "assistant", "content": RESPONSE1},
],
[
{"role": "user", "content": PROMPT},
{"role": "assistant", "content": RESPONSE2},
],
],
}
response = requests.post(
url + "/judge",
json=json_data,
).json()
print(response)
print("scores:", [x["embedding"] for x in response])
...@@ -215,12 +215,11 @@ class EmbeddingReqInput: ...@@ -215,12 +215,11 @@ class EmbeddingReqInput:
raise ValueError("Either text or input_ids should be provided.") raise ValueError("Either text or input_ids should be provided.")
if self.text is not None: if self.text is not None:
is_single = isinstance(self.text, str) self.is_single = isinstance(self.text, str)
else: else:
is_single = isinstance(self.input_ids[0], int) self.is_single = isinstance(self.input_ids[0], int)
self.is_single = is_single
if is_single: if self.is_single:
if self.rid is None: if self.rid is None:
self.rid = uuid.uuid4().hex self.rid = uuid.uuid4().hex
if self.sampling_params is None: if self.sampling_params is None:
...@@ -254,6 +253,52 @@ class TokenizedEmbeddingReqInput: ...@@ -254,6 +253,52 @@ class TokenizedEmbeddingReqInput:
sampling_params: SamplingParams sampling_params: SamplingParams
@dataclass
class RewardReqInput:
# The input prompt in the chat format. It can be a single prompt or a batch of prompts.
conv: Union[List[List[Dict]], List[Dict]]
# The request id.
rid: Optional[Union[List[str], str]] = None
# Dummy sampling params for compatibility
sampling_params: Union[List[Dict], Dict] = None
is_single: bool = True
def post_init(self):
self.is_single = isinstance(self.conv[0], dict)
if self.is_single:
if self.rid is None:
self.rid = uuid.uuid4().hex
if self.sampling_params is None:
self.sampling_params = {}
self.sampling_params["max_new_tokens"] = 1
else:
# support select operation
self.batch_size = len(self.conv)
if self.rid is None:
self.rid = [uuid.uuid4().hex for _ in range(self.batch_size)]
else:
if not isinstance(self.rid, list):
raise ValueError("The rid should be a list.")
if self.sampling_params is None:
self.sampling_params = [{}] * self.batch_size
for i in range(self.batch_size):
self.sampling_params[i]["max_new_tokens"] = 1
@dataclass
class TokenizedRewardReqInput:
# The request id
rid: str
# The input text
input_text: str
# The input token ids
input_ids: List[int]
# Dummy sampling params for compatibility
sampling_params: SamplingParams
@dataclass @dataclass
class BatchTokenIDOut: class BatchTokenIDOut:
# The request id # The request id
......
...@@ -46,8 +46,10 @@ from sglang.srt.managers.io_struct import ( ...@@ -46,8 +46,10 @@ from sglang.srt.managers.io_struct import (
EmbeddingReqInput, EmbeddingReqInput,
FlushCacheReq, FlushCacheReq,
GenerateReqInput, GenerateReqInput,
RewardReqInput,
TokenizedEmbeddingReqInput, TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput, TokenizedGenerateReqInput,
TokenizedRewardReqInput,
UpdateWeightReqInput, UpdateWeightReqInput,
UpdateWeightReqOutput, UpdateWeightReqOutput,
) )
...@@ -142,7 +144,7 @@ class TokenizerManager: ...@@ -142,7 +144,7 @@ class TokenizerManager:
async def generate_request( async def generate_request(
self, self,
obj: Union[GenerateReqInput, EmbeddingReqInput], obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
request: Optional[fastapi.Request] = None, request: Optional[fastapi.Request] = None,
): ):
if self.to_create_loop: if self.to_create_loop:
...@@ -163,7 +165,7 @@ class TokenizerManager: ...@@ -163,7 +165,7 @@ class TokenizerManager:
async def _handle_single_request( async def _handle_single_request(
self, self,
obj: Union[GenerateReqInput, EmbeddingReqInput], obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
request: Optional[fastapi.Request] = None, request: Optional[fastapi.Request] = None,
index: Optional[int] = None, index: Optional[int] = None,
is_cache_for_prefill: Optional[bool] = False, is_cache_for_prefill: Optional[bool] = False,
...@@ -173,7 +175,13 @@ class TokenizerManager: ...@@ -173,7 +175,13 @@ class TokenizerManager:
rid = obj.rid if not_use_index else obj.rid[index] rid = obj.rid if not_use_index else obj.rid[index]
input_text = obj.text if not_use_index else obj.text[index] input_text = obj.text if not_use_index else obj.text[index]
if obj.input_ids is None: if hasattr(obj, "conv"):
# reward model
assert self.tokenizer is not None
conv = obj.conv if not_use_index else obj.conv[index]
input_text = self.tokenizer.apply_chat_template(conv, tokenize=False)
input_ids = self.tokenizer.encode(input_text)
elif obj.input_ids is None:
assert self.tokenizer is not None assert self.tokenizer is not None
input_ids = self.tokenizer.encode(input_text) input_ids = self.tokenizer.encode(input_text)
else: else:
...@@ -269,13 +277,21 @@ class TokenizerManager: ...@@ -269,13 +277,21 @@ class TokenizerManager:
else obj.lora_path else obj.lora_path
), ),
) )
else: # is embedding elif isinstance(obj, EmbeddingReqInput):
tokenized_obj = TokenizedEmbeddingReqInput( tokenized_obj = TokenizedEmbeddingReqInput(
rid, rid,
input_text, input_text,
input_ids, input_ids,
sampling_params, sampling_params,
) )
else:
assert isinstance(obj, RewardReqInput)
tokenized_obj = TokenizedRewardReqInput(
rid,
input_text,
input_ids,
sampling_params,
)
self.send_to_controller.send_pyobj(tokenized_obj) self.send_to_controller.send_pyobj(tokenized_obj)
# Recv results # Recv results
...@@ -292,7 +308,7 @@ class TokenizerManager: ...@@ -292,7 +308,7 @@ class TokenizerManager:
async def _handle_batch_request( async def _handle_batch_request(
self, self,
obj: Union[GenerateReqInput, EmbeddingReqInput], obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
request: Optional[fastapi.Request] = None, request: Optional[fastapi.Request] = None,
): ):
batch_size = obj.batch_size batch_size = obj.batch_size
...@@ -329,9 +345,16 @@ class TokenizerManager: ...@@ -329,9 +345,16 @@ class TokenizerManager:
rid = obj.rid[index] rid = obj.rid[index]
if parallel_sample_num == 1: if parallel_sample_num == 1:
## select operation ## select operation
if obj.input_ids is None: if hasattr(obj, "conv"):
# reward model
conv = obj.conv[i]
input_text = self.tokenizer.apply_chat_template(
conv, tokenize=False
)
input_ids = self.tokenizer.encode(input_text)
elif obj.input_ids is None:
input_text = obj.text[i] input_text = obj.text[i]
input_ids = self.tokenizer.encode(obj.text[i]) input_ids = self.tokenizer.encode(input_text)
else: else:
input_text = None input_text = None
input_ids = obj.input_ids[i] input_ids = obj.input_ids[i]
...@@ -370,13 +393,21 @@ class TokenizerManager: ...@@ -370,13 +393,21 @@ class TokenizerManager:
else obj.lora_path else obj.lora_path
), ),
) )
else: elif isinstance(obj, EmbeddingReqInput):
tokenized_obj = TokenizedEmbeddingReqInput( tokenized_obj = TokenizedEmbeddingReqInput(
rid, rid,
input_text, input_text,
input_ids, input_ids,
sampling_params, sampling_params,
) )
else:
assert isinstance(obj, RewardReqInput)
tokenized_obj = TokenizedRewardReqInput(
rid,
input_text,
input_ids,
sampling_params,
)
self.send_to_controller.send_pyobj(tokenized_obj) self.send_to_controller.send_pyobj(tokenized_obj)
event = asyncio.Event() event = asyncio.Event()
...@@ -442,7 +473,7 @@ class TokenizerManager: ...@@ -442,7 +473,7 @@ class TokenizerManager:
async def _wait_for_response( async def _wait_for_response(
self, self,
state: ReqState, state: ReqState,
obj: Union[GenerateReqInput, EmbeddingReqInput], obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
rid: str, rid: str,
request: Optional[fastapi.Request] = None, request: Optional[fastapi.Request] = None,
index: Optional[int] = None, index: Optional[int] = None,
...@@ -469,7 +500,7 @@ class TokenizerManager: ...@@ -469,7 +500,7 @@ class TokenizerManager:
), ),
obj.return_text_in_logprobs, obj.return_text_in_logprobs,
) )
else: # isinstance(obj, EmbeddingReqInput) else: # isinstance(obj, (EmbeddingReqInput, RewardReqInput))
out = state.out_list[-1] out = state.out_list[-1]
out["index"] = response_index out["index"] = response_index
......
...@@ -22,7 +22,7 @@ import os ...@@ -22,7 +22,7 @@ import os
import pickle import pickle
import time import time
import warnings import warnings
from typing import Any, List, Optional from typing import Any, List, Optional, Union
import torch import torch
import torch.distributed import torch.distributed
...@@ -41,6 +41,7 @@ from sglang.srt.managers.io_struct import ( ...@@ -41,6 +41,7 @@ from sglang.srt.managers.io_struct import (
FlushCacheReq, FlushCacheReq,
TokenizedEmbeddingReqInput, TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput, TokenizedGenerateReqInput,
TokenizedRewardReqInput,
UpdateWeightReqInput, UpdateWeightReqInput,
UpdateWeightReqOutput, UpdateWeightReqOutput,
) )
...@@ -223,7 +224,9 @@ class ModelTpServer: ...@@ -223,7 +224,9 @@ class ModelTpServer:
if isinstance(recv_req, TokenizedGenerateReqInput): if isinstance(recv_req, TokenizedGenerateReqInput):
self.handle_generate_request(recv_req) self.handle_generate_request(recv_req)
self.do_not_get_new_batch = False self.do_not_get_new_batch = False
elif isinstance(recv_req, TokenizedEmbeddingReqInput): elif isinstance(
recv_req, (TokenizedEmbeddingReqInput, TokenizedRewardReqInput)
):
self.handle_embedding_request(recv_req) self.handle_embedding_request(recv_req)
self.do_not_get_new_batch = False self.do_not_get_new_batch = False
elif isinstance(recv_req, FlushCacheReq): elif isinstance(recv_req, FlushCacheReq):
...@@ -407,7 +410,7 @@ class ModelTpServer: ...@@ -407,7 +410,7 @@ class ModelTpServer:
def handle_embedding_request( def handle_embedding_request(
self, self,
recv_req: TokenizedEmbeddingReqInput, recv_req: Union[TokenizedEmbeddingReqInput, TokenizedRewardReqInput],
): ):
req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids) req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
req.tokenizer = self.tokenizer req.tokenizer = self.tokenizer
......
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from typing import Iterable, Optional, Tuple
import torch
from torch import nn
from transformers import LlamaConfig
from vllm.config import CacheConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel
class LlamaForSequenceClassification(nn.Module):
def __init__(
self,
config: LlamaConfig,
quant_config: Optional[QuantizationConfig] = None,
cache_config: Optional[CacheConfig] = None,
) -> None:
super().__init__()
self.config = config
self.torchao_config = None
self.quant_config = quant_config
self.num_labels = config.num_labels
self.model = LlamaModel(config, quant_config=quant_config)
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=False)
self.eos_token_id = config.eos_token_id
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
input_embeds: torch.Tensor = None,
) -> EmbeddingPoolerOutput:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
scores = self.score(hidden_states)
return self.pooler(scores, input_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "classification_head" in name:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
elif "lm_head" in name:
continue
else:
LlamaForCausalLM.load_weights(self, [(name, loaded_weight)])
class LlamaForSequenceClassificationWithNormal_Weights(LlamaForSequenceClassification):
class Weights(torch.nn.Module):
def __init__(self, hidden_size, num_label):
super().__init__()
self.fc = torch.nn.Sequential(
torch.nn.Linear(hidden_size, hidden_size, dtype=torch.float16),
torch.nn.SELU(),
torch.nn.Linear(hidden_size, hidden_size, dtype=torch.float16),
torch.nn.SELU(),
torch.nn.Linear(hidden_size, num_label // 2, dtype=torch.float16),
)
def forward(self, x):
return self.fc(x.to(torch.float16))
def __init__(
self,
config: LlamaConfig,
quant_config: Optional[QuantizationConfig] = None,
cache_config: Optional[CacheConfig] = None,
) -> None:
super().__init__(config, quant_config, cache_config)
self.weights = self.Weights(config.hidden_size, self.num_labels)
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
input_embeds: torch.Tensor = None,
get_embedding: bool = True,
) -> EmbeddingPoolerOutput:
assert (
get_embedding
), "LlamaForSequenceClassification is only used for embedding"
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
logits = self.score(hidden_states)
weights = self.weights(hidden_states)
pooled_logits = self.pooler(logits, input_metadata).embeddings
pooled_weights = self.pooler(weights, input_metadata).embeddings
rews = pooled_logits.view(-1, self.num_labels // 2, 2)[:, :, 0].view(
-1, self.num_labels // 2
)
scores = (rews * pooled_weights).sum(dim=-1).view(-1, 1)
return EmbeddingPoolerOutput(scores)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "classification_head" in name:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
elif "lm_head" in name:
continue
else:
LlamaForCausalLM.load_weights(self, [(name, loaded_weight)])
EntryClass = [
LlamaForSequenceClassification,
LlamaForSequenceClassificationWithNormal_Weights,
]
...@@ -54,6 +54,7 @@ from sglang.srt.managers.detokenizer_manager import start_detokenizer_process ...@@ -54,6 +54,7 @@ from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
EmbeddingReqInput, EmbeddingReqInput,
GenerateReqInput, GenerateReqInput,
RewardReqInput,
UpdateWeightReqInput, UpdateWeightReqInput,
) )
from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.managers.tokenizer_manager import TokenizerManager
...@@ -213,6 +214,21 @@ app.post("/encode")(encode_request) ...@@ -213,6 +214,21 @@ app.post("/encode")(encode_request)
app.put("/encode")(encode_request) app.put("/encode")(encode_request)
async def judge_request(obj: RewardReqInput, request: Request):
"""Handle an embedding request."""
try:
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
return ret
except ValueError as e:
return JSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
)
app.post("/judge")(judge_request)
app.put("/judge")(judge_request)
@app.post("/v1/completions") @app.post("/v1/completions")
async def openai_v1_completions(raw_request: Request): async def openai_v1_completions(raw_request: Request):
return await v1_completions(tokenizer_manager, raw_request) return await v1_completions(tokenizer_manager, raw_request)
...@@ -635,15 +651,26 @@ class Runtime: ...@@ -635,15 +651,26 @@ class Runtime:
def encode( def encode(
self, self,
prompt: Union[str, List[str]], prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
): ):
json_data = { if isinstance(prompt, str) or isinstance(prompt[0], str):
"text": prompt, # embedding
} json_data = {
response = requests.post( "text": prompt,
self.url + "/encode", }
json=json_data, response = requests.post(
) self.url + "/encode",
json=json_data,
)
else:
# reward
json_data = {
"conv": prompt,
}
response = requests.post(
self.url + "/judge",
json=json_data,
)
return json.dumps(response.json()) return json.dumps(response.json())
def __del__(self): def __del__(self):
......
...@@ -219,6 +219,8 @@ def is_generation_model(model_architectures, is_embedding: bool = False): ...@@ -219,6 +219,8 @@ def is_generation_model(model_architectures, is_embedding: bool = False):
if ( if (
"LlamaEmbeddingModel" in model_architectures "LlamaEmbeddingModel" in model_architectures
or "MistralModel" in model_architectures or "MistralModel" in model_architectures
or "LlamaForSequenceClassification" in model_architectures
or "LlamaForSequenceClassificationWithNormal_Weights" in model_architectures
): ):
return False return False
else: else:
......
...@@ -65,6 +65,7 @@ class ModelOutput: ...@@ -65,6 +65,7 @@ class ModelOutput:
top_input_logprobs: List[torch.Tensor] = None top_input_logprobs: List[torch.Tensor] = None
top_output_logprobs: List[torch.Tensor] = None top_output_logprobs: List[torch.Tensor] = None
embed_logits: List[torch.Tensor] = None embed_logits: List[torch.Tensor] = None
scores: List[float] = None
class HFRunner: class HFRunner:
...@@ -72,10 +73,10 @@ class HFRunner: ...@@ -72,10 +73,10 @@ class HFRunner:
self, self,
model_path, model_path,
torch_dtype, torch_dtype,
is_generation, model_type="generation",
output_str_only=False, output_str_only=False,
): ):
self.is_generation = is_generation self.model_type = model_type
self.output_str_only = output_str_only self.output_str_only = output_str_only
self.in_queue = mp.Queue() self.in_queue = mp.Queue()
...@@ -92,22 +93,41 @@ class HFRunner: ...@@ -92,22 +93,41 @@ class HFRunner:
) )
self.model_proc.start() self.model_proc.start()
def needs_trust_remote_code(self, model_path):
models_needs_trust_remote = [
"LxzGordon/URM-LLaMa-3.1-8B",
]
if model_path in models_needs_trust_remote:
return True
return False
def start_model_process(self, in_queue, out_queue, model_path, torch_dtype): def start_model_process(self, in_queue, out_queue, model_path, torch_dtype):
self.tokenizer = get_tokenizer(model_path) self.tokenizer = get_tokenizer(model_path, torch_dtype=torch.dtype)
if self.is_generation:
if self.model_type == "generation":
self.base_model = AutoModelForCausalLM.from_pretrained( self.base_model = AutoModelForCausalLM.from_pretrained(
model_path, model_path,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
trust_remote_code=False, trust_remote_code=False,
low_cpu_mem_usage=True, low_cpu_mem_usage=True,
).cuda() ).cuda()
else: elif self.model_type == "embedding":
from sentence_transformers import SentenceTransformer from sentence_transformers import SentenceTransformer
self.model = SentenceTransformer( self.model = SentenceTransformer(
model_path, model_path,
model_kwargs={"torch_dtype": torch_dtype}, model_kwargs={"torch_dtype": torch_dtype},
) ).cuda()
elif self.model_type == "reward":
from transformers import AutoModelForSequenceClassification
self.model = AutoModelForSequenceClassification.from_pretrained(
model_path,
torch_dtype=torch_dtype,
trust_remote_code=self.needs_trust_remote_code(model_path),
).cuda()
else:
raise Exception(f"Unrecognized model type {self.model_type}")
while True: while True:
prompts, max_new_tokens, lora_paths = in_queue.get() prompts, max_new_tokens, lora_paths = in_queue.get()
...@@ -115,7 +135,7 @@ class HFRunner: ...@@ -115,7 +135,7 @@ class HFRunner:
assert len(prompts) == len(lora_paths) assert len(prompts) == len(lora_paths)
if prompts is not None: if prompts is not None:
if self.is_generation: if self.model_type == "generation":
output_strs = [] output_strs = []
top_input_logprobs = [] top_input_logprobs = []
top_output_logprobs = [] top_output_logprobs = []
...@@ -179,11 +199,27 @@ class HFRunner: ...@@ -179,11 +199,27 @@ class HFRunner:
) )
) )
else: elif self.model_type == "embedding":
assert not self.output_str_only assert not self.output_str_only
logits = self.model.encode(prompts).tolist() logits = self.model.encode(prompts).tolist()
out_queue.put(ModelOutput(embed_logits=logits)) out_queue.put(ModelOutput(embed_logits=logits))
elif self.model_type == "reward":
scores = []
for conv in prompts:
conv_formatted = self.tokenizer.apply_chat_template(
conv, tokenize=False
)
conv_tokenized = self.tokenizer(
conv_formatted, return_tensors="pt"
).to("cuda")
scores.append(
float(self.model(**conv_tokenized).logits[0][0].item())
)
out_queue.put(ModelOutput(scores=scores))
else:
raise Exception(f"Unrecognized model type {self.model_type}")
def forward( def forward(
self, self,
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS, prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
...@@ -210,7 +246,7 @@ class SRTRunner: ...@@ -210,7 +246,7 @@ class SRTRunner:
self, self,
model_path, model_path,
torch_dtype, torch_dtype,
is_generation, model_type,
tp_size=1, tp_size=1,
port=DEFAULT_PORT_FOR_SRT_TEST_RUNNER, port=DEFAULT_PORT_FOR_SRT_TEST_RUNNER,
lora_paths=None, lora_paths=None,
...@@ -218,13 +254,14 @@ class SRTRunner: ...@@ -218,13 +254,14 @@ class SRTRunner:
disable_cuda_graph=False, disable_cuda_graph=False,
disable_radix_cache=False, disable_radix_cache=False,
): ):
self.is_generation = is_generation self.model_type = model_type
self.is_generation = model_type == "generation"
self.runtime = Runtime( self.runtime = Runtime(
model_path=model_path, model_path=model_path,
tp_size=tp_size, tp_size=tp_size,
dtype=get_dtype_str(torch_dtype), dtype=get_dtype_str(torch_dtype),
port=port, port=port,
mem_fraction_static=0.69, mem_fraction_static=0.65,
trust_remote_code=False, trust_remote_code=False,
is_embedding=not self.is_generation, is_embedding=not self.is_generation,
lora_paths=lora_paths, lora_paths=lora_paths,
...@@ -285,8 +322,12 @@ class SRTRunner: ...@@ -285,8 +322,12 @@ class SRTRunner:
else: else:
response = self.runtime.encode(prompts) response = self.runtime.encode(prompts)
response = json.loads(response) response = json.loads(response)
logits = [x["embedding"] for x in response] if self.model_type == "embedding":
return ModelOutput(embed_logits=logits) logits = [x["embedding"] for x in response]
return ModelOutput(embed_logits=logits)
else:
scores = [x["embedding"][0] for x in response]
return ModelOutput(scores=scores)
def batch_forward( def batch_forward(
self, self,
...@@ -316,8 +357,12 @@ class SRTRunner: ...@@ -316,8 +357,12 @@ class SRTRunner:
else: else:
response = self.runtime.encode(prompts) response = self.runtime.encode(prompts)
response = json.loads(response) response = json.loads(response)
logits = [x["embedding"] for x in response] if self.model_type == "embedding":
return ModelOutput(embed_logits=logits) logits = [x["embedding"] for x in response]
return ModelOutput(embed_logits=logits)
else:
scores = [x["embedding"][0] for x in response]
return ModelOutput(scores=logits)
def __enter__(self): def __enter__(self):
return self return self
......
...@@ -39,7 +39,9 @@ class TestEmbeddingModels(unittest.TestCase): ...@@ -39,7 +39,9 @@ class TestEmbeddingModels(unittest.TestCase):
prefill_tolerance, prefill_tolerance,
) -> None: ) -> None:
with HFRunner( with HFRunner(
model_path, torch_dtype=torch_dtype, is_generation=False model_path,
torch_dtype=torch_dtype,
model_type="embedding",
) as hf_runner: ) as hf_runner:
hf_outputs = hf_runner.forward(prompts) hf_outputs = hf_runner.forward(prompts)
...@@ -47,7 +49,7 @@ class TestEmbeddingModels(unittest.TestCase): ...@@ -47,7 +49,7 @@ class TestEmbeddingModels(unittest.TestCase):
model_path, model_path,
tp_size=tp_size, tp_size=tp_size,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
is_generation=False, model_type="embedding",
) as srt_runner: ) as srt_runner:
srt_outputs = srt_runner.forward(prompts) srt_outputs = srt_runner.forward(prompts)
......
...@@ -73,7 +73,9 @@ class TestGenerationModels(unittest.TestCase): ...@@ -73,7 +73,9 @@ class TestGenerationModels(unittest.TestCase):
max_new_tokens = 32 max_new_tokens = 32
with HFRunner( with HFRunner(
model_path, torch_dtype=torch_dtype, is_generation=True model_path,
torch_dtype=torch_dtype,
model_type="generation",
) as hf_runner: ) as hf_runner:
hf_outputs = hf_runner.forward(prompts, max_new_tokens=max_new_tokens) hf_outputs = hf_runner.forward(prompts, max_new_tokens=max_new_tokens)
...@@ -81,7 +83,7 @@ class TestGenerationModels(unittest.TestCase): ...@@ -81,7 +83,7 @@ class TestGenerationModels(unittest.TestCase):
model_path, model_path,
tp_size=model_case.tp_size, tp_size=model_case.tp_size,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
is_generation=True, model_type="generation",
) as srt_runner: ) as srt_runner:
srt_outputs = srt_runner.forward(prompts, max_new_tokens=max_new_tokens) srt_outputs = srt_runner.forward(prompts, max_new_tokens=max_new_tokens)
......
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import multiprocessing as mp
import unittest
import torch
from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner
MODELS = [
("LxzGordon/URM-LLaMa-3.1-8B", 1, 2e-2),
]
TORCH_DTYPES = [torch.float16]
# PROMPT = "Jane has 12 apples. She gives 4 apples to her friend Mark, then buys 1 more apple, and finally splits all her apples equally among herself and her 2 siblings. How many apples does each person get?"
# RESPONSE1 = "1. Jane starts with 12 apples and gives 4 to Mark. 12 - 4 = 8. Jane now has 8 apples.\n2. Jane buys 1 more apple. 8 + 1 = 9. Jane now has 9 apples.\n3. Jane splits the 9 apples equally among herself and her 2 siblings (3 people in total). 9 ÷ 3 = 3 apples each. Each person gets 3 apples."
# RESPONSE2 = "1. Jane starts with 12 apples and gives 4 to Mark. 12 - 4 = 8. Jane now has 8 apples.\n2. Jane buys 1 more apple. 8 + 1 = 9. Jane now has 9 apples.\n3. Jane splits the 9 apples equally among her 2 siblings (2 people in total). 9 ÷ 2 = 4.5 apples each. Each person gets 4 apples."
PROMPT = (
"What is the range of the numeric output of a sigmoid node in a neural network?"
)
RESPONSE1 = "The output of a sigmoid node is bounded between -1 and 1."
RESPONSE2 = "The output of a sigmoid node is bounded between 0 and 1."
CONVS = [
[{"role": "user", "content": PROMPT}, {"role": "assistant", "content": RESPONSE1}],
[{"role": "user", "content": PROMPT}, {"role": "assistant", "content": RESPONSE2}],
]
class TestRewardModels(unittest.TestCase):
def assert_close_reward_scores(
self,
convs,
model_path,
tp_size,
torch_dtype,
tolerance,
) -> None:
with HFRunner(
model_path,
torch_dtype=torch_dtype,
model_type="reward",
) as hf_runner:
hf_outputs = hf_runner.forward(convs)
with SRTRunner(
model_path,
torch_dtype=torch_dtype,
model_type="reward",
) as srt_runner:
srt_outputs = srt_runner.forward(convs)
hf_scores = torch.tensor(hf_outputs.scores)
srt_scores = torch.tensor(srt_outputs.scores)
print(hf_scores)
print(srt_scores)
assert torch.all(
abs(hf_scores - srt_scores) < tolerance
), "reward scores are not all close"
def test_reward_scores(self):
for model, tp_size, tolerance in MODELS:
for torch_dtype in TORCH_DTYPES:
self.assert_close_reward_scores(
CONVS, model, tp_size, torch_dtype, tolerance
)
if __name__ == "__main__":
try:
mp.set_start_method("spawn")
except RuntimeError:
pass
unittest.main()
...@@ -7,6 +7,7 @@ suites = { ...@@ -7,6 +7,7 @@ suites = {
"minimal": [ "minimal": [
"models/test_embedding_models.py", "models/test_embedding_models.py",
"models/test_generation_models.py", "models/test_generation_models.py",
"models/test_reward_models.py",
"sampling/penaltylib", "sampling/penaltylib",
"test_chunked_prefill.py", "test_chunked_prefill.py",
"test_embedding_openai_server.py", "test_embedding_openai_server.py",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment