Unverified Commit 0de7c2d0 authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

Add e5-mistral modules [unreachable code] - step 1/3 (#983)

parent 6ed4e3b8
# adapted from
# https://github.com/vllm-project/vllm/blob/82a1b1a82b1fbb454c82a9ef95730b929c9b270c/vllm/model_executor/layers/pooler.py
from dataclasses import dataclass
from enum import IntEnum
import torch
import torch.nn as nn
from sglang.srt.model_executor.model_runner import InputMetadata
class PoolingType(IntEnum):
LAST = 0
@dataclass
class EmbeddingPoolerOutput:
embeddings: torch.Tensor
class Pooler(nn.Module):
"""A layer that pools specific information from hidden states.
This layer does the following:
1. Extracts specific tokens or aggregates data based on pooling method.
2. Normalizes output if specified.
3. Returns structured results as `PoolerOutput`.
Attributes:
pooling_type: The type of pooling to use (LAST, AVERAGE, MAX).
normalize: Whether to normalize the pooled data.
"""
def __init__(self, pooling_type: PoolingType, normalize: bool):
super().__init__()
self.pooling_type = pooling_type
self.normalize = normalize
def forward(
self, hidden_states: torch.Tensor, input_metadata: InputMetadata
) -> EmbeddingPoolerOutput:
if self.pooling_type == PoolingType.LAST:
last_token_indices = torch.cumsum(input_metadata.extend_seq_lens, dim=0) - 1
pooled_data = hidden_states[last_token_indices]
else:
raise ValueError(f"Invalid pooling type: {self.pooling_type}")
if self.normalize:
pooled_data = nn.functional.normalize(pooled_data, p=2, dim=1)
return EmbeddingPoolerOutput(embeddings=pooled_data)
from typing import Iterable, Optional, Tuple
import torch
from torch import nn
from transformers import LlamaConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
from sglang.srt.model_executor.model_runner import InputMetadata
from sglang.srt.models.llama2 import LlamaForCausalLM, LlamaModel
class LlamaEmbeddingModel(nn.Module):
def __init__(
self,
config: LlamaConfig,
quant_config=None,
cache_config=None,
efficient_weight_load=False,
) -> None:
super().__init__()
self.model = LlamaModel(config, quant_config=quant_config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
input_embeds: torch.Tensor = None,
) -> EmbeddingPoolerOutput:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
return self.pooler(hidden_states, input_metadata)
def load_weights(
self, weights: Iterable[Tuple[str, torch.Tensor]], name=None, loaded_weight=None
):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.model.named_parameters())
def load_weights_per_param(name, loaded_weight):
if "rotary_emb.inv_freq" in name or "projector" in name:
return
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
return
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if name.startswith("model.vision_tower") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
return
if name.startswith("model.vision_tower") and name not in params_dict:
return
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
if name is None or loaded_weight is None:
for name, loaded_weight in weights:
load_weights_per_param(name, loaded_weight)
else:
load_weights_per_param(name, loaded_weight)
EntryClass = LlamaEmbeddingModel
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment