Unverified Commit fad7ca73 authored by Praneth Paruchuri's avatar Praneth Paruchuri Committed by GitHub
Browse files

model: support starcoder2 (#10609)

parent 08af8ffb
......@@ -55,3 +55,4 @@ in the GitHub search bar.
| **Granite 3.0 MoE** (IBM) | `ibm-granite/granite-3.0-3b-a800m-instruct` | IBM’s Mixture-of-Experts models offering strong performance with cost-efficiency. MoE expert routing designed for enterprise deployment at scale. |
| **Llama Nemotron Super** (v1, v1.5, NVIDIA) | `nvidia/Llama-3_3-Nemotron-Super-49B-v1`, `nvidia/Llama-3_3-Nemotron-Super-49B-v1_5` | The [NVIDIA Nemotron](https://www.nvidia.com/en-us/ai-data-science/foundation-models/nemotron/) family builds on the strongest open models in the ecosystem by enhancing them with greater accuracy, efficiency, and transparency using NVIDIA open synthetic datasets, advanced techniques, and tools. This enables the creation of practical, right-sized, and high-performing AI agents. |
| **Llama Nemotron Ultra** (v1, NVIDIA) | `nvidia/Llama-3_1-Nemotron-Ultra-253B-v1` | The [NVIDIA Nemotron](https://www.nvidia.com/en-us/ai-data-science/foundation-models/nemotron/) family builds on the strongest open models in the ecosystem by enhancing them with greater accuracy, efficiency, and transparency using NVIDIA open synthetic datasets, advanced techniques, and tools. This enables the creation of practical, right-sized, and high-performing AI agents. |
| **StarCoder2** (3B-15B) | `bigcode/starcoder2-7b` | StarCoder2 is a family of open large language models (LLMs) specialized for code generation and understanding. It is the successor to StarCoder, jointly developed by the BigCode project (a collaboration between Hugging Face, ServiceNow Research, and other contributors). |
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2024 BigCode and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/starcoder2.py
""" PyTorch Starcoder2 model."""
from collections.abc import Iterable
from typing import Optional, Tuple
import torch
from torch import nn
from transformers import Starcoder2Config
from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import get_act_fn
from sglang.srt.layers.linear import (
ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE,
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix, make_layers
class Starcoder2Attention(nn.Module):
def __init__(
self,
config: Starcoder2Config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
layer_id: int = 0,
):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = config.num_attention_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = config.num_key_value_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = self.hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = config.rope_theta
self.max_position_embeddings = config.max_position_embeddings
self.use_bias = config.use_bias
self.qkv_proj = QKVParallelLinear(
self.hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=self.use_bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
self.hidden_size,
bias=self.use_bias,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
base=int(self.rope_theta),
is_neox_style=True,
)
self.attn = RadixAttention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
layer_id=layer_id,
quant_config=quant_config,
prefix=f"{prefix}.attn",
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, forward_batch)
output, _ = self.o_proj(attn_output)
return output
class Starcoder2MLP(nn.Module):
def __init__(
self,
config: Starcoder2Config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.c_fc = ColumnParallelLinear(
config.hidden_size,
config.intermediate_size,
bias=config.use_bias,
quant_config=quant_config,
prefix=f"{prefix}.c_fc",
)
self.c_proj = RowParallelLinear(
config.intermediate_size,
config.hidden_size,
bias=config.use_bias,
quant_config=quant_config,
prefix=f"{prefix}.c_proj",
)
self.act = get_act_fn(config.hidden_act)
def forward(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:
hidden_states, _ = self.c_fc(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states, _ = self.c_proj(hidden_states)
return hidden_states
class Starcoder2DecoderLayer(nn.Module):
def __init__(
self,
config: Starcoder2Config,
layer_id: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = Starcoder2Attention(
config=config,
layer_id=layer_id,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
self.mlp = Starcoder2MLP(
config, quant_config=quant_config, prefix=f"{prefix}.mlp"
)
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
self.post_attention_layernorm = nn.LayerNorm(
config.hidden_size, eps=config.norm_epsilon
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
# Self Attention
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
forward_batch=forward_batch,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class Starcoder2Model(nn.Module):
def __init__(
self,
config: Starcoder2Config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.embed_tokens",
)
pp_group = get_pp_group()
pp_size = pp_group.world_size
pp_rank = pp_group.rank
self.start_layer = pp_rank * config.num_hidden_layers // pp_size
self.end_layer = (pp_rank + 1) * config.num_hidden_layers // pp_size
self.layers = make_layers(
config.num_hidden_layers,
lambda idx, prefix: Starcoder2DecoderLayer(
config=config, quant_config=quant_config, layer_id=idx, prefix=prefix
),
prefix=f"{prefix}.layers",
)
self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if inputs_embeds is None:
hidden_states = self.embed_tokens(input_ids)
else:
hidden_states = inputs_embeds
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states = layer(
positions,
hidden_states,
forward_batch,
)
hidden_states = self.norm(hidden_states)
return hidden_states
class Starcoder2ForCausalLM(nn.Module):
def __init__(
self,
config: Starcoder2Config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
self.model = Starcoder2Model(
config, quant_config, prefix=add_prefix("model", prefix)
)
self.vocab_size = config.vocab_size
self.unpadded_vocab_size = config.vocab_size
if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.unpadded_vocab_size = config.vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
quant_config=quant_config,
prefix=f"{prefix}.lm_head",
)
self.logits_processor = LogitsProcessor(config=config)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
hidden_states = self.model(
input_ids=input_ids,
positions=positions,
forward_batch=forward_batch,
inputs_embeds=inputs_embeds,
)
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freqs" in name:
continue
is_stacked = False
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name in name:
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight, shard_id)
is_stacked = True
break
if is_stacked:
continue
param = params_dict.get(name)
if param is None:
continue
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
EntryClass = Starcoder2ForCausalLM
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment