Unverified Commit 37f15475 authored by Marc Sun's avatar Marc Sun Committed by GitHub
Browse files

[FEAT] Add transformers backend support (#5929)

parent 8a548052
...@@ -63,6 +63,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s ...@@ -63,6 +63,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| `kv_cache_dtype` | Dtype of the kv cache. | `auto` | | `kv_cache_dtype` | Dtype of the kv cache. | `auto` |
| `context_length` | The model's maximum context length. Defaults to None (will use the value from the model's config.json instead). Note that extending the default might lead to strange behavior. | None | | `context_length` | The model's maximum context length. Defaults to None (will use the value from the model's config.json instead). Note that extending the default might lead to strange behavior. | None |
| `device` | The device we put the model. | None | | `device` | The device we put the model. | None |
| `impl` | The implementation of the model to use. Defaults to SGlang implementation and fall back to transformers if needed | `auto` |
| `served_model_name` | Override the model name returned by the v1/models endpoint in OpenAI API server.| None | | `served_model_name` | Override the model name returned by the v1/models endpoint in OpenAI API server.| None |
| `is_embedding` | Set to `true` to perform [embedding](./openai_api_embeddings.ipynb) / [encode](https://docs.sglang.ai/backend/native_api#Encode-(embedding-model)) and [reward](https://docs.sglang.ai/backend/native_api#Classify-(reward-model)) tasks. | `False` | | `is_embedding` | Set to `true` to perform [embedding](./openai_api_embeddings.ipynb) / [encode](https://docs.sglang.ai/backend/native_api#Encode-(embedding-model)) and [reward](https://docs.sglang.ai/backend/native_api#Classify-(reward-model)) tasks. | `False` |
| `revision` | Adjust if a specific version of the model should be used. | None | | `revision` | Adjust if a specific version of the model should be used. | None |
......
...@@ -47,6 +47,7 @@ The core features include: ...@@ -47,6 +47,7 @@ The core features include:
supported_models/embedding_models.md supported_models/embedding_models.md
supported_models/reward_models.md supported_models/reward_models.md
supported_models/support_new_models.md supported_models/support_new_models.md
supported_models/transformers_fallback.md
.. toctree:: .. toctree::
:maxdepth: 1 :maxdepth: 1
......
# Transformers fallback in SGLang
`sglang` can fall back to using models that are available in `transformers`. This works for most decoder-style language models and support for vision-language models is coming soon!
## Example launch Command
By default, we will use sglang implementation if it is available. Otherwise, we will fall back to transformers one. However, you can switch the implementation by setting `impl` to `transformers`.
```shell
python3 -m sglang.launch_server \
--model-path meta-llama/Llama-3.2-1B-Instruct \
--host 0.0.0.0 \
--port 30000 \
--impl transformers
```
#### Supported features
##### Quantization
Transformers fall back has supported most of available quantization in SGLang (except GGUF). See [Quantization page](https://docs.sglang.ai/backend/quantization.html) for more information about supported quantization in SGLang.
##### Remote code
This fallback also means that any model on the hub that can be used in `transformers` with `trust_remote_code=True` that correctly implements attention can be used in production!
A model just needs the following two things:
```python
from transformers import PreTrainedModel
from torch import nn
class MyAttention(nn.Module):
def forward(self, hidden_states, **kwargs): # <- kwargs are required
...
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
**kwargs,
)
...
class MyModel(PreTrainedModel):
_supports_attention_backend = True
```
Here is what happens in the background:
1. The config is loaded
2. `MyModel` python class is loaded from the `auto_map`, and we check that the model `_supports_attention_backend`.
3. The `TransformersModel` backend is used. See `/srt/models/transformers`, which leverages `self.config._attn_implementation = "sglang"`, thus the need to use `ALL_ATTENTION_FUNCTIONS`.
That's it!
...@@ -16,7 +16,7 @@ import json ...@@ -16,7 +16,7 @@ import json
import logging import logging
import math import math
import os import os
from enum import IntEnum, auto from enum import Enum, IntEnum, auto
from typing import List, Optional, Set, Union from typing import List, Optional, Set, Union
import torch import torch
...@@ -39,6 +39,12 @@ class AttentionArch(IntEnum): ...@@ -39,6 +39,12 @@ class AttentionArch(IntEnum):
MHA = auto() MHA = auto()
class ModelImpl(str, Enum):
AUTO = "auto"
SGLANG = "sglang"
TRANSFORMERS = "transformers"
class ModelConfig: class ModelConfig:
def __init__( def __init__(
self, self,
...@@ -53,11 +59,13 @@ class ModelConfig: ...@@ -53,11 +59,13 @@ class ModelConfig:
quantization: Optional[str] = None, quantization: Optional[str] = None,
override_config_file: Optional[str] = None, override_config_file: Optional[str] = None,
is_draft_model: bool = False, is_draft_model: bool = False,
impl: Union[str, ModelImpl] = ModelImpl.AUTO,
) -> None: ) -> None:
self.model_path = model_path self.model_path = model_path
self.revision = revision self.revision = revision
self.quantization = quantization self.quantization = quantization
self.impl = impl
# Parse args # Parse args
self.maybe_pull_model_tokenizer_from_remote() self.maybe_pull_model_tokenizer_from_remote()
...@@ -256,6 +264,7 @@ class ModelConfig: ...@@ -256,6 +264,7 @@ class ModelConfig:
enable_multimodal=server_args.enable_multimodal, enable_multimodal=server_args.enable_multimodal,
dtype=server_args.dtype, dtype=server_args.dtype,
quantization=server_args.quantization, quantization=server_args.quantization,
impl=server_args.impl,
**kwargs, **kwargs,
) )
......
...@@ -2,12 +2,17 @@ ...@@ -2,12 +2,17 @@
"""Utilities for selecting and loading models.""" """Utilities for selecting and loading models."""
import contextlib import contextlib
import logging
from typing import Tuple, Type from typing import Tuple, Type
import torch import torch
import transformers
from torch import nn from torch import nn
from transformers.dynamic_module_utils import get_class_from_dynamic_module
from sglang.srt.configs.model_config import ModelConfig from sglang.srt.configs.model_config import ModelConfig, ModelImpl
logger = logging.getLogger(__name__)
@contextlib.contextmanager @contextlib.contextmanager
...@@ -19,6 +24,61 @@ def set_default_torch_dtype(dtype: torch.dtype): ...@@ -19,6 +24,61 @@ def set_default_torch_dtype(dtype: torch.dtype):
torch.set_default_dtype(old_dtype) torch.set_default_dtype(old_dtype)
def resolve_transformers_arch(model_config: ModelConfig, architectures: list[str]):
for i, arch in enumerate(architectures):
if arch == "TransformersForCausalLM":
continue
auto_map: dict[str, str] = (
getattr(model_config.hf_config, "auto_map", None) or dict()
)
# Make sure that config class is always initialized before model class,
# otherwise the model class won't be able to access the config class,
# the expected auto_map should have correct order like:
# "auto_map": {
# "AutoConfig": "<your-repo-name>--<config-name>",
# "AutoModel": "<your-repo-name>--<config-name>",
# "AutoModelFor<Task>": "<your-repo-name>--<config-name>",
# },
auto_modules = {
name: get_class_from_dynamic_module(
module, model_config.model_path, revision=model_config.revision
)
for name, module in sorted(auto_map.items(), key=lambda x: x[0])
}
model_module = getattr(transformers, arch, None)
if model_module is None:
if "AutoModel" not in auto_map:
raise ValueError(
f"Cannot find model module. '{arch}' is not a registered "
"model in the Transformers library (only relevant if the "
"model is meant to be in Transformers) and 'AutoModel' is "
"not present in the model config's 'auto_map' (relevant "
"if the model is custom)."
)
model_module = auto_modules["AutoModel"]
if model_config.impl == ModelImpl.TRANSFORMERS:
if not model_module.is_backend_compatible():
raise ValueError(
f"The Transformers implementation of {arch} is not "
"compatible with vLLM."
)
architectures[i] = "TransformersForCausalLM"
if model_config.impl == ModelImpl.AUTO:
if not model_module.is_backend_compatible():
raise ValueError(
f"{arch} has no SGlang implementation and the Transformers "
"implementation is not compatible with SGLang."
)
logger.warning(
"%s has no SGLang implementation, falling back to Transformers "
"implementation. Some features may not be supported and "
"performance may not be optimal.",
arch,
)
architectures[i] = "TransformersForCausalLM"
return architectures
def get_model_architecture(model_config: ModelConfig) -> Tuple[Type[nn.Module], str]: def get_model_architecture(model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
from sglang.srt.models.registry import ModelRegistry from sglang.srt.models.registry import ModelRegistry
...@@ -34,6 +94,12 @@ def get_model_architecture(model_config: ModelConfig) -> Tuple[Type[nn.Module], ...@@ -34,6 +94,12 @@ def get_model_architecture(model_config: ModelConfig) -> Tuple[Type[nn.Module],
): ):
architectures = ["QuantMixtralForCausalLM"] architectures = ["QuantMixtralForCausalLM"]
supported_archs = ModelRegistry.get_supported_archs()
is_native_supported = any(arch in supported_archs for arch in architectures)
if not is_native_supported or model_config.impl == ModelImpl.TRANSFORMERS:
architectures = resolve_transformers_arch(model_config, architectures)
return ModelRegistry.resolve_model_cls(architectures) return ModelRegistry.resolve_model_cls(architectures)
......
...@@ -49,7 +49,15 @@ class _ModelRegistry: ...@@ -49,7 +49,15 @@ class _ModelRegistry:
if not architectures: if not architectures:
logger.warning("No model architectures are specified") logger.warning("No model architectures are specified")
return architectures # filter out support architectures
normalized_arch = list(
filter(lambda model: model in self.models, architectures)
)
# make sure Transformers backend is put at the last as a fallback
if len(normalized_arch) != len(architectures):
normalized_arch.append("TransformersForCausalLM")
return normalized_arch
def resolve_model_cls( def resolve_model_cls(
self, self,
......
# Copyright 2025 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Adapted from
# https://github.com/vllm-project/vllm/blob/a1a2aaadb9122f05667140e39cf67e5736c8b6d6/vllm/model_executor/models/transformers.py
"""Wrapper around `transformers` models"""
import logging
import re
from typing import Iterable, Literal, Optional, Tuple, Union
import torch
from torch import nn
from transformers import AutoModel, PretrainedConfig, PreTrainedModel
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from sglang.srt.distributed import divide, get_tensor_model_parallel_world_size
from sglang.srt.layers.linear import (
ColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
logger = logging.getLogger(__name__)
def maybe_prefix(prefix: str, name: str) -> str:
"""Add a prefix to a name if the prefix is non-empty.
Args:
prefix: The prefix to add. If empty, no prefix will be added.
name: The name to potentially prefix.
Returns:
The string "prefix.name" if prefix was non-empty, otherwise just "name".
"""
return name if not prefix else f"{prefix}.{name}"
def sglang_flash_attention_forward(
# Transformers args
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: torch.Tensor,
# sglang kwargs
forward_batch: ForwardBatch,
# Transformers kwargs
scaling: float = None,
attention_instances: list[RadixAttention] = None,
**kwargs,
):
self_attn: RadixAttention = attention_instances[module.layer_idx]
if scaling is not None:
self_attn.scaling = float(scaling)
hidden = query.shape[-2]
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
query, key, value = (x.reshape(hidden, -1) for x in (query, key, value))
return self_attn.forward(query, key, value, forward_batch=forward_batch), None
ALL_ATTENTION_FUNCTIONS["sglang"] = sglang_flash_attention_forward
class HFColumnParallelLinear(ColumnParallelLinear):
def forward(self, input: torch.Tensor) -> torch.Tensor:
return super().forward(input)[0]
class HFRowParallelLinear(RowParallelLinear):
def forward(self, input: torch.Tensor) -> torch.Tensor:
return super().forward(input)[0]
def replace_linear_class(
linear: nn.Linear,
style: Literal["colwise", "rowwise"],
quant_config: QuantizationConfig,
) -> Union[ColumnParallelLinear, RowParallelLinear]:
"""
Replace nn.Linear with one of vLLM's tensor parallel linear classes.
Args:
linear (nn.Linear): `nn.Linear` to be replaced.
style (str): Tensor parallel style of the new linear, e.g. "colwise".
quant_config (QuantConfig): Quantization config for the new linear.
Returns:
Union[ColumnParallelLinear, RowParallelLinear]: The new linear.
"""
if not isinstance(style, str):
raise ValueError(f"Unsupported parallel style type {type(style)}, expected str")
sglang_linear_cls = {
"colwise": ColumnParallelLinear,
"rowwise": RowParallelLinear,
}.get(style, ReplicatedLinear)
class HFCompatibleLinear(sglang_linear_cls):
"""
Wrapper class that removes `output_bias` from returned output.
"""
@property
def parent_cls(self) -> type:
return sglang_linear_cls
def forward(self, input: torch.Tensor) -> torch.Tensor:
return super().forward(input)[0]
return HFCompatibleLinear(
input_size=linear.in_features,
output_size=linear.out_features,
bias=linear.bias is not None,
quant_config=quant_config,
)
class TransformersForCausalLM(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
logger.info("Using Transformers backend.")
self.quant_config = quant_config
self.config = config
self.vocab_size = config.vocab_size
self.unpadded_vocab_size = config.vocab_size
# model is loaded under set_default_torch_dtype(model_config.dtype)
self.model: PreTrainedModel = AutoModel.from_config(
self.config,
torch_dtype=torch.get_default_dtype(),
attn_implementation="sglang",
trust_remote_code=True,
)
# Attention modifications (assumes 1 attention op per hidden layer)
tp_size = get_tensor_model_parallel_world_size()
# MLP modifications
self.tensor_parallel(tp_size)
head_dim = (
(config.hidden_size // config.num_attention_heads)
if not hasattr(config, "head_dim")
else config.head_dim
)
self.attention_instances = [
RadixAttention(
num_heads=divide(config.num_attention_heads, tp_size),
head_dim=head_dim,
# NOTE: We use Llama scale as default, if it's set by
# Transformers, it's updated in sglang_flash_attention_forward
scaling=head_dim**-0.5,
num_kv_heads=divide(config.num_key_value_heads, tp_size),
layer_id=i,
quant_config=self.quant_config,
prefix=f"{i}.attn",
)
for i in range(config.num_hidden_layers)
]
# Model modifications
self.replace_vocab_embed_class(self.model)
# ForCausalLM modifications
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=self.quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
if config.tie_word_embeddings:
self.lm_head.weight = self.model.get_input_embeddings().weight
self.logits_processor = LogitsProcessor(config)
def log_replacement(self, name: str, old_module: nn.Module, new_module: nn.Module):
logger.debug("%s: %s -> %s", name, old_module, new_module)
def tensor_parallel(self, tp_size: int):
"""
Apply the model's tensor parallelization plan.
Currently only supports linear layers.
"""
if not self.model.supports_tp_plan:
if tp_size <= 1:
return
raise ValueError(
f"{type(self.model)} does not support tensor parallel yet!"
)
tp_plan = self.model._tp_plan
def _tensor_parallel(module: nn.Module, prefix: str = ""):
for child_name, child_module in module.named_children():
qual_name = maybe_prefix(prefix, child_name)
for pattern, style in tp_plan.items():
if re.match(pattern, qual_name) and isinstance(
child_module, nn.Linear
):
new_module = replace_linear_class(
child_module, style, self.quant_config
)
setattr(module, child_name, new_module)
self.log_replacement(qual_name, child_module, new_module)
else:
_tensor_parallel(child_module, prefix=qual_name)
_tensor_parallel(self.model)
def replace_vocab_embed_class(self, module: nn.Module):
# Use native set input embeddings
new_module = VocabParallelEmbedding(
self.vocab_size,
self.config.hidden_size,
org_num_embeddings=self.config.vocab_size,
quant_config=None,
)
self.log_replacement(
"input embedding", self.model.get_input_embeddings(), new_module
)
self.model.set_input_embeddings(new_module)
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
get_embedding: bool = False,
) -> LogitsProcessorOutput:
assert get_embedding is False, "embedding is not supported yet"
aux_hidden_states = None
hidden_states = self.model(
input_ids[None, ...],
use_cache=False,
position_ids=positions[None, ...],
forward_batch=forward_batch,
attention_instances=self.attention_instances,
return_dict=False,
)[0][
0, ...
] # we remove batch dimension for now
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch, aux_hidden_states
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if name not in params_dict:
name = f"{self.model.base_model_prefix}.{name}"
if name in params_dict:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
EntryClass = [TransformersForCausalLM]
...@@ -61,6 +61,7 @@ class ServerArgs: ...@@ -61,6 +61,7 @@ class ServerArgs:
is_embedding: bool = False is_embedding: bool = False
enable_multimodal: Optional[bool] = None enable_multimodal: Optional[bool] = None
revision: Optional[str] = None revision: Optional[str] = None
impl: str = "auto"
# Port for the HTTP server # Port for the HTTP server
host: str = "127.0.0.1" host: str = "127.0.0.1"
...@@ -726,6 +727,18 @@ class ServerArgs: ...@@ -726,6 +727,18 @@ class ServerArgs:
default=ServerArgs.page_size, default=ServerArgs.page_size,
help="The number of tokens in a page.", help="The number of tokens in a page.",
) )
parser.add_argument(
"--impl",
type=str,
default=ServerArgs.impl,
help="Which implementation of the model to use.\n\n"
'* "auto" will try to use the SGLang implementation if it exists '
"and fall back to the Transformers implementation if no SGLang "
"implementation is available.\n"
'* "sglang" will use the SGLang model implementation.\n'
'* "transformers" will use the Transformers model '
"implementation.\n",
)
# Other runtime options # Other runtime options
parser.add_argument( parser.add_argument(
......
...@@ -455,6 +455,7 @@ class SRTRunner: ...@@ -455,6 +455,7 @@ class SRTRunner:
torch_dtype: torch.dtype, torch_dtype: torch.dtype,
model_type: str, model_type: str,
tp_size: int = 1, tp_size: int = 1,
impl: str = "auto",
port: int = DEFAULT_PORT_FOR_SRT_TEST_RUNNER, port: int = DEFAULT_PORT_FOR_SRT_TEST_RUNNER,
lora_paths: List[str] = None, lora_paths: List[str] = None,
max_loras_per_batch: int = 4, max_loras_per_batch: int = 4,
...@@ -475,6 +476,7 @@ class SRTRunner: ...@@ -475,6 +476,7 @@ class SRTRunner:
speculative_num_draft_tokens: Optional[int] = None, speculative_num_draft_tokens: Optional[int] = None,
disable_overlap_schedule: bool = False, disable_overlap_schedule: bool = False,
disable_custom_all_reduce: bool = False, disable_custom_all_reduce: bool = False,
torchao_config: Optional[str] = None,
): ):
self.model_type = model_type self.model_type = model_type
self.is_generation = model_type == "generation" self.is_generation = model_type == "generation"
...@@ -493,6 +495,8 @@ class SRTRunner: ...@@ -493,6 +495,8 @@ class SRTRunner:
tp_size=tp_size, tp_size=tp_size,
dtype=get_dtype_str(torch_dtype), dtype=get_dtype_str(torch_dtype),
port=port, port=port,
impl=impl,
torchao_config=torchao_config,
mem_fraction_static=mem_fraction_static, mem_fraction_static=mem_fraction_static,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
is_embedding=not self.is_generation, is_embedding=not self.is_generation,
......
import dataclasses
import multiprocessing as mp
import unittest
from types import SimpleNamespace
from typing import List
import torch
from sglang.srt.utils import kill_process_tree
from sglang.test.runners import DEFAULT_PROMPTS, SRTRunner, check_close_model_outputs
from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
is_in_ci,
popen_launch_server,
)
class TestTransformersFallbackEndpoint(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=["--impl", "transformers"],
)
cls.mmlu_lower_bound = 0.65
cls.gsm8k_lower_bound = 0.65
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_mmlu(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
)
from sglang.test.run_eval import run_eval
metrics = run_eval(args)
self.assertGreaterEqual(metrics["score"], self.mmlu_lower_bound)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
from sglang.test.few_shot_gsm8k import run_eval
metrics = run_eval(args)
print(f"{metrics=}")
self.assertGreater(metrics["accuracy"], self.gsm8k_lower_bound)
class TestTransformersFallbackTorchAO(TestTransformersFallbackEndpoint):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--impl",
"transformers",
"--torchao-config",
"int4wo-128",
],
)
cls.mmlu_lower_bound = 0.65
cls.gsm8k_lower_bound = 0.65
@dataclasses.dataclass
class ModelCase:
model_path: str
tp_size: int = 1
prefill_tolerance: float = 5e-2
decode_tolerance: float = 5e-2
rouge_l_tolerance: float = 1
skip_long_prompt: bool = False
trust_remote_code: bool = False
torchao_config: str = None
torch_dtype: torch.dtype = torch.float16
# Popular models that run on the CI
CI_MODELS = [
ModelCase(DEFAULT_MODEL_NAME_FOR_TEST),
]
ALL_OTHER_MODELS = [
ModelCase(DEFAULT_MODEL_NAME_FOR_TEST, tp_size=2),
]
class TestTransformersFallbackEngine(CustomTestCase):
@classmethod
def setUpClass(cls):
mp.set_start_method("spawn", force=True)
def assert_close_logits_and_output_strs(
self,
prompts: List[str],
model_case: ModelCase,
) -> None:
model_path = model_case.model_path
max_new_tokens = 32
# force to use transformers impl
with SRTRunner(
model_path,
tp_size=model_case.tp_size,
torch_dtype=model_case.torch_dtype,
model_type="generation",
impl="transformers",
trust_remote_code=model_case.trust_remote_code,
torchao_config=model_case.torchao_config,
) as srt_runner:
srt_outputs = srt_runner.forward(prompts, max_new_tokens=max_new_tokens)
with SRTRunner(
model_path,
tp_size=model_case.tp_size,
torch_dtype=model_case.torch_dtype,
model_type="generation",
trust_remote_code=model_case.trust_remote_code,
torchao_config=model_case.torchao_config,
) as srt_runner:
srt_transformers_outputs = srt_runner.forward(
prompts, max_new_tokens=max_new_tokens
)
check_close_model_outputs(
hf_outputs=srt_transformers_outputs,
srt_outputs=srt_outputs,
prefill_tolerance=model_case.prefill_tolerance,
decode_tolerance=model_case.decode_tolerance,
rouge_l_tolerance=model_case.rouge_l_tolerance,
debug_text=f"model_path={model_path} prompts={prompts}",
)
def test_ci_models(self):
for model_case in CI_MODELS:
# Skip long prompts for models that do not have a long context
prompts = DEFAULT_PROMPTS
if model_case.skip_long_prompt:
prompts = [p for p in DEFAULT_PROMPTS if len(p) < 1000]
# Assert the logits and output strs are close
self.assert_close_logits_and_output_strs(prompts, model_case)
def test_others(self):
if is_in_ci():
return
# Skip long prompts for models that do not have a long context
prompts = DEFAULT_PROMPTS
for model_case in ALL_OTHER_MODELS:
if model_case.skip_long_prompt:
prompts = [p for p in DEFAULT_PROMPTS if len(p) < 1000]
# Assert the logits and output strs are close
self.assert_close_logits_and_output_strs(prompts, model_case)
if __name__ == "__main__":
unittest.main()
...@@ -26,6 +26,7 @@ suites = { ...@@ -26,6 +26,7 @@ suites = {
TestFile("models/test_qwen_models.py", 82), TestFile("models/test_qwen_models.py", 82),
TestFile("models/test_reward_models.py", 132), TestFile("models/test_reward_models.py", 132),
TestFile("models/test_vlm_models.py", 437), TestFile("models/test_vlm_models.py", 437),
TestFile("models/test_transformers_models.py", 320),
TestFile("test_abort.py", 51), TestFile("test_abort.py", 51),
TestFile("test_block_int8.py", 22), TestFile("test_block_int8.py", 22),
TestFile("test_create_kvindices.py", 2), TestFile("test_create_kvindices.py", 2),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment