"src/array/vscode:/vscode.git/clone" did not exist on "44f0b5fe400c2ff192259bb8f93d6f7913443993"
Unverified Commit 16ff3d4b authored by wenhuipeng's avatar wenhuipeng Committed by GitHub
Browse files

Support opt model (#10165)

parent 83d55ac5
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Inference-only OPT model compatible with HuggingFace weights."""
from collections.abc import Iterable
from typing import Optional, Union
import torch
import torch.nn.functional as F
from torch import nn
from transformers import OPTConfig
from sglang.srt.distributed import (
get_pp_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from sglang.srt.layers.activation import get_act_fn
from sglang.srt.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import (
default_weight_loader,
kv_cache_scales_loader,
maybe_remap_kv_scale_name,
)
from sglang.srt.utils import add_prefix, make_layers
def get_activation(name="relu"):
"""Select an activation function by name
Args:
name: str
activation function name,
one of ["relu", "gelu", "swish", "sigmoid"],
default "relu".
"""
name = name.lower()
if name == "relu":
return nn.ReLU()
if name == "gelu":
return nn.GELU()
if name == "sigmoid":
return torch.nn.Sigmoid()
return nn.Identity()
class OPTLearnedPositionalEmbedding(nn.Embedding):
def __init__(self, num_embeddings: int, embedding_dim: int):
# OPT is set up so that if padding_idx is specified then offset the
# embedding ids by 2 and adjust num_embeddings appropriately. Other
# models don't have this hack
self.offset = 2
super().__init__(num_embeddings + self.offset, embedding_dim)
def forward(self, positions: torch.Tensor):
return super().forward(positions + self.offset)
class OPTAttention(nn.Module):
def __init__(
self,
embed_dim: int,
num_heads: int,
layer_id: int = 0,
bias: bool = True,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.embed_dim = embed_dim
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
total_num_heads = num_heads
assert num_heads % tensor_model_parallel_world_size == 0
self.num_heads = total_num_heads // tensor_model_parallel_world_size
self.head_dim = embed_dim // total_num_heads
self.scaling = self.head_dim**-0.5
self.qkv_proj = QKVParallelLinear(
embed_dim,
self.head_dim,
total_num_heads,
bias=bias,
quant_config=quant_config,
prefix=add_prefix("qkv_proj", prefix),
)
self.out_proj = RowParallelLinear(
embed_dim,
embed_dim,
bias=bias,
quant_config=quant_config,
prefix=add_prefix("o_proj", prefix),
)
self.attn = RadixAttention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_heads,
layer_id=layer_id,
quant_config=quant_config,
prefix=add_prefix("attn", prefix),
)
def forward(
self,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
attn_output = self.attn(q, k, v, forward_batch)
output, _ = self.out_proj(attn_output)
return output
class OPTDecoderLayer(nn.Module):
def __init__(
self,
config: OPTConfig,
layer_id: int = 0,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.self_attn = OPTAttention(
embed_dim=self.embed_dim,
num_heads=config.num_attention_heads,
layer_id=layer_id,
bias=config.enable_bias,
quant_config=quant_config,
prefix=add_prefix("self_attn", prefix),
)
self.do_layer_norm_before = config.do_layer_norm_before
self.self_attn_layer_norm = nn.LayerNorm(
self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine
)
self.fc1 = ColumnParallelLinear(
self.embed_dim,
config.ffn_dim,
bias=config.enable_bias,
quant_config=quant_config,
prefix=add_prefix("fc1", prefix),
)
self.activation_fn = get_activation(config.activation_function)
self.fc2 = RowParallelLinear(
config.ffn_dim,
self.embed_dim,
bias=config.enable_bias,
quant_config=quant_config,
prefix=add_prefix("fc2", prefix),
)
self.final_layer_norm = nn.LayerNorm(
self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine
)
def forward(
self,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
# Self Attention
residual = hidden_states
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
if self.do_layer_norm_before:
hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states = self.self_attn(
hidden_states=hidden_states, forward_batch=forward_batch
)
hidden_states = residual + hidden_states
# 350m applies layer norm AFTER attention
if not self.do_layer_norm_before:
hidden_states = self.self_attn_layer_norm(hidden_states)
# Fully Connected
residual = hidden_states
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
if self.do_layer_norm_before:
hidden_states = self.final_layer_norm(hidden_states)
hidden_states, _ = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states, _ = self.fc2(hidden_states)
hidden_states = residual + hidden_states
# 350m applies layer norm AFTER attention
if not self.do_layer_norm_before:
hidden_states = self.final_layer_norm(hidden_states)
return hidden_states
class OPTDecoder(nn.Module):
def __init__(
self,
config: OPTConfig,
layer_id: int = 0,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
self.max_target_positions = config.max_position_embeddings
self.vocab_size = config.vocab_size
self.pp_group = get_pp_group()
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.word_embed_proj_dim,
prefix=add_prefix("embed_tokens", prefix),
)
# Positional embeddings are replicated (not sharded).
self.embed_positions = OPTLearnedPositionalEmbedding(
config.max_position_embeddings, config.hidden_size
)
# Project out & in will be replicated if they exist.
if config.word_embed_proj_dim != config.hidden_size:
self.project_out = ReplicatedLinear(
config.hidden_size,
config.word_embed_proj_dim,
bias=False,
quant_config=quant_config,
prefix=add_prefix("project_out", prefix),
)
else:
self.project_out = None
if config.word_embed_proj_dim != config.hidden_size:
self.project_in = ReplicatedLinear(
config.word_embed_proj_dim,
config.hidden_size,
bias=False,
quant_config=quant_config,
prefix=add_prefix("project_in", prefix),
)
else:
self.project_in = None
# Note that the only purpose of `config._remove_final_layer_norm` is to
# keep backward compatibility with checkpoints that have been fine-tuned
# before transformers v4.20.1
# see https://github.com/facebookresearch/metaseq/pull/164
if config.do_layer_norm_before and not config._remove_final_layer_norm:
self.final_layer_norm = nn.LayerNorm(
config.hidden_size,
elementwise_affine=config.layer_norm_elementwise_affine,
)
else:
self.final_layer_norm = None
self.layers, self.start_layer, self.end_layer = make_layers(
config.num_hidden_layers,
lambda idx, prefix: OPTDecoderLayer(
config=config, layer_id=idx, quant_config=quant_config, prefix=prefix
),
pp_rank=self.pp_group.rank_in_group,
pp_size=self.pp_group.world_size,
prefix="model.layers",
)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
pp_proxy_tensors: Optional[PPProxyTensors] = None,
input_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, PPProxyTensors]:
if self.pp_group.is_first_rank:
if input_embeds is None:
input_embeds = self.embed_tokens(input_ids)
pos_embeds = self.embed_positions(positions)
if self.project_in is not None:
input_embeds, _ = self.project_in(input_embeds)
hidden_states = input_embeds + pos_embeds
else:
assert pp_proxy_tensors is not None
hidden_states = pp_proxy_tensors["hidden_states"]
for layer in self.layers[self.start_layer : self.end_layer]:
hidden_states = layer(
hidden_states=hidden_states, forward_batch=forward_batch
)
if not self.pp_group.is_last_rank:
return PPProxyTensors({"hidden_states": hidden_states})
if self.final_layer_norm is not None:
hidden_states = self.final_layer_norm(hidden_states)
# 没有经过这里
if self.project_out is not None:
hidden_states, _ = self.project_out(hidden_states)
return hidden_states
class OPTModel(nn.Module):
def __init__(
self,
config: OPTConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
# config = vllm_config.model_config.hf_config
# quant_config = vllm_config.quant_config
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.pp_group = get_pp_group()
self.decoder = OPTDecoder(
config=config,
quant_config=quant_config,
prefix=add_prefix("decoder", prefix),
)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
pp_proxy_tensors: Optional[PPProxyTensors],
input_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, PPProxyTensors]:
return self.decoder(
input_ids,
positions,
pp_proxy_tensors=pp_proxy_tensors,
input_embeds=input_embeds,
forward_batch=forward_batch,
)
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
for layer_idx, scaling_factor in kv_cache_scales_loader(
quantization_param_path,
tp_rank,
tp_size,
self.config.num_hidden_layers,
self.config.__class__.model_type,
):
if not isinstance(self.decoder.layers[layer_idx], nn.Identity):
layer_self_attn = self.decoder.layers[layer_idx].self_attn
if hasattr(layer_self_attn.attn, "k_scale"):
layer_self_attn.attn.k_scale = scaling_factor
layer_self_attn.attn.v_scale = scaling_factor
else:
raise RuntimeError(
"Self attention has no KV cache scaling " "factor attribute!"
)
class OPTForCausalLM(nn.Module):
# BitandBytes specific attributes
# in TP, these weights are partitioned along the column dimension (dim=-1)
column_parallel_weights_modules = [".down_proj.", ".o_proj."]
def __init__(
self,
config: OPTConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
self.quant_config = quant_config
self.model = OPTModel(
config=config, quant_config=quant_config, prefix=add_prefix("model", prefix)
)
if self.config.tie_word_embeddings:
self.lm_head = self.model.decoder.embed_tokens
else:
self.lm_head = ParallelLMHead(
config.vocab_size,
config.word_embed_proj_dim,
prefix=add_prefix("lm_head", prefix),
)
self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
self.capture_aux_hidden_states = False
self.pp_group = get_pp_group()
self.stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
]
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
pp_proxy_tensors: Optional[PPProxyTensors] = None,
input_embeds: Optional[torch.Tensor] = None,
get_embedding: bool = False,
) -> LogitsProcessorOutput:
hidden_states = self.model(
input_ids=input_ids,
positions=positions,
forward_batch=forward_batch,
input_embeds=input_embeds,
pp_proxy_tensors=pp_proxy_tensors,
)
aux_hidden_states = None
if self.capture_aux_hidden_states:
hidden_states, aux_hidden_states = hidden_states
if self.pp_group.is_last_rank:
if not get_embedding:
return self.logits_processor(
input_ids,
hidden_states,
self.lm_head,
forward_batch,
aux_hidden_states=aux_hidden_states,
)
else:
return self.pooler(hidden_states, forward_batch)
else:
return hidden_states
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> None:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in weights:
if name.startswith("decoder"):
name = name.replace("decoder.", "model.decoder.")
layer_id = get_layer_id(name)
if (
layer_id is not None
and hasattr(self.model, "start_layer")
and (
layer_id < self.model.start_layer
or layer_id >= self.model.end_layer
)
):
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# if is_pp_missing_parameter(name, self):
# continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# if is_pp_missing_parameter(name, self):
# continue
if name not in params_dict:
continue
if name in params_dict.keys():
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
else:
logger.warning(f"Parameter {name} not found in params_dict")
@property
def start_layer(self):
return self.model.start_layer
@property
def end_layer(self):
return self.model.end_layer
def get_input_embeddings(self) -> nn.Embedding:
return self.model.embed_tokens
def get_module_name_from_weight_name(self, name):
for param_name, weight_name, shard_id, num_shard in self.stacked_params_mapping:
if weight_name in name:
return (
name.replace(weight_name, param_name)[: -len(".weight")],
num_shard,
)
return name[: -len(".weight")], 1
def get_num_params(self):
params_dict = dict(self.named_parameters())
return len(params_dict)
def get_weights_by_name(
self, name: str, truncate_size: int = 100, tp_size: int = 1
) -> Optional[torch.Tensor]:
"""Get the weights of the parameter by its name. Similar to `get_parameter` in Hugging Face.
Only used for unit test with an unoptimized performance.
For optimized performance, please use torch.save and torch.load.
"""
try:
if name == "lm_head.weight" and self.config.tie_word_embeddings:
logger.info(
"word embedding is tied for this model, return embed_tokens.weight as lm_head.weight."
)
return (
self.model.embed_tokens.weight.cpu()
.to(torch.float32)
.numpy()
.tolist()[:truncate_size]
)
mapped_name = name
mapped_shard_id = None
for param_name, weight_name, shard_id in self.stacked_params_mapping:
if weight_name in name:
mapped_name = name.replace(weight_name, param_name)
mapped_shard_id = shard_id
break
params_dict = dict(self.named_parameters())
param = params_dict[mapped_name]
if mapped_shard_id is not None:
if mapped_shard_id in ["q", "k", "v"]:
num_heads = self.config.num_attention_heads // tp_size
num_kv_heads = self.config.num_attention_heads // tp_size
head_dim = (
self.config.hidden_size // self.config.num_attention_heads
)
if mapped_shard_id == "q":
offset = 0
size = num_heads * head_dim
elif mapped_shard_id == "k":
offset = num_heads * head_dim
size = num_kv_heads * head_dim
elif mapped_shard_id == "v":
offset = (num_heads + num_kv_heads) * head_dim
size = num_kv_heads * head_dim
weight = param.data.narrow(0, offset, size)
elif mapped_shard_id in [0, 1]:
intermediate_size = self.config.ffn_dim
slice_size = intermediate_size // tp_size
if mapped_shard_id == 0: # gate_proj
offset = 0
size = slice_size
elif mapped_shard_id == 1: # up_proj
offset = slice_size
size = slice_size
weight = param.data.narrow(0, offset, size)
else:
weight = param.data
else:
weight = param.data
if tp_size > 1 and ("o_proj" in name or "down_proj" in name):
gathered_weights = [torch.zeros_like(weight) for _ in range(tp_size)]
torch.distributed.all_gather(gathered_weights, weight)
weight = torch.cat(gathered_weights, dim=1)
return weight.cpu().to(torch.float32).numpy().tolist()[:truncate_size]
except Exception:
logger.error(
f"Error getting weights by name {name} in OPTForCausalLM: {get_exception_traceback()}"
)
return None
def get_embed_and_head(self):
return self.model.embed_tokens.weight, self.lm_head.weight
def set_embed_and_head(self, embed, head):
del self.model.embed_tokens.weight
del self.lm_head.weight
self.model.embed_tokens.weight = embed
self.lm_head.weight = head
torch.cuda.empty_cache()
torch.cuda.synchronize()
def get_embed(self):
return self.model.embed_tokens.weight
def set_embed(self, embed):
# NOTE: If draft hidden size != target hidden size, the embed weight cannot be shared for EAGLE3
if (
hasattr(self.config, "target_hidden_size")
and self.config.target_hidden_size != self.config.hidden_size
):
return
del self.model.embed_tokens.weight
self.model.embed_tokens.weight = embed
torch.cuda.empty_cache()
torch.cuda.synchronize()
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
self.model.load_kv_cache_scales(quantization_param_path)
EntryClass = [OPTForCausalLM]
...@@ -77,6 +77,7 @@ ALL_MODELS = [ ...@@ -77,6 +77,7 @@ ALL_MODELS = [
trust_remote_code=True, trust_remote_code=True,
skip_long_prompt=True, skip_long_prompt=True,
), ),
ModelCase("facebook/opt-125m", skip_long_prompt=True),
ModelCase( ModelCase(
"nvidia/Llama-3_3-Nemotron-Super-49B-v1_5", "nvidia/Llama-3_3-Nemotron-Super-49B-v1_5",
tp_size=2, tp_size=2,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment