"vscode:/vscode.git/clone" did not exist on "743516f37bf3e12b8ed79553e3ccbe4e60b8b374"
Unverified Commit 9aa6553d authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

[Feature] Support reward model LxzGordon/URM-LLaMa-3.1-8B (#1525)

parent b1e330bc
......@@ -29,6 +29,7 @@ jobs:
run: |
pip install --upgrade pip
pip install -e "python[dev]"
pip install transformers==4.44
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall
- name: Run test
......@@ -48,6 +49,7 @@ jobs:
run: |
pip install --upgrade pip
pip install -e "python[dev]"
pip install transformers==4.44
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall
- name: Run test
......@@ -67,6 +69,7 @@ jobs:
run: |
pip install --upgrade pip
pip install -e "python[dev]"
pip install transformers==4.44
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall
- name: Run test
......@@ -86,6 +89,7 @@ jobs:
run: |
pip install --upgrade pip
pip install -e "python[dev]"
pip install transformers==4.44
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall
- name: Run test
......@@ -105,6 +109,7 @@ jobs:
run: |
pip install --upgrade pip
pip install -e "python[all]"
pip install transformers==4.44
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall
- name: Benchmark Single Latency
......@@ -136,6 +141,7 @@ jobs:
run: |
pip install --upgrade pip
pip install -e "python[all]"
pip install transformers==4.44
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall
- name: Benchmark Offline Throughput (w/o RadixAttention)
......@@ -167,6 +173,7 @@ jobs:
run: |
pip install --upgrade pip
pip install -e "python[all]"
pip install transformers==4.44
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall
- name: Benchmark Offline Throughput (TP=2)
......@@ -198,6 +205,7 @@ jobs:
run: |
pip install --upgrade pip
pip install -e "python[all]"
pip install transformers==4.44
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall
git clone https://github.com/merrymercy/human-eval.git
......@@ -221,6 +229,7 @@ jobs:
run: |
pip install --upgrade pip
pip install -e "python[all]"
pip install transformers==4.44
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall
git clone https://github.com/merrymercy/human-eval.git
......
# launch server
# python -m sglang.launch_server --model LxzGordon/URM-LLaMa-3.1-8B --is-embedding
import json
import requests
url = "http://127.0.0.1:30000"
PROMPT = (
"What is the range of the numeric output of a sigmoid node in a neural network?"
)
RESPONSE1 = "The output of a sigmoid node is bounded between -1 and 1."
RESPONSE2 = "The output of a sigmoid node is bounded between 0 and 1."
json_data = {
"conv": [
[
{"role": "user", "content": PROMPT},
{"role": "assistant", "content": RESPONSE1},
],
[
{"role": "user", "content": PROMPT},
{"role": "assistant", "content": RESPONSE2},
],
],
}
response = requests.post(
url + "/judge",
json=json_data,
).json()
print(response)
print("scores:", [x["embedding"] for x in response])
......@@ -215,12 +215,11 @@ class EmbeddingReqInput:
raise ValueError("Either text or input_ids should be provided.")
if self.text is not None:
is_single = isinstance(self.text, str)
self.is_single = isinstance(self.text, str)
else:
is_single = isinstance(self.input_ids[0], int)
self.is_single = is_single
self.is_single = isinstance(self.input_ids[0], int)
if is_single:
if self.is_single:
if self.rid is None:
self.rid = uuid.uuid4().hex
if self.sampling_params is None:
......@@ -254,6 +253,52 @@ class TokenizedEmbeddingReqInput:
sampling_params: SamplingParams
@dataclass
class RewardReqInput:
# The input prompt in the chat format. It can be a single prompt or a batch of prompts.
conv: Union[List[List[Dict]], List[Dict]]
# The request id.
rid: Optional[Union[List[str], str]] = None
# Dummy sampling params for compatibility
sampling_params: Union[List[Dict], Dict] = None
is_single: bool = True
def post_init(self):
self.is_single = isinstance(self.conv[0], dict)
if self.is_single:
if self.rid is None:
self.rid = uuid.uuid4().hex
if self.sampling_params is None:
self.sampling_params = {}
self.sampling_params["max_new_tokens"] = 1
else:
# support select operation
self.batch_size = len(self.conv)
if self.rid is None:
self.rid = [uuid.uuid4().hex for _ in range(self.batch_size)]
else:
if not isinstance(self.rid, list):
raise ValueError("The rid should be a list.")
if self.sampling_params is None:
self.sampling_params = [{}] * self.batch_size
for i in range(self.batch_size):
self.sampling_params[i]["max_new_tokens"] = 1
@dataclass
class TokenizedRewardReqInput:
# The request id
rid: str
# The input text
input_text: str
# The input token ids
input_ids: List[int]
# Dummy sampling params for compatibility
sampling_params: SamplingParams
@dataclass
class BatchTokenIDOut:
# The request id
......
......@@ -46,8 +46,10 @@ from sglang.srt.managers.io_struct import (
EmbeddingReqInput,
FlushCacheReq,
GenerateReqInput,
RewardReqInput,
TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput,
TokenizedRewardReqInput,
UpdateWeightReqInput,
UpdateWeightReqOutput,
)
......@@ -142,7 +144,7 @@ class TokenizerManager:
async def generate_request(
self,
obj: Union[GenerateReqInput, EmbeddingReqInput],
obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
request: Optional[fastapi.Request] = None,
):
if self.to_create_loop:
......@@ -163,7 +165,7 @@ class TokenizerManager:
async def _handle_single_request(
self,
obj: Union[GenerateReqInput, EmbeddingReqInput],
obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
request: Optional[fastapi.Request] = None,
index: Optional[int] = None,
is_cache_for_prefill: Optional[bool] = False,
......@@ -173,7 +175,13 @@ class TokenizerManager:
rid = obj.rid if not_use_index else obj.rid[index]
input_text = obj.text if not_use_index else obj.text[index]
if obj.input_ids is None:
if hasattr(obj, "conv"):
# reward model
assert self.tokenizer is not None
conv = obj.conv if not_use_index else obj.conv[index]
input_text = self.tokenizer.apply_chat_template(conv, tokenize=False)
input_ids = self.tokenizer.encode(input_text)
elif obj.input_ids is None:
assert self.tokenizer is not None
input_ids = self.tokenizer.encode(input_text)
else:
......@@ -269,13 +277,21 @@ class TokenizerManager:
else obj.lora_path
),
)
else: # is embedding
elif isinstance(obj, EmbeddingReqInput):
tokenized_obj = TokenizedEmbeddingReqInput(
rid,
input_text,
input_ids,
sampling_params,
)
else:
assert isinstance(obj, RewardReqInput)
tokenized_obj = TokenizedRewardReqInput(
rid,
input_text,
input_ids,
sampling_params,
)
self.send_to_controller.send_pyobj(tokenized_obj)
# Recv results
......@@ -292,7 +308,7 @@ class TokenizerManager:
async def _handle_batch_request(
self,
obj: Union[GenerateReqInput, EmbeddingReqInput],
obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
request: Optional[fastapi.Request] = None,
):
batch_size = obj.batch_size
......@@ -329,9 +345,16 @@ class TokenizerManager:
rid = obj.rid[index]
if parallel_sample_num == 1:
## select operation
if obj.input_ids is None:
if hasattr(obj, "conv"):
# reward model
conv = obj.conv[i]
input_text = self.tokenizer.apply_chat_template(
conv, tokenize=False
)
input_ids = self.tokenizer.encode(input_text)
elif obj.input_ids is None:
input_text = obj.text[i]
input_ids = self.tokenizer.encode(obj.text[i])
input_ids = self.tokenizer.encode(input_text)
else:
input_text = None
input_ids = obj.input_ids[i]
......@@ -370,13 +393,21 @@ class TokenizerManager:
else obj.lora_path
),
)
else:
elif isinstance(obj, EmbeddingReqInput):
tokenized_obj = TokenizedEmbeddingReqInput(
rid,
input_text,
input_ids,
sampling_params,
)
else:
assert isinstance(obj, RewardReqInput)
tokenized_obj = TokenizedRewardReqInput(
rid,
input_text,
input_ids,
sampling_params,
)
self.send_to_controller.send_pyobj(tokenized_obj)
event = asyncio.Event()
......@@ -442,7 +473,7 @@ class TokenizerManager:
async def _wait_for_response(
self,
state: ReqState,
obj: Union[GenerateReqInput, EmbeddingReqInput],
obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
rid: str,
request: Optional[fastapi.Request] = None,
index: Optional[int] = None,
......@@ -469,7 +500,7 @@ class TokenizerManager:
),
obj.return_text_in_logprobs,
)
else: # isinstance(obj, EmbeddingReqInput)
else: # isinstance(obj, (EmbeddingReqInput, RewardReqInput))
out = state.out_list[-1]
out["index"] = response_index
......
......@@ -22,7 +22,7 @@ import os
import pickle
import time
import warnings
from typing import Any, List, Optional
from typing import Any, List, Optional, Union
import torch
import torch.distributed
......@@ -41,6 +41,7 @@ from sglang.srt.managers.io_struct import (
FlushCacheReq,
TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput,
TokenizedRewardReqInput,
UpdateWeightReqInput,
UpdateWeightReqOutput,
)
......@@ -223,7 +224,9 @@ class ModelTpServer:
if isinstance(recv_req, TokenizedGenerateReqInput):
self.handle_generate_request(recv_req)
self.do_not_get_new_batch = False
elif isinstance(recv_req, TokenizedEmbeddingReqInput):
elif isinstance(
recv_req, (TokenizedEmbeddingReqInput, TokenizedRewardReqInput)
):
self.handle_embedding_request(recv_req)
self.do_not_get_new_batch = False
elif isinstance(recv_req, FlushCacheReq):
......@@ -407,7 +410,7 @@ class ModelTpServer:
def handle_embedding_request(
self,
recv_req: TokenizedEmbeddingReqInput,
recv_req: Union[TokenizedEmbeddingReqInput, TokenizedRewardReqInput],
):
req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
req.tokenizer = self.tokenizer
......
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from typing import Iterable, Optional, Tuple
import torch
from torch import nn
from transformers import LlamaConfig
from vllm.config import CacheConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel
class LlamaForSequenceClassification(nn.Module):
def __init__(
self,
config: LlamaConfig,
quant_config: Optional[QuantizationConfig] = None,
cache_config: Optional[CacheConfig] = None,
) -> None:
super().__init__()
self.config = config
self.torchao_config = None
self.quant_config = quant_config
self.num_labels = config.num_labels
self.model = LlamaModel(config, quant_config=quant_config)
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=False)
self.eos_token_id = config.eos_token_id
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
input_embeds: torch.Tensor = None,
) -> EmbeddingPoolerOutput:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
scores = self.score(hidden_states)
return self.pooler(scores, input_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "classification_head" in name:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
elif "lm_head" in name:
continue
else:
LlamaForCausalLM.load_weights(self, [(name, loaded_weight)])
class LlamaForSequenceClassificationWithNormal_Weights(LlamaForSequenceClassification):
class Weights(torch.nn.Module):
def __init__(self, hidden_size, num_label):
super().__init__()
self.fc = torch.nn.Sequential(
torch.nn.Linear(hidden_size, hidden_size, dtype=torch.float16),
torch.nn.SELU(),
torch.nn.Linear(hidden_size, hidden_size, dtype=torch.float16),
torch.nn.SELU(),
torch.nn.Linear(hidden_size, num_label // 2, dtype=torch.float16),
)
def forward(self, x):
return self.fc(x.to(torch.float16))
def __init__(
self,
config: LlamaConfig,
quant_config: Optional[QuantizationConfig] = None,
cache_config: Optional[CacheConfig] = None,
) -> None:
super().__init__(config, quant_config, cache_config)
self.weights = self.Weights(config.hidden_size, self.num_labels)
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
input_embeds: torch.Tensor = None,
get_embedding: bool = True,
) -> EmbeddingPoolerOutput:
assert (
get_embedding
), "LlamaForSequenceClassification is only used for embedding"
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
logits = self.score(hidden_states)
weights = self.weights(hidden_states)
pooled_logits = self.pooler(logits, input_metadata).embeddings
pooled_weights = self.pooler(weights, input_metadata).embeddings
rews = pooled_logits.view(-1, self.num_labels // 2, 2)[:, :, 0].view(
-1, self.num_labels // 2
)
scores = (rews * pooled_weights).sum(dim=-1).view(-1, 1)
return EmbeddingPoolerOutput(scores)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "classification_head" in name:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
elif "lm_head" in name:
continue
else:
LlamaForCausalLM.load_weights(self, [(name, loaded_weight)])
EntryClass = [
LlamaForSequenceClassification,
LlamaForSequenceClassificationWithNormal_Weights,
]
......@@ -54,6 +54,7 @@ from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
from sglang.srt.managers.io_struct import (
EmbeddingReqInput,
GenerateReqInput,
RewardReqInput,
UpdateWeightReqInput,
)
from sglang.srt.managers.tokenizer_manager import TokenizerManager
......@@ -213,6 +214,21 @@ app.post("/encode")(encode_request)
app.put("/encode")(encode_request)
async def judge_request(obj: RewardReqInput, request: Request):
"""Handle an embedding request."""
try:
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
return ret
except ValueError as e:
return JSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
)
app.post("/judge")(judge_request)
app.put("/judge")(judge_request)
@app.post("/v1/completions")
async def openai_v1_completions(raw_request: Request):
return await v1_completions(tokenizer_manager, raw_request)
......@@ -635,15 +651,26 @@ class Runtime:
def encode(
self,
prompt: Union[str, List[str]],
prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
):
json_data = {
"text": prompt,
}
response = requests.post(
self.url + "/encode",
json=json_data,
)
if isinstance(prompt, str) or isinstance(prompt[0], str):
# embedding
json_data = {
"text": prompt,
}
response = requests.post(
self.url + "/encode",
json=json_data,
)
else:
# reward
json_data = {
"conv": prompt,
}
response = requests.post(
self.url + "/judge",
json=json_data,
)
return json.dumps(response.json())
def __del__(self):
......
......@@ -219,6 +219,8 @@ def is_generation_model(model_architectures, is_embedding: bool = False):
if (
"LlamaEmbeddingModel" in model_architectures
or "MistralModel" in model_architectures
or "LlamaForSequenceClassification" in model_architectures
or "LlamaForSequenceClassificationWithNormal_Weights" in model_architectures
):
return False
else:
......
......@@ -65,6 +65,7 @@ class ModelOutput:
top_input_logprobs: List[torch.Tensor] = None
top_output_logprobs: List[torch.Tensor] = None
embed_logits: List[torch.Tensor] = None
scores: List[float] = None
class HFRunner:
......@@ -72,10 +73,10 @@ class HFRunner:
self,
model_path,
torch_dtype,
is_generation,
model_type="generation",
output_str_only=False,
):
self.is_generation = is_generation
self.model_type = model_type
self.output_str_only = output_str_only
self.in_queue = mp.Queue()
......@@ -92,22 +93,41 @@ class HFRunner:
)
self.model_proc.start()
def needs_trust_remote_code(self, model_path):
models_needs_trust_remote = [
"LxzGordon/URM-LLaMa-3.1-8B",
]
if model_path in models_needs_trust_remote:
return True
return False
def start_model_process(self, in_queue, out_queue, model_path, torch_dtype):
self.tokenizer = get_tokenizer(model_path)
if self.is_generation:
self.tokenizer = get_tokenizer(model_path, torch_dtype=torch.dtype)
if self.model_type == "generation":
self.base_model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch_dtype,
trust_remote_code=False,
low_cpu_mem_usage=True,
).cuda()
else:
elif self.model_type == "embedding":
from sentence_transformers import SentenceTransformer
self.model = SentenceTransformer(
model_path,
model_kwargs={"torch_dtype": torch_dtype},
)
).cuda()
elif self.model_type == "reward":
from transformers import AutoModelForSequenceClassification
self.model = AutoModelForSequenceClassification.from_pretrained(
model_path,
torch_dtype=torch_dtype,
trust_remote_code=self.needs_trust_remote_code(model_path),
).cuda()
else:
raise Exception(f"Unrecognized model type {self.model_type}")
while True:
prompts, max_new_tokens, lora_paths = in_queue.get()
......@@ -115,7 +135,7 @@ class HFRunner:
assert len(prompts) == len(lora_paths)
if prompts is not None:
if self.is_generation:
if self.model_type == "generation":
output_strs = []
top_input_logprobs = []
top_output_logprobs = []
......@@ -179,11 +199,27 @@ class HFRunner:
)
)
else:
elif self.model_type == "embedding":
assert not self.output_str_only
logits = self.model.encode(prompts).tolist()
out_queue.put(ModelOutput(embed_logits=logits))
elif self.model_type == "reward":
scores = []
for conv in prompts:
conv_formatted = self.tokenizer.apply_chat_template(
conv, tokenize=False
)
conv_tokenized = self.tokenizer(
conv_formatted, return_tensors="pt"
).to("cuda")
scores.append(
float(self.model(**conv_tokenized).logits[0][0].item())
)
out_queue.put(ModelOutput(scores=scores))
else:
raise Exception(f"Unrecognized model type {self.model_type}")
def forward(
self,
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
......@@ -210,7 +246,7 @@ class SRTRunner:
self,
model_path,
torch_dtype,
is_generation,
model_type,
tp_size=1,
port=DEFAULT_PORT_FOR_SRT_TEST_RUNNER,
lora_paths=None,
......@@ -218,13 +254,14 @@ class SRTRunner:
disable_cuda_graph=False,
disable_radix_cache=False,
):
self.is_generation = is_generation
self.model_type = model_type
self.is_generation = model_type == "generation"
self.runtime = Runtime(
model_path=model_path,
tp_size=tp_size,
dtype=get_dtype_str(torch_dtype),
port=port,
mem_fraction_static=0.69,
mem_fraction_static=0.65,
trust_remote_code=False,
is_embedding=not self.is_generation,
lora_paths=lora_paths,
......@@ -285,8 +322,12 @@ class SRTRunner:
else:
response = self.runtime.encode(prompts)
response = json.loads(response)
logits = [x["embedding"] for x in response]
return ModelOutput(embed_logits=logits)
if self.model_type == "embedding":
logits = [x["embedding"] for x in response]
return ModelOutput(embed_logits=logits)
else:
scores = [x["embedding"][0] for x in response]
return ModelOutput(scores=scores)
def batch_forward(
self,
......@@ -316,8 +357,12 @@ class SRTRunner:
else:
response = self.runtime.encode(prompts)
response = json.loads(response)
logits = [x["embedding"] for x in response]
return ModelOutput(embed_logits=logits)
if self.model_type == "embedding":
logits = [x["embedding"] for x in response]
return ModelOutput(embed_logits=logits)
else:
scores = [x["embedding"][0] for x in response]
return ModelOutput(scores=logits)
def __enter__(self):
return self
......
......@@ -39,7 +39,9 @@ class TestEmbeddingModels(unittest.TestCase):
prefill_tolerance,
) -> None:
with HFRunner(
model_path, torch_dtype=torch_dtype, is_generation=False
model_path,
torch_dtype=torch_dtype,
model_type="embedding",
) as hf_runner:
hf_outputs = hf_runner.forward(prompts)
......@@ -47,7 +49,7 @@ class TestEmbeddingModels(unittest.TestCase):
model_path,
tp_size=tp_size,
torch_dtype=torch_dtype,
is_generation=False,
model_type="embedding",
) as srt_runner:
srt_outputs = srt_runner.forward(prompts)
......
......@@ -73,7 +73,9 @@ class TestGenerationModels(unittest.TestCase):
max_new_tokens = 32
with HFRunner(
model_path, torch_dtype=torch_dtype, is_generation=True
model_path,
torch_dtype=torch_dtype,
model_type="generation",
) as hf_runner:
hf_outputs = hf_runner.forward(prompts, max_new_tokens=max_new_tokens)
......@@ -81,7 +83,7 @@ class TestGenerationModels(unittest.TestCase):
model_path,
tp_size=model_case.tp_size,
torch_dtype=torch_dtype,
is_generation=True,
model_type="generation",
) as srt_runner:
srt_outputs = srt_runner.forward(prompts, max_new_tokens=max_new_tokens)
......
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import multiprocessing as mp
import unittest
import torch
from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner
MODELS = [
("LxzGordon/URM-LLaMa-3.1-8B", 1, 2e-2),
]
TORCH_DTYPES = [torch.float16]
# PROMPT = "Jane has 12 apples. She gives 4 apples to her friend Mark, then buys 1 more apple, and finally splits all her apples equally among herself and her 2 siblings. How many apples does each person get?"
# RESPONSE1 = "1. Jane starts with 12 apples and gives 4 to Mark. 12 - 4 = 8. Jane now has 8 apples.\n2. Jane buys 1 more apple. 8 + 1 = 9. Jane now has 9 apples.\n3. Jane splits the 9 apples equally among herself and her 2 siblings (3 people in total). 9 ÷ 3 = 3 apples each. Each person gets 3 apples."
# RESPONSE2 = "1. Jane starts with 12 apples and gives 4 to Mark. 12 - 4 = 8. Jane now has 8 apples.\n2. Jane buys 1 more apple. 8 + 1 = 9. Jane now has 9 apples.\n3. Jane splits the 9 apples equally among her 2 siblings (2 people in total). 9 ÷ 2 = 4.5 apples each. Each person gets 4 apples."
PROMPT = (
"What is the range of the numeric output of a sigmoid node in a neural network?"
)
RESPONSE1 = "The output of a sigmoid node is bounded between -1 and 1."
RESPONSE2 = "The output of a sigmoid node is bounded between 0 and 1."
CONVS = [
[{"role": "user", "content": PROMPT}, {"role": "assistant", "content": RESPONSE1}],
[{"role": "user", "content": PROMPT}, {"role": "assistant", "content": RESPONSE2}],
]
class TestRewardModels(unittest.TestCase):
def assert_close_reward_scores(
self,
convs,
model_path,
tp_size,
torch_dtype,
tolerance,
) -> None:
with HFRunner(
model_path,
torch_dtype=torch_dtype,
model_type="reward",
) as hf_runner:
hf_outputs = hf_runner.forward(convs)
with SRTRunner(
model_path,
torch_dtype=torch_dtype,
model_type="reward",
) as srt_runner:
srt_outputs = srt_runner.forward(convs)
hf_scores = torch.tensor(hf_outputs.scores)
srt_scores = torch.tensor(srt_outputs.scores)
print(hf_scores)
print(srt_scores)
assert torch.all(
abs(hf_scores - srt_scores) < tolerance
), "reward scores are not all close"
def test_reward_scores(self):
for model, tp_size, tolerance in MODELS:
for torch_dtype in TORCH_DTYPES:
self.assert_close_reward_scores(
CONVS, model, tp_size, torch_dtype, tolerance
)
if __name__ == "__main__":
try:
mp.set_start_method("spawn")
except RuntimeError:
pass
unittest.main()
......@@ -7,6 +7,7 @@ suites = {
"minimal": [
"models/test_embedding_models.py",
"models/test_generation_models.py",
"models/test_reward_models.py",
"sampling/penaltylib",
"test_chunked_prefill.py",
"test_embedding_openai_server.py",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment