"docs/vscode:/vscode.git/clone" did not exist on "4684ea2fe8d568f44c491068c3eb94aac27045f3"
Unverified Commit 22c3702e authored by Ximingwang-09's avatar Ximingwang-09 Committed by GitHub
Browse files

[Model] Support Qwen2ForSequenceClassification (#4609)


Co-authored-by: default avatarximing.wxm <ximing.wxm@antgroup.com>
parent 4c584fc6
......@@ -54,6 +54,8 @@
- `python -m sglang.launch_server --model-path internlm/internlm2-7b-reward --is-embedding --trust-remote-code`
- Qwen2ForRewardModel
- `python -m sglang.launch_server --model-path Qwen/Qwen2.5-Math-RM-72B --is-embedding --trust-remote-code --tp-size=4`
- Qwen2ForSequenceClassification
- `python -m sglang.launch_server --model-path jason9693/Qwen2.5-1.5B-apeach --is-embedding --trust-remote-code`
## How to Support a New Language Model
To support a new model in SGLang, you only need to add a single file under [SGLang Models Directory](https://github.com/sgl-project/sglang/tree/main/python/sglang/srt/models).
......
......@@ -453,6 +453,7 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal
or "LlamaForSequenceClassificationWithNormal_Weights" in model_architectures
or "InternLM2ForRewardModel" in model_architectures
or "Qwen2ForRewardModel" in model_architectures
or "Qwen2ForSequenceClassification" in model_architectures
):
return False
else:
......
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from typing import Iterable, Optional, Tuple
import torch
from torch import nn
from transformers import Qwen2Config
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.models.qwen2 import Qwen2ForCausalLM, Qwen2Model
from sglang.srt.utils import add_prefix
class Qwen2ForSequenceClassification(nn.Module):
def __init__(
self,
config: Qwen2Config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.quant_config = quant_config
self.model = Qwen2Model(
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
)
self.score = nn.Linear(config.hidden_size, config.num_labels)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=False)
self.eos_token_id = config.eos_token_id
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
get_embedding: bool = True,
) -> EmbeddingPoolerOutput:
assert (
get_embedding
), "Qwen2ForSequenceClassification is only used for embedding"
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
logits = self.score(hidden_states)
pooled_logits = self.pooler(logits, forward_batch).embeddings
return EmbeddingPoolerOutput(pooled_logits)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# Filter out lm_head weights of Qwen2ForCausalLM
filtered_weights = [
(name, w) for name, w in weights if not name.startswith("lm_head")
]
return Qwen2ForCausalLM.load_weights(self, filtered_weights)
EntryClass = [
Qwen2ForSequenceClassification,
]
......@@ -13,18 +13,20 @@
# ==============================================================================
import multiprocessing as mp
import random
import unittest
import torch
from transformers import AutoConfig, AutoTokenizer
from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner
from sglang.test.test_utils import get_similarities
from sglang.test.test_utils import get_similarities, is_in_ci
MODELS = [
("Alibaba-NLP/gte-Qwen2-1.5B-instruct", 1, 1e-5),
("intfloat/e5-mistral-7b-instruct", 1, 1e-5),
("marco/mcdse-2b-v1", 1, 1e-5),
("jason9693/Qwen2.5-1.5B-apeach", 1, 1e-5),
]
TORCH_DTYPES = [torch.float16]
......@@ -91,7 +93,12 @@ class TestEmbeddingModels(unittest.TestCase):
), "embeddings are not all close"
def test_prefill_logits(self):
for model, tp_size, prefill_tolerance in MODELS:
models_to_test = MODELS
if is_in_ci():
models_to_test = [random.choice(MODELS)]
for model, tp_size, prefill_tolerance in models_to_test:
for torch_dtype in TORCH_DTYPES:
self.assert_close_prefill_logits(
DEFAULT_PROMPTS, model, tp_size, torch_dtype, prefill_tolerance
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment