Commit 6d4b207a authored by zhuwenwen's avatar zhuwenwen
Browse files

[Model] support telechat_12B

parent a7e72c36
...@@ -11,15 +11,16 @@ vLLM是一个快速且易于使用的LLM推理和服务库,使用PageAttention ...@@ -11,15 +11,16 @@ vLLM是一个快速且易于使用的LLM推理和服务库,使用PageAttention
| 结构 | 模型 | 模型并行 | FP16 | | 结构 | 模型 | 模型并行 | FP16 |
| :------: | :------: | :------: | :------: | | :------: | :------: | :------: | :------: |
| LlamaForCausalLM | LLaMA、LLaMA-2、LLaMA-3、Codellama、deepseek、Yi | Yes | Yes | | LlamaForCausalLM | LLaMA、LLaMA-2、LLaMA-3、Codellama、deepseek、Yi | Yes | Yes |
| QWenLMHeadModel | QWen | Yes | Yes | | QWenLMHeadModel | QWen、Qwen-VL | Yes | Yes |
| Qwen2ForCausalLM | QWen1.5、CodeQwen1.5、QWen2 | Yes | Yes | | Qwen2ForCausalLM | QWen1.5、CodeQwen1.5、QWen2 | Yes | Yes |
| ChatGLMModel | chatglm2、chatglm3 | Yes | Yes | | ChatGLMModel | chatglm2、chatglm3 | Yes | Yes |
| BaiChuanForCausalLM | Baichuan、Baichuan2 | Yes | Yes | | BaiChuanForCausalLM | Baichuan、Baichuan2 | Yes | Yes |
| BloomForCausalLM     | BLOOM        | Yes | Yes | | BloomForCausalLM | BLOOM | Yes | Yes |
| InternLMForCausalLM | InternLM | Yes | Yes | | InternLMForCausalLM | InternLM | Yes | Yes |
| InternLM2ForCausalLM | InternLM2 | Yes | Yes | | InternLM2ForCausalLM | InternLM2 | Yes | Yes |
| DeepseekV2ForCausalLM | DeepSeek-V2 | Yes | Yes | | DeepseekV2ForCausalLM | DeepSeek-V2 | Yes | Yes |
| MixtralForCausalLM | Mixtral-8x7B | Yes | Yes | | MixtralForCausalLM | Mixtral-8x7B | Yes | Yes |
| TeleChat12BForCausalLM (#TelechatForCausalLM) | TeleChat-12B | Yes | Yes |
## 安装 ## 安装
......
{{ (messages|selectattr('role', 'equalto', 'system')|list|last).content|trim if (messages|selectattr('role', 'equalto', 'system')|list) else '' }}
{%- for message in messages -%}
{%- if message['role'] == 'user' -%}
{{- '<_user>' + message['content'] +'<_bot>' -}}
{%- elif message['role'] == 'assistant' -%}
{{- message['content'] + '<_end>' -}}
{%- endif -%}
{%- endfor -%}
...@@ -35,7 +35,7 @@ class PagedAttention: ...@@ -35,7 +35,7 @@ class PagedAttention:
@staticmethod @staticmethod
def get_supported_head_sizes() -> List[int]: def get_supported_head_sizes() -> List[int]:
return [64, 80, 96, 112, 120, 128, 192, 256] return [64, 80, 96, 112, 120, 128, 160, 192, 256]
@staticmethod @staticmethod
def get_kv_cache_shape( def get_kv_cache_shape(
......
...@@ -13,14 +13,10 @@ from vllm.distributed.parallel_state import in_the_same_node_as ...@@ -13,14 +13,10 @@ from vllm.distributed.parallel_state import in_the_same_node_as
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import cuda_device_count_stateless from vllm.utils import cuda_device_count_stateless
from vllm.utils import is_hip
try: try:
if (not is_hip()): ops.meta_size()
ops.meta_size() custom_ar = True
custom_ar = True
else:
custom_ar = False
except Exception: except Exception:
# For AMD GPUs and CPUs # For AMD GPUs and CPUs
......
...@@ -157,7 +157,7 @@ class AWQLinearMethod(LinearMethodBase): ...@@ -157,7 +157,7 @@ class AWQLinearMethod(LinearMethodBase):
zeros_and_scales = GroupQuantScaleParameter(data=torch.empty( zeros_and_scales = GroupQuantScaleParameter(data=torch.empty(
input_size_per_partition // self.quant_config.group_size, input_size_per_partition // self.quant_config.group_size,
output_size_per_partition, output_size_per_partition,
dtype=params_dtype, dtype=torch.int32,
), ),
input_dim=0, input_dim=0,
output_dim=1, output_dim=1,
......
...@@ -61,6 +61,7 @@ _GENERATION_MODELS = { ...@@ -61,6 +61,7 @@ _GENERATION_MODELS = {
"StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"), "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
"StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"), "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
"Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"), "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
"TeleChat12BForCausalLM": ("telechat_12B", "TeleChat12BForCausalLM"), # telechat12b
"SolarForCausalLM": ("solar", "SolarForCausalLM"), "SolarForCausalLM": ("solar", "SolarForCausalLM"),
"ArcticForCausalLM": ("arctic", "ArcticForCausalLM"), "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
"XverseForCausalLM": ("xverse", "XverseForCausalLM"), "XverseForCausalLM": ("xverse", "XverseForCausalLM"),
......
# coding=utf-8
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import (get_pp_group, get_pp_indices,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
split_tensor_along_last_dim,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, kv_cache_scales_loader)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.utils import is_hip, print_warning_once
from .interfaces import SupportsLoRA
class TeleChatMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
) -> None:
super().__init__()
self.gate_proj = ColumnParallelLinear(
input_size=hidden_size,
output_size=intermediate_size,
bias=False,
quant_config=quant_config)
self.up_proj = ColumnParallelLinear(
input_size=hidden_size,
output_size=intermediate_size,
bias=False,
quant_config=quant_config)
self.down_proj = RowParallelLinear(input_size=intermediate_size,
output_size=hidden_size,
# bias=bias,
bias=True,
quant_config=quant_config,
input_is_parallel=True)
self.act_fn = SiluAndMul()
def forward(self, x):
gate_output, _ = self.gate_proj(x)
up_output, _ = self.up_proj(x)
gate_output = self.act_fn(torch.cat([gate_output, up_output], dim=-1))
output, _ = self.down_proj(gate_output)
return output
class TeleChatAttention(nn.Module):
def __init__(
self,
config,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
cache_config: Optional[CacheConfig] = None,
) -> None:
super().__init__()
self.config = config
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.tp_size = tp_size
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.query = ColumnParallelLinear(
input_size=hidden_size,
output_size=hidden_size,
bias=False,
quant_config=quant_config,
gather_output=False)
kv_projection_size = self.head_dim * self.total_num_heads
self.key_value = MergedColumnParallelLinear(
input_size=hidden_size,
output_sizes=[kv_projection_size] * 2,
bias=False,
quant_config=quant_config,
gather_output=False,)
self.dense = RowParallelLinear(
input_size=hidden_size,
output_size=hidden_size,
bias=True,
quant_config=quant_config,
input_is_parallel=True,
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
def split_tensor_along_last_dim(self,
tensor: torch.Tensor,
num_partitions: int,
contiguous_split_chunks: bool = False,
):
# Get the size and dimension.
last_dim = tensor.dim() - 1
last_dim_size = tensor.size()[last_dim] // num_partitions
# Split.
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
return tensor_list
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
query_layer, _ = self.query(hidden_states)
mixed_kv_layer, _ = self.key_value(hidden_states)
(key_layer, value_layer) = self.split_tensor_along_last_dim(mixed_kv_layer, 2)
query_layer, key_layer = self.rotary_emb(positions, query_layer, key_layer)
attn_output = self.attn(query_layer, key_layer, value_layer, kv_cache, attn_metadata)
output, _ = self.dense(attn_output)
return output
class TeleChatDecoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
if rope_scaling is not None and getattr(
config, "original_max_position_embeddings", None):
rope_scaling["original_max_position_embeddings"] = (
config.original_max_position_embeddings)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
# Support abacusai/Smaug-72B-v0.1 with attention_bias
# Support internlm/internlm-7b with bias
attention_bias = getattr(config, "attention_bias", False) or getattr(
config, "bias", False)
self.self_attention = TeleChatAttention(
config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=getattr(config, "num_key_value_heads",
config.num_attention_heads),
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
bias=attention_bias,
cache_config=cache_config,
)
self.mlp = TeleChatMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
bias=getattr(config, "mlp_bias", False),
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> Tuple[torch.Tensor, torch.Tensor]:
residual = hidden_states
layernorm_output = self.input_layernorm(hidden_states)
attn_outputs = self.self_attention(
positions=positions,
hidden_states=layernorm_output,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
attn_outputs = residual + attn_outputs
residual = attn_outputs
layernorm_output = self.post_attention_layernorm(attn_outputs)
output = residual + self.mlp(layernorm_output)
return output
class TeleChatModel(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.word_embeddings = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
)
self.h = nn.ModuleList(
[
TeleChatDecoderLayer(config=config,
cache_config=cache_config,
quant_config=quant_config)
for _ in range(self.config.num_hidden_layers)
]
)
self.ln_f = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.word_embeddings(input_ids)
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.get_input_embeddings(input_ids)
for i in range(self.config.num_hidden_layers):
layer = self.h[i]
hidden_states = layer(
positions,
hidden_states,
kv_caches[i],
attn_metadata,
)
hidden_states = self.ln_f(hidden_states)
return hidden_states
class TeleChat12BForCausalLM(nn.Module, SupportsLoRA):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens",
"lm_head"
]
embedding_modules = {
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
super().__init__()
config.intermediate_size = config.ffn_hidden_size
config.hidden_act = "silu"
config.rms_norm_eps = config.layer_norm_epsilon
config.tie_word_embeddings = False
self.config = config
self.lora_config = lora_config
self.transformer = TeleChatModel(config,
cache_config,
quant_config,
lora_config=lora_config)
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
bias=False,
quant_config=quant_config,
)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
model_output = self.transformer(input_ids, positions, kv_caches,
attn_metadata)
return model_output
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def make_empty_intermediate_tensors(
self, batch_size: int, dtype: torch.dtype,
device: torch.device) -> IntermediateTensors:
return IntermediateTensors({
"hidden_states":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
"residual":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
})
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters())
num_key_value_heads = self.config.num_attention_heads
head_dim = self.config.hidden_size // num_key_value_heads
for name, loaded_weight in weights:
if "self_attention.key_value" in name:
k_weight = []
v_weight = []
for i in range(num_key_value_heads):
start =i * head_dim * 2
k_weight.append(loaded_weight[start:start+head_dim,:])
v_weight.append(loaded_weight[start+head_dim:start+2*head_dim:])
k_weight = torch.cat(k_weight,dim=0)
v_weight = torch.cat(v_weight,dim=0)
loaded_weight = torch.cat([k_weight, v_weight], dim=0)
try:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
except KeyError:
print("key error")
pass
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
for layer_idx, scaling_factor in kv_cache_scales_loader(
quantization_param_path, tp_rank, tp_size,
self.config.num_hidden_layers,
self.config.__class__.model_type):
if not isinstance(self.model.layers[layer_idx], nn.Identity):
layer_self_attn = self.model.layers[layer_idx].self_attn
if is_hip():
# The scaling factor convention we are assuming is
# quantized_value * scaling_factor ~= true_value
# which is consistent with the practice of setting
# scaling_factor = tensor_amax / FPtype_max
scaling_factor *= 2
if hasattr(layer_self_attn, "kv_scale"):
layer_self_attn.attn._kv_scale = scaling_factor
else:
raise RuntimeError("Self attention has no KV cache scaling "
"factor attribute!")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment