Unverified Commit 167591e8 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Better unit tests for adding a new model (#1488)

parent 441c22db
......@@ -90,9 +90,9 @@ docker run --gpus all \
<summary>More</summary>
> This method is recommended if you plan to serve it as a service.
> A better approach is to use the [k8s-sglang-service.yaml](./docker/k8s-sglang-service.yaml).
> A better approach is to use the [k8s-sglang-service.yaml](docker/k8s-sglang-service.yaml).
1. Copy the [compose.yml](./docker/compose.yaml) to your local machine
1. Copy the [compose.yml](docker/compose.yaml) to your local machine
2. Execute the command `docker compose up -d` in your terminal.
</details>
......@@ -271,7 +271,7 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct
- gte-Qwen2
- `python -m sglang.launch_server --model-path Alibaba-NLP/gte-Qwen2-7B-instruct --is-embedding`
Instructions for supporting a new model are [here](https://github.com/sgl-project/sglang/blob/main/docs/en/model_support.md).
Instructions for supporting a new model are [here](docs/en/model_support.md).
#### Use Models From ModelScope
<details>
......@@ -566,7 +566,7 @@ def chat_example(s):
Learn more at this [blog](https://lmsys.org/blog/2024-07-25-sglang-llama3/).
## Roadmap
[Development Roadmap (2024 Q3)](https://github.com/sgl-project/sglang/issues/634)
[Development Roadmap (2024 Q4)](https://github.com/sgl-project/sglang/issues/1487)
## Citation And Acknowledgment
Please cite our paper, [SGLang: Efficient Execution of Structured Language Model Programs](https://arxiv.org/abs/2312.07104), if you find the project useful.
......
......@@ -8,8 +8,7 @@ The core features include:
- **Fast Backend Runtime**: Provides efficient serving with RadixAttention for prefix caching, jump-forward constrained decoding, continuous batching, token attention (paged attention), tensor parallelism, FlashInfer kernels, chunked prefill, and quantization (INT4/FP8/AWQ/GPTQ).
- **Flexible Frontend Language**: Offers an intuitive interface for programming LLM applications, including chained generation calls, advanced prompting, control flow, multi-modal inputs, parallelism, and external interactions.
- **Extensive Model Support**: Supports a wide range of generative models (Llama 3, Gemma 2, Mistral, QWen, DeepSeek, LLaVA, etc.) and embedding models (e5-mistral), with easy extensibility for integrating new models.
- **Active Community**: SGLang is open-source and backed by an active community with industry adoption, welcoming contributions to improve LLM and VLM serving.
- **Active Community**: SGLang is open-source and backed by an active community with industry adoption.
.. toctree::
:maxdepth: 1
......
# How to Support a New Model
To support a new model in SGLang, you only need to add a single file under [SGLang Models Directory](https://github.com/sgl-project/sglang/tree/main/python/sglang/srt/models). You can learn from existing model implementations and create new files for the new models. Most models are based on the transformer architecture, making them very similar.
To support a new model in SGLang, you only need to add a single file under [SGLang Models Directory](https://github.com/sgl-project/sglang/tree/main/python/sglang/srt/models).
You can learn from existing model implementations and create new files for the new models.
For most models, you should be able to find a similar model to start with (e.g., starting from Llama).
Another valuable resource is the [vLLM Models Directory](https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/models). vLLM has extensive coverage of models, and SGLang has reused vLLM for most parts of the model implementations. This similarity makes it easy to port many models from vLLM to SGLang.
## Test the correctness
To port a model from vLLM to SGLang, you can compare these two files [SGLang LLaMA Implementation](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/models/llama.py) and [vLLM LLaMA Implementation](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llama.py). This comparison will help you understand how to convert a model implementation from vLLM to SGLang. The major difference is the replacement of PagedAttention with RadixAttention. The other parts are almost identical. Specifically,
### Interactive debugging
For interactive debugging, you can compare the outputs of huggingface/transformers and SGLang.
The following two commands should give the same text output and very similar prefill logits.
- Get the reference output by `python3 scripts/playground/reference_hf.py --model [new model]`
- Get the SGLang output by `python3 -m sglang.bench_latency --correct --model [new model]`
### Add the model to the test suite
To make sure the new model is well maintained in the future, it is better to add it to the test suite.
You can add it to the `ALL_OTHER_MODELS` list in the [test_generation_models.py](https://github.com/sgl-project/sglang/blob/main/test/srt/models/test_generation_models.py) and run the following command to test it.
For example, if the model is Qwen/Qwen2-1.5B
```
ONLY_RUN=Qwen/Qwen2-1.5B python3 -m unittest test_generation_models.TestGenerationModels.test_others
```
## Port a model from vLLM to SGLang
Another valuable resource is the [vLLM Models Directory](https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/models). vLLM has extensive coverage of models, and SGLang reuses vLLM's interface and some layers to implement the models. This similarity makes it easy to port many models from vLLM to SGLang.
To port a model from vLLM to SGLang, you can compare these two files [SGLang Llama Implementation](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/models/llama.py) and [vLLM Llama Implementation](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llama.py). This comparison will help you understand how to convert a model implementation from vLLM to SGLang. The major difference is the replacement of Attention with RadixAttention. The other parts are almost identical. Specifically,
- Replace vllm's `Attention` with `RadixAttention`. Note that you need to pass `layer_id` all the way to `RadixAttention`.
- Replace vllm's `LogitsProcessor` with SGLang's `LogitsProcessor`.
- Replace other vLLM layers with SGLang layers (e.g., `RMSNorm`, `SiluAndMul`).
- Remove `Sample`.
- Change `forward()` functions, and add `input_metadata`.
- Add `EntryClass` at the end.
- Test correctness by comparing the final logits and outputs of the two following commands:
- `python3 scripts/playground/reference_hf.py --model [new model]`
- `python3 -m sglang.bench_latency --model [new model] --correct --output-len 16 --trust-remote-code`
- Update [Supported Models](https://github.com/sgl-project/sglang/tree/main?tab=readme-ov-file#supported-models) at [README](https://github.com/sgl-project/sglang/blob/main/README.md).
......@@ -21,19 +21,18 @@ from typing import List, Union
import torch
import torch.nn.functional as F
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
from sglang.srt.server import Runtime
from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER
DEFAULT_PROMPTS = [
# the output of gemma-2-2b from SRT is unstable on the commented prompt
# "The capital of France is",
"Apple is red. Banana is Yellow. " * 800 + "Apple is",
"The capital of the United Kingdom is",
"Today is a sunny day and I like",
"AI is a field of computer science focused on",
# the output of gemma-2-2b from SRT is unstable on the commented prompt
# "The capital of France is",
]
dirpath = os.path.dirname(__file__)
......@@ -132,6 +131,8 @@ class HFRunner:
input_ids = torch.tensor([p], device="cuda")
if lora_paths is not None and lora_paths[i] is not None:
from peft import PeftModel
self.model = PeftModel.from_pretrained(
self.base_model,
lora_paths[i],
......
......@@ -587,3 +587,37 @@ def run_bench_latency(model, other_args):
kill_child_process(process.pid)
return output_throughput
def lcs(X, Y):
m = len(X)
n = len(Y)
L = [[0] * (n + 1) for _ in range(m + 1)]
for i in range(m + 1):
for j in range(n + 1):
if i == 0 or j == 0:
L[i][j] = 0
elif X[i - 1] == Y[j - 1]:
L[i][j] = L[i - 1][j - 1] + 1
else:
L[i][j] = max(L[i - 1][j], L[i][j - 1])
return L[m][n]
def calculate_rouge_l(output_strs_list1, output_strs_list2):
"""calculate the ROUGE-L score"""
rouge_l_scores = []
for s1, s2 in zip(output_strs_list1, output_strs_list2):
lcs_len = lcs(s1, s2)
precision = lcs_len / len(s1) if len(s1) > 0 else 0
recall = lcs_len / len(s2) if len(s2) > 0 else 0
if precision + recall > 0:
fmeasure = (2 * precision * recall) / (precision + recall)
else:
fmeasure = 0.0
rouge_l_scores.append(fmeasure)
return rouge_l_scores
......@@ -39,20 +39,21 @@ def normal_text(args):
device_map="auto",
trust_remote_code=True,
)
m.cuda()
prompts = [
"The capital of France is",
"The capital of the United Kindom is",
"Today is a sunny day and I like",
]
max_new_tokens = 16
max_new_tokens = 17
torch.cuda.set_device(0)
for i, p in enumerate(prompts):
if isinstance(p, str):
input_ids = t.encode(p, return_tensors="pt").cuda()
input_ids = t.encode(p, return_tensors="pt").to("cuda:0")
else:
input_ids = torch.tensor([p], device="cuda")
input_ids = torch.tensor([p], device="cuda:0")
output_ids = m.generate(
input_ids, do_sample=False, max_new_tokens=max_new_tokens
......
"""
Usage:
To test a specific model:
1. Add it to ALL_OTHER_MODELS
2. Run `ONLY_RUN=Qwen/Qwen2-1.5B python3 -m unittest test_generation_models.TestGenerationModels.test_others`
"""
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
......@@ -13,69 +21,55 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
import dataclasses
import multiprocessing as mp
import os
import unittest
from typing import List
import torch
from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner
from sglang.test.test_utils import calculate_rouge_l
MODELS = [
("meta-llama/Meta-Llama-3.1-8B-Instruct", 1, 1.1, 3e-2, 4e-2, 1),
("google/gemma-2-2b", 1, 3, 3e-2, 5e-2, 1),
("Alibaba-NLP/gte-Qwen2-1.5B-instruct", 1, None, 6e-2, 4e-2, 1),
]
TORCH_DTYPES = [torch.float16]
def lcs(X, Y):
m = len(X)
n = len(Y)
L = [[0] * (n + 1) for _ in range(m + 1)]
for i in range(m + 1):
for j in range(n + 1):
if i == 0 or j == 0:
L[i][j] = 0
elif X[i - 1] == Y[j - 1]:
L[i][j] = L[i - 1][j - 1] + 1
else:
L[i][j] = max(L[i - 1][j], L[i][j - 1])
@dataclasses.dataclass
class ModelCase:
model_path: str
tp_size: int = 1
prefill_tolerance: float = 5e-2
decode_tolerance: float = 5e-2
rouge_l_tolerance: float = 1
return L[m][n]
# Popular models that run on CI
CI_MODELS = [
ModelCase("meta-llama/Meta-Llama-3.1-8B-Instruct"),
ModelCase("google/gemma-2-2b"),
]
def calculate_rouge_l(output_strs_list1, output_strs_list2):
rouge_l_scores = []
for s1, s2 in zip(output_strs_list1, output_strs_list2):
lcs_len = lcs(s1, s2)
precision = lcs_len / len(s1) if len(s1) > 0 else 0
recall = lcs_len / len(s2) if len(s2) > 0 else 0
if precision + recall > 0:
fmeasure = (2 * precision * recall) / (precision + recall)
else:
fmeasure = 0.0
rouge_l_scores.append(fmeasure)
# All other models
ALL_OTHER_MODELS = [
ModelCase("Qwen/Qwen2-1.5B"),
]
return rouge_l_scores
TORCH_DTYPES = [torch.float16]
class TestGenerationModels(unittest.TestCase):
def assert_close_prefill_logits_and_output_strs(
def assert_close_logits_and_output_strs(
self,
prompts,
model_path,
tp_size,
torch_dtype,
max_new_tokens,
prefill_tolerance,
output_tolerance,
rouge_threshold,
long_context_tolerance,
prompts: List[str],
model_case: ModelCase,
torch_dtype: torch.dtype,
) -> None:
if model_path == "Alibaba-NLP/gte-Qwen2-1.5B-instruct":
prompts = prompts[:-1]
model_path = model_case.model_path
prefill_tolerance, decode_tolerance, rouge_l_tolerance = (
model_case.prefill_tolerance,
model_case.decode_tolerance,
model_case.rouge_l_tolerance,
)
max_new_tokens = 32
with HFRunner(
model_path, torch_dtype=torch_dtype, is_generation=True
......@@ -84,14 +78,14 @@ class TestGenerationModels(unittest.TestCase):
with SRTRunner(
model_path,
tp_size=tp_size,
tp_size=model_case.tp_size,
torch_dtype=torch_dtype,
is_generation=True,
) as srt_runner:
srt_outputs = srt_runner.forward(prompts, max_new_tokens=max_new_tokens)
for i in range(len(prompts)):
# input logprobs comparison
# Compare input logprobs
hf_logprobs = torch.Tensor(hf_outputs.top_input_logprobs[i])
srt_logprobs = torch.Tensor(srt_outputs.top_input_logprobs[i])
input_len = hf_logprobs.shape[0]
......@@ -99,67 +93,56 @@ class TestGenerationModels(unittest.TestCase):
"prefill logprobs max_diff", torch.max(abs(hf_logprobs - srt_logprobs))
)
if input_len <= 100:
assert torch.all(
abs(hf_logprobs - srt_logprobs) < prefill_tolerance
), f"prefill logprobs are not all close with model_path={model_path} prompts={prompts} prefill_tolerance={prefill_tolerance}"
assert torch.all(abs(hf_logprobs - srt_logprobs) < prefill_tolerance), (
f"prefill logprobs are not all close with model_path={model_path} prompts={prompts} "
f"prefill_tolerance={prefill_tolerance}."
f"{hf_logprobs=}, {srt_logprobs=}"
)
# output logprobs comparison
# Compare output logprobs
hf_logprobs = torch.Tensor(hf_outputs.top_output_logprobs[i])
srt_logprobs = torch.Tensor(srt_outputs.top_output_logprobs[i])
# print(
# "output logprobs diff",
# [
# float(torch.max(abs(hf_logprobs[j] - srt_logprobs[j])))
# for j in range(max_new_tokens)
# ],
# )
print(
"output logprobs max_diff", torch.max(abs(hf_logprobs - srt_logprobs))
"decode logprobs max_diff", torch.max(abs(hf_logprobs - srt_logprobs))
)
if input_len <= 100:
assert torch.all(
abs(hf_logprobs - srt_logprobs) < output_tolerance
), f"output logprobs are not all close with model_path={model_path} prompts={prompts}... output_tolerance={output_tolerance}"
assert torch.all(abs(hf_logprobs - srt_logprobs) < decode_tolerance), (
f"decode logprobs are not all close with model_path={model_path} prompts={prompts} "
f"decode_tolerance={decode_tolerance}."
f"{hf_logprobs=}, {srt_logprobs=}"
)
# output strings comparison
print(f"hf_outputs.output_strs={hf_outputs.output_strs}")
print(f"srt_outputs.output_strs={srt_outputs.output_strs}")
# Compare output strings
print(f"{hf_outputs.output_strs=}")
print(f"{srt_outputs.output_strs=}")
rouge_l_scores = calculate_rouge_l(
hf_outputs.output_strs, srt_outputs.output_strs
)
print(f"rouge_l_scores={rouge_l_scores}")
print(f"{rouge_l_scores=}")
assert all(
score >= rouge_threshold for score in rouge_l_scores
), f"Not all ROUGE-L scores are greater than rouge_threshold={rouge_threshold}"
def test_prefill_logits_and_output_strs(self):
for (
model,
tp_size,
long_context_tolerance,
prefill_tolerance,
output_tolerance,
rouge_threshold,
) in MODELS:
score >= rouge_l_tolerance for score in rouge_l_scores
), f"Not all ROUGE-L scores are greater than rouge_l_tolerance={rouge_l_tolerance}"
def test_ci_models(self):
for model_case in CI_MODELS:
for torch_dtype in TORCH_DTYPES:
max_new_tokens = 32
self.assert_close_prefill_logits_and_output_strs(
DEFAULT_PROMPTS,
model,
tp_size,
torch_dtype,
max_new_tokens,
prefill_tolerance=prefill_tolerance,
output_tolerance=output_tolerance,
rouge_threshold=rouge_threshold,
long_context_tolerance=long_context_tolerance,
self.assert_close_logits_and_output_strs(
DEFAULT_PROMPTS, model_case, torch_dtype
)
def test_others(self):
for model_case in ALL_OTHER_MODELS:
if (
"ONLY_RUN" in os.environ
and os.environ["ONLY_RUN"] != model_case.model_path
):
continue
self.assert_close_logits_and_output_strs(
DEFAULT_PROMPTS, model_case, torch.float16
)
if __name__ == "__main__":
try:
mp.set_start_method("spawn")
except RuntimeError:
pass
if __name__ == "__main__":
mp.set_start_method("spawn")
unittest.main()
......@@ -12,7 +12,7 @@ from sglang.test.test_utils import (
)
class TestReplaceWeights(unittest.TestCase):
class TestUpdateWeights(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
......@@ -33,13 +33,7 @@ class TestReplaceWeights(unittest.TestCase):
"sampling_params": {
"temperature": 0,
"max_new_tokens": 32,
"n": 1,
},
"stream": False,
"return_logprob": False,
"top_logprobs_num": 0,
"return_text_in_logprobs": False,
"logprob_start_len": 0,
},
)
print(json.dumps(response.json()))
......@@ -64,7 +58,7 @@ class TestReplaceWeights(unittest.TestCase):
print(json.dumps(response.json()))
return ret
def test_replace_weights(self):
def test_update_weights(self):
origin_model_path = self.get_model_info()
print(f"origin_model_path: {origin_model_path}")
origin_response = self.run_decode()
......@@ -92,7 +86,7 @@ class TestReplaceWeights(unittest.TestCase):
updated_response = self.run_decode()
assert origin_response[:32] == updated_response[:32]
def test_replace_weights_unexist_model(self):
def test_update_weights_unexist_model(self):
origin_model_path = self.get_model_info()
print(f"origin_model_path: {origin_model_path}")
origin_response = self.run_decode()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment