Unverified Commit 30b4f771 authored by Chayenne's avatar Chayenne Committed by GitHub
Browse files

Support Alibaba-NLP/gte-Qwen2-7B-instruct embedding Model (#1186)


Co-authored-by: default avatarYing Sheng <sqy1415@gmail.com>
parent 66e7dcaf
...@@ -43,4 +43,4 @@ jobs: ...@@ -43,4 +43,4 @@ jobs:
run: | run: |
cd test/srt cd test/srt
python3 test_eval_accuracy_large.py python3 test_eval_accuracy_large.py
timeout-minutes: 10 timeout-minutes: 20
...@@ -41,7 +41,7 @@ jobs: ...@@ -41,7 +41,7 @@ jobs:
run: | run: |
cd test/srt cd test/srt
python3 run_suite.py --suite minimal python3 run_suite.py --suite minimal
timeout-minutes: 18 timeout-minutes: 20
- name: Test Frontend Language - name: Test Frontend Language
run: | run: |
......
...@@ -187,6 +187,13 @@ response = client.chat.completions.create( ...@@ -187,6 +187,13 @@ response = client.chat.completions.create(
max_tokens=64, max_tokens=64,
) )
print(response) print(response)
# Text embedding
response = client.embeddings.create(
model="default",
input="How are you today",
)
print(response)
``` ```
It supports streaming, vision, and most features of the Chat/Completions/Models/Batch endpoints specified by the [OpenAI API Reference](https://platform.openai.com/docs/api-reference/). It supports streaming, vision, and most features of the Chat/Completions/Models/Batch endpoints specified by the [OpenAI API Reference](https://platform.openai.com/docs/api-reference/).
...@@ -223,6 +230,8 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct ...@@ -223,6 +230,8 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct
### Supported Models ### Supported Models
**Generative Models**
- Llama / Llama 2 / Llama 3 / Llama 3.1 - Llama / Llama 2 / Llama 3 / Llama 3.1
- Mistral / Mixtral / Mistral NeMo - Mistral / Mixtral / Mistral NeMo
- Gemma / Gemma 2 - Gemma / Gemma 2
...@@ -243,6 +252,12 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct ...@@ -243,6 +252,12 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct
- ChatGLM - ChatGLM
- InternLM 2 - InternLM 2
**Embedding Models**
- e5-mistral
- gte-Qwen2
- `python -m sglang.launch_server --model-path Alibaba-NLP/gte-Qwen2-7B-instruct --is-embedding`
Instructions for supporting a new model are [here](https://github.com/sgl-project/sglang/blob/main/docs/en/model_support.md). Instructions for supporting a new model are [here](https://github.com/sgl-project/sglang/blob/main/docs/en/model_support.md).
#### Use Models From ModelScope #### Use Models From ModelScope
......
...@@ -94,7 +94,10 @@ class TokenizerManager: ...@@ -94,7 +94,10 @@ class TokenizerManager:
trust_remote_code=server_args.trust_remote_code, trust_remote_code=server_args.trust_remote_code,
model_overide_args=model_overide_args, model_overide_args=model_overide_args,
) )
self.is_generation = is_generation_model(self.hf_config.architectures)
self.is_generation = is_generation_model(
self.hf_config.architectures, self.server_args.is_embedding
)
if server_args.context_length is not None: if server_args.context_length is not None:
self.context_len = server_args.context_length self.context_len = server_args.context_length
......
...@@ -94,6 +94,7 @@ class ModelTpServer: ...@@ -94,6 +94,7 @@ class ModelTpServer:
context_length=server_args.context_length, context_length=server_args.context_length,
model_overide_args=model_overide_args, model_overide_args=model_overide_args,
) )
self.model_runner = ModelRunner( self.model_runner = ModelRunner(
model_config=self.model_config, model_config=self.model_config,
mem_fraction_static=server_args.mem_fraction_static, mem_fraction_static=server_args.mem_fraction_static,
......
...@@ -204,7 +204,7 @@ class ModelRunner: ...@@ -204,7 +204,7 @@ class ModelRunner:
else None else None
) )
self.is_generation = is_generation_model( self.is_generation = is_generation_model(
self.model_config.hf_config.architectures self.model_config.hf_config.architectures, self.server_args.is_embedding
) )
logger.info( logger.info(
...@@ -522,9 +522,18 @@ class ModelRunner: ...@@ -522,9 +522,18 @@ class ModelRunner:
batch, batch,
forward_mode=ForwardMode.EXTEND, forward_mode=ForwardMode.EXTEND,
) )
return self.model.forward( if self.is_generation:
batch.input_ids, input_metadata.positions, input_metadata return self.model.forward(
) batch.input_ids, input_metadata.positions, input_metadata
)
else:
# Only embedding models have get_embedding parameter
return self.model.forward(
batch.input_ids,
input_metadata.positions,
input_metadata,
get_embedding=True,
)
@torch.inference_mode() @torch.inference_mode()
def forward_extend_multi_modal(self, batch: ScheduleBatch): def forward_extend_multi_modal(self, batch: ScheduleBatch):
......
...@@ -29,7 +29,11 @@ class LlamaEmbeddingModel(nn.Module): ...@@ -29,7 +29,11 @@ class LlamaEmbeddingModel(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
input_metadata: InputMetadata, input_metadata: InputMetadata,
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
get_embedding: bool = True,
) -> EmbeddingPoolerOutput: ) -> EmbeddingPoolerOutput:
assert (
get_embedding
), "LlamaEmbeddingModel / MistralModel is only used for embedding"
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
return self.pooler(hidden_states, input_metadata) return self.pooler(hidden_states, input_metadata)
......
...@@ -38,6 +38,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -38,6 +38,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
...@@ -275,6 +276,7 @@ class Qwen2ForCausalLM(nn.Module): ...@@ -275,6 +276,7 @@ class Qwen2ForCausalLM(nn.Module):
self.model = Qwen2Model(config, quant_config=quant_config) self.model = Qwen2Model(config, quant_config=quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
@torch.no_grad() @torch.no_grad()
def forward( def forward(
...@@ -283,11 +285,15 @@ class Qwen2ForCausalLM(nn.Module): ...@@ -283,11 +285,15 @@ class Qwen2ForCausalLM(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
input_metadata: InputMetadata, input_metadata: InputMetadata,
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
get_embedding: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
return self.logits_processor( if not get_embedding:
input_ids, hidden_states, self.lm_head.weight, input_metadata return self.logits_processor(
) input_ids, hidden_states, self.lm_head.weight, input_metadata
)
else:
return self.pooler(hidden_states, input_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [
......
...@@ -333,11 +333,13 @@ def launch_server( ...@@ -333,11 +333,13 @@ def launch_server(
start_process = start_controller_process_single start_process = start_controller_process_single
else: else:
start_process = start_controller_process_multi start_process = start_controller_process_multi
proc_controller = mp.Process( proc_controller = mp.Process(
target=start_process, target=start_process,
args=(server_args, port_args, pipe_controller_writer, model_overide_args), args=(server_args, port_args, pipe_controller_writer, model_overide_args),
) )
proc_controller.start() proc_controller.start()
proc_detoken = mp.Process( proc_detoken = mp.Process(
target=start_detokenizer_process, target=start_detokenizer_process,
args=( args=(
...@@ -515,6 +517,7 @@ class Runtime: ...@@ -515,6 +517,7 @@ class Runtime:
self.pid = None self.pid = None
pipe_reader, pipe_writer = mp.Pipe(duplex=False) pipe_reader, pipe_writer = mp.Pipe(duplex=False)
proc = mp.Process( proc = mp.Process(
target=launch_server, target=launch_server,
args=(self.server_args, model_overide_args, pipe_writer), args=(self.server_args, model_overide_args, pipe_writer),
......
...@@ -38,6 +38,7 @@ class ServerArgs: ...@@ -38,6 +38,7 @@ class ServerArgs:
quantization: Optional[str] = None quantization: Optional[str] = None
served_model_name: Optional[str] = None served_model_name: Optional[str] = None
chat_template: Optional[str] = None chat_template: Optional[str] = None
is_embedding: bool = False
# Port # Port
host: str = "127.0.0.1" host: str = "127.0.0.1"
...@@ -200,6 +201,11 @@ class ServerArgs: ...@@ -200,6 +201,11 @@ class ServerArgs:
action="store_true", action="store_true",
help="Whether or not to allow for custom models defined on the Hub in their own modeling files.", help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
) )
parser.add_argument(
"--is-embedding",
action="store_true",
help="Whether to use a CausalLM as an embedding model.",
)
parser.add_argument( parser.add_argument(
"--context-length", "--context-length",
type=int, type=int,
...@@ -458,6 +464,11 @@ class ServerArgs: ...@@ -458,6 +464,11 @@ class ServerArgs:
assert not ( assert not (
self.dp_size > 1 and self.node_rank is not None self.dp_size > 1 and self.node_rank is not None
), "multi-node data parallel is not supported" ), "multi-node data parallel is not supported"
if "Alibaba-NLP/gte-Qwen2-1.5B-instruct" == self.model_path:
logger.info(
"Not sure why, the tokenizer will add an additional token at the end of the prompt when trust_remote_mode=True"
)
self.trust_remote_code = False
if "gemma-2" in self.model_path.lower(): if "gemma-2" in self.model_path.lower():
logger.info("When using sliding window in gemma-2, turn on flashinfer.") logger.info("When using sliding window in gemma-2, turn on flashinfer.")
self.disable_flashinfer = False self.disable_flashinfer = False
......
...@@ -224,13 +224,18 @@ def is_multimodal_model(model): ...@@ -224,13 +224,18 @@ def is_multimodal_model(model):
raise ValueError("unrecognized type") raise ValueError("unrecognized type")
def is_generation_model(model_architectures): def is_generation_model(model_architectures, is_embedding: bool = False):
# We have two ways to determine whether a model is a generative model.
# 1. Check the model architectue
# 2. check the `is_embedding` server args
if ( if (
"LlamaEmbeddingModel" in model_architectures "LlamaEmbeddingModel" in model_architectures
or "MistralModel" in model_architectures or "MistralModel" in model_architectures
): ):
return False return False
return True else:
return not is_embedding
def decode_video_base64(video_base64): def decode_video_base64(video_base64):
......
...@@ -14,7 +14,7 @@ limitations under the License. ...@@ -14,7 +14,7 @@ limitations under the License.
""" """
import json import json
import multiprocessing import multiprocessing as mp
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Union from typing import List, Union
...@@ -63,37 +63,35 @@ class HFRunner: ...@@ -63,37 +63,35 @@ class HFRunner:
self, self,
model_path, model_path,
torch_dtype, torch_dtype,
is_generation_model, is_generation,
): ):
self.in_queue = multiprocessing.Queue() self.is_generation = is_generation
self.out_queue = multiprocessing.Queue()
self.model_proc = multiprocessing.Process( self.in_queue = mp.Queue()
self.out_queue = mp.Queue()
self.model_proc = mp.Process(
target=self.start_model_process, target=self.start_model_process,
args=( args=(
self.in_queue, self.in_queue,
self.out_queue, self.out_queue,
model_path, model_path,
torch_dtype, torch_dtype,
is_generation_model,
), ),
) )
self.model_proc.start() self.model_proc.start()
def start_model_process( def start_model_process(self, in_queue, out_queue, model_path, torch_dtype):
self, in_queue, out_queue, model_path, torch_dtype, is_generation_model
):
self.tokenizer = AutoTokenizer.from_pretrained( self.tokenizer = AutoTokenizer.from_pretrained(
model_path, model_path,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
) )
self.is_generation_model = is_generation_model if self.is_generation:
if self.is_generation_model:
self.model = AutoModelForCausalLM.from_pretrained( self.model = AutoModelForCausalLM.from_pretrained(
model_path, model_path,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
trust_remote_code=False,
low_cpu_mem_usage=True, low_cpu_mem_usage=True,
).cuda() ).cuda()
else: else:
...@@ -107,7 +105,7 @@ class HFRunner: ...@@ -107,7 +105,7 @@ class HFRunner:
while True: while True:
prompts, max_new_tokens = in_queue.get() prompts, max_new_tokens = in_queue.get()
if prompts is not None: if prompts is not None:
if self.is_generation_model: if self.is_generation:
output_strs = [] output_strs = []
prefill_logprobs = [] prefill_logprobs = []
for p in prompts: for p in prompts:
...@@ -171,17 +169,19 @@ class SRTRunner: ...@@ -171,17 +169,19 @@ class SRTRunner:
self, self,
model_path, model_path,
torch_dtype, torch_dtype,
is_generation_model, is_generation,
tp_size=1, tp_size=1,
port=5157, port=5157,
): ):
self.is_generation_model = is_generation_model self.is_generation = is_generation
self.runtime = Runtime( self.runtime = Runtime(
model_path=model_path, model_path=model_path,
tp_size=tp_size, tp_size=tp_size,
dtype=get_dtype_str(torch_dtype), dtype=get_dtype_str(torch_dtype),
port=port, port=port,
mem_fraction_static=0.7, mem_fraction_static=0.7,
trust_remote_code=False,
is_embedding=not self.is_generation,
) )
def forward( def forward(
...@@ -189,7 +189,7 @@ class SRTRunner: ...@@ -189,7 +189,7 @@ class SRTRunner:
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS, prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
max_new_tokens=8, max_new_tokens=8,
): ):
if self.is_generation_model: if self.is_generation:
# the return value contains logprobs from prefill # the return value contains logprobs from prefill
output_strs = [] output_strs = []
top_input_logprobs = [] top_input_logprobs = []
......
...@@ -20,7 +20,10 @@ import torch ...@@ -20,7 +20,10 @@ import torch
from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner
from sglang.test.test_utils import get_similarities from sglang.test.test_utils import get_similarities
MODELS = [("intfloat/e5-mistral-7b-instruct", 1, 0.2)] MODELS = [
("Alibaba-NLP/gte-Qwen2-1.5B-instruct", 1, 1e-5),
("intfloat/e5-mistral-7b-instruct", 1, 1e-5),
]
TORCH_DTYPES = [torch.float16] TORCH_DTYPES = [torch.float16]
...@@ -32,10 +35,10 @@ class TestEmbeddingModels(unittest.TestCase): ...@@ -32,10 +35,10 @@ class TestEmbeddingModels(unittest.TestCase):
model_path, model_path,
tp_size, tp_size,
torch_dtype, torch_dtype,
long_context_tolerance, prefill_tolerance,
) -> None: ) -> None:
with HFRunner( with HFRunner(
model_path, torch_dtype=torch_dtype, is_generation_model=False model_path, torch_dtype=torch_dtype, is_generation=False
) as hf_runner: ) as hf_runner:
hf_outputs = hf_runner.forward(prompts) hf_outputs = hf_runner.forward(prompts)
...@@ -43,11 +46,9 @@ class TestEmbeddingModels(unittest.TestCase): ...@@ -43,11 +46,9 @@ class TestEmbeddingModels(unittest.TestCase):
model_path, model_path,
tp_size=tp_size, tp_size=tp_size,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
is_generation_model=False, is_generation=False,
) as srt_runner: ) as srt_runner:
srt_outputs = srt_runner.forward( srt_outputs = srt_runner.forward(prompts)
prompts,
)
for i in range(len(prompts)): for i in range(len(prompts)):
hf_logits = torch.Tensor(hf_outputs.embed_logits[i]) hf_logits = torch.Tensor(hf_outputs.embed_logits[i])
...@@ -57,18 +58,15 @@ class TestEmbeddingModels(unittest.TestCase): ...@@ -57,18 +58,15 @@ class TestEmbeddingModels(unittest.TestCase):
print("similarity diff", abs(similarity - 1)) print("similarity diff", abs(similarity - 1))
if len(prompts[i]) <= 1000: if len(prompts[i]) <= 1000:
tolerance = 1e-5 assert torch.all(
else: abs(similarity - 1) < prefill_tolerance
tolerance = long_context_tolerance ), "embeddings are not all close"
assert torch.all(
abs(similarity - 1) < tolerance
), "embeddings are not all close"
def test_prefill_logits(self): def test_prefill_logits(self):
for model, tp_size, long_context_tolerance in MODELS: for model, tp_size, prefill_tolerance in MODELS:
for torch_dtype in TORCH_DTYPES: for torch_dtype in TORCH_DTYPES:
self.assert_close_prefill_logits( self.assert_close_prefill_logits(
DEFAULT_PROMPTS, model, tp_size, torch_dtype, long_context_tolerance DEFAULT_PROMPTS, model, tp_size, torch_dtype, prefill_tolerance
) )
......
...@@ -20,12 +20,46 @@ import torch ...@@ -20,12 +20,46 @@ import torch
from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner
MODELS = [ MODELS = [
("meta-llama/Meta-Llama-3.1-8B-Instruct", 1, 1.1), ("meta-llama/Meta-Llama-3.1-8B-Instruct", 1, 1.1, 3e-2, 1),
("google/gemma-2-2b", 1, 3), ("google/gemma-2-2b", 1, 3, 3e-2, 1),
("Alibaba-NLP/gte-Qwen2-1.5B-instruct", 1, None, 6e-2, 1),
] ]
TORCH_DTYPES = [torch.float16] TORCH_DTYPES = [torch.float16]
def lcs(X, Y):
m = len(X)
n = len(Y)
L = [[0] * (n + 1) for _ in range(m + 1)]
for i in range(m + 1):
for j in range(n + 1):
if i == 0 or j == 0:
L[i][j] = 0
elif X[i - 1] == Y[j - 1]:
L[i][j] = L[i - 1][j - 1] + 1
else:
L[i][j] = max(L[i - 1][j], L[i][j - 1])
return L[m][n]
def calculate_rouge_l(output_strs_list1, output_strs_list2):
rouge_l_scores = []
for s1, s2 in zip(output_strs_list1, output_strs_list2):
lcs_len = lcs(s1, s2)
precision = lcs_len / len(s1) if len(s1) > 0 else 0
recall = lcs_len / len(s2) if len(s2) > 0 else 0
if precision + recall > 0:
fmeasure = (2 * precision * recall) / (precision + recall)
else:
fmeasure = 0.0
rouge_l_scores.append(fmeasure)
return rouge_l_scores
class TestGenerationModels(unittest.TestCase): class TestGenerationModels(unittest.TestCase):
def assert_close_prefill_logits_and_output_strs( def assert_close_prefill_logits_and_output_strs(
...@@ -35,10 +69,14 @@ class TestGenerationModels(unittest.TestCase): ...@@ -35,10 +69,14 @@ class TestGenerationModels(unittest.TestCase):
tp_size, tp_size,
torch_dtype, torch_dtype,
max_new_tokens, max_new_tokens,
prefill_tolerance,
rouge_threshold,
long_context_tolerance, long_context_tolerance,
) -> None: ) -> None:
if model_path == "Alibaba-NLP/gte-Qwen2-1.5B-instruct":
prompts = prompts[:-1]
with HFRunner( with HFRunner(
model_path, torch_dtype=torch_dtype, is_generation_model=True model_path, torch_dtype=torch_dtype, is_generation=True
) as hf_runner: ) as hf_runner:
hf_outputs = hf_runner.forward(prompts, max_new_tokens=max_new_tokens) hf_outputs = hf_runner.forward(prompts, max_new_tokens=max_new_tokens)
...@@ -46,7 +84,7 @@ class TestGenerationModels(unittest.TestCase): ...@@ -46,7 +84,7 @@ class TestGenerationModels(unittest.TestCase):
model_path, model_path,
tp_size=tp_size, tp_size=tp_size,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
is_generation_model=True, is_generation=True,
) as srt_runner: ) as srt_runner:
srt_outputs = srt_runner.forward(prompts, max_new_tokens=max_new_tokens) srt_outputs = srt_runner.forward(prompts, max_new_tokens=max_new_tokens)
...@@ -56,17 +94,34 @@ class TestGenerationModels(unittest.TestCase): ...@@ -56,17 +94,34 @@ class TestGenerationModels(unittest.TestCase):
print("max_diff", torch.max(abs(hf_logprobs - srt_logprobs))) print("max_diff", torch.max(abs(hf_logprobs - srt_logprobs)))
if hf_logprobs.shape[0] <= 100: if hf_logprobs.shape[0] <= 100:
tolerance = 3e-2
assert torch.all( assert torch.all(
abs(hf_logprobs - srt_logprobs) < tolerance abs(hf_logprobs - srt_logprobs) < prefill_tolerance
), "prefill logprobs are not all close" ), "prefill logprobs are not all close"
print(hf_outputs.output_strs) print(hf_outputs.output_strs)
print(srt_outputs.output_strs) print(srt_outputs.output_strs)
assert hf_outputs.output_strs == srt_outputs.output_strs rouge_l_scores = calculate_rouge_l(
hf_outputs.output_strs, srt_outputs.output_strs
)
assert all(
score >= rouge_threshold for score in rouge_l_scores
), f"Not all ROUGE-L scores are greater than {rouge_threshold}"
def test_prefill_logits_and_output_strs(self): def test_prefill_logits_and_output_strs(self):
for model, tp_size, long_context_tolerance in MODELS: import multiprocessing as mp
try:
mp.set_start_method("spawn")
except RuntimeError:
pass
for (
model,
tp_size,
long_context_tolerance,
prefill_tolerance,
rouge_threshold,
) in MODELS:
for torch_dtype in TORCH_DTYPES: for torch_dtype in TORCH_DTYPES:
max_new_tokens = 8 max_new_tokens = 8
self.assert_close_prefill_logits_and_output_strs( self.assert_close_prefill_logits_and_output_strs(
...@@ -75,6 +130,8 @@ class TestGenerationModels(unittest.TestCase): ...@@ -75,6 +130,8 @@ class TestGenerationModels(unittest.TestCase):
tp_size, tp_size,
torch_dtype, torch_dtype,
max_new_tokens, max_new_tokens,
prefill_tolerance=prefill_tolerance,
rouge_threshold=rouge_threshold,
long_context_tolerance=long_context_tolerance, long_context_tolerance=long_context_tolerance,
) )
......
...@@ -5,6 +5,9 @@ from sglang.test.test_utils import run_unittest_files ...@@ -5,6 +5,9 @@ from sglang.test.test_utils import run_unittest_files
suites = { suites = {
"minimal": [ "minimal": [
"models/test_embedding_models.py",
"models/test_generation_models.py",
"sampling/penaltylib",
"test_chunked_prefill.py", "test_chunked_prefill.py",
"test_embedding_openai_server.py", "test_embedding_openai_server.py",
"test_eval_accuracy_mini.py", "test_eval_accuracy_mini.py",
...@@ -13,11 +16,8 @@ suites = { ...@@ -13,11 +16,8 @@ suites = {
"test_skip_tokenizer_init.py", "test_skip_tokenizer_init.py",
"test_torch_compile.py", "test_torch_compile.py",
"test_triton_attn_backend.py", "test_triton_attn_backend.py",
"test_vision_openai_server.py",
"test_update_weights.py", "test_update_weights.py",
"models/test_generation_models.py", "test_vision_openai_server.py",
"models/test_embedding_models.py",
"sampling/penaltylib",
], ],
"sampling/penaltylib": glob.glob( "sampling/penaltylib": glob.glob(
"sampling/penaltylib/**/test_*.py", recursive=True "sampling/penaltylib/**/test_*.py", recursive=True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment