Unverified Commit 167591e8 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Better unit tests for adding a new model (#1488)

parent 441c22db
...@@ -90,9 +90,9 @@ docker run --gpus all \ ...@@ -90,9 +90,9 @@ docker run --gpus all \
<summary>More</summary> <summary>More</summary>
> This method is recommended if you plan to serve it as a service. > This method is recommended if you plan to serve it as a service.
> A better approach is to use the [k8s-sglang-service.yaml](./docker/k8s-sglang-service.yaml). > A better approach is to use the [k8s-sglang-service.yaml](docker/k8s-sglang-service.yaml).
1. Copy the [compose.yml](./docker/compose.yaml) to your local machine 1. Copy the [compose.yml](docker/compose.yaml) to your local machine
2. Execute the command `docker compose up -d` in your terminal. 2. Execute the command `docker compose up -d` in your terminal.
</details> </details>
...@@ -271,7 +271,7 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct ...@@ -271,7 +271,7 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct
- gte-Qwen2 - gte-Qwen2
- `python -m sglang.launch_server --model-path Alibaba-NLP/gte-Qwen2-7B-instruct --is-embedding` - `python -m sglang.launch_server --model-path Alibaba-NLP/gte-Qwen2-7B-instruct --is-embedding`
Instructions for supporting a new model are [here](https://github.com/sgl-project/sglang/blob/main/docs/en/model_support.md). Instructions for supporting a new model are [here](docs/en/model_support.md).
#### Use Models From ModelScope #### Use Models From ModelScope
<details> <details>
...@@ -566,7 +566,7 @@ def chat_example(s): ...@@ -566,7 +566,7 @@ def chat_example(s):
Learn more at this [blog](https://lmsys.org/blog/2024-07-25-sglang-llama3/). Learn more at this [blog](https://lmsys.org/blog/2024-07-25-sglang-llama3/).
## Roadmap ## Roadmap
[Development Roadmap (2024 Q3)](https://github.com/sgl-project/sglang/issues/634) [Development Roadmap (2024 Q4)](https://github.com/sgl-project/sglang/issues/1487)
## Citation And Acknowledgment ## Citation And Acknowledgment
Please cite our paper, [SGLang: Efficient Execution of Structured Language Model Programs](https://arxiv.org/abs/2312.07104), if you find the project useful. Please cite our paper, [SGLang: Efficient Execution of Structured Language Model Programs](https://arxiv.org/abs/2312.07104), if you find the project useful.
......
...@@ -8,8 +8,7 @@ The core features include: ...@@ -8,8 +8,7 @@ The core features include:
- **Fast Backend Runtime**: Provides efficient serving with RadixAttention for prefix caching, jump-forward constrained decoding, continuous batching, token attention (paged attention), tensor parallelism, FlashInfer kernels, chunked prefill, and quantization (INT4/FP8/AWQ/GPTQ). - **Fast Backend Runtime**: Provides efficient serving with RadixAttention for prefix caching, jump-forward constrained decoding, continuous batching, token attention (paged attention), tensor parallelism, FlashInfer kernels, chunked prefill, and quantization (INT4/FP8/AWQ/GPTQ).
- **Flexible Frontend Language**: Offers an intuitive interface for programming LLM applications, including chained generation calls, advanced prompting, control flow, multi-modal inputs, parallelism, and external interactions. - **Flexible Frontend Language**: Offers an intuitive interface for programming LLM applications, including chained generation calls, advanced prompting, control flow, multi-modal inputs, parallelism, and external interactions.
- **Extensive Model Support**: Supports a wide range of generative models (Llama 3, Gemma 2, Mistral, QWen, DeepSeek, LLaVA, etc.) and embedding models (e5-mistral), with easy extensibility for integrating new models. - **Extensive Model Support**: Supports a wide range of generative models (Llama 3, Gemma 2, Mistral, QWen, DeepSeek, LLaVA, etc.) and embedding models (e5-mistral), with easy extensibility for integrating new models.
- **Active Community**: SGLang is open-source and backed by an active community with industry adoption, welcoming contributions to improve LLM and VLM serving. - **Active Community**: SGLang is open-source and backed by an active community with industry adoption.
.. toctree:: .. toctree::
:maxdepth: 1 :maxdepth: 1
......
# How to Support a New Model # How to Support a New Model
To support a new model in SGLang, you only need to add a single file under [SGLang Models Directory](https://github.com/sgl-project/sglang/tree/main/python/sglang/srt/models). You can learn from existing model implementations and create new files for the new models. Most models are based on the transformer architecture, making them very similar. To support a new model in SGLang, you only need to add a single file under [SGLang Models Directory](https://github.com/sgl-project/sglang/tree/main/python/sglang/srt/models).
You can learn from existing model implementations and create new files for the new models.
For most models, you should be able to find a similar model to start with (e.g., starting from Llama).
Another valuable resource is the [vLLM Models Directory](https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/models). vLLM has extensive coverage of models, and SGLang has reused vLLM for most parts of the model implementations. This similarity makes it easy to port many models from vLLM to SGLang. ## Test the correctness
To port a model from vLLM to SGLang, you can compare these two files [SGLang LLaMA Implementation](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/models/llama.py) and [vLLM LLaMA Implementation](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llama.py). This comparison will help you understand how to convert a model implementation from vLLM to SGLang. The major difference is the replacement of PagedAttention with RadixAttention. The other parts are almost identical. Specifically, ### Interactive debugging
For interactive debugging, you can compare the outputs of huggingface/transformers and SGLang.
The following two commands should give the same text output and very similar prefill logits.
- Get the reference output by `python3 scripts/playground/reference_hf.py --model [new model]`
- Get the SGLang output by `python3 -m sglang.bench_latency --correct --model [new model]`
### Add the model to the test suite
To make sure the new model is well maintained in the future, it is better to add it to the test suite.
You can add it to the `ALL_OTHER_MODELS` list in the [test_generation_models.py](https://github.com/sgl-project/sglang/blob/main/test/srt/models/test_generation_models.py) and run the following command to test it.
For example, if the model is Qwen/Qwen2-1.5B
```
ONLY_RUN=Qwen/Qwen2-1.5B python3 -m unittest test_generation_models.TestGenerationModels.test_others
```
## Port a model from vLLM to SGLang
Another valuable resource is the [vLLM Models Directory](https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/models). vLLM has extensive coverage of models, and SGLang reuses vLLM's interface and some layers to implement the models. This similarity makes it easy to port many models from vLLM to SGLang.
To port a model from vLLM to SGLang, you can compare these two files [SGLang Llama Implementation](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/models/llama.py) and [vLLM Llama Implementation](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llama.py). This comparison will help you understand how to convert a model implementation from vLLM to SGLang. The major difference is the replacement of Attention with RadixAttention. The other parts are almost identical. Specifically,
- Replace vllm's `Attention` with `RadixAttention`. Note that you need to pass `layer_id` all the way to `RadixAttention`. - Replace vllm's `Attention` with `RadixAttention`. Note that you need to pass `layer_id` all the way to `RadixAttention`.
- Replace vllm's `LogitsProcessor` with SGLang's `LogitsProcessor`. - Replace vllm's `LogitsProcessor` with SGLang's `LogitsProcessor`.
- Replace other vLLM layers with SGLang layers (e.g., `RMSNorm`, `SiluAndMul`).
- Remove `Sample`. - Remove `Sample`.
- Change `forward()` functions, and add `input_metadata`. - Change `forward()` functions, and add `input_metadata`.
- Add `EntryClass` at the end. - Add `EntryClass` at the end.
- Test correctness by comparing the final logits and outputs of the two following commands:
- `python3 scripts/playground/reference_hf.py --model [new model]`
- `python3 -m sglang.bench_latency --model [new model] --correct --output-len 16 --trust-remote-code`
- Update [Supported Models](https://github.com/sgl-project/sglang/tree/main?tab=readme-ov-file#supported-models) at [README](https://github.com/sgl-project/sglang/blob/main/README.md).
...@@ -21,19 +21,18 @@ from typing import List, Union ...@@ -21,19 +21,18 @@ from typing import List, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
from sglang.srt.server import Runtime from sglang.srt.server import Runtime
from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER
DEFAULT_PROMPTS = [ DEFAULT_PROMPTS = [
# the output of gemma-2-2b from SRT is unstable on the commented prompt
# "The capital of France is",
"Apple is red. Banana is Yellow. " * 800 + "Apple is", "Apple is red. Banana is Yellow. " * 800 + "Apple is",
"The capital of the United Kingdom is", "The capital of the United Kingdom is",
"Today is a sunny day and I like", "Today is a sunny day and I like",
"AI is a field of computer science focused on", "AI is a field of computer science focused on",
# the output of gemma-2-2b from SRT is unstable on the commented prompt
# "The capital of France is",
] ]
dirpath = os.path.dirname(__file__) dirpath = os.path.dirname(__file__)
...@@ -132,6 +131,8 @@ class HFRunner: ...@@ -132,6 +131,8 @@ class HFRunner:
input_ids = torch.tensor([p], device="cuda") input_ids = torch.tensor([p], device="cuda")
if lora_paths is not None and lora_paths[i] is not None: if lora_paths is not None and lora_paths[i] is not None:
from peft import PeftModel
self.model = PeftModel.from_pretrained( self.model = PeftModel.from_pretrained(
self.base_model, self.base_model,
lora_paths[i], lora_paths[i],
......
...@@ -587,3 +587,37 @@ def run_bench_latency(model, other_args): ...@@ -587,3 +587,37 @@ def run_bench_latency(model, other_args):
kill_child_process(process.pid) kill_child_process(process.pid)
return output_throughput return output_throughput
def lcs(X, Y):
m = len(X)
n = len(Y)
L = [[0] * (n + 1) for _ in range(m + 1)]
for i in range(m + 1):
for j in range(n + 1):
if i == 0 or j == 0:
L[i][j] = 0
elif X[i - 1] == Y[j - 1]:
L[i][j] = L[i - 1][j - 1] + 1
else:
L[i][j] = max(L[i - 1][j], L[i][j - 1])
return L[m][n]
def calculate_rouge_l(output_strs_list1, output_strs_list2):
"""calculate the ROUGE-L score"""
rouge_l_scores = []
for s1, s2 in zip(output_strs_list1, output_strs_list2):
lcs_len = lcs(s1, s2)
precision = lcs_len / len(s1) if len(s1) > 0 else 0
recall = lcs_len / len(s2) if len(s2) > 0 else 0
if precision + recall > 0:
fmeasure = (2 * precision * recall) / (precision + recall)
else:
fmeasure = 0.0
rouge_l_scores.append(fmeasure)
return rouge_l_scores
...@@ -39,20 +39,21 @@ def normal_text(args): ...@@ -39,20 +39,21 @@ def normal_text(args):
device_map="auto", device_map="auto",
trust_remote_code=True, trust_remote_code=True,
) )
m.cuda()
prompts = [ prompts = [
"The capital of France is", "The capital of France is",
"The capital of the United Kindom is", "The capital of the United Kindom is",
"Today is a sunny day and I like", "Today is a sunny day and I like",
] ]
max_new_tokens = 16 max_new_tokens = 17
torch.cuda.set_device(0)
for i, p in enumerate(prompts): for i, p in enumerate(prompts):
if isinstance(p, str): if isinstance(p, str):
input_ids = t.encode(p, return_tensors="pt").cuda() input_ids = t.encode(p, return_tensors="pt").to("cuda:0")
else: else:
input_ids = torch.tensor([p], device="cuda") input_ids = torch.tensor([p], device="cuda:0")
output_ids = m.generate( output_ids = m.generate(
input_ids, do_sample=False, max_new_tokens=max_new_tokens input_ids, do_sample=False, max_new_tokens=max_new_tokens
......
"""
Usage:
To test a specific model:
1. Add it to ALL_OTHER_MODELS
2. Run `ONLY_RUN=Qwen/Qwen2-1.5B python3 -m unittest test_generation_models.TestGenerationModels.test_others`
"""
""" """
Copyright 2023-2024 SGLang Team Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
...@@ -13,69 +21,55 @@ See the License for the specific language governing permissions and ...@@ -13,69 +21,55 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
import dataclasses
import multiprocessing as mp import multiprocessing as mp
import os
import unittest import unittest
from typing import List
import torch import torch
from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner
from sglang.test.test_utils import calculate_rouge_l
MODELS = [
("meta-llama/Meta-Llama-3.1-8B-Instruct", 1, 1.1, 3e-2, 4e-2, 1),
("google/gemma-2-2b", 1, 3, 3e-2, 5e-2, 1),
("Alibaba-NLP/gte-Qwen2-1.5B-instruct", 1, None, 6e-2, 4e-2, 1),
]
TORCH_DTYPES = [torch.float16]
def lcs(X, Y):
m = len(X)
n = len(Y)
L = [[0] * (n + 1) for _ in range(m + 1)]
for i in range(m + 1): @dataclasses.dataclass
for j in range(n + 1): class ModelCase:
if i == 0 or j == 0: model_path: str
L[i][j] = 0 tp_size: int = 1
elif X[i - 1] == Y[j - 1]: prefill_tolerance: float = 5e-2
L[i][j] = L[i - 1][j - 1] + 1 decode_tolerance: float = 5e-2
else: rouge_l_tolerance: float = 1
L[i][j] = max(L[i - 1][j], L[i][j - 1])
return L[m][n]
# Popular models that run on CI
CI_MODELS = [
ModelCase("meta-llama/Meta-Llama-3.1-8B-Instruct"),
ModelCase("google/gemma-2-2b"),
]
def calculate_rouge_l(output_strs_list1, output_strs_list2): # All other models
rouge_l_scores = [] ALL_OTHER_MODELS = [
ModelCase("Qwen/Qwen2-1.5B"),
for s1, s2 in zip(output_strs_list1, output_strs_list2): ]
lcs_len = lcs(s1, s2)
precision = lcs_len / len(s1) if len(s1) > 0 else 0
recall = lcs_len / len(s2) if len(s2) > 0 else 0
if precision + recall > 0:
fmeasure = (2 * precision * recall) / (precision + recall)
else:
fmeasure = 0.0
rouge_l_scores.append(fmeasure)
return rouge_l_scores TORCH_DTYPES = [torch.float16]
class TestGenerationModels(unittest.TestCase): class TestGenerationModels(unittest.TestCase):
def assert_close_prefill_logits_and_output_strs( def assert_close_logits_and_output_strs(
self, self,
prompts, prompts: List[str],
model_path, model_case: ModelCase,
tp_size, torch_dtype: torch.dtype,
torch_dtype,
max_new_tokens,
prefill_tolerance,
output_tolerance,
rouge_threshold,
long_context_tolerance,
) -> None: ) -> None:
if model_path == "Alibaba-NLP/gte-Qwen2-1.5B-instruct": model_path = model_case.model_path
prompts = prompts[:-1] prefill_tolerance, decode_tolerance, rouge_l_tolerance = (
model_case.prefill_tolerance,
model_case.decode_tolerance,
model_case.rouge_l_tolerance,
)
max_new_tokens = 32
with HFRunner( with HFRunner(
model_path, torch_dtype=torch_dtype, is_generation=True model_path, torch_dtype=torch_dtype, is_generation=True
...@@ -84,14 +78,14 @@ class TestGenerationModels(unittest.TestCase): ...@@ -84,14 +78,14 @@ class TestGenerationModels(unittest.TestCase):
with SRTRunner( with SRTRunner(
model_path, model_path,
tp_size=tp_size, tp_size=model_case.tp_size,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
is_generation=True, is_generation=True,
) as srt_runner: ) as srt_runner:
srt_outputs = srt_runner.forward(prompts, max_new_tokens=max_new_tokens) srt_outputs = srt_runner.forward(prompts, max_new_tokens=max_new_tokens)
for i in range(len(prompts)): for i in range(len(prompts)):
# input logprobs comparison # Compare input logprobs
hf_logprobs = torch.Tensor(hf_outputs.top_input_logprobs[i]) hf_logprobs = torch.Tensor(hf_outputs.top_input_logprobs[i])
srt_logprobs = torch.Tensor(srt_outputs.top_input_logprobs[i]) srt_logprobs = torch.Tensor(srt_outputs.top_input_logprobs[i])
input_len = hf_logprobs.shape[0] input_len = hf_logprobs.shape[0]
...@@ -99,67 +93,56 @@ class TestGenerationModels(unittest.TestCase): ...@@ -99,67 +93,56 @@ class TestGenerationModels(unittest.TestCase):
"prefill logprobs max_diff", torch.max(abs(hf_logprobs - srt_logprobs)) "prefill logprobs max_diff", torch.max(abs(hf_logprobs - srt_logprobs))
) )
if input_len <= 100: if input_len <= 100:
assert torch.all( assert torch.all(abs(hf_logprobs - srt_logprobs) < prefill_tolerance), (
abs(hf_logprobs - srt_logprobs) < prefill_tolerance f"prefill logprobs are not all close with model_path={model_path} prompts={prompts} "
), f"prefill logprobs are not all close with model_path={model_path} prompts={prompts} prefill_tolerance={prefill_tolerance}" f"prefill_tolerance={prefill_tolerance}."
f"{hf_logprobs=}, {srt_logprobs=}"
)
# output logprobs comparison # Compare output logprobs
hf_logprobs = torch.Tensor(hf_outputs.top_output_logprobs[i]) hf_logprobs = torch.Tensor(hf_outputs.top_output_logprobs[i])
srt_logprobs = torch.Tensor(srt_outputs.top_output_logprobs[i]) srt_logprobs = torch.Tensor(srt_outputs.top_output_logprobs[i])
# print(
# "output logprobs diff",
# [
# float(torch.max(abs(hf_logprobs[j] - srt_logprobs[j])))
# for j in range(max_new_tokens)
# ],
# )
print( print(
"output logprobs max_diff", torch.max(abs(hf_logprobs - srt_logprobs)) "decode logprobs max_diff", torch.max(abs(hf_logprobs - srt_logprobs))
) )
if input_len <= 100: if input_len <= 100:
assert torch.all( assert torch.all(abs(hf_logprobs - srt_logprobs) < decode_tolerance), (
abs(hf_logprobs - srt_logprobs) < output_tolerance f"decode logprobs are not all close with model_path={model_path} prompts={prompts} "
), f"output logprobs are not all close with model_path={model_path} prompts={prompts}... output_tolerance={output_tolerance}" f"decode_tolerance={decode_tolerance}."
f"{hf_logprobs=}, {srt_logprobs=}"
)
# output strings comparison # Compare output strings
print(f"hf_outputs.output_strs={hf_outputs.output_strs}") print(f"{hf_outputs.output_strs=}")
print(f"srt_outputs.output_strs={srt_outputs.output_strs}") print(f"{srt_outputs.output_strs=}")
rouge_l_scores = calculate_rouge_l( rouge_l_scores = calculate_rouge_l(
hf_outputs.output_strs, srt_outputs.output_strs hf_outputs.output_strs, srt_outputs.output_strs
) )
print(f"rouge_l_scores={rouge_l_scores}") print(f"{rouge_l_scores=}")
assert all( assert all(
score >= rouge_threshold for score in rouge_l_scores score >= rouge_l_tolerance for score in rouge_l_scores
), f"Not all ROUGE-L scores are greater than rouge_threshold={rouge_threshold}" ), f"Not all ROUGE-L scores are greater than rouge_l_tolerance={rouge_l_tolerance}"
def test_prefill_logits_and_output_strs(self): def test_ci_models(self):
for ( for model_case in CI_MODELS:
model,
tp_size,
long_context_tolerance,
prefill_tolerance,
output_tolerance,
rouge_threshold,
) in MODELS:
for torch_dtype in TORCH_DTYPES: for torch_dtype in TORCH_DTYPES:
max_new_tokens = 32 self.assert_close_logits_and_output_strs(
self.assert_close_prefill_logits_and_output_strs( DEFAULT_PROMPTS, model_case, torch_dtype
DEFAULT_PROMPTS,
model,
tp_size,
torch_dtype,
max_new_tokens,
prefill_tolerance=prefill_tolerance,
output_tolerance=output_tolerance,
rouge_threshold=rouge_threshold,
long_context_tolerance=long_context_tolerance,
) )
def test_others(self):
for model_case in ALL_OTHER_MODELS:
if (
"ONLY_RUN" in os.environ
and os.environ["ONLY_RUN"] != model_case.model_path
):
continue
self.assert_close_logits_and_output_strs(
DEFAULT_PROMPTS, model_case, torch.float16
)
if __name__ == "__main__":
try:
mp.set_start_method("spawn")
except RuntimeError:
pass
if __name__ == "__main__":
mp.set_start_method("spawn")
unittest.main() unittest.main()
...@@ -12,7 +12,7 @@ from sglang.test.test_utils import ( ...@@ -12,7 +12,7 @@ from sglang.test.test_utils import (
) )
class TestReplaceWeights(unittest.TestCase): class TestUpdateWeights(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST cls.model = DEFAULT_MODEL_NAME_FOR_TEST
...@@ -33,13 +33,7 @@ class TestReplaceWeights(unittest.TestCase): ...@@ -33,13 +33,7 @@ class TestReplaceWeights(unittest.TestCase):
"sampling_params": { "sampling_params": {
"temperature": 0, "temperature": 0,
"max_new_tokens": 32, "max_new_tokens": 32,
"n": 1,
}, },
"stream": False,
"return_logprob": False,
"top_logprobs_num": 0,
"return_text_in_logprobs": False,
"logprob_start_len": 0,
}, },
) )
print(json.dumps(response.json())) print(json.dumps(response.json()))
...@@ -64,7 +58,7 @@ class TestReplaceWeights(unittest.TestCase): ...@@ -64,7 +58,7 @@ class TestReplaceWeights(unittest.TestCase):
print(json.dumps(response.json())) print(json.dumps(response.json()))
return ret return ret
def test_replace_weights(self): def test_update_weights(self):
origin_model_path = self.get_model_info() origin_model_path = self.get_model_info()
print(f"origin_model_path: {origin_model_path}") print(f"origin_model_path: {origin_model_path}")
origin_response = self.run_decode() origin_response = self.run_decode()
...@@ -92,7 +86,7 @@ class TestReplaceWeights(unittest.TestCase): ...@@ -92,7 +86,7 @@ class TestReplaceWeights(unittest.TestCase):
updated_response = self.run_decode() updated_response = self.run_decode()
assert origin_response[:32] == updated_response[:32] assert origin_response[:32] == updated_response[:32]
def test_replace_weights_unexist_model(self): def test_update_weights_unexist_model(self):
origin_model_path = self.get_model_info() origin_model_path = self.get_model_info()
print(f"origin_model_path: {origin_model_path}") print(f"origin_model_path: {origin_model_path}")
origin_response = self.run_decode() origin_response = self.run_decode()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment