Unverified Commit 94e05770 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix after QWen support (#82)

parent 63e97e5e
...@@ -168,7 +168,10 @@ def match_llama2_chat(model_path: str): ...@@ -168,7 +168,10 @@ def match_llama2_chat(model_path: str):
@register_chat_template_matching_function @register_chat_template_matching_function
def match_chat_ml(model_path: str): def match_chat_ml(model_path: str):
if "tinyllama" in model_path.lower(): model_path = model_path.lower()
if "tinyllama" in model_path:
return get_chat_template("chatml")
if "qwen" in model_path and "chat" in model_path:
return get_chat_template("chatml") return get_chat_template("chatml")
......
...@@ -55,6 +55,7 @@ class DetokenizerManager: ...@@ -55,6 +55,7 @@ class DetokenizerManager:
first_token = self.tokenizer.convert_ids_to_tokens( first_token = self.tokenizer.convert_ids_to_tokens(
int(output_tokens[i][0]) int(output_tokens[i][0])
) )
if not isinstance(first_token, str):
first_token = first_token.decode("utf-8") first_token = first_token.decode("utf-8")
if first_token.startswith("▁"): if first_token.startswith("▁"):
output_strs[i] = " " + output_strs[i] output_strs[i] = " " + output_strs[i]
......
...@@ -5,7 +5,6 @@ from sglang.srt.layers.logits_processor import LogitsProcessor ...@@ -5,7 +5,6 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.router.model_runner import InputMetadata from sglang.srt.managers.router.model_runner import InputMetadata
from torch import nn from torch import nn
from vllm.transformers_utils.configs.qwen import QWenConfig
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
...@@ -26,9 +25,10 @@ from vllm.model_executor.weight_utils import ( ...@@ -26,9 +25,10 @@ from vllm.model_executor.weight_utils import (
default_weight_loader, default_weight_loader,
hf_model_weights_iterator, hf_model_weights_iterator,
) )
from vllm.transformers_utils.configs.qwen import QWenConfig
class QWenMLP(nn.Module):
class QWenMLP(nn.Module):
def __init__( def __init__(
self, self,
hidden_size: int, hidden_size: int,
...@@ -49,8 +49,10 @@ class QWenMLP(nn.Module): ...@@ -49,8 +49,10 @@ class QWenMLP(nn.Module):
input_is_parallel=True, input_is_parallel=True,
) )
if hidden_act != "silu": if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. " raise ValueError(
"Only silu is supported for now.") f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now."
)
self.act_fn = SiluAndMul() self.act_fn = SiluAndMul()
def forward(self, x): def forward(self, x):
...@@ -59,31 +61,28 @@ class QWenMLP(nn.Module): ...@@ -59,31 +61,28 @@ class QWenMLP(nn.Module):
x, _ = self.c_proj(x) x, _ = self.c_proj(x)
return x return x
class QWenAttention(nn.Module):
def __init__(self, class QWenAttention(nn.Module):
def __init__(
self,
hidden_size: int, hidden_size: int,
num_heads: int, num_heads: int,
max_position_embeddings: int, max_position_embeddings: int,
layer_id: int = 0, layer_id: int = 0,
rope_theta: float = 10000, rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None): rope_scaling: Optional[Dict[str, Any]] = None,
):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size( tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
)
self.total_num_heads = num_heads self.total_num_heads = num_heads
assert self.total_num_heads % tensor_model_parallel_world_size == 0 assert self.total_num_heads % tensor_model_parallel_world_size == 0
self.num_heads = (self.total_num_heads // self.num_heads = self.total_num_heads // tensor_model_parallel_world_size
tensor_model_parallel_world_size)
self.head_dim = hidden_size // self.total_num_heads self.head_dim = hidden_size // self.total_num_heads
# pylint: disable=invalid-name # pylint: disable=invalid-name
self.c_attn = QKVParallelLinear( self.c_attn = QKVParallelLinear(
hidden_size, hidden_size, self.head_dim, self.total_num_heads, bias=True
self.head_dim,
self.total_num_heads,
bias=True
) )
self.c_proj = RowParallelLinear( self.c_proj = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
...@@ -120,20 +119,22 @@ class QWenAttention(nn.Module): ...@@ -120,20 +119,22 @@ class QWenAttention(nn.Module):
output, _ = self.c_proj(attn_output) output, _ = self.c_proj(attn_output)
return output return output
class QWenBlock(nn.Module):
def __init__(self, config: QWenConfig,layer_id): class QWenBlock(nn.Module):
def __init__(self, config: QWenConfig, layer_id):
super().__init__() super().__init__()
self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
rope_theta = getattr(config, "rope_theta", 10000) rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None) rope_scaling = getattr(config, "rope_scaling", None)
self.attn = QWenAttention(config.hidden_size, self.attn = QWenAttention(
config.hidden_size,
config.num_attention_heads, config.num_attention_heads,
config.max_position_embeddings, config.max_position_embeddings,
rope_theta=rope_theta, rope_theta=rope_theta,
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
layer_id=layer_id) layer_id=layer_id,
)
self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
...@@ -162,9 +163,9 @@ class QWenBlock(nn.Module): ...@@ -162,9 +163,9 @@ class QWenBlock(nn.Module):
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
return hidden_states return hidden_states
class QWenModel(nn.Module):
def __init__(self, config:QWenConfig): class QWenModel(nn.Module):
def __init__(self, config: QWenConfig):
super().__init__() super().__init__()
self.config = config self.config = config
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
...@@ -175,7 +176,8 @@ class QWenModel(nn.Module): ...@@ -175,7 +176,8 @@ class QWenModel(nn.Module):
config.hidden_size, config.hidden_size,
) )
self.h = nn.ModuleList( self.h = nn.ModuleList(
[QWenBlock(config, i) for i in range(config.num_hidden_layers)]) [QWenBlock(config, i) for i in range(config.num_hidden_layers)]
)
self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
def forward( def forward(
...@@ -195,26 +197,23 @@ class QWenModel(nn.Module): ...@@ -195,26 +197,23 @@ class QWenModel(nn.Module):
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
return hidden_states return hidden_states
class QWenLMHeadModel(nn.Module):
def __init__(self, config: QWenConfig,linear_method=None): class QWenLMHeadModel(nn.Module):
def __init__(self, config: QWenConfig, linear_method=None):
super().__init__() super().__init__()
self.config = config self.config = config
self.transformer = QWenModel(config) self.transformer = QWenModel(config)
vocab_size = ((config.vocab_size + 63) // 64) * 64 vocab_size = ((config.vocab_size + 63) // 64) * 64
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(vocab_size, config.hidden_size)
vocab_size,
config.hidden_size
)
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
input_metadata: InputMetadata input_metadata: InputMetadata,
): ):
hidden_states = self.transformer(input_ids, positions,input_metadata) hidden_states = self.transformer(input_ids, positions, input_metadata)
next_tokens = self.logits_processor( next_tokens = self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata input_ids, hidden_states, self.lm_head.weight, input_metadata
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment