Unverified Commit 22c3702e authored by Ximingwang-09's avatar Ximingwang-09 Committed by GitHub
Browse files

[Model] Support Qwen2ForSequenceClassification (#4609)


Co-authored-by: default avatarximing.wxm <ximing.wxm@antgroup.com>
parent 4c584fc6
...@@ -54,6 +54,8 @@ ...@@ -54,6 +54,8 @@
- `python -m sglang.launch_server --model-path internlm/internlm2-7b-reward --is-embedding --trust-remote-code` - `python -m sglang.launch_server --model-path internlm/internlm2-7b-reward --is-embedding --trust-remote-code`
- Qwen2ForRewardModel - Qwen2ForRewardModel
- `python -m sglang.launch_server --model-path Qwen/Qwen2.5-Math-RM-72B --is-embedding --trust-remote-code --tp-size=4` - `python -m sglang.launch_server --model-path Qwen/Qwen2.5-Math-RM-72B --is-embedding --trust-remote-code --tp-size=4`
- Qwen2ForSequenceClassification
- `python -m sglang.launch_server --model-path jason9693/Qwen2.5-1.5B-apeach --is-embedding --trust-remote-code`
## How to Support a New Language Model ## How to Support a New Language Model
To support a new model in SGLang, you only need to add a single file under [SGLang Models Directory](https://github.com/sgl-project/sglang/tree/main/python/sglang/srt/models). To support a new model in SGLang, you only need to add a single file under [SGLang Models Directory](https://github.com/sgl-project/sglang/tree/main/python/sglang/srt/models).
......
...@@ -453,6 +453,7 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal ...@@ -453,6 +453,7 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal
or "LlamaForSequenceClassificationWithNormal_Weights" in model_architectures or "LlamaForSequenceClassificationWithNormal_Weights" in model_architectures
or "InternLM2ForRewardModel" in model_architectures or "InternLM2ForRewardModel" in model_architectures
or "Qwen2ForRewardModel" in model_architectures or "Qwen2ForRewardModel" in model_architectures
or "Qwen2ForSequenceClassification" in model_architectures
): ):
return False return False
else: else:
......
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from typing import Iterable, Optional, Tuple
import torch
from torch import nn
from transformers import Qwen2Config
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.models.qwen2 import Qwen2ForCausalLM, Qwen2Model
from sglang.srt.utils import add_prefix
class Qwen2ForSequenceClassification(nn.Module):
def __init__(
self,
config: Qwen2Config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.quant_config = quant_config
self.model = Qwen2Model(
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
)
self.score = nn.Linear(config.hidden_size, config.num_labels)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=False)
self.eos_token_id = config.eos_token_id
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
get_embedding: bool = True,
) -> EmbeddingPoolerOutput:
assert (
get_embedding
), "Qwen2ForSequenceClassification is only used for embedding"
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
logits = self.score(hidden_states)
pooled_logits = self.pooler(logits, forward_batch).embeddings
return EmbeddingPoolerOutput(pooled_logits)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# Filter out lm_head weights of Qwen2ForCausalLM
filtered_weights = [
(name, w) for name, w in weights if not name.startswith("lm_head")
]
return Qwen2ForCausalLM.load_weights(self, filtered_weights)
EntryClass = [
Qwen2ForSequenceClassification,
]
...@@ -13,18 +13,20 @@ ...@@ -13,18 +13,20 @@
# ============================================================================== # ==============================================================================
import multiprocessing as mp import multiprocessing as mp
import random
import unittest import unittest
import torch import torch
from transformers import AutoConfig, AutoTokenizer from transformers import AutoConfig, AutoTokenizer
from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner
from sglang.test.test_utils import get_similarities from sglang.test.test_utils import get_similarities, is_in_ci
MODELS = [ MODELS = [
("Alibaba-NLP/gte-Qwen2-1.5B-instruct", 1, 1e-5), ("Alibaba-NLP/gte-Qwen2-1.5B-instruct", 1, 1e-5),
("intfloat/e5-mistral-7b-instruct", 1, 1e-5), ("intfloat/e5-mistral-7b-instruct", 1, 1e-5),
("marco/mcdse-2b-v1", 1, 1e-5), ("marco/mcdse-2b-v1", 1, 1e-5),
("jason9693/Qwen2.5-1.5B-apeach", 1, 1e-5),
] ]
TORCH_DTYPES = [torch.float16] TORCH_DTYPES = [torch.float16]
...@@ -91,7 +93,12 @@ class TestEmbeddingModels(unittest.TestCase): ...@@ -91,7 +93,12 @@ class TestEmbeddingModels(unittest.TestCase):
), "embeddings are not all close" ), "embeddings are not all close"
def test_prefill_logits(self): def test_prefill_logits(self):
for model, tp_size, prefill_tolerance in MODELS: models_to_test = MODELS
if is_in_ci():
models_to_test = [random.choice(MODELS)]
for model, tp_size, prefill_tolerance in models_to_test:
for torch_dtype in TORCH_DTYPES: for torch_dtype in TORCH_DTYPES:
self.assert_close_prefill_logits( self.assert_close_prefill_logits(
DEFAULT_PROMPTS, model, tp_size, torch_dtype, prefill_tolerance DEFAULT_PROMPTS, model, tp_size, torch_dtype, prefill_tolerance
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment