Commit 1fa15099 authored by Lianmin Zheng's avatar Lianmin Zheng
Browse files

Add LlamaForClassification (#559)

parent 303ef888
...@@ -36,7 +36,7 @@ pip install "sglang[all]" ...@@ -36,7 +36,7 @@ pip install "sglang[all]"
### Method 2: From source ### Method 2: From source
``` ```
git clone git@github.com:sgl-project/sglang.git git clone https://github.com/sgl-project/sglang.git
cd sglang cd sglang
pip install --upgrade pip pip install --upgrade pip
......
from typing import Iterable, Optional, Tuple
import torch
import tqdm
from torch import nn
from transformers import LlamaConfig
from vllm.config import CacheConfig
from vllm.distributed import (
get_tensor_model_parallel_rank,
)
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.managers.controller.model_runner import InputMetadata
from sglang.srt.layers.logits_processor import LogitProcessorOutput
from sglang.srt.models.llama2 import LlamaModel
class LlamaForClassification(nn.Module):
def __init__(
self,
config: LlamaConfig,
quant_config: Optional[QuantizationConfig] = None,
cache_config: Optional[CacheConfig] = None,
) -> None:
super().__init__()
self.config = config
self.quant_config = quant_config
self.model = LlamaModel(config, quant_config=quant_config)
self.classification_head = nn.Linear(config.hidden_size, config.classification_out_size)
self.eos_token_id = config.eos_token_id
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
is_eos_token = input_ids == self.eos_token_id
hidden_states = hidden_states[is_eos_token]
scores = self.classification_head(hidden_states)
if scores.shape[0] != input_metadata.batch_size:
print("Warning: the EOS tokens are missing in some sentences.")
scores = torch.ones((input_metadata.batch_size, self.config.classification_out_size)).to(input_ids.device)
return LogitProcessorOutput(
next_token_logits=scores,
next_token_logprobs=scores,
normalized_prompt_logprobs=scores,
prefill_token_logprobs=torch.ones_like(input_ids),
prefill_top_logprobs=None,
decode_top_logprobs=None,
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
if get_tensor_model_parallel_rank() == 0:
weights = tqdm.tqdm(weights, total=int(len(params_dict) * 1.5))
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name or "projector" in name:
continue
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
if "lm_head" in name:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if name.startswith("model.vision_tower") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if name.startswith("model.vision_tower") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
EntryClass = LlamaForClassification
\ No newline at end of file
"""
Usage:
python3 test_httpserver_classify.py
"""
import argparse
import numpy as np
import requests
def get_logits(url, prompt):
response = requests.post(
url + "/generate",
json={
"text": prompt,
"sampling_params": {
"max_new_tokens": 0,
},
"return_logprob": True,
},
)
return response.json()["meta_info"]["normalized_prompt_logprob"]
def get_logits_batch(url, prompts):
response = requests.post(
url + "/generate",
json={
"text": prompts,
"sampling_params": {
"max_new_tokens": 0,
},
"return_logprob": True,
},
)
ret = response.json()
logits = np.array(list(
ret[i]["meta_info"]["normalized_prompt_logprob"]
for i in range(len(prompts))
))
return logits
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="http://127.0.0.1")
parser.add_argument("--port", type=int, default=30000)
args = parser.parse_args()
url = f"{args.host}:{args.port}"
# A single request
prompt = "This is a test prompt.<|eot_id|>"
logits = get_logits(url, prompt)
print(f"{logits=}")
# A batch of requests
prompts = [
"This is a test prompt.<|eot_id|>",
"This is another test prompt.<|eot_id|>",
"This is a long long long long test prompt.<|eot_id|>",
]
logits = get_logits_batch(url, prompts)
print(f"{logits=}")
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment