Unverified Commit 63e97e5e authored by Arcmoon's avatar Arcmoon Committed by GitHub
Browse files

Suppport qwen model and solve some problems (#75)

parent e08bca28
...@@ -316,6 +316,7 @@ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port ...@@ -316,6 +316,7 @@ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port
- Mixtral - Mixtral
- LLaVA - LLaVA
- `python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --port 30000` - `python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --port 30000`
- Qwen
- AWQ quantization - AWQ quantization
## Benchmark And Performance ## Benchmark And Performance
......
...@@ -61,7 +61,6 @@ class RadixAttention(nn.Module): ...@@ -61,7 +61,6 @@ class RadixAttention(nn.Module):
def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata): def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata):
o = torch.empty_like(q) o = torch.empty_like(q)
self.store_kv_cache(k, v, input_metadata) self.store_kv_cache(k, v, input_metadata)
extend_attention_fwd( extend_attention_fwd(
q.view(-1, self.tp_q_head_num, self.head_dim), q.view(-1, self.tp_q_head_num, self.head_dim),
k.contiguous(), k.contiguous(),
......
...@@ -55,6 +55,7 @@ class DetokenizerManager: ...@@ -55,6 +55,7 @@ class DetokenizerManager:
first_token = self.tokenizer.convert_ids_to_tokens( first_token = self.tokenizer.convert_ids_to_tokens(
int(output_tokens[i][0]) int(output_tokens[i][0])
) )
first_token = first_token.decode("utf-8")
if first_token.startswith("▁"): if first_token.startswith("▁"):
output_strs[i] = " " + output_strs[i] output_strs[i] = " " + output_strs[i]
......
...@@ -240,6 +240,7 @@ class ModelRunner: ...@@ -240,6 +240,7 @@ class ModelRunner:
from sglang.srt.models.llama2 import LlamaForCausalLM from sglang.srt.models.llama2 import LlamaForCausalLM
from sglang.srt.models.llava import LlavaLlamaForCausalLM from sglang.srt.models.llava import LlavaLlamaForCausalLM
from sglang.srt.models.mixtral import MixtralForCausalLM from sglang.srt.models.mixtral import MixtralForCausalLM
from sglang.srt.models.qwen import QWenLMHeadModel
# Select model class # Select model class
architectures = getattr(self.model_config.hf_config, "architectures", []) architectures = getattr(self.model_config.hf_config, "architectures", [])
...@@ -258,6 +259,9 @@ class ModelRunner: ...@@ -258,6 +259,9 @@ class ModelRunner:
if arch == "MixtralForCausalLM": if arch == "MixtralForCausalLM":
model_class = MixtralForCausalLM model_class = MixtralForCausalLM
break break
if arch == "QWenLMHeadModel":
model_class = QWenLMHeadModel
break
if model_class is None: if model_class is None:
raise ValueError(f"Unsupported architectures: {architectures}") raise ValueError(f"Unsupported architectures: {architectures}")
......
...@@ -20,8 +20,10 @@ class ModelConfig: ...@@ -20,8 +20,10 @@ class ModelConfig:
# Unify the config keys for hf_config # Unify the config keys for hf_config
self.context_len = get_context_length(self.hf_config) self.context_len = get_context_length(self.hf_config)
self.head_dim = self.hf_config.hidden_size // self.hf_config.num_attention_heads self.head_dim = self.hf_config.hidden_size // self.hf_config.num_attention_heads
self.num_key_value_heads = self.hf_config.num_key_value_heads
self.num_attention_heads = self.hf_config.num_attention_heads self.num_attention_heads = self.hf_config.num_attention_heads
self.num_key_value_heads = getattr(self.hf_config, "num_key_value_heads", None)
if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads
self.hidden_size = self.hf_config.hidden_size self.hidden_size = self.hf_config.hidden_size
self.num_hidden_layers = self.hf_config.num_hidden_layers self.num_hidden_layers = self.hf_config.num_hidden_layers
self.vocab_size = self.hf_config.vocab_size self.vocab_size = self.hf_config.vocab_size
from typing import Any, Dict, List, Optional, Tuple
import torch
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.router.model_runner import InputMetadata
from torch import nn
from vllm.transformers_utils.configs.qwen import QWenConfig
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
LinearMethodBase,
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size,
)
from vllm.model_executor.weight_utils import (
default_weight_loader,
hf_model_weights_iterator,
)
class QWenMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str = "silu",
):
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
2 * [intermediate_size],
bias=False,
gather_output=False,
)
self.c_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
input_is_parallel=True,
)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.c_proj(x)
return x
class QWenAttention(nn.Module):
def __init__(self,
hidden_size: int,
num_heads: int,
max_position_embeddings: int,
layer_id: int = 0,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None):
super().__init__()
self.hidden_size = hidden_size
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
)
self.total_num_heads = num_heads
assert self.total_num_heads % tensor_model_parallel_world_size == 0
self.num_heads = (self.total_num_heads //
tensor_model_parallel_world_size)
self.head_dim = hidden_size // self.total_num_heads
# pylint: disable=invalid-name
self.c_attn = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
bias=True
)
self.c_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
input_is_parallel=True,
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
)
self.scaling = self.head_dim**-0.5
self.attn = RadixAttention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_heads,
layer_id=layer_id,
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
) -> torch.Tensor:
qkv, _ = self.c_attn(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, input_metadata)
output, _ = self.c_proj(attn_output)
return output
class QWenBlock(nn.Module):
def __init__(self, config: QWenConfig,layer_id):
super().__init__()
self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
self.attn = QWenAttention(config.hidden_size,
config.num_attention_heads,
config.max_position_embeddings,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
layer_id=layer_id)
self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.mlp = QWenMLP(config.hidden_size, config.intermediate_size // 2)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
) -> torch.Tensor:
# Self Attention
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
hidden_states = self.attn(
positions=positions,
hidden_states=hidden_states,
input_metadata=input_metadata,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.ln_2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class QWenModel(nn.Module):
def __init__(self, config:QWenConfig):
super().__init__()
self.config = config
self.vocab_size = config.vocab_size
vocab_size = ((config.vocab_size + 63) // 64) * 64
self.wte = VocabParallelEmbedding(
vocab_size,
config.hidden_size,
)
self.h = nn.ModuleList(
[QWenBlock(config, i) for i in range(config.num_hidden_layers)])
self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
) -> torch.Tensor:
hidden_states = self.wte(input_ids)
for i in range(len(self.h)):
layer = self.h[i]
hidden_states = layer(
positions,
hidden_states,
input_metadata,
)
hidden_states = self.ln_f(hidden_states)
return hidden_states
class QWenLMHeadModel(nn.Module):
def __init__(self, config: QWenConfig,linear_method=None):
super().__init__()
self.config = config
self.transformer = QWenModel(config)
vocab_size = ((config.vocab_size + 63) // 64) * 64
self.lm_head = ParallelLMHead(
vocab_size,
config.hidden_size
)
self.logits_processor = LogitsProcessor(config)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata
):
hidden_states = self.transformer(input_ids, positions,input_metadata)
next_tokens = self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
)
return next_tokens
_column_parallel_weights = []
_row_parallel_weights = ["c_proj.weight"]
def load_weights(
self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None,
):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("gate_up_proj", "w2", 0),
("gate_up_proj", "w1", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision
):
if "rotary_emb.inv_freq" in name:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
...@@ -108,9 +108,11 @@ def get_exception_traceback(): ...@@ -108,9 +108,11 @@ def get_exception_traceback():
def get_int_token_logit_bias(tokenizer, vocab_size): def get_int_token_logit_bias(tokenizer, vocab_size):
from transformers import LlamaTokenizer, LlamaTokenizerFast from transformers import LlamaTokenizer, LlamaTokenizerFast
# a bug when model's vocab size > tokenizer.vocab_size
vocab_size = tokenizer.vocab_size
logit_bias = np.zeros(vocab_size, dtype=np.float32) logit_bias = np.zeros(vocab_size, dtype=np.float32)
for t_id in range(vocab_size): for t_id in range(vocab_size):
ss = tokenizer.decode(t_id).strip() ss = tokenizer.decode([t_id]).strip()
if not (ss.isdigit() or len(ss) == 0 or t_id == tokenizer.eos_token_id): if not (ss.isdigit() or len(ss) == 0 or t_id == tokenizer.eos_token_id):
logit_bias[t_id] = -1e5 logit_bias[t_id] = -1e5
# else: # else:
...@@ -214,4 +216,4 @@ def load_image(image_file): ...@@ -214,4 +216,4 @@ def load_image(image_file):
else: else:
image = Image.open(BytesIO(base64.b64decode(image_file))) image = Image.open(BytesIO(base64.b64decode(image_file)))
return image return image
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment