Unverified Commit 70cc0749 authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

Add model accuracy test - step 1 (#866)

parent 7dd8a7e6
...@@ -35,6 +35,7 @@ jobs: ...@@ -35,6 +35,7 @@ jobs:
pip install -e "python[all]" pip install -e "python[all]"
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.3/ --force-reinstall pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.3/ --force-reinstall
pip install --upgrade transformers pip install --upgrade transformers
pip install accelerate
- name: Test Frontend Language with SRT Backend - name: Test Frontend Language with SRT Backend
run: | run: |
...@@ -50,6 +51,7 @@ jobs: ...@@ -50,6 +51,7 @@ jobs:
run: | run: |
cd test/srt cd test/srt
python3 test_eval_accuracy.py python3 test_eval_accuracy.py
python3 models/test_causal_models.py
- name: Test Frontend Language with OpenAI Backend - name: Test Frontend Language with OpenAI Backend
run: | run: |
......
...@@ -28,7 +28,7 @@ import sys ...@@ -28,7 +28,7 @@ import sys
import threading import threading
import time import time
from http import HTTPStatus from http import HTTPStatus
from typing import Dict, Optional from typing import Dict, List, Optional, Union
# Fix a bug of Python threading # Fix a bug of Python threading
setattr(threading, "_register_atexit", lambda *args, **kwargs: None) setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
...@@ -481,10 +481,10 @@ class Runtime: ...@@ -481,10 +481,10 @@ class Runtime:
trust_remote_code=self.server_args.trust_remote_code, trust_remote_code=self.server_args.trust_remote_code,
) )
async def add_request( async def async_generate(
self, self,
prompt: str, prompt: str,
sampling_params: Dict, sampling_params: Optional[Dict] = None,
): ):
json_data = { json_data = {
"text": prompt, "text": prompt,
...@@ -507,5 +507,26 @@ class Runtime: ...@@ -507,5 +507,26 @@ class Runtime:
yield cur yield cur
pos += len(cur) pos += len(cur)
add_request = async_generate
def generate(
self,
prompt: str,
sampling_params: Optional[Dict] = None,
return_logprob: Optional[Union[List[bool], bool]] = False,
top_logprobs_num: Optional[Union[List[int], int]] = None,
):
json_data = {
"text": prompt,
"sampling_params": sampling_params,
"return_logprob": return_logprob,
"top_logprobs_num": top_logprobs_num,
}
response = requests.post(
self.url + "/generate",
json=json_data,
)
return json.dumps(response.json())
def __del__(self): def __del__(self):
self.shutdown() self.shutdown()
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import json
import multiprocessing
from dataclasses import dataclass
from typing import List, Union
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
from sglang.srt.server import Runtime
DEFAULT_PROMPTS = [
"The capital of France is",
"The capital of the United Kindom is",
"Today is a sunny day and I like",
]
NUM_TOP_LOGPROBS = 5
def is_embedding_model(model_path):
# FIXME incomplete list
if "e5-mistral-7b-instruct" in model_path.lower():
return True
return False
def get_dtype_str(torch_dtype):
if torch_dtype is torch.float16:
return "float16"
else:
raise NotImplementedError()
@dataclass
class ModelOutput:
output_strs: str = None
top_input_logprobs: torch.Tensor = None
top_output_logprobs: torch.Tensor = None
embed_logits: torch.Tensor = None
class HFRunner:
def __init__(
self,
model_path,
torch_dtype=torch.float16,
is_embedding_model=None,
):
self.in_queue = multiprocessing.Queue()
self.out_queue = multiprocessing.Queue()
self.model_proc = multiprocessing.Process(
target=self.start_model_process,
args=(
self.in_queue,
self.out_queue,
model_path,
torch_dtype,
is_embedding_model,
),
)
self.model_proc.start()
def start_model_process(
self, in_queue, out_queue, model_path, torch_dtype, is_embedding_model
):
self.tokenizer = AutoTokenizer.from_pretrained(
model_path,
torch_dtype=torch_dtype,
trust_remote_code=True,
)
self.is_embedding_model = (
is_embedding_model(model_path)
if is_embedding_model is None
else is_embedding_model
)
if not self.is_embedding_model:
self.model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
trust_remote_code=True,
).cuda()
else:
from sentence_transformers import SentenceTransformer
self.model = SentenceTransformer(
model_path,
device="cpu",
).to(dtype=torch_dtype)
while True:
prompts, max_new_tokens = in_queue.get()
if prompts is not None:
if not self.is_embedding_model:
output_strs = []
prefill_logprobs = []
for p in prompts:
if isinstance(p, str):
input_ids = self.tokenizer.encode(
p, return_tensors="pt"
).cuda()
else:
input_ids = torch.tensor([p], device="cuda")
output_ids = self.model.generate(
input_ids, do_sample=False, max_new_tokens=max_new_tokens
)
output_strs.append(self.tokenizer.decode(output_ids[0]))
logits = self.model.forward(input_ids).logits[0]
logprobs = F.log_softmax(
logits, dim=-1, dtype=torch.float32
).tolist()
# index_of_max = (lambda nums: nums.index(max(nums)))(logprobs[-1])
# print("index", index_of_max)
logprobs = [
sorted(token_logprobs, reverse=True)[:NUM_TOP_LOGPROBS]
for token_logprobs in logprobs
]
prefill_logprobs.append(logprobs)
out_queue.put(
ModelOutput(
output_strs=output_strs, top_input_logprobs=prefill_logprobs
)
)
else:
assert isinstance(prompts, List[str])
logits = self.model.encode(prompts).tolist()
out_queue.put(ModelOutput(embed_logits=logits))
def forward(
self,
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
max_new_tokens=64,
):
self.in_queue.put((prompts, max_new_tokens))
return self.out_queue.get()
def terminate(self):
self.model_proc.terminate()
self.in_queue = self.out_queue = None
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
self.model_proc.terminate()
self.in_queue = self.out_queue = None
class SRTRunner:
def __init__(
self,
model_path,
tp_size=1,
torch_dtype=torch.float16,
is_embedding_model=None,
):
self.is_embedding_model = (
is_embedding_model(model_path)
if is_embedding_model is None
else is_embedding_model
)
if self.is_embedding_model:
raise NotImplementedError()
self.runtime = Runtime(
model_path=model_path,
tp_size=tp_size,
dtype=get_dtype_str(torch_dtype),
)
def forward(
self,
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
max_new_tokens=64,
):
# the return value contains logprobs from prefill
output_strs = []
top_input_logprobs = []
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
for prompt in prompts:
response = self.runtime.generate(
prompt,
sampling_params=sampling_params,
return_logprob=True,
top_logprobs_num=NUM_TOP_LOGPROBS,
)
response = json.loads(response)
output_strs.append(response["text"])
top_input_logprobs.append(
[
[tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
for x in response["meta_info"]["input_top_logprobs"][1:]
]
+ [
[
tup[0]
for tup in response["meta_info"]["output_top_logprobs"][0][
:NUM_TOP_LOGPROBS
]
]
]
)
# print(response["meta_info"]["output_top_logprobs"][0])
return ModelOutput(
output_strs=output_strs, top_input_logprobs=top_input_logprobs
)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
self.runtime.shutdown()
del self.runtime
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
import torch
from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner
MODELS = [
("meta-llama/Meta-Llama-3.1-8B-Instruct", 1),
# ("meta-llama/Meta-Llama-3.1-8B-Instruct", 2),
]
TORCH_DTYPES = [torch.float16]
class TestCausalModels(unittest.TestCase):
def assert_close_prefill_logits(
self,
prompts,
model_path,
tp_size,
torch_dtype,
) -> None:
with HFRunner(
model_path, torch_dtype=torch_dtype, is_embedding_model=False
) as hf_runner:
hf_outputs = hf_runner.forward(prompts)
with SRTRunner(
model_path,
tp_size=tp_size,
torch_dtype=torch_dtype,
is_embedding_model=False,
) as srt_runner:
srt_outputs = srt_runner.forward(prompts)
for i in range(len(prompts)):
hf_logprobs = torch.Tensor(hf_outputs.top_input_logprobs[i])
srt_logprobs = torch.Tensor(srt_outputs.top_input_logprobs[i])
tolerance = 2e-2
assert torch.all(
abs(hf_logprobs - srt_logprobs) < tolerance
), f"prefill logprobs not all close"
def test_prefill_logits(self):
for model, tp_size in MODELS:
for torch_dtype in TORCH_DTYPES:
self.assert_close_prefill_logits(
DEFAULT_PROMPTS, model, tp_size, torch_dtype
)
if __name__ == "__main__":
unittest.main(warnings="ignore")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment